feat: push predicates into virtual datasets (#31486)
This commit is contained in:
parent
f29eafd044
commit
e4b3ecd372
|
|
@ -517,6 +517,8 @@ DEFAULT_FEATURE_FLAGS: dict[str, bool] = {
|
||||||
# Apply RLS rules to SQL Lab queries. This requires parsing and manipulating the
|
# Apply RLS rules to SQL Lab queries. This requires parsing and manipulating the
|
||||||
# query, and might break queries and/or allow users to bypass RLS. Use with care!
|
# query, and might break queries and/or allow users to bypass RLS. Use with care!
|
||||||
"RLS_IN_SQLLAB": False,
|
"RLS_IN_SQLLAB": False,
|
||||||
|
# Try to optimize SQL queries — for now only predicate pushdown is supported.
|
||||||
|
"OPTIMIZE_SQL": False,
|
||||||
# When impersonating a user, use the email prefix instead of the username
|
# When impersonating a user, use the email prefix instead of the username
|
||||||
"IMPERSONATE_WITH_EMAIL_PREFIX": False,
|
"IMPERSONATE_WITH_EMAIL_PREFIX": False,
|
||||||
# Enable caching per impersonation key (e.g username) in a datasource where user
|
# Enable caching per impersonation key (e.g username) in a datasource where user
|
||||||
|
|
|
||||||
|
|
@ -1568,7 +1568,11 @@ class SqlaTable(
|
||||||
# probe adhoc column type
|
# probe adhoc column type
|
||||||
tbl, _ = self.get_from_clause(template_processor)
|
tbl, _ = self.get_from_clause(template_processor)
|
||||||
qry = sa.select([sqla_column]).limit(1).select_from(tbl)
|
qry = sa.select([sqla_column]).limit(1).select_from(tbl)
|
||||||
sql = self.database.compile_sqla_query(qry)
|
sql = self.database.compile_sqla_query(
|
||||||
|
qry,
|
||||||
|
catalog=self.catalog,
|
||||||
|
schema=self.schema,
|
||||||
|
)
|
||||||
col_desc = get_columns_description(
|
col_desc = get_columns_description(
|
||||||
self.database,
|
self.database,
|
||||||
self.catalog,
|
self.catalog,
|
||||||
|
|
|
||||||
|
|
@ -1701,7 +1701,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||||
)
|
)
|
||||||
if partition_query is not None:
|
if partition_query is not None:
|
||||||
qry = partition_query
|
qry = partition_query
|
||||||
sql = database.compile_sqla_query(qry)
|
sql = database.compile_sqla_query(qry, table.catalog, table.schema)
|
||||||
if indent:
|
if indent:
|
||||||
sql = SQLScript(sql, engine=cls.engine).format()
|
sql = SQLScript(sql, engine=cls.engine).format()
|
||||||
return sql
|
return sql
|
||||||
|
|
|
||||||
|
|
@ -74,6 +74,7 @@ from superset.extensions import (
|
||||||
)
|
)
|
||||||
from superset.models.helpers import AuditMixinNullable, ImportExportMixin, UUIDMixin
|
from superset.models.helpers import AuditMixinNullable, ImportExportMixin, UUIDMixin
|
||||||
from superset.result_set import SupersetResultSet
|
from superset.result_set import SupersetResultSet
|
||||||
|
from superset.sql.parse import SQLScript
|
||||||
from superset.sql_parse import Table
|
from superset.sql_parse import Table
|
||||||
from superset.superset_typing import (
|
from superset.superset_typing import (
|
||||||
DbapiDescription,
|
DbapiDescription,
|
||||||
|
|
@ -740,6 +741,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
||||||
qry: Select,
|
qry: Select,
|
||||||
catalog: str | None = None,
|
catalog: str | None = None,
|
||||||
schema: str | None = None,
|
schema: str | None = None,
|
||||||
|
is_virtual: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
with self.get_sqla_engine(catalog=catalog, schema=schema) as engine:
|
with self.get_sqla_engine(catalog=catalog, schema=schema) as engine:
|
||||||
sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True}))
|
sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True}))
|
||||||
|
|
@ -748,6 +750,12 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
||||||
if engine.dialect.identifier_preparer._double_percents: # noqa
|
if engine.dialect.identifier_preparer._double_percents: # noqa
|
||||||
sql = sql.replace("%%", "%")
|
sql = sql.replace("%%", "%")
|
||||||
|
|
||||||
|
# for nwo we only optimize queries on virtual datasources, since the only
|
||||||
|
# optimization available is predicate pushdown
|
||||||
|
if is_feature_enabled("OPTIMIZE_SQL") and is_virtual:
|
||||||
|
script = SQLScript(sql, self.db_engine_spec.engine).optimize()
|
||||||
|
sql = script.format()
|
||||||
|
|
||||||
return sql
|
return sql
|
||||||
|
|
||||||
def select_star( # pylint: disable=too-many-arguments
|
def select_star( # pylint: disable=too-many-arguments
|
||||||
|
|
|
||||||
|
|
@ -883,7 +883,12 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
||||||
mutate: bool = True,
|
mutate: bool = True,
|
||||||
) -> QueryStringExtended:
|
) -> QueryStringExtended:
|
||||||
sqlaq = self.get_sqla_query(**query_obj)
|
sqlaq = self.get_sqla_query(**query_obj)
|
||||||
sql = self.database.compile_sqla_query(sqlaq.sqla_query)
|
sql = self.database.compile_sqla_query(
|
||||||
|
sqlaq.sqla_query,
|
||||||
|
catalog=self.catalog,
|
||||||
|
schema=self.schema,
|
||||||
|
is_virtual=bool(self.sql),
|
||||||
|
)
|
||||||
sql = self._apply_cte(sql, sqlaq.cte)
|
sql = self._apply_cte(sql, sqlaq.cte)
|
||||||
|
|
||||||
if mutate:
|
if mutate:
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
import enum
|
import enum
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
|
@ -31,6 +32,7 @@ from deprecation import deprecated
|
||||||
from sqlglot import exp
|
from sqlglot import exp
|
||||||
from sqlglot.dialects.dialect import Dialect, Dialects
|
from sqlglot.dialects.dialect import Dialect, Dialects
|
||||||
from sqlglot.errors import ParseError
|
from sqlglot.errors import ParseError
|
||||||
|
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
|
||||||
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
|
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
|
||||||
|
|
||||||
from superset.exceptions import SupersetParseError
|
from superset.exceptions import SupersetParseError
|
||||||
|
|
@ -227,6 +229,12 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def optimize(self) -> BaseSQLStatement[InternalRepresentation]:
|
||||||
|
"""
|
||||||
|
Return optimized statement.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return self.format()
|
return self.format()
|
||||||
|
|
||||||
|
|
@ -431,6 +439,19 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||||
for eq in set_item.find_all(exp.EQ)
|
for eq in set_item.find_all(exp.EQ)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def optimize(self) -> SQLStatement:
|
||||||
|
"""
|
||||||
|
Return optimized statement.
|
||||||
|
"""
|
||||||
|
# only optimize statements that have a custom dialect
|
||||||
|
if not self._dialect:
|
||||||
|
return SQLStatement(self._sql, self.engine, self._parsed.copy())
|
||||||
|
|
||||||
|
optimized = pushdown_predicates(self._parsed, dialect=self._dialect)
|
||||||
|
sql = optimized.sql(dialect=self._dialect)
|
||||||
|
|
||||||
|
return SQLStatement(sql, self.engine, optimized)
|
||||||
|
|
||||||
|
|
||||||
class KQLSplitState(enum.Enum):
|
class KQLSplitState(enum.Enum):
|
||||||
"""
|
"""
|
||||||
|
|
@ -589,6 +610,14 @@ class KustoKQLStatement(BaseSQLStatement[str]):
|
||||||
"""
|
"""
|
||||||
return self._parsed.startswith(".") and not self._parsed.startswith(".show")
|
return self._parsed.startswith(".") and not self._parsed.startswith(".show")
|
||||||
|
|
||||||
|
def optimize(self) -> KustoKQLStatement:
|
||||||
|
"""
|
||||||
|
Return optimized statement.
|
||||||
|
|
||||||
|
Kusto KQL doesn't support optimization, so this method is a no-op.
|
||||||
|
"""
|
||||||
|
return KustoKQLStatement(self._sql, self.engine, self._parsed)
|
||||||
|
|
||||||
|
|
||||||
class SQLScript:
|
class SQLScript:
|
||||||
"""
|
"""
|
||||||
|
|
@ -643,6 +672,17 @@ class SQLScript:
|
||||||
"""
|
"""
|
||||||
return any(statement.is_mutating() for statement in self.statements)
|
return any(statement.is_mutating() for statement in self.statements)
|
||||||
|
|
||||||
|
def optimize(self) -> SQLScript:
|
||||||
|
"""
|
||||||
|
Return optimized script.
|
||||||
|
"""
|
||||||
|
script = copy.deepcopy(self)
|
||||||
|
script.statements = [ # type: ignore
|
||||||
|
statement.optimize() for statement in self.statements
|
||||||
|
]
|
||||||
|
|
||||||
|
return script
|
||||||
|
|
||||||
|
|
||||||
def extract_tables_from_statement(
|
def extract_tables_from_statement(
|
||||||
statement: exp.Expression,
|
statement: exp.Expression,
|
||||||
|
|
|
||||||
|
|
@ -226,7 +226,7 @@ def test_select_star(mocker: MockerFixture) -> None:
|
||||||
|
|
||||||
# mock the database so we can compile the query
|
# mock the database so we can compile the query
|
||||||
database = mocker.MagicMock()
|
database = mocker.MagicMock()
|
||||||
database.compile_sqla_query = lambda query: str(
|
database.compile_sqla_query = lambda query, catalog, schema: str(
|
||||||
query.compile(dialect=sqlite.dialect())
|
query.compile(dialect=sqlite.dialect())
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -149,7 +149,7 @@ def test_select_star(mocker: MockerFixture) -> None:
|
||||||
|
|
||||||
# mock the database so we can compile the query
|
# mock the database so we can compile the query
|
||||||
database = mocker.MagicMock()
|
database = mocker.MagicMock()
|
||||||
database.compile_sqla_query = lambda query: str(
|
database.compile_sqla_query = lambda query, catalog, schema: str(
|
||||||
query.compile(dialect=BigQueryDialect(), compile_kwargs={"literal_binds": True})
|
query.compile(dialect=BigQueryDialect(), compile_kwargs={"literal_binds": True})
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -21,9 +21,17 @@ from datetime import datetime
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
|
from sqlalchemy import (
|
||||||
|
Column,
|
||||||
|
Integer,
|
||||||
|
MetaData,
|
||||||
|
select,
|
||||||
|
Table as SqlalchemyTable,
|
||||||
|
)
|
||||||
from sqlalchemy.engine.reflection import Inspector
|
from sqlalchemy.engine.reflection import Inspector
|
||||||
from sqlalchemy.engine.url import make_url
|
from sqlalchemy.engine.url import make_url
|
||||||
from sqlalchemy.orm.session import Session
|
from sqlalchemy.orm.session import Session
|
||||||
|
from sqlalchemy.sql import Select
|
||||||
|
|
||||||
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
||||||
from superset.errors import SupersetErrorType
|
from superset.errors import SupersetErrorType
|
||||||
|
|
@ -45,6 +53,29 @@ oauth2_client_info = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def query() -> Select:
|
||||||
|
"""
|
||||||
|
A nested query fixture used to test query optimization.
|
||||||
|
"""
|
||||||
|
metadata = MetaData()
|
||||||
|
some_table = SqlalchemyTable(
|
||||||
|
"some_table",
|
||||||
|
metadata,
|
||||||
|
Column("a", Integer),
|
||||||
|
Column("b", Integer),
|
||||||
|
Column("c", Integer),
|
||||||
|
)
|
||||||
|
|
||||||
|
inner_select = select(some_table.c.a, some_table.c.b, some_table.c.c)
|
||||||
|
outer_select = select(inner_select.c.a, inner_select.c.b).where(
|
||||||
|
inner_select.c.a > 1,
|
||||||
|
inner_select.c.b == 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
return outer_select
|
||||||
|
|
||||||
|
|
||||||
def test_get_metrics(mocker: MockerFixture) -> None:
|
def test_get_metrics(mocker: MockerFixture) -> None:
|
||||||
"""
|
"""
|
||||||
Tests for ``get_metrics``.
|
Tests for ``get_metrics``.
|
||||||
|
|
@ -683,3 +714,56 @@ def test_purge_oauth2_tokens(session: Session) -> None:
|
||||||
# make sure database was not deleted... just in case
|
# make sure database was not deleted... just in case
|
||||||
database = session.query(Database).filter_by(id=database1.id).one()
|
database = session.query(Database).filter_by(id=database1.id).one()
|
||||||
assert database.name == "my_oauth2_db"
|
assert database.name == "my_oauth2_db"
|
||||||
|
|
||||||
|
|
||||||
|
def test_compile_sqla_query_no_optimization(query: Select) -> None:
|
||||||
|
"""
|
||||||
|
Test the `compile_sqla_query` method.
|
||||||
|
"""
|
||||||
|
from superset.models.core import Database
|
||||||
|
|
||||||
|
database = Database(
|
||||||
|
database_name="db",
|
||||||
|
sqlalchemy_uri="sqlite://",
|
||||||
|
)
|
||||||
|
|
||||||
|
space = " "
|
||||||
|
|
||||||
|
assert (
|
||||||
|
database.compile_sqla_query(query, is_virtual=True)
|
||||||
|
== f"""SELECT anon_1.a, anon_1.b{space}
|
||||||
|
FROM (SELECT some_table.a AS a, some_table.b AS b, some_table.c AS c{space}
|
||||||
|
FROM some_table) AS anon_1{space}
|
||||||
|
WHERE anon_1.a > 1 AND anon_1.b = 2"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@with_feature_flags(OPTIMIZE_SQL=True)
|
||||||
|
def test_compile_sqla_query(query: Select) -> None:
|
||||||
|
"""
|
||||||
|
Test the `compile_sqla_query` method.
|
||||||
|
"""
|
||||||
|
from superset.models.core import Database
|
||||||
|
|
||||||
|
database = Database(
|
||||||
|
database_name="db",
|
||||||
|
sqlalchemy_uri="sqlite://",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
database.compile_sqla_query(query, is_virtual=True)
|
||||||
|
== """SELECT
|
||||||
|
anon_1.a,
|
||||||
|
anon_1.b
|
||||||
|
FROM (
|
||||||
|
SELECT
|
||||||
|
some_table.a AS a,
|
||||||
|
some_table.b AS b,
|
||||||
|
some_table.c AS c
|
||||||
|
FROM some_table
|
||||||
|
WHERE
|
||||||
|
some_table.a > 1 AND some_table.b = 2
|
||||||
|
) AS anon_1
|
||||||
|
WHERE
|
||||||
|
TRUE AND TRUE"""
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -1070,3 +1070,46 @@ def test_is_mutating(engine: str) -> None:
|
||||||
"with source as ( select 1 as one ) select * from source",
|
"with source as ( select 1 as one ) select * from source",
|
||||||
engine=engine,
|
engine=engine,
|
||||||
).is_mutating()
|
).is_mutating()
|
||||||
|
|
||||||
|
|
||||||
|
def test_optimize() -> None:
|
||||||
|
"""
|
||||||
|
Test that the `optimize` method works as expected.
|
||||||
|
|
||||||
|
The SQL optimization only works with engines that have a corresponding dialect.
|
||||||
|
"""
|
||||||
|
sql = """
|
||||||
|
SELECT anon_1.a, anon_1.b
|
||||||
|
FROM (SELECT some_table.a AS a, some_table.b AS b, some_table.c AS c
|
||||||
|
FROM some_table) AS anon_1
|
||||||
|
WHERE anon_1.a > 1 AND anon_1.b = 2
|
||||||
|
"""
|
||||||
|
|
||||||
|
optimized = """SELECT
|
||||||
|
anon_1.a,
|
||||||
|
anon_1.b
|
||||||
|
FROM (
|
||||||
|
SELECT
|
||||||
|
some_table.a AS a,
|
||||||
|
some_table.b AS b,
|
||||||
|
some_table.c AS c
|
||||||
|
FROM some_table
|
||||||
|
WHERE
|
||||||
|
some_table.a > 1 AND some_table.b = 2
|
||||||
|
) AS anon_1
|
||||||
|
WHERE
|
||||||
|
TRUE AND TRUE"""
|
||||||
|
|
||||||
|
not_optimized = """
|
||||||
|
SELECT anon_1.a,
|
||||||
|
anon_1.b
|
||||||
|
FROM
|
||||||
|
(SELECT some_table.a AS a,
|
||||||
|
some_table.b AS b,
|
||||||
|
some_table.c AS c
|
||||||
|
FROM some_table) AS anon_1
|
||||||
|
WHERE anon_1.a > 1
|
||||||
|
AND anon_1.b = 2"""
|
||||||
|
|
||||||
|
assert SQLStatement(sql, "sqlite").optimize().format() == optimized
|
||||||
|
assert SQLStatement(sql, "firebolt").optimize().format() == not_optimized
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue