diff --git a/superset/charts/dao.py b/superset/charts/dao.py index 2a80f82b8..8e16f3b44 100644 --- a/superset/charts/dao.py +++ b/superset/charts/dao.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=arguments-renamed import logging from typing import List, Optional, TYPE_CHECKING diff --git a/superset/dao/base.py b/superset/dao/base.py index 79ece40c9..ebd6a8908 100644 --- a/superset/dao/base.py +++ b/superset/dao/base.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=isinstance-second-argument-not-valid-type from typing import Any, Dict, List, Optional, Type from flask_appbuilder.models.filters import BaseFilter @@ -89,6 +90,19 @@ class BaseDAO: ).apply(query, None) return query.all() + @classmethod + def find_one_or_none(cls, **filter_by: Any) -> Optional[Model]: + """ + Get the first that fit the `base_filter` + """ + query = db.session.query(cls.model_cls) + if cls.base_filter: + data_model = SQLAInterface(cls.model_cls, db.session) + query = cls.base_filter( # pylint: disable=not-callable + "id", data_model + ).apply(query, None) + return query.filter_by(**filter_by).one_or_none() + @classmethod def create(cls, properties: Dict[str, Any], commit: bool = True) -> Model: """ @@ -109,6 +123,27 @@ class BaseDAO: raise DAOCreateFailedError(exception=ex) from ex return model + @classmethod + def save(cls, instance_model: Model, commit: bool = True) -> Model: + """ + Generic for saving models + :raises: DAOCreateFailedError + """ + if cls.model_cls is None: + raise DAOConfigError() + if not isinstance(instance_model, cls.model_cls): + raise DAOCreateFailedError( + "the instance model is not a type of the model class" + ) + try: + db.session.add(instance_model) + if commit: + db.session.commit() + except SQLAlchemyError as ex: # pragma: no cover + db.session.rollback() + raise DAOCreateFailedError(exception=ex) from ex + return instance_model + @classmethod def update( cls, model: Model, properties: Dict[str, Any], commit: bool = True diff --git a/superset/datasets/dao.py b/superset/datasets/dao.py index b58622b9d..363e89b8b 100644 --- a/superset/datasets/dao.py +++ b/superset/datasets/dao.py @@ -143,7 +143,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods return len(dataset_query) == 0 @classmethod - def update( # pylint: disable=arguments-differ + def update( cls, model: SqlaTable, properties: Dict[str, Any], commit: bool = True ) -> Optional[SqlaTable]: """ diff --git a/superset/models/core.py b/superset/models/core.py index 28ba2eb96..d4e78503d 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=line-too-long """A collection of ORM sqlalchemy models for Superset""" import enum import json @@ -253,7 +254,7 @@ class Database( @property def parameters_schema(self) -> Dict[str, Any]: try: - parameters_schema = self.db_engine_spec.parameters_json_schema() # type: ignore # pylint: disable=line-too-long + parameters_schema = self.db_engine_spec.parameters_json_schema() # type: ignore except Exception: # pylint: disable=broad-except parameters_schema = {} return parameters_schema diff --git a/superset/sqllab/command.py b/superset/sqllab/command.py index c9b9df419..cb4854965 100644 --- a/superset/sqllab/command.py +++ b/superset/sqllab/command.py @@ -14,73 +14,71 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=line-too-long +# pylint: disable=too-few-public-methods, too-many-arguments from __future__ import annotations -import dataclasses import logging -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, TYPE_CHECKING -import simplejson as json -from flask import g -from flask_babel import gettext as __, ngettext -from jinja2.exceptions import TemplateError -from jinja2.meta import find_undeclared_variables -from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.orm.session import Session +from flask_babel import gettext as __ -from superset import app, db, is_feature_enabled, sql_lab from superset.commands.base import BaseCommand from superset.common.db_query_status import QueryStatus -from superset.errors import ErrorLevel, SupersetError, SupersetErrorType -from superset.exceptions import ( - SupersetErrorException, - SupersetErrorsException, - SupersetGenericDBErrorException, - SupersetGenericErrorException, - SupersetSecurityException, - SupersetTemplateParamsErrorException, - SupersetTimeoutException, -) -from superset.jinja_context import BaseTemplateProcessor, get_template_processor +from superset.dao.exceptions import DAOCreateFailedError +from superset.errors import SupersetErrorType +from superset.exceptions import SupersetErrorsException, SupersetGenericErrorException from superset.models.core import Database from superset.models.sql_lab import Query -from superset.queries.dao import QueryDAO from superset.sqllab.command_status import SqlJsonExecutionStatus -from superset.sqllab.exceptions import SqlLabException -from superset.sqllab.limiting_factor import LimitingFactor -from superset.sqllab.utils import apply_display_max_row_configuration_if_require -from superset.utils import core as utils -from superset.utils.dates import now_as_float -from superset.utils.sqllab_execution_context import SqlJsonExecutionContext - -config = app.config -logger = logging.getLogger(__name__) - -PARAMETER_MISSING_ERR = ( - "Please check your template parameters for syntax errors and make sure " - "they match across your SQL query and Set Parameters. Then, try running " - "your query again." +from superset.sqllab.exceptions import ( + QueryIsForbiddenToAccessException, + SqlLabException, ) +from superset.sqllab.limiting_factor import LimitingFactor -SqlResults = Dict[str, Any] +if TYPE_CHECKING: + from superset.sqllab.sql_json_executer import SqlJsonExecutor + from superset.sqllab.sqllab_execution_context import SqlJsonExecutionContext + from superset.queries.dao import QueryDAO + from superset.databases.dao import DatabaseDAO + +logger = logging.getLogger(__name__) CommandResult = Dict[str, Any] class ExecuteSqlCommand(BaseCommand): _execution_context: SqlJsonExecutionContext + _query_dao: QueryDAO + _database_dao: DatabaseDAO + _access_validator: CanAccessQueryValidator + _sql_query_render: SqlQueryRender + _sql_json_executor: SqlJsonExecutor + _execution_context_convertor: ExecutionContextConvertor + _sqllab_ctas_no_limit: bool _log_params: Optional[Dict[str, Any]] = None - _session: Session def __init__( self, execution_context: SqlJsonExecutionContext, + query_dao: QueryDAO, + database_dao: DatabaseDAO, + access_validator: CanAccessQueryValidator, + sql_query_render: SqlQueryRender, + sql_json_executor: SqlJsonExecutor, + execution_context_convertor: ExecutionContextConvertor, + sqllab_ctas_no_limit_flag: bool, log_params: Optional[Dict[str, Any]] = None, ) -> None: self._execution_context = execution_context + self._query_dao = query_dao + self._database_dao = database_dao + self._access_validator = access_validator + self._sql_query_render = sql_query_render + self._sql_json_executor = sql_json_executor + self._execution_context_convertor = execution_context_convertor + self._sqllab_ctas_no_limit = sqllab_ctas_no_limit_flag self._log_params = log_params - self._session = db.session() def validate(self) -> None: pass @@ -90,7 +88,7 @@ class ExecuteSqlCommand(BaseCommand): ) -> CommandResult: """Runs arbitrary sql and returns data as json""" try: - query = self._get_existing_query() + query = self._try_get_existing_query() if self.is_query_handled(query): self._execution_context.set_query(query) # type: ignore status = SqlJsonExecutionStatus.QUERY_ALREADY_CREATED @@ -98,24 +96,21 @@ class ExecuteSqlCommand(BaseCommand): status = self._run_sql_json_exec_from_scratch() return { "status": status, - "payload": self._create_payload_from_execution_context(status), + "payload": self._execution_context_convertor.to_payload( + self._execution_context, status + ), } except (SqlLabException, SupersetErrorsException) as ex: raise ex except Exception as ex: raise SqlLabException(self._execution_context, exception=ex) from ex - def _get_existing_query(self) -> Optional[Query]: - query = ( - self._session.query(Query) - .filter_by( - client_id=self._execution_context.client_id, - user_id=self._execution_context.user_id, - sql_editor_id=self._execution_context.sql_editor_id, - ) - .one_or_none() + def _try_get_existing_query(self) -> Optional[Query]: + return self._query_dao.find_one_or_none( + client_id=self._execution_context.client_id, + user_id=self._execution_context.user_id, + sql_editor_id=self._execution_context.sql_editor_id, ) - return query @classmethod def is_query_handled(cls, query: Optional[Query]) -> bool: @@ -130,20 +125,20 @@ class ExecuteSqlCommand(BaseCommand): query = self._execution_context.create_query() self._save_new_query(query) try: - self._save_new_query(query) logger.info("Triggering query_id: %i", query.id) self._validate_access(query) self._execution_context.set_query(query) - rendered_query = self._render_query() + rendered_query = self._sql_query_render.render(self._execution_context) self._set_query_limit_if_required(rendered_query) - return self._execute_query(rendered_query) + return self._sql_json_executor.execute( + self._execution_context, rendered_query, self._log_params + ) except Exception as ex: - query.status = QueryStatus.FAILED - self._session.commit() + self._query_dao.update(query, {"status": QueryStatus.FAILED}) raise ex def _get_the_query_db(self) -> Database: - mydb = self._session.query(Database).get(self._execution_context.database_id) + mydb = self._database_dao.find_by_id(self._execution_context.database_id) self._validate_query_db(mydb) return mydb @@ -159,74 +154,21 @@ class ExecuteSqlCommand(BaseCommand): def _save_new_query(self, query: Query) -> None: try: - self._session.add(query) - self._session.flush() - self._session.commit() # shouldn't be necessary - except SQLAlchemyError as ex: - logger.error("Errors saving query details %s", str(ex), exc_info=True) - self._session.rollback() - if not query.id: - raise SupersetGenericErrorException( - __( - "The query record was not created as expected. Please " - "contact an administrator for further assistance or try again." - ) - ) + self._query_dao.save(query) + except DAOCreateFailedError as ex: + raise SqlLabException( + self._execution_context, + SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + "The query record was not created as expected", + ex, + "Please contact an administrator for further assistance or try again.", + ) from ex def _validate_access(self, query: Query) -> None: try: - query.raise_for_access() - except SupersetSecurityException as ex: - query.set_extra_json_key("errors", [dataclasses.asdict(ex.error)]) - query.status = QueryStatus.FAILED - query.error_message = ex.error.message - self._session.commit() - raise SupersetErrorException(ex.error, status=403) from ex - - def _render_query(self) -> str: - def validate( - rendered_query: str, template_processor: BaseTemplateProcessor - ) -> None: - if is_feature_enabled("ENABLE_TEMPLATE_PROCESSING"): - # pylint: disable=protected-access - ast = template_processor._env.parse(rendered_query) - undefined_parameters = find_undeclared_variables(ast) # type: ignore - if undefined_parameters: - raise SupersetTemplateParamsErrorException( - message=ngettext( - "The parameter %(parameters)s in your query is undefined.", - "The following parameters in your query are undefined: %(parameters)s.", - len(undefined_parameters), - parameters=utils.format_list(undefined_parameters), - ) - + " " - + PARAMETER_MISSING_ERR, - error=SupersetErrorType.MISSING_TEMPLATE_PARAMS_ERROR, - extra={ - "undefined_parameters": list(undefined_parameters), - "template_parameters": self._execution_context.template_params, - }, - ) - - query = self._execution_context.query - - try: - template_processor = get_template_processor( - database=query.database, query=query - ) - rendered_query = template_processor.process_template( - query.sql, **self._execution_context.template_params - ) - validate(rendered_query, template_processor) - except TemplateError as ex: - raise SupersetTemplateParamsErrorException( - message=__( - 'The query contains one or more malformed template parameters. Please check your query and confirm that all template parameters are surround by double braces, for example, "{{ ds }}". Then, try running your query again.' - ), - error=SupersetErrorType.INVALID_TEMPLATE_PARAMS_ERROR, - ) from ex - - return rendered_query + self._access_validator.validate(query) + except Exception as ex: + raise QueryIsForbiddenToAccessException(self._execution_context, ex) from ex def _set_query_limit_if_required(self, rendered_query: str,) -> None: if self._is_required_to_set_limit(): @@ -234,7 +176,7 @@ class ExecuteSqlCommand(BaseCommand): def _is_required_to_set_limit(self) -> bool: return not ( - config.get("SQLLAB_CTAS_NO_LIMIT") and self._execution_context.select_as_cta + self._sqllab_ctas_no_limit and self._execution_context.select_as_cta ) def _set_query_limit(self, rendered_query: str) -> None: @@ -255,161 +197,21 @@ class ExecuteSqlCommand(BaseCommand): lim for lim in limits if lim is not None ) - def _execute_query(self, rendered_query: str,) -> SqlJsonExecutionStatus: - # Flag for whether or not to expand data - # (feature that will expand Presto row objects and arrays) - # Async request. - if self._execution_context.is_run_asynchronous(): - return self._sql_json_async(rendered_query) - return self._sql_json_sync(rendered_query) +class CanAccessQueryValidator: + def validate(self, query: Query) -> None: + raise NotImplementedError() - def _sql_json_async(self, rendered_query: str,) -> SqlJsonExecutionStatus: - """ - Send SQL JSON query to celery workers. - :param rendered_query: the rendered query to perform by workers - :return: A Flask Response - """ - query = self._execution_context.query - logger.info("Query %i: Running query on a Celery worker", query.id) - # Ignore the celery future object and the request may time out. - query_id = query.id - try: - task = sql_lab.get_sql_results.delay( - query.id, - rendered_query, - return_results=False, - store_results=not query.select_as_cta, - user_name=g.user.username - if g.user and hasattr(g.user, "username") - else None, - start_time=now_as_float(), - expand_data=self._execution_context.expand_data, - log_params=self._log_params, - ) - # Explicitly forget the task to ensure the task metadata is removed from the - # Celery results backend in a timely manner. - try: - task.forget() - except NotImplementedError: - logger.warning( - "Unable to forget Celery task as backend" - "does not support this operation" - ) - except Exception as ex: - logger.exception("Query %i: %s", query.id, str(ex)) +class SqlQueryRender: + def render(self, execution_context: SqlJsonExecutionContext) -> str: + raise NotImplementedError() - message = __("Failed to start remote query on a worker.") - error = SupersetError( - message=message, - error_type=SupersetErrorType.ASYNC_WORKERS_ERROR, - level=ErrorLevel.ERROR, - ) - error_payload = dataclasses.asdict(error) - query.set_extra_json_key("errors", [error_payload]) - query.status = QueryStatus.FAILED - query.error_message = message - self._session.commit() - - raise SupersetErrorException(error) from ex - - # Update saved query with execution info from the query execution - QueryDAO.update_saved_query_exec_info(query_id) - - self._session.commit() - return SqlJsonExecutionStatus.QUERY_IS_RUNNING - - def _sql_json_sync(self, rendered_query: str) -> SqlJsonExecutionStatus: - """ - Execute SQL query (sql json). - - :param rendered_query: The rendered query (included templates) - :raises: SupersetTimeoutException - """ - query = self._execution_context.query - try: - timeout = config["SQLLAB_TIMEOUT"] - timeout_msg = f"The query exceeded the {timeout} seconds timeout." - query_id = query.id - data = self._get_sql_results_with_timeout( - timeout, rendered_query, timeout_msg, - ) - # Update saved query if needed - QueryDAO.update_saved_query_exec_info(query_id) - self._execution_context.set_execution_result(data) - except SupersetTimeoutException as ex: - # re-raise exception for api exception handler - raise ex - except Exception as ex: - logger.exception("Query %i failed unexpectedly", query.id) - raise SupersetGenericDBErrorException( - utils.error_msg_from_exception(ex) - ) from ex - - if data is not None and data.get("status") == QueryStatus.FAILED: - # new error payload with rich context - if data["errors"]: - raise SupersetErrorsException( - [SupersetError(**params) for params in data["errors"]] - ) - # old string-only error message - raise SupersetGenericDBErrorException(data["error"]) - return SqlJsonExecutionStatus.HAS_RESULTS - - def _get_sql_results_with_timeout( - self, timeout: int, rendered_query: str, timeout_msg: str, - ) -> Optional[SqlResults]: - query = self._execution_context.query - with utils.timeout(seconds=timeout, error_message=timeout_msg): - # pylint: disable=no-value-for-parameter - return sql_lab.get_sql_results( - query.id, - rendered_query, - return_results=True, - store_results=self._is_store_results(query), - user_name=g.user.username - if g.user and hasattr(g.user, "username") - else None, - expand_data=self._execution_context.expand_data, - log_params=self._log_params, - ) - - @classmethod - def _is_store_results(cls, query: Query) -> bool: - return ( - is_feature_enabled("SQLLAB_BACKEND_PERSISTENCE") and not query.select_as_cta - ) - - def _create_payload_from_execution_context( # pylint: disable=invalid-name - self, status: SqlJsonExecutionStatus, +class ExecutionContextConvertor: + def to_payload( + self, + execution_context: SqlJsonExecutionContext, + execution_status: SqlJsonExecutionStatus, ) -> str: - - if status == SqlJsonExecutionStatus.HAS_RESULTS: - return self._to_payload_results_based( - self._execution_context.get_execution_result() or {} - ) - return self._to_payload_query_based(self._execution_context.query) - - def _to_payload_results_based( # pylint: disable=no-self-use - self, execution_result: SqlResults - ) -> str: - display_max_row = config["DISPLAY_MAX_ROW"] - return json.dumps( - apply_display_max_row_configuration_if_require( - execution_result, display_max_row - ), - default=utils.pessimistic_json_iso_dttm_ser, - ignore_nan=True, - encoding=None, - ) - - def _to_payload_query_based( # pylint: disable=no-self-use - self, query: Query - ) -> str: - return json.dumps( - {"query": query.to_dict()}, - default=utils.json_int_dttm_ser, - ignore_nan=True, - ) + raise NotImplementedError() diff --git a/superset/sqllab/exceptions.py b/superset/sqllab/exceptions.py index 6b1736c91..ac632d731 100644 --- a/superset/sqllab/exceptions.py +++ b/superset/sqllab/exceptions.py @@ -81,3 +81,20 @@ class SqlLabException(SupersetException): return ": {}".format(exception.message) # type: ignore return ": {}".format(str(exception)) return "" + + +QUERY_IS_FORBIDDEN_TO_ACCESS_REASON_MESSAGE = "can not access the query" + + +class QueryIsForbiddenToAccessException(SqlLabException): + def __init__( + self, + sql_json_execution_context: SqlJsonExecutionContext, + exception: Optional[Exception] = None, + ) -> None: + super().__init__( + sql_json_execution_context, + SupersetErrorType.QUERY_SECURITY_ACCESS_ERROR, + QUERY_IS_FORBIDDEN_TO_ACCESS_REASON_MESSAGE, + exception, + ) diff --git a/superset/sqllab/execution_context_convertor.py b/superset/sqllab/execution_context_convertor.py new file mode 100644 index 000000000..6d52355d2 --- /dev/null +++ b/superset/sqllab/execution_context_convertor.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +import simplejson as json + +import superset.utils.core as utils +from superset.sqllab.command import ExecutionContextConvertor +from superset.sqllab.command_status import SqlJsonExecutionStatus +from superset.sqllab.utils import apply_display_max_row_configuration_if_require + +if TYPE_CHECKING: + from superset.sqllab.sqllab_execution_context import SqlJsonExecutionContext + from superset.sqllab.sql_json_executer import SqlResults + from superset.models.sql_lab import Query + + +class ExecutionContextConvertorImpl(ExecutionContextConvertor): + _max_row_in_display_configuration: int # pylint: disable=invalid-name + + def set_max_row_in_display(self, value: int) -> None: + self._max_row_in_display_configuration = value # pylint: disable=invalid-name + + def to_payload( + self, + execution_context: SqlJsonExecutionContext, + execution_status: SqlJsonExecutionStatus, + ) -> str: + + if execution_status == SqlJsonExecutionStatus.HAS_RESULTS: + return self._to_payload_results_based( + execution_context.get_execution_result() or {} + ) + return self._to_payload_query_based(execution_context.query) + + def _to_payload_results_based(self, execution_result: SqlResults) -> str: + return json.dumps( + apply_display_max_row_configuration_if_require( + execution_result, self._max_row_in_display_configuration + ), + default=utils.pessimistic_json_iso_dttm_ser, + ignore_nan=True, + encoding=None, + ) + + def _to_payload_query_based( # pylint: disable=no-self-use + self, query: Query + ) -> str: + return json.dumps( + {"query": query.to_dict()}, default=utils.json_int_dttm_ser, ignore_nan=True + ) diff --git a/superset/sqllab/query_render.py b/superset/sqllab/query_render.py new file mode 100644 index 000000000..b03b21d83 --- /dev/null +++ b/superset/sqllab/query_render.py @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, no-self-use, too-few-public-methods, too-many-arguments +from __future__ import annotations + +from typing import Any, Callable, Dict, Optional, TYPE_CHECKING + +from flask_babel import gettext as __, ngettext +from jinja2 import TemplateError +from jinja2.meta import find_undeclared_variables + +from superset import is_feature_enabled +from superset.errors import SupersetErrorType +from superset.sqllab.command import SqlQueryRender +from superset.sqllab.exceptions import SqlLabException +from superset.utils import core as utils + +MSG_OF_1006 = "Issue 1006 - One or more parameters specified in the query are missing." + +if TYPE_CHECKING: + from superset.sqllab.sqllab_execution_context import SqlJsonExecutionContext + from superset.jinja_context import BaseTemplateProcessor + +PARAMETER_MISSING_ERR = ( + "Please check your template parameters for syntax errors and make sure " + "they match across your SQL query and Set Parameters. Then, try running " + "your query again." +) + + +class SqlQueryRenderImpl(SqlQueryRender): + _sql_template_processor_factory: Callable[..., BaseTemplateProcessor] + + def __init__( + self, sql_template_factory: Callable[..., BaseTemplateProcessor] + ) -> None: + + self._sql_template_processor_factory = sql_template_factory # type: ignore + + def render(self, execution_context: SqlJsonExecutionContext) -> str: + query_model = execution_context.query + try: + sql_template_processor = self._sql_template_processor_factory( + database=query_model.database, query=query_model + ) + + rendered_query = sql_template_processor.process_template( + query_model.sql, **execution_context.template_params + ) + self._validate(execution_context, rendered_query, sql_template_processor) + return rendered_query + except TemplateError as ex: + self._raise_template_exception(ex, execution_context) + return "NOT_REACHABLE_CODE" + + def _validate( + self, + execution_context: SqlJsonExecutionContext, + rendered_query: str, + sql_template_processor: BaseTemplateProcessor, + ) -> None: + if is_feature_enabled("ENABLE_TEMPLATE_PROCESSING"): + # pylint: disable=protected-access + syntax_tree = sql_template_processor._env.parse(rendered_query) + undefined_parameters = find_undeclared_variables( # type: ignore + syntax_tree + ) + if undefined_parameters: + self._raise_undefined_parameter_exception( + execution_context, undefined_parameters + ) + + def _raise_undefined_parameter_exception( + self, execution_context: SqlJsonExecutionContext, undefined_parameters: Any + ) -> None: + raise SqlQueryRenderException( + sql_json_execution_context=execution_context, + error_type=SupersetErrorType.MISSING_TEMPLATE_PARAMS_ERROR, + reason_message=ngettext( + "The parameter %(parameters)s in your query is undefined.", + "The following parameters in your query are undefined: %(parameters)s.", + len(undefined_parameters), + parameters=utils.format_list(undefined_parameters), + ), + suggestion_help_msg=PARAMETER_MISSING_ERR, + extra={ + "undefined_parameters": list(undefined_parameters), + "template_parameters": execution_context.template_params, + "issue_codes": [{"code": 1006, "message": MSG_OF_1006,}], + }, + ) + + def _raise_template_exception( + self, ex: Exception, execution_context: SqlJsonExecutionContext + ) -> None: + raise SqlQueryRenderException( + sql_json_execution_context=execution_context, + error_type=SupersetErrorType.INVALID_TEMPLATE_PARAMS_ERROR, + reason_message=__( + "The query contains one or more malformed template parameters." + ), + suggestion_help_msg=__( + "Please check your query and confirm that all template " + "parameters are surround by double braces, for example, " + '"{{ ds }}". Then, try running your query again.' + ), + ) from ex + + +class SqlQueryRenderException(SqlLabException): + _extra: Optional[Dict[str, Any]] + + def __init__( + self, + sql_json_execution_context: SqlJsonExecutionContext, + error_type: SupersetErrorType, + reason_message: Optional[str] = None, + exception: Optional[Exception] = None, + suggestion_help_msg: Optional[str] = None, + extra: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__( + sql_json_execution_context, + error_type, + reason_message, + exception, + suggestion_help_msg, + ) + self._extra = extra + + @property + def extra(self) -> Optional[Dict[str, Any]]: + return self._extra + + def to_dict(self) -> Dict[str, Any]: + rv = super().to_dict() + if self._extra: + rv["extra"] = self._extra + return rv diff --git a/superset/sqllab/sql_json_executer.py b/superset/sqllab/sql_json_executer.py new file mode 100644 index 000000000..77023b341 --- /dev/null +++ b/superset/sqllab/sql_json_executer.py @@ -0,0 +1,207 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=too-few-public-methods, invalid-name +from __future__ import annotations + +import dataclasses +import logging +from abc import ABC +from typing import Any, Callable, Dict, Optional, TYPE_CHECKING + +from flask import g +from flask_babel import gettext as __ + +from superset.errors import ErrorLevel, SupersetError, SupersetErrorType +from superset.exceptions import ( + SupersetErrorException, + SupersetErrorsException, + SupersetGenericDBErrorException, + SupersetTimeoutException, +) +from superset.sqllab.command_status import SqlJsonExecutionStatus +from superset.utils import core as utils +from superset.utils.dates import now_as_float + +if TYPE_CHECKING: + from superset.queries.dao import QueryDAO + from superset.sqllab.sqllab_execution_context import SqlJsonExecutionContext + +QueryStatus = utils.QueryStatus +logger = logging.getLogger(__name__) + +SqlResults = Dict[str, Any] + +GetSqlResultsTask = Callable[..., SqlResults] + + +class SqlJsonExecutor: + def execute( + self, + execution_context: SqlJsonExecutionContext, + rendered_query: str, + log_params: Optional[Dict[str, Any]], + ) -> SqlJsonExecutionStatus: + raise NotImplementedError() + + +class SqlJsonExecutorBase(SqlJsonExecutor, ABC): + _query_dao: QueryDAO + _get_sql_results_task: GetSqlResultsTask + + def __init__(self, query_dao: QueryDAO, get_sql_results_task: GetSqlResultsTask): + self._query_dao = query_dao + self._get_sql_results_task = get_sql_results_task # type: ignore + + +class SynchronousSqlJsonExecutor(SqlJsonExecutorBase): + _timeout_duration_in_seconds: int + _sqllab_backend_persistence_feature_enable: bool + + def __init__( + self, + query_dao: QueryDAO, + get_sql_results_task: GetSqlResultsTask, + timeout_duration_in_seconds: int, + sqllab_backend_persistence_feature_enable: bool, + ): + super().__init__(query_dao, get_sql_results_task) + self._timeout_duration_in_seconds = timeout_duration_in_seconds + self._sqllab_backend_persistence_feature_enable = ( + sqllab_backend_persistence_feature_enable + ) + + def execute( + self, + execution_context: SqlJsonExecutionContext, + rendered_query: str, + log_params: Optional[Dict[str, Any]], + ) -> SqlJsonExecutionStatus: + query_id = execution_context.query.id + try: + data = self._get_sql_results_with_timeout( + execution_context, rendered_query, log_params + ) + self._query_dao.update_saved_query_exec_info(query_id) + execution_context.set_execution_result(data) + except SupersetTimeoutException as ex: + raise ex + except Exception as ex: + logger.exception("Query %i failed unexpectedly", query_id) + raise SupersetGenericDBErrorException( + utils.error_msg_from_exception(ex) + ) from ex + + if data.get("status") == QueryStatus.FAILED: # type: ignore + # new error payload with rich context + if data["errors"]: # type: ignore + raise SupersetErrorsException( + [SupersetError(**params) for params in data["errors"]] # type: ignore + ) + # old string-only error message + raise SupersetGenericDBErrorException(data["error"]) # type: ignore + + return SqlJsonExecutionStatus.HAS_RESULTS + + def _get_sql_results_with_timeout( + self, + execution_context: SqlJsonExecutionContext, + rendered_query: str, + log_params: Optional[Dict[str, Any]], + ) -> Optional[SqlResults]: + with utils.timeout( + seconds=self._timeout_duration_in_seconds, + error_message=self._get_timeout_error_msg(), + ): + return self._get_sql_results(execution_context, rendered_query, log_params) + + def _get_sql_results( + self, + execution_context: SqlJsonExecutionContext, + rendered_query: str, + log_params: Optional[Dict[str, Any]], + ) -> Optional[SqlResults]: + return self._get_sql_results_task( + execution_context.query.id, + rendered_query, + return_results=True, + store_results=self._is_store_results(execution_context), + user_name=g.user.username + if g.user and hasattr(g.user, "username") + else None, + expand_data=execution_context.expand_data, + log_params=log_params, + ) + + def _is_store_results(self, execution_context: SqlJsonExecutionContext) -> bool: + return ( + self._sqllab_backend_persistence_feature_enable + and not execution_context.select_as_cta + ) + + def _get_timeout_error_msg(self) -> str: + return "The query exceeded the {timeout} seconds timeout.".format( + timeout=self._timeout_duration_in_seconds + ) + + +class ASynchronousSqlJsonExecutor(SqlJsonExecutorBase): + def execute( + self, + execution_context: SqlJsonExecutionContext, + rendered_query: str, + log_params: Optional[Dict[str, Any]], + ) -> SqlJsonExecutionStatus: + + query_id = execution_context.query.id + logger.info("Query %i: Running query on a Celery worker", query_id) + try: + task = self._get_sql_results_task.delay( # type: ignore + query_id, + rendered_query, + return_results=False, + store_results=not execution_context.select_as_cta, + user_name=g.user.username + if g.user and hasattr(g.user, "username") + else None, + start_time=now_as_float(), + expand_data=execution_context.expand_data, + log_params=log_params, + ) + try: + task.forget() + except NotImplementedError: + logger.warning( + "Unable to forget Celery task as backend" + "does not support this operation" + ) + except Exception as ex: + logger.exception("Query %i: %s", query_id, str(ex)) + + message = __("Failed to start remote query on a worker.") + error = SupersetError( + message=message, + error_type=SupersetErrorType.ASYNC_WORKERS_ERROR, + level=ErrorLevel.ERROR, + ) + error_payload = dataclasses.asdict(error) + query = execution_context.query + query.set_extra_json_key("errors", [error_payload]) + query.status = QueryStatus.FAILED + query.error_message = message + raise SupersetErrorException(error) from ex + self._query_dao.update_saved_query_exec_info(query_id) + return SqlJsonExecutionStatus.QUERY_IS_RUNNING diff --git a/superset/utils/sqllab_execution_context.py b/superset/sqllab/sqllab_execution_context.py similarity index 100% rename from superset/utils/sqllab_execution_context.py rename to superset/sqllab/sqllab_execution_context.py diff --git a/superset/sqllab/validators.py b/superset/sqllab/validators.py new file mode 100644 index 000000000..726a2760e --- /dev/null +++ b/superset/sqllab/validators.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=too-few-public-methods +from __future__ import annotations + +from typing import TYPE_CHECKING + +from superset import security_manager +from superset.sqllab.command import CanAccessQueryValidator + +if TYPE_CHECKING: + from superset.models.sql_lab import Query + + +class CanAccessQueryValidatorImpl(CanAccessQueryValidator): + def validate(self, query: Query) -> None: + security_manager.raise_for_access(query=query) diff --git a/superset/views/core.py b/superset/views/core.py index dadda30a7..c9bcac7f3 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=too-many-lines +# pylint: disable=too-many-lines, invalid-name from __future__ import annotations import logging @@ -95,14 +95,28 @@ from superset.models.datasource_access_request import DatasourceAccessRequest from superset.models.slice import Slice from superset.models.sql_lab import Query, TabState from superset.models.user_attributes import UserAttribute +from superset.queries.dao import QueryDAO from superset.security.analytics_db_safety import check_sqlalchemy_uri +from superset.sql_lab import get_sql_results from superset.sql_parse import ParsedQuery, Table from superset.sql_validators import get_validator_by_name from superset.sqllab.command import CommandResult, ExecuteSqlCommand from superset.sqllab.command_status import SqlJsonExecutionStatus -from superset.sqllab.exceptions import SqlLabException +from superset.sqllab.exceptions import ( + QueryIsForbiddenToAccessException, + SqlLabException, +) +from superset.sqllab.execution_context_convertor import ExecutionContextConvertorImpl from superset.sqllab.limiting_factor import LimitingFactor +from superset.sqllab.query_render import SqlQueryRenderImpl +from superset.sqllab.sql_json_executer import ( + ASynchronousSqlJsonExecutor, + SqlJsonExecutor, + SynchronousSqlJsonExecutor, +) +from superset.sqllab.sqllab_execution_context import SqlJsonExecutionContext from superset.sqllab.utils import apply_display_max_row_configuration_if_require +from superset.sqllab.validators import CanAccessQueryValidatorImpl from superset.tasks.async_queries import load_explore_json_into_cache from superset.typing import FlaskResponse from superset.utils import core as utils, csv @@ -111,7 +125,6 @@ from superset.utils.cache import etag_cache from superset.utils.core import apply_max_row_limit, ReservedUrlParameters from superset.utils.dates import now_as_float from superset.utils.decorators import check_dashboard_access -from superset.utils.sqllab_execution_context import SqlJsonExecutionContext from superset.views.base import ( api, BaseSupersetView, @@ -2440,13 +2453,60 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods "user_agent": cast(Optional[str], request.headers.get("USER_AGENT")) } execution_context = SqlJsonExecutionContext(request.json) - command = ExecuteSqlCommand(execution_context, log_params) + command = self._create_sql_json_command(execution_context, log_params) command_result: CommandResult = command.run() return self._create_response_from_execution_context(command_result) except SqlLabException as ex: + logger.error(ex.message) + self._set_http_status_into_Sql_lab_exception(ex) payload = {"errors": [ex.to_dict()]} return json_error_response(status=ex.status, payload=payload) + @staticmethod + def _create_sql_json_command( + execution_context: SqlJsonExecutionContext, log_params: Optional[Dict[str, Any]] + ) -> ExecuteSqlCommand: + query_dao = QueryDAO() + sql_json_executor = Superset._create_sql_json_executor( + execution_context, query_dao + ) + execution_context_convertor = ExecutionContextConvertorImpl() + execution_context_convertor.set_max_row_in_display( + int(config.get("DISPLAY_MAX_ROW")) # type: ignore + ) + return ExecuteSqlCommand( + execution_context, + query_dao, + DatabaseDAO(), + CanAccessQueryValidatorImpl(), + SqlQueryRenderImpl(get_template_processor), + sql_json_executor, + execution_context_convertor, + config.get("SQLLAB_CTAS_NO_LIMIT"), # type: ignore + log_params, + ) + + @staticmethod + def _create_sql_json_executor( + execution_context: SqlJsonExecutionContext, query_dao: QueryDAO + ) -> SqlJsonExecutor: + sql_json_executor: SqlJsonExecutor + if execution_context.is_run_asynchronous(): + sql_json_executor = ASynchronousSqlJsonExecutor(query_dao, get_sql_results) + else: + sql_json_executor = SynchronousSqlJsonExecutor( + query_dao, + get_sql_results, + config.get("SQLLAB_TIMEOUT"), # type: ignore + is_feature_enabled("SQLLAB_BACKEND_PERSISTENCE"), + ) + return sql_json_executor + + @staticmethod + def _set_http_status_into_Sql_lab_exception(ex: SqlLabException) -> None: + if isinstance(ex, QueryIsForbiddenToAccessException): + ex.status = 403 + def _create_response_from_execution_context( # pylint: disable=invalid-name, no-self-use self, command_result: CommandResult, ) -> FlaskResponse: diff --git a/tests/integration_tests/celery_tests.py b/tests/integration_tests/celery_tests.py index f4224d20a..c7465bcbe 100644 --- a/tests/integration_tests/celery_tests.py +++ b/tests/integration_tests/celery_tests.py @@ -217,7 +217,7 @@ def test_run_sync_query_cta_no_data(setup_sqllab): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW]) @mock.patch( - "superset.utils.sqllab_execution_context.get_cta_schema_name", + "superset.sqllab.sqllab_execution_context.get_cta_schema_name", lambda d, u, s, sql: CTAS_SCHEMA_NAME, ) def test_run_sync_query_cta_config(setup_sqllab, ctas_method): @@ -245,7 +245,7 @@ def test_run_sync_query_cta_config(setup_sqllab, ctas_method): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW]) @mock.patch( - "superset.utils.sqllab_execution_context.get_cta_schema_name", + "superset.sqllab.sqllab_execution_context.get_cta_schema_name", lambda d, u, s, sql: CTAS_SCHEMA_NAME, ) def test_run_async_query_cta_config(setup_sqllab, ctas_method): diff --git a/tests/integration_tests/core_tests.py b/tests/integration_tests/core_tests.py index f91ab454f..c35eb1bc0 100644 --- a/tests/integration_tests/core_tests.py +++ b/tests/integration_tests/core_tests.py @@ -749,10 +749,11 @@ class TestCore(SupersetTestCase): data = self.run_sql(sql, "fdaklj3ws") self.assertEqual(data["data"][0]["test"], "2") + @pytest.mark.ofek @mock.patch( "tests.integration_tests.superset_test_custom_template_processors.datetime" ) - @mock.patch("superset.sql_lab.get_sql_results") + @mock.patch("superset.views.core.get_sql_results") def test_custom_templated_sql_json(self, sql_lab_mock, mock_dt) -> None: """Test sqllab receives macros expanded query.""" mock_dt.utcnow = mock.Mock(return_value=datetime.datetime(1970, 1, 1)) diff --git a/tests/integration_tests/sqllab_tests.py b/tests/integration_tests/sqllab_tests.py index 1e9751414..b6dea6cf8 100644 --- a/tests/integration_tests/sqllab_tests.py +++ b/tests/integration_tests/sqllab_tests.py @@ -189,7 +189,7 @@ class TestSqlLab(SupersetTestCase): return with mock.patch( - "superset.utils.sqllab_execution_context.get_cta_schema_name", + "superset.sqllab.sqllab_execution_context.get_cta_schema_name", lambda d, u, s, sql: f"{u.username}_database", ): old_allow_ctas = examples_db.allow_ctas