feat(sqlparse): improve table parsing (#26476)
This commit is contained in:
parent
d34874cf2b
commit
c0b57bd1c3
|
|
@ -141,7 +141,9 @@ geographiclib==1.52
|
||||||
geopy==2.2.0
|
geopy==2.2.0
|
||||||
# via apache-superset
|
# via apache-superset
|
||||||
greenlet==2.0.2
|
greenlet==2.0.2
|
||||||
# via shillelagh
|
# via
|
||||||
|
# shillelagh
|
||||||
|
# sqlalchemy
|
||||||
gunicorn==21.2.0
|
gunicorn==21.2.0
|
||||||
# via apache-superset
|
# via apache-superset
|
||||||
hashids==1.3.1
|
hashids==1.3.1
|
||||||
|
|
@ -155,7 +157,10 @@ idna==3.2
|
||||||
# email-validator
|
# email-validator
|
||||||
# requests
|
# requests
|
||||||
importlib-metadata==6.6.0
|
importlib-metadata==6.6.0
|
||||||
# via apache-superset
|
# via
|
||||||
|
# apache-superset
|
||||||
|
# flask
|
||||||
|
# shillelagh
|
||||||
importlib-resources==5.12.0
|
importlib-resources==5.12.0
|
||||||
# via limits
|
# via limits
|
||||||
isodate==0.6.0
|
isodate==0.6.0
|
||||||
|
|
@ -327,6 +332,8 @@ sqlalchemy-utils==0.38.3
|
||||||
# via
|
# via
|
||||||
# apache-superset
|
# apache-superset
|
||||||
# flask-appbuilder
|
# flask-appbuilder
|
||||||
|
sqlglot==20.8.0
|
||||||
|
# via apache-superset
|
||||||
sqlparse==0.4.4
|
sqlparse==0.4.4
|
||||||
# via apache-superset
|
# via apache-superset
|
||||||
sshtunnel==0.4.0
|
sshtunnel==0.4.0
|
||||||
|
|
@ -376,7 +383,9 @@ wtforms-json==0.3.5
|
||||||
xlsxwriter==3.0.7
|
xlsxwriter==3.0.7
|
||||||
# via apache-superset
|
# via apache-superset
|
||||||
zipp==3.15.0
|
zipp==3.15.0
|
||||||
# via importlib-metadata
|
# via
|
||||||
|
# importlib-metadata
|
||||||
|
# importlib-resources
|
||||||
|
|
||||||
# The following packages are considered to be unsafe in a requirements file:
|
# The following packages are considered to be unsafe in a requirements file:
|
||||||
# setuptools
|
# setuptools
|
||||||
|
|
|
||||||
|
|
@ -24,10 +24,6 @@ db-dtypes==1.1.1
|
||||||
# via pandas-gbq
|
# via pandas-gbq
|
||||||
docker==6.1.1
|
docker==6.1.1
|
||||||
# via -r requirements/testing.in
|
# via -r requirements/testing.in
|
||||||
exceptiongroup==1.1.1
|
|
||||||
# via pytest
|
|
||||||
ephem==4.1.4
|
|
||||||
# via lunarcalendar
|
|
||||||
flask-testing==0.8.1
|
flask-testing==0.8.1
|
||||||
# via -r requirements/testing.in
|
# via -r requirements/testing.in
|
||||||
fonttools==4.39.4
|
fonttools==4.39.4
|
||||||
|
|
@ -121,6 +117,8 @@ pyee==9.0.4
|
||||||
# via playwright
|
# via playwright
|
||||||
pyfakefs==5.2.2
|
pyfakefs==5.2.2
|
||||||
# via -r requirements/testing.in
|
# via -r requirements/testing.in
|
||||||
|
pyhive[presto]==0.7.0
|
||||||
|
# via apache-superset
|
||||||
pytest==7.3.1
|
pytest==7.3.1
|
||||||
# via
|
# via
|
||||||
# -r requirements/testing.in
|
# -r requirements/testing.in
|
||||||
|
|
|
||||||
1
setup.py
1
setup.py
|
|
@ -125,6 +125,7 @@ setup(
|
||||||
"slack_sdk>=3.19.0, <4",
|
"slack_sdk>=3.19.0, <4",
|
||||||
"sqlalchemy>=1.4, <2",
|
"sqlalchemy>=1.4, <2",
|
||||||
"sqlalchemy-utils>=0.38.3, <0.39",
|
"sqlalchemy-utils>=0.38.3, <0.39",
|
||||||
|
"sqlglot>=20,<21",
|
||||||
"sqlparse>=0.4.4, <0.5",
|
"sqlparse>=0.4.4, <0.5",
|
||||||
"tabulate>=0.8.9, <0.9",
|
"tabulate>=0.8.9, <0.9",
|
||||||
"typing-extensions>=4, <5",
|
"typing-extensions>=4, <5",
|
||||||
|
|
|
||||||
|
|
@ -70,7 +70,10 @@ class DuplicateDatasetCommand(CreateMixin, BaseCommand):
|
||||||
table.normalize_columns = self._base_model.normalize_columns
|
table.normalize_columns = self._base_model.normalize_columns
|
||||||
table.always_filter_main_dttm = self._base_model.always_filter_main_dttm
|
table.always_filter_main_dttm = self._base_model.always_filter_main_dttm
|
||||||
table.is_sqllab_view = True
|
table.is_sqllab_view = True
|
||||||
table.sql = ParsedQuery(self._base_model.sql).stripped()
|
table.sql = ParsedQuery(
|
||||||
|
self._base_model.sql,
|
||||||
|
engine=database.db_engine_spec.engine,
|
||||||
|
).stripped()
|
||||||
db.session.add(table)
|
db.session.add(table)
|
||||||
cols = []
|
cols = []
|
||||||
for config_ in self._base_model.columns:
|
for config_ in self._base_model.columns:
|
||||||
|
|
|
||||||
|
|
@ -115,7 +115,10 @@ class SqlResultExportCommand(BaseCommand):
|
||||||
limit = None
|
limit = None
|
||||||
else:
|
else:
|
||||||
sql = self._query.executed_sql
|
sql = self._query.executed_sql
|
||||||
limit = ParsedQuery(sql).limit
|
limit = ParsedQuery(
|
||||||
|
sql,
|
||||||
|
engine=self._query.database.db_engine_spec.engine,
|
||||||
|
).limit
|
||||||
if limit is not None and self._query.limiting_factor in {
|
if limit is not None and self._query.limiting_factor in {
|
||||||
LimitingFactor.QUERY,
|
LimitingFactor.QUERY,
|
||||||
LimitingFactor.DROPDOWN,
|
LimitingFactor.DROPDOWN,
|
||||||
|
|
|
||||||
|
|
@ -1457,7 +1457,7 @@ class SqlaTable(
|
||||||
return self.get_sqla_table(), None
|
return self.get_sqla_table(), None
|
||||||
|
|
||||||
from_sql = self.get_rendered_sql(template_processor)
|
from_sql = self.get_rendered_sql(template_processor)
|
||||||
parsed_query = ParsedQuery(from_sql)
|
parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
|
||||||
if not (
|
if not (
|
||||||
parsed_query.is_unknown()
|
parsed_query.is_unknown()
|
||||||
or self.db_engine_spec.is_readonly_query(parsed_query)
|
or self.db_engine_spec.is_readonly_query(parsed_query)
|
||||||
|
|
|
||||||
|
|
@ -111,7 +111,7 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]:
|
||||||
sql = dataset.get_template_processor().process_template(
|
sql = dataset.get_template_processor().process_template(
|
||||||
dataset.sql, **dataset.template_params_dict
|
dataset.sql, **dataset.template_params_dict
|
||||||
)
|
)
|
||||||
parsed_query = ParsedQuery(sql)
|
parsed_query = ParsedQuery(sql, engine=db_engine_spec.engine)
|
||||||
if not db_engine_spec.is_readonly_query(parsed_query):
|
if not db_engine_spec.is_readonly_query(parsed_query):
|
||||||
raise SupersetSecurityException(
|
raise SupersetSecurityException(
|
||||||
SupersetError(
|
SupersetError(
|
||||||
|
|
|
||||||
|
|
@ -899,7 +899,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||||
return database.compile_sqla_query(qry)
|
return database.compile_sqla_query(qry)
|
||||||
|
|
||||||
if cls.limit_method == LimitMethod.FORCE_LIMIT:
|
if cls.limit_method == LimitMethod.FORCE_LIMIT:
|
||||||
parsed_query = sql_parse.ParsedQuery(sql)
|
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
|
||||||
sql = parsed_query.set_or_update_query_limit(limit, force=force)
|
sql = parsed_query.set_or_update_query_limit(limit, force=force)
|
||||||
|
|
||||||
return sql
|
return sql
|
||||||
|
|
@ -980,7 +980,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||||
:param sql: SQL query
|
:param sql: SQL query
|
||||||
:return: Value of limit clause in query
|
:return: Value of limit clause in query
|
||||||
"""
|
"""
|
||||||
parsed_query = sql_parse.ParsedQuery(sql)
|
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
|
||||||
return parsed_query.limit
|
return parsed_query.limit
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -992,7 +992,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||||
:param limit: New limit to insert/replace into query
|
:param limit: New limit to insert/replace into query
|
||||||
:return: Query with new limit
|
:return: Query with new limit
|
||||||
"""
|
"""
|
||||||
parsed_query = sql_parse.ParsedQuery(sql)
|
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
|
||||||
return parsed_query.set_or_update_query_limit(limit)
|
return parsed_query.set_or_update_query_limit(limit)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -1487,7 +1487,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||||
:param database: Database instance
|
:param database: Database instance
|
||||||
:return: Dictionary with different costs
|
:return: Dictionary with different costs
|
||||||
"""
|
"""
|
||||||
parsed_query = ParsedQuery(statement)
|
parsed_query = ParsedQuery(statement, engine=cls.engine)
|
||||||
sql = parsed_query.stripped()
|
sql = parsed_query.stripped()
|
||||||
sql_query_mutator = current_app.config["SQL_QUERY_MUTATOR"]
|
sql_query_mutator = current_app.config["SQL_QUERY_MUTATOR"]
|
||||||
mutate_after_split = current_app.config["MUTATE_AFTER_SPLIT"]
|
mutate_after_split = current_app.config["MUTATE_AFTER_SPLIT"]
|
||||||
|
|
@ -1522,7 +1522,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||||
"Database does not support cost estimation"
|
"Database does not support cost estimation"
|
||||||
)
|
)
|
||||||
|
|
||||||
parsed_query = sql_parse.ParsedQuery(sql)
|
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
|
||||||
statements = parsed_query.get_statements()
|
statements = parsed_query.get_statements()
|
||||||
|
|
||||||
costs = []
|
costs = []
|
||||||
|
|
@ -1583,7 +1583,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if not cls.allows_sql_comments:
|
if not cls.allows_sql_comments:
|
||||||
query = sql_parse.strip_comments_from_sql(query)
|
query = sql_parse.strip_comments_from_sql(query, engine=cls.engine)
|
||||||
|
|
||||||
if cls.arraysize:
|
if cls.arraysize:
|
||||||
cursor.arraysize = cls.arraysize
|
cursor.arraysize = cls.arraysize
|
||||||
|
|
|
||||||
|
|
@ -435,7 +435,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
|
||||||
if not cls.get_allow_cost_estimate(extra):
|
if not cls.get_allow_cost_estimate(extra):
|
||||||
raise SupersetException("Database does not support cost estimation")
|
raise SupersetException("Database does not support cost estimation")
|
||||||
|
|
||||||
parsed_query = sql_parse.ParsedQuery(sql)
|
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
|
||||||
statements = parsed_query.get_statements()
|
statements = parsed_query.get_statements()
|
||||||
costs = []
|
costs = []
|
||||||
for statement in statements:
|
for statement in statements:
|
||||||
|
|
|
||||||
|
|
@ -1093,7 +1093,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from_sql = self.get_rendered_sql(template_processor)
|
from_sql = self.get_rendered_sql(template_processor)
|
||||||
parsed_query = ParsedQuery(from_sql)
|
parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
|
||||||
if not (
|
if not (
|
||||||
parsed_query.is_unknown()
|
parsed_query.is_unknown()
|
||||||
or self.db_engine_spec.is_readonly_query(parsed_query)
|
or self.db_engine_spec.is_readonly_query(parsed_query)
|
||||||
|
|
|
||||||
|
|
@ -183,7 +183,7 @@ class Query(
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sql_tables(self) -> list[Table]:
|
def sql_tables(self) -> list[Table]:
|
||||||
return list(ParsedQuery(self.sql).tables)
|
return list(ParsedQuery(self.sql, engine=self.db_engine_spec.engine).tables)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def columns(self) -> list["TableColumn"]:
|
def columns(self) -> list["TableColumn"]:
|
||||||
|
|
@ -427,7 +427,9 @@ class SavedQuery(AuditMixinNullable, ExtraJSONMixin, ImportExportMixin, Model):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sql_tables(self) -> list[Table]:
|
def sql_tables(self) -> list[Table]:
|
||||||
return list(ParsedQuery(self.sql).tables)
|
return list(
|
||||||
|
ParsedQuery(self.sql, engine=self.database.db_engine_spec.engine).tables
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def last_run_humanized(self) -> str:
|
def last_run_humanized(self) -> str:
|
||||||
|
|
|
||||||
|
|
@ -1876,7 +1876,10 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||||
default_schema = database.get_default_schema_for_query(query)
|
default_schema = database.get_default_schema_for_query(query)
|
||||||
tables = {
|
tables = {
|
||||||
Table(table_.table, table_.schema or default_schema)
|
Table(table_.table, table_.schema or default_schema)
|
||||||
for table_ in sql_parse.ParsedQuery(query.sql).tables
|
for table_ in sql_parse.ParsedQuery(
|
||||||
|
query.sql,
|
||||||
|
engine=database.db_engine_spec.engine,
|
||||||
|
).tables
|
||||||
}
|
}
|
||||||
elif table:
|
elif table:
|
||||||
tables = {table}
|
tables = {table}
|
||||||
|
|
|
||||||
|
|
@ -199,7 +199,7 @@ def execute_sql_statement(
|
||||||
database: Database = query.database
|
database: Database = query.database
|
||||||
db_engine_spec = database.db_engine_spec
|
db_engine_spec = database.db_engine_spec
|
||||||
|
|
||||||
parsed_query = ParsedQuery(sql_statement)
|
parsed_query = ParsedQuery(sql_statement, engine=db_engine_spec.engine)
|
||||||
if is_feature_enabled("RLS_IN_SQLLAB"):
|
if is_feature_enabled("RLS_IN_SQLLAB"):
|
||||||
# There are two ways to insert RLS: either replacing the table with a subquery
|
# There are two ways to insert RLS: either replacing the table with a subquery
|
||||||
# that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is
|
# that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is
|
||||||
|
|
@ -219,7 +219,8 @@ def execute_sql_statement(
|
||||||
database.id,
|
database.id,
|
||||||
query.schema,
|
query.schema,
|
||||||
)
|
)
|
||||||
)
|
),
|
||||||
|
engine=db_engine_spec.engine,
|
||||||
)
|
)
|
||||||
|
|
||||||
sql = parsed_query.stripped()
|
sql = parsed_query.stripped()
|
||||||
|
|
@ -409,7 +410,11 @@ def execute_sql_statements(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Breaking down into multiple statements
|
# Breaking down into multiple statements
|
||||||
parsed_query = ParsedQuery(rendered_query, strip_comments=True)
|
parsed_query = ParsedQuery(
|
||||||
|
rendered_query,
|
||||||
|
strip_comments=True,
|
||||||
|
engine=db_engine_spec.engine,
|
||||||
|
)
|
||||||
if not db_engine_spec.run_multiple_statements_as_one:
|
if not db_engine_spec.run_multiple_statements_as_one:
|
||||||
statements = parsed_query.get_statements()
|
statements = parsed_query.get_statements()
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
|
||||||
|
|
@ -14,15 +14,22 @@
|
||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
|
# pylint: disable=too-many-lines
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from collections.abc import Iterator
|
import urllib.parse
|
||||||
|
from collections.abc import Iterable, Iterator
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, cast, Optional
|
from typing import Any, cast, Optional
|
||||||
from urllib import parse
|
|
||||||
|
|
||||||
import sqlparse
|
import sqlparse
|
||||||
from sqlalchemy import and_
|
from sqlalchemy import and_
|
||||||
|
from sqlglot import exp, parse, parse_one
|
||||||
|
from sqlglot.dialects import Dialects
|
||||||
|
from sqlglot.errors import ParseError
|
||||||
|
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
|
||||||
from sqlparse import keywords
|
from sqlparse import keywords
|
||||||
from sqlparse.lexer import Lexer
|
from sqlparse.lexer import Lexer
|
||||||
from sqlparse.sql import (
|
from sqlparse.sql import (
|
||||||
|
|
@ -53,7 +60,7 @@ from superset.utils.backports import StrEnum
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from sqloxide import parse_sql as sqloxide_parse
|
from sqloxide import parse_sql as sqloxide_parse
|
||||||
except: # pylint: disable=bare-except
|
except (ImportError, ModuleNotFoundError):
|
||||||
sqloxide_parse = None
|
sqloxide_parse = None
|
||||||
|
|
||||||
RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"}
|
RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"}
|
||||||
|
|
@ -72,6 +79,59 @@ sqlparser_sql_regex.insert(25, (r"'(''|\\\\|\\|[^'])*'", sqlparse.tokens.String.
|
||||||
lex.set_SQL_REGEX(sqlparser_sql_regex)
|
lex.set_SQL_REGEX(sqlparser_sql_regex)
|
||||||
|
|
||||||
|
|
||||||
|
# mapping between DB engine specs and sqlglot dialects
|
||||||
|
SQLGLOT_DIALECTS = {
|
||||||
|
"ascend": Dialects.HIVE,
|
||||||
|
"awsathena": Dialects.PRESTO,
|
||||||
|
"bigquery": Dialects.BIGQUERY,
|
||||||
|
"clickhouse": Dialects.CLICKHOUSE,
|
||||||
|
"clickhousedb": Dialects.CLICKHOUSE,
|
||||||
|
"cockroachdb": Dialects.POSTGRES,
|
||||||
|
# "crate": ???
|
||||||
|
# "databend": ???
|
||||||
|
"databricks": Dialects.DATABRICKS,
|
||||||
|
# "db2": ???
|
||||||
|
# "dremio": ???
|
||||||
|
"drill": Dialects.DRILL,
|
||||||
|
# "druid": ???
|
||||||
|
"duckdb": Dialects.DUCKDB,
|
||||||
|
# "dynamodb": ???
|
||||||
|
# "elasticsearch": ???
|
||||||
|
# "exa": ???
|
||||||
|
# "firebird": ???
|
||||||
|
# "firebolt": ???
|
||||||
|
"gsheets": Dialects.SQLITE,
|
||||||
|
"hana": Dialects.POSTGRES,
|
||||||
|
"hive": Dialects.HIVE,
|
||||||
|
# "ibmi": ???
|
||||||
|
# "impala": ???
|
||||||
|
# "kustokql": ???
|
||||||
|
# "kylin": ???
|
||||||
|
# "mssql": ???
|
||||||
|
"mysql": Dialects.MYSQL,
|
||||||
|
"netezza": Dialects.POSTGRES,
|
||||||
|
# "ocient": ???
|
||||||
|
# "odelasticsearch": ???
|
||||||
|
"oracle": Dialects.ORACLE,
|
||||||
|
# "pinot": ???
|
||||||
|
"postgresql": Dialects.POSTGRES,
|
||||||
|
"presto": Dialects.PRESTO,
|
||||||
|
"pydoris": Dialects.DORIS,
|
||||||
|
"redshift": Dialects.REDSHIFT,
|
||||||
|
# "risingwave": ???
|
||||||
|
# "rockset": ???
|
||||||
|
"shillelagh": Dialects.SQLITE,
|
||||||
|
"snowflake": Dialects.SNOWFLAKE,
|
||||||
|
# "solr": ???
|
||||||
|
"sqlite": Dialects.SQLITE,
|
||||||
|
"starrocks": Dialects.STARROCKS,
|
||||||
|
"superset": Dialects.SQLITE,
|
||||||
|
"teradatasql": Dialects.TERADATA,
|
||||||
|
"trino": Dialects.TRINO,
|
||||||
|
"vertica": Dialects.POSTGRES,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class CtasMethod(StrEnum):
|
class CtasMethod(StrEnum):
|
||||||
TABLE = "TABLE"
|
TABLE = "TABLE"
|
||||||
VIEW = "VIEW"
|
VIEW = "VIEW"
|
||||||
|
|
@ -150,7 +210,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
|
||||||
return cte, remainder
|
return cte, remainder
|
||||||
|
|
||||||
|
|
||||||
def strip_comments_from_sql(statement: str) -> str:
|
def strip_comments_from_sql(statement: str, engine: Optional[str] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Strips comments from a SQL statement, does a simple test first
|
Strips comments from a SQL statement, does a simple test first
|
||||||
to avoid always instantiating the expensive ParsedQuery constructor
|
to avoid always instantiating the expensive ParsedQuery constructor
|
||||||
|
|
@ -160,7 +220,11 @@ def strip_comments_from_sql(statement: str) -> str:
|
||||||
:param statement: A string with the SQL statement
|
:param statement: A string with the SQL statement
|
||||||
:return: SQL statement without comments
|
:return: SQL statement without comments
|
||||||
"""
|
"""
|
||||||
return ParsedQuery(statement).strip_comments() if "--" in statement else statement
|
return (
|
||||||
|
ParsedQuery(statement, engine=engine).strip_comments()
|
||||||
|
if "--" in statement
|
||||||
|
else statement
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(eq=True, frozen=True)
|
@dataclass(eq=True, frozen=True)
|
||||||
|
|
@ -179,7 +243,7 @@ class Table:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return ".".join(
|
return ".".join(
|
||||||
parse.quote(part, safe="").replace(".", "%2E")
|
urllib.parse.quote(part, safe="").replace(".", "%2E")
|
||||||
for part in [self.catalog, self.schema, self.table]
|
for part in [self.catalog, self.schema, self.table]
|
||||||
if part
|
if part
|
||||||
)
|
)
|
||||||
|
|
@ -189,11 +253,17 @@ class Table:
|
||||||
|
|
||||||
|
|
||||||
class ParsedQuery:
|
class ParsedQuery:
|
||||||
def __init__(self, sql_statement: str, strip_comments: bool = False):
|
def __init__(
|
||||||
|
self,
|
||||||
|
sql_statement: str,
|
||||||
|
strip_comments: bool = False,
|
||||||
|
engine: Optional[str] = None,
|
||||||
|
):
|
||||||
if strip_comments:
|
if strip_comments:
|
||||||
sql_statement = sqlparse.format(sql_statement, strip_comments=True)
|
sql_statement = sqlparse.format(sql_statement, strip_comments=True)
|
||||||
|
|
||||||
self.sql: str = sql_statement
|
self.sql: str = sql_statement
|
||||||
|
self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
|
||||||
self._tables: set[Table] = set()
|
self._tables: set[Table] = set()
|
||||||
self._alias_names: set[str] = set()
|
self._alias_names: set[str] = set()
|
||||||
self._limit: Optional[int] = None
|
self._limit: Optional[int] = None
|
||||||
|
|
@ -206,14 +276,94 @@ class ParsedQuery:
|
||||||
@property
|
@property
|
||||||
def tables(self) -> set[Table]:
|
def tables(self) -> set[Table]:
|
||||||
if not self._tables:
|
if not self._tables:
|
||||||
for statement in self._parsed:
|
self._tables = self._extract_tables_from_sql()
|
||||||
self._extract_from_token(statement)
|
|
||||||
|
|
||||||
self._tables = {
|
|
||||||
table for table in self._tables if str(table) not in self._alias_names
|
|
||||||
}
|
|
||||||
return self._tables
|
return self._tables
|
||||||
|
|
||||||
|
def _extract_tables_from_sql(self) -> set[Table]:
|
||||||
|
"""
|
||||||
|
Extract all table references in a query.
|
||||||
|
|
||||||
|
Note: this uses sqlglot, since it's better at catching more edge cases.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
statements = parse(self.sql, dialect=self._dialect)
|
||||||
|
except ParseError:
|
||||||
|
logger.warning("Unable to parse SQL (%s): %s", self._dialect, self.sql)
|
||||||
|
return set()
|
||||||
|
|
||||||
|
return {
|
||||||
|
table
|
||||||
|
for statement in statements
|
||||||
|
for table in self._extract_tables_from_statement(statement)
|
||||||
|
if statement
|
||||||
|
}
|
||||||
|
|
||||||
|
def _extract_tables_from_statement(self, statement: exp.Expression) -> set[Table]:
|
||||||
|
"""
|
||||||
|
Extract all table references in a single statement.
|
||||||
|
|
||||||
|
Please not that this is not trivial; consider the following queries:
|
||||||
|
|
||||||
|
DESCRIBE some_table;
|
||||||
|
SHOW PARTITIONS FROM some_table;
|
||||||
|
WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM masked_name;
|
||||||
|
|
||||||
|
See the unit tests for other tricky cases.
|
||||||
|
"""
|
||||||
|
sources: Iterable[exp.Table]
|
||||||
|
|
||||||
|
if isinstance(statement, exp.Describe):
|
||||||
|
# A `DESCRIBE` query has no sources in sqlglot, so we need to explicitly
|
||||||
|
# query for all tables.
|
||||||
|
sources = statement.find_all(exp.Table)
|
||||||
|
elif isinstance(statement, exp.Command):
|
||||||
|
# Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a
|
||||||
|
# `SELECT` statetement in order to extract tables.
|
||||||
|
literal = statement.find(exp.Literal)
|
||||||
|
if not literal:
|
||||||
|
return set()
|
||||||
|
|
||||||
|
pseudo_query = parse_one(f"SELECT {literal.this}", dialect=self._dialect)
|
||||||
|
sources = pseudo_query.find_all(exp.Table)
|
||||||
|
else:
|
||||||
|
sources = [
|
||||||
|
source
|
||||||
|
for scope in traverse_scope(statement)
|
||||||
|
for source in scope.sources.values()
|
||||||
|
if isinstance(source, exp.Table) and not self._is_cte(source, scope)
|
||||||
|
]
|
||||||
|
|
||||||
|
return {
|
||||||
|
Table(
|
||||||
|
source.name,
|
||||||
|
source.db if source.db != "" else None,
|
||||||
|
source.catalog if source.catalog != "" else None,
|
||||||
|
)
|
||||||
|
for source in sources
|
||||||
|
}
|
||||||
|
|
||||||
|
def _is_cte(self, source: exp.Table, scope: Scope) -> bool:
|
||||||
|
"""
|
||||||
|
Is the source a CTE?
|
||||||
|
|
||||||
|
CTEs in the parent scope look like tables (and are represented by
|
||||||
|
exp.Table objects), but should not be considered as such;
|
||||||
|
otherwise a user with access to table `foo` could access any table
|
||||||
|
with a query like this:
|
||||||
|
|
||||||
|
WITH foo AS (SELECT * FROM target_table) SELECT * FROM foo
|
||||||
|
|
||||||
|
"""
|
||||||
|
parent_sources = scope.parent.sources if scope.parent else {}
|
||||||
|
ctes_in_scope = {
|
||||||
|
name
|
||||||
|
for name, parent_scope in parent_sources.items()
|
||||||
|
if isinstance(parent_scope, Scope)
|
||||||
|
and parent_scope.scope_type == ScopeType.CTE
|
||||||
|
}
|
||||||
|
|
||||||
|
return source.name in ctes_in_scope
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def limit(self) -> Optional[int]:
|
def limit(self) -> Optional[int]:
|
||||||
return self._limit
|
return self._limit
|
||||||
|
|
@ -393,28 +543,6 @@ class ParsedQuery:
|
||||||
def _is_identifier(token: Token) -> bool:
|
def _is_identifier(token: Token) -> bool:
|
||||||
return isinstance(token, (IdentifierList, Identifier))
|
return isinstance(token, (IdentifierList, Identifier))
|
||||||
|
|
||||||
def _process_tokenlist(self, token_list: TokenList) -> None:
|
|
||||||
"""
|
|
||||||
Add table names to table set
|
|
||||||
|
|
||||||
:param token_list: TokenList to be processed
|
|
||||||
"""
|
|
||||||
# exclude subselects
|
|
||||||
if "(" not in str(token_list):
|
|
||||||
table = self.get_table(token_list)
|
|
||||||
if table and not table.table.startswith(CTE_PREFIX):
|
|
||||||
self._tables.add(table)
|
|
||||||
return
|
|
||||||
|
|
||||||
# store aliases
|
|
||||||
if token_list.has_alias():
|
|
||||||
self._alias_names.add(token_list.get_alias())
|
|
||||||
|
|
||||||
# some aliases are not parsed properly
|
|
||||||
if token_list.tokens[0].ttype == Name:
|
|
||||||
self._alias_names.add(token_list.tokens[0].value)
|
|
||||||
self._extract_from_token(token_list)
|
|
||||||
|
|
||||||
def as_create_table(
|
def as_create_table(
|
||||||
self,
|
self,
|
||||||
table_name: str,
|
table_name: str,
|
||||||
|
|
@ -441,50 +569,6 @@ class ParsedQuery:
|
||||||
exec_sql += f"CREATE {method} {full_table_name} AS \n{sql}"
|
exec_sql += f"CREATE {method} {full_table_name} AS \n{sql}"
|
||||||
return exec_sql
|
return exec_sql
|
||||||
|
|
||||||
def _extract_from_token(self, token: Token) -> None:
|
|
||||||
"""
|
|
||||||
<Identifier> store a list of subtokens and <IdentifierList> store lists of
|
|
||||||
subtoken list.
|
|
||||||
|
|
||||||
It extracts <IdentifierList> and <Identifier> from :param token: and loops
|
|
||||||
through all subtokens recursively. It finds table_name_preceding_token and
|
|
||||||
passes <IdentifierList> and <Identifier> to self._process_tokenlist to populate
|
|
||||||
self._tables.
|
|
||||||
|
|
||||||
:param token: instance of Token or child class, e.g. TokenList, to be processed
|
|
||||||
"""
|
|
||||||
if not hasattr(token, "tokens"):
|
|
||||||
return
|
|
||||||
|
|
||||||
table_name_preceding_token = False
|
|
||||||
|
|
||||||
for item in token.tokens:
|
|
||||||
if item.is_group and (
|
|
||||||
not self._is_identifier(item) or isinstance(item.tokens[0], Parenthesis)
|
|
||||||
):
|
|
||||||
self._extract_from_token(item)
|
|
||||||
|
|
||||||
if item.ttype in Keyword and (
|
|
||||||
item.normalized in PRECEDES_TABLE_NAME
|
|
||||||
or item.normalized.endswith(" JOIN")
|
|
||||||
):
|
|
||||||
table_name_preceding_token = True
|
|
||||||
continue
|
|
||||||
|
|
||||||
if item.ttype in Keyword:
|
|
||||||
table_name_preceding_token = False
|
|
||||||
continue
|
|
||||||
if table_name_preceding_token:
|
|
||||||
if isinstance(item, Identifier):
|
|
||||||
self._process_tokenlist(item)
|
|
||||||
elif isinstance(item, IdentifierList):
|
|
||||||
for token2 in item.get_identifiers():
|
|
||||||
if isinstance(token2, TokenList):
|
|
||||||
self._process_tokenlist(token2)
|
|
||||||
elif isinstance(item, IdentifierList):
|
|
||||||
if any(not self._is_identifier(token2) for token2 in item.tokens):
|
|
||||||
self._extract_from_token(item)
|
|
||||||
|
|
||||||
def set_or_update_query_limit(self, new_limit: int, force: bool = False) -> str:
|
def set_or_update_query_limit(self, new_limit: int, force: bool = False) -> str:
|
||||||
"""Returns the query with the specified limit.
|
"""Returns the query with the specified limit.
|
||||||
|
|
||||||
|
|
@ -881,7 +965,7 @@ def insert_rls_in_predicate(
|
||||||
|
|
||||||
|
|
||||||
# mapping between sqloxide and SQLAlchemy dialects
|
# mapping between sqloxide and SQLAlchemy dialects
|
||||||
SQLOXITE_DIALECTS = {
|
SQLOXIDE_DIALECTS = {
|
||||||
"ansi": {"trino", "trinonative", "presto"},
|
"ansi": {"trino", "trinonative", "presto"},
|
||||||
"hive": {"hive", "databricks"},
|
"hive": {"hive", "databricks"},
|
||||||
"ms": {"mssql"},
|
"ms": {"mssql"},
|
||||||
|
|
@ -914,7 +998,7 @@ def extract_table_references(
|
||||||
tree = None
|
tree = None
|
||||||
|
|
||||||
if sqloxide_parse:
|
if sqloxide_parse:
|
||||||
for dialect, sqla_dialects in SQLOXITE_DIALECTS.items():
|
for dialect, sqla_dialects in SQLOXIDE_DIALECTS.items():
|
||||||
if sqla_dialect in sqla_dialects:
|
if sqla_dialect in sqla_dialects:
|
||||||
break
|
break
|
||||||
sql_text = RE_JINJA_BLOCK.sub(" ", sql_text)
|
sql_text = RE_JINJA_BLOCK.sub(" ", sql_text)
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
|
||||||
) -> Optional[SQLValidationAnnotation]:
|
) -> Optional[SQLValidationAnnotation]:
|
||||||
# pylint: disable=too-many-locals
|
# pylint: disable=too-many-locals
|
||||||
db_engine_spec = database.db_engine_spec
|
db_engine_spec = database.db_engine_spec
|
||||||
parsed_query = ParsedQuery(statement)
|
parsed_query = ParsedQuery(statement, engine=db_engine_spec.engine)
|
||||||
sql = parsed_query.stripped()
|
sql = parsed_query.stripped()
|
||||||
|
|
||||||
# Hook to allow environment-specific mutation (usually comments) to the SQL
|
# Hook to allow environment-specific mutation (usually comments) to the SQL
|
||||||
|
|
@ -154,7 +154,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
|
||||||
For example, "SELECT 1 FROM default.mytable" becomes "EXPLAIN (TYPE
|
For example, "SELECT 1 FROM default.mytable" becomes "EXPLAIN (TYPE
|
||||||
VALIDATE) SELECT 1 FROM default.mytable.
|
VALIDATE) SELECT 1 FROM default.mytable.
|
||||||
"""
|
"""
|
||||||
parsed_query = ParsedQuery(sql)
|
parsed_query = ParsedQuery(sql, engine=database.db_engine_spec.engine)
|
||||||
statements = parsed_query.get_statements()
|
statements = parsed_query.get_statements()
|
||||||
|
|
||||||
logger.info("Validating %i statement(s)", len(statements))
|
logger.info("Validating %i statement(s)", len(statements))
|
||||||
|
|
|
||||||
|
|
@ -58,7 +58,11 @@ class SqlQueryRenderImpl(SqlQueryRender):
|
||||||
database=query_model.database, query=query_model
|
database=query_model.database, query=query_model
|
||||||
)
|
)
|
||||||
|
|
||||||
parsed_query = ParsedQuery(query_model.sql, strip_comments=True)
|
parsed_query = ParsedQuery(
|
||||||
|
query_model.sql,
|
||||||
|
strip_comments=True,
|
||||||
|
engine=query_model.database.db_engine_spec.engine,
|
||||||
|
)
|
||||||
rendered_query = sql_template_processor.process_template(
|
rendered_query = sql_template_processor.process_template(
|
||||||
parsed_query.stripped(), **execution_context.template_params
|
parsed_query.stripped(), **execution_context.template_params
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -40,11 +40,11 @@ from superset.sql_parse import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def extract_tables(query: str) -> set[Table]:
|
def extract_tables(query: str, engine: Optional[str] = None) -> set[Table]:
|
||||||
"""
|
"""
|
||||||
Helper function to extract tables referenced in a query.
|
Helper function to extract tables referenced in a query.
|
||||||
"""
|
"""
|
||||||
return ParsedQuery(query).tables
|
return ParsedQuery(query, engine=engine).tables
|
||||||
|
|
||||||
|
|
||||||
def test_table() -> None:
|
def test_table() -> None:
|
||||||
|
|
@ -96,8 +96,13 @@ def test_extract_tables() -> None:
|
||||||
Table("left_table")
|
Table("left_table")
|
||||||
}
|
}
|
||||||
|
|
||||||
# reverse select
|
assert extract_tables(
|
||||||
assert extract_tables("FROM t1 SELECT field") == {Table("t1")}
|
"SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;"
|
||||||
|
) == {Table("forbidden_table")}
|
||||||
|
|
||||||
|
assert extract_tables(
|
||||||
|
"select * from (select * from forbidden_table) forbidden_table"
|
||||||
|
) == {Table("forbidden_table")}
|
||||||
|
|
||||||
|
|
||||||
def test_extract_tables_subselect() -> None:
|
def test_extract_tables_subselect() -> None:
|
||||||
|
|
@ -263,14 +268,16 @@ def test_extract_tables_illdefined() -> None:
|
||||||
assert extract_tables("SELECT * FROM schemaname.") == set()
|
assert extract_tables("SELECT * FROM schemaname.") == set()
|
||||||
assert extract_tables("SELECT * FROM catalogname.schemaname.") == set()
|
assert extract_tables("SELECT * FROM catalogname.schemaname.") == set()
|
||||||
assert extract_tables("SELECT * FROM catalogname..") == set()
|
assert extract_tables("SELECT * FROM catalogname..") == set()
|
||||||
assert extract_tables("SELECT * FROM catalogname..tbname") == set()
|
assert extract_tables("SELECT * FROM catalogname..tbname") == {
|
||||||
|
Table(table="tbname", schema=None, catalog="catalogname")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_extract_tables_show_tables_from() -> None:
|
def test_extract_tables_show_tables_from() -> None:
|
||||||
"""
|
"""
|
||||||
Test ``SHOW TABLES FROM``.
|
Test ``SHOW TABLES FROM``.
|
||||||
"""
|
"""
|
||||||
assert extract_tables("SHOW TABLES FROM s1 like '%order%'") == set()
|
assert extract_tables("SHOW TABLES FROM s1 like '%order%'", "mysql") == set()
|
||||||
|
|
||||||
|
|
||||||
def test_extract_tables_show_columns_from() -> None:
|
def test_extract_tables_show_columns_from() -> None:
|
||||||
|
|
@ -311,7 +318,7 @@ WHERE regionkey IN (SELECT regionkey FROM t2)
|
||||||
"""
|
"""
|
||||||
SELECT name
|
SELECT name
|
||||||
FROM t1
|
FROM t1
|
||||||
WHERE regionkey EXISTS (SELECT regionkey FROM t2)
|
WHERE EXISTS (SELECT 1 FROM t2 WHERE t1.regionkey = t2.regionkey);
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
== {Table("t1"), Table("t2")}
|
== {Table("t1"), Table("t2")}
|
||||||
|
|
@ -526,6 +533,18 @@ select * from (select key from q1) a
|
||||||
== {Table("src")}
|
== {Table("src")}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# weird query with circular dependency
|
||||||
|
assert (
|
||||||
|
extract_tables(
|
||||||
|
"""
|
||||||
|
with src as ( select key from q2 where key = '5'),
|
||||||
|
q2 as ( select key from src where key = '5')
|
||||||
|
select * from (select key from src) a
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
== set()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_extract_tables_multistatement() -> None:
|
def test_extract_tables_multistatement() -> None:
|
||||||
"""
|
"""
|
||||||
|
|
@ -665,7 +684,8 @@ def test_extract_tables_nested_select() -> None:
|
||||||
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME)
|
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME)
|
||||||
from INFORMATION_SCHEMA.COLUMNS
|
from INFORMATION_SCHEMA.COLUMNS
|
||||||
WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
|
WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
|
||||||
"""
|
""",
|
||||||
|
"mysql",
|
||||||
)
|
)
|
||||||
== {Table("COLUMNS", "INFORMATION_SCHEMA")}
|
== {Table("COLUMNS", "INFORMATION_SCHEMA")}
|
||||||
)
|
)
|
||||||
|
|
@ -676,7 +696,8 @@ WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
|
||||||
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
|
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
|
||||||
from INFORMATION_SCHEMA.COLUMNS
|
from INFORMATION_SCHEMA.COLUMNS
|
||||||
WHERE TABLE_NAME="bi_achievement_daily"),0x7e)));
|
WHERE TABLE_NAME="bi_achievement_daily"),0x7e)));
|
||||||
"""
|
""",
|
||||||
|
"mysql",
|
||||||
)
|
)
|
||||||
== {Table("COLUMNS", "INFORMATION_SCHEMA")}
|
== {Table("COLUMNS", "INFORMATION_SCHEMA")}
|
||||||
)
|
)
|
||||||
|
|
@ -1306,6 +1327,14 @@ def test_sqlparse_issue_652():
|
||||||
"(SELECT table_name FROM /**/ information_schema.tables WHERE table_name LIKE '%user%' LIMIT 1)",
|
"(SELECT table_name FROM /**/ information_schema.tables WHERE table_name LIKE '%user%' LIMIT 1)",
|
||||||
True,
|
True,
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
"SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;",
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"SELECT * FROM (SELECT * FROM forbidden_table) forbidden_table",
|
||||||
|
True,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_has_table_query(sql: str, expected: bool) -> None:
|
def test_has_table_query(sql: str, expected: bool) -> None:
|
||||||
|
|
@ -1790,13 +1819,17 @@ def test_extract_table_references(mocker: MockerFixture) -> None:
|
||||||
assert extract_table_references(
|
assert extract_table_references(
|
||||||
sql,
|
sql,
|
||||||
"trino",
|
"trino",
|
||||||
) == {Table(table="other_table", schema=None, catalog=None)}
|
) == {
|
||||||
|
Table(table="table", schema=None, catalog=None),
|
||||||
|
Table(table="other_table", schema=None, catalog=None),
|
||||||
|
}
|
||||||
logger.warning.assert_called_once()
|
logger.warning.assert_called_once()
|
||||||
|
|
||||||
logger = mocker.patch("superset.migrations.shared.utils.logger")
|
logger = mocker.patch("superset.migrations.shared.utils.logger")
|
||||||
sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table"
|
sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table"
|
||||||
assert extract_table_references(sql, "trino", show_warning=False) == {
|
assert extract_table_references(sql, "trino", show_warning=False) == {
|
||||||
Table(table="other_table", schema=None, catalog=None)
|
Table(table="table", schema=None, catalog=None),
|
||||||
|
Table(table="other_table", schema=None, catalog=None),
|
||||||
}
|
}
|
||||||
logger.warning.assert_not_called()
|
logger.warning.assert_not_called()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue