feat(sqlparse): improve table parsing (#26476)

This commit is contained in:
Beto Dealmeida 2024-01-22 11:16:50 -05:00 committed by GitHub
parent d34874cf2b
commit c0b57bd1c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 265 additions and 120 deletions

View File

@ -141,7 +141,9 @@ geographiclib==1.52
geopy==2.2.0 geopy==2.2.0
# via apache-superset # via apache-superset
greenlet==2.0.2 greenlet==2.0.2
# via shillelagh # via
# shillelagh
# sqlalchemy
gunicorn==21.2.0 gunicorn==21.2.0
# via apache-superset # via apache-superset
hashids==1.3.1 hashids==1.3.1
@ -155,7 +157,10 @@ idna==3.2
# email-validator # email-validator
# requests # requests
importlib-metadata==6.6.0 importlib-metadata==6.6.0
# via apache-superset # via
# apache-superset
# flask
# shillelagh
importlib-resources==5.12.0 importlib-resources==5.12.0
# via limits # via limits
isodate==0.6.0 isodate==0.6.0
@ -327,6 +332,8 @@ sqlalchemy-utils==0.38.3
# via # via
# apache-superset # apache-superset
# flask-appbuilder # flask-appbuilder
sqlglot==20.8.0
# via apache-superset
sqlparse==0.4.4 sqlparse==0.4.4
# via apache-superset # via apache-superset
sshtunnel==0.4.0 sshtunnel==0.4.0
@ -376,7 +383,9 @@ wtforms-json==0.3.5
xlsxwriter==3.0.7 xlsxwriter==3.0.7
# via apache-superset # via apache-superset
zipp==3.15.0 zipp==3.15.0
# via importlib-metadata # via
# importlib-metadata
# importlib-resources
# The following packages are considered to be unsafe in a requirements file: # The following packages are considered to be unsafe in a requirements file:
# setuptools # setuptools

View File

@ -24,10 +24,6 @@ db-dtypes==1.1.1
# via pandas-gbq # via pandas-gbq
docker==6.1.1 docker==6.1.1
# via -r requirements/testing.in # via -r requirements/testing.in
exceptiongroup==1.1.1
# via pytest
ephem==4.1.4
# via lunarcalendar
flask-testing==0.8.1 flask-testing==0.8.1
# via -r requirements/testing.in # via -r requirements/testing.in
fonttools==4.39.4 fonttools==4.39.4
@ -121,6 +117,8 @@ pyee==9.0.4
# via playwright # via playwright
pyfakefs==5.2.2 pyfakefs==5.2.2
# via -r requirements/testing.in # via -r requirements/testing.in
pyhive[presto]==0.7.0
# via apache-superset
pytest==7.3.1 pytest==7.3.1
# via # via
# -r requirements/testing.in # -r requirements/testing.in

View File

@ -125,6 +125,7 @@ setup(
"slack_sdk>=3.19.0, <4", "slack_sdk>=3.19.0, <4",
"sqlalchemy>=1.4, <2", "sqlalchemy>=1.4, <2",
"sqlalchemy-utils>=0.38.3, <0.39", "sqlalchemy-utils>=0.38.3, <0.39",
"sqlglot>=20,<21",
"sqlparse>=0.4.4, <0.5", "sqlparse>=0.4.4, <0.5",
"tabulate>=0.8.9, <0.9", "tabulate>=0.8.9, <0.9",
"typing-extensions>=4, <5", "typing-extensions>=4, <5",

View File

@ -70,7 +70,10 @@ class DuplicateDatasetCommand(CreateMixin, BaseCommand):
table.normalize_columns = self._base_model.normalize_columns table.normalize_columns = self._base_model.normalize_columns
table.always_filter_main_dttm = self._base_model.always_filter_main_dttm table.always_filter_main_dttm = self._base_model.always_filter_main_dttm
table.is_sqllab_view = True table.is_sqllab_view = True
table.sql = ParsedQuery(self._base_model.sql).stripped() table.sql = ParsedQuery(
self._base_model.sql,
engine=database.db_engine_spec.engine,
).stripped()
db.session.add(table) db.session.add(table)
cols = [] cols = []
for config_ in self._base_model.columns: for config_ in self._base_model.columns:

View File

@ -115,7 +115,10 @@ class SqlResultExportCommand(BaseCommand):
limit = None limit = None
else: else:
sql = self._query.executed_sql sql = self._query.executed_sql
limit = ParsedQuery(sql).limit limit = ParsedQuery(
sql,
engine=self._query.database.db_engine_spec.engine,
).limit
if limit is not None and self._query.limiting_factor in { if limit is not None and self._query.limiting_factor in {
LimitingFactor.QUERY, LimitingFactor.QUERY,
LimitingFactor.DROPDOWN, LimitingFactor.DROPDOWN,

View File

@ -1457,7 +1457,7 @@ class SqlaTable(
return self.get_sqla_table(), None return self.get_sqla_table(), None
from_sql = self.get_rendered_sql(template_processor) from_sql = self.get_rendered_sql(template_processor)
parsed_query = ParsedQuery(from_sql) parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
if not ( if not (
parsed_query.is_unknown() parsed_query.is_unknown()
or self.db_engine_spec.is_readonly_query(parsed_query) or self.db_engine_spec.is_readonly_query(parsed_query)

View File

@ -111,7 +111,7 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]:
sql = dataset.get_template_processor().process_template( sql = dataset.get_template_processor().process_template(
dataset.sql, **dataset.template_params_dict dataset.sql, **dataset.template_params_dict
) )
parsed_query = ParsedQuery(sql) parsed_query = ParsedQuery(sql, engine=db_engine_spec.engine)
if not db_engine_spec.is_readonly_query(parsed_query): if not db_engine_spec.is_readonly_query(parsed_query):
raise SupersetSecurityException( raise SupersetSecurityException(
SupersetError( SupersetError(

View File

@ -899,7 +899,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return database.compile_sqla_query(qry) return database.compile_sqla_query(qry)
if cls.limit_method == LimitMethod.FORCE_LIMIT: if cls.limit_method == LimitMethod.FORCE_LIMIT:
parsed_query = sql_parse.ParsedQuery(sql) parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
sql = parsed_query.set_or_update_query_limit(limit, force=force) sql = parsed_query.set_or_update_query_limit(limit, force=force)
return sql return sql
@ -980,7 +980,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:param sql: SQL query :param sql: SQL query
:return: Value of limit clause in query :return: Value of limit clause in query
""" """
parsed_query = sql_parse.ParsedQuery(sql) parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
return parsed_query.limit return parsed_query.limit
@classmethod @classmethod
@ -992,7 +992,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:param limit: New limit to insert/replace into query :param limit: New limit to insert/replace into query
:return: Query with new limit :return: Query with new limit
""" """
parsed_query = sql_parse.ParsedQuery(sql) parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
return parsed_query.set_or_update_query_limit(limit) return parsed_query.set_or_update_query_limit(limit)
@classmethod @classmethod
@ -1487,7 +1487,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:param database: Database instance :param database: Database instance
:return: Dictionary with different costs :return: Dictionary with different costs
""" """
parsed_query = ParsedQuery(statement) parsed_query = ParsedQuery(statement, engine=cls.engine)
sql = parsed_query.stripped() sql = parsed_query.stripped()
sql_query_mutator = current_app.config["SQL_QUERY_MUTATOR"] sql_query_mutator = current_app.config["SQL_QUERY_MUTATOR"]
mutate_after_split = current_app.config["MUTATE_AFTER_SPLIT"] mutate_after_split = current_app.config["MUTATE_AFTER_SPLIT"]
@ -1522,7 +1522,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
"Database does not support cost estimation" "Database does not support cost estimation"
) )
parsed_query = sql_parse.ParsedQuery(sql) parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
statements = parsed_query.get_statements() statements = parsed_query.get_statements()
costs = [] costs = []
@ -1583,7 +1583,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:return: :return:
""" """
if not cls.allows_sql_comments: if not cls.allows_sql_comments:
query = sql_parse.strip_comments_from_sql(query) query = sql_parse.strip_comments_from_sql(query, engine=cls.engine)
if cls.arraysize: if cls.arraysize:
cursor.arraysize = cls.arraysize cursor.arraysize = cls.arraysize

View File

@ -435,7 +435,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
if not cls.get_allow_cost_estimate(extra): if not cls.get_allow_cost_estimate(extra):
raise SupersetException("Database does not support cost estimation") raise SupersetException("Database does not support cost estimation")
parsed_query = sql_parse.ParsedQuery(sql) parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
statements = parsed_query.get_statements() statements = parsed_query.get_statements()
costs = [] costs = []
for statement in statements: for statement in statements:

View File

@ -1093,7 +1093,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
""" """
from_sql = self.get_rendered_sql(template_processor) from_sql = self.get_rendered_sql(template_processor)
parsed_query = ParsedQuery(from_sql) parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
if not ( if not (
parsed_query.is_unknown() parsed_query.is_unknown()
or self.db_engine_spec.is_readonly_query(parsed_query) or self.db_engine_spec.is_readonly_query(parsed_query)

View File

@ -183,7 +183,7 @@ class Query(
@property @property
def sql_tables(self) -> list[Table]: def sql_tables(self) -> list[Table]:
return list(ParsedQuery(self.sql).tables) return list(ParsedQuery(self.sql, engine=self.db_engine_spec.engine).tables)
@property @property
def columns(self) -> list["TableColumn"]: def columns(self) -> list["TableColumn"]:
@ -427,7 +427,9 @@ class SavedQuery(AuditMixinNullable, ExtraJSONMixin, ImportExportMixin, Model):
@property @property
def sql_tables(self) -> list[Table]: def sql_tables(self) -> list[Table]:
return list(ParsedQuery(self.sql).tables) return list(
ParsedQuery(self.sql, engine=self.database.db_engine_spec.engine).tables
)
@property @property
def last_run_humanized(self) -> str: def last_run_humanized(self) -> str:

View File

@ -1876,7 +1876,10 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
default_schema = database.get_default_schema_for_query(query) default_schema = database.get_default_schema_for_query(query)
tables = { tables = {
Table(table_.table, table_.schema or default_schema) Table(table_.table, table_.schema or default_schema)
for table_ in sql_parse.ParsedQuery(query.sql).tables for table_ in sql_parse.ParsedQuery(
query.sql,
engine=database.db_engine_spec.engine,
).tables
} }
elif table: elif table:
tables = {table} tables = {table}

View File

@ -199,7 +199,7 @@ def execute_sql_statement(
database: Database = query.database database: Database = query.database
db_engine_spec = database.db_engine_spec db_engine_spec = database.db_engine_spec
parsed_query = ParsedQuery(sql_statement) parsed_query = ParsedQuery(sql_statement, engine=db_engine_spec.engine)
if is_feature_enabled("RLS_IN_SQLLAB"): if is_feature_enabled("RLS_IN_SQLLAB"):
# There are two ways to insert RLS: either replacing the table with a subquery # There are two ways to insert RLS: either replacing the table with a subquery
# that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is # that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is
@ -219,7 +219,8 @@ def execute_sql_statement(
database.id, database.id,
query.schema, query.schema,
) )
) ),
engine=db_engine_spec.engine,
) )
sql = parsed_query.stripped() sql = parsed_query.stripped()
@ -409,7 +410,11 @@ def execute_sql_statements(
) )
# Breaking down into multiple statements # Breaking down into multiple statements
parsed_query = ParsedQuery(rendered_query, strip_comments=True) parsed_query = ParsedQuery(
rendered_query,
strip_comments=True,
engine=db_engine_spec.engine,
)
if not db_engine_spec.run_multiple_statements_as_one: if not db_engine_spec.run_multiple_statements_as_one:
statements = parsed_query.get_statements() statements = parsed_query.get_statements()
logger.info( logger.info(

View File

@ -14,15 +14,22 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=too-many-lines
import logging import logging
import re import re
from collections.abc import Iterator import urllib.parse
from collections.abc import Iterable, Iterator
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, cast, Optional from typing import Any, cast, Optional
from urllib import parse
import sqlparse import sqlparse
from sqlalchemy import and_ from sqlalchemy import and_
from sqlglot import exp, parse, parse_one
from sqlglot.dialects import Dialects
from sqlglot.errors import ParseError
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
from sqlparse import keywords from sqlparse import keywords
from sqlparse.lexer import Lexer from sqlparse.lexer import Lexer
from sqlparse.sql import ( from sqlparse.sql import (
@ -53,7 +60,7 @@ from superset.utils.backports import StrEnum
try: try:
from sqloxide import parse_sql as sqloxide_parse from sqloxide import parse_sql as sqloxide_parse
except: # pylint: disable=bare-except except (ImportError, ModuleNotFoundError):
sqloxide_parse = None sqloxide_parse = None
RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"} RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"}
@ -72,6 +79,59 @@ sqlparser_sql_regex.insert(25, (r"'(''|\\\\|\\|[^'])*'", sqlparse.tokens.String.
lex.set_SQL_REGEX(sqlparser_sql_regex) lex.set_SQL_REGEX(sqlparser_sql_regex)
# mapping between DB engine specs and sqlglot dialects
SQLGLOT_DIALECTS = {
"ascend": Dialects.HIVE,
"awsathena": Dialects.PRESTO,
"bigquery": Dialects.BIGQUERY,
"clickhouse": Dialects.CLICKHOUSE,
"clickhousedb": Dialects.CLICKHOUSE,
"cockroachdb": Dialects.POSTGRES,
# "crate": ???
# "databend": ???
"databricks": Dialects.DATABRICKS,
# "db2": ???
# "dremio": ???
"drill": Dialects.DRILL,
# "druid": ???
"duckdb": Dialects.DUCKDB,
# "dynamodb": ???
# "elasticsearch": ???
# "exa": ???
# "firebird": ???
# "firebolt": ???
"gsheets": Dialects.SQLITE,
"hana": Dialects.POSTGRES,
"hive": Dialects.HIVE,
# "ibmi": ???
# "impala": ???
# "kustokql": ???
# "kylin": ???
# "mssql": ???
"mysql": Dialects.MYSQL,
"netezza": Dialects.POSTGRES,
# "ocient": ???
# "odelasticsearch": ???
"oracle": Dialects.ORACLE,
# "pinot": ???
"postgresql": Dialects.POSTGRES,
"presto": Dialects.PRESTO,
"pydoris": Dialects.DORIS,
"redshift": Dialects.REDSHIFT,
# "risingwave": ???
# "rockset": ???
"shillelagh": Dialects.SQLITE,
"snowflake": Dialects.SNOWFLAKE,
# "solr": ???
"sqlite": Dialects.SQLITE,
"starrocks": Dialects.STARROCKS,
"superset": Dialects.SQLITE,
"teradatasql": Dialects.TERADATA,
"trino": Dialects.TRINO,
"vertica": Dialects.POSTGRES,
}
class CtasMethod(StrEnum): class CtasMethod(StrEnum):
TABLE = "TABLE" TABLE = "TABLE"
VIEW = "VIEW" VIEW = "VIEW"
@ -150,7 +210,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
return cte, remainder return cte, remainder
def strip_comments_from_sql(statement: str) -> str: def strip_comments_from_sql(statement: str, engine: Optional[str] = None) -> str:
""" """
Strips comments from a SQL statement, does a simple test first Strips comments from a SQL statement, does a simple test first
to avoid always instantiating the expensive ParsedQuery constructor to avoid always instantiating the expensive ParsedQuery constructor
@ -160,7 +220,11 @@ def strip_comments_from_sql(statement: str) -> str:
:param statement: A string with the SQL statement :param statement: A string with the SQL statement
:return: SQL statement without comments :return: SQL statement without comments
""" """
return ParsedQuery(statement).strip_comments() if "--" in statement else statement return (
ParsedQuery(statement, engine=engine).strip_comments()
if "--" in statement
else statement
)
@dataclass(eq=True, frozen=True) @dataclass(eq=True, frozen=True)
@ -179,7 +243,7 @@ class Table:
""" """
return ".".join( return ".".join(
parse.quote(part, safe="").replace(".", "%2E") urllib.parse.quote(part, safe="").replace(".", "%2E")
for part in [self.catalog, self.schema, self.table] for part in [self.catalog, self.schema, self.table]
if part if part
) )
@ -189,11 +253,17 @@ class Table:
class ParsedQuery: class ParsedQuery:
def __init__(self, sql_statement: str, strip_comments: bool = False): def __init__(
self,
sql_statement: str,
strip_comments: bool = False,
engine: Optional[str] = None,
):
if strip_comments: if strip_comments:
sql_statement = sqlparse.format(sql_statement, strip_comments=True) sql_statement = sqlparse.format(sql_statement, strip_comments=True)
self.sql: str = sql_statement self.sql: str = sql_statement
self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
self._tables: set[Table] = set() self._tables: set[Table] = set()
self._alias_names: set[str] = set() self._alias_names: set[str] = set()
self._limit: Optional[int] = None self._limit: Optional[int] = None
@ -206,14 +276,94 @@ class ParsedQuery:
@property @property
def tables(self) -> set[Table]: def tables(self) -> set[Table]:
if not self._tables: if not self._tables:
for statement in self._parsed: self._tables = self._extract_tables_from_sql()
self._extract_from_token(statement)
self._tables = {
table for table in self._tables if str(table) not in self._alias_names
}
return self._tables return self._tables
def _extract_tables_from_sql(self) -> set[Table]:
"""
Extract all table references in a query.
Note: this uses sqlglot, since it's better at catching more edge cases.
"""
try:
statements = parse(self.sql, dialect=self._dialect)
except ParseError:
logger.warning("Unable to parse SQL (%s): %s", self._dialect, self.sql)
return set()
return {
table
for statement in statements
for table in self._extract_tables_from_statement(statement)
if statement
}
def _extract_tables_from_statement(self, statement: exp.Expression) -> set[Table]:
"""
Extract all table references in a single statement.
Please not that this is not trivial; consider the following queries:
DESCRIBE some_table;
SHOW PARTITIONS FROM some_table;
WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM masked_name;
See the unit tests for other tricky cases.
"""
sources: Iterable[exp.Table]
if isinstance(statement, exp.Describe):
# A `DESCRIBE` query has no sources in sqlglot, so we need to explicitly
# query for all tables.
sources = statement.find_all(exp.Table)
elif isinstance(statement, exp.Command):
# Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a
# `SELECT` statetement in order to extract tables.
literal = statement.find(exp.Literal)
if not literal:
return set()
pseudo_query = parse_one(f"SELECT {literal.this}", dialect=self._dialect)
sources = pseudo_query.find_all(exp.Table)
else:
sources = [
source
for scope in traverse_scope(statement)
for source in scope.sources.values()
if isinstance(source, exp.Table) and not self._is_cte(source, scope)
]
return {
Table(
source.name,
source.db if source.db != "" else None,
source.catalog if source.catalog != "" else None,
)
for source in sources
}
def _is_cte(self, source: exp.Table, scope: Scope) -> bool:
"""
Is the source a CTE?
CTEs in the parent scope look like tables (and are represented by
exp.Table objects), but should not be considered as such;
otherwise a user with access to table `foo` could access any table
with a query like this:
WITH foo AS (SELECT * FROM target_table) SELECT * FROM foo
"""
parent_sources = scope.parent.sources if scope.parent else {}
ctes_in_scope = {
name
for name, parent_scope in parent_sources.items()
if isinstance(parent_scope, Scope)
and parent_scope.scope_type == ScopeType.CTE
}
return source.name in ctes_in_scope
@property @property
def limit(self) -> Optional[int]: def limit(self) -> Optional[int]:
return self._limit return self._limit
@ -393,28 +543,6 @@ class ParsedQuery:
def _is_identifier(token: Token) -> bool: def _is_identifier(token: Token) -> bool:
return isinstance(token, (IdentifierList, Identifier)) return isinstance(token, (IdentifierList, Identifier))
def _process_tokenlist(self, token_list: TokenList) -> None:
"""
Add table names to table set
:param token_list: TokenList to be processed
"""
# exclude subselects
if "(" not in str(token_list):
table = self.get_table(token_list)
if table and not table.table.startswith(CTE_PREFIX):
self._tables.add(table)
return
# store aliases
if token_list.has_alias():
self._alias_names.add(token_list.get_alias())
# some aliases are not parsed properly
if token_list.tokens[0].ttype == Name:
self._alias_names.add(token_list.tokens[0].value)
self._extract_from_token(token_list)
def as_create_table( def as_create_table(
self, self,
table_name: str, table_name: str,
@ -441,50 +569,6 @@ class ParsedQuery:
exec_sql += f"CREATE {method} {full_table_name} AS \n{sql}" exec_sql += f"CREATE {method} {full_table_name} AS \n{sql}"
return exec_sql return exec_sql
def _extract_from_token(self, token: Token) -> None:
"""
<Identifier> store a list of subtokens and <IdentifierList> store lists of
subtoken list.
It extracts <IdentifierList> and <Identifier> from :param token: and loops
through all subtokens recursively. It finds table_name_preceding_token and
passes <IdentifierList> and <Identifier> to self._process_tokenlist to populate
self._tables.
:param token: instance of Token or child class, e.g. TokenList, to be processed
"""
if not hasattr(token, "tokens"):
return
table_name_preceding_token = False
for item in token.tokens:
if item.is_group and (
not self._is_identifier(item) or isinstance(item.tokens[0], Parenthesis)
):
self._extract_from_token(item)
if item.ttype in Keyword and (
item.normalized in PRECEDES_TABLE_NAME
or item.normalized.endswith(" JOIN")
):
table_name_preceding_token = True
continue
if item.ttype in Keyword:
table_name_preceding_token = False
continue
if table_name_preceding_token:
if isinstance(item, Identifier):
self._process_tokenlist(item)
elif isinstance(item, IdentifierList):
for token2 in item.get_identifiers():
if isinstance(token2, TokenList):
self._process_tokenlist(token2)
elif isinstance(item, IdentifierList):
if any(not self._is_identifier(token2) for token2 in item.tokens):
self._extract_from_token(item)
def set_or_update_query_limit(self, new_limit: int, force: bool = False) -> str: def set_or_update_query_limit(self, new_limit: int, force: bool = False) -> str:
"""Returns the query with the specified limit. """Returns the query with the specified limit.
@ -881,7 +965,7 @@ def insert_rls_in_predicate(
# mapping between sqloxide and SQLAlchemy dialects # mapping between sqloxide and SQLAlchemy dialects
SQLOXITE_DIALECTS = { SQLOXIDE_DIALECTS = {
"ansi": {"trino", "trinonative", "presto"}, "ansi": {"trino", "trinonative", "presto"},
"hive": {"hive", "databricks"}, "hive": {"hive", "databricks"},
"ms": {"mssql"}, "ms": {"mssql"},
@ -914,7 +998,7 @@ def extract_table_references(
tree = None tree = None
if sqloxide_parse: if sqloxide_parse:
for dialect, sqla_dialects in SQLOXITE_DIALECTS.items(): for dialect, sqla_dialects in SQLOXIDE_DIALECTS.items():
if sqla_dialect in sqla_dialects: if sqla_dialect in sqla_dialects:
break break
sql_text = RE_JINJA_BLOCK.sub(" ", sql_text) sql_text = RE_JINJA_BLOCK.sub(" ", sql_text)

View File

@ -50,7 +50,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
) -> Optional[SQLValidationAnnotation]: ) -> Optional[SQLValidationAnnotation]:
# pylint: disable=too-many-locals # pylint: disable=too-many-locals
db_engine_spec = database.db_engine_spec db_engine_spec = database.db_engine_spec
parsed_query = ParsedQuery(statement) parsed_query = ParsedQuery(statement, engine=db_engine_spec.engine)
sql = parsed_query.stripped() sql = parsed_query.stripped()
# Hook to allow environment-specific mutation (usually comments) to the SQL # Hook to allow environment-specific mutation (usually comments) to the SQL
@ -154,7 +154,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
For example, "SELECT 1 FROM default.mytable" becomes "EXPLAIN (TYPE For example, "SELECT 1 FROM default.mytable" becomes "EXPLAIN (TYPE
VALIDATE) SELECT 1 FROM default.mytable. VALIDATE) SELECT 1 FROM default.mytable.
""" """
parsed_query = ParsedQuery(sql) parsed_query = ParsedQuery(sql, engine=database.db_engine_spec.engine)
statements = parsed_query.get_statements() statements = parsed_query.get_statements()
logger.info("Validating %i statement(s)", len(statements)) logger.info("Validating %i statement(s)", len(statements))

View File

@ -58,7 +58,11 @@ class SqlQueryRenderImpl(SqlQueryRender):
database=query_model.database, query=query_model database=query_model.database, query=query_model
) )
parsed_query = ParsedQuery(query_model.sql, strip_comments=True) parsed_query = ParsedQuery(
query_model.sql,
strip_comments=True,
engine=query_model.database.db_engine_spec.engine,
)
rendered_query = sql_template_processor.process_template( rendered_query = sql_template_processor.process_template(
parsed_query.stripped(), **execution_context.template_params parsed_query.stripped(), **execution_context.template_params
) )

View File

@ -40,11 +40,11 @@ from superset.sql_parse import (
) )
def extract_tables(query: str) -> set[Table]: def extract_tables(query: str, engine: Optional[str] = None) -> set[Table]:
""" """
Helper function to extract tables referenced in a query. Helper function to extract tables referenced in a query.
""" """
return ParsedQuery(query).tables return ParsedQuery(query, engine=engine).tables
def test_table() -> None: def test_table() -> None:
@ -96,8 +96,13 @@ def test_extract_tables() -> None:
Table("left_table") Table("left_table")
} }
# reverse select assert extract_tables(
assert extract_tables("FROM t1 SELECT field") == {Table("t1")} "SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;"
) == {Table("forbidden_table")}
assert extract_tables(
"select * from (select * from forbidden_table) forbidden_table"
) == {Table("forbidden_table")}
def test_extract_tables_subselect() -> None: def test_extract_tables_subselect() -> None:
@ -263,14 +268,16 @@ def test_extract_tables_illdefined() -> None:
assert extract_tables("SELECT * FROM schemaname.") == set() assert extract_tables("SELECT * FROM schemaname.") == set()
assert extract_tables("SELECT * FROM catalogname.schemaname.") == set() assert extract_tables("SELECT * FROM catalogname.schemaname.") == set()
assert extract_tables("SELECT * FROM catalogname..") == set() assert extract_tables("SELECT * FROM catalogname..") == set()
assert extract_tables("SELECT * FROM catalogname..tbname") == set() assert extract_tables("SELECT * FROM catalogname..tbname") == {
Table(table="tbname", schema=None, catalog="catalogname")
}
def test_extract_tables_show_tables_from() -> None: def test_extract_tables_show_tables_from() -> None:
""" """
Test ``SHOW TABLES FROM``. Test ``SHOW TABLES FROM``.
""" """
assert extract_tables("SHOW TABLES FROM s1 like '%order%'") == set() assert extract_tables("SHOW TABLES FROM s1 like '%order%'", "mysql") == set()
def test_extract_tables_show_columns_from() -> None: def test_extract_tables_show_columns_from() -> None:
@ -311,7 +318,7 @@ WHERE regionkey IN (SELECT regionkey FROM t2)
""" """
SELECT name SELECT name
FROM t1 FROM t1
WHERE regionkey EXISTS (SELECT regionkey FROM t2) WHERE EXISTS (SELECT 1 FROM t2 WHERE t1.regionkey = t2.regionkey);
""" """
) )
== {Table("t1"), Table("t2")} == {Table("t1"), Table("t2")}
@ -526,6 +533,18 @@ select * from (select key from q1) a
== {Table("src")} == {Table("src")}
) )
# weird query with circular dependency
assert (
extract_tables(
"""
with src as ( select key from q2 where key = '5'),
q2 as ( select key from src where key = '5')
select * from (select key from src) a
"""
)
== set()
)
def test_extract_tables_multistatement() -> None: def test_extract_tables_multistatement() -> None:
""" """
@ -665,7 +684,8 @@ def test_extract_tables_nested_select() -> None:
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME) select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME)
from INFORMATION_SCHEMA.COLUMNS from INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA like "%bi%"),0x7e))); WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
""" """,
"mysql",
) )
== {Table("COLUMNS", "INFORMATION_SCHEMA")} == {Table("COLUMNS", "INFORMATION_SCHEMA")}
) )
@ -676,7 +696,8 @@ WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME) select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
from INFORMATION_SCHEMA.COLUMNS from INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_NAME="bi_achievement_daily"),0x7e))); WHERE TABLE_NAME="bi_achievement_daily"),0x7e)));
""" """,
"mysql",
) )
== {Table("COLUMNS", "INFORMATION_SCHEMA")} == {Table("COLUMNS", "INFORMATION_SCHEMA")}
) )
@ -1306,6 +1327,14 @@ def test_sqlparse_issue_652():
"(SELECT table_name FROM /**/ information_schema.tables WHERE table_name LIKE '%user%' LIMIT 1)", "(SELECT table_name FROM /**/ information_schema.tables WHERE table_name LIKE '%user%' LIMIT 1)",
True, True,
), ),
(
"SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;",
True,
),
(
"SELECT * FROM (SELECT * FROM forbidden_table) forbidden_table",
True,
),
], ],
) )
def test_has_table_query(sql: str, expected: bool) -> None: def test_has_table_query(sql: str, expected: bool) -> None:
@ -1790,13 +1819,17 @@ def test_extract_table_references(mocker: MockerFixture) -> None:
assert extract_table_references( assert extract_table_references(
sql, sql,
"trino", "trino",
) == {Table(table="other_table", schema=None, catalog=None)} ) == {
Table(table="table", schema=None, catalog=None),
Table(table="other_table", schema=None, catalog=None),
}
logger.warning.assert_called_once() logger.warning.assert_called_once()
logger = mocker.patch("superset.migrations.shared.utils.logger") logger = mocker.patch("superset.migrations.shared.utils.logger")
sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table" sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table"
assert extract_table_references(sql, "trino", show_warning=False) == { assert extract_table_references(sql, "trino", show_warning=False) == {
Table(table="other_table", schema=None, catalog=None) Table(table="table", schema=None, catalog=None),
Table(table="other_table", schema=None, catalog=None),
} }
logger.warning.assert_not_called() logger.warning.assert_not_called()