feat: `is_mutating` method (#30177)
This commit is contained in:
parent
05197db71b
commit
1f890718a2
|
|
@ -56,7 +56,7 @@ const fakeDatabaseApiResult = {
|
|||
allow_file_upload: 'Allow Csv Upload',
|
||||
allow_ctas: 'Allow Ctas',
|
||||
allow_cvas: 'Allow Cvas',
|
||||
allow_dml: 'Allow Dml',
|
||||
allow_dml: 'Allow DDL and DML',
|
||||
allow_run_async: 'Allow Run Async',
|
||||
allows_cost_estimate: 'Allows Cost Estimate',
|
||||
allows_subquery: 'Allows Subquery',
|
||||
|
|
|
|||
|
|
@ -172,11 +172,11 @@ const ExtraOptions = ({
|
|||
indeterminate={false}
|
||||
checked={!!db?.allow_dml}
|
||||
onChange={onInputChange}
|
||||
labelText={t('Allow DML')}
|
||||
labelText={t('Allow DDL and DML')}
|
||||
/>
|
||||
<InfoTooltip
|
||||
tooltip={t(
|
||||
'Allow manipulation of the database using non-SELECT statements such as UPDATE, DELETE, CREATE, etc.',
|
||||
'Allow the execution of DDL (Data Definition Language: CREATE, DROP, TRUNCATE, etc.) and DML (Data Modification Language: INSERT, UPDATE, DELETE, etc)',
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -700,9 +700,9 @@ describe('DatabaseModal', () => {
|
|||
/force all tables and views to be created in this schema when clicking ctas or cvas in sql lab\./i,
|
||||
);
|
||||
const allowDMLCheckbox = screen.getByRole('checkbox', {
|
||||
name: /allow dml/i,
|
||||
name: /allow ddl and dml/i,
|
||||
});
|
||||
const allowDMLText = screen.getByText(/allow dml/i);
|
||||
const allowDMLText = screen.getByText(/allow ddl and dml/i);
|
||||
const enableQueryCostEstimationCheckbox = screen.getByRole('checkbox', {
|
||||
name: /enable query cost estimation/i,
|
||||
});
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ beforeEach(() => {
|
|||
allow_file_upload: 'Allow Csv Upload',
|
||||
allow_ctas: 'Allow Ctas',
|
||||
allow_cvas: 'Allow Cvas',
|
||||
allow_dml: 'Allow Dml',
|
||||
allow_dml: 'Allow DDL and DML',
|
||||
allow_multi_schema_metadata_fetch: 'Allow Multi Schema Metadata Fetch',
|
||||
allow_run_async: 'Allow Run Async',
|
||||
allows_cost_estimate: 'Allows Cost Estimate',
|
||||
|
|
|
|||
|
|
@ -274,7 +274,7 @@ class SupersetShillelaghAdapter(Adapter):
|
|||
# to perform updates and deletes. Otherwise we can only do inserts and selects.
|
||||
self._rowid: str | None = None
|
||||
|
||||
# Does the database allow DML?
|
||||
# Does the database allow DDL/DML?
|
||||
self._allow_dml: bool = False
|
||||
|
||||
# Read column information from the database, and store it for later.
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ from superset.sql_parse import (
|
|||
insert_rls_as_subquery,
|
||||
insert_rls_in_predicate,
|
||||
ParsedQuery,
|
||||
SQLStatement,
|
||||
Table,
|
||||
)
|
||||
from superset.sqllab.limiting_factor import LimitingFactor
|
||||
|
|
@ -194,7 +195,7 @@ def get_sql_results( # pylint: disable=too-many-arguments
|
|||
return handle_query_error(ex, query)
|
||||
|
||||
|
||||
def execute_sql_statement( # pylint: disable=too-many-statements
|
||||
def execute_sql_statement( # pylint: disable=too-many-statements, too-many-locals
|
||||
sql_statement: str,
|
||||
query: Query,
|
||||
cursor: Any,
|
||||
|
|
@ -236,7 +237,8 @@ def execute_sql_statement( # pylint: disable=too-many-statements
|
|||
# We are testing to see if more rows exist than the limit.
|
||||
increased_limit = None if query.limit is None else query.limit + 1
|
||||
|
||||
if not db_engine_spec.is_readonly_query(parsed_query) and not database.allow_dml:
|
||||
parsed_statement = SQLStatement(sql_statement, engine=db_engine_spec.engine)
|
||||
if parsed_statement.is_mutating() and not database.allow_dml:
|
||||
raise SupersetErrorException(
|
||||
SupersetError(
|
||||
message=__("Only SELECT statements are allowed against this database."),
|
||||
|
|
|
|||
|
|
@ -452,6 +452,14 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def is_mutating(self) -> bool:
|
||||
"""
|
||||
Check if the statement mutates data (DDL/DML).
|
||||
|
||||
:return: True if the statement mutates data.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.format()
|
||||
|
||||
|
|
@ -522,6 +530,43 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
|||
dialect = SQLGLOT_DIALECTS.get(engine)
|
||||
return extract_tables_from_statement(parsed, dialect)
|
||||
|
||||
def is_mutating(self) -> bool:
|
||||
"""
|
||||
Check if the statement mutates data (DDL/DML).
|
||||
|
||||
:return: True if the statement mutates data.
|
||||
"""
|
||||
for node in self._parsed.walk():
|
||||
if isinstance(
|
||||
node,
|
||||
(
|
||||
exp.Insert,
|
||||
exp.Update,
|
||||
exp.Delete,
|
||||
exp.Merge,
|
||||
exp.Create,
|
||||
exp.Drop,
|
||||
exp.TruncateTable,
|
||||
),
|
||||
):
|
||||
return True
|
||||
|
||||
if isinstance(node, exp.Command) and node.name == "ALTER":
|
||||
return True
|
||||
|
||||
# Postgres runs DMLs prefixed by `EXPLAIN ANALYZE`, see
|
||||
# https://www.postgresql.org/docs/current/sql-explain.html
|
||||
if (
|
||||
self._dialect == Dialects.POSTGRES
|
||||
and isinstance(self._parsed, exp.Command)
|
||||
and self._parsed.name == "EXPLAIN"
|
||||
and self._parsed.expression.name.upper().startswith("ANALYZE ")
|
||||
):
|
||||
analyzed_sql = self._parsed.expression.name[len("ANALYZE ") :]
|
||||
return SQLStatement(analyzed_sql, self.engine).is_mutating()
|
||||
|
||||
return False
|
||||
|
||||
def format(self, comments: bool = True) -> str:
|
||||
"""
|
||||
Pretty-format the SQL statement.
|
||||
|
|
@ -688,6 +733,14 @@ class KustoKQLStatement(BaseSQLStatement[str]):
|
|||
|
||||
return {}
|
||||
|
||||
def is_mutating(self) -> bool:
|
||||
"""
|
||||
Check if the statement mutates data (DDL/DML).
|
||||
|
||||
:return: True if the statement mutates data.
|
||||
"""
|
||||
return self._parsed.startswith(".") and not self._parsed.startswith(".show")
|
||||
|
||||
|
||||
class SQLScript:
|
||||
"""
|
||||
|
|
@ -730,6 +783,14 @@ class SQLScript:
|
|||
|
||||
return settings
|
||||
|
||||
def has_mutation(self) -> bool:
|
||||
"""
|
||||
Check if the script contains mutating statements.
|
||||
|
||||
:return: True if the script contains mutating statements
|
||||
"""
|
||||
return any(statement.is_mutating() for statement in self.statements)
|
||||
|
||||
|
||||
class ParsedQuery:
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -187,7 +187,7 @@ class DatabaseMixin:
|
|||
"expose_in_sqllab": _("Expose in SQL Lab"),
|
||||
"allow_ctas": _("Allow CREATE TABLE AS"),
|
||||
"allow_cvas": _("Allow CREATE VIEW AS"),
|
||||
"allow_dml": _("Allow DML"),
|
||||
"allow_dml": _("Allow DDL/DML"),
|
||||
"force_ctas_schema": _("CTAS Schema"),
|
||||
"database_name": _("Database"),
|
||||
"creator": _("Creator"),
|
||||
|
|
|
|||
|
|
@ -2058,3 +2058,60 @@ on $left.Day1 == $right.Day
|
|||
| project Day1, Day2, Percentage = count_*100.0/count_1
|
||||
""",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("engine", "sql", "expected"),
|
||||
[
|
||||
# SQLite tests
|
||||
("sqlite", "SELECT 1", False),
|
||||
("sqlite", "INSERT INTO foo VALUES (1)", True),
|
||||
("sqlite", "UPDATE foo SET bar = 2 WHERE id = 1", True),
|
||||
("sqlite", "DELETE FROM foo WHERE id = 1", True),
|
||||
("sqlite", "CREATE TABLE foo (id INT, bar TEXT)", True),
|
||||
("sqlite", "DROP TABLE foo", True),
|
||||
("sqlite", "EXPLAIN SELECT * FROM foo", False),
|
||||
("sqlite", "PRAGMA table_info(foo)", False),
|
||||
("postgresql", "SELECT 1", False),
|
||||
("postgresql", "INSERT INTO foo (id, bar) VALUES (1, 'test')", True),
|
||||
("postgresql", "UPDATE foo SET bar = 'new' WHERE id = 1", True),
|
||||
("postgresql", "DELETE FROM foo WHERE id = 1", True),
|
||||
("postgresql", "CREATE TABLE foo (id SERIAL PRIMARY KEY, bar TEXT)", True),
|
||||
("postgresql", "DROP TABLE foo", True),
|
||||
("postgresql", "EXPLAIN ANALYZE SELECT * FROM foo", False),
|
||||
("postgresql", "EXPLAIN ANALYZE DELETE FROM foo", True),
|
||||
("postgresql", "SHOW search_path", False),
|
||||
("postgresql", "SET search_path TO public", False),
|
||||
(
|
||||
"postgres",
|
||||
"""
|
||||
with source as (
|
||||
select 1 as one
|
||||
)
|
||||
select * from source
|
||||
""",
|
||||
False,
|
||||
),
|
||||
("trino", "SELECT 1", False),
|
||||
("trino", "INSERT INTO foo VALUES (1, 'bar')", True),
|
||||
("trino", "UPDATE foo SET bar = 'baz' WHERE id = 1", True),
|
||||
("trino", "DELETE FROM foo WHERE id = 1", True),
|
||||
("trino", "CREATE TABLE foo (id INT, bar VARCHAR)", True),
|
||||
("trino", "DROP TABLE foo", True),
|
||||
("trino", "EXPLAIN SELECT * FROM foo", False),
|
||||
("trino", "SHOW SCHEMAS", False),
|
||||
("trino", "SET SESSION optimization_level = '3'", False),
|
||||
("kustokql", "tbl | limit 100", False),
|
||||
("kustokql", "let foo = 1; tbl | where bar == foo", False),
|
||||
("kustokql", ".show tables", False),
|
||||
("kustokql", "print 1", False),
|
||||
("kustokql", "set querytrace; Events | take 100", False),
|
||||
("kustokql", ".drop table foo", True),
|
||||
("kustokql", ".set-or-append table foo <| bar", True),
|
||||
],
|
||||
)
|
||||
def test_has_mutation(engine: str, sql: str, expected: bool) -> None:
|
||||
"""
|
||||
Test the `has_mutation` method.
|
||||
"""
|
||||
assert SQLScript(sql, engine).has_mutation() == expected
|
||||
|
|
|
|||
Loading…
Reference in New Issue