From 47890743091eabe3ae22c1b46329d548fec592fa Mon Sep 17 00:00:00 2001 From: Jesse Yang Date: Thu, 1 Apr 2021 18:10:17 -0700 Subject: [PATCH] fix(sqla-query): order by aggregations in Presto and Hive (#13739) --- superset/connectors/sqla/models.py | 130 +++++++++++++++++++--------- superset/db_engine_specs/base.py | 38 +++++---- superset/db_engine_specs/hive.py | 3 + superset/db_engine_specs/pinot.py | 10 +-- superset/db_engine_specs/presto.py | 1 + superset/utils/core.py | 16 ++++ tests/conftest.py | 21 +++-- tests/databases/commands_tests.py | 5 +- tests/fixtures/query_context.py | 67 ++++++++++++++- tests/query_context_tests.py | 132 ++++++++++++++++++++++------- 10 files changed, 315 insertions(+), 108 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 2b61520cd..a433f03fe 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -16,6 +16,7 @@ # under the License. import json import logging +import re from collections import defaultdict, OrderedDict from contextlib import closing from dataclasses import dataclass, field # pylint: disable=wrong-import-order @@ -50,6 +51,7 @@ from sqlalchemy.schema import UniqueConstraint from sqlalchemy.sql import column, ColumnElement, literal_column, table, text from sqlalchemy.sql.elements import ColumnClause from sqlalchemy.sql.expression import Label, Select, TextAsFrom, TextClause +from sqlalchemy.sql.selectable import Alias, TableClause from sqlalchemy.types import TypeEngine from superset import app, db, is_feature_enabled, security_manager @@ -70,7 +72,7 @@ from superset.result_set import SupersetResultSet from superset.sql_parse import ParsedQuery from superset.typing import AdhocMetric, Metric, OrderBy, QueryObjectDict from superset.utils import core as utils -from superset.utils.core import GenericDataType +from superset.utils.core import GenericDataType, remove_duplicates config = app.config metadata = Model.metadata # pylint: disable=no-member @@ -465,7 +467,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at database_id = Column(Integer, ForeignKey("dbs.id"), nullable=False) fetch_values_predicate = Column(String(1000)) owners = relationship(owner_class, secondary=sqlatable_user, backref="tables") - database = relationship( + database: Database = relationship( "Database", backref=backref("tables", cascade="all, delete-orphan"), foreign_keys=[database_id], @@ -507,22 +509,6 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at "MAX": sa.func.MAX, } - def make_sqla_column_compatible( - self, sqla_col: Column, label: Optional[str] = None - ) -> Column: - """Takes a sqlalchemy column object and adds label info if supported by engine. - :param sqla_col: sqlalchemy column instance - :param label: alias/label that column is expected to have - :return: either a sql alchemy column or label instance if supported by engine - """ - label_expected = label or sqla_col.name - db_engine_spec = self.database.db_engine_spec - # add quotes to tables - if db_engine_spec.allows_alias_in_select: - label = db_engine_spec.make_label_compatible(label_expected) - sqla_col = sqla_col.label(label) - return sqla_col - def __repr__(self) -> str: return self.name @@ -708,11 +694,10 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at def data(self) -> Dict[str, Any]: data_ = super().data if self.type == "table": - grains = self.database.grains() or [] - if grains: - grains = [(g.duration, g.name) for g in grains] data_["granularity_sqla"] = utils.choicify(self.dttm_cols) - data_["time_grain_sqla"] = grains + data_["time_grain_sqla"] = [ + (g.duration, g.name) for g in self.database.grains() or [] + ] data_["main_dttm_col"] = self.main_dttm_col data_["fetch_values_predicate"] = self.fetch_values_predicate data_["template_params"] = self.template_params @@ -800,7 +785,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at all_queries = query_str_ext.prequeries + [query_str_ext.sql] return ";\n\n".join(all_queries) + ";" - def get_sqla_table(self) -> table: + def get_sqla_table(self) -> TableClause: tbl = table(self.table_name) if self.schema: tbl.schema = self.schema @@ -808,7 +793,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at def get_from_clause( self, template_processor: Optional[BaseTemplateProcessor] = None - ) -> Union[table, TextAsFrom]: + ) -> Union[TableClause, Alias]: """ Return where to select the columns and metrics from. Either a physical table or a virtual table with it's own subquery. @@ -882,6 +867,51 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at return self.make_sqla_column_compatible(sqla_metric, label) + def make_sqla_column_compatible( + self, sqla_col: Column, label: Optional[str] = None + ) -> Column: + """Takes a sqlalchemy column object and adds label info if supported by engine. + :param sqla_col: sqlalchemy column instance + :param label: alias/label that column is expected to have + :return: either a sql alchemy column or label instance if supported by engine + """ + label_expected = label or sqla_col.name + db_engine_spec = self.database.db_engine_spec + # add quotes to tables + if db_engine_spec.allows_alias_in_select: + label = db_engine_spec.make_label_compatible(label_expected) + sqla_col = sqla_col.label(label) + return sqla_col + + def make_orderby_compatible( + self, select_exprs: List[ColumnElement], orderby_exprs: List[ColumnElement] + ) -> None: + """ + If needed, make sure aliases for selected columns are not used in + `ORDER BY`. + + In some databases (e.g. Presto), `ORDER BY` clause is not able to + automatically pick the source column if a `SELECT` clause alias is named + the same as a source column. In this case, we update the SELECT alias to + another name to avoid the conflict. + """ + if self.database.db_engine_spec.allows_alias_to_source_column: + return + + def is_alias_used_in_orderby(col: ColumnElement) -> bool: + if not isinstance(col, Label): + return False + regexp = re.compile(f"\\(.*\\b{re.escape(col.name)}\\b.*\\)", re.IGNORECASE) + return any(regexp.search(str(x)) for x in orderby_exprs) + + # Iterate through selected columns, if column alias appears in orderby + # use another `alias`. The final output columns will still use the + # original names, because they are updated by `labels_expected` after + # querying. + for col in select_exprs: + if is_alias_used_in_orderby(col): + col.name = f"{col.name}__" + def _get_sqla_row_level_filters( self, template_processor: BaseTemplateProcessor ) -> List[str]: @@ -995,9 +1025,8 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at # To ensure correct handling of the ORDER BY labeling we need to reference the # metric instance if defined in the SELECT clause. - metrics_exprs_by_label = { - m.name: m for m in metrics_exprs # pylint: disable=protected-access - } + metrics_exprs_by_label = {m.name: m for m in metrics_exprs} + metrics_exprs_by_expr = {str(m): m for m in metrics_exprs} # Since orderby may use adhoc metrics, too; we need to process them first orderby_exprs: List[ColumnElement] = [] @@ -1007,21 +1036,25 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at if utils.is_adhoc_metric(col): # add adhoc sort by column to columns_by_name if not exists col = self.adhoc_metric_to_sqla(col, columns_by_name) + # if the adhoc metric has been defined before + # use the existing instance. + col = metrics_exprs_by_expr.get(str(col), col) need_groupby = True elif col in columns_by_name: col = columns_by_name[col].get_sqla_col() + elif col in metrics_exprs_by_label: + col = metrics_exprs_by_label[col] + need_groupby = True elif col in metrics_by_name: col = metrics_by_name[col].get_sqla_col() need_groupby = True - elif col in metrics_exprs_by_label: - col = metrics_exprs_by_label[col] if isinstance(col, ColumnElement): orderby_exprs.append(col) else: # Could not convert a column reference to valid ColumnElement raise QueryObjectValidationError( - _("Unknown column used in orderby: %(col)", col=orig_col) + _("Unknown column used in orderby: %(col)s", col=orig_col) ) select_exprs: List[Union[Column, Label]] = [] @@ -1093,11 +1126,21 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at dttm_col.get_time_filter(from_dttm, to_dttm, time_range_endpoints) ) - select_exprs += metrics_exprs - labels_expected = [c.name for c in select_exprs] - select_exprs = db_engine_spec.make_select_compatible( - groupby_exprs_with_timestamp.values(), select_exprs + # Always remove duplicates by column name, as sometimes `metrics_exprs` + # can have the same name as a groupby column (e.g. when users use + # raw columns as custom SQL adhoc metric). + select_exprs = remove_duplicates( + select_exprs + metrics_exprs, key=lambda x: x.name ) + + # Expected output columns + labels_expected = [c.name for c in select_exprs] + + # Order by columns are "hidden" columns, some databases require them + # always be present in SELECT if an aggregation function is used + if not db_engine_spec.allows_hidden_ordeby_agg: + select_exprs = remove_duplicates(select_exprs + orderby_exprs) + qry = sa.select(select_exprs) tbl = self.get_from_clause(template_processor) @@ -1213,12 +1256,13 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at qry = qry.where(and_(*where_clause_and)) qry = qry.having(and_(*having_clause_and)) + self.make_orderby_compatible(select_exprs, orderby_exprs) + for col, (orig_col, ascending) in zip(orderby_exprs, orderby): - if ( - db_engine_spec.allows_alias_in_orderby - and col.name in metrics_exprs_by_label - ): - col = Label(col.name, metrics_exprs_by_label[col.name]) + if not db_engine_spec.allows_alias_in_orderby and isinstance(col, Label): + # if engine does not allow using SELECT alias in ORDER BY + # revert to the underlying column + col = col.element direction = asc if ascending else desc qry = qry.order_by(direction(col)) @@ -1315,6 +1359,9 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at result.df, dimensions, groupby_exprs_sans_timestamp ) qry = qry.where(top_groups) + + qry = qry.select_from(tbl) + if is_rowcount: if not db_engine_spec.allows_subqueries: raise QueryObjectValidationError( @@ -1322,10 +1369,9 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at ) label = "rowcount" col = self.make_sqla_column_compatible(literal_column("COUNT(*)"), label) - qry = select([col]).select_from(qry.select_from(tbl).alias("rowcount_qry")) + qry = select([col]).select_from(qry.alias("rowcount_qry")) labels_expected = [label] - else: - qry = qry.select_from(tbl) + return SqlaQuery( extra_cache_keys=extra_cache_keys, labels_expected=labels_expected, diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index b0545e625..e59bc8465 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -49,7 +49,7 @@ from sqlalchemy.engine.url import URL from sqlalchemy.ext.compiler import compiles from sqlalchemy.orm import Session from sqlalchemy.sql import quoted_name, text -from sqlalchemy.sql.expression import ColumnClause, ColumnElement, Select, TextAsFrom +from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom from sqlalchemy.types import String, TypeEngine, UnicodeText from superset import app, security_manager, sql_parse @@ -137,7 +137,18 @@ class LimitMethod: # pylint: disable=too-few-public-methods class BaseEngineSpec: # pylint: disable=too-many-public-methods - """Abstract class for database engine specific configurations""" + """Abstract class for database engine specific configurations + + Attributes: + allows_alias_to_source_column: Whether the engine is able to pick the + source column for aggregation clauses + used in ORDER BY when a column in SELECT + has an alias that is the same as a source + column. + allows_hidden_orderby_agg: Whether the engine allows ORDER BY to + directly use aggregation clauses, without + having to add the same aggregation in SELECT. + """ engine = "base" # str as defined in sqlalchemy.engine.engine engine_aliases: Optional[Tuple[str]] = None @@ -241,6 +252,15 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods allows_alias_in_select = True allows_alias_in_orderby = True allows_sql_comments = True + + # Whether ORDER BY clause can use aliases created in SELECT + # that are the same as a source column + allows_alias_to_source_column = True + + # Whether ORDER BY clause must appear in SELECT + # if TRUE, then it doesn't have to. + allows_hidden_ordeby_agg = True + force_column_alias_quotes = False arraysize = 0 max_column_name_length = 0 @@ -441,20 +461,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods ) ) - @classmethod - def make_select_compatible( - cls, groupby_exprs: Dict[str, ColumnElement], select_exprs: List[ColumnElement] - ) -> List[ColumnElement]: - """ - Some databases will just return the group-by field into the select, but don't - allow the group-by field to be put into the select list. - - :param groupby_exprs: mapping between column name and column object - :param select_exprs: all columns in the select clause - :return: columns to be included in the final select clause - """ - return select_exprs - @classmethod def fetch_data( cls, cursor: Any, limit: Optional[int] = None diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index 51bedbee3..4234ddc63 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -81,6 +81,9 @@ class HiveEngineSpec(PrestoEngineSpec): engine = "hive" engine_name = "Apache Hive" max_column_name_length = 767 + allows_alias_to_source_column = True + allows_hidden_ordeby_agg = False + # pylint: disable=line-too-long _time_grain_expressions = { None: "{col}", diff --git a/superset/db_engine_specs/pinot.py b/superset/db_engine_specs/pinot.py index e2207305e..b07a6256d 100644 --- a/superset/db_engine_specs/pinot.py +++ b/superset/db_engine_specs/pinot.py @@ -14,9 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Dict, List, Optional +from typing import Dict, Optional -from sqlalchemy.sql.expression import ColumnClause, ColumnElement +from sqlalchemy.sql.expression import ColumnClause from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression @@ -112,9 +112,3 @@ class PinotEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method time_expr = f"DATETIMECONVERT({{col}}, '{tf}', '{tf}', '{granularity}')" return TimestampExpression(time_expr, col) - - @classmethod - def make_select_compatible( - cls, groupby_exprs: Dict[str, ColumnElement], select_exprs: List[ColumnElement] - ) -> List[ColumnElement]: - return select_exprs diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 27fad223e..f9524aa61 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -128,6 +128,7 @@ def get_children(column: Dict[str, str]) -> List[Dict[str, str]]: class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-methods engine = "presto" engine_name = "Presto" + allows_alias_to_source_column = False _time_grain_expressions = { None: "{col}", diff --git a/superset/utils/core.py b/superset/utils/core.py index 11842786c..882423ed5 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1631,6 +1631,22 @@ def find_duplicates(items: Iterable[InputType]) -> List[InputType]: return [item for item, count in collections.Counter(items).items() if count > 1] +def remove_duplicates( + items: Iterable[InputType], key: Optional[Callable[[InputType], Any]] = None +) -> List[InputType]: + """Remove duplicate items in an iterable.""" + if not key: + return list(dict.fromkeys(items).keys()) + seen = set() + result = [] + for item in items: + item_key = key(item) + if item_key not in seen: + seen.add(item_key) + result.append(item) + return result + + def normalize_dttm_col( df: pd.DataFrame, timestamp_format: Optional[str], diff --git a/tests/conftest.py b/tests/conftest.py index b8543851d..3bd5170d4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,7 +23,7 @@ from sqlalchemy.engine import Engine from tests.test_app import app from superset import db -from superset.utils.core import get_example_database +from superset.utils.core import get_example_database, json_dumps_w_dates CTAS_SCHEMA_NAME = "sqllab_test_db" @@ -73,13 +73,22 @@ def drop_from_schema(engine: Engine, schema_name: str): def setup_presto_if_needed(): backend = app.config["SQLALCHEMY_EXAMPLES_URI"].split("://")[0] + database = get_example_database() + extra = database.get_extra() + if backend == "presto": # decrease poll interval for tests - presto_poll_interval = app.config["PRESTO_POLL_INTERVAL"] - extra = f'{{"engine_params": {{"connect_args": {{"poll_interval": {presto_poll_interval}}}}}}}' - database = get_example_database() - database.extra = extra - db.session.commit() + extra = { + **extra, + "engine_params": { + "connect_args": {"poll_interval": app.config["PRESTO_POLL_INTERVAL"]} + }, + } + else: + # remove `poll_interval` from databases that do not support it + extra = {**extra, "engine_params": {}} + database.extra = json_dumps_w_dates(extra) + db.session.commit() if backend in {"presto", "hive"}: database = get_example_database() diff --git a/tests/databases/commands_tests.py b/tests/databases/commands_tests.py index 5594b56f5..0ff25194b 100644 --- a/tests/databases/commands_tests.py +++ b/tests/databases/commands_tests.py @@ -82,7 +82,10 @@ class TestExportDatabasesCommand(SupersetTestCase): "schemas_allowed_for_csv_upload": [], } if backend() == "presto": - expected_extra = {"engine_params": {"connect_args": {"poll_interval": 0.1}}} + expected_extra = { + **expected_extra, + "engine_params": {"connect_args": {"poll_interval": 0.1}}, + } assert core_files.issubset(set(contents.keys())) diff --git a/tests/fixtures/query_context.py b/tests/fixtures/query_context.py index 38e156aae..12fddc630 100644 --- a/tests/fixtures/query_context.py +++ b/tests/fixtures/query_context.py @@ -31,13 +31,13 @@ query_birth_names = { }, "groupby": ["name"], "metrics": [{"label": "sum__num"}], - "order_desc": True, "orderby": [["sum__num", False]], "row_limit": 100, "granularity": "ds", "time_range": "100 years ago : now", "timeseries_limit": 0, "timeseries_limit_metric": None, + "order_desc": True, "filters": [ {"col": "gender", "op": "==", "val": "boy"}, {"col": "num", "op": "IS NOT NULL"}, @@ -49,8 +49,57 @@ query_birth_names = { } QUERY_OBJECTS: Dict[str, Dict[str, object]] = { - "birth_names": {**query_birth_names, "is_timeseries": False,}, - "birth_names:include_time": {**query_birth_names, "groupby": [DTTM_ALIAS, "name"],}, + "birth_names": query_birth_names, + # `:suffix` are overrides only + "birth_names:include_time": {"groupby": [DTTM_ALIAS, "name"],}, + "birth_names:orderby_dup_alias": { + "metrics": [ + { + "expressionType": "SIMPLE", + "column": {"column_name": "num_girls", "type": "BIGINT(20)"}, + "aggregate": "SUM", + "label": "num_girls", + }, + { + "expressionType": "SIMPLE", + "column": {"column_name": "num_boys", "type": "BIGINT(20)"}, + "aggregate": "SUM", + "label": "num_boys", + }, + ], + "orderby": [ + [ + { + "expressionType": "SIMPLE", + "column": {"column_name": "num_girls", "type": "BIGINT(20)"}, + "aggregate": "SUM", + # the same underlying expression, but different label + "label": "SUM(num_girls)", + }, + False, + ], + # reference the ambiguous alias in SIMPLE metric + [ + { + "expressionType": "SIMPLE", + "column": {"column_name": "num_boys", "type": "BIGINT(20)"}, + "aggregate": "AVG", + "label": "AVG(num_boys)", + }, + False, + ], + # reference the ambiguous alias in CUSTOM SQL metric + [ + { + "expressionType": "SQL", + "sqlExpression": "MAX(CASE WHEN num_boys > 0 THEN 1 ELSE 0 END)", + "label": "MAX(CASE WHEN...", + }, + True, + ], + ], + }, + "birth_names:only_orderby_has_metric": {"metrics": [],}, } ANNOTATION_LAYERS = { @@ -150,7 +199,17 @@ def get_query_object( ) -> Dict[str, Any]: if query_name not in QUERY_OBJECTS: raise Exception(f"QueryObject fixture not defined for datasource: {query_name}") - query_object = copy.deepcopy(QUERY_OBJECTS[query_name]) + obj = QUERY_OBJECTS[query_name] + + # apply overrides + if ":" in query_name: + parent_query_name = query_name.split(":")[0] + obj = { + **QUERY_OBJECTS[parent_query_name], + **obj, + } + + query_object = copy.deepcopy(obj) if add_postprocessing_operations: query_object["post_processing"] = _get_postprocessing_operation(query_name) return query_object diff --git a/tests/query_context_tests.py b/tests/query_context_tests.py index 377f717be..660bcddcc 100644 --- a/tests/query_context_tests.py +++ b/tests/query_context_tests.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import re +from typing import Any, Dict import pytest @@ -24,9 +25,9 @@ from superset.common.query_context import QueryContext from superset.common.query_object import QueryObject from superset.connectors.connector_registry import ConnectorRegistry from superset.extensions import cache_manager -from superset.models.cache import CacheKey from superset.utils.core import ( AdhocMetricExpressionType, + backend, ChartDataResultFormat, ChartDataResultType, TimeRangeEndpoint, @@ -36,6 +37,17 @@ from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with from tests.fixtures.query_context import get_query_context +def get_sql_text(payload: Dict[str, Any]) -> str: + payload["result_type"] = ChartDataResultType.QUERY.value + query_context = ChartDataQueryContextSchema().load(payload) + responses = query_context.get_payload() + assert len(responses) == 1 + response = responses["queries"][0] + assert len(response) == 2 + assert response["language"] == "sql" + return response["query"] + + class TestQueryContext(SupersetTestCase): def test_schema_deserialization(self): """ @@ -301,14 +313,7 @@ class TestQueryContext(SupersetTestCase): """ self.login(username="admin") payload = get_query_context("birth_names") - payload["result_type"] = ChartDataResultType.QUERY.value - query_context = ChartDataQueryContextSchema().load(payload) - responses = query_context.get_payload() - assert len(responses) == 1 - response = responses["queries"][0] - assert len(response) == 2 - sql_text = response["query"] - assert response["language"] == "sql" + sql_text = get_sql_text(payload) assert "SELECT" in sql_text assert re.search(r'[`"\[]?num[`"\]]? IS NOT NULL', sql_text) assert re.search( @@ -318,37 +323,102 @@ class TestQueryContext(SupersetTestCase): ) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_fetch_values_predicate_in_query(self): + def test_handle_sort_by_metrics(self): """ - Ensure that fetch values predicate is added to query + Should properly handle sort by metrics in various scenarios. """ self.login(username="admin") - payload = get_query_context("birth_names") - payload["result_type"] = ChartDataResultType.QUERY.value - payload["queries"][0]["apply_fetch_values_predicate"] = True - query_context = ChartDataQueryContextSchema().load(payload) - responses = query_context.get_payload() - assert len(responses) == 1 - response = responses["queries"][0] - assert len(response) == 2 - assert response["language"] == "sql" - assert "123 = 123" in response["query"] + + sql_text = get_sql_text(get_query_context("birth_names")) + if backend() == "hive": + # should have no duplicate `SUM(num)` + assert "SUM(num) AS `sum__num`," not in sql_text + assert "SUM(num) AS `sum__num`" in sql_text + # the alias should be in ORDER BY + assert "ORDER BY `sum__num` DESC" in sql_text + else: + assert re.search(r'ORDER BY [`"\[]?sum__num[`"\]]? DESC', sql_text) + + sql_text = get_sql_text( + get_query_context("birth_names:only_orderby_has_metric") + ) + if backend() == "hive": + assert "SUM(num) AS `sum__num`," not in sql_text + assert "SUM(num) AS `sum__num`" in sql_text + assert "ORDER BY `sum__num` DESC" in sql_text + else: + assert re.search( + r'ORDER BY SUM\([`"\[]?num[`"\]]?\) DESC', sql_text, re.IGNORECASE + ) + + sql_text = get_sql_text(get_query_context("birth_names:orderby_dup_alias")) + + # Check SELECT clauses + if backend() == "presto": + # presto cannot have ambiguous alias in order by, so selected column + # alias is renamed. + assert 'sum("num_boys") AS "num_boys__"' in sql_text + else: + assert re.search( + r'SUM\([`"\[]?num_boys[`"\]]?\) AS [`\"\[]?num_boys[`"\]]?', + sql_text, + re.IGNORECASE, + ) + + # Check ORDER BY clauses + if backend() == "hive": + # Hive must add additional SORT BY metrics to SELECT + assert re.search( + r"MAX\(CASE.*END\) AS `MAX\(CASE WHEN...`", + sql_text, + re.IGNORECASE | re.DOTALL, + ) + + # The additional column with the same expression but a different label + # as an existing metric should not be added + assert "sum(`num_girls`) AS `SUM(num_girls)`" not in sql_text + + # Should reference all ORDER BY columns by aliases + assert "ORDER BY `num_girls` DESC," in sql_text + assert "`AVG(num_boys)` DESC," in sql_text + assert "`MAX(CASE WHEN...` ASC" in sql_text + else: + if backend() == "presto": + # since the selected `num_boys` is renamed to `num_boys__` + # it must be references as expression + assert re.search( + r'ORDER BY SUM\([`"\[]?num_girls[`"\]]?\) DESC', + sql_text, + re.IGNORECASE, + ) + else: + # Should reference the adhoc metric by alias when possible + assert re.search( + r'ORDER BY [`"\[]?num_girls[`"\]]? DESC', sql_text, re.IGNORECASE, + ) + + # ORDER BY only columns should always be expressions + assert re.search( + r'AVG\([`"\[]?num_boys[`"\]]?\) DESC', sql_text, re.IGNORECASE, + ) + assert re.search( + r"MAX\(CASE.*END\) ASC", sql_text, re.IGNORECASE | re.DOTALL + ) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_fetch_values_predicate_not_in_query(self): + def test_fetch_values_predicate(self): """ - Ensure that fetch values predicate is not added to query + Ensure that fetch values predicate is added to query if needed """ self.login(username="admin") + payload = get_query_context("birth_names") - payload["result_type"] = ChartDataResultType.QUERY.value - query_context = ChartDataQueryContextSchema().load(payload) - responses = query_context.get_payload() - assert len(responses) == 1 - response = responses["queries"][0] - assert len(response) == 2 - assert response["language"] == "sql" - assert "123 = 123" not in response["query"] + sql_text = get_sql_text(payload) + assert "123 = 123" not in sql_text + + payload["queries"][0]["apply_fetch_values_predicate"] = True + sql_text = get_sql_text(payload) + assert "123 = 123" in sql_text def test_query_object_unknown_fields(self): """