From 09802acf0d8ba0939c22a2ca851576a1d1e47649 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Tue, 26 Nov 2024 17:01:07 -0500 Subject: [PATCH] refactor: remove more sqlparse (#31032) --- superset/commands/dataset/duplicate.py | 7 +-- superset/config.py | 2 +- superset/connectors/sqla/models.py | 31 ++--------- superset/db_engine_specs/base.py | 29 ++++------ superset/db_engine_specs/bigquery.py | 7 ++- superset/db_engine_specs/hive.py | 11 +--- superset/db_engine_specs/kusto.py | 14 ----- superset/models/helpers.py | 12 ++--- superset/sql_validators/presto_db.py | 23 ++++---- superset/sqllab/query_render.py | 8 +-- .../db_engine_specs/base_engine_spec_tests.py | 16 +----- .../db_engine_specs/hive_tests.py | 15 +----- .../db_engine_specs/presto_tests.py | 15 +----- .../unit_tests/db_engine_specs/test_kusto.py | 53 ++++++++++--------- tests/unit_tests/sql/parse_tests.py | 24 ++++++++- 15 files changed, 95 insertions(+), 172 deletions(-) diff --git a/superset/commands/dataset/duplicate.py b/superset/commands/dataset/duplicate.py index 8e82a7662..242ffa8e6 100644 --- a/superset/commands/dataset/duplicate.py +++ b/superset/commands/dataset/duplicate.py @@ -36,7 +36,7 @@ from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import SupersetErrorException from superset.extensions import db from superset.models.core import Database -from superset.sql_parse import ParsedQuery, Table +from superset.sql_parse import Table from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -70,10 +70,7 @@ class DuplicateDatasetCommand(CreateMixin, BaseCommand): table.normalize_columns = self._base_model.normalize_columns table.always_filter_main_dttm = self._base_model.always_filter_main_dttm table.is_sqllab_view = True - table.sql = ParsedQuery( - self._base_model.sql, - engine=database.db_engine_spec.engine, - ).stripped() + table.sql = self._base_model.sql.strip().strip(";") db.session.add(table) cols = [] for config_ in self._base_model.columns: diff --git a/superset/config.py b/superset/config.py index 9e92bf79b..acead4c2d 100644 --- a/superset/config.py +++ b/superset/config.py @@ -1784,7 +1784,7 @@ GUEST_TOKEN_VALIDATOR_HOOK = None # def DATASET_HEALTH_CHECK(datasource: SqlaTable) -> Optional[str]: # if ( # datasource.sql and -# len(sql_parse.ParsedQuery(datasource.sql, strip_comments=True).tables) == 1 +# len(SQLScript(datasource.sql).tables) == 1 # ): # return ( # "This virtual dataset queries only one table and therefore could be " diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 292a230ae..01889cb2e 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -67,7 +67,7 @@ from sqlalchemy.orm.mapper import Mapper from sqlalchemy.schema import UniqueConstraint from sqlalchemy.sql import column, ColumnElement, literal_column, table from sqlalchemy.sql.elements import ColumnClause, TextClause -from sqlalchemy.sql.expression import Label, TextAsFrom +from sqlalchemy.sql.expression import Label from sqlalchemy.sql.selectable import Alias, TableClause from superset import app, db, is_feature_enabled, security_manager @@ -104,7 +104,7 @@ from superset.models.helpers import ( QueryResult, ) from superset.models.slice import Slice -from superset.sql_parse import ParsedQuery, Table +from superset.sql_parse import Table from superset.superset_typing import ( AdhocColumn, AdhocMetric, @@ -1469,34 +1469,13 @@ class SqlaTable( return tbl def get_from_clause( - self, template_processor: BaseTemplateProcessor | None = None + self, + template_processor: BaseTemplateProcessor | None = None, ) -> tuple[TableClause | Alias, str | None]: - """ - Return where to select the columns and metrics from. Either a physical table - or a virtual table with it's own subquery. If the FROM is referencing a - CTE, the CTE is returned as the second value in the return tuple. - """ if not self.is_virtual: return self.get_sqla_table(), None - from_sql = self.get_rendered_sql(template_processor) + "\n" - parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine) - if not ( - parsed_query.is_unknown() - or self.db_engine_spec.is_readonly_query(parsed_query) - ): - raise QueryObjectValidationError( - _("Virtual dataset query must be read-only") - ) - - cte = self.db_engine_spec.get_cte_query(from_sql) - from_clause = ( - table(self.db_engine_spec.cte_alias) - if cte - else TextAsFrom(self.text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS) - ) - - return from_clause, cte + return super().get_from_clause(template_processor) def adhoc_metric_to_sqla( self, diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 8cabb1e58..e6c52d684 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -63,7 +63,7 @@ from superset.constants import QUERY_CANCEL_KEY, TimeGrain as TimeGrainConstants from superset.databases.utils import get_table_metadata, make_url_safe from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import DisallowedSQLFunction, OAuth2Error, OAuth2RedirectError -from superset.sql.parse import SQLScript, Table +from superset.sql.parse import BaseSQLStatement, SQLScript, Table from superset.sql_parse import ParsedQuery from superset.superset_typing import ( OAuth2ClientConfig, @@ -1737,18 +1737,19 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods ) @classmethod - def process_statement(cls, statement: str, database: Database) -> str: + def process_statement( + cls, + statement: BaseSQLStatement[Any], + database: Database, + ) -> str: """ - Process a SQL statement by stripping and mutating it. + Process a SQL statement by mutating it. :param statement: A single SQL statement :param database: Database instance :return: Dictionary with different costs """ - parsed_query = ParsedQuery(statement, engine=cls.engine) - sql = parsed_query.stripped() - - return database.mutate_sql_based_on_config(sql, is_split=True) + return database.mutate_sql_based_on_config(str(statement), is_split=True) @classmethod def estimate_query_cost( # pylint: disable=too-many-arguments @@ -1773,8 +1774,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods "Database does not support cost estimation" ) - parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine) - statements = parsed_query.get_statements() + parsed_script = SQLScript(sql, engine=cls.engine) with database.get_raw_connection( catalog=catalog, @@ -1788,7 +1788,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods cls.process_statement(statement, database), cursor, ) - for statement in statements + for statement in parsed_script.statements ] @classmethod @@ -2056,15 +2056,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods logger.error(ex, exc_info=True) raise - @classmethod - def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool: - """Pessimistic readonly, 100% sure statement won't mutate anything""" - return ( - parsed_query.is_select() - or parsed_query.is_explain() - or parsed_query.is_show() - ) - @classmethod def is_select_query(cls, parsed_query: ParsedQuery) -> bool: """ diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py index 70bc4bc84..74a5d151f 100644 --- a/superset/db_engine_specs/bigquery.py +++ b/superset/db_engine_specs/bigquery.py @@ -36,7 +36,6 @@ from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import URL from sqlalchemy.sql import sqltypes -from superset import sql_parse from superset.constants import TimeGrain from superset.databases.schemas import encrypted_field_properties, EncryptedString from superset.databases.utils import make_url_safe @@ -44,6 +43,7 @@ from superset.db_engine_specs.base import BaseEngineSpec, BasicPropertiesType from superset.db_engine_specs.exceptions import SupersetDBAPIConnectionError from superset.errors import SupersetError, SupersetErrorType from superset.exceptions import SupersetException +from superset.sql.parse import SQLScript from superset.sql_parse import Table from superset.superset_typing import ResultSetColumnType from superset.utils import core as utils, json @@ -449,8 +449,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met if not cls.get_allow_cost_estimate(extra): raise SupersetException("Database does not support cost estimation") - parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine) - statements = parsed_query.get_statements() + parsed_script = SQLScript(sql, engine=cls.engine) with cls.get_engine( database, @@ -463,7 +462,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met cls.process_statement(statement, database), client, ) - for statement in statements + for statement in parsed_script.statements ] @classmethod diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index b9dc09b1d..c2cde83c7 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -45,7 +45,7 @@ from superset.db_engine_specs.presto import PrestoEngineSpec from superset.exceptions import SupersetException from superset.extensions import cache_manager from superset.models.sql_lab import Query -from superset.sql_parse import ParsedQuery, Table +from superset.sql_parse import Table from superset.superset_typing import ResultSetColumnType if TYPE_CHECKING: @@ -598,15 +598,6 @@ class HiveEngineSpec(PrestoEngineSpec): # otherwise, return no function names to prevent errors return [] - @classmethod - def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool: - """Pessimistic readonly, 100% sure statement won't mutate anything""" - return ( - super().is_readonly_query(parsed_query) - or parsed_query.is_set() - or parsed_query.is_show() - ) - @classmethod def has_implicit_cancel(cls) -> bool: """ diff --git a/superset/db_engine_specs/kusto.py b/superset/db_engine_specs/kusto.py index 332d1c1c6..696faf74b 100644 --- a/superset/db_engine_specs/kusto.py +++ b/superset/db_engine_specs/kusto.py @@ -104,11 +104,6 @@ class KustoSqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method return f"""CONVERT(DATETIME, '{datetime_formatted}', 126)""" return None - @classmethod - def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool: - """Pessimistic readonly, 100% sure statement won't mutate anything""" - return parsed_query.sql.lower().startswith("select") - class KustoKqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method limit_method = LimitMethod.WRAP_SQL @@ -158,15 +153,6 @@ class KustoKqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method return None - @classmethod - def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool: - """ - Pessimistic readonly, 100% sure statement won't mutate anything. - """ - return KustoKqlEngineSpec.is_select_query( - parsed_query - ) or parsed_query.sql.startswith(".show") - @classmethod def is_select_query(cls, parsed_query: ParsedQuery) -> bool: return not parsed_query.sql.startswith(".") diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 51808f9a4..feb05a401 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -72,7 +72,6 @@ from superset.sql.parse import SQLScript from superset.sql_parse import ( has_table_query, insert_rls_in_predicate, - ParsedQuery, sanitize_clause, ) from superset.superset_typing import ( @@ -1039,6 +1038,9 @@ class ExploreMixin: # pylint: disable=too-many-public-methods """ Render sql with template engine (Jinja). """ + if not self.sql: + return "" + sql = self.sql.strip("\t\r\n; ") if template_processor: try: @@ -1072,13 +1074,9 @@ class ExploreMixin: # pylint: disable=too-many-public-methods or a virtual table with it's own subquery. If the FROM is referencing a CTE, the CTE is returned as the second value in the return tuple. """ - from_sql = self.get_rendered_sql(template_processor) + "\n" - parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine) - if not ( - parsed_query.is_unknown() - or self.db_engine_spec.is_readonly_query(parsed_query) - ): + parsed_script = SQLScript(from_sql, engine=self.db_engine_spec.engine) + if parsed_script.has_mutation(): raise QueryObjectValidationError( _("Virtual dataset query must be read-only") ) diff --git a/superset/sql_validators/presto_db.py b/superset/sql_validators/presto_db.py index 010272ea3..e247a0322 100644 --- a/superset/sql_validators/presto_db.py +++ b/superset/sql_validators/presto_db.py @@ -20,11 +20,11 @@ from __future__ import annotations import logging import time from contextlib import closing -from typing import Any +from typing import Any, cast from superset import app from superset.models.core import Database -from superset.sql_parse import ParsedQuery +from superset.sql.parse import SQLScript, SQLStatement from superset.sql_validators.base import BaseSQLValidator, SQLValidationAnnotation from superset.utils.core import QuerySource @@ -46,17 +46,15 @@ class PrestoDBSQLValidator(BaseSQLValidator): @classmethod def validate_statement( cls, - statement: str, + statement: SQLStatement, database: Database, cursor: Any, ) -> SQLValidationAnnotation | None: # pylint: disable=too-many-locals db_engine_spec = database.db_engine_spec - parsed_query = ParsedQuery(statement, engine=db_engine_spec.engine) - sql = parsed_query.stripped() # Hook to allow environment-specific mutation (usually comments) to the SQL - sql = database.mutate_sql_based_on_config(sql) + sql = database.mutate_sql_based_on_config(str(statement)) # Transform the final statement to an explain call before sending it on # to presto to validate @@ -155,10 +153,9 @@ class PrestoDBSQLValidator(BaseSQLValidator): For example, "SELECT 1 FROM default.mytable" becomes "EXPLAIN (TYPE VALIDATE) SELECT 1 FROM default.mytable. """ - parsed_query = ParsedQuery(sql, engine=database.db_engine_spec.engine) - statements = parsed_query.get_statements() + parsed_script = SQLScript(sql, engine=database.db_engine_spec.engine) - logger.info("Validating %i statement(s)", len(statements)) + logger.info("Validating %i statement(s)", len(parsed_script.statements)) # todo(hughhh): update this to use new database.get_raw_connection() # this function keeps stalling CI with database.get_sqla_engine( @@ -171,8 +168,12 @@ class PrestoDBSQLValidator(BaseSQLValidator): annotations: list[SQLValidationAnnotation] = [] with closing(engine.raw_connection()) as conn: cursor = conn.cursor() - for statement in parsed_query.get_statements(): - annotation = cls.validate_statement(statement, database, cursor) + for statement in parsed_script.statements: + annotation = cls.validate_statement( + cast(SQLStatement, statement), + database, + cursor, + ) if annotation: annotations.append(annotation) logger.debug("Validation found %i error(s)", len(annotations)) diff --git a/superset/sqllab/query_render.py b/superset/sqllab/query_render.py index 67bea2b1f..7d41d7fb0 100644 --- a/superset/sqllab/query_render.py +++ b/superset/sqllab/query_render.py @@ -26,7 +26,6 @@ from jinja2.meta import find_undeclared_variables from superset import is_feature_enabled from superset.commands.sql_lab.execute import SqlQueryRender from superset.errors import SupersetErrorType -from superset.sql_parse import ParsedQuery from superset.sqllab.exceptions import SqlLabException from superset.utils import core as utils @@ -58,12 +57,9 @@ class SqlQueryRenderImpl(SqlQueryRender): database=query_model.database, query=query_model ) - parsed_query = ParsedQuery( - query_model.sql, - engine=query_model.database.db_engine_spec.engine, - ) rendered_query = sql_template_processor.process_template( - parsed_query.stripped(), **execution_context.template_params + query_model.sql.strip().strip(";"), + **execution_context.template_params, ) self._validate(execution_context, rendered_query, sql_template_processor) return rendered_query diff --git a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py index 715657e4f..916de39cd 100644 --- a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py +++ b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py @@ -30,7 +30,7 @@ from superset.db_engine_specs.base import ( from superset.db_engine_specs.mysql import MySQLEngineSpec from superset.db_engine_specs.sqlite import SqliteEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType -from superset.sql_parse import ParsedQuery, Table +from superset.sql_parse import Table from superset.utils.database import get_example_database from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec from tests.integration_tests.test_app import app @@ -310,20 +310,6 @@ class TestDbEngineSpecs(TestDbEngineSpec): ) -def test_is_readonly(): - def is_readonly(sql: str) -> bool: - return BaseEngineSpec.is_readonly_query(ParsedQuery(sql)) - - assert is_readonly("SHOW LOCKS test EXTENDED") - assert not is_readonly("SET hivevar:desc='Legislators'") - assert not is_readonly("UPDATE t1 SET col1 = NULL") - assert is_readonly("EXPLAIN SELECT 1") - assert is_readonly("SELECT 1") - assert is_readonly("WITH (SELECT 1) bla SELECT * from bla") - assert is_readonly("SHOW CATALOGS") - assert is_readonly("SHOW TABLES") - - def test_time_grain_denylist(): config = app.config.copy() app.config["TIME_GRAIN_DENYLIST"] = ["PT1M", "SQLITE_NONEXISTENT_GRAIN"] diff --git a/tests/integration_tests/db_engine_specs/hive_tests.py b/tests/integration_tests/db_engine_specs/hive_tests.py index 77d1b4d6a..7ef4854cd 100644 --- a/tests/integration_tests/db_engine_specs/hive_tests.py +++ b/tests/integration_tests/db_engine_specs/hive_tests.py @@ -25,7 +25,7 @@ from sqlalchemy.sql import select from superset.db_engine_specs.hive import HiveEngineSpec, upload_to_s3 from superset.exceptions import SupersetException -from superset.sql_parse import ParsedQuery, Table +from superset.sql_parse import Table from tests.integration_tests.test_app import app @@ -227,19 +227,6 @@ def test_df_to_sql_if_exists_replace_with_schema(mock_upload_to_s3, mock_g): app.config = config -def test_is_readonly(): - def is_readonly(sql: str) -> bool: - return HiveEngineSpec.is_readonly_query(ParsedQuery(sql)) - - assert not is_readonly("UPDATE t1 SET col1 = NULL") - assert not is_readonly("INSERT OVERWRITE TABLE tabB SELECT a.Age FROM TableA") - assert is_readonly("SHOW LOCKS test EXTENDED") - assert is_readonly("SET hivevar:desc='Legislators'") - assert is_readonly("EXPLAIN SELECT 1") - assert is_readonly("SELECT 1") - assert is_readonly("WITH (SELECT 1) bla SELECT * from bla") - - @pytest.mark.parametrize( "schema,upload_prefix", [("foo", "EXTERNAL_HIVE_TABLES/1/foo/"), (None, "EXTERNAL_HIVE_TABLES/1/")], diff --git a/tests/integration_tests/db_engine_specs/presto_tests.py b/tests/integration_tests/db_engine_specs/presto_tests.py index 94e3ea627..9d83bb5bb 100644 --- a/tests/integration_tests/db_engine_specs/presto_tests.py +++ b/tests/integration_tests/db_engine_specs/presto_tests.py @@ -25,7 +25,7 @@ from sqlalchemy.sql import select from superset.db_engine_specs.presto import PrestoEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType -from superset.sql_parse import ParsedQuery, Table +from superset.sql_parse import Table from superset.utils.database import get_example_database from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec @@ -1172,19 +1172,6 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec): ] -def test_is_readonly(): - def is_readonly(sql: str) -> bool: - return PrestoEngineSpec.is_readonly_query(ParsedQuery(sql)) - - assert not is_readonly("SET hivevar:desc='Legislators'") - assert not is_readonly("UPDATE t1 SET col1 = NULL") - assert not is_readonly("INSERT OVERWRITE TABLE tabB SELECT a.Age FROM TableA") - assert is_readonly("SHOW LOCKS test EXTENDED") - assert is_readonly("EXPLAIN SELECT 1") - assert is_readonly("SELECT 1") - assert is_readonly("WITH (SELECT 1) bla SELECT * from bla") - - def test_get_catalog_names(app_context: AppContext) -> None: """ Test the ``get_catalog_names`` method. diff --git a/tests/unit_tests/db_engine_specs/test_kusto.py b/tests/unit_tests/db_engine_specs/test_kusto.py index 9fc1cd39f..68330ed2e 100644 --- a/tests/unit_tests/db_engine_specs/test_kusto.py +++ b/tests/unit_tests/db_engine_specs/test_kusto.py @@ -20,6 +20,8 @@ from typing import Optional import pytest +from superset.sql.parse import SQLScript +from superset.sql_parse import ParsedQuery from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm from tests.unit_tests.fixtures.common import dttm # noqa: F401 @@ -27,24 +29,26 @@ from tests.unit_tests.fixtures.common import dttm # noqa: F401 @pytest.mark.parametrize( "sql,expected", [ - ("SELECT foo FROM tbl", True), + ("SELECT foo FROM tbl", False), ("SHOW TABLES", False), ("EXPLAIN SELECT foo FROM tbl", False), - ("INSERT INTO tbl (foo) VALUES (1)", False), + ("INSERT INTO tbl (foo) VALUES (1)", True), ], ) -def test_sql_is_readonly_query(sql: str, expected: bool) -> None: +def test_sql_has_mutation(sql: str, expected: bool) -> None: """ Make sure that SQL dialect consider only SELECT statements as read-only """ from superset.db_engine_specs.kusto import KustoSqlEngineSpec - from superset.sql_parse import ParsedQuery - parsed_query = ParsedQuery(sql) - is_readonly = KustoSqlEngineSpec.is_readonly_query(parsed_query) - - assert expected == is_readonly + assert ( + SQLScript( + sql, + engine=KustoSqlEngineSpec.engine, + ).has_mutation() + == expected + ) @pytest.mark.parametrize( @@ -62,38 +66,37 @@ def test_kql_is_select_query(kql: str, expected: bool) -> None: """ from superset.db_engine_specs.kusto import KustoKqlEngineSpec - from superset.sql_parse import ParsedQuery parsed_query = ParsedQuery(kql) - is_select = KustoKqlEngineSpec.is_select_query(parsed_query) - - assert expected == is_select + assert KustoKqlEngineSpec.is_select_query(parsed_query) == expected @pytest.mark.parametrize( "kql,expected", [ - ("tbl | limit 100", True), - ("let foo = 1; tbl | where bar == foo", True), - (".show tables", True), - ("print 1", True), - ("set querytrace; Events | take 100", True), - (".drop table foo", False), - (".set-or-append table foo <| bar", False), + ("tbl | limit 100", False), + ("let foo = 1; tbl | where bar == foo", False), + (".show tables", False), + ("print 1", False), + ("set querytrace; Events | take 100", False), + (".drop table foo", True), + (".set-or-append table foo <| bar", True), ], ) -def test_kql_is_readonly_query(kql: str, expected: bool) -> None: +def test_kql_has_mutation(kql: str, expected: bool) -> None: """ Make sure that KQL dialect consider only SELECT statements as read-only """ from superset.db_engine_specs.kusto import KustoKqlEngineSpec - from superset.sql_parse import ParsedQuery - parsed_query = ParsedQuery(kql) - is_readonly = KustoKqlEngineSpec.is_readonly_query(parsed_query) - - assert expected == is_readonly + assert ( + SQLScript( + kql, + engine=KustoKqlEngineSpec.engine, + ).has_mutation() + == expected + ) def test_kql_parse_sql() -> None: diff --git a/tests/unit_tests/sql/parse_tests.py b/tests/unit_tests/sql/parse_tests.py index 726e2294e..4911d4c0e 100644 --- a/tests/unit_tests/sql/parse_tests.py +++ b/tests/unit_tests/sql/parse_tests.py @@ -945,6 +945,28 @@ on $left.Day1 == $right.Day ("kustokql", "set querytrace; Events | take 100", False), ("kustokql", ".drop table foo", True), ("kustokql", ".set-or-append table foo <| bar", True), + ("base", "SHOW LOCKS test EXTENDED", False), + ("base", "SET hivevar:desc='Legislators'", False), + ("base", "UPDATE t1 SET col1 = NULL", True), + ("base", "EXPLAIN SELECT 1", False), + ("base", "SELECT 1", False), + ("base", "WITH bla AS (SELECT 1) SELECT * FROM bla", False), + ("base", "SHOW CATALOGS", False), + ("base", "SHOW TABLES", False), + ("hive", "UPDATE t1 SET col1 = NULL", True), + ("hive", "INSERT OVERWRITE TABLE tabB SELECT a.Age FROM TableA", True), + ("hive", "SHOW LOCKS test EXTENDED", False), + ("hive", "SET hivevar:desc='Legislators'", False), + ("hive", "EXPLAIN SELECT 1", False), + ("hive", "SELECT 1", False), + ("hive", "WITH bla AS (SELECT 1) SELECT * FROM bla", False), + ("presto", "SET hivevar:desc='Legislators'", False), + ("presto", "UPDATE t1 SET col1 = NULL", True), + ("presto", "INSERT OVERWRITE TABLE tabB SELECT a.Age FROM TableA", True), + ("presto", "SHOW LOCKS test EXTENDED", False), + ("presto", "EXPLAIN SELECT 1", False), + ("presto", "SELECT 1", False), + ("presto", "WITH bla AS (SELECT 1) SELECT * FROM bla", False), ], ) def test_has_mutation(engine: str, sql: str, expected: bool) -> None: @@ -1042,7 +1064,7 @@ def test_custom_dialect(app: None) -> None: ) def test_is_mutating(engine: str) -> None: """ - Tests for `is_mutating`. + Global tests for `is_mutating`, covering all supported engines. """ assert not SQLStatement( "with source as ( select 1 as one ) select * from source",