feat: push predicates into virtual datasets (#31486)

This commit is contained in:
Beto Dealmeida 2025-01-08 22:11:28 -05:00 committed by GitHub
parent f29eafd044
commit e4b3ecd372
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 191 additions and 5 deletions

View File

@ -517,6 +517,8 @@ DEFAULT_FEATURE_FLAGS: dict[str, bool] = {
# Apply RLS rules to SQL Lab queries. This requires parsing and manipulating the # Apply RLS rules to SQL Lab queries. This requires parsing and manipulating the
# query, and might break queries and/or allow users to bypass RLS. Use with care! # query, and might break queries and/or allow users to bypass RLS. Use with care!
"RLS_IN_SQLLAB": False, "RLS_IN_SQLLAB": False,
# Try to optimize SQL queries — for now only predicate pushdown is supported.
"OPTIMIZE_SQL": False,
# When impersonating a user, use the email prefix instead of the username # When impersonating a user, use the email prefix instead of the username
"IMPERSONATE_WITH_EMAIL_PREFIX": False, "IMPERSONATE_WITH_EMAIL_PREFIX": False,
# Enable caching per impersonation key (e.g username) in a datasource where user # Enable caching per impersonation key (e.g username) in a datasource where user

View File

@ -1568,7 +1568,11 @@ class SqlaTable(
# probe adhoc column type # probe adhoc column type
tbl, _ = self.get_from_clause(template_processor) tbl, _ = self.get_from_clause(template_processor)
qry = sa.select([sqla_column]).limit(1).select_from(tbl) qry = sa.select([sqla_column]).limit(1).select_from(tbl)
sql = self.database.compile_sqla_query(qry) sql = self.database.compile_sqla_query(
qry,
catalog=self.catalog,
schema=self.schema,
)
col_desc = get_columns_description( col_desc = get_columns_description(
self.database, self.database,
self.catalog, self.catalog,

View File

@ -1701,7 +1701,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
) )
if partition_query is not None: if partition_query is not None:
qry = partition_query qry = partition_query
sql = database.compile_sqla_query(qry) sql = database.compile_sqla_query(qry, table.catalog, table.schema)
if indent: if indent:
sql = SQLScript(sql, engine=cls.engine).format() sql = SQLScript(sql, engine=cls.engine).format()
return sql return sql

View File

@ -74,6 +74,7 @@ from superset.extensions import (
) )
from superset.models.helpers import AuditMixinNullable, ImportExportMixin, UUIDMixin from superset.models.helpers import AuditMixinNullable, ImportExportMixin, UUIDMixin
from superset.result_set import SupersetResultSet from superset.result_set import SupersetResultSet
from superset.sql.parse import SQLScript
from superset.sql_parse import Table from superset.sql_parse import Table
from superset.superset_typing import ( from superset.superset_typing import (
DbapiDescription, DbapiDescription,
@ -740,6 +741,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
qry: Select, qry: Select,
catalog: str | None = None, catalog: str | None = None,
schema: str | None = None, schema: str | None = None,
is_virtual: bool = False,
) -> str: ) -> str:
with self.get_sqla_engine(catalog=catalog, schema=schema) as engine: with self.get_sqla_engine(catalog=catalog, schema=schema) as engine:
sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True})) sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True}))
@ -748,6 +750,12 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
if engine.dialect.identifier_preparer._double_percents: # noqa if engine.dialect.identifier_preparer._double_percents: # noqa
sql = sql.replace("%%", "%") sql = sql.replace("%%", "%")
# for nwo we only optimize queries on virtual datasources, since the only
# optimization available is predicate pushdown
if is_feature_enabled("OPTIMIZE_SQL") and is_virtual:
script = SQLScript(sql, self.db_engine_spec.engine).optimize()
sql = script.format()
return sql return sql
def select_star( # pylint: disable=too-many-arguments def select_star( # pylint: disable=too-many-arguments

View File

@ -883,7 +883,12 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
mutate: bool = True, mutate: bool = True,
) -> QueryStringExtended: ) -> QueryStringExtended:
sqlaq = self.get_sqla_query(**query_obj) sqlaq = self.get_sqla_query(**query_obj)
sql = self.database.compile_sqla_query(sqlaq.sqla_query) sql = self.database.compile_sqla_query(
sqlaq.sqla_query,
catalog=self.catalog,
schema=self.schema,
is_virtual=bool(self.sql),
)
sql = self._apply_cte(sql, sqlaq.cte) sql = self._apply_cte(sql, sqlaq.cte)
if mutate: if mutate:

