diff --git a/superset/config.py b/superset/config.py index fa55318f6..f274cf31e 100644 --- a/superset/config.py +++ b/superset/config.py @@ -74,6 +74,8 @@ if TYPE_CHECKING: from superset.models.dashboard import Dashboard from superset.models.slice import Slice + DialectExtensions = dict[str, Dialects | type[Dialect]] + # Realtime stats logger, a StatsD implementation exists STATS_LOGGER = DummyStatsLogger() @@ -251,7 +253,7 @@ SQLALCHEMY_ENCRYPTED_FIELD_TYPE_ADAPTER = ( # pylint: disable=invalid-name ) # Extends the default SQLGlot dialects with additional dialects -SQLGLOT_DIALECTS_EXTENSIONS: dict[str, Dialects | type[Dialect]] = {} +SQLGLOT_DIALECTS_EXTENSIONS: DialectExtensions | Callable[[], DialectExtensions] = {} # The limit of queries fetched for query search QUERY_SEARCH_LIMIT = 1000 diff --git a/superset/initialization/__init__.py b/superset/initialization/__init__.py index ee7fcf9ef..10a5e5089 100644 --- a/superset/initialization/__init__.py +++ b/superset/initialization/__init__.py @@ -547,7 +547,12 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods feature_flag_manager.init_app(self.superset_app) def configure_sqlglot_dialects(self) -> None: - SQLGLOT_DIALECTS.update(self.config["SQLGLOT_DIALECTS_EXTENSIONS"]) + extensions = self.config["SQLGLOT_DIALECTS_EXTENSIONS"] + + if callable(extensions): + extensions = extensions() + + SQLGLOT_DIALECTS.update(extensions) @transaction() def configure_fab(self) -> None: diff --git a/superset/sql/dialects/__init__.py b/superset/sql/dialects/__init__.py index 13a83393a..ab09de3c2 100644 --- a/superset/sql/dialects/__init__.py +++ b/superset/sql/dialects/__init__.py @@ -14,3 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +from .firebolt import Firebolt, FireboltOld + +__all__ = ["Firebolt", "FireboltOld"] diff --git a/superset/sql/dialects/firebolt.py b/superset/sql/dialects/firebolt.py index 119ee3ba1..c939dee45 100644 --- a/superset/sql/dialects/firebolt.py +++ b/superset/sql/dialects/firebolt.py @@ -19,6 +19,7 @@ from __future__ import annotations from sqlglot import exp, generator, parser from sqlglot.dialects.dialect import Dialect +from sqlglot.helper import csv from sqlglot.tokens import TokenType @@ -73,3 +74,119 @@ class Firebolt(Dialect): return f"NOT ({self.sql(expression, 'this')})" return super().not_sql(expression) + + +class FireboltOld(Firebolt): + """ + Dialect for the old version of Firebolt (https://old.docs.firebolt.io/). + + The main difference is that `UNNEST` is an operator like `JOIN`, instead of a + function. + """ + + class Parser(Firebolt.Parser): + TABLE_ALIAS_TOKENS = Firebolt.Parser.TABLE_ALIAS_TOKENS - {TokenType.UNNEST} + + def _parse_join( + self, + skip_join_token: bool = False, + parse_bracket: bool = False, + ) -> exp.Join | None: + if unnest := self._parse_unnest(): + return self.expression(exp.Join, this=unnest) + + return super()._parse_join(skip_join_token, parse_bracket) + + def _parse_unnest(self, with_alias: bool = True) -> exp.Unnest | None: + if not self._match(TokenType.UNNEST): + return None + + # parse expressions (col1 AS foo), instead of equalities as in the original + # dialect + expressions = self._parse_wrapped_csv(self._parse_expression) + offset = self._match_pair(TokenType.WITH, TokenType.ORDINALITY) + + alias = self._parse_table_alias() if with_alias else None + + if alias: + if self.dialect.UNNEST_COLUMN_ONLY: + if alias.args.get("columns"): + self.raise_error("Unexpected extra column alias in unnest.") + + alias.set("columns", [alias.this]) + alias.set("this", None) + + columns = alias.args.get("columns") or [] + if offset and len(expressions) < len(columns): + offset = columns.pop() + + if not offset and self._match_pair(TokenType.WITH, TokenType.OFFSET): + self._match(TokenType.ALIAS) + offset = self._parse_id_var( + any_token=False, tokens=self.UNNEST_OFFSET_ALIAS_TOKENS + ) or exp.to_identifier("offset") + + return self.expression( + exp.Unnest, + expressions=expressions, + alias=alias, + offset=offset, + ) + + class Generator(Firebolt.Generator): + def join_sql(self, expression: exp.Join) -> str: + if not self.SEMI_ANTI_JOIN_WITH_SIDE and expression.kind in ( + "SEMI", + "ANTI", + ): + side = None + else: + side = expression.side + + op_sql = " ".join( + op + for op in ( + expression.method, + "GLOBAL" if expression.args.get("global") else None, + side, + expression.kind, + expression.hint if self.JOIN_HINTS else None, + ) + if op + ) + match_cond = self.sql(expression, "match_condition") + match_cond = f" MATCH_CONDITION ({match_cond})" if match_cond else "" + on_sql = self.sql(expression, "on") + using = expression.args.get("using") + + if not on_sql and using: + on_sql = csv(*(self.sql(column) for column in using)) + + this = expression.this + this_sql = self.sql(this) + + if exprs := self.expressions(expression): + this_sql = f"{this_sql},{self.seg(exprs)}" + + if on_sql: + on_sql = self.indent(on_sql, skip_first=True) + space = self.seg(" " * self.pad) if self.pretty else " " + if using: + on_sql = f"{space}USING ({on_sql})" + else: + on_sql = f"{space}ON {on_sql}" + elif not op_sql: + # the main difference with the base dialect is the lack of comma before + # an `UNNEST` + if ( + isinstance(this, exp.Lateral) + and this.args.get("cross_apply") is not None + ) or isinstance(this, exp.Unnest): + return f" {this_sql}" + + return f", {this_sql}" + + if op_sql != "STRAIGHT_JOIN": + op_sql = f"{op_sql} JOIN" if op_sql else "JOIN" + + return f"{self.seg(op_sql)} {this_sql}{match_cond}{on_sql}" diff --git a/superset/sql/parse.py b/superset/sql/parse.py index 3fd13a800..f5923fecc 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -40,8 +40,6 @@ from superset.sql.dialects.firebolt import Firebolt logger = logging.getLogger(__name__) -# register 3rd party dialects -Dialect.classes["firebolt"] = Firebolt # mapping between DB engine specs and sqlglot dialects SQLGLOT_DIALECTS = { @@ -65,7 +63,7 @@ SQLGLOT_DIALECTS = { # "elasticsearch": ??? # "exa": ??? # "firebird": ??? - "firebolt": "firebolt", + "firebolt": Firebolt, "gsheets": Dialects.SQLITE, "hana": Dialects.POSTGRES, "hive": Dialects.HIVE, diff --git a/tests/unit_tests/sql/parse_tests.py b/tests/unit_tests/sql/parse_tests.py index a2aff686a..2df24a1b3 100644 --- a/tests/unit_tests/sql/parse_tests.py +++ b/tests/unit_tests/sql/parse_tests.py @@ -1146,3 +1146,21 @@ SELECT FROM tbl """.strip() ) + + +def test_firebolt_old() -> None: + """ + Test the dialect for the old Firebolt syntax. + """ + from superset.sql.dialects import FireboltOld + from superset.sql.parse import SQLGLOT_DIALECTS + + SQLGLOT_DIALECTS["firebolt"] = FireboltOld + + sql = "SELECT * FROM t1 UNNEST(col1 AS foo)" + assert ( + SQLStatement(sql, "firebolt").format() + == """SELECT + * +FROM t1 UNNEST(col1 AS foo)""" + )