chore: organize SQL parsing files (#30258)

This commit is contained in:
Beto Dealmeida 2024-09-13 16:24:19 -04:00 committed by GitHub
parent 8cd18cac8c
commit bdf29cb7c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 1650 additions and 886 deletions

View File

@ -63,7 +63,8 @@ from superset.constants import TimeGrain as TimeGrainConstants
from superset.databases.utils import get_table_metadata, make_url_safe
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import DisallowedSQLFunction, OAuth2Error, OAuth2RedirectError
from superset.sql_parse import ParsedQuery, SQLScript, Table
from superset.sql.parse import SQLScript, Table
from superset.sql_parse import ParsedQuery
from superset.superset_typing import (
OAuth2ClientConfig,
OAuth2State,

View File

@ -35,7 +35,7 @@ from superset.db_engine_specs.base import BaseEngineSpec, BasicParametersMixin
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetException, SupersetSecurityException
from superset.models.sql_lab import Query
from superset.sql_parse import SQLScript
from superset.sql.parse import SQLScript
from superset.utils import core as utils, json
from superset.utils.core import GenericDataType

View File

@ -14,6 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from collections import defaultdict
from typing import Any, Optional
@ -304,12 +307,30 @@ class SupersetParseError(SupersetErrorException):
status = 422
def __init__(self, sql: str, engine: Optional[str] = None):
def __init__( # pylint: disable=too-many-arguments
self,
sql: str,
engine: Optional[str] = None,
message: Optional[str] = None,
highlight: Optional[str] = None,
line: Optional[int] = None,
column: Optional[int] = None,
):
if message is None:
parts = [_("Error parsing")]
if highlight:
parts.append(_(" near '%(highlight)s'", highlight=highlight))
if line:
parts.append(_(" at line %(line)d", line=line))
if column:
parts.append(_(":%(column)d", column=column))
message = "".join(parts)
error = SupersetError(
message=_("The SQL is invalid and cannot be parsed."),
message=message,
error_type=SupersetErrorType.INVALID_SQL_ERROR,
level=ErrorLevel.ERROR,
extra={"sql": sql, "engine": engine},
extra={"sql": sql, "engine": engine, "line": line, "column": column},
)
super().__init__(error)

View File

@ -68,13 +68,12 @@ from superset.exceptions import (
)
from superset.extensions import feature_flag_manager
from superset.jinja_context import BaseTemplateProcessor
from superset.sql.parse import SQLScript, SQLStatement
from superset.sql_parse import (
has_table_query,
insert_rls_in_predicate,
ParsedQuery,
sanitize_clause,
SQLScript,
SQLStatement,
)
from superset.superset_typing import (
AdhocMetric,

16
superset/sql/__init__.py Normal file
View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

648
superset/sql/parse.py Normal file
View File

@ -0,0 +1,648 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import enum
import logging
import re
import urllib.parse
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any, Generic, TypeVar
import sqlglot
from sqlglot import exp
from sqlglot.dialects.dialect import Dialect, Dialects
from sqlglot.errors import ParseError
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
from superset.exceptions import SupersetParseError
logger = logging.getLogger(__name__)
# mapping between DB engine specs and sqlglot dialects
SQLGLOT_DIALECTS = {
"base": Dialects.DIALECT,
"ascend": Dialects.HIVE,
"awsathena": Dialects.PRESTO,
"bigquery": Dialects.BIGQUERY,
"clickhouse": Dialects.CLICKHOUSE,
"clickhousedb": Dialects.CLICKHOUSE,
"cockroachdb": Dialects.POSTGRES,
"couchbase": Dialects.MYSQL,
# "crate": ???
# "databend": ???
"databricks": Dialects.DATABRICKS,
# "db2": ???
# "dremio": ???
"drill": Dialects.DRILL,
# "druid": ???
"duckdb": Dialects.DUCKDB,
# "dynamodb": ???
# "elasticsearch": ???
# "exa": ???
# "firebird": ???
# "firebolt": ???
"gsheets": Dialects.SQLITE,
"hana": Dialects.POSTGRES,
"hive": Dialects.HIVE,
# "ibmi": ???
# "impala": ???
# "kustokql": ???
# "kylin": ???
"mssql": Dialects.TSQL,
"mysql": Dialects.MYSQL,
"netezza": Dialects.POSTGRES,
# "ocient": ???
# "odelasticsearch": ???
"oracle": Dialects.ORACLE,
# "pinot": ???
"postgresql": Dialects.POSTGRES,
"presto": Dialects.PRESTO,
"pydoris": Dialects.DORIS,
"redshift": Dialects.REDSHIFT,
# "risingwave": ???
# "rockset": ???
"shillelagh": Dialects.SQLITE,
"snowflake": Dialects.SNOWFLAKE,
# "solr": ???
"spark": Dialects.SPARK,
"sqlite": Dialects.SQLITE,
"starrocks": Dialects.STARROCKS,
"superset": Dialects.SQLITE,
"teradatasql": Dialects.TERADATA,
"trino": Dialects.TRINO,
"vertica": Dialects.POSTGRES,
}
@dataclass(eq=True, frozen=True)
class Table:
"""
A fully qualified SQL table conforming to [[catalog.]schema.]table.
"""
table: str
schema: str | None = None
catalog: str | None = None
def __str__(self) -> str:
"""
Return the fully qualified SQL table name.
Should not be used for SQL generation, only for logging and debugging, since the
quoting is not engine-specific.
"""
return ".".join(
urllib.parse.quote(part, safe="").replace(".", "%2E")
for part in [self.catalog, self.schema, self.table]
if part
)
def __eq__(self, other: Any) -> bool:
return str(self) == str(other)
# To avoid unnecessary parsing/formatting of queries, the statement has the concept of
# an "internal representation", which is the AST of the SQL statement. For most of the
# engines supported by Superset this is `sqlglot.exp.Expression`, but there is a special
# case: KustoKQL uses a different syntax and there are no Python parsers for it, so we
# store the AST as a string (the original query), and manipulate it with regular
# expressions.
InternalRepresentation = TypeVar("InternalRepresentation")
# The base type. This helps type checking the `split_query` method correctly, since each
# derived class has a more specific return type (the class itself). This will no longer
# be needed once Python 3.11 is the lowest version supported. See PEP 673 for more
# information: https://peps.python.org/pep-0673/
TBaseSQLStatement = TypeVar("TBaseSQLStatement") # pylint: disable=invalid-name
class BaseSQLStatement(Generic[InternalRepresentation]):
"""
Base class for SQL statements.
The class can be instantiated with a string representation of the script or, for
efficiency reasons, with a pre-parsed AST. This is useful with `sqlglot.parse`,
which will split a script in multiple already parsed statements.
The `engine` parameters comes from the `engine` attribute in a Superset DB engine
spec.
"""
def __init__(
self,
statement: str | InternalRepresentation,
engine: str,
):
self._parsed: InternalRepresentation = (
self._parse_statement(statement, engine)
if isinstance(statement, str)
else statement
)
self.engine = engine
self.tables = self._extract_tables_from_statement(self._parsed, self.engine)
@classmethod
def split_script(
cls: type[TBaseSQLStatement],
script: str,
engine: str,
) -> list[TBaseSQLStatement]:
"""
Split a script into multiple instantiated statements.
This is a helper function to split a full SQL script into multiple
`BaseSQLStatement` instances. It's used by `SQLScript` when instantiating the
statements within a script.
"""
raise NotImplementedError()
@classmethod
def _parse_statement(
cls,
statement: str,
engine: str,
) -> InternalRepresentation:
"""
Parse a string containing a single SQL statement, and returns the parsed AST.
Derived classes should not assume that `statement` contains a single statement,
and MUST explicitly validate that. Since this validation is parser dependent the
responsibility is left to the children classes.
"""
raise NotImplementedError()
@classmethod
def _extract_tables_from_statement(
cls,
parsed: InternalRepresentation,
engine: str,
) -> set[Table]:
"""
Extract all table references in a given statement.
"""
raise NotImplementedError()
def format(self, comments: bool = True) -> str:
"""
Format the statement, optionally ommitting comments.
"""
raise NotImplementedError()
def get_settings(self) -> dict[str, str | bool]:
"""
Return any settings set by the statement.
For example, for this statement:
sql> SET foo = 'bar';
The method should return `{"foo": "'bar'"}`. Note the single quotes.
"""
raise NotImplementedError()
def is_mutating(self) -> bool:
"""
Check if the statement mutates data (DDL/DML).
:return: True if the statement mutates data.
"""
raise NotImplementedError()
def __str__(self) -> str:
return self.format()
class SQLStatement(BaseSQLStatement[exp.Expression]):
"""
A SQL statement.
This class is used for all engines with dialects that can be parsed using sqlglot.
"""
def __init__(
self,
statement: str | exp.Expression,
engine: str,
):
self._dialect = SQLGLOT_DIALECTS.get(engine)
super().__init__(statement, engine)
@classmethod
def _parse(cls, script: str, engine: str) -> list[exp.Expression]:
"""
Parse helper.
"""
dialect = SQLGLOT_DIALECTS.get(engine)
try:
return sqlglot.parse(script, dialect=dialect)
except sqlglot.errors.ParseError as ex:
error = ex.errors[0]
raise SupersetParseError(
script,
engine,
highlight=error["highlight"],
line=error["line"],
column=error["col"],
) from ex
except sqlglot.errors.SqlglotError as ex:
raise SupersetParseError(
script,
engine,
message="Unable to parse script",
) from ex
@classmethod
def split_script(
cls,
script: str,
engine: str,
) -> list[SQLStatement]:
return [
cls(statement, engine)
for statement in cls._parse(script, engine)
if statement
]
@classmethod
def _parse_statement(
cls,
statement: str,
engine: str,
) -> exp.Expression:
"""
Parse a single SQL statement.
"""
statements = cls.split_script(statement, engine)
if len(statements) != 1:
raise SupersetParseError("SQLStatement should have exactly one statement")
return statements[0]._parsed # pylint: disable=protected-access
@classmethod
def _extract_tables_from_statement(
cls,
parsed: exp.Expression,
engine: str,
) -> set[Table]:
"""
Find all referenced tables.
"""
dialect = SQLGLOT_DIALECTS.get(engine)
return extract_tables_from_statement(parsed, dialect)
def is_mutating(self) -> bool:
"""
Check if the statement mutates data (DDL/DML).
:return: True if the statement mutates data.
"""
for node in self._parsed.walk():
if isinstance(
node,
(
exp.Insert,
exp.Update,
exp.Delete,
exp.Merge,
exp.Create,
exp.Drop,
exp.TruncateTable,
),
):
return True
if isinstance(node, exp.Command) and node.name == "ALTER":
return True
# Postgres runs DMLs prefixed by `EXPLAIN ANALYZE`, see
# https://www.postgresql.org/docs/current/sql-explain.html
if (
self._dialect == Dialects.POSTGRES
and isinstance(self._parsed, exp.Command)
and self._parsed.name == "EXPLAIN"
and self._parsed.expression.name.upper().startswith("ANALYZE ")
):
analyzed_sql = self._parsed.expression.name[len("ANALYZE ") :]
return SQLStatement(analyzed_sql, self.engine).is_mutating()
return False
def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL statement.
"""
write = Dialect.get_or_raise(self._dialect)
return write.generate(self._parsed, copy=False, comments=comments, pretty=True)
def get_settings(self) -> dict[str, str | bool]:
"""
Return the settings for the SQL statement.
>>> statement = SQLStatement("SET foo = 'bar'")
>>> statement.get_settings()
{"foo": "'bar'"}
"""
return {
eq.this.sql(): eq.expression.sql()
for set_item in self._parsed.find_all(exp.SetItem)
for eq in set_item.find_all(exp.EQ)
}
class KQLSplitState(enum.Enum):
"""
State machine for splitting a KQL script.
The state machine keeps track of whether we're inside a string or not, so we
don't split the script in a semi-colon that's part of a string.
"""
OUTSIDE_STRING = enum.auto()
INSIDE_SINGLE_QUOTED_STRING = enum.auto()
INSIDE_DOUBLE_QUOTED_STRING = enum.auto()
INSIDE_MULTILINE_STRING = enum.auto()
def split_kql(kql: str) -> list[str]:
"""
Custom function for splitting KQL statements.
"""
statements = []
state = KQLSplitState.OUTSIDE_STRING
statement_start = 0
script = kql if kql.endswith(";") else kql + ";"
for i, character in enumerate(script):
if state == KQLSplitState.OUTSIDE_STRING:
if character == ";":
statements.append(script[statement_start:i])
statement_start = i + 1
elif character == "'":
state = KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
elif character == '"':
state = KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
elif character == "`" and script[i - 2 : i] == "``":
state = KQLSplitState.INSIDE_MULTILINE_STRING
elif (
state == KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
and character == "'"
and script[i - 1] != "\\"
):
state = KQLSplitState.OUTSIDE_STRING
elif (
state == KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
and character == '"'
and script[i - 1] != "\\"
):
state = KQLSplitState.OUTSIDE_STRING
elif (
state == KQLSplitState.INSIDE_MULTILINE_STRING
and character == "`"
and script[i - 2 : i] == "``"
):
state = KQLSplitState.OUTSIDE_STRING
return statements
class KustoKQLStatement(BaseSQLStatement[str]):
"""
Special class for Kusto KQL.
Kusto KQL is a SQL-like language, but it's not supported by sqlglot. Queries look
like this:
StormEvents
| summarize PropertyDamage = sum(DamageProperty) by State
| join kind=innerunique PopulationData on State
| project State, PropertyDamagePerCapita = PropertyDamage / Population
| sort by PropertyDamagePerCapita
See https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/ for more
details about it.
"""
@classmethod
def split_script(
cls,
script: str,
engine: str,
) -> list[KustoKQLStatement]:
"""
Split a script at semi-colons.
Since we don't have a parser, we use a simple state machine based function. See
https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/scalar-data-types/string
for more information.
"""
return [cls(statement, engine) for statement in split_kql(script)]
@classmethod
def _parse_statement(
cls,
statement: str,
engine: str,
) -> str:
if engine != "kustokql":
raise SupersetParseError(f"Invalid engine: {engine}")
statements = split_kql(statement)
if len(statements) != 1:
raise SupersetParseError("SQLStatement should have exactly one statement")
return statements[0].strip()
@classmethod
def _extract_tables_from_statement(
cls,
parsed: str,
engine: str,
) -> set[Table]:
"""
Extract all tables referenced in the statement.
StormEvents
| where InjuriesDirect + InjuriesIndirect > 50
| join (PopulationData) on State
| project State, Population, TotalInjuries = InjuriesDirect + InjuriesIndirect
"""
logger.warning(
"Kusto KQL doesn't support table extraction. This means that data access "
"roles will not be enforced by Superset in the database."
)
return set()
def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL statement.
"""
return self._parsed
def get_settings(self) -> dict[str, str | bool]:
"""
Return the settings for the SQL statement.
>>> statement = KustoKQLStatement("set querytrace;")
>>> statement.get_settings()
{"querytrace": True}
"""
set_regex = r"^set\s+(?P<name>\w+)(?:\s*=\s*(?P<value>\w+))?$"
if match := re.match(set_regex, self._parsed, re.IGNORECASE):
return {match.group("name"): match.group("value") or True}
return {}
def is_mutating(self) -> bool:
"""
Check if the statement mutates data (DDL/DML).
:return: True if the statement mutates data.
"""
return self._parsed.startswith(".") and not self._parsed.startswith(".show")
class SQLScript:
"""
A SQL script, with 0+ statements.
"""
# Special engines that can't be parsed using sqlglot. Supporting non-SQL engines
# adds a lot of complexity to Superset, so we should avoid adding new engines to
# this data structure.
special_engines = {
"kustokql": KustoKQLStatement,
}
def __init__(
self,
script: str,
engine: str,
):
statement_class = self.special_engines.get(engine, SQLStatement)
self.engine = engine
self.statements = statement_class.split_script(script, engine)
def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL script.
"""
return ";\n".join(statement.format(comments) for statement in self.statements)
def get_settings(self) -> dict[str, str | bool]:
"""
Return the settings for the SQL script.
>>> statement = SQLScript("SET foo = 'bar'; SET foo = 'baz'")
>>> statement.get_settings()
{"foo": "'baz'"}
"""
settings: dict[str, str | bool] = {}
for statement in self.statements:
settings.update(statement.get_settings())
return settings
def has_mutation(self) -> bool:
"""
Check if the script contains mutating statements.
:return: True if the script contains mutating statements
"""
return any(statement.is_mutating() for statement in self.statements)
def extract_tables_from_statement(
statement: exp.Expression,
dialect: Dialects | None,
) -> set[Table]:
"""
Extract all table references in a single statement.
Please not that this is not trivial; consider the following queries:
DESCRIBE some_table;
SHOW PARTITIONS FROM some_table;
WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM masked_name;
See the unit tests for other tricky cases.
"""
sources: Iterable[exp.Table]
if isinstance(statement, exp.Describe):
# A `DESCRIBE` query has no sources in sqlglot, so we need to explicitly
# query for all tables.
sources = statement.find_all(exp.Table)
elif isinstance(statement, exp.Command):
# Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a
# `SELECT` statetement in order to extract tables.
literal = statement.find(exp.Literal)
if not literal:
return set()
try:
pseudo_query = sqlglot.parse_one(f"SELECT {literal.this}", dialect=dialect)
except ParseError:
return set()
sources = pseudo_query.find_all(exp.Table)
else:
sources = [
source
for scope in traverse_scope(statement)
for source in scope.sources.values()
if isinstance(source, exp.Table) and not is_cte(source, scope)
]
return {
Table(
source.name,
source.db if source.db != "" else None,
source.catalog if source.catalog != "" else None,
)
for source in sources
}
def is_cte(source: exp.Table, scope: Scope) -> bool:
"""
Is the source a CTE?
CTEs in the parent scope look like tables (and are represented by
exp.Table objects), but should not be considered as such;
otherwise a user with access to table `foo` could access any table
with a query like this:
WITH foo AS (SELECT * FROM target_table) SELECT * FROM foo
"""
parent_sources = scope.parent.sources if scope.parent else {}
ctes_in_scope = {
name
for name, parent_scope in parent_sources.items()
if isinstance(parent_scope, Scope) and parent_scope.scope_type == ScopeType.CTE
}
return source.name in ctes_in_scope

View File

@ -51,13 +51,12 @@ from superset.extensions import celery_app, event_logger
from superset.models.core import Database
from superset.models.sql_lab import Query
from superset.result_set import SupersetResultSet
from superset.sql.parse import SQLStatement, Table
from superset.sql_parse import (
CtasMethod,
insert_rls_as_subquery,
insert_rls_in_predicate,
ParsedQuery,
SQLStatement,
Table,
)
from superset.sqllab.limiting_factor import LimitingFactor
from superset.sqllab.utils import write_ipc_buffer

View File

@ -19,23 +19,16 @@
from __future__ import annotations
import enum
import logging
import re
import urllib.parse
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from typing import Any, cast, Generic, TYPE_CHECKING, TypeVar
from collections.abc import Iterator
from typing import Any, cast, TYPE_CHECKING
import sqlglot
import sqlparse
from flask_babel import gettext as __
from jinja2 import nodes
from sqlalchemy import and_
from sqlglot import exp, parse, parse_one
from sqlglot.dialects.dialect import Dialect, Dialects
from sqlglot.errors import ParseError, SqlglotError
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
from sqlglot.dialects.dialect import Dialects
from sqlparse import keywords
from sqlparse.lexer import Lexer
from sqlparse.sql import (
@ -68,6 +61,7 @@ from superset.exceptions import (
SupersetParseError,
SupersetSecurityException,
)
from superset.sql.parse import extract_tables_from_statement, SQLScript, Table
from superset.utils.backports import StrEnum
try:
@ -226,7 +220,9 @@ def get_cte_remainder_query(sql: str) -> tuple[str | None, str]:
def check_sql_functions_exist(
sql: str, function_list: set[str], engine: str | None = None
sql: str,
function_list: set[str],
engine: str = "base",
) -> bool:
"""
Check if the SQL statement contains any of the specified functions.
@ -238,7 +234,7 @@ def check_sql_functions_exist(
return ParsedQuery(sql, engine=engine).check_functions_exist(function_list)
def strip_comments_from_sql(statement: str, engine: str | None = None) -> str:
def strip_comments_from_sql(statement: str, engine: str = "base") -> str:
"""
Strips comments from a SQL statement, does a simple test first
to avoid always instantiating the expensive ParsedQuery constructor
@ -255,554 +251,18 @@ def strip_comments_from_sql(statement: str, engine: str | None = None) -> str:
)
@dataclass(eq=True, frozen=True)
class Table:
"""
A fully qualified SQL table conforming to [[catalog.]schema.]table.
"""
table: str
schema: str | None = None
catalog: str | None = None
def __str__(self) -> str:
"""
Return the fully qualified SQL table name.
"""
return ".".join(
urllib.parse.quote(part, safe="").replace(".", "%2E")
for part in [self.catalog, self.schema, self.table]
if part
)
def __eq__(self, __o: object) -> bool:
return str(self) == str(__o)
def extract_tables_from_statement(
statement: exp.Expression,
dialect: Dialects | None,
) -> set[Table]:
"""
Extract all table references in a single statement.
Please not that this is not trivial; consider the following queries:
DESCRIBE some_table;
SHOW PARTITIONS FROM some_table;
WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM masked_name;
See the unit tests for other tricky cases.
"""
sources: Iterable[exp.Table]
if isinstance(statement, exp.Describe):
# A `DESCRIBE` query has no sources in sqlglot, so we need to explicitly
# query for all tables.
sources = statement.find_all(exp.Table)
elif isinstance(statement, exp.Command):
# Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a
# `SELECT` statetement in order to extract tables.
literal = statement.find(exp.Literal)
if not literal:
return set()
try:
pseudo_query = parse_one(f"SELECT {literal.this}", dialect=dialect)
except ParseError:
return set()
sources = pseudo_query.find_all(exp.Table)
else:
sources = [
source
for scope in traverse_scope(statement)
for source in scope.sources.values()
if isinstance(source, exp.Table) and not is_cte(source, scope)
]
return {
Table(
source.name,
source.db if source.db != "" else None,
source.catalog if source.catalog != "" else None,
)
for source in sources
}
def is_cte(source: exp.Table, scope: Scope) -> bool:
"""
Is the source a CTE?
CTEs in the parent scope look like tables (and are represented by
exp.Table objects), but should not be considered as such;
otherwise a user with access to table `foo` could access any table
with a query like this:
WITH foo AS (SELECT * FROM target_table) SELECT * FROM foo
"""
parent_sources = scope.parent.sources if scope.parent else {}
ctes_in_scope = {
name
for name, parent_scope in parent_sources.items()
if isinstance(parent_scope, Scope) and parent_scope.scope_type == ScopeType.CTE
}
return source.name in ctes_in_scope
# To avoid unnecessary parsing/formatting of queries, the statement has the concept of
# an "internal representation", which is the AST of the SQL statement. For most of the
# engines supported by Superset this is `sqlglot.exp.Expression`, but there is a special
# case: KustoKQL uses a different syntax and there are no Python parsers for it, so we
# store the AST as a string (the original query), and manipulate it with regular
# expressions.
InternalRepresentation = TypeVar("InternalRepresentation")
# The base type. This helps type checking the `split_query` method correctly, since each
# derived class has a more specific return type (the class itself). This will no longer
# be needed once Python 3.11 is the lowest version supported. See PEP 673 for more
# information: https://peps.python.org/pep-0673/
TBaseSQLStatement = TypeVar("TBaseSQLStatement") # pylint: disable=invalid-name
class BaseSQLStatement(Generic[InternalRepresentation]):
"""
Base class for SQL statements.
The class can be instantiated with a string representation of the query or, for
efficiency reasons, with a pre-parsed AST. This is useful with `sqlglot.parse`,
which will split a query in multiple already parsed statements.
The `engine` parameters comes from the `engine` attribute in a Superset DB engine
spec.
"""
def __init__(
self,
statement: str | InternalRepresentation,
engine: str,
):
self._parsed: InternalRepresentation = (
self._parse_statement(statement, engine)
if isinstance(statement, str)
else statement
)
self.engine = engine
self.tables = self._extract_tables_from_statement(self._parsed, self.engine)
@classmethod
def split_query(
cls: type[TBaseSQLStatement],
query: str,
engine: str,
) -> list[TBaseSQLStatement]:
"""
Split a query into multiple instantiated statements.
This is a helper function to split a full SQL query into multiple
`BaseSQLStatement` instances. It's used by `SQLScript` when instantiating the
statements within a query.
"""
raise NotImplementedError()
@classmethod
def _parse_statement(
cls,
statement: str,
engine: str,
) -> InternalRepresentation:
"""
Parse a string containing a single SQL statement, and returns the parsed AST.
Derived classes should not assume that `statement` contains a single statement,
and MUST explicitly validate that. Since this validation is parser dependent the
responsibility is left to the children classes.
"""
raise NotImplementedError()
@classmethod
def _extract_tables_from_statement(
cls,
parsed: InternalRepresentation,
engine: str,
) -> set[Table]:
"""
Extract all table references in a given statement.
"""
raise NotImplementedError()
def format(self, comments: bool = True) -> str:
"""
Format the statement, optionally ommitting comments.
"""
raise NotImplementedError()
def get_settings(self) -> dict[str, str | bool]:
"""
Return any settings set by the statement.
For example, for this statement:
sql> SET foo = 'bar';
The method should return `{"foo": "'bar'"}`. Note the single quotes.
"""
raise NotImplementedError()
def is_mutating(self) -> bool:
"""
Check if the statement mutates data (DDL/DML).
:return: True if the statement mutates data.
"""
raise NotImplementedError()
def __str__(self) -> str:
return self.format()
class SQLStatement(BaseSQLStatement[exp.Expression]):
"""
A SQL statement.
This class is used for all engines with dialects that can be parsed using sqlglot.
"""
def __init__(
self,
statement: str | exp.Expression,
engine: str,
):
self._dialect = SQLGLOT_DIALECTS.get(engine)
super().__init__(statement, engine)
@classmethod
def split_query(
cls,
query: str,
engine: str,
) -> list[SQLStatement]:
dialect = SQLGLOT_DIALECTS.get(engine)
try:
statements = sqlglot.parse(query, dialect=dialect)
except sqlglot.errors.ParseError as ex:
raise SupersetParseError("Unable to split query") from ex
return [cls(statement, engine) for statement in statements if statement]
@classmethod
def _parse_statement(
cls,
statement: str,
engine: str,
) -> exp.Expression:
"""
Parse a single SQL statement.
"""
dialect = SQLGLOT_DIALECTS.get(engine)
# We could parse with `sqlglot.parse_one` to get a single statement, but we need
# to verify that the string contains exactly one statement.
try:
statements = sqlglot.parse(statement, dialect=dialect)
except sqlglot.errors.ParseError as ex:
raise SupersetParseError("Unable to split query") from ex
statements = [statement for statement in statements if statement]
if len(statements) != 1:
raise SupersetParseError("SQLStatement should have exactly one statement")
return statements[0]
@classmethod
def _extract_tables_from_statement(
cls,
parsed: exp.Expression,
engine: str,
) -> set[Table]:
"""
Find all referenced tables.
"""
dialect = SQLGLOT_DIALECTS.get(engine)
return extract_tables_from_statement(parsed, dialect)
def is_mutating(self) -> bool:
"""
Check if the statement mutates data (DDL/DML).
:return: True if the statement mutates data.
"""
for node in self._parsed.walk():
if isinstance(
node,
(
exp.Insert,
exp.Update,
exp.Delete,
exp.Merge,
exp.Create,
exp.Drop,
exp.TruncateTable,
),
):
return True
if isinstance(node, exp.Command) and node.name == "ALTER":
return True
# Postgres runs DMLs prefixed by `EXPLAIN ANALYZE`, see
# https://www.postgresql.org/docs/current/sql-explain.html
if (
self._dialect == Dialects.POSTGRES
and isinstance(self._parsed, exp.Command)
and self._parsed.name == "EXPLAIN"
and self._parsed.expression.name.upper().startswith("ANALYZE ")
):
analyzed_sql = self._parsed.expression.name[len("ANALYZE ") :]
return SQLStatement(analyzed_sql, self.engine).is_mutating()
return False
def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL statement.
"""
write = Dialect.get_or_raise(self._dialect)
return write.generate(self._parsed, copy=False, comments=comments, pretty=True)
def get_settings(self) -> dict[str, str | bool]:
"""
Return the settings for the SQL statement.
>>> statement = SQLStatement("SET foo = 'bar'")
>>> statement.get_settings()
{"foo": "'bar'"}
"""
return {
eq.this.sql(): eq.expression.sql()
for set_item in self._parsed.find_all(exp.SetItem)
for eq in set_item.find_all(exp.EQ)
}
class KQLSplitState(enum.Enum):
"""
State machine for splitting a KQL query.
The state machine keeps track of whether we're inside a string or not, so we
don't split the query in a semi-colon that's part of a string.
"""
OUTSIDE_STRING = enum.auto()
INSIDE_SINGLE_QUOTED_STRING = enum.auto()
INSIDE_DOUBLE_QUOTED_STRING = enum.auto()
INSIDE_MULTILINE_STRING = enum.auto()
def split_kql(kql: str) -> list[str]:
"""
Custom function for splitting KQL statements.
"""
statements = []
state = KQLSplitState.OUTSIDE_STRING
statement_start = 0
query = kql if kql.endswith(";") else kql + ";"
for i, character in enumerate(query):
if state == KQLSplitState.OUTSIDE_STRING:
if character == ";":
statements.append(query[statement_start:i])
statement_start = i + 1
elif character == "'":
state = KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
elif character == '"':
state = KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
elif character == "`" and query[i - 2 : i] == "``":
state = KQLSplitState.INSIDE_MULTILINE_STRING
elif (
state == KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
and character == "'"
and query[i - 1] != "\\"
):
state = KQLSplitState.OUTSIDE_STRING
elif (
state == KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
and character == '"'
and query[i - 1] != "\\"
):
state = KQLSplitState.OUTSIDE_STRING
elif (
state == KQLSplitState.INSIDE_MULTILINE_STRING
and character == "`"
and query[i - 2 : i] == "``"
):
state = KQLSplitState.OUTSIDE_STRING
return statements
class KustoKQLStatement(BaseSQLStatement[str]):
"""
Special class for Kusto KQL.
Kusto KQL is a SQL-like language, but it's not supported by sqlglot. Queries look
like this:
StormEvents
| summarize PropertyDamage = sum(DamageProperty) by State
| join kind=innerunique PopulationData on State
| project State, PropertyDamagePerCapita = PropertyDamage / Population
| sort by PropertyDamagePerCapita
See https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/ for more
details about it.
"""
@classmethod
def split_query(
cls,
query: str,
engine: str,
) -> list[KustoKQLStatement]:
"""
Split a query at semi-colons.
Since we don't have a parser, we use a simple state machine based function. See
https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/scalar-data-types/string
for more information.
"""
return [cls(statement, engine) for statement in split_kql(query)]
@classmethod
def _parse_statement(
cls,
statement: str,
engine: str,
) -> str:
if engine != "kustokql":
raise SupersetParseError(f"Invalid engine: {engine}")
statements = split_kql(statement)
if len(statements) != 1:
raise SupersetParseError("SQLStatement should have exactly one statement")
return statements[0].strip()
@classmethod
def _extract_tables_from_statement(cls, parsed: str, engine: str) -> set[Table]:
"""
Extract all tables referenced in the statement.
StormEvents
| where InjuriesDirect + InjuriesIndirect > 50
| join (PopulationData) on State
| project State, Population, TotalInjuries = InjuriesDirect + InjuriesIndirect
"""
logger.warning(
"Kusto KQL doesn't support table extraction. This means that data access "
"roles will not be enforced by Superset in the database."
)
return set()
def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL statement.
"""
return self._parsed
def get_settings(self) -> dict[str, str | bool]:
"""
Return the settings for the SQL statement.
>>> statement = KustoKQLStatement("set querytrace;")
>>> statement.get_settings()
{"querytrace": True}
"""
set_regex = r"^set\s+(?P<name>\w+)(?:\s*=\s*(?P<value>\w+))?$"
if match := re.match(set_regex, self._parsed, re.IGNORECASE):
return {match.group("name"): match.group("value") or True}
return {}
def is_mutating(self) -> bool:
"""
Check if the statement mutates data (DDL/DML).
:return: True if the statement mutates data.
"""
return self._parsed.startswith(".") and not self._parsed.startswith(".show")
class SQLScript:
"""
A SQL script, with 0+ statements.
"""
# Special engines that can't be parsed using sqlglot. Supporting non-SQL engines
# adds a lot of complexity to Superset, so we should avoid adding new engines to
# this data structure.
special_engines = {
"kustokql": KustoKQLStatement,
}
def __init__(
self,
query: str,
engine: str,
):
statement_class = self.special_engines.get(engine, SQLStatement)
self.statements = statement_class.split_query(query, engine)
def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL query.
"""
return ";\n".join(statement.format(comments) for statement in self.statements)
def get_settings(self) -> dict[str, str | bool]:
"""
Return the settings for the SQL query.
>>> statement = SQLScript("SET foo = 'bar'; SET foo = 'baz'")
>>> statement.get_settings()
{"foo": "'baz'"}
"""
settings: dict[str, str | bool] = {}
for statement in self.statements:
settings.update(statement.get_settings())
return settings
def has_mutation(self) -> bool:
"""
Check if the script contains mutating statements.
:return: True if the script contains mutating statements
"""
return any(statement.is_mutating() for statement in self.statements)
class ParsedQuery:
def __init__(
self,
sql_statement: str,
strip_comments: bool = False,
engine: str | None = None,
engine: str = "base",
):
if strip_comments:
sql_statement = sqlparse.format(sql_statement, strip_comments=True)
self.sql: str = sql_statement
self._engine = engine
self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
self._tables: set[Table] = set()
self._alias_names: set[str] = set()
@ -854,24 +314,18 @@ class ParsedQuery:
Note: this uses sqlglot, since it's better at catching more edge cases.
"""
try:
statements = parse(self.stripped(), dialect=self._dialect)
except SqlglotError as ex:
statements = [
statement._parsed # pylint: disable=protected-access
for statement in SQLScript(self.stripped(), self._engine).statements
]
except SupersetParseError as ex:
logger.warning("Unable to parse SQL (%s): %s", self._dialect, self.sql)
message = (
"Error parsing near '{highlight}' at line {line}:{col}".format( # pylint: disable=consider-using-f-string
**ex.errors[0]
)
if isinstance(ex, ParseError)
else str(ex)
)
raise SupersetSecurityException(
SupersetError(
error_type=SupersetErrorType.QUERY_SECURITY_ACCESS_ERROR,
message=__(
"You may have an error in your SQL statement. {message}"
).format(message=message),
).format(message=ex.error.message),
level=ErrorLevel.ERROR,
)
) from ex
@ -883,77 +337,6 @@ class ParsedQuery:
if statement
}
def _extract_tables_from_statement(self, statement: exp.Expression) -> set[Table]:
"""
Extract all table references in a single statement.
Please not that this is not trivial; consider the following queries:
DESCRIBE some_table;
SHOW PARTITIONS FROM some_table;
WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM masked_name;
See the unit tests for other tricky cases.
"""
sources: Iterable[exp.Table]
if isinstance(statement, exp.Describe):
# A `DESCRIBE` query has no sources in sqlglot, so we need to explicitly
# query for all tables.
sources = statement.find_all(exp.Table)
elif isinstance(statement, exp.Command):
# Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a
# `SELECT` statetement in order to extract tables.
if not (literal := statement.find(exp.Literal)):
return set()
try:
pseudo_query = parse_one(
f"SELECT {literal.this}",
dialect=self._dialect,
)
sources = pseudo_query.find_all(exp.Table)
except SqlglotError:
return set()
else:
sources = [
source
for scope in traverse_scope(statement)
for source in scope.sources.values()
if isinstance(source, exp.Table) and not self._is_cte(source, scope)
]
return {
Table(
source.name,
source.db if source.db != "" else None,
source.catalog if source.catalog != "" else None,
)
for source in sources
}
def _is_cte(self, source: exp.Table, scope: Scope) -> bool:
"""
Is the source a CTE?
CTEs in the parent scope look like tables (and are represented by
exp.Table objects), but should not be considered as such;
otherwise a user with access to table `foo` could access any table
with a query like this:
WITH foo AS (SELECT * FROM target_table) SELECT * FROM foo
"""
parent_sources = scope.parent.sources if scope.parent else {}
ctes_in_scope = {
name
for name, parent_scope in parent_sources.items()
if isinstance(parent_scope, Scope)
and parent_scope.scope_type == ScopeType.CTE
}
return source.name in ctes_in_scope
@property
def limit(self) -> int | None:
return self._limit

View File

@ -35,8 +35,8 @@ from superset.daos.query import QueryDAO
from superset.extensions import event_logger
from superset.jinja_context import get_template_processor
from superset.models.sql_lab import Query
from superset.sql.parse import SQLScript
from superset.sql_lab import get_sql_results
from superset.sql_parse import SQLScript
from superset.sqllab.command_status import SqlJsonExecutionStatus
from superset.sqllab.exceptions import (
QueryIsForbiddenToAccessException,

View File

@ -108,9 +108,8 @@ def cache_dashboard_thumbnail(
)
# pylint: disable=too-many-arguments
@celery_app.task(name="cache_dashboard_screenshot", soft_time_limit=300)
def cache_dashboard_screenshot(
def cache_dashboard_screenshot( # pylint: disable=too-many-arguments
username: str,
dashboard_id: int,
dashboard_url: str,

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,920 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, redefined-outer-name, too-many-lines
import pytest
from superset.exceptions import SupersetParseError
from superset.sql.parse import (
extract_tables_from_statement,
KustoKQLStatement,
split_kql,
SQLGLOT_DIALECTS,
SQLScript,
SQLStatement,
Table,
)
def test_table() -> None:
"""
Test the `Table` class and its string conversion.
Special characters in the table, schema, or catalog name should be escaped correctly.
"""
assert str(Table("tbname")) == "tbname"
assert str(Table("tbname", "schemaname")) == "schemaname.tbname"
assert (
str(Table("tbname", "schemaname", "catalogname"))
== "catalogname.schemaname.tbname"
)
assert (
str(Table("table.name", "schema/name", "catalog\nname"))
== "catalog%0Aname.schema%2Fname.table%2Ename"
)
def extract_tables_from_sql(sql: str, engine: str = "postgresql") -> set[Table]:
"""
Helper function to extract tables from SQL.
"""
dialect = SQLGLOT_DIALECTS.get(engine)
return {
table
for statement in SQLScript(sql, engine).statements
for table in extract_tables_from_statement(statement._parsed, dialect)
}
def test_extract_tables_from_sql() -> None:
"""
Test that referenced tables are parsed correctly from the SQL.
"""
assert extract_tables_from_sql("SELECT * FROM tbname") == {Table("tbname")}
assert extract_tables_from_sql("SELECT * FROM tbname foo") == {Table("tbname")}
assert extract_tables_from_sql("SELECT * FROM tbname AS foo") == {Table("tbname")}
# underscore
assert extract_tables_from_sql("SELECT * FROM tb_name") == {Table("tb_name")}
# quotes
assert extract_tables_from_sql('SELECT * FROM "tbname"') == {Table("tbname")}
# unicode
assert extract_tables_from_sql('SELECT * FROM "tb_name" WHERE city = "Lübeck"') == {
Table("tb_name")
}
# columns
assert extract_tables_from_sql("SELECT field1, field2 FROM tb_name") == {
Table("tb_name")
}
assert extract_tables_from_sql("SELECT t1.f1, t2.f2 FROM t1, t2") == {
Table("t1"),
Table("t2"),
}
# named table
assert extract_tables_from_sql(
"SELECT a.date, a.field FROM left_table a LIMIT 10"
) == {Table("left_table")}
assert extract_tables_from_sql(
"SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;"
) == {Table("forbidden_table")}
assert extract_tables_from_sql(
"select * from (select * from forbidden_table) forbidden_table"
) == {Table("forbidden_table")}
def test_extract_tables_subselect() -> None:
"""
Test that tables inside subselects are parsed correctly.
"""
assert extract_tables_from_sql(
"""
SELECT sub.*
FROM (
SELECT *
FROM s1.t1
WHERE day_of_week = 'Friday'
) sub, s2.t2
WHERE sub.resolution = 'NONE'
"""
) == {Table("t1", "s1"), Table("t2", "s2")}
assert extract_tables_from_sql(
"""
SELECT sub.*
FROM (
SELECT *
FROM s1.t1
WHERE day_of_week = 'Friday'
) sub
WHERE sub.resolution = 'NONE'
"""
) == {Table("t1", "s1")}
assert extract_tables_from_sql(
"""
SELECT * FROM t1
WHERE s11 > ANY (
SELECT COUNT(*) /* no hint */ FROM t2
WHERE NOT EXISTS (
SELECT * FROM t3
WHERE ROW(5*t2.s1,77)=(
SELECT 50,11*s1 FROM t4
)
)
)
"""
) == {Table("t1"), Table("t2"), Table("t3"), Table("t4")}
def test_extract_tables_select_in_expression() -> None:
"""
Test that parser works with `SELECT`s used as expressions.
"""
assert extract_tables_from_sql("SELECT f1, (SELECT count(1) FROM t2) FROM t1") == {
Table("t1"),
Table("t2"),
}
assert extract_tables_from_sql(
"SELECT f1, (SELECT count(1) FROM t2) as f2 FROM t1"
) == {
Table("t1"),
Table("t2"),
}
def test_extract_tables_parenthesis() -> None:
"""
Test that parenthesis are parsed correctly.
"""
assert extract_tables_from_sql("SELECT f1, (x + y) AS f2 FROM t1") == {Table("t1")}
def test_extract_tables_with_schema() -> None:
"""
Test that schemas are parsed correctly.
"""
assert extract_tables_from_sql("SELECT * FROM schemaname.tbname") == {
Table("tbname", "schemaname")
}
assert extract_tables_from_sql('SELECT * FROM "schemaname"."tbname"') == {
Table("tbname", "schemaname")
}
assert extract_tables_from_sql('SELECT * FROM "schemaname"."tbname" foo') == {
Table("tbname", "schemaname")
}
assert extract_tables_from_sql('SELECT * FROM "schemaname"."tbname" AS foo') == {
Table("tbname", "schemaname")
}
def test_extract_tables_union() -> None:
"""
Test that `UNION` queries work as expected.
"""
assert extract_tables_from_sql("SELECT * FROM t1 UNION SELECT * FROM t2") == {
Table("t1"),
Table("t2"),
}
assert extract_tables_from_sql("SELECT * FROM t1 UNION ALL SELECT * FROM t2") == {
Table("t1"),
Table("t2"),
}
assert extract_tables_from_sql(
"SELECT * FROM t1 INTERSECT ALL SELECT * FROM t2"
) == {
Table("t1"),
Table("t2"),
}
def test_extract_tables_select_from_values() -> None:
"""
Test that selecting from values returns no tables.
"""
assert extract_tables_from_sql("SELECT * FROM VALUES (13, 42)") == set()
def test_extract_tables_select_array() -> None:
"""
Test that queries selecting arrays work as expected.
"""
assert extract_tables_from_sql(
"""
SELECT ARRAY[1, 2, 3] AS my_array
FROM t1 LIMIT 10
"""
) == {Table("t1")}
def test_extract_tables_select_if() -> None:
"""
Test that queries with an `IF` work as expected.
"""
assert extract_tables_from_sql(
"""
SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL)
FROM t1 LIMIT 10
"""
) == {Table("t1")}
def test_extract_tables_with_catalog() -> None:
"""
Test that catalogs are parsed correctly.
"""
assert extract_tables_from_sql("SELECT * FROM catalogname.schemaname.tbname") == {
Table("tbname", "schemaname", "catalogname")
}
def test_extract_tables_illdefined() -> None:
"""
Test that ill-defined tables return an empty set.
"""
with pytest.raises(SupersetParseError) as excinfo:
extract_tables_from_sql("SELECT * FROM schemaname.")
assert str(excinfo.value) == "Error parsing near '.' at line 1:25"
with pytest.raises(SupersetParseError) as excinfo:
extract_tables_from_sql("SELECT * FROM catalogname.schemaname.")
assert str(excinfo.value) == "Error parsing near '.' at line 1:37"
with pytest.raises(SupersetParseError) as excinfo:
extract_tables_from_sql("SELECT * FROM catalogname..")
assert str(excinfo.value) == "Error parsing near '.' at line 1:27"
with pytest.raises(SupersetParseError) as excinfo:
extract_tables_from_sql('SELECT * FROM "tbname')
assert str(excinfo.value) == "Unable to parse script"
# odd edge case that works
assert extract_tables_from_sql("SELECT * FROM catalogname..tbname") == {
Table(table="tbname", schema=None, catalog="catalogname")
}
def test_extract_tables_show_tables_from() -> None:
"""
Test `SHOW TABLES FROM`.
"""
assert (
extract_tables_from_sql("SHOW TABLES FROM s1 like '%order%'", "mysql") == set()
)
def test_extract_tables_show_columns_from() -> None:
"""
Test `SHOW COLUMNS FROM`.
"""
assert extract_tables_from_sql("SHOW COLUMNS FROM t1") == {Table("t1")}
def test_extract_tables_where_subquery() -> None:
"""
Test that tables in a `WHERE` subquery are parsed correctly.
"""
assert extract_tables_from_sql(
"""
SELECT name
FROM t1
WHERE regionkey = (SELECT max(regionkey) FROM t2)
"""
) == {Table("t1"), Table("t2")}
assert extract_tables_from_sql(
"""
SELECT name
FROM t1
WHERE regionkey IN (SELECT regionkey FROM t2)
"""
) == {Table("t1"), Table("t2")}
assert extract_tables_from_sql(
"""
SELECT name
FROM t1
WHERE EXISTS (SELECT 1 FROM t2 WHERE t1.regionkey = t2.regionkey);
"""
) == {Table("t1"), Table("t2")}
def test_extract_tables_describe() -> None:
"""
Test `DESCRIBE`.
"""
assert extract_tables_from_sql("DESCRIBE t1") == {Table("t1")}
def test_extract_tables_show_partitions() -> None:
"""
Test `SHOW PARTITIONS`.
"""
assert extract_tables_from_sql(
"""
SHOW PARTITIONS FROM orders
WHERE ds >= '2013-01-01' ORDER BY ds DESC
"""
) == {Table("orders")}
def test_extract_tables_join() -> None:
"""
Test joins.
"""
assert extract_tables_from_sql(
"SELECT t1.*, t2.* FROM t1 JOIN t2 ON t1.a = t2.a;"
) == {
Table("t1"),
Table("t2"),
}
assert extract_tables_from_sql(
"""
SELECT a.date, b.name
FROM left_table a
JOIN (
SELECT
CAST((b.year) as VARCHAR) date,
name
FROM right_table
) b
ON a.date = b.date
"""
) == {Table("left_table"), Table("right_table")}
assert extract_tables_from_sql(
"""
SELECT a.date, b.name
FROM left_table a
LEFT INNER JOIN (
SELECT
CAST((b.year) as VARCHAR) date,
name
FROM right_table
) b
ON a.date = b.date
"""
) == {Table("left_table"), Table("right_table")}
assert extract_tables_from_sql(
"""
SELECT a.date, b.name
FROM left_table a
RIGHT OUTER JOIN (
SELECT
CAST((b.year) as VARCHAR) date,
name
FROM right_table
) b
ON a.date = b.date
"""
) == {Table("left_table"), Table("right_table")}
assert extract_tables_from_sql(
"""
SELECT a.date, b.name
FROM left_table a
FULL OUTER JOIN (
SELECT
CAST((b.year) as VARCHAR) date,
name
FROM right_table
) b
ON a.date = b.date
"""
) == {Table("left_table"), Table("right_table")}
def test_extract_tables_semi_join() -> None:
"""
Test `LEFT SEMI JOIN`.
"""
assert extract_tables_from_sql(
"""
SELECT a.date, b.name
FROM left_table a
LEFT SEMI JOIN (
SELECT
CAST((b.year) as VARCHAR) date,
name
FROM right_table
) b
ON a.data = b.date
"""
) == {Table("left_table"), Table("right_table")}
def test_extract_tables_combinations() -> None:
"""
Test a complex case with nested queries.
"""
assert extract_tables_from_sql(
"""
SELECT * FROM t1
WHERE s11 > ANY (
SELECT * FROM t1 UNION ALL SELECT * FROM (
SELECT t6.*, t3.* FROM t6 JOIN t3 ON t6.a = t3.a
) tmp_join
WHERE NOT EXISTS (
SELECT * FROM t3
WHERE ROW(5*t3.s1,77)=(
SELECT 50,11*s1 FROM t4
)
)
)
"""
) == {Table("t1"), Table("t3"), Table("t4"), Table("t6")}
assert extract_tables_from_sql(
"""
SELECT * FROM (
SELECT * FROM (
SELECT * FROM (
SELECT * FROM EmployeeS
) AS S1
) AS S2
) AS S3
"""
) == {Table("EmployeeS")}
def test_extract_tables_with() -> None:
"""
Test `WITH`.
"""
assert extract_tables_from_sql(
"""
WITH
x AS (SELECT a FROM t1),
y AS (SELECT a AS b FROM t2),
z AS (SELECT b AS c FROM t3)
SELECT c FROM z
"""
) == {Table("t1"), Table("t2"), Table("t3")}
assert extract_tables_from_sql(
"""
WITH
x AS (SELECT a FROM t1),
y AS (SELECT a AS b FROM x),
z AS (SELECT b AS c FROM y)
SELECT c FROM z
"""
) == {Table("t1")}
def test_extract_tables_reusing_aliases() -> None:
"""
Test that the parser follows aliases.
"""
assert extract_tables_from_sql(
"""
with q1 as ( select key from q2 where key = '5'),
q2 as ( select key from src where key = '5')
select * from (select key from q1) a
"""
) == {Table("src")}
# weird query with circular dependency
assert (
extract_tables_from_sql(
"""
with src as ( select key from q2 where key = '5'),
q2 as ( select key from src where key = '5')
select * from (select key from src) a
"""
)
== set()
)
def test_extract_tables_multistatement() -> None:
"""
Test that the parser works with multiple statements.
"""
assert extract_tables_from_sql("SELECT * FROM t1; SELECT * FROM t2") == {
Table("t1"),
Table("t2"),
}
assert extract_tables_from_sql("SELECT * FROM t1; SELECT * FROM t2;") == {
Table("t1"),
Table("t2"),
}
assert extract_tables_from_sql(
"ADD JAR file:///hive.jar; SELECT * FROM t1;",
engine="hive",
) == {Table("t1")}
def test_extract_tables_complex() -> None:
"""
Test a few complex queries.
"""
assert extract_tables_from_sql(
"""
SELECT sum(m_examples) AS "sum__m_example"
FROM (
SELECT
COUNT(DISTINCT id_userid) AS m_examples,
some_more_info
FROM my_b_table b
JOIN my_t_table t ON b.ds=t.ds
JOIN my_l_table l ON b.uid=l.uid
WHERE
b.rid IN (
SELECT other_col
FROM inner_table
)
AND l.bla IN ('x', 'y')
GROUP BY 2
ORDER BY 2 ASC
) AS "meh"
ORDER BY "sum__m_example" DESC
LIMIT 10;
"""
) == {
Table("my_l_table"),
Table("my_b_table"),
Table("my_t_table"),
Table("inner_table"),
}
assert extract_tables_from_sql(
"""
SELECT *
FROM table_a AS a, table_b AS b, table_c as c
WHERE a.id = b.id and b.id = c.id
"""
) == {Table("table_a"), Table("table_b"), Table("table_c")}
assert extract_tables_from_sql(
"""
SELECT somecol AS somecol
FROM (
WITH bla AS (
SELECT col_a
FROM a
WHERE
1=1
AND column_of_choice NOT IN (
SELECT interesting_col
FROM b
)
),
rb AS (
SELECT yet_another_column
FROM (
SELECT a
FROM c
GROUP BY the_other_col
) not_table
LEFT JOIN bla foo
ON foo.prop = not_table.bad_col0
WHERE 1=1
GROUP BY
not_table.bad_col1 ,
not_table.bad_col2 ,
ORDER BY not_table.bad_col_3 DESC ,
not_table.bad_col4 ,
not_table.bad_col5
)
SELECT random_col
FROM d
WHERE 1=1
UNION ALL SELECT even_more_cols
FROM e
WHERE 1=1
UNION ALL SELECT lets_go_deeper
FROM f
WHERE 1=1
WHERE 2=2
GROUP BY last_col
LIMIT 50000
)
"""
) == {Table("a"), Table("b"), Table("c"), Table("d"), Table("e"), Table("f")}
def test_extract_tables_mixed_from_clause() -> None:
"""
Test that the parser handles a `FROM` clause with table and subselect.
"""
assert extract_tables_from_sql(
"""
SELECT *
FROM table_a AS a, (select * from table_b) AS b, table_c as c
WHERE a.id = b.id and b.id = c.id
"""
) == {Table("table_a"), Table("table_b"), Table("table_c")}
def test_extract_tables_nested_select() -> None:
"""
Test that the parser handles selects inside functions.
"""
assert extract_tables_from_sql(
"""
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME)
from INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
""",
"mysql",
) == {Table("COLUMNS", "INFORMATION_SCHEMA")}
assert extract_tables_from_sql(
"""
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
from INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_NAME="bi_achievement_daily"),0x7e)));
""",
"mysql",
) == {Table("COLUMNS", "INFORMATION_SCHEMA")}
def test_extract_tables_complex_cte_with_prefix() -> None:
"""
Test that the parser handles CTEs with prefixes.
"""
assert extract_tables_from_sql(
"""
WITH CTE__test (SalesPersonID, SalesOrderID, SalesYear)
AS (
SELECT SalesPersonID, SalesOrderID, YEAR(OrderDate) AS SalesYear
FROM SalesOrderHeader
WHERE SalesPersonID IS NOT NULL
)
SELECT SalesPersonID, COUNT(SalesOrderID) AS TotalSales, SalesYear
FROM CTE__test
GROUP BY SalesYear, SalesPersonID
ORDER BY SalesPersonID, SalesYear;
"""
) == {Table("SalesOrderHeader")}
def test_extract_tables_identifier_list_with_keyword_as_alias() -> None:
"""
Test that aliases that are keywords are parsed correctly.
"""
assert extract_tables_from_sql(
"""
WITH
f AS (SELECT * FROM foo),
match AS (SELECT * FROM f)
SELECT * FROM match
"""
) == {Table("foo")}
def test_sqlscript() -> None:
"""
Test the `SQLScript` class.
"""
script = SQLScript("SELECT 1; SELECT 2;", "sqlite")
assert len(script.statements) == 2
assert script.format() == "SELECT\n 1;\nSELECT\n 2"
assert script.statements[0].format() == "SELECT\n 1"
script = SQLScript("SET a=1; SET a=2; SELECT 3;", "sqlite")
assert script.get_settings() == {"a": "2"}
query = SQLScript(
"""set querytrace;
Events | take 100""",
"kustokql",
)
assert query.get_settings() == {"querytrace": True}
def test_sqlstatement() -> None:
"""
Test the `SQLStatement` class.
"""
statement = SQLStatement(
"SELECT * FROM table1 UNION ALL SELECT * FROM table2",
"sqlite",
)
assert statement.tables == {
Table(table="table1", schema=None, catalog=None),
Table(table="table2", schema=None, catalog=None),
}
assert (
statement.format()
== "SELECT\n *\nFROM table1\nUNION ALL\nSELECT\n *\nFROM table2"
)
statement = SQLStatement("SET a=1", "sqlite")
assert statement.get_settings() == {"a": "1"}
def test_kustokqlstatement_split_script() -> None:
"""
Test the `KustoKQLStatement` split method.
"""
statements = KustoKQLStatement.split_script(
"""
let totalPagesPerDay = PageViews
| summarize by Page, Day = startofday(Timestamp)
| summarize count() by Day;
let materializedScope = PageViews
| summarize by Page, Day = startofday(Timestamp);
let cachedResult = materialize(materializedScope);
cachedResult
| project Page, Day1 = Day
| join kind = inner
(
cachedResult
| project Page, Day2 = Day
)
on Page
| where Day2 > Day1
| summarize count() by Day1, Day2
| join kind = inner
totalPagesPerDay
on $left.Day1 == $right.Day
| project Day1, Day2, Percentage = count_*100.0/count_1
""",
"kustokql",
)
assert len(statements) == 4
def test_kustokqlstatement_with_program() -> None:
"""
Test the `KustoKQLStatement` split method when the KQL has a program.
"""
statements = KustoKQLStatement.split_script(
"""
print program = ```
public class Program {
public static void Main() {
System.Console.WriteLine("Hello!");
}
}```
""",
"kustokql",
)
assert len(statements) == 1
def test_kustokqlstatement_with_set() -> None:
"""
Test the `KustoKQLStatement` split method when the KQL has a set command.
"""
statements = KustoKQLStatement.split_script(
"""
set querytrace;
Events | take 100
""",
"kustokql",
)
assert len(statements) == 2
assert statements[0].format() == "set querytrace"
assert statements[1].format() == "Events | take 100"
@pytest.mark.parametrize(
"kql,statements",
[
('print banner=strcat("Hello", ", ", "World!")', 1),
(r"print 'O\'Malley\'s'", 1),
(r"print 'O\'Mal;ley\'s'", 1),
("print ```foo;\nbar;\nbaz;```\n", 1),
],
)
def test_kustokql_statement_split_special(kql: str, statements: int) -> None:
assert len(KustoKQLStatement.split_script(kql, "kustokql")) == statements
def test_split_kql() -> None:
"""
Test the `split_kql` function.
"""
kql = """
let totalPagesPerDay = PageViews
| summarize by Page, Day = startofday(Timestamp)
| summarize count() by Day;
let materializedScope = PageViews
| summarize by Page, Day = startofday(Timestamp);
let cachedResult = materialize(materializedScope);
cachedResult
| project Page, Day1 = Day
| join kind = inner
(
cachedResult
| project Page, Day2 = Day
)
on Page
| where Day2 > Day1
| summarize count() by Day1, Day2
| join kind = inner
totalPagesPerDay
on $left.Day1 == $right.Day
| project Day1, Day2, Percentage = count_*100.0/count_1
"""
assert split_kql(kql) == [
"""
let totalPagesPerDay = PageViews
| summarize by Page, Day = startofday(Timestamp)
| summarize count() by Day""",
"""
let materializedScope = PageViews
| summarize by Page, Day = startofday(Timestamp)""",
"""
let cachedResult = materialize(materializedScope)""",
"""
cachedResult
| project Page, Day1 = Day
| join kind = inner
(
cachedResult
| project Page, Day2 = Day
)
on Page
| where Day2 > Day1
| summarize count() by Day1, Day2
| join kind = inner
totalPagesPerDay
on $left.Day1 == $right.Day
| project Day1, Day2, Percentage = count_*100.0/count_1
""",
]
@pytest.mark.parametrize(
("engine", "sql", "expected"),
[
# SQLite tests
("sqlite", "SELECT 1", False),
("sqlite", "INSERT INTO foo VALUES (1)", True),
("sqlite", "UPDATE foo SET bar = 2 WHERE id = 1", True),
("sqlite", "DELETE FROM foo WHERE id = 1", True),
("sqlite", "CREATE TABLE foo (id INT, bar TEXT)", True),
("sqlite", "DROP TABLE foo", True),
("sqlite", "EXPLAIN SELECT * FROM foo", False),
("sqlite", "PRAGMA table_info(foo)", False),
("postgresql", "SELECT 1", False),
("postgresql", "INSERT INTO foo (id, bar) VALUES (1, 'test')", True),
("postgresql", "UPDATE foo SET bar = 'new' WHERE id = 1", True),
("postgresql", "DELETE FROM foo WHERE id = 1", True),
("postgresql", "CREATE TABLE foo (id SERIAL PRIMARY KEY, bar TEXT)", True),
("postgresql", "DROP TABLE foo", True),
("postgresql", "EXPLAIN ANALYZE SELECT * FROM foo", False),
("postgresql", "EXPLAIN ANALYZE DELETE FROM foo", True),
("postgresql", "SHOW search_path", False),
("postgresql", "SET search_path TO public", False),
(
"postgres",
"""
with source as (
select 1 as one
)
select * from source
""",
False,
),
("trino", "SELECT 1", False),
("trino", "INSERT INTO foo VALUES (1, 'bar')", True),
("trino", "UPDATE foo SET bar = 'baz' WHERE id = 1", True),
("trino", "DELETE FROM foo WHERE id = 1", True),
("trino", "CREATE TABLE foo (id INT, bar VARCHAR)", True),
("trino", "DROP TABLE foo", True),
("trino", "EXPLAIN SELECT * FROM foo", False),
("trino", "SHOW SCHEMAS", False),
("trino", "SET SESSION optimization_level = '3'", False),
("kustokql", "tbl | limit 100", False),
("kustokql", "let foo = 1; tbl | where bar == foo", False),
("kustokql", ".show tables", False),
("kustokql", "print 1", False),
("kustokql", "set querytrace; Events | take 100", False),
("kustokql", ".drop table foo", True),
("kustokql", ".set-or-append table foo <| bar", True),
],
)
def test_has_mutation(engine: str, sql: str, expected: bool) -> None:
"""
Test the `has_mutation` method.
"""
assert SQLScript(sql, engine).has_mutation() == expected

View File

@ -30,6 +30,7 @@ from superset.exceptions import (
QueryClauseValidationException,
SupersetSecurityException,
)
from superset.sql.parse import Table
from superset.sql_parse import (
add_table_name,
check_sql_functions_exist,
@ -39,18 +40,13 @@ from superset.sql_parse import (
has_table_query,
insert_rls_as_subquery,
insert_rls_in_predicate,
KustoKQLStatement,
ParsedQuery,
sanitize_clause,
split_kql,
SQLScript,
SQLStatement,
strip_comments_from_sql,
Table,
)
def extract_tables(query: str, engine: Optional[str] = None) -> set[Table]:
def extract_tables(query: str, engine: str = "base") -> set[Table]:
"""
Helper function to extract tables referenced in a query.
"""
@ -285,7 +281,7 @@ def test_extract_tables_illdefined() -> None:
extract_tables('SELECT * FROM "tbname')
assert (
str(excinfo.value)
== "You may have an error in your SQL statement. Error tokenizing 'SELECT * FROM \"tbnam'"
== "You may have an error in your SQL statement. Unable to parse script"
)
# odd edge case that works
@ -1834,49 +1830,6 @@ SELECT * FROM t"""
assert ParsedQuery("USE foo; SELECT * FROM bar").is_select()
def test_sqlquery() -> None:
"""
Test the `SQLScript` class.
"""
script = SQLScript("SELECT 1; SELECT 2;", "sqlite")
assert len(script.statements) == 2
assert script.format() == "SELECT\n 1;\nSELECT\n 2"
assert script.statements[0].format() == "SELECT\n 1"
script = SQLScript("SET a=1; SET a=2; SELECT 3;", "sqlite")
assert script.get_settings() == {"a": "2"}
query = SQLScript(
"""set querytrace;
Events | take 100""",
"kustokql",
)
assert query.get_settings() == {"querytrace": True}
def test_sqlstatement() -> None:
"""
Test the `SQLStatement` class.
"""
statement = SQLStatement(
"SELECT * FROM table1 UNION ALL SELECT * FROM table2",
"sqlite",
)
assert statement.tables == {
Table(table="table1", schema=None, catalog=None),
Table(table="table2", schema=None, catalog=None),
}
assert (
statement.format()
== "SELECT\n *\nFROM table1\nUNION ALL\nSELECT\n *\nFROM table2"
)
statement = SQLStatement("SET a=1", "sqlite")
assert statement.get_settings() == {"a": "1"}
@pytest.mark.parametrize(
"engine",
[
@ -1924,194 +1877,3 @@ def test_extract_tables_from_jinja_sql(
)
== expected
)
def test_kustokqlstatement_split_query() -> None:
"""
Test the `KustoKQLStatement` split method.
"""
statements = KustoKQLStatement.split_query(
"""
let totalPagesPerDay = PageViews
| summarize by Page, Day = startofday(Timestamp)
| summarize count() by Day;
let materializedScope = PageViews
| summarize by Page, Day = startofday(Timestamp);
let cachedResult = materialize(materializedScope);
cachedResult
| project Page, Day1 = Day
| join kind = inner
(
cachedResult
| project Page, Day2 = Day
)
on Page
| where Day2 > Day1
| summarize count() by Day1, Day2
| join kind = inner
totalPagesPerDay
on $left.Day1 == $right.Day
| project Day1, Day2, Percentage = count_*100.0/count_1
""",
"kustokql",
)
assert len(statements) == 4
def test_kustokqlstatement_with_program() -> None:
"""
Test the `KustoKQLStatement` split method when the KQL has a program.
"""
statements = KustoKQLStatement.split_query(
"""
print program = ```
public class Program {
public static void Main() {
System.Console.WriteLine("Hello!");
}
}```
""",
"kustokql",
)
assert len(statements) == 1
def test_kustokqlstatement_with_set() -> None:
"""
Test the `KustoKQLStatement` split method when the KQL has a set command.
"""
statements = KustoKQLStatement.split_query(
"""
set querytrace;
Events | take 100
""",
"kustokql",
)
assert len(statements) == 2
assert statements[0].format() == "set querytrace"
assert statements[1].format() == "Events | take 100"
@pytest.mark.parametrize(
"kql,statements",
[
('print banner=strcat("Hello", ", ", "World!")', 1),
(r"print 'O\'Malley\'s'", 1),
(r"print 'O\'Mal;ley\'s'", 1),
("print ```foo;\nbar;\nbaz;```\n", 1),
],
)
def test_kustokql_statement_split_special(kql: str, statements: int) -> None:
assert len(KustoKQLStatement.split_query(kql, "kustokql")) == statements
def test_split_kql() -> None:
"""
Test the `split_kql` function.
"""
kql = """
let totalPagesPerDay = PageViews
| summarize by Page, Day = startofday(Timestamp)
| summarize count() by Day;
let materializedScope = PageViews
| summarize by Page, Day = startofday(Timestamp);
let cachedResult = materialize(materializedScope);
cachedResult
| project Page, Day1 = Day
| join kind = inner
(
cachedResult
| project Page, Day2 = Day
)
on Page
| where Day2 > Day1
| summarize count() by Day1, Day2
| join kind = inner
totalPagesPerDay
on $left.Day1 == $right.Day
| project Day1, Day2, Percentage = count_*100.0/count_1
"""
assert split_kql(kql) == [
"""
let totalPagesPerDay = PageViews
| summarize by Page, Day = startofday(Timestamp)
| summarize count() by Day""",
"""
let materializedScope = PageViews
| summarize by Page, Day = startofday(Timestamp)""",
"""
let cachedResult = materialize(materializedScope)""",
"""
cachedResult
| project Page, Day1 = Day
| join kind = inner
(
cachedResult
| project Page, Day2 = Day
)
on Page
| where Day2 > Day1
| summarize count() by Day1, Day2
| join kind = inner
totalPagesPerDay
on $left.Day1 == $right.Day
| project Day1, Day2, Percentage = count_*100.0/count_1
""",
]
@pytest.mark.parametrize(
("engine", "sql", "expected"),
[
# SQLite tests
("sqlite", "SELECT 1", False),
("sqlite", "INSERT INTO foo VALUES (1)", True),
("sqlite", "UPDATE foo SET bar = 2 WHERE id = 1", True),
("sqlite", "DELETE FROM foo WHERE id = 1", True),
("sqlite", "CREATE TABLE foo (id INT, bar TEXT)", True),
("sqlite", "DROP TABLE foo", True),
("sqlite", "EXPLAIN SELECT * FROM foo", False),
("sqlite", "PRAGMA table_info(foo)", False),
("postgresql", "SELECT 1", False),
("postgresql", "INSERT INTO foo (id, bar) VALUES (1, 'test')", True),
("postgresql", "UPDATE foo SET bar = 'new' WHERE id = 1", True),
("postgresql", "DELETE FROM foo WHERE id = 1", True),
("postgresql", "CREATE TABLE foo (id SERIAL PRIMARY KEY, bar TEXT)", True),
("postgresql", "DROP TABLE foo", True),
("postgresql", "EXPLAIN ANALYZE SELECT * FROM foo", False),
("postgresql", "EXPLAIN ANALYZE DELETE FROM foo", True),
("postgresql", "SHOW search_path", False),
("postgresql", "SET search_path TO public", False),
(
"postgres",
"""
with source as (
select 1 as one
)
select * from source
""",
False,
),
("trino", "SELECT 1", False),
("trino", "INSERT INTO foo VALUES (1, 'bar')", True),
("trino", "UPDATE foo SET bar = 'baz' WHERE id = 1", True),
("trino", "DELETE FROM foo WHERE id = 1", True),
("trino", "CREATE TABLE foo (id INT, bar VARCHAR)", True),
("trino", "DROP TABLE foo", True),
("trino", "EXPLAIN SELECT * FROM foo", False),
("trino", "SHOW SCHEMAS", False),
("trino", "SET SESSION optimization_level = '3'", False),
("kustokql", "tbl | limit 100", False),
("kustokql", "let foo = 1; tbl | where bar == foo", False),
("kustokql", ".show tables", False),
("kustokql", "print 1", False),
("kustokql", "set querytrace; Events | take 100", False),
("kustokql", ".drop table foo", True),
("kustokql", ".set-or-append table foo <| bar", True),
],
)
def test_has_mutation(engine: str, sql: str, expected: bool) -> None:
"""
Test the `has_mutation` method.
"""
assert SQLScript(sql, engine).has_mutation() == expected