fix: `sqlparse` fallback for formatting queries (#30578)

This commit is contained in:
Beto Dealmeida 2024-10-11 15:45:40 -04:00 committed by GitHub
parent 9a2b1a5cf7
commit 47c1e09c75
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 125 additions and 35 deletions

View File

@ -26,6 +26,8 @@ from dataclasses import dataclass
from typing import Any, Generic, TypeVar
import sqlglot
import sqlparse
from deprecation import deprecated
from sqlglot import exp
from sqlglot.dialects.dialect import Dialect, Dialects
from sqlglot.errors import ParseError
@ -138,9 +140,9 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
"""
Base class for SQL statements.
The class can be instantiated with a string representation of the script or, for
efficiency reasons, with a pre-parsed AST. This is useful with `sqlglot.parse`,
which will split a script in multiple already parsed statements.
The class should be instantiated with a string representation of the script and, for
efficiency reasons, optionally with a pre-parsed AST. This is useful with
`sqlglot.parse`, which will split a script in multiple already parsed statements.
The `engine` parameters comes from the `engine` attribute in a Superset DB engine
spec.
@ -148,14 +150,12 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
def __init__(
self,
statement: str | InternalRepresentation,
statement: str,
engine: str,
ast: InternalRepresentation | None = None,
):
self._parsed: InternalRepresentation = (
self._parse_statement(statement, engine)
if isinstance(statement, str)
else statement
)
self._sql = statement
self._parsed = ast or self._parse_statement(statement, engine)
self.engine = engine
self.tables = self._extract_tables_from_statement(self._parsed, self.engine)
@ -239,11 +239,12 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
def __init__(
self,
statement: str | exp.Expression,
statement: str,
engine: str,
ast: exp.Expression | None = None,
):
self._dialect = SQLGLOT_DIALECTS.get(engine)
super().__init__(statement, engine)
super().__init__(statement, engine, ast)
@classmethod
def _parse(cls, script: str, engine: str) -> list[exp.Expression]:
@ -275,11 +276,47 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
script: str,
engine: str,
) -> list[SQLStatement]:
return [
cls(statement, engine)
for statement in cls._parse(script, engine)
if statement
]
if engine in SQLGLOT_DIALECTS:
try:
return [
cls(ast.sql(), engine, ast)
for ast in cls._parse(script, engine)
if ast
]
except ValueError:
# `ast.sql()` might raise an error on some cases (eg, `SHOW TABLES
# FROM`). In this case, we rely on the tokenizer to generate the
# statements.
pass
# When we don't have a sqlglot dialect we can't rely on `ast.sql()` to correctly
# generate the SQL of each statement, so we tokenize the script and split it
# based on the location of semi-colons.
statements = []
start = 0
remainder = script
try:
tokens = sqlglot.tokenize(script)
except sqlglot.errors.TokenError as ex:
raise SupersetParseError(
script,
engine,
message="Unable to tokenize script",
) from ex
for token in tokens:
if token.token_type == sqlglot.TokenType.SEMICOLON:
statement, start = script[start : token.start], token.end + 1
ast = cls._parse(statement, engine)[0]
statements.append(cls(statement.strip(), engine, ast))
remainder = script[start:]
if remainder.strip():
ast = cls._parse(remainder, engine)[0]
statements.append(cls(remainder.strip(), engine, ast))
return statements
@classmethod
def _parse_statement(
@ -349,8 +386,34 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
"""
Pretty-format the SQL statement.
"""
write = Dialect.get_or_raise(self._dialect)
return write.generate(self._parsed, copy=False, comments=comments, pretty=True)
if self._dialect:
try:
write = Dialect.get_or_raise(self._dialect)
return write.generate(
self._parsed,
copy=False,
comments=comments,
pretty=True,
)
except ValueError:
pass
return self._fallback_formatting()
@deprecated(deprecated_in="4.0", removed_in="5.0")
def _fallback_formatting(self) -> str:
"""
Format SQL without a specific dialect.
Reformatting SQL using the generic sqlglot dialect is known to break queries.
For example, it will change `foo NOT IN (1, 2)` to `NOT foo IN (1,2)`, which
breaks the query for Firebolt. To avoid this, we use sqlparse for formatting
when the dialect is not known.
In 5.0 we should remove `sqlparse`, and the method should return the query
unmodified.
"""
return sqlparse.format(self._sql, reindent=True, keyword_case="upper")
def get_settings(self) -> dict[str, str | bool]:
"""
@ -456,7 +519,9 @@ class KustoKQLStatement(BaseSQLStatement[str]):
https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/scalar-data-types/string
for more information.
"""
return [cls(statement, engine) for statement in split_kql(script)]
return [
cls(statement, engine, statement.strip()) for statement in split_kql(script)
]
@classmethod
def _parse_statement(
@ -498,7 +563,7 @@ class KustoKQLStatement(BaseSQLStatement[str]):
"""
Pretty-format the SQL statement.
"""
return self._parsed
return self._sql.strip()
def get_settings(self) -> dict[str, str | bool]:
"""
@ -548,6 +613,9 @@ class SQLScript:
def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL script.
Note that even though KQL is very different from SQL, multiple statements are
still separated by semi-colons.
"""
return ";\n".join(statement.format(comments) for statement in self.statements)

View File

@ -281,7 +281,7 @@ class TestSqlLabApi(SupersetTestCase):
"/api/v1/sqllab/format_sql/",
json=data,
)
success_resp = {"result": "SELECT\n 1\nFROM my_table"}
success_resp = {"result": "SELECT 1\nFROM my_table"}
resp_data = json.loads(rv.data.decode("utf-8"))
self.assertDictEqual(resp_data, success_resp) # noqa: PT009
assert rv.status_code == 200

View File

@ -241,14 +241,7 @@ def test_select_star(mocker: MockerFixture) -> None:
latest_partition=False,
cols=cols,
)
assert (
sql
== """SELECT
a
FROM my_table
LIMIT ?
OFFSET ?"""
)
assert sql == "SELECT a\nFROM my_table\nLIMIT ?\nOFFSET ?"
sql = NoLimitDBEngineSpec.select_star(
database=database,
@ -260,12 +253,7 @@ OFFSET ?"""
latest_partition=False,
cols=cols,
)
assert (
sql
== """SELECT
a
FROM my_table"""
)
assert sql == "SELECT a\nFROM my_table"
def test_extra_table_metadata(mocker: MockerFixture) -> None:

