refactor: remove more sqlparse (#31032)

This commit is contained in:
Beto Dealmeida 2024-11-26 17:01:07 -05:00 committed by GitHub
parent 9224051b80
commit 09802acf0d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 95 additions and 172 deletions

View File

@ -36,7 +36,7 @@ from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetErrorException
from superset.extensions import db
from superset.models.core import Database
from superset.sql_parse import ParsedQuery, Table
from superset.sql_parse import Table
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -70,10 +70,7 @@ class DuplicateDatasetCommand(CreateMixin, BaseCommand):
table.normalize_columns = self._base_model.normalize_columns
table.always_filter_main_dttm = self._base_model.always_filter_main_dttm
table.is_sqllab_view = True
table.sql = ParsedQuery(
self._base_model.sql,
engine=database.db_engine_spec.engine,
).stripped()
table.sql = self._base_model.sql.strip().strip(";")
db.session.add(table)
cols = []
for config_ in self._base_model.columns:

View File

@ -1784,7 +1784,7 @@ GUEST_TOKEN_VALIDATOR_HOOK = None
# def DATASET_HEALTH_CHECK(datasource: SqlaTable) -> Optional[str]:
# if (
# datasource.sql and
# len(sql_parse.ParsedQuery(datasource.sql, strip_comments=True).tables) == 1
# len(SQLScript(datasource.sql).tables) == 1
# ):
# return (
# "This virtual dataset queries only one table and therefore could be "

View File

@ -67,7 +67,7 @@ from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import column, ColumnElement, literal_column, table
from sqlalchemy.sql.elements import ColumnClause, TextClause
from sqlalchemy.sql.expression import Label, TextAsFrom
from sqlalchemy.sql.expression import Label
from sqlalchemy.sql.selectable import Alias, TableClause
from superset import app, db, is_feature_enabled, security_manager
@ -104,7 +104,7 @@ from superset.models.helpers import (
QueryResult,
)
from superset.models.slice import Slice
from superset.sql_parse import ParsedQuery, Table
from superset.sql_parse import Table
from superset.superset_typing import (
AdhocColumn,
AdhocMetric,
@ -1469,34 +1469,13 @@ class SqlaTable(
return tbl
def get_from_clause(
self, template_processor: BaseTemplateProcessor | None = None
self,
template_processor: BaseTemplateProcessor | None = None,
) -> tuple[TableClause | Alias, str | None]:
"""
Return where to select the columns and metrics from. Either a physical table
or a virtual table with it's own subquery. If the FROM is referencing a
CTE, the CTE is returned as the second value in the return tuple.
"""
if not self.is_virtual:
return self.get_sqla_table(), None
from_sql = self.get_rendered_sql(template_processor) + "\n"
parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
if not (
parsed_query.is_unknown()
or self.db_engine_spec.is_readonly_query(parsed_query)
):
raise QueryObjectValidationError(
_("Virtual dataset query must be read-only")
)
cte = self.db_engine_spec.get_cte_query(from_sql)
from_clause = (
table(self.db_engine_spec.cte_alias)
if cte
else TextAsFrom(self.text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS)
)
return from_clause, cte
return super().get_from_clause(template_processor)
def adhoc_metric_to_sqla(
self,

View File

@ -63,7 +63,7 @@ from superset.constants import QUERY_CANCEL_KEY, TimeGrain as TimeGrainConstants
from superset.databases.utils import get_table_metadata, make_url_safe
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import DisallowedSQLFunction, OAuth2Error, OAuth2RedirectError
from superset.sql.parse import SQLScript, Table
from superset.sql.parse import BaseSQLStatement, SQLScript, Table
from superset.sql_parse import ParsedQuery
from superset.superset_typing import (
OAuth2ClientConfig,
@ -1737,18 +1737,19 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
)
@classmethod
def process_statement(cls, statement: str, database: Database) -> str:
def process_statement(
cls,
statement: BaseSQLStatement[Any],
database: Database,
) -> str:
"""
Process a SQL statement by stripping and mutating it.
Process a SQL statement by mutating it.
:param statement: A single SQL statement
:param database: Database instance
:return: Dictionary with different costs
"""
parsed_query = ParsedQuery(statement, engine=cls.engine)
sql = parsed_query.stripped()
return database.mutate_sql_based_on_config(sql, is_split=True)
return database.mutate_sql_based_on_config(str(statement), is_split=True)
@classmethod
def estimate_query_cost( # pylint: disable=too-many-arguments
@ -1773,8 +1774,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
"Database does not support cost estimation"
)
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
statements = parsed_query.get_statements()
parsed_script = SQLScript(sql, engine=cls.engine)
with database.get_raw_connection(
catalog=catalog,
@ -1788,7 +1788,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
cls.process_statement(statement, database),
cursor,
)
for statement in statements
for statement in parsed_script.statements
]
@classmethod
@ -2056,15 +2056,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
logger.error(ex, exc_info=True)
raise
@classmethod
def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool:
"""Pessimistic readonly, 100% sure statement won't mutate anything"""
return (
parsed_query.is_select()
or parsed_query.is_explain()
or parsed_query.is_show()
)
@classmethod
def is_select_query(cls, parsed_query: ParsedQuery) -> bool:
"""

View File

@ -36,7 +36,6 @@ from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from sqlalchemy.sql import sqltypes
from superset import sql_parse
from superset.constants import TimeGrain
from superset.databases.schemas import encrypted_field_properties, EncryptedString
from superset.databases.utils import make_url_safe
@ -44,6 +43,7 @@ from superset.db_engine_specs.base import BaseEngineSpec, BasicPropertiesType
from superset.db_engine_specs.exceptions import SupersetDBAPIConnectionError
from superset.errors import SupersetError, SupersetErrorType
from superset.exceptions import SupersetException
from superset.sql.parse import SQLScript
from superset.sql_parse import Table
from superset.superset_typing import ResultSetColumnType
from superset.utils import core as utils, json
@ -449,8 +449,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
if not cls.get_allow_cost_estimate(extra):
raise SupersetException("Database does not support cost estimation")
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
statements = parsed_query.get_statements()
parsed_script = SQLScript(sql, engine=cls.engine)
with cls.get_engine(
database,
@ -463,7 +462,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
cls.process_statement(statement, database),
client,
)
for statement in statements
for statement in parsed_script.statements
]
@classmethod

View File

@ -45,7 +45,7 @@ from superset.db_engine_specs.presto import PrestoEngineSpec
from superset.exceptions import SupersetException
from superset.extensions import cache_manager
from superset.models.sql_lab import Query
from superset.sql_parse import ParsedQuery, Table
from superset.sql_parse import Table
from superset.superset_typing import ResultSetColumnType
if TYPE_CHECKING:
@ -598,15 +598,6 @@ class HiveEngineSpec(PrestoEngineSpec):
# otherwise, return no function names to prevent errors
return []
@classmethod
def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool:
"""Pessimistic readonly, 100% sure statement won't mutate anything"""
return (
super().is_readonly_query(parsed_query)
or parsed_query.is_set()
or parsed_query.is_show()
)
@classmethod
def has_implicit_cancel(cls) -> bool:
"""

View File

@ -104,11 +104,6 @@ class KustoSqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
return f"""CONVERT(DATETIME, '{datetime_formatted}', 126)"""
return None
@classmethod
def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool:
"""Pessimistic readonly, 100% sure statement won't mutate anything"""
return parsed_query.sql.lower().startswith("select")
class KustoKqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
limit_method = LimitMethod.WRAP_SQL
@ -158,15 +153,6 @@ class KustoKqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
return None
@classmethod
def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool:
"""
Pessimistic readonly, 100% sure statement won't mutate anything.
"""
return KustoKqlEngineSpec.is_select_query(
parsed_query
) or parsed_query.sql.startswith(".show")
@classmethod
def is_select_query(cls, parsed_query: ParsedQuery) -> bool:
return not parsed_query.sql.startswith(".")

View File

@ -72,7 +72,6 @@ from superset.sql.parse import SQLScript
from superset.sql_parse import (
has_table_query,
insert_rls_in_predicate,
ParsedQuery,
sanitize_clause,
)
from superset.superset_typing import (
@ -1039,6 +1038,9 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
"""
Render sql with template engine (Jinja).
"""
if not self.sql:
return ""
sql = self.sql.strip("\t\r\n; ")
if template_processor:
try:
@ -1072,13 +1074,9 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
or a virtual table with it's own subquery. If the FROM is referencing a
CTE, the CTE is returned as the second value in the return tuple.
"""
from_sql = self.get_rendered_sql(template_processor) + "\n"
parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
if not (
parsed_query.is_unknown()
or self.db_engine_spec.is_readonly_query(parsed_query)
):
parsed_script = SQLScript(from_sql, engine=self.db_engine_spec.engine)
if parsed_script.has_mutation():
raise QueryObjectValidationError(
_("Virtual dataset query must be read-only")
)

View File

@ -20,11 +20,11 @@ from __future__ import annotations
import logging
import time
from contextlib import closing
from typing import Any
from typing import Any, cast
from superset import app
from superset.models.core import Database
from superset.sql_parse import ParsedQuery
from superset.sql.parse import SQLScript, SQLStatement
from superset.sql_validators.base import BaseSQLValidator, SQLValidationAnnotation
from superset.utils.core import QuerySource
@ -46,17 +46,15 @@ class PrestoDBSQLValidator(BaseSQLValidator):
@classmethod
def validate_statement(
cls,
statement: str,
statement: SQLStatement,
database: Database,
cursor: Any,
) -> SQLValidationAnnotation | None:
# pylint: disable=too-many-locals
db_engine_spec = database.db_engine_spec
parsed_query = ParsedQuery(statement, engine=db_engine_spec.engine)
sql = parsed_query.stripped()
# Hook to allow environment-specific mutation (usually comments) to the SQL
sql = database.mutate_sql_based_on_config(sql)
sql = database.mutate_sql_based_on_config(str(statement))
# Transform the final statement to an explain call before sending it on
# to presto to validate
@ -155,10 +153,9 @@ class PrestoDBSQLValidator(BaseSQLValidator):
For example, "SELECT 1 FROM default.mytable" becomes "EXPLAIN (TYPE
VALIDATE) SELECT 1 FROM default.mytable.
"""
parsed_query = ParsedQuery(sql, engine=database.db_engine_spec.engine)
statements = parsed_query.get_statements()
parsed_script = SQLScript(sql, engine=database.db_engine_spec.engine)
logger.info("Validating %i statement(s)", len(statements))
logger.info("Validating %i statement(s)", len(parsed_script.statements))
# todo(hughhh): update this to use new database.get_raw_connection()
# this function keeps stalling CI
with database.get_sqla_engine(
@ -171,8 +168,12 @@ class PrestoDBSQLValidator(BaseSQLValidator):
annotations: list[SQLValidationAnnotation] = []
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
for statement in parsed_query.get_statements():
annotation = cls.validate_statement(statement, database, cursor)
for statement in parsed_script.statements:
annotation = cls.validate_statement(
cast(SQLStatement, statement),
database,
cursor,
)
if annotation:
annotations.append(annotation)
logger.debug("Validation found %i error(s)", len(annotations))

View File

@ -26,7 +26,6 @@ from jinja2.meta import find_undeclared_variables
from superset import is_feature_enabled
from superset.commands.sql_lab.execute import SqlQueryRender
from superset.errors import SupersetErrorType
from superset.sql_parse import ParsedQuery
from superset.sqllab.exceptions import SqlLabException
from superset.utils import core as utils
@ -58,12 +57,9 @@ class SqlQueryRenderImpl(SqlQueryRender):
database=query_model.database, query=query_model
)
parsed_query = ParsedQuery(
query_model.sql,
engine=query_model.database.db_engine_spec.engine,
)
rendered_query = sql_template_processor.process_template(
parsed_query.stripped(), **execution_context.template_params
query_model.sql.strip().strip(";"),
**execution_context.template_params,
)
self._validate(execution_context, rendered_query, sql_template_processor)
return rendered_query

View File

@ -30,7 +30,7 @@ from superset.db_engine_specs.base import (
from superset.db_engine_specs.mysql import MySQLEngineSpec
from superset.db_engine_specs.sqlite import SqliteEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.sql_parse import ParsedQuery, Table
from superset.sql_parse import Table
from superset.utils.database import get_example_database
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
from tests.integration_tests.test_app import app
@ -310,20 +310,6 @@ class TestDbEngineSpecs(TestDbEngineSpec):
)
def test_is_readonly():
def is_readonly(sql: str) -> bool:
return BaseEngineSpec.is_readonly_query(ParsedQuery(sql))
assert is_readonly("SHOW LOCKS test EXTENDED")
assert not is_readonly("SET hivevar:desc='Legislators'")
assert not is_readonly("UPDATE t1 SET col1 = NULL")
assert is_readonly("EXPLAIN SELECT 1")
assert is_readonly("SELECT 1")
assert is_readonly("WITH (SELECT 1) bla SELECT * from bla")
assert is_readonly("SHOW CATALOGS")
assert is_readonly("SHOW TABLES")
def test_time_grain_denylist():
config = app.config.copy()
app.config["TIME_GRAIN_DENYLIST"] = ["PT1M", "SQLITE_NONEXISTENT_GRAIN"]

View File

@ -25,7 +25,7 @@ from sqlalchemy.sql import select
from superset.db_engine_specs.hive import HiveEngineSpec, upload_to_s3
from superset.exceptions import SupersetException
from superset.sql_parse import ParsedQuery, Table
from superset.sql_parse import Table
from tests.integration_tests.test_app import app
@ -227,19 +227,6 @@ def test_df_to_sql_if_exists_replace_with_schema(mock_upload_to_s3, mock_g):
app.config = config
def test_is_readonly():
def is_readonly(sql: str) -> bool:
return HiveEngineSpec.is_readonly_query(ParsedQuery(sql))
assert not is_readonly("UPDATE t1 SET col1 = NULL")
assert not is_readonly("INSERT OVERWRITE TABLE tabB SELECT a.Age FROM TableA")
assert is_readonly("SHOW LOCKS test EXTENDED")
assert is_readonly("SET hivevar:desc='Legislators'")
assert is_readonly("EXPLAIN SELECT 1")
assert is_readonly("SELECT 1")
assert is_readonly("WITH (SELECT 1) bla SELECT * from bla")
@pytest.mark.parametrize(
"schema,upload_prefix",
[("foo", "EXTERNAL_HIVE_TABLES/1/foo/"), (None, "EXTERNAL_HIVE_TABLES/1/")],

View File

@ -25,7 +25,7 @@ from sqlalchemy.sql import select
from superset.db_engine_specs.presto import PrestoEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.sql_parse import ParsedQuery, Table
from superset.sql_parse import Table
from superset.utils.database import get_example_database
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
@ -1172,19 +1172,6 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
]
def test_is_readonly():
def is_readonly(sql: str) -> bool:
return PrestoEngineSpec.is_readonly_query(ParsedQuery(sql))
assert not is_readonly("SET hivevar:desc='Legislators'")
assert not is_readonly("UPDATE t1 SET col1 = NULL")
assert not is_readonly("INSERT OVERWRITE TABLE tabB SELECT a.Age FROM TableA")
assert is_readonly("SHOW LOCKS test EXTENDED")
assert is_readonly("EXPLAIN SELECT 1")
assert is_readonly("SELECT 1")
assert is_readonly("WITH (SELECT 1) bla SELECT * from bla")
def test_get_catalog_names(app_context: AppContext) -> None:
"""
Test the ``get_catalog_names`` method.

View File

@ -20,6 +20,8 @@ from typing import Optional
import pytest
from superset.sql.parse import SQLScript
from superset.sql_parse import ParsedQuery
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm # noqa: F401
@ -27,24 +29,26 @@ from tests.unit_tests.fixtures.common import dttm # noqa: F401
@pytest.mark.parametrize(
"sql,expected",
[
("SELECT foo FROM tbl", True),
("SELECT foo FROM tbl", False),
("SHOW TABLES", False),
("EXPLAIN SELECT foo FROM tbl", False),
("INSERT INTO tbl (foo) VALUES (1)", False),
("INSERT INTO tbl (foo) VALUES (1)", True),
],
)
def test_sql_is_readonly_query(sql: str, expected: bool) -> None:
def test_sql_has_mutation(sql: str, expected: bool) -> None:
"""
Make sure that SQL dialect consider only SELECT statements as read-only
"""
from superset.db_engine_specs.kusto import KustoSqlEngineSpec
from superset.sql_parse import ParsedQuery
parsed_query = ParsedQuery(sql)
is_readonly = KustoSqlEngineSpec.is_readonly_query(parsed_query)
assert expected == is_readonly
assert (
SQLScript(
sql,
engine=KustoSqlEngineSpec.engine,
).has_mutation()
== expected
)
@pytest.mark.parametrize(
@ -62,38 +66,37 @@ def test_kql_is_select_query(kql: str, expected: bool) -> None:
"""
from superset.db_engine_specs.kusto import KustoKqlEngineSpec
from superset.sql_parse import ParsedQuery
parsed_query = ParsedQuery(kql)
is_select = KustoKqlEngineSpec.is_select_query(parsed_query)
assert expected == is_select
assert KustoKqlEngineSpec.is_select_query(parsed_query) == expected
@pytest.mark.parametrize(
"kql,expected",
[
("tbl | limit 100", True),
("let foo = 1; tbl | where bar == foo", True),
(".show tables", True),
("print 1", True),
("set querytrace; Events | take 100", True),
(".drop table foo", False),
(".set-or-append table foo <| bar", False),
("tbl | limit 100", False),
("let foo = 1; tbl | where bar == foo", False),
(".show tables", False),
("print 1", False),
("set querytrace; Events | take 100", False),
(".drop table foo", True),
(".set-or-append table foo <| bar", True),
],
)
def test_kql_is_readonly_query(kql: str, expected: bool) -> None:
def test_kql_has_mutation(kql: str, expected: bool) -> None:
"""
Make sure that KQL dialect consider only SELECT statements as read-only
"""
from superset.db_engine_specs.kusto import KustoKqlEngineSpec
from superset.sql_parse import ParsedQuery
parsed_query = ParsedQuery(kql)
is_readonly = KustoKqlEngineSpec.is_readonly_query(parsed_query)
assert expected == is_readonly
assert (
SQLScript(
kql,
engine=KustoKqlEngineSpec.engine,
).has_mutation()
== expected
)
def test_kql_parse_sql() -> None:

View File

@ -945,6 +945,28 @@ on $left.Day1 == $right.Day
("kustokql", "set querytrace; Events | take 100", False),
("kustokql", ".drop table foo", True),
("kustokql", ".set-or-append table foo <| bar", True),
("base", "SHOW LOCKS test EXTENDED", False),
("base", "SET hivevar:desc='Legislators'", False),
("base", "UPDATE t1 SET col1 = NULL", True),
("base", "EXPLAIN SELECT 1", False),
("base", "SELECT 1", False),
("base", "WITH bla AS (SELECT 1) SELECT * FROM bla", False),
("base", "SHOW CATALOGS", False),
("base", "SHOW TABLES", False),
("hive", "UPDATE t1 SET col1 = NULL", True),
("hive", "INSERT OVERWRITE TABLE tabB SELECT a.Age FROM TableA", True),
("hive", "SHOW LOCKS test EXTENDED", False),
("hive", "SET hivevar:desc='Legislators'", False),
("hive", "EXPLAIN SELECT 1", False),
("hive", "SELECT 1", False),
("hive", "WITH bla AS (SELECT 1) SELECT * FROM bla", False),
("presto", "SET hivevar:desc='Legislators'", False),
("presto", "UPDATE t1 SET col1 = NULL", True),
("presto", "INSERT OVERWRITE TABLE tabB SELECT a.Age FROM TableA", True),
("presto", "SHOW LOCKS test EXTENDED", False),
("presto", "EXPLAIN SELECT 1", False),
("presto", "SELECT 1", False),
("presto", "WITH bla AS (SELECT 1) SELECT * FROM bla", False),
],
)
def test_has_mutation(engine: str, sql: str, expected: bool) -> None:
@ -1042,7 +1064,7 @@ def test_custom_dialect(app: None) -> None:
)
def test_is_mutating(engine: str) -> None:
"""
Tests for `is_mutating`.
Global tests for `is_mutating`, covering all supported engines.
"""
assert not SQLStatement(
"with source as ( select 1 as one ) select * from source",