[Bugfix] Remove prequery properties from query_obj (#7896)

* Create query_obj for every filter

* Deprecate is_prequery and prequeries from query_obj

* Fix tests

* Fix typos and remove redundant ; from sql

* Add typing to namedtuples and move all query str logic to one place

* Fix unit test
This commit is contained in:
Ville Brofeldt 2019-07-23 22:13:58 +03:00 committed by GitHub
parent 2221445f44
commit 07a76f83b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 48 additions and 66 deletions

View File

@ -48,8 +48,6 @@ class QueryObject:
timeseries_limit_metric: Optional[Dict] = None,
order_desc: bool = True,
extras: Optional[Dict] = None,
prequeries: Optional[List[Dict]] = None,
is_prequery: bool = False,
columns: List[str] = None,
orderby: List[List] = None,
relative_start: str = app.config.get("DEFAULT_RELATIVE_START_TIME", "today"),
@ -78,8 +76,6 @@ class QueryObject:
self.timeseries_limit = timeseries_limit
self.timeseries_limit_metric = timeseries_limit_metric
self.order_desc = order_desc
self.prequeries = prequeries if prequeries is not None else []
self.is_prequery = is_prequery
self.extras = extras if extras is not None else {}
self.columns = columns if columns is not None else []
self.orderby = orderby if orderby is not None else []
@ -97,8 +93,6 @@ class QueryObject:
"timeseries_limit": self.timeseries_limit,
"timeseries_limit_metric": self.timeseries_limit_metric,
"order_desc": self.order_desc,
"prequeries": self.prequeries,
"is_prequery": self.is_prequery,
"extras": self.extras,
"columns": self.columns,
"orderby": self.orderby,

View File

@ -1120,8 +1120,6 @@ class DruidDatasource(Model, BaseDatasource):
phase=2,
client=None,
order_desc=True,
prequeries=None,
is_prequery=False,
):
"""Runs a query against Druid and returns a dataframe.
"""

View File

