From 07a76f83b14df6c2fb343f7933dc87f53ecd5155 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt <33317356+villebro@users.noreply.github.com> Date: Tue, 23 Jul 2019 22:13:58 +0300 Subject: [PATCH] [Bugfix] Remove prequery properties from query_obj (#7896) * Create query_obj for every filter * Deprecate is_prequery and prequeries from query_obj * Fix tests * Fix typos and remove redundant ; from sql * Add typing to namedtuples and move all query str logic to one place * Fix unit test --- superset/common/query_object.py | 6 --- superset/connectors/druid/models.py | 2 - superset/connectors/sqla/models.py | 61 +++++++++++++++-------------- superset/db_engine_specs/base.py | 6 +-- superset/db_engine_specs/druid.py | 4 +- superset/db_engine_specs/gsheets.py | 4 +- superset/db_engine_specs/pinot.py | 6 +-- superset/models/core.py | 2 +- superset/views/core.py | 7 +--- superset/viz.py | 2 - tests/model_tests.py | 13 ++---- tests/sqla_models_tests.py | 1 - 12 files changed, 48 insertions(+), 66 deletions(-) diff --git a/superset/common/query_object.py b/superset/common/query_object.py index 7d72aa5eb..28f2303c2 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -48,8 +48,6 @@ class QueryObject: timeseries_limit_metric: Optional[Dict] = None, order_desc: bool = True, extras: Optional[Dict] = None, - prequeries: Optional[List[Dict]] = None, - is_prequery: bool = False, columns: List[str] = None, orderby: List[List] = None, relative_start: str = app.config.get("DEFAULT_RELATIVE_START_TIME", "today"), @@ -78,8 +76,6 @@ class QueryObject: self.timeseries_limit = timeseries_limit self.timeseries_limit_metric = timeseries_limit_metric self.order_desc = order_desc - self.prequeries = prequeries if prequeries is not None else [] - self.is_prequery = is_prequery self.extras = extras if extras is not None else {} self.columns = columns if columns is not None else [] self.orderby = orderby if orderby is not None else [] @@ -97,8 +93,6 @@ class QueryObject: "timeseries_limit": self.timeseries_limit, "timeseries_limit_metric": self.timeseries_limit_metric, "order_desc": self.order_desc, - "prequeries": self.prequeries, - "is_prequery": self.is_prequery, "extras": self.extras, "columns": self.columns, "orderby": self.orderby, diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index 6a0873cf7..d7b00c373 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -1120,8 +1120,6 @@ class DruidDatasource(Model, BaseDatasource): phase=2, client=None, order_desc=True, - prequeries=None, - is_prequery=False, ): """Runs a query against Druid and returns a dataframe. """ diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index dfedca3a8..35b8590fe 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -15,10 +15,10 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=C,R,W -from collections import namedtuple, OrderedDict +from collections import OrderedDict from datetime import datetime import logging -from typing import Any, List, Optional, Union +from typing import Any, List, NamedTuple, Optional, Union from flask import escape, Markup from flask_appbuilder import Model @@ -45,7 +45,7 @@ from sqlalchemy.orm import backref, relationship from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.schema import UniqueConstraint from sqlalchemy.sql import column, literal_column, table, text -from sqlalchemy.sql.expression import Label, TextAsFrom +from sqlalchemy.sql.expression import Label, Select, TextAsFrom import sqlparse from superset import app, db, security_manager @@ -61,10 +61,18 @@ from superset.utils import core as utils, import_datasource config = app.config metadata = Model.metadata # pylint: disable=no-member -SqlaQuery = namedtuple( - "SqlaQuery", ["sqla_query", "labels_expected", "extra_cache_keys"] -) -QueryStringExtended = namedtuple("QueryStringExtended", ["sql", "labels_expected"]) + +class SqlaQuery(NamedTuple): + extra_cache_keys: List[Any] + labels_expected: List[str] + prequeries: List[str] + sqla_query: Select + + +class QueryStringExtended(NamedTuple): + labels_expected: List[str] + prequeries: List[str] + sql: str class AnnotationDatasource(BaseDatasource): @@ -351,7 +359,7 @@ class SqlaTable(Model, BaseDatasource): """ label_expected = label or sqla_col.name db_engine_spec = self.database.db_engine_spec - if db_engine_spec.supports_column_aliases: + if db_engine_spec.allows_column_aliases: label = db_engine_spec.make_label_compatible(label_expected) sqla_col = sqla_col.label(label) sqla_col._df_label_expected = label_expected @@ -532,18 +540,20 @@ class SqlaTable(Model, BaseDatasource): def get_template_processor(self, **kwargs): return get_template_processor(table=self, database=self.database, **kwargs) - def get_query_str_extended(self, query_obj): + def get_query_str_extended(self, query_obj) -> QueryStringExtended: sqlaq = self.get_sqla_query(**query_obj) sql = self.database.compile_sqla_query(sqlaq.sqla_query) logging.info(sql) sql = sqlparse.format(sql, reindent=True) - if query_obj["is_prequery"]: - query_obj["prequeries"].append(sql) sql = self.mutate_query_from_config(sql) - return QueryStringExtended(labels_expected=sqlaq.labels_expected, sql=sql) + return QueryStringExtended( + labels_expected=sqlaq.labels_expected, sql=sql, prequeries=sqlaq.prequeries + ) def get_query_str(self, query_obj): - return self.get_query_str_extended(query_obj).sql + query_str_ext = self.get_query_str_extended(query_obj) + all_queries = query_str_ext.prequeries + [query_str_ext.sql] + return ";\n\n".join(all_queries) + ";" def get_sqla_table(self): tbl = table(self.table_name) @@ -606,8 +616,6 @@ class SqlaTable(Model, BaseDatasource): extras=None, columns=None, order_desc=True, - prequeries=None, - is_prequery=False, ): """Querying any sqla table from this common interface""" template_kwargs = { @@ -624,6 +632,7 @@ class SqlaTable(Model, BaseDatasource): template_kwargs["extra_cache_keys"] = extra_cache_keys template_processor = self.get_template_processor(**template_kwargs) db_engine_spec = self.database.db_engine_spec + prequeries: List[str] = [] orderby = orderby or [] @@ -793,7 +802,7 @@ class SqlaTable(Model, BaseDatasource): qry = qry.limit(row_limit) if is_timeseries and timeseries_limit and groupby and not time_groupby_inline: - if self.database.db_engine_spec.inner_joins: + if self.database.db_engine_spec.allows_joins: # some sql dialects require for order by expressions # to also be in the select clause -- others, e.g. vertica, # require a unique inner alias @@ -844,10 +853,8 @@ class SqlaTable(Model, BaseDatasource): ) ] - # run subquery to get top groups - subquery_obj = { - "prequeries": prequeries, - "is_prequery": True, + # run prequery to get top groups + prequery_obj = { "is_timeseries": False, "row_limit": timeseries_limit, "groupby": groupby, @@ -861,7 +868,8 @@ class SqlaTable(Model, BaseDatasource): "columns": columns, "order_desc": True, } - result = self.query(subquery_obj) + result = self.query(prequery_obj) + prequeries.append(result.query) dimensions = [ c for c in result.df.columns @@ -873,9 +881,10 @@ class SqlaTable(Model, BaseDatasource): qry = qry.where(top_groups) return SqlaQuery( - sqla_query=qry.select_from(tbl), - labels_expected=labels_expected, extra_cache_keys=extra_cache_keys, + labels_expected=labels_expected, + sqla_query=qry.select_from(tbl), + prequeries=prequeries, ) def _get_timeseries_orderby(self, timeseries_limit_metric, metrics_dict, cols): @@ -929,12 +938,6 @@ class SqlaTable(Model, BaseDatasource): db_engine_spec = self.database.db_engine_spec error_message = db_engine_spec.extract_error_message(e) - # if this is a main query with prequeries, combine them together - if not query_obj["is_prequery"]: - query_obj["prequeries"].append(sql) - sql = ";\n\n".join(query_obj["prequeries"]) - sql += ";" - return QueryResult( status=status, df=df, diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index eced6bb91..d4a550c51 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -117,9 +117,9 @@ class BaseEngineSpec(object): time_groupby_inline = False limit_method = LimitMethod.FORCE_LIMIT time_secondary_columns = False - inner_joins = True - allows_subquery = True - supports_column_aliases = True + allows_joins = True + allows_subqueries = True + allows_column_aliases = True force_column_alias_quotes = False arraysize = 0 max_column_name_length = 0 diff --git a/superset/db_engine_specs/druid.py b/superset/db_engine_specs/druid.py index 46095eea6..8cd0c4cb9 100644 --- a/superset/db_engine_specs/druid.py +++ b/superset/db_engine_specs/druid.py @@ -22,8 +22,8 @@ class DruidEngineSpec(BaseEngineSpec): """Engine spec for Druid.io""" engine = "druid" - inner_joins = False - allows_subquery = True + allows_joins = False + allows_subqueries = True time_grain_functions = { None: "{col}", diff --git a/superset/db_engine_specs/gsheets.py b/superset/db_engine_specs/gsheets.py index ef8067987..d7b3bc7a1 100644 --- a/superset/db_engine_specs/gsheets.py +++ b/superset/db_engine_specs/gsheets.py @@ -22,5 +22,5 @@ class GSheetsEngineSpec(SqliteEngineSpec): """Engine for Google spreadsheets""" engine = "gsheets" - inner_joins = False - allows_subquery = False + allows_joins = False + allows_subqueries = False diff --git a/superset/db_engine_specs/pinot.py b/superset/db_engine_specs/pinot.py index 1a01fa3a5..132cb48f1 100644 --- a/superset/db_engine_specs/pinot.py +++ b/superset/db_engine_specs/pinot.py @@ -24,9 +24,9 @@ from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression class PinotEngineSpec(BaseEngineSpec): engine = "pinot" - allows_subquery = False - inner_joins = False - supports_column_aliases = False + allows_subqueries = False + allows_joins = False + allows_column_aliases = False # Pinot does its own conversion below time_grain_functions: Dict[Optional[str], str] = { diff --git a/superset/models/core.py b/superset/models/core.py index cb8061197..5a98d5efa 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -771,7 +771,7 @@ class Database(Model, AuditMixinNullable, ImportMixin): @property def allows_subquery(self): - return self.db_engine_spec.allows_subquery + return self.db_engine_spec.allows_subqueries @property def data(self): diff --git a/superset/views/core.py b/superset/views/core.py index c1bb28555..3c652c3fe 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -1008,12 +1008,7 @@ class Superset(BaseSupersetView): logging.exception(e) return json_error_response(e) - if query_obj and query_obj["prequeries"]: - query_obj["prequeries"].append(query) - query = ";\n\n".join(query_obj["prequeries"]) - if query: - query += ";" - else: + if not query: query = "No query." return self.json_response( diff --git a/superset/viz.py b/superset/viz.py index 52075be02..b804ec9cb 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -320,8 +320,6 @@ class BaseViz(object): "extras": extras, "timeseries_limit_metric": timeseries_limit_metric, "order_desc": order_desc, - "prequeries": [], - "is_prequery": False, } return d diff --git a/tests/model_tests.py b/tests/model_tests.py index 7acd4b6b0..cd43a8c04 100644 --- a/tests/model_tests.py +++ b/tests/model_tests.py @@ -195,11 +195,11 @@ class SqlaTableModelTestCase(SupersetTestCase): ds_col.expression = None ds_col.python_date_format = None spec = self.get_database_by_id(tbl.database_id).db_engine_spec - if not spec.inner_joins and inner_join: + if not spec.allows_joins and inner_join: # if the db does not support inner joins, we cannot force it so return None - old_inner_join = spec.inner_joins - spec.inner_joins = inner_join + old_inner_join = spec.allows_joins + spec.allows_joins = inner_join arbitrary_gby = "state || gender || '_test'" arbitrary_metric = dict( label="arbitrary", expressionType="SQL", sqlExpression="COUNT(1)" @@ -209,12 +209,10 @@ class SqlaTableModelTestCase(SupersetTestCase): metrics=[arbitrary_metric], filter=[], is_timeseries=is_timeseries, - prequeries=[], columns=[], granularity="ds", from_dttm=None, to_dttm=None, - is_prequery=False, extras=dict(time_grain_sqla="P1Y"), ) qr = tbl.query(query_obj) @@ -226,7 +224,7 @@ class SqlaTableModelTestCase(SupersetTestCase): self.assertIn("JOIN", sql.upper()) else: self.assertNotIn("JOIN", sql.upper()) - spec.inner_joins = old_inner_join + spec.allows_joins = old_inner_join self.assertIsNotNone(qr.df) return qr.df @@ -258,7 +256,6 @@ class SqlaTableModelTestCase(SupersetTestCase): granularity=None, from_dttm=None, to_dttm=None, - is_prequery=False, extras={}, ) sql = tbl.get_query_str(query_obj) @@ -285,7 +282,6 @@ class SqlaTableModelTestCase(SupersetTestCase): granularity=None, from_dttm=None, to_dttm=None, - is_prequery=False, extras={}, ) @@ -306,7 +302,6 @@ class SqlaTableModelTestCase(SupersetTestCase): granularity=None, from_dttm=None, to_dttm=None, - is_prequery=False, extras={}, ) diff --git a/tests/sqla_models_tests.py b/tests/sqla_models_tests.py index bbf0d2236..8f6212708 100644 --- a/tests/sqla_models_tests.py +++ b/tests/sqla_models_tests.py @@ -52,7 +52,6 @@ class DatabaseModelTestCase(SupersetTestCase): "metrics": [], "is_timeseries": False, "filter": [], - "is_prequery": False, "extras": {"where": "(user != '{{ cache_key_wrapper('user_2') }}')"}, } extra_cache_keys = table.get_extra_cache_keys(query_obj)