From 08358d623b4938956526df840ef9e466bf281b6a Mon Sep 17 00:00:00 2001 From: Ville Brofeldt <33317356+villebro@users.noreply.github.com> Date: Fri, 7 Aug 2020 17:37:40 +0300 Subject: [PATCH] fix: handle query exceptions gracefully (#10548) * fix: handle query exceptions gracefully * add more recasts * add test * disable test for presto * switch to SQLA error --- superset/common/query_context.py | 6 ++- superset/connectors/sqla/models.py | 85 ++++++++++++++++++++++-------- superset/views/core.py | 7 +-- superset/viz.py | 50 ++++++++++++------ tests/sqla_models_tests.py | 25 +++++++++ 5 files changed, 132 insertions(+), 41 deletions(-) diff --git a/superset/common/query_context.py b/superset/common/query_context.py index e602fbfac..0d33f9c4a 100644 --- a/superset/common/query_context.py +++ b/superset/common/query_context.py @@ -27,6 +27,7 @@ from superset import app, cache, db, security_manager from superset.common.query_object import QueryObject from superset.connectors.base.models import BaseDatasource from superset.connectors.connector_registry import ConnectorRegistry +from superset.exceptions import QueryObjectValidationError from superset.stats_logger import BaseStatsLogger from superset.utils import core as utils from superset.utils.core import DTTM_ALIAS @@ -244,10 +245,13 @@ class QueryContext: if not self.force: stats_logger.incr("loaded_from_source_without_force") is_loaded = True + except QueryObjectValidationError as ex: + error_message = str(ex) + status = utils.QueryStatus.FAILED except Exception as ex: # pylint: disable=broad-except logger.exception(ex) if not error_message: - error_message = "{}".format(ex) + error_message = str(ex) status = utils.QueryStatus.FAILED stacktrace = utils.get_stacktrace() diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 530a2e10a..cfc807d1f 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -25,6 +25,7 @@ import sqlparse from flask import escape, Markup from flask_appbuilder import Model from flask_babel import lazy_gettext as _ +from jinja2.exceptions import TemplateError from sqlalchemy import ( and_, asc, @@ -40,7 +41,7 @@ from sqlalchemy import ( Table, Text, ) -from sqlalchemy.exc import CompileError +from sqlalchemy.exc import CompileError, SQLAlchemyError from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, Session from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.schema import UniqueConstraint @@ -51,7 +52,7 @@ from superset import app, db, is_feature_enabled, security_manager from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric from superset.constants import NULL_STRING from superset.db_engine_specs.base import TimestampExpression -from superset.exceptions import DatabaseNotFound +from superset.exceptions import DatabaseNotFound, QueryObjectValidationError from superset.jinja_context import ( BaseTemplateProcessor, ExtraCache, @@ -634,7 +635,15 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at if self.fetch_values_predicate: tp = self.get_template_processor() - qry = qry.where(text(tp.process_template(self.fetch_values_predicate))) + try: + qry = qry.where(text(tp.process_template(self.fetch_values_predicate))) + except TemplateError as ex: + raise QueryObjectValidationError( + _( + "Error in jinja expression in fetch values predicate: %(msg)s", + msg=ex.message, + ) + ) engine = self.database.get_sqla_engine() sql = "{}".format(qry.compile(engine, compile_kwargs={"literal_binds": True})) @@ -684,7 +693,16 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at if self.sql: from_sql = self.sql if template_processor: - from_sql = template_processor.process_template(from_sql) + try: + from_sql = template_processor.process_template(from_sql) + except TemplateError as ex: + raise QueryObjectValidationError( + _( + "Error in jinja expression in FROM clause: %(msg)s", + msg=ex.message, + ) + ) + from_sql = sqlparse.format(from_sql, strip_comments=True) return TextAsFrom(sa.text(from_sql), []).alias("expr_qry") return self.get_sqla_table() @@ -730,10 +748,15 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at :returns: A list of SQL clauses to be ANDed together. :rtype: List[str] """ - return [ - text("({})".format(template_processor.process_template(f.clause))) - for f in security_manager.get_rls_filters(self) - ] + try: + return [ + text("({})".format(template_processor.process_template(f.clause))) + for f in security_manager.get_rls_filters(self) + ] + except TemplateError as ex: + raise QueryObjectValidationError( + _("Error in jinja expression in RLS filters: %(msg)s", msg=ex.message,) + ) def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements self, @@ -791,7 +814,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at metrics_by_name: Dict[str, SqlMetric] = {m.metric_name: m for m in self.metrics} if not granularity and is_timeseries: - raise Exception( + raise QueryObjectValidationError( _( "Datetime column not provided as part table configuration " "and is required by this type of chart" @@ -802,7 +825,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at and not columns and (is_sip_38 or (not is_sip_38 and not groupby)) ): - raise Exception(_("Empty query?")) + raise QueryObjectValidationError(_("Empty query?")) metrics_exprs: List[ColumnElement] = [] for metric in metrics: if utils.is_adhoc_metric(metric): @@ -811,7 +834,9 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at elif isinstance(metric, str) and metric in metrics_by_name: metrics_exprs.append(metrics_by_name[metric].get_sqla_col()) else: - raise Exception(_("Metric '%(metric)s' does not exist", metric=metric)) + raise QueryObjectValidationError( + _("Metric '%(metric)s' does not exist", metric=metric) + ) if metrics_exprs: main_metric_expr = metrics_exprs[0] else: @@ -958,7 +983,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at != None ) else: - raise Exception( + raise QueryObjectValidationError( _("Invalid filter operation type: %(op)s", op=op) ) if config["ENABLE_ROW_LEVEL_SECURITY"]: @@ -966,11 +991,27 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at if extras: where = extras.get("where") if where: - where = template_processor.process_template(where) + try: + where = template_processor.process_template(where) + except TemplateError as ex: + raise QueryObjectValidationError( + _( + "Error in jinja expression in WHERE clause: %(msg)s", + msg=ex.message, + ) + ) where_clause_and += [sa.text("({})".format(where))] having = extras.get("having") if having: - having = template_processor.process_template(having) + try: + having = template_processor.process_template(having) + except TemplateError as ex: + raise QueryObjectValidationError( + _( + "Error in jinja expression in HAVING clause: %(msg)s", + msg=ex.message, + ) + ) having_clause_and += [sa.text("({})".format(having))] if granularity: qry = qry.where(and_(*(time_filters + where_clause_and))) @@ -1117,7 +1158,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at ): ob = metrics_by_name[timeseries_limit_metric].get_sqla_col() else: - raise Exception( + raise QueryObjectValidationError( _("Metric '%(metric)s' does not exist", metric=timeseries_limit_metric) ) @@ -1159,7 +1200,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at labels_expected = query_str_ext.labels_expected if df is not None and not df.empty: if len(df.columns) != len(labels_expected): - raise Exception( + raise QueryObjectValidationError( f"For {sql}, df.columns: {df.columns}" f" differs from {labels_expected}" ) @@ -1193,13 +1234,13 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at """Fetches the metadata for the table and merges it in""" try: table_ = self.get_sqla_table_object() - except Exception as ex: - logger.exception(ex) - raise Exception( + except SQLAlchemyError: + raise QueryObjectValidationError( _( - "Table [{}] doesn't seem to exist in the specified database, " - "couldn't fetch column information" - ).format(self.table_name) + "Table %(table)s doesn't seem to exist in the specified database, " + "couldn't fetch column information", + table=self.table_name, + ) ) metrics = [] diff --git a/superset/views/core.py b/superset/views/core.py index f3a226459..48cd4583d 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -32,6 +32,7 @@ from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.security.decorators import has_access, has_access_api from flask_appbuilder.security.sqla import models as ab_models from flask_babel import gettext as __, lazy_gettext as _ +from jinja2.exceptions import TemplateError from sqlalchemy import and_, or_, select from sqlalchemy.engine.url import make_url from sqlalchemy.exc import ( @@ -535,7 +536,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods return self.generate_json(viz_obj, response_type) except SupersetException as ex: - return json_error_response(utils.error_msg_from_exception(ex)) + return json_error_response(utils.error_msg_from_exception(ex), 400) @event_logger.log_this @has_access @@ -2300,10 +2301,10 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods rendered_query = template_processor.process_template( query.sql, **template_params ) - except Exception as ex: # pylint: disable=broad-except + except TemplateError as ex: error_msg = utils.error_msg_from_exception(ex) return json_error_response( - f"Query {query_id}: Template rendering failed: {error_msg}" + f"Query {query_id}: Template syntax error: {error_msg}" ) # Limit is not applied to the CTA queries if SQLLAB_CTAS_NO_LIMIT flag is set diff --git a/superset/viz.py b/superset/viz.py index 27b6ad9e5..88d67bb03 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -329,13 +329,17 @@ class BaseViz: # default order direction order_desc = form_data.get("order_desc", True) - since, until = utils.get_since_until( - relative_start=relative_start, - relative_end=relative_end, - time_range=form_data.get("time_range"), - since=form_data.get("since"), - until=form_data.get("until"), - ) + try: + since, until = utils.get_since_until( + relative_start=relative_start, + relative_end=relative_end, + time_range=form_data.get("time_range"), + since=form_data.get("since"), + until=form_data.get("until"), + ) + except ValueError as ex: + raise QueryObjectValidationError(str(ex)) + time_shift = form_data.get("time_shift", "") self.time_shift = utils.parse_past_timedelta(time_shift) from_dttm = None if since is None else (since - self.time_shift) @@ -475,6 +479,16 @@ class BaseViz: if not self.force: stats_logger.incr("loaded_from_source_without_force") is_loaded = True + except QueryObjectValidationError as ex: + error = dataclasses.asdict( + SupersetError( + message=str(ex), + level=ErrorLevel.ERROR, + error_type=SupersetErrorType.VIZ_GET_DF_ERROR, + ) + ) + self.errors.append(error) + self.status = utils.QueryStatus.FAILED except Exception as ex: logger.exception(ex) @@ -889,13 +903,16 @@ class CalHeatmapViz(BaseViz): values[str(v / 10 ** 9)] = obj.get(metric) data[metric] = values - start, end = utils.get_since_until( - relative_start=relative_start, - relative_end=relative_end, - time_range=form_data.get("time_range"), - since=form_data.get("since"), - until=form_data.get("until"), - ) + try: + start, end = utils.get_since_until( + relative_start=relative_start, + relative_end=relative_end, + time_range=form_data.get("time_range"), + since=form_data.get("since"), + until=form_data.get("until"), + ) + except ValueError as ex: + raise QueryObjectValidationError(str(ex)) if not start or not end: raise QueryObjectValidationError( "Please provide both time bounds (Since and Until)" @@ -1288,7 +1305,10 @@ class NVD3TimeSeriesViz(NVD3Viz): for option in time_compare: query_object = self.query_obj() - delta = utils.parse_past_timedelta(option) + try: + delta = utils.parse_past_timedelta(option) + except ValueError as ex: + raise QueryObjectValidationError(str(ex)) query_object["inner_from_dttm"] = query_object["from_dttm"] query_object["inner_to_dttm"] = query_object["to_dttm"] diff --git a/tests/sqla_models_tests.py b/tests/sqla_models_tests.py index 3b9a84f6e..4666fd7f2 100644 --- a/tests/sqla_models_tests.py +++ b/tests/sqla_models_tests.py @@ -17,10 +17,12 @@ # isort:skip_file from typing import Any, Dict, NamedTuple, List, Tuple, Union from unittest.mock import patch +import pytest import tests.test_app from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.db_engine_specs.druid import DruidEngineSpec +from superset.exceptions import QueryObjectValidationError from superset.models.core import Database from superset.utils.core import DbColumnType, get_example_database, FilterOperator @@ -170,3 +172,26 @@ class TestDatabaseModel(SupersetTestCase): sqla_query = table.get_sqla_query(**query_obj) sql = table.database.compile_sqla_query(sqla_query.sqla_query) self.assertIn(filter_.expected, sql) + + def test_incorrect_jinja_syntax_raises_correct_exception(self): + query_obj = { + "granularity": None, + "from_dttm": None, + "to_dttm": None, + "groupby": ["user"], + "metrics": [], + "is_timeseries": False, + "filter": [], + "extras": {}, + } + + # Table with Jinja callable. + table = SqlaTable( + table_name="test_table", + sql="SELECT '{{ abcd xyz + 1 ASDF }}' as user", + database=get_example_database(), + ) + # TODO(villebro): make it work with presto + if get_example_database().backend != "presto": + with pytest.raises(QueryObjectValidationError): + table.get_sqla_query(**query_obj)