fix: `sqlparse` fallback for formatting queries (#30578)
This commit is contained in:
parent
9a2b1a5cf7
commit
47c1e09c75
|
|
@ -26,6 +26,8 @@ from dataclasses import dataclass
|
|||
from typing import Any, Generic, TypeVar
|
||||
|
||||
import sqlglot
|
||||
import sqlparse
|
||||
from deprecation import deprecated
|
||||
from sqlglot import exp
|
||||
from sqlglot.dialects.dialect import Dialect, Dialects
|
||||
from sqlglot.errors import ParseError
|
||||
|
|
@ -138,9 +140,9 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
|
|||
"""
|
||||
Base class for SQL statements.
|
||||
|
||||
The class can be instantiated with a string representation of the script or, for
|
||||
efficiency reasons, with a pre-parsed AST. This is useful with `sqlglot.parse`,
|
||||
which will split a script in multiple already parsed statements.
|
||||
The class should be instantiated with a string representation of the script and, for
|
||||
efficiency reasons, optionally with a pre-parsed AST. This is useful with
|
||||
`sqlglot.parse`, which will split a script in multiple already parsed statements.
|
||||
|
||||
The `engine` parameters comes from the `engine` attribute in a Superset DB engine
|
||||
spec.
|
||||
|
|
@ -148,14 +150,12 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
statement: str | InternalRepresentation,
|
||||
statement: str,
|
||||
engine: str,
|
||||
ast: InternalRepresentation | None = None,
|
||||
):
|
||||
self._parsed: InternalRepresentation = (
|
||||
self._parse_statement(statement, engine)
|
||||
if isinstance(statement, str)
|
||||
else statement
|
||||
)
|
||||
self._sql = statement
|
||||
self._parsed = ast or self._parse_statement(statement, engine)
|
||||
self.engine = engine
|
||||
self.tables = self._extract_tables_from_statement(self._parsed, self.engine)
|
||||
|
||||
|
|
@ -239,11 +239,12 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
statement: str | exp.Expression,
|
||||
statement: str,
|
||||
engine: str,
|
||||
ast: exp.Expression | None = None,
|
||||
):
|
||||
self._dialect = SQLGLOT_DIALECTS.get(engine)
|
||||
super().__init__(statement, engine)
|
||||
super().__init__(statement, engine, ast)
|
||||
|
||||
@classmethod
|
||||
def _parse(cls, script: str, engine: str) -> list[exp.Expression]:
|
||||
|
|
@ -275,11 +276,47 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
|||
script: str,
|
||||
engine: str,
|
||||
) -> list[SQLStatement]:
|
||||
return [
|
||||
cls(statement, engine)
|
||||
for statement in cls._parse(script, engine)
|
||||
if statement
|
||||
]
|
||||
if engine in SQLGLOT_DIALECTS:
|
||||
try:
|
||||
return [
|
||||
cls(ast.sql(), engine, ast)
|
||||
for ast in cls._parse(script, engine)
|
||||
if ast
|
||||
]
|
||||
except ValueError:
|
||||
# `ast.sql()` might raise an error on some cases (eg, `SHOW TABLES
|
||||
# FROM`). In this case, we rely on the tokenizer to generate the
|
||||
# statements.
|
||||
pass
|
||||
|
||||
# When we don't have a sqlglot dialect we can't rely on `ast.sql()` to correctly
|
||||
# generate the SQL of each statement, so we tokenize the script and split it
|
||||
# based on the location of semi-colons.
|
||||
statements = []
|
||||
start = 0
|
||||
remainder = script
|
||||
|
||||
try:
|
||||
tokens = sqlglot.tokenize(script)
|
||||
except sqlglot.errors.TokenError as ex:
|
||||
raise SupersetParseError(
|
||||
script,
|
||||
engine,
|
||||
message="Unable to tokenize script",
|
||||
) from ex
|
||||
|
||||
for token in tokens:
|
||||
if token.token_type == sqlglot.TokenType.SEMICOLON:
|
||||
statement, start = script[start : token.start], token.end + 1
|
||||
ast = cls._parse(statement, engine)[0]
|
||||
statements.append(cls(statement.strip(), engine, ast))
|
||||
remainder = script[start:]
|
||||
|
||||
if remainder.strip():
|
||||
ast = cls._parse(remainder, engine)[0]
|
||||
statements.append(cls(remainder.strip(), engine, ast))
|
||||
|
||||
return statements
|
||||
|
||||
@classmethod
|
||||
def _parse_statement(
|
||||
|
|
@ -349,8 +386,34 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
|||
"""
|
||||
Pretty-format the SQL statement.
|
||||
"""
|
||||
write = Dialect.get_or_raise(self._dialect)
|
||||
return write.generate(self._parsed, copy=False, comments=comments, pretty=True)
|
||||
if self._dialect:
|
||||
try:
|
||||
write = Dialect.get_or_raise(self._dialect)
|
||||
return write.generate(
|
||||
self._parsed,
|
||||
copy=False,
|
||||
comments=comments,
|
||||
pretty=True,
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return self._fallback_formatting()
|
||||
|
||||
@deprecated(deprecated_in="4.0", removed_in="5.0")
|
||||
def _fallback_formatting(self) -> str:
|
||||
"""
|
||||
Format SQL without a specific dialect.
|
||||
|
||||
Reformatting SQL using the generic sqlglot dialect is known to break queries.
|
||||
For example, it will change `foo NOT IN (1, 2)` to `NOT foo IN (1,2)`, which
|
||||
breaks the query for Firebolt. To avoid this, we use sqlparse for formatting
|
||||
when the dialect is not known.
|
||||
|
||||
In 5.0 we should remove `sqlparse`, and the method should return the query
|
||||
unmodified.
|
||||
"""
|
||||
return sqlparse.format(self._sql, reindent=True, keyword_case="upper")
|
||||
|
||||
def get_settings(self) -> dict[str, str | bool]:
|
||||
"""
|
||||
|
|
@ -456,7 +519,9 @@ class KustoKQLStatement(BaseSQLStatement[str]):
|
|||
https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/scalar-data-types/string
|
||||
for more information.
|
||||
"""
|
||||
return [cls(statement, engine) for statement in split_kql(script)]
|
||||
return [
|
||||
cls(statement, engine, statement.strip()) for statement in split_kql(script)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _parse_statement(
|
||||
|
|
@ -498,7 +563,7 @@ class KustoKQLStatement(BaseSQLStatement[str]):
|
|||
"""
|
||||
Pretty-format the SQL statement.
|
||||
"""
|
||||
return self._parsed
|
||||
return self._sql.strip()
|
||||
|
||||
def get_settings(self) -> dict[str, str | bool]:
|
||||
"""
|
||||
|
|
@ -548,6 +613,9 @@ class SQLScript:
|
|||
def format(self, comments: bool = True) -> str:
|
||||
"""
|
||||
Pretty-format the SQL script.
|
||||
|
||||
Note that even though KQL is very different from SQL, multiple statements are
|
||||
still separated by semi-colons.
|
||||
"""
|
||||
return ";\n".join(statement.format(comments) for statement in self.statements)
|
||||
|
||||
|
|
|
|||
|
|
@ -281,7 +281,7 @@ class TestSqlLabApi(SupersetTestCase):
|
|||
"/api/v1/sqllab/format_sql/",
|
||||
json=data,
|
||||
)
|
||||
success_resp = {"result": "SELECT\n 1\nFROM my_table"}
|
||||
success_resp = {"result": "SELECT 1\nFROM my_table"}
|
||||
resp_data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertDictEqual(resp_data, success_resp) # noqa: PT009
|
||||
assert rv.status_code == 200
|
||||
|
|
|
|||
|
|
@ -241,14 +241,7 @@ def test_select_star(mocker: MockerFixture) -> None:
|
|||
latest_partition=False,
|
||||
cols=cols,
|
||||
)
|
||||
assert (
|
||||
sql
|
||||
== """SELECT
|
||||
a
|
||||
FROM my_table
|
||||
LIMIT ?
|
||||
OFFSET ?"""
|
||||
)
|
||||
assert sql == "SELECT a\nFROM my_table\nLIMIT ?\nOFFSET ?"
|
||||
|
||||
sql = NoLimitDBEngineSpec.select_star(
|
||||
database=database,
|
||||
|
|
@ -260,12 +253,7 @@ OFFSET ?"""
|
|||
latest_partition=False,
|
||||
cols=cols,
|
||||
)
|
||||
assert (
|
||||
sql
|
||||
== """SELECT
|
||||
a
|
||||
FROM my_table"""
|
||||
)
|
||||
assert sql == "SELECT a\nFROM my_table"
|
||||
|
||||
|
||||
def test_extra_table_metadata(mocker: MockerFixture) -> None:
|
||||
|
|
|
|||
|
|
@ -284,6 +284,40 @@ def test_extract_tables_show_tables_from() -> None:
|
|||
)
|
||||
|
||||
|
||||
def test_format_show_tables() -> None:
|
||||
"""
|
||||
Test format when `ast.sql()` raises an exception.
|
||||
|
||||
In that case sqlparse should be used instead.
|
||||
"""
|
||||
assert (
|
||||
SQLScript("SHOW TABLES FROM s1 like '%order%'", "mysql").format()
|
||||
== "SHOW TABLES FROM s1 LIKE '%order%'"
|
||||
)
|
||||
|
||||
|
||||
def test_format_no_dialect() -> None:
|
||||
"""
|
||||
Test format with an engine that has no corresponding dialect.
|
||||
"""
|
||||
assert (
|
||||
SQLScript("SELECT col FROM t WHERE col NOT IN (1, 2)", "firebolt").format()
|
||||
== "SELECT col\nFROM t\nWHERE col NOT IN (1,\n 2)"
|
||||
)
|
||||
|
||||
|
||||
def test_split_no_dialect() -> None:
|
||||
"""
|
||||
Test the statement split when the engine has no corresponding dialect.
|
||||
"""
|
||||
sql = "SELECT col FROM t WHERE col NOT IN (1, 2); SELECT * FROM t; SELECT foo"
|
||||
statements = SQLScript(sql, "firebolt").statements
|
||||
assert len(statements) == 3
|
||||
assert statements[0]._sql == "SELECT col FROM t WHERE col NOT IN (1, 2)"
|
||||
assert statements[1]._sql == "SELECT * FROM t"
|
||||
assert statements[2]._sql == "SELECT foo"
|
||||
|
||||
|
||||
def test_extract_tables_show_columns_from() -> None:
|
||||
"""
|
||||
Test `SHOW COLUMNS FROM`.
|
||||
|
|
|
|||
Loading…
Reference in New Issue