diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 74e8e6c3d..ca1d4bc57 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -77,7 +77,7 @@ from superset.connectors.sqla.utils import ( get_physical_table_metadata, get_virtual_table_metadata, ) -from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression +from superset.db_engine_specs.base import BaseEngineSpec, CTE_ALIAS, TimestampExpression from superset.exceptions import QueryObjectValidationError from superset.jinja_context import ( BaseTemplateProcessor, @@ -107,6 +107,7 @@ VIRTUAL_TABLE_ALIAS = "virtual_table" class SqlaQuery(NamedTuple): applied_template_filters: List[str] + cte: Optional[str] extra_cache_keys: List[Any] labels_expected: List[str] prequeries: List[str] @@ -562,6 +563,19 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho def __repr__(self) -> str: return self.name + @staticmethod + def _apply_cte(sql: str, cte: Optional[str]) -> str: + """ + Append a CTE before the SELECT statement if defined + + :param sql: SELECT statement + :param cte: CTE statement + :return: + """ + if cte: + sql = f"{cte}\n{sql}" + return sql + @property def db_engine_spec(self) -> Type[BaseEngineSpec]: return self.database.db_engine_spec @@ -743,12 +757,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho cols = {col.column_name: col for col in self.columns} target_col = cols[column_name] tp = self.get_template_processor() + tbl, cte = self.get_from_clause(tp) - qry = ( - select([target_col.get_sqla_col()]) - .select_from(self.get_from_clause(tp)) - .distinct() - ) + qry = select([target_col.get_sqla_col()]).select_from(tbl).distinct() if limit: qry = qry.limit(limit) @@ -756,7 +767,8 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho qry = qry.where(self.get_fetch_values_predicate()) engine = self.database.get_sqla_engine() - sql = "{}".format(qry.compile(engine, compile_kwargs={"literal_binds": True})) + sql = qry.compile(engine, compile_kwargs={"literal_binds": True}) + sql = self._apply_cte(sql, cte) sql = self.mutate_query_from_config(sql) df = pd.read_sql_query(sql=sql, con=engine) @@ -778,6 +790,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho def get_query_str_extended(self, query_obj: QueryObjectDict) -> QueryStringExtended: sqlaq = self.get_sqla_query(**query_obj) sql = self.database.compile_sqla_query(sqlaq.sqla_query) + sql = self._apply_cte(sql, sqlaq.cte) sql = sqlparse.format(sql, reindent=True) sql = self.mutate_query_from_config(sql) return QueryStringExtended( @@ -800,13 +813,14 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho def get_from_clause( self, template_processor: Optional[BaseTemplateProcessor] = None - ) -> Union[TableClause, Alias]: + ) -> Tuple[Union[TableClause, Alias], Optional[str]]: """ Return where to select the columns and metrics from. Either a physical table - or a virtual table with it's own subquery. + or a virtual table with it's own subquery. If the FROM is referencing a + CTE, the CTE is returned as the second value in the return tuple. """ if not self.is_virtual: - return self.get_sqla_table() + return self.get_sqla_table(), None from_sql = self.get_rendered_sql(template_processor) parsed_query = ParsedQuery(from_sql) @@ -817,7 +831,15 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho raise QueryObjectValidationError( _("Virtual dataset query must be read-only") ) - return TextAsFrom(self.text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS) + + cte = self.db_engine_spec.get_cte_query(from_sql) + from_clause = ( + table(CTE_ALIAS) + if cte + else TextAsFrom(self.text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS) + ) + + return from_clause, cte def get_rendered_sql( self, template_processor: Optional[BaseTemplateProcessor] = None @@ -1224,7 +1246,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho qry = sa.select(select_exprs) - tbl = self.get_from_clause(template_processor) + tbl, cte = self.get_from_clause(template_processor) if groupby_all_columns: qry = qry.group_by(*groupby_all_columns.values()) @@ -1491,6 +1513,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho return SqlaQuery( applied_template_filters=applied_template_filters, + cte=cte, extra_cache_keys=extra_cache_keys, labels_expected=labels_expected, sqla_query=qry, diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index bdd1922d2..764f3fde7 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -54,6 +54,7 @@ from sqlalchemy.orm import Session from sqlalchemy.sql import quoted_name, text from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom, TextClause from sqlalchemy.types import TypeEngine +from sqlparse.tokens import CTE from typing_extensions import TypedDict from superset import security_manager, sql_parse @@ -80,6 +81,9 @@ ColumnTypeMapping = Tuple[ logger = logging.getLogger() +CTE_ALIAS = "__cte" + + class TimeGrain(NamedTuple): name: str # TODO: redundant field, remove label: str @@ -292,6 +296,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods # But for backward compatibility, False by default allows_hidden_cc_in_orderby = False + # Whether allow CTE as subquery or regular CTE + # If True, then it will allow in subquery , + # if False it will allow as regular CTE + allows_cte_in_subquery = True + force_column_alias_quotes = False arraysize = 0 max_column_name_length = 0 @@ -663,6 +672,31 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods parsed_query = sql_parse.ParsedQuery(sql) return parsed_query.set_or_update_query_limit(limit) + @classmethod + def get_cte_query(cls, sql: str) -> Optional[str]: + """ + Convert the input CTE based SQL to the SQL for virtual table conversion + + :param sql: SQL query + :return: CTE with the main select query aliased as `__cte` + + """ + if not cls.allows_cte_in_subquery: + stmt = sqlparse.parse(sql)[0] + + # The first meaningful token for CTE will be with WITH + idx, token = stmt.token_next(-1, skip_ws=True, skip_cm=True) + if not (token and token.ttype == CTE): + return None + idx, token = stmt.token_next(idx) + idx = stmt.token_index(token) + 1 + + # extract rest of the SQLs after CTE + remainder = "".join(str(token) for token in stmt.tokens[idx:]).strip() + return f"WITH {token.value},\n{CTE_ALIAS} AS (\n{remainder}\n)" + + return None + @classmethod def df_to_sql( cls, diff --git a/superset/db_engine_specs/mssql.py b/superset/db_engine_specs/mssql.py index 874992847..e5c66e046 100644 --- a/superset/db_engine_specs/mssql.py +++ b/superset/db_engine_specs/mssql.py @@ -47,6 +47,7 @@ class MssqlEngineSpec(BaseEngineSpec): engine_name = "Microsoft SQL Server" limit_method = LimitMethod.WRAP_SQL max_column_name_length = 128 + allows_cte_in_subquery = False _time_grain_expressions = { None: "{col}", diff --git a/tests/integration_tests/core_tests.py b/tests/integration_tests/core_tests.py index 43288e08a..a4c95a20d 100644 --- a/tests/integration_tests/core_tests.py +++ b/tests/integration_tests/core_tests.py @@ -984,7 +984,7 @@ class TestCore(SupersetTestCase): sql=commented_query, database=get_example_database(), ) - rendered_query = str(table.get_from_clause()) + rendered_query = str(table.get_from_clause()[0]) self.assertEqual(clean_query, rendered_query) def test_slice_payload_no_datasource(self): diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py index d822f50de..4dc27c092 100644 --- a/tests/unit_tests/db_engine_specs/test_base.py +++ b/tests/unit_tests/db_engine_specs/test_base.py @@ -16,7 +16,11 @@ # under the License. # pylint: disable=unused-argument, import-outside-toplevel, protected-access +from textwrap import dedent + +import pytest from flask.ctx import AppContext +from sqlalchemy.types import TypeEngine def test_get_text_clause_with_colon(app_context: AppContext) -> None: @@ -56,3 +60,42 @@ def test_parse_sql_multi_statement(app_context: AppContext) -> None: "SELECT foo FROM tbl1", "SELECT bar FROM tbl2", ] + + +@pytest.mark.parametrize( + "original,expected", + [ + ( + dedent( + """ +with currency as +( +select 'INR' as cur +) +select * from currency +""" + ), + None, + ), + ("SELECT 1 as cnt", None,), + ( + dedent( + """ +select 'INR' as cur +union +select 'AUD' as cur +union +select 'USD' as cur +""" + ), + None, + ), + ], +) +def test_cte_query_parsing( + app_context: AppContext, original: TypeEngine, expected: str +) -> None: + from superset.db_engine_specs.base import BaseEngineSpec + + actual = BaseEngineSpec.get_cte_query(original) + assert actual == expected diff --git a/tests/unit_tests/db_engine_specs/test_mssql.py b/tests/unit_tests/db_engine_specs/test_mssql.py index 75d2dcb10..250b8158f 100644 --- a/tests/unit_tests/db_engine_specs/test_mssql.py +++ b/tests/unit_tests/db_engine_specs/test_mssql.py @@ -180,6 +180,57 @@ def test_column_datatype_to_string( assert actual == expected +@pytest.mark.parametrize( + "original,expected", + [ + ( + dedent( + """ +with currency as ( +select 'INR' as cur +), +currency_2 as ( +select 'EUR' as cur +) +select * from currency union all select * from currency_2 +""" + ), + dedent( + """WITH currency as ( +select 'INR' as cur +), +currency_2 as ( +select 'EUR' as cur +), +__cte AS ( +select * from currency union all select * from currency_2 +)""" + ), + ), + ("SELECT 1 as cnt", None,), + ( + dedent( + """ +select 'INR' as cur +union +select 'AUD' as cur +union +select 'USD' as cur +""" + ), + None, + ), + ], +) +def test_cte_query_parsing( + app_context: AppContext, original: TypeEngine, expected: str +) -> None: + from superset.db_engine_specs.mssql import MssqlEngineSpec + + actual = MssqlEngineSpec.get_cte_query(original) + assert actual == expected + + def test_extract_errors(app_context: AppContext) -> None: """ Test that custom error messages are extracted correctly.