@ -15,10 +15,10 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=C,R,W
from collections import namedtuple, OrderedDict
from collections import OrderedDict
from datetime import datetime
import logging
from typing import Any, List, Optional, Union
from typing import Any, List, NamedTuple, Optional, Union
from flask import escape, Markup
from flask_appbuilder import Model
@ -45,7 +45,7 @@ from sqlalchemy.orm import backref, relationship
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import column, literal_column, table, text
from sqlalchemy.sql.expression import Label, TextAsFrom
from sqlalchemy.sql.expression import Label, Select, TextAsFrom
import sqlparse
from superset import app, db, security_manager
@ -61,10 +61,18 @@ from superset.utils import core as utils, import_datasource
config = app.config
metadata = Model.metadata # pylint: disable=no-member
SqlaQuery = namedtuple(
"SqlaQuery", ["sqla_query", "labels_expected", "extra_cache_keys"]
)
QueryStringExtended = namedtuple("QueryStringExtended", ["sql", "labels_expected"])
class SqlaQuery(NamedTuple):
extra_cache_keys: List[Any]
labels_expected: List[str]
prequeries: List[str]
sqla_query: Select
class QueryStringExtended(NamedTuple):
labels_expected: List[str]
prequeries: List[str]
sql: str
class AnnotationDatasource(BaseDatasource):
@ -351,7 +359,7 @@ class SqlaTable(Model, BaseDatasource):
"""
label_expected = label or sqla_col.name
db_engine_spec = self.database.db_engine_spec
if db_engine_spec.supports_column_aliases:
if db_engine_spec.allows_column_aliases:
label = db_engine_spec.make_label_compatible(label_expected)
sqla_col = sqla_col.label(label)
sqla_col._df_label_expected = label_expected
@ -532,18 +540,20 @@ class SqlaTable(Model, BaseDatasource):
def get_template_processor(self, **kwargs):
return get_template_processor(table=self, database=self.database, **kwargs)
def get_query_str_extended(self, query_obj):
def get_query_str_extended(self, query_obj) -> QueryStringExtended:
sqlaq = self.get_sqla_query(**query_obj)
sql = self.database.compile_sqla_query(sqlaq.sqla_query)
logging.info(sql)
sql = sqlparse.format(sql, reindent=True)
if query_obj["is_prequery"]:
query_obj["prequeries"].append(sql)
sql = self.mutate_query_from_config(sql)
return QueryStringExtended(labels_expected=sqlaq.labels_expected, sql=sql)
return QueryStringExtended(
labels_expected=sqlaq.labels_expected, sql=sql, prequeries=sqlaq.prequeries
)
def get_query_str(self, query_obj):
return self.get_query_str_extended(query_obj).sql
query_str_ext = self.get_query_str_extended(query_obj)
all_queries = query_str_ext.prequeries + [query_str_ext.sql]
return ";\n\n".join(all_queries) + ";"
def get_sqla_table(self):
tbl = table(self.table_name)
@ -606,8 +616,6 @@ class SqlaTable(Model, BaseDatasource):
extras=None,
columns=None,
order_desc=True,
prequeries=None,
is_prequery=False,
):
"""Querying any sqla table from this common interface"""
template_kwargs = {
@ -624,6 +632,7 @@ class SqlaTable(Model, BaseDatasource):
template_kwargs["extra_cache_keys"] = extra_cache_keys
template_processor = self.get_template_processor(**template_kwargs)
db_engine_spec = self.database.db_engine_spec
prequeries: List[str] = []
orderby = orderby or []
@ -793,7 +802,7 @@ class SqlaTable(Model, BaseDatasource):
qry = qry.limit(row_limit)
if is_timeseries and timeseries_limit and groupby and not time_groupby_inline:
if self.database.db_engine_spec.inner_joins:
if self.database.db_engine_spec.allows_joins:
# some sql dialects require for order by expressions
# to also be in the select clause -- others, e.g. vertica,
# require a unique inner alias
@ -844,10 +853,8 @@ class SqlaTable(Model, BaseDatasource):
)
]
# run subquery to get top groups
subquery_obj = {
"prequeries": prequeries,
"is_prequery": True,
# run prequery to get top groups
prequery_obj = {
"is_timeseries": False,
"row_limit": timeseries_limit,
"groupby": groupby,
@ -861,7 +868,8 @@ class SqlaTable(Model, BaseDatasource):
"columns": columns,
"order_desc": True,
}
result = self.query(subquery_obj)
result = self.query(prequery_obj)
prequeries.append(result.query)
dimensions = [
c
for c in result.df.columns
@ -873,9 +881,10 @@ class SqlaTable(Model, BaseDatasource):
qry = qry.where(top_groups)
return SqlaQuery(
sqla_query=qry.select_from(tbl),
labels_expected=labels_expected,
extra_cache_keys=extra_cache_keys,
labels_expected=labels_expected,
sqla_query=qry.select_from(tbl),
prequeries=prequeries,
)
def _get_timeseries_orderby(self, timeseries_limit_metric, metrics_dict, cols):
@ -929,12 +938,6 @@ class SqlaTable(Model, BaseDatasource):
db_engine_spec = self.database.db_engine_spec
error_message = db_engine_spec.extract_error_message(e)
# if this is a main query with prequeries, combine them together
if not query_obj["is_prequery"]:
query_obj["prequeries"].append(sql)
sql = ";\n\n".join(query_obj["prequeries"])
sql += ";"
return QueryResult(
status=status,
df=df,

View File

@ -117,9 +117,9 @@ class BaseEngineSpec(object):
time_groupby_inline = False
limit_method = LimitMethod.FORCE_LIMIT
time_secondary_columns = False
inner_joins = True
allows_subquery = True
supports_column_aliases = True
allows_joins = True
allows_subqueries = True
allows_column_aliases = True
force_column_alias_quotes = False
arraysize = 0
max_column_name_length = 0

View File

@ -22,8 +22,8 @@ class DruidEngineSpec(BaseEngineSpec):
"""Engine spec for Druid.io"""
engine = "druid"
inner_joins = False
allows_subquery = True
allows_joins = False
allows_subqueries = True
time_grain_functions = {
None: "{col}",

View File

@ -22,5 +22,5 @@ class GSheetsEngineSpec(SqliteEngineSpec):
"""Engine for Google spreadsheets"""
engine = "gsheets"
inner_joins = False
allows_subquery = False
allows_joins = False
allows_subqueries = False

View File

@ -24,9 +24,9 @@ from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression
class PinotEngineSpec(BaseEngineSpec):
engine = "pinot"
allows_subquery = False
inner_joins = False
supports_column_aliases = False
allows_subqueries = False
allows_joins = False
allows_column_aliases = False
# Pinot does its own conversion below
time_grain_functions: Dict[Optional[str], str] = {

View File

@ -771,7 +771,7 @@ class Database(Model, AuditMixinNullable, ImportMixin):
@property
def allows_subquery(self):
return self.db_engine_spec.allows_subquery
return self.db_engine_spec.allows_subqueries
@property
def data(self):

View File

@ -1008,12 +1008,7 @@ class Superset(BaseSupersetView):
logging.exception(e)
return json_error_response(e)
if query_obj and query_obj["prequeries"]:
query_obj["prequeries"].append(query)
query = ";\n\n".join(query_obj["prequeries"])
if query:
query += ";"
else:
if not query:
query = "No query."
return self.json_response(

View File

@ -320,8 +320,6 @@ class BaseViz(object):
"extras": extras,
"timeseries_limit_metric": timeseries_limit_metric,
"order_desc": order_desc,
"prequeries": [],
"is_prequery": False,
}
return d

View File

@ -195,11 +195,11 @@ class SqlaTableModelTestCase(SupersetTestCase):
ds_col.expression = None
ds_col.python_date_format = None
spec = self.get_database_by_id(tbl.database_id).db_engine_spec
if not spec.inner_joins and inner_join:
if not spec.allows_joins and inner_join:
# if the db does not support inner joins, we cannot force it so
return None
old_inner_join = spec.inner_joins
spec.inner_joins = inner_join
old_inner_join = spec.allows_joins
spec.allows_joins = inner_join
arbitrary_gby = "state || gender || '_test'"
arbitrary_metric = dict(
label="arbitrary", expressionType="SQL", sqlExpression="COUNT(1)"
@ -209,12 +209,10 @@ class SqlaTableModelTestCase(SupersetTestCase):
metrics=[arbitrary_metric],
filter=[],
is_timeseries=is_timeseries,
prequeries=[],
columns=[],
granularity="ds",
from_dttm=None,
to_dttm=None,
is_prequery=False,
extras=dict(time_grain_sqla="P1Y"),
)
qr = tbl.query(query_obj)
@ -226,7 +224,7 @@ class SqlaTableModelTestCase(SupersetTestCase):
self.assertIn("JOIN", sql.upper())
else:
self.assertNotIn("JOIN", sql.upper())
spec.inner_joins = old_inner_join
spec.allows_joins = old_inner_join
self.assertIsNotNone(qr.df)
return qr.df
@ -258,7 +256,6 @@ class SqlaTableModelTestCase(SupersetTestCase):
granularity=None,
from_dttm=None,
to_dttm=None,
is_prequery=False,
extras={},
)
sql = tbl.get_query_str(query_obj)
@ -285,7 +282,6 @@ class SqlaTableModelTestCase(SupersetTestCase):
granularity=None,
from_dttm=None,
to_dttm=None,
is_prequery=False,
extras={},
)
@ -306,7 +302,6 @@ class SqlaTableModelTestCase(SupersetTestCase):
granularity=None,
from_dttm=None,
to_dttm=None,
is_prequery=False,
extras={},
)

View File

@ -52,7 +52,6 @@ class DatabaseModelTestCase(SupersetTestCase):
"metrics": [],
"is_timeseries": False,
"filter": [],
"is_prequery": False,
"extras": {"where": "(user != '{{ cache_key_wrapper('user_2') }}')"},
}
extra_cache_keys = table.get_extra_cache_keys(query_obj)