From cc9fd88c0d2a2406f202da7dcea2f0d14f1d017e Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Fri, 27 Sep 2024 15:26:36 -0400 Subject: [PATCH] chore: improve DML check (#30417) --- pyproject.toml | 2 +- requirements/base.txt | 2 +- superset/sql/parse.py | 2 +- superset/sql_lab.py | 29 ++++++++++++++++++------- tests/integration_tests/sqllab_tests.py | 6 ++++- tests/unit_tests/sql/parse_tests.py | 14 ++++++++++++ 6 files changed, 43 insertions(+), 12 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2318f6fca..e1f32afa7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,7 +89,7 @@ dependencies = [ "slack_sdk>=3.19.0, <4", "sqlalchemy>=1.4, <2", "sqlalchemy-utils>=0.38.3, <0.39", - "sqlglot>=23.0.2,<24", + "sqlglot>=25.24.0,<26", "sqlparse>=0.5.0", "tabulate>=0.8.9, <0.9", "typing-extensions>=4, <5", diff --git a/requirements/base.txt b/requirements/base.txt index 22540e3d7..904b1a18f 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -350,7 +350,7 @@ sqlalchemy-utils==0.38.3 # via # apache-superset # flask-appbuilder -sqlglot==23.6.3 +sqlglot==25.24.0 # via apache-superset sqlparse==0.5.0 # via apache-superset diff --git a/superset/sql/parse.py b/superset/sql/parse.py index 3ec928fab..377411b94 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -362,7 +362,7 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): """ return { - eq.this.sql(): eq.expression.sql() + eq.this.sql(comments=False): eq.expression.sql(comments=False) for set_item in self._parsed.find_all(exp.SetItem) for eq in set_item.find_all(exp.EQ) } diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 3d3b2898f..65a093610 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -46,6 +46,7 @@ from superset.exceptions import ( OAuth2RedirectError, SupersetErrorException, SupersetErrorsException, + SupersetParseError, ) from superset.extensions import celery_app, event_logger from superset.models.core import Database @@ -236,15 +237,27 @@ def execute_sql_statement( # pylint: disable=too-many-statements, too-many-loca # We are testing to see if more rows exist than the limit. increased_limit = None if query.limit is None else query.limit + 1 - parsed_statement = SQLStatement(sql_statement, engine=db_engine_spec.engine) - if parsed_statement.is_mutating() and not database.allow_dml: - raise SupersetErrorException( - SupersetError( - message=__("Only SELECT statements are allowed against this database."), - error_type=SupersetErrorType.DML_NOT_ALLOWED_ERROR, - level=ErrorLevel.ERROR, + if not database.allow_dml: + try: + parsed_statement = SQLStatement(sql_statement, engine=db_engine_spec.engine) + disallowed = parsed_statement.is_mutating() + except SupersetParseError: + # if we fail to parse teh query, disallow by default + disallowed = True + + if disallowed: + raise SupersetErrorException( + SupersetError( + message=__( + "This database does not allow for DDL/DML, and the query " + "could not be parsed to confirm it is a read-only query. Please " + "contact your administrator for more assistance." + ), + error_type=SupersetErrorType.DML_NOT_ALLOWED_ERROR, + level=ErrorLevel.ERROR, + ) ) - ) + if apply_ctas: if not query.tmp_table_name: start_dttm = datetime.fromtimestamp(query.start_time) diff --git a/tests/integration_tests/sqllab_tests.py b/tests/integration_tests/sqllab_tests.py index 829854d96..cc1a813c9 100644 --- a/tests/integration_tests/sqllab_tests.py +++ b/tests/integration_tests/sqllab_tests.py @@ -137,7 +137,11 @@ class TestSqlLab(SupersetTestCase): assert data == { "errors": [ { - "message": "Only SELECT statements are allowed against this database.", + "message": ( + "This database does not allow for DDL/DML, and the query " + "could not be parsed to confirm it is a read-only query. Please " + "contact your administrator for more assistance." + ), "error_type": SupersetErrorType.DML_NOT_ALLOWED_ERROR, "level": ErrorLevel.ERROR, "extra": { diff --git a/tests/unit_tests/sql/parse_tests.py b/tests/unit_tests/sql/parse_tests.py index f5d55bc13..6c1e57912 100644 --- a/tests/unit_tests/sql/parse_tests.py +++ b/tests/unit_tests/sql/parse_tests.py @@ -918,3 +918,17 @@ def test_has_mutation(engine: str, sql: str, expected: bool) -> None: Test the `has_mutation` method. """ assert SQLScript(sql, engine).has_mutation() == expected + + +def test_get_settings() -> None: + """ + Test `get_settings` in some edge cases. + """ + sql = """ +set +-- this is a tricky comment +search_path -- another one += bar; +SELECT * FROM some_table; + """ + assert SQLScript(sql, "postgresql").get_settings() == {"search_path": "bar"}