View File

@ -284,6 +284,40 @@ def test_extract_tables_show_tables_from() -> None:
)
def test_format_show_tables() -> None:
"""
Test format when `ast.sql()` raises an exception.
In that case sqlparse should be used instead.
"""
assert (
SQLScript("SHOW TABLES FROM s1 like '%order%'", "mysql").format()
== "SHOW TABLES FROM s1 LIKE '%order%'"
)
def test_format_no_dialect() -> None:
"""
Test format with an engine that has no corresponding dialect.
"""
assert (
SQLScript("SELECT col FROM t WHERE col NOT IN (1, 2)", "firebolt").format()
== "SELECT col\nFROM t\nWHERE col NOT IN (1,\n 2)"
)
def test_split_no_dialect() -> None:
"""
Test the statement split when the engine has no corresponding dialect.
"""
sql = "SELECT col FROM t WHERE col NOT IN (1, 2); SELECT * FROM t; SELECT foo"
statements = SQLScript(sql, "firebolt").statements
assert len(statements) == 3
assert statements[0]._sql == "SELECT col FROM t WHERE col NOT IN (1, 2)"
assert statements[1]._sql == "SELECT * FROM t"
assert statements[2]._sql == "SELECT foo"
def test_extract_tables_show_columns_from() -> None:
"""
Test `SHOW COLUMNS FROM`.