View File

@ -17,6 +17,7 @@
from __future__ import annotations from __future__ import annotations
import copy
import enum import enum
import logging import logging
import re import re
@ -31,6 +32,7 @@ from deprecation import deprecated
from sqlglot import exp from sqlglot import exp
from sqlglot.dialects.dialect import Dialect, Dialects from sqlglot.dialects.dialect import Dialect, Dialects
from sqlglot.errors import ParseError from sqlglot.errors import ParseError
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
from superset.exceptions import SupersetParseError from superset.exceptions import SupersetParseError
@ -227,6 +229,12 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
""" """
raise NotImplementedError() raise NotImplementedError()
def optimize(self) -> BaseSQLStatement[InternalRepresentation]:
"""
Return optimized statement.
"""
raise NotImplementedError()
def __str__(self) -> str: def __str__(self) -> str:
return self.format() return self.format()
@ -431,6 +439,19 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
for eq in set_item.find_all(exp.EQ) for eq in set_item.find_all(exp.EQ)
} }
def optimize(self) -> SQLStatement:
"""
Return optimized statement.
"""
# only optimize statements that have a custom dialect
if not self._dialect:
return SQLStatement(self._sql, self.engine, self._parsed.copy())
optimized = pushdown_predicates(self._parsed, dialect=self._dialect)
sql = optimized.sql(dialect=self._dialect)
return SQLStatement(sql, self.engine, optimized)
class KQLSplitState(enum.Enum): class KQLSplitState(enum.Enum):
""" """
@ -589,6 +610,14 @@ class KustoKQLStatement(BaseSQLStatement[str]):
""" """
return self._parsed.startswith(".") and not self._parsed.startswith(".show") return self._parsed.startswith(".") and not self._parsed.startswith(".show")
def optimize(self) -> KustoKQLStatement:
"""
Return optimized statement.
Kusto KQL doesn't support optimization, so this method is a no-op.
"""
return KustoKQLStatement(self._sql, self.engine, self._parsed)
class SQLScript: class SQLScript:
""" """
@ -643,6 +672,17 @@ class SQLScript:
""" """
return any(statement.is_mutating() for statement in self.statements) return any(statement.is_mutating() for statement in self.statements)
def optimize(self) -> SQLScript:
"""
Return optimized script.
"""
script = copy.deepcopy(self)
script.statements = [ # type: ignore
statement.optimize() for statement in self.statements
]
return script
def extract_tables_from_statement( def extract_tables_from_statement(
statement: exp.Expression, statement: exp.Expression,

View File

@ -226,7 +226,7 @@ def test_select_star(mocker: MockerFixture) -> None:
# mock the database so we can compile the query # mock the database so we can compile the query
database = mocker.MagicMock() database = mocker.MagicMock()
database.compile_sqla_query = lambda query: str( database.compile_sqla_query = lambda query, catalog, schema: str(
query.compile(dialect=sqlite.dialect()) query.compile(dialect=sqlite.dialect())
) )

View File

@ -149,7 +149,7 @@ def test_select_star(mocker: MockerFixture) -> None:
# mock the database so we can compile the query # mock the database so we can compile the query
database = mocker.MagicMock() database = mocker.MagicMock()
database.compile_sqla_query = lambda query: str( database.compile_sqla_query = lambda query, catalog, schema: str(
query.compile(dialect=BigQueryDialect(), compile_kwargs={"literal_binds": True}) query.compile(dialect=BigQueryDialect(), compile_kwargs={"literal_binds": True})
) )

View File

@ -21,9 +21,17 @@ from datetime import datetime
import pytest import pytest
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from sqlalchemy import (
Column,
Integer,
MetaData,
select,
Table as SqlalchemyTable,
)
from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import make_url from sqlalchemy.engine.url import make_url
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from sqlalchemy.sql import Select
from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.errors import SupersetErrorType from superset.errors import SupersetErrorType
@ -45,6 +53,29 @@ oauth2_client_info = {
} }
@pytest.fixture
def query() -> Select:
"""
A nested query fixture used to test query optimization.
"""
metadata = MetaData()
some_table = SqlalchemyTable(
"some_table",
metadata,
Column("a", Integer),
Column("b", Integer),
Column("c", Integer),
)
inner_select = select(some_table.c.a, some_table.c.b, some_table.c.c)
outer_select = select(inner_select.c.a, inner_select.c.b).where(
inner_select.c.a > 1,
inner_select.c.b == 2,
)
return outer_select
def test_get_metrics(mocker: MockerFixture) -> None: def test_get_metrics(mocker: MockerFixture) -> None:
""" """
Tests for ``get_metrics``. Tests for ``get_metrics``.
@ -683,3 +714,56 @@ def test_purge_oauth2_tokens(session: Session) -> None:
# make sure database was not deleted... just in case # make sure database was not deleted... just in case
database = session.query(Database).filter_by(id=database1.id).one() database = session.query(Database).filter_by(id=database1.id).one()
assert database.name == "my_oauth2_db" assert database.name == "my_oauth2_db"
def test_compile_sqla_query_no_optimization(query: Select) -> None:
"""
Test the `compile_sqla_query` method.
"""
from superset.models.core import Database
database = Database(
database_name="db",
sqlalchemy_uri="sqlite://",
)
space = " "
assert (
database.compile_sqla_query(query, is_virtual=True)
== f"""SELECT anon_1.a, anon_1.b{space}
FROM (SELECT some_table.a AS a, some_table.b AS b, some_table.c AS c{space}
FROM some_table) AS anon_1{space}
WHERE anon_1.a > 1 AND anon_1.b = 2"""
)
@with_feature_flags(OPTIMIZE_SQL=True)
def test_compile_sqla_query(query: Select) -> None:
"""
Test the `compile_sqla_query` method.
"""
from superset.models.core import Database
database = Database(
database_name="db",
sqlalchemy_uri="sqlite://",
)
assert (
database.compile_sqla_query(query, is_virtual=True)
== """SELECT
anon_1.a,
anon_1.b
FROM (
SELECT
some_table.a AS a,
some_table.b AS b,
some_table.c AS c
FROM some_table
WHERE
some_table.a > 1 AND some_table.b = 2
) AS anon_1
WHERE
TRUE AND TRUE"""
)

View File

@ -1070,3 +1070,46 @@ def test_is_mutating(engine: str) -> None:
"with source as ( select 1 as one ) select * from source", "with source as ( select 1 as one ) select * from source",
engine=engine, engine=engine,
).is_mutating() ).is_mutating()
def test_optimize() -> None:
"""
Test that the `optimize` method works as expected.
The SQL optimization only works with engines that have a corresponding dialect.
"""
sql = """
SELECT anon_1.a, anon_1.b
FROM (SELECT some_table.a AS a, some_table.b AS b, some_table.c AS c
FROM some_table) AS anon_1
WHERE anon_1.a > 1 AND anon_1.b = 2
"""
optimized = """SELECT
anon_1.a,
anon_1.b
FROM (
SELECT
some_table.a AS a,
some_table.b AS b,
some_table.c AS c
FROM some_table
WHERE
some_table.a > 1 AND some_table.b = 2
) AS anon_1
WHERE
TRUE AND TRUE"""
not_optimized = """
SELECT anon_1.a,
anon_1.b
FROM
(SELECT some_table.a AS a,
some_table.b AS b,
some_table.c AS c
FROM some_table) AS anon_1
WHERE anon_1.a > 1
AND anon_1.b = 2"""
assert SQLStatement(sql, "sqlite").optimize().format() == optimized
assert SQLStatement(sql, "firebolt").optimize().format() == not_optimized