refactor: sql lab command: separate concerns into different modules (#16917)

* chore move sql_execution_context to sqllab package

* add new helper methods into base Dao

* refactor separate get existing query concern from command

* refactor separate query access validation concern

* refactor separate get query's database concern from command

* refactor separate get query rendering concern from command

* refactor sqllab_execution_context

* refactor separate creating payload for view

* chore decouple command from superset app

* fix pylint issues

* fix failed tests

* fix pylint issues

* fix failed test

* fix failed black

* fix failed black

* fix failed test
This commit is contained in:
ofekisr 2021-10-03 11:15:46 +03:00 committed by GitHub
parent f0060a63c0
commit 0d0c759cfe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 661 additions and 286 deletions

View File

@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=arguments-renamed
import logging
from typing import List, Optional, TYPE_CHECKING

View File

@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=isinstance-second-argument-not-valid-type
from typing import Any, Dict, List, Optional, Type
from flask_appbuilder.models.filters import BaseFilter
@ -89,6 +90,19 @@ class BaseDAO:
).apply(query, None)
return query.all()
@classmethod
def find_one_or_none(cls, **filter_by: Any) -> Optional[Model]:
"""
Get the first that fit the `base_filter`
"""
query = db.session.query(cls.model_cls)
if cls.base_filter:
data_model = SQLAInterface(cls.model_cls, db.session)
query = cls.base_filter( # pylint: disable=not-callable
"id", data_model
).apply(query, None)
return query.filter_by(**filter_by).one_or_none()
@classmethod
def create(cls, properties: Dict[str, Any], commit: bool = True) -> Model:
"""
@ -109,6 +123,27 @@ class BaseDAO:
raise DAOCreateFailedError(exception=ex) from ex
return model
@classmethod
def save(cls, instance_model: Model, commit: bool = True) -> Model:
"""
Generic for saving models
:raises: DAOCreateFailedError
"""
if cls.model_cls is None:
raise DAOConfigError()
if not isinstance(instance_model, cls.model_cls):
raise DAOCreateFailedError(
"the instance model is not a type of the model class"
)
try:
db.session.add(instance_model)
if commit:
db.session.commit()
except SQLAlchemyError as ex: # pragma: no cover
db.session.rollback()
raise DAOCreateFailedError(exception=ex) from ex
return instance_model
@classmethod
def update(
cls, model: Model, properties: Dict[str, Any], commit: bool = True

View File

@ -143,7 +143,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
return len(dataset_query) == 0
@classmethod
def update( # pylint: disable=arguments-differ
def update(
cls, model: SqlaTable, properties: Dict[str, Any], commit: bool = True
) -> Optional[SqlaTable]:
"""

View File

@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=line-too-long
"""A collection of ORM sqlalchemy models for Superset"""
import enum
import json
@ -253,7 +254,7 @@ class Database(
@property
def parameters_schema(self) -> Dict[str, Any]:
try:
parameters_schema = self.db_engine_spec.parameters_json_schema() # type: ignore # pylint: disable=line-too-long
parameters_schema = self.db_engine_spec.parameters_json_schema() # type: ignore
except Exception: # pylint: disable=broad-except
parameters_schema = {}
return parameters_schema

View File

@ -14,73 +14,71 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=line-too-long
# pylint: disable=too-few-public-methods, too-many-arguments
from __future__ import annotations
import dataclasses
import logging
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, TYPE_CHECKING
import simplejson as json
from flask import g
from flask_babel import gettext as __, ngettext
from jinja2.exceptions import TemplateError
from jinja2.meta import find_undeclared_variables
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm.session import Session
from flask_babel import gettext as __
from superset import app, db, is_feature_enabled, sql_lab
from superset.commands.base import BaseCommand
from superset.common.db_query_status import QueryStatus
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import (
SupersetErrorException,
SupersetErrorsException,
SupersetGenericDBErrorException,
SupersetGenericErrorException,
SupersetSecurityException,
SupersetTemplateParamsErrorException,
SupersetTimeoutException,
)
from superset.jinja_context import BaseTemplateProcessor, get_template_processor
from superset.dao.exceptions import DAOCreateFailedError
from superset.errors import SupersetErrorType
from superset.exceptions import SupersetErrorsException, SupersetGenericErrorException
from superset.models.core import Database
from superset.models.sql_lab import Query
from superset.queries.dao import QueryDAO
from superset.sqllab.command_status import SqlJsonExecutionStatus
from superset.sqllab.exceptions import SqlLabException
from superset.sqllab.limiting_factor import LimitingFactor
from superset.sqllab.utils import apply_display_max_row_configuration_if_require
from superset.utils import core as utils
from superset.utils.dates import now_as_float
from superset.utils.sqllab_execution_context import SqlJsonExecutionContext
config = app.config
logger = logging.getLogger(__name__)
PARAMETER_MISSING_ERR = (
"Please check your template parameters for syntax errors and make sure "
"they match across your SQL query and Set Parameters. Then, try running "
"your query again."
from superset.sqllab.exceptions import (
QueryIsForbiddenToAccessException,
SqlLabException,
)
from superset.sqllab.limiting_factor import LimitingFactor
SqlResults = Dict[str, Any]
if TYPE_CHECKING:
from superset.sqllab.sql_json_executer import SqlJsonExecutor
from superset.sqllab.sqllab_execution_context import SqlJsonExecutionContext
from superset.queries.dao import QueryDAO
from superset.databases.dao import DatabaseDAO
logger = logging.getLogger(__name__)
CommandResult = Dict[str, Any]
class ExecuteSqlCommand(BaseCommand):
_execution_context: SqlJsonExecutionContext
_query_dao: QueryDAO
_database_dao: DatabaseDAO
_access_validator: CanAccessQueryValidator
_sql_query_render: SqlQueryRender
_sql_json_executor: SqlJsonExecutor
_execution_context_convertor: ExecutionContextConvertor
_sqllab_ctas_no_limit: bool
_log_params: Optional[Dict[str, Any]] = None
_session: Session
def __init__(
self,
execution_context: SqlJsonExecutionContext,
query_dao: QueryDAO,
database_dao: DatabaseDAO,
access_validator: CanAccessQueryValidator,
sql_query_render: SqlQueryRender,
sql_json_executor: SqlJsonExecutor,
execution_context_convertor: ExecutionContextConvertor,
sqllab_ctas_no_limit_flag: bool,
log_params: Optional[Dict[str, Any]] = None,
) -> None:
self._execution_context = execution_context
self._query_dao = query_dao
self._database_dao = database_dao
self._access_validator = access_validator
self._sql_query_render = sql_query_render
self._sql_json_executor = sql_json_executor
self._execution_context_convertor = execution_context_convertor
self._sqllab_ctas_no_limit = sqllab_ctas_no_limit_flag
self._log_params = log_params
self._session = db.session()
def validate(self) -> None:
pass
@ -90,7 +88,7 @@ class ExecuteSqlCommand(BaseCommand):
) -> CommandResult:
"""Runs arbitrary sql and returns data as json"""
try:
query = self._get_existing_query()
query = self._try_get_existing_query()
if self.is_query_handled(query):
self._execution_context.set_query(query) # type: ignore
status = SqlJsonExecutionStatus.QUERY_ALREADY_CREATED
@ -98,24 +96,21 @@ class ExecuteSqlCommand(BaseCommand):
status = self._run_sql_json_exec_from_scratch()
return {
"status": status,
"payload": self._create_payload_from_execution_context(status),
"payload": self._execution_context_convertor.to_payload(
self._execution_context, status
),
}
except (SqlLabException, SupersetErrorsException) as ex:
raise ex
except Exception as ex:
raise SqlLabException(self._execution_context, exception=ex) from ex
def _get_existing_query(self) -> Optional[Query]:
query = (
self._session.query(Query)
.filter_by(
client_id=self._execution_context.client_id,
user_id=self._execution_context.user_id,
sql_editor_id=self._execution_context.sql_editor_id,
)
.one_or_none()
def _try_get_existing_query(self) -> Optional[Query]:
return self._query_dao.find_one_or_none(
client_id=self._execution_context.client_id,
user_id=self._execution_context.user_id,
sql_editor_id=self._execution_context.sql_editor_id,
)
return query
@classmethod
def is_query_handled(cls, query: Optional[Query]) -> bool:
@ -130,20 +125,20 @@ class ExecuteSqlCommand(BaseCommand):
query = self._execution_context.create_query()
self._save_new_query(query)
try:
self._save_new_query(query)
logger.info("Triggering query_id: %i", query.id)
self._validate_access(query)
self._execution_context.set_query(query)
rendered_query = self._render_query()
rendered_query = self._sql_query_render.render(self._execution_context)
self._set_query_limit_if_required(rendered_query)
return self._execute_query(rendered_query)
return self._sql_json_executor.execute(
self._execution_context, rendered_query, self._log_params
)
except Exception as ex:
query.status = QueryStatus.FAILED
self._session.commit()
self._query_dao.update(query, {"status": QueryStatus.FAILED})
raise ex
def _get_the_query_db(self) -> Database:
mydb = self._session.query(Database).get(self._execution_context.database_id)
mydb = self._database_dao.find_by_id(self._execution_context.database_id)
self._validate_query_db(mydb)
return mydb
@ -159,74 +154,21 @@ class ExecuteSqlCommand(BaseCommand):
def _save_new_query(self, query: Query) -> None:
try:
self._session.add(query)
self._session.flush()
self._session.commit() # shouldn't be necessary
except SQLAlchemyError as ex:
logger.error("Errors saving query details %s", str(ex), exc_info=True)
self._session.rollback()
if not query.id:
raise SupersetGenericErrorException(
__(
"The query record was not created as expected. Please "
"contact an administrator for further assistance or try again."
)
)
self._query_dao.save(query)
except DAOCreateFailedError as ex:
raise SqlLabException(
self._execution_context,
SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
"The query record was not created as expected",
ex,
"Please contact an administrator for further assistance or try again.",
) from ex
def _validate_access(self, query: Query) -> None:
try:
query.raise_for_access()
except SupersetSecurityException as ex:
query.set_extra_json_key("errors", [dataclasses.asdict(ex.error)])
query.status = QueryStatus.FAILED
query.error_message = ex.error.message
self._session.commit()
raise SupersetErrorException(ex.error, status=403) from ex
def _render_query(self) -> str:
def validate(
rendered_query: str, template_processor: BaseTemplateProcessor
) -> None:
if is_feature_enabled("ENABLE_TEMPLATE_PROCESSING"):
# pylint: disable=protected-access
ast = template_processor._env.parse(rendered_query)
undefined_parameters = find_undeclared_variables(ast) # type: ignore
if undefined_parameters:
raise SupersetTemplateParamsErrorException(
message=ngettext(
"The parameter %(parameters)s in your query is undefined.",
"The following parameters in your query are undefined: %(parameters)s.",
len(undefined_parameters),
parameters=utils.format_list(undefined_parameters),
)
+ " "
+ PARAMETER_MISSING_ERR,
error=SupersetErrorType.MISSING_TEMPLATE_PARAMS_ERROR,
extra={
"undefined_parameters": list(undefined_parameters),
"template_parameters": self._execution_context.template_params,
},
)
query = self._execution_context.query
try:
template_processor = get_template_processor(
database=query.database, query=query
)
rendered_query = template_processor.process_template(
query.sql, **self._execution_context.template_params
)
validate(rendered_query, template_processor)
except TemplateError as ex:
raise SupersetTemplateParamsErrorException(
message=__(
'The query contains one or more malformed template parameters. Please check your query and confirm that all template parameters are surround by double braces, for example, "{{ ds }}". Then, try running your query again.'
),
error=SupersetErrorType.INVALID_TEMPLATE_PARAMS_ERROR,
) from ex
return rendered_query
self._access_validator.validate(query)
except Exception as ex:
raise QueryIsForbiddenToAccessException(self._execution_context, ex) from ex
def _set_query_limit_if_required(self, rendered_query: str,) -> None:
if self._is_required_to_set_limit():
@ -234,7 +176,7 @@ class ExecuteSqlCommand(BaseCommand):
def _is_required_to_set_limit(self) -> bool:
return not (
config.get("SQLLAB_CTAS_NO_LIMIT") and self._execution_context.select_as_cta
self._sqllab_ctas_no_limit and self._execution_context.select_as_cta
)
def _set_query_limit(self, rendered_query: str) -> None:
@ -255,161 +197,21 @@ class ExecuteSqlCommand(BaseCommand):
lim for lim in limits if lim is not None
)
def _execute_query(self, rendered_query: str,) -> SqlJsonExecutionStatus:
# Flag for whether or not to expand data
# (feature that will expand Presto row objects and arrays)
# Async request.
if self._execution_context.is_run_asynchronous():
return self._sql_json_async(rendered_query)
return self._sql_json_sync(rendered_query)
class CanAccessQueryValidator:
def validate(self, query: Query) -> None:
raise NotImplementedError()
def _sql_json_async(self, rendered_query: str,) -> SqlJsonExecutionStatus:
"""
Send SQL JSON query to celery workers.
:param rendered_query: the rendered query to perform by workers
:return: A Flask Response
"""
query = self._execution_context.query
logger.info("Query %i: Running query on a Celery worker", query.id)
# Ignore the celery future object and the request may time out.
query_id = query.id
try:
task = sql_lab.get_sql_results.delay(
query.id,
rendered_query,
return_results=False,
store_results=not query.select_as_cta,
user_name=g.user.username
if g.user and hasattr(g.user, "username")
else None,
start_time=now_as_float(),
expand_data=self._execution_context.expand_data,
log_params=self._log_params,
)
# Explicitly forget the task to ensure the task metadata is removed from the
# Celery results backend in a timely manner.
try:
task.forget()
except NotImplementedError:
logger.warning(
"Unable to forget Celery task as backend"
"does not support this operation"
)
except Exception as ex:
logger.exception("Query %i: %s", query.id, str(ex))
class SqlQueryRender:
def render(self, execution_context: SqlJsonExecutionContext) -> str:
raise NotImplementedError()
message = __("Failed to start remote query on a worker.")
error = SupersetError(
message=message,
error_type=SupersetErrorType.ASYNC_WORKERS_ERROR,
level=ErrorLevel.ERROR,
)
error_payload = dataclasses.asdict(error)
query.set_extra_json_key("errors", [error_payload])
query.status = QueryStatus.FAILED
query.error_message = message
self._session.commit()
raise SupersetErrorException(error) from ex
# Update saved query with execution info from the query execution
QueryDAO.update_saved_query_exec_info(query_id)
self._session.commit()
return SqlJsonExecutionStatus.QUERY_IS_RUNNING
def _sql_json_sync(self, rendered_query: str) -> SqlJsonExecutionStatus:
"""
Execute SQL query (sql json).
:param rendered_query: The rendered query (included templates)
:raises: SupersetTimeoutException
"""
query = self._execution_context.query
try:
timeout = config["SQLLAB_TIMEOUT"]
timeout_msg = f"The query exceeded the {timeout} seconds timeout."
query_id = query.id
data = self._get_sql_results_with_timeout(
timeout, rendered_query, timeout_msg,
)
# Update saved query if needed
QueryDAO.update_saved_query_exec_info(query_id)
self._execution_context.set_execution_result(data)
except SupersetTimeoutException as ex:
# re-raise exception for api exception handler
raise ex
except Exception as ex:
logger.exception("Query %i failed unexpectedly", query.id)
raise SupersetGenericDBErrorException(
utils.error_msg_from_exception(ex)
) from ex
if data is not None and data.get("status") == QueryStatus.FAILED:
# new error payload with rich context
if data["errors"]:
raise SupersetErrorsException(
[SupersetError(**params) for params in data["errors"]]
)
# old string-only error message
raise SupersetGenericDBErrorException(data["error"])
return SqlJsonExecutionStatus.HAS_RESULTS
def _get_sql_results_with_timeout(
self, timeout: int, rendered_query: str, timeout_msg: str,
) -> Optional[SqlResults]:
query = self._execution_context.query
with utils.timeout(seconds=timeout, error_message=timeout_msg):
# pylint: disable=no-value-for-parameter
return sql_lab.get_sql_results(
query.id,
rendered_query,
return_results=True,
store_results=self._is_store_results(query),
user_name=g.user.username
if g.user and hasattr(g.user, "username")
else None,
expand_data=self._execution_context.expand_data,
log_params=self._log_params,
)
@classmethod
def _is_store_results(cls, query: Query) -> bool:
return (
is_feature_enabled("SQLLAB_BACKEND_PERSISTENCE") and not query.select_as_cta
)
def _create_payload_from_execution_context( # pylint: disable=invalid-name
self, status: SqlJsonExecutionStatus,
class ExecutionContextConvertor:
def to_payload(
self,
execution_context: SqlJsonExecutionContext,
execution_status: SqlJsonExecutionStatus,
) -> str:
if status == SqlJsonExecutionStatus.HAS_RESULTS:
return self._to_payload_results_based(
self._execution_context.get_execution_result() or {}
)
return self._to_payload_query_based(self._execution_context.query)
def _to_payload_results_based( # pylint: disable=no-self-use
self, execution_result: SqlResults
) -> str:
display_max_row = config["DISPLAY_MAX_ROW"]
return json.dumps(
apply_display_max_row_configuration_if_require(
execution_result, display_max_row
),
default=utils.pessimistic_json_iso_dttm_ser,
ignore_nan=True,
encoding=None,
)
def _to_payload_query_based( # pylint: disable=no-self-use
self, query: Query
) -> str:
return json.dumps(
{"query": query.to_dict()},
default=utils.json_int_dttm_ser,
ignore_nan=True,
)
raise NotImplementedError()

View File

@ -81,3 +81,20 @@ class SqlLabException(SupersetException):
return ": {}".format(exception.message) # type: ignore
return ": {}".format(str(exception))
return ""
QUERY_IS_FORBIDDEN_TO_ACCESS_REASON_MESSAGE = "can not access the query"
class QueryIsForbiddenToAccessException(SqlLabException):
def __init__(
self,
sql_json_execution_context: SqlJsonExecutionContext,
exception: Optional[Exception] = None,
) -> None:
super().__init__(
sql_json_execution_context,
SupersetErrorType.QUERY_SECURITY_ACCESS_ERROR,
QUERY_IS_FORBIDDEN_TO_ACCESS_REASON_MESSAGE,
exception,
)

View File

@ -0,0 +1,67 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
import simplejson as json
import superset.utils.core as utils
from superset.sqllab.command import ExecutionContextConvertor
from superset.sqllab.command_status import SqlJsonExecutionStatus
from superset.sqllab.utils import apply_display_max_row_configuration_if_require
if TYPE_CHECKING:
from superset.sqllab.sqllab_execution_context import SqlJsonExecutionContext
from superset.sqllab.sql_json_executer import SqlResults
from superset.models.sql_lab import Query
class ExecutionContextConvertorImpl(ExecutionContextConvertor):
_max_row_in_display_configuration: int # pylint: disable=invalid-name
def set_max_row_in_display(self, value: int) -> None:
self._max_row_in_display_configuration = value # pylint: disable=invalid-name
def to_payload(
self,
execution_context: SqlJsonExecutionContext,
execution_status: SqlJsonExecutionStatus,
) -> str:
if execution_status == SqlJsonExecutionStatus.HAS_RESULTS:
return self._to_payload_results_based(
execution_context.get_execution_result() or {}
)
return self._to_payload_query_based(execution_context.query)
def _to_payload_results_based(self, execution_result: SqlResults) -> str:
return json.dumps(
apply_display_max_row_configuration_if_require(
execution_result, self._max_row_in_display_configuration
),
default=utils.pessimistic_json_iso_dttm_ser,
ignore_nan=True,
encoding=None,
)
def _to_payload_query_based( # pylint: disable=no-self-use
self, query: Query
) -> str:
return json.dumps(
{"query": query.to_dict()}, default=utils.json_int_dttm_ser, ignore_nan=True
)

View File

@ -0,0 +1,153 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, no-self-use, too-few-public-methods, too-many-arguments
from __future__ import annotations
from typing import Any, Callable, Dict, Optional, TYPE_CHECKING
from flask_babel import gettext as __, ngettext
from jinja2 import TemplateError
from jinja2.meta import find_undeclared_variables
from superset import is_feature_enabled
from superset.errors import SupersetErrorType
from superset.sqllab.command import SqlQueryRender
from superset.sqllab.exceptions import SqlLabException
from superset.utils import core as utils
MSG_OF_1006 = "Issue 1006 - One or more parameters specified in the query are missing."
if TYPE_CHECKING:
from superset.sqllab.sqllab_execution_context import SqlJsonExecutionContext
from superset.jinja_context import BaseTemplateProcessor
PARAMETER_MISSING_ERR = (
"Please check your template parameters for syntax errors and make sure "
"they match across your SQL query and Set Parameters. Then, try running "
"your query again."
)
class SqlQueryRenderImpl(SqlQueryRender):
_sql_template_processor_factory: Callable[..., BaseTemplateProcessor]
def __init__(
self, sql_template_factory: Callable[..., BaseTemplateProcessor]
) -> None:
self._sql_template_processor_factory = sql_template_factory # type: ignore
def render(self, execution_context: SqlJsonExecutionContext) -> str:
query_model = execution_context.query
try:
sql_template_processor = self._sql_template_processor_factory(
database=query_model.database, query=query_model
)
rendered_query = sql_template_processor.process_template(
query_model.sql, **execution_context.template_params
)
self._validate(execution_context, rendered_query, sql_template_processor)
return rendered_query
except TemplateError as ex:
self._raise_template_exception(ex, execution_context)
return "NOT_REACHABLE_CODE"
def _validate(
self,
execution_context: SqlJsonExecutionContext,
rendered_query: str,
sql_template_processor: BaseTemplateProcessor,
) -> None:
if is_feature_enabled("ENABLE_TEMPLATE_PROCESSING"):
# pylint: disable=protected-access
syntax_tree = sql_template_processor._env.parse(rendered_query)
undefined_parameters = find_undeclared_variables( # type: ignore
syntax_tree
)
if undefined_parameters:
self._raise_undefined_parameter_exception(
execution_context, undefined_parameters
)
def _raise_undefined_parameter_exception(
self, execution_context: SqlJsonExecutionContext, undefined_parameters: Any
) -> None:
raise SqlQueryRenderException(
sql_json_execution_context=execution_context,
error_type=SupersetErrorType.MISSING_TEMPLATE_PARAMS_ERROR,
reason_message=ngettext(
"The parameter %(parameters)s in your query is undefined.",
"The following parameters in your query are undefined: %(parameters)s.",
len(undefined_parameters),
parameters=utils.format_list(undefined_parameters),
),
suggestion_help_msg=PARAMETER_MISSING_ERR,
extra={
"undefined_parameters": list(undefined_parameters),
"template_parameters": execution_context.template_params,
"issue_codes": [{"code": 1006, "message": MSG_OF_1006,}],
},
)
def _raise_template_exception(
self, ex: Exception, execution_context: SqlJsonExecutionContext
) -> None:
raise SqlQueryRenderException(
sql_json_execution_context=execution_context,
error_type=SupersetErrorType.INVALID_TEMPLATE_PARAMS_ERROR,
reason_message=__(
"The query contains one or more malformed template parameters."
),
suggestion_help_msg=__(
"Please check your query and confirm that all template "
"parameters are surround by double braces, for example, "
'"{{ ds }}". Then, try running your query again.'
),
) from ex
class SqlQueryRenderException(SqlLabException):
_extra: Optional[Dict[str, Any]]
def __init__(
self,
sql_json_execution_context: SqlJsonExecutionContext,
error_type: SupersetErrorType,
reason_message: Optional[str] = None,
exception: Optional[Exception] = None,
suggestion_help_msg: Optional[str] = None,
extra: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__(
sql_json_execution_context,
error_type,
reason_message,
exception,
suggestion_help_msg,
)
self._extra = extra
@property
def extra(self) -> Optional[Dict[str, Any]]:
return self._extra
def to_dict(self) -> Dict[str, Any]:
rv = super().to_dict()
if self._extra:
rv["extra"] = self._extra
return rv

View File

@ -0,0 +1,207 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-few-public-methods, invalid-name
from __future__ import annotations
import dataclasses
import logging
from abc import ABC
from typing import Any, Callable, Dict, Optional, TYPE_CHECKING
from flask import g
from flask_babel import gettext as __
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import (
SupersetErrorException,
SupersetErrorsException,
SupersetGenericDBErrorException,
SupersetTimeoutException,
)
from superset.sqllab.command_status import SqlJsonExecutionStatus
from superset.utils import core as utils
from superset.utils.dates import now_as_float
if TYPE_CHECKING:
from superset.queries.dao import QueryDAO
from superset.sqllab.sqllab_execution_context import SqlJsonExecutionContext
QueryStatus = utils.QueryStatus
logger = logging.getLogger(__name__)
SqlResults = Dict[str, Any]
GetSqlResultsTask = Callable[..., SqlResults]
class SqlJsonExecutor:
def execute(
self,
execution_context: SqlJsonExecutionContext,
rendered_query: str,
log_params: Optional[Dict[str, Any]],
) -> SqlJsonExecutionStatus:
raise NotImplementedError()
class SqlJsonExecutorBase(SqlJsonExecutor, ABC):
_query_dao: QueryDAO
_get_sql_results_task: GetSqlResultsTask
def __init__(self, query_dao: QueryDAO, get_sql_results_task: GetSqlResultsTask):
self._query_dao = query_dao
self._get_sql_results_task = get_sql_results_task # type: ignore
class SynchronousSqlJsonExecutor(SqlJsonExecutorBase):
_timeout_duration_in_seconds: int
_sqllab_backend_persistence_feature_enable: bool
def __init__(
self,
query_dao: QueryDAO,
get_sql_results_task: GetSqlResultsTask,
timeout_duration_in_seconds: int,
sqllab_backend_persistence_feature_enable: bool,
):
super().__init__(query_dao, get_sql_results_task)
self._timeout_duration_in_seconds = timeout_duration_in_seconds
self._sqllab_backend_persistence_feature_enable = (
sqllab_backend_persistence_feature_enable
)
def execute(
self,
execution_context: SqlJsonExecutionContext,
rendered_query: str,
log_params: Optional[Dict[str, Any]],
) -> SqlJsonExecutionStatus:
query_id = execution_context.query.id
try:
data = self._get_sql_results_with_timeout(
execution_context, rendered_query, log_params
)
self._query_dao.update_saved_query_exec_info(query_id)
execution_context.set_execution_result(data)
except SupersetTimeoutException as ex:
raise ex
except Exception as ex:
logger.exception("Query %i failed unexpectedly", query_id)
raise SupersetGenericDBErrorException(
utils.error_msg_from_exception(ex)
) from ex
if data.get("status") == QueryStatus.FAILED: # type: ignore
# new error payload with rich context
if data["errors"]: # type: ignore
raise SupersetErrorsException(
[SupersetError(**params) for params in data["errors"]] # type: ignore
)
# old string-only error message
raise SupersetGenericDBErrorException(data["error"]) # type: ignore
return SqlJsonExecutionStatus.HAS_RESULTS
def _get_sql_results_with_timeout(
self,
execution_context: SqlJsonExecutionContext,
rendered_query: str,
log_params: Optional[Dict[str, Any]],
) -> Optional[SqlResults]:
with utils.timeout(
seconds=self._timeout_duration_in_seconds,
error_message=self._get_timeout_error_msg(),
):
return self._get_sql_results(execution_context, rendered_query, log_params)
def _get_sql_results(
self,
execution_context: SqlJsonExecutionContext,
rendered_query: str,
log_params: Optional[Dict[str, Any]],
) -> Optional[SqlResults]:
return self._get_sql_results_task(
execution_context.query.id,
rendered_query,
return_results=True,
store_results=self._is_store_results(execution_context),
user_name=g.user.username
if g.user and hasattr(g.user, "username")
else None,
expand_data=execution_context.expand_data,
log_params=log_params,
)
def _is_store_results(self, execution_context: SqlJsonExecutionContext) -> bool:
return (
self._sqllab_backend_persistence_feature_enable
and not execution_context.select_as_cta
)
def _get_timeout_error_msg(self) -> str:
return "The query exceeded the {timeout} seconds timeout.".format(
timeout=self._timeout_duration_in_seconds
)
class ASynchronousSqlJsonExecutor(SqlJsonExecutorBase):
def execute(
self,
execution_context: SqlJsonExecutionContext,
rendered_query: str,
log_params: Optional[Dict[str, Any]],
) -> SqlJsonExecutionStatus:
query_id = execution_context.query.id
logger.info("Query %i: Running query on a Celery worker", query_id)
try:
task = self._get_sql_results_task.delay( # type: ignore
query_id,
rendered_query,
return_results=False,
store_results=not execution_context.select_as_cta,
user_name=g.user.username
if g.user and hasattr(g.user, "username")
else None,
start_time=now_as_float(),
expand_data=execution_context.expand_data,
log_params=log_params,
)
try:
task.forget()
except NotImplementedError:
logger.warning(
"Unable to forget Celery task as backend"
"does not support this operation"
)
except Exception as ex:
logger.exception("Query %i: %s", query_id, str(ex))
message = __("Failed to start remote query on a worker.")
error = SupersetError(
message=message,
error_type=SupersetErrorType.ASYNC_WORKERS_ERROR,
level=ErrorLevel.ERROR,
)
error_payload = dataclasses.asdict(error)
query = execution_context.query
query.set_extra_json_key("errors", [error_payload])
query.status = QueryStatus.FAILED
query.error_message = message
raise SupersetErrorException(error) from ex
self._query_dao.update_saved_query_exec_info(query_id)
return SqlJsonExecutionStatus.QUERY_IS_RUNNING

View File

@ -0,0 +1,31 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-few-public-methods
from __future__ import annotations
from typing import TYPE_CHECKING
from superset import security_manager
from superset.sqllab.command import CanAccessQueryValidator
if TYPE_CHECKING:
from superset.models.sql_lab import Query
class CanAccessQueryValidatorImpl(CanAccessQueryValidator):
def validate(self, query: Query) -> None:
security_manager.raise_for_access(query=query)

View File

@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-many-lines
# pylint: disable=too-many-lines, invalid-name
from __future__ import annotations
import logging
@ -95,14 +95,28 @@ from superset.models.datasource_access_request import DatasourceAccessRequest
from superset.models.slice import Slice
from superset.models.sql_lab import Query, TabState
from superset.models.user_attributes import UserAttribute
from superset.queries.dao import QueryDAO
from superset.security.analytics_db_safety import check_sqlalchemy_uri
from superset.sql_lab import get_sql_results
from superset.sql_parse import ParsedQuery, Table
from superset.sql_validators import get_validator_by_name
from superset.sqllab.command import CommandResult, ExecuteSqlCommand
from superset.sqllab.command_status import SqlJsonExecutionStatus
from superset.sqllab.exceptions import SqlLabException
from superset.sqllab.exceptions import (
QueryIsForbiddenToAccessException,
SqlLabException,
)
from superset.sqllab.execution_context_convertor import ExecutionContextConvertorImpl
from superset.sqllab.limiting_factor import LimitingFactor
from superset.sqllab.query_render import SqlQueryRenderImpl
from superset.sqllab.sql_json_executer import (
ASynchronousSqlJsonExecutor,
SqlJsonExecutor,
SynchronousSqlJsonExecutor,
)
from superset.sqllab.sqllab_execution_context import SqlJsonExecutionContext
from superset.sqllab.utils import apply_display_max_row_configuration_if_require
from superset.sqllab.validators import CanAccessQueryValidatorImpl
from superset.tasks.async_queries import load_explore_json_into_cache
from superset.typing import FlaskResponse
from superset.utils import core as utils, csv
@ -111,7 +125,6 @@ from superset.utils.cache import etag_cache
from superset.utils.core import apply_max_row_limit, ReservedUrlParameters
from superset.utils.dates import now_as_float
from superset.utils.decorators import check_dashboard_access
from superset.utils.sqllab_execution_context import SqlJsonExecutionContext
from superset.views.base import (
api,
BaseSupersetView,
@ -2440,13 +2453,60 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
"user_agent": cast(Optional[str], request.headers.get("USER_AGENT"))
}
execution_context = SqlJsonExecutionContext(request.json)
command = ExecuteSqlCommand(execution_context, log_params)
command = self._create_sql_json_command(execution_context, log_params)
command_result: CommandResult = command.run()
return self._create_response_from_execution_context(command_result)
except SqlLabException as ex:
logger.error(ex.message)
self._set_http_status_into_Sql_lab_exception(ex)
payload = {"errors": [ex.to_dict()]}
return json_error_response(status=ex.status, payload=payload)
@staticmethod
def _create_sql_json_command(
execution_context: SqlJsonExecutionContext, log_params: Optional[Dict[str, Any]]
) -> ExecuteSqlCommand:
query_dao = QueryDAO()
sql_json_executor = Superset._create_sql_json_executor(
execution_context, query_dao
)
execution_context_convertor = ExecutionContextConvertorImpl()
execution_context_convertor.set_max_row_in_display(
int(config.get("DISPLAY_MAX_ROW")) # type: ignore
)
return ExecuteSqlCommand(
execution_context,
query_dao,
DatabaseDAO(),
CanAccessQueryValidatorImpl(),
SqlQueryRenderImpl(get_template_processor),
sql_json_executor,
execution_context_convertor,
config.get("SQLLAB_CTAS_NO_LIMIT"), # type: ignore
log_params,
)
@staticmethod
def _create_sql_json_executor(
execution_context: SqlJsonExecutionContext, query_dao: QueryDAO
) -> SqlJsonExecutor:
sql_json_executor: SqlJsonExecutor
if execution_context.is_run_asynchronous():
sql_json_executor = ASynchronousSqlJsonExecutor(query_dao, get_sql_results)
else:
sql_json_executor = SynchronousSqlJsonExecutor(
query_dao,
get_sql_results,
config.get("SQLLAB_TIMEOUT"), # type: ignore
is_feature_enabled("SQLLAB_BACKEND_PERSISTENCE"),
)
return sql_json_executor
@staticmethod
def _set_http_status_into_Sql_lab_exception(ex: SqlLabException) -> None:
if isinstance(ex, QueryIsForbiddenToAccessException):
ex.status = 403
def _create_response_from_execution_context( # pylint: disable=invalid-name, no-self-use
self, command_result: CommandResult,
) -> FlaskResponse:

View File

@ -217,7 +217,7 @@ def test_run_sync_query_cta_no_data(setup_sqllab):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW])
@mock.patch(
"superset.utils.sqllab_execution_context.get_cta_schema_name",
"superset.sqllab.sqllab_execution_context.get_cta_schema_name",
lambda d, u, s, sql: CTAS_SCHEMA_NAME,
)
def test_run_sync_query_cta_config(setup_sqllab, ctas_method):
@ -245,7 +245,7 @@ def test_run_sync_query_cta_config(setup_sqllab, ctas_method):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW])
@mock.patch(
"superset.utils.sqllab_execution_context.get_cta_schema_name",
"superset.sqllab.sqllab_execution_context.get_cta_schema_name",
lambda d, u, s, sql: CTAS_SCHEMA_NAME,
)
def test_run_async_query_cta_config(setup_sqllab, ctas_method):

View File

@ -749,10 +749,11 @@ class TestCore(SupersetTestCase):
data = self.run_sql(sql, "fdaklj3ws")
self.assertEqual(data["data"][0]["test"], "2")
@pytest.mark.ofek
@mock.patch(
"tests.integration_tests.superset_test_custom_template_processors.datetime"
)
@mock.patch("superset.sql_lab.get_sql_results")
@mock.patch("superset.views.core.get_sql_results")
def test_custom_templated_sql_json(self, sql_lab_mock, mock_dt) -> None:
"""Test sqllab receives macros expanded query."""
mock_dt.utcnow = mock.Mock(return_value=datetime.datetime(1970, 1, 1))

View File

@ -189,7 +189,7 @@ class TestSqlLab(SupersetTestCase):
return
with mock.patch(
"superset.utils.sqllab_execution_context.get_cta_schema_name",
"superset.sqllab.sqllab_execution_context.get_cta_schema_name",
lambda d, u, s, sql: f"{u.username}_database",
):
old_allow_ctas = examples_db.allow_ctas