From 5da6d2bd8889e75c4b44507f3ce7dae5065ceefa Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Thu, 16 May 2024 12:49:31 -0400 Subject: [PATCH] feat: add support for catalogs (#28416) --- superset/db_engine_specs/README.md | 22 +---- superset/db_engine_specs/bigquery.py | 34 ++++++- superset/db_engine_specs/presto.py | 57 ++++++----- superset/db_engine_specs/snowflake.py | 28 ++++-- superset/migrations/shared/catalogs.py | 8 +- ...58d051681a3b_add_catalog_perm_to_tables.py | 4 +- ...81be5b6b74_enable_catalog_in_databricks.py | 4 +- ...nable_catalog_in_bigquery_presto_trino_.py | 40 ++++++++ .../integration_tests/databases/api_tests.py | 2 +- tests/integration_tests/model_tests.py | 8 +- .../db_engine_specs/test_bigquery.py | 94 +++++++++++++++++++ .../unit_tests/db_engine_specs/test_presto.py | 88 +++++++++++++++++ .../db_engine_specs/test_snowflake.py | 88 +++++++++++++++++ .../unit_tests/db_engine_specs/test_trino.py | 89 ++++++++++++++++++ 14 files changed, 504 insertions(+), 62 deletions(-) create mode 100644 superset/migrations/versions/2024-05-09_18-44_87ffc36f9842_enable_catalog_in_bigquery_presto_trino_.py diff --git a/superset/db_engine_specs/README.md b/superset/db_engine_specs/README.md index 4a108be65..88362f6b0 100644 --- a/superset/db_engine_specs/README.md +++ b/superset/db_engine_specs/README.md @@ -706,29 +706,11 @@ Hive and Trino: 4. Table 5. Column -If the database supports catalogs, then the DB engine spec should have the `supports_catalog` class attribute set to true. +If the database supports catalogs, then the DB engine spec should have the `supports_catalog` class attribute set to true. It should also implement the `get_default_catalog` method, so that the proper permissions can be created when datasets are added. ### Dynamic catalog -Superset has no support for multiple catalogs. A given SQLAlchemy URI connects to a single catalog, and it's impossible to browse other catalogs, or change the catalog. This means that datasets can only be added for the main catalog of the database. For example, with this Postgres SQLAlchemy URI: - -``` -postgresql://admin:password123@db.example.org:5432/db -``` - -Here, datasets can only be added to the `db` catalog (which Postgres calls a "database"). - -One confusing problem is that many databases allow querying across catalogs in SQL Lab. For example, with BigQuery one can write: - -```sql -SELECT * FROM project.schema.table -``` - -This means that **even though the database is configured for a given catalog (project), users can query other projects**. This is a common workaround for creating datasets in catalogs other than the catalog configured in the database: just create a virtual dataset. - -Ideally we would want users to be able to choose the catalog when using SQL Lab and when creating datasets. In order to do that, DB engine specs need to implement a method that rewrites the SQLAlchemy URI depending on the desired catalog. This method already exists, and is the same method used for dynamic schemas, `adjust_engine_params`, but currently there are no UI affordances for choosing a catalog. - -Before the UI is implemented Superset still needs to implement support for catalogs in its security manager. But in the meantime, it's possible for DB engine spec developers to support dynamic catalogs, by setting `supports_dynamic_catalog` to true and implementing `adjust_engine_params` to handle a catalog. +Superset support for multiple catalogs. Since, in general, a given SQLAlchemy URI connects only to a single catalog, it requires DB engine specs to implement the `adjust_engine_params` method to rewrite the URL to connect to a different catalog, similar to how dynamic schemas work. Additionally, DB engine specs should also implement the `get_catalog_names` method, so that users can browse the available catalogs. ### SSH tunneling diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py index ca52bd51c..4fff5ab3d 100644 --- a/superset/db_engine_specs/bigquery.py +++ b/superset/db_engine_specs/bigquery.py @@ -35,6 +35,7 @@ from marshmallow.exceptions import ValidationError from sqlalchemy import column, types from sqlalchemy.engine.base import Engine from sqlalchemy.engine.reflection import Inspector +from sqlalchemy.engine.url import URL from sqlalchemy.sql import sqltypes from superset import sql_parse @@ -127,7 +128,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met allows_hidden_cc_in_orderby = True - supports_catalog = False + supports_catalog = supports_dynamic_catalog = True """ https://www.python.org/dev/peps/pep-0249/#arraysize @@ -459,6 +460,24 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met for statement in statements ] + @classmethod + def get_default_catalog(cls, database: Database) -> str | None: + """ + Get the default catalog. + """ + url = database.url_object + + # The SQLAlchemy driver accepts both `bigquery://project` (where the project is + # technically a host) and `bigquery:///project` (where it's a database). But + # both can be missing, and the project is inferred from the authentication + # credentials. + if project := url.host or url.database: + return project + + with database.get_sqla_engine() as engine: + client = cls._get_client(engine) + return client.project + @classmethod def get_catalog_names( cls, @@ -477,6 +496,19 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met return {project.project_id for project in projects} + @classmethod + def adjust_engine_params( + cls, + uri: URL, + connect_args: dict[str, Any], + catalog: str | None = None, + schema: str | None = None, + ) -> tuple[URL, dict[str, Any]]: + if catalog: + uri = uri.set(host=catalog, database="") + + return uri, connect_args + @classmethod def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool: return True diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index b8bc7f62d..ba542b8f6 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -14,7 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + # pylint: disable=too-many-lines + from __future__ import annotations import contextlib @@ -165,6 +167,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta): """ supports_dynamic_schema = True + supports_catalog = supports_dynamic_catalog = True column_type_mappings = ( ( @@ -295,6 +298,24 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta): def epoch_to_dttm(cls) -> str: return "from_unixtime({col})" + @classmethod + def get_default_catalog(cls, database: "Database") -> str | None: + """ + Return the default catalog. + """ + return database.url_object.database.split("/")[0] + + @classmethod + def get_catalog_names( + cls, + database: Database, + inspector: Inspector, + ) -> set[str]: + """ + Get all catalogs. + """ + return {catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")} + @classmethod def adjust_engine_params( cls, @@ -303,14 +324,22 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta): catalog: str | None = None, schema: str | None = None, ) -> tuple[URL, dict[str, Any]]: - database = uri.database - if schema and database: + if uri.database and "/" in uri.database: + current_catalog, current_schema = uri.database.split("/", 1) + else: + current_catalog, current_schema = uri.database, None + + if schema: schema = parse.quote(schema, safe="") - if "/" in database: - database = database.split("/")[0] + "/" + schema - else: - database += "/" + schema - uri = uri.set(database=database) + + adjusted_database = "/".join( + [ + catalog or current_catalog or "", + schema or current_schema or "", + ] + ).rstrip("/") + + uri = uri.set(database=adjusted_database) return uri, connect_args @@ -651,8 +680,6 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): engine_name = "Presto" allows_alias_to_source_column = False - supports_catalog = False - custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { COLUMN_DOES_NOT_EXIST_REGEX: ( __( @@ -815,17 +842,6 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): results = cursor.fetchall() return {row[0] for row in results} - @classmethod - def get_catalog_names( - cls, - database: Database, - inspector: Inspector, - ) -> set[str]: - """ - Get all catalogs. - """ - return {catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")} - @classmethod def _create_column_info( cls, name: str, data_type: types.TypeEngine @@ -1251,7 +1267,6 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): ), } - # flake8 is not matching `Optional[str]` to `Any` for some reason... metadata["view"] = cast( Any, cls.get_create_view(database, table.schema, table.table), diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py index 9a82cfcca..137cc4e00 100644 --- a/superset/db_engine_specs/snowflake.py +++ b/superset/db_engine_specs/snowflake.py @@ -85,7 +85,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): sqlalchemy_uri_placeholder = "snowflake://" supports_dynamic_schema = True - supports_catalog = False + supports_catalog = supports_dynamic_catalog = True _time_grain_expressions = { None: "{col}", @@ -144,12 +144,19 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): catalog: Optional[str] = None, schema: Optional[str] = None, ) -> tuple[URL, dict[str, Any]]: - database = uri.database - if "/" in database: - database = database.split("/")[0] - if schema: - schema = parse.quote(schema, safe="") - uri = uri.set(database=f"{database}/{schema}") + if "/" in uri.database: + current_catalog, current_schema = uri.database.split("/", 1) + else: + current_catalog, current_schema = uri.database, None + + adjusted_database = "/".join( + [ + catalog or current_catalog, + schema or current_schema or "", + ] + ).rstrip("/") + + uri = uri.set(database=adjusted_database) return uri, connect_args @@ -169,6 +176,13 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): return parse.unquote(database.split("/")[1]) + @classmethod + def get_default_catalog(cls, database: "Database") -> Optional[str]: + """ + Return the default catalog. + """ + return database.url_object.database.split("/")[0] + @classmethod def get_catalog_names( cls, diff --git a/superset/migrations/shared/catalogs.py b/superset/migrations/shared/catalogs.py index 4b13d6043..6c03faec4 100644 --- a/superset/migrations/shared/catalogs.py +++ b/superset/migrations/shared/catalogs.py @@ -86,7 +86,7 @@ class Slice(Base): schema_perm = sa.Column(sa.String(1000)) -def upgrade_catalog_perms(engine: str | None = None) -> None: +def upgrade_catalog_perms(engines: set[str] | None = None) -> None: """ Update models when catalogs are introduced in a DB engine spec. @@ -102,7 +102,7 @@ def upgrade_catalog_perms(engine: str | None = None) -> None: for database in session.query(Database).all(): db_engine_spec = database.db_engine_spec if ( - engine and db_engine_spec.engine != engine + engines and db_engine_spec.engine not in engines ) or not db_engine_spec.supports_catalog: continue @@ -166,7 +166,7 @@ def upgrade_catalog_perms(engine: str | None = None) -> None: session.commit() -def downgrade_catalog_perms(engine: str | None = None) -> None: +def downgrade_catalog_perms(engines: set[str] | None = None) -> None: """ Reverse the process of `upgrade_catalog_perms`. """ @@ -175,7 +175,7 @@ def downgrade_catalog_perms(engine: str | None = None) -> None: for database in session.query(Database).all(): db_engine_spec = database.db_engine_spec if ( - engine and db_engine_spec.engine != engine + engines and db_engine_spec.engine not in engines ) or not db_engine_spec.supports_catalog: continue diff --git a/superset/migrations/versions/2024-05-01_10-52_58d051681a3b_add_catalog_perm_to_tables.py b/superset/migrations/versions/2024-05-01_10-52_58d051681a3b_add_catalog_perm_to_tables.py index f8f782474..856ad2ad0 100644 --- a/superset/migrations/versions/2024-05-01_10-52_58d051681a3b_add_catalog_perm_to_tables.py +++ b/superset/migrations/versions/2024-05-01_10-52_58d051681a3b_add_catalog_perm_to_tables.py @@ -44,10 +44,10 @@ def upgrade(): "slices", sa.Column("catalog_perm", sa.String(length=1000), nullable=True), ) - upgrade_catalog_perms(engine="postgresql") + upgrade_catalog_perms(engines={"postgresql"}) def downgrade(): op.drop_column("slices", "catalog_perm") op.drop_column("tables", "catalog_perm") - downgrade_catalog_perms(engine="postgresql") + downgrade_catalog_perms(engines={"postgresql"}) diff --git a/superset/migrations/versions/2024-05-08_19-33_4081be5b6b74_enable_catalog_in_databricks.py b/superset/migrations/versions/2024-05-08_19-33_4081be5b6b74_enable_catalog_in_databricks.py index f39d6fa0d..3f7f82ee9 100644 --- a/superset/migrations/versions/2024-05-08_19-33_4081be5b6b74_enable_catalog_in_databricks.py +++ b/superset/migrations/versions/2024-05-08_19-33_4081be5b6b74_enable_catalog_in_databricks.py @@ -33,8 +33,8 @@ down_revision = "645bb206f96c" def upgrade(): - upgrade_catalog_perms(engine="databricks") + upgrade_catalog_perms(engines={"databricks"}) def downgrade(): - downgrade_catalog_perms(engine="databricks") + downgrade_catalog_perms(engines={"databricks"}) diff --git a/superset/migrations/versions/2024-05-09_18-44_87ffc36f9842_enable_catalog_in_bigquery_presto_trino_.py b/superset/migrations/versions/2024-05-09_18-44_87ffc36f9842_enable_catalog_in_bigquery_presto_trino_.py new file mode 100644 index 000000000..bca3b1e4b --- /dev/null +++ b/superset/migrations/versions/2024-05-09_18-44_87ffc36f9842_enable_catalog_in_bigquery_presto_trino_.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Enable catalog in BigQuery/Presto/Trino/Snowflake + +Revision ID: 87ffc36f9842 +Revises: 4081be5b6b74 +Create Date: 2024-05-09 18:44:43.289445 + +""" + +from superset.migrations.shared.catalogs import ( + downgrade_catalog_perms, + upgrade_catalog_perms, +) + +# revision identifiers, used by Alembic. +revision = "87ffc36f9842" +down_revision = "4081be5b6b74" + + +def upgrade(): + upgrade_catalog_perms(engines={"trino", "presto", "bigquery", "snowflake"}) + + +def downgrade(): + downgrade_catalog_perms(engines={"trino", "presto", "bigquery", "snowflake"}) diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index 520a068d4..a6deb942a 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -3281,7 +3281,7 @@ class TestDatabaseApi(SupersetTestCase): "sqlalchemy_uri_placeholder": "bigquery://{project_id}", "engine_information": { "supports_file_upload": True, - "supports_dynamic_catalog": False, + "supports_dynamic_catalog": True, "disable_ssh_tunneling": True, }, }, diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py index 4dc15a2af..df806b04b 100644 --- a/tests/integration_tests/model_tests.py +++ b/tests/integration_tests/model_tests.py @@ -167,7 +167,7 @@ class TestDatabaseModel(SupersetTestCase): model._get_sqla_engine() call_args = mocked_create_engine.call_args - assert str(call_args[0][0]) == "presto://gamma@localhost" + assert str(call_args[0][0]) == "presto://gamma@localhost/" assert call_args[1]["connect_args"] == { "protocol": "https", @@ -180,7 +180,7 @@ class TestDatabaseModel(SupersetTestCase): model._get_sqla_engine() call_args = mocked_create_engine.call_args - assert str(call_args[0][0]) == "presto://localhost" + assert str(call_args[0][0]) == "presto://localhost/" assert call_args[1]["connect_args"] == { "protocol": "https", @@ -225,7 +225,7 @@ class TestDatabaseModel(SupersetTestCase): model._get_sqla_engine() call_args = mocked_create_engine.call_args - assert str(call_args[0][0]) == "trino://localhost" + assert str(call_args[0][0]) == "trino://localhost/" assert call_args[1]["connect_args"]["user"] == "gamma" model = Database( @@ -239,7 +239,7 @@ class TestDatabaseModel(SupersetTestCase): assert ( str(call_args[0][0]) - == "trino://original_user:original_user_password@localhost" + == "trino://original_user:original_user_password@localhost/" ) assert call_args[1]["connect_args"]["user"] == "gamma" diff --git a/tests/unit_tests/db_engine_specs/test_bigquery.py b/tests/unit_tests/db_engine_specs/test_bigquery.py index 616ae6684..13eecda09 100644 --- a/tests/unit_tests/db_engine_specs/test_bigquery.py +++ b/tests/unit_tests/db_engine_specs/test_bigquery.py @@ -24,6 +24,7 @@ from typing import Optional import pytest from pytest_mock import MockFixture from sqlalchemy import select +from sqlalchemy.engine.url import make_url from sqlalchemy.sql import sqltypes from sqlalchemy_bigquery import BigQueryDialect @@ -333,3 +334,96 @@ def test_convert_dttm( from superset.db_engine_specs.bigquery import BigQueryEngineSpec as spec assert_convert_dttm(spec, target_type, expected_result, dttm) + + +def test_get_default_catalog(mocker: MockFixture) -> None: + """ + Test that we get the default catalog from the connection URI. + """ + from superset.db_engine_specs.bigquery import BigQueryEngineSpec + from superset.models.core import Database + + mocker.patch.object(Database, "get_sqla_engine") + get_client = mocker.patch.object(BigQueryEngineSpec, "_get_client") + get_client().project = "project" + + database = Database( + database_name="my_db", + sqlalchemy_uri="bigquery://project", + ) + assert BigQueryEngineSpec.get_default_catalog(database) == "project" + + database = Database( + database_name="my_db", + sqlalchemy_uri="bigquery:///project", + ) + assert BigQueryEngineSpec.get_default_catalog(database) == "project" + + database = Database( + database_name="my_db", + sqlalchemy_uri="bigquery://", + ) + assert BigQueryEngineSpec.get_default_catalog(database) == "project" + + +def test_adjust_engine_params_catalog_as_host() -> None: + """ + Test passing a custom catalog. + + In this test, the original URI has the catalog as the host. + """ + from superset.db_engine_specs.bigquery import BigQueryEngineSpec + + url = make_url("bigquery://project") + + uri = BigQueryEngineSpec.adjust_engine_params(url, {})[0] + assert str(uri) == "bigquery://project" + + uri = BigQueryEngineSpec.adjust_engine_params( + url, + {}, + catalog="other-project", + )[0] + assert str(uri) == "bigquery://other-project/" + + +def test_adjust_engine_params_catalog_as_database() -> None: + """ + Test passing a custom catalog. + + In this test, the original URI has the catalog as the database. + """ + from superset.db_engine_specs.bigquery import BigQueryEngineSpec + + url = make_url("bigquery:///project") + + uri = BigQueryEngineSpec.adjust_engine_params(url, {})[0] + assert str(uri) == "bigquery:///project" + + uri = BigQueryEngineSpec.adjust_engine_params( + url, + {}, + catalog="other-project", + )[0] + assert str(uri) == "bigquery://other-project/" + + +def test_adjust_engine_params_no_catalog() -> None: + """ + Test passing a custom catalog. + + In this test, the original URI has no catalog. + """ + from superset.db_engine_specs.bigquery import BigQueryEngineSpec + + url = make_url("bigquery://") + + uri = BigQueryEngineSpec.adjust_engine_params(url, {})[0] + assert str(uri) == "bigquery://" + + uri = BigQueryEngineSpec.adjust_engine_params( + url, + {}, + catalog="other-project", + )[0] + assert str(uri) == "bigquery://other-project/" diff --git a/tests/unit_tests/db_engine_specs/test_presto.py b/tests/unit_tests/db_engine_specs/test_presto.py index 3d7703f0f..f9680006d 100644 --- a/tests/unit_tests/db_engine_specs/test_presto.py +++ b/tests/unit_tests/db_engine_specs/test_presto.py @@ -155,3 +155,91 @@ def test_where_latest_partition( ) assert str(actual) == expected + + +def test_adjust_engine_params_fully_qualified() -> None: + """ + Test the ``adjust_engine_params`` method when the URL has catalog and schema. + """ + from superset.db_engine_specs.presto import PrestoEngineSpec + + url = make_url("presto://localhost:8080/hive/default") + + uri = PrestoEngineSpec.adjust_engine_params(url, {})[0] + assert str(uri) == "presto://localhost:8080/hive/default" + + uri = PrestoEngineSpec.adjust_engine_params( + url, + {}, + schema="new_schema", + )[0] + assert str(uri) == "presto://localhost:8080/hive/new_schema" + + uri = PrestoEngineSpec.adjust_engine_params( + url, + {}, + catalog="new_catalog", + )[0] + assert str(uri) == "presto://localhost:8080/new_catalog/default" + + uri = PrestoEngineSpec.adjust_engine_params( + url, + {}, + catalog="new_catalog", + schema="new_schema", + )[0] + assert str(uri) == "presto://localhost:8080/new_catalog/new_schema" + + +def test_adjust_engine_params_catalog_only() -> None: + """ + Test the ``adjust_engine_params`` method when the URL has only the catalog. + """ + from superset.db_engine_specs.presto import PrestoEngineSpec + + url = make_url("presto://localhost:8080/hive") + + uri = PrestoEngineSpec.adjust_engine_params(url, {})[0] + assert str(uri) == "presto://localhost:8080/hive" + + uri = PrestoEngineSpec.adjust_engine_params( + url, + {}, + schema="new_schema", + )[0] + assert str(uri) == "presto://localhost:8080/hive/new_schema" + + uri = PrestoEngineSpec.adjust_engine_params( + url, + {}, + catalog="new_catalog", + )[0] + assert str(uri) == "presto://localhost:8080/new_catalog" + + uri = PrestoEngineSpec.adjust_engine_params( + url, + {}, + catalog="new_catalog", + schema="new_schema", + )[0] + assert str(uri) == "presto://localhost:8080/new_catalog/new_schema" + + +def test_get_default_catalog() -> None: + """ + Test the ``get_default_catalog`` method. + """ + from superset.db_engine_specs.presto import PrestoEngineSpec + from superset.models.core import Database + + database = Database( + database_name="my_db", + sqlalchemy_uri="presto://localhost:8080/hive", + ) + assert PrestoEngineSpec.get_default_catalog(database) == "hive" + + database = Database( + database_name="my_db", + sqlalchemy_uri="presto://localhost:8080/hive/default", + ) + assert PrestoEngineSpec.get_default_catalog(database) == "hive" diff --git a/tests/unit_tests/db_engine_specs/test_snowflake.py b/tests/unit_tests/db_engine_specs/test_snowflake.py index cf2393e84..dbbd58f00 100644 --- a/tests/unit_tests/db_engine_specs/test_snowflake.py +++ b/tests/unit_tests/db_engine_specs/test_snowflake.py @@ -203,3 +203,91 @@ def test_get_schema_from_engine_params() -> None: ) is None ) + + +def test_adjust_engine_params_fully_qualified() -> None: + """ + Test the ``adjust_engine_params`` method when the URL has catalog and schema. + """ + from superset.db_engine_specs.snowflake import SnowflakeEngineSpec + + url = make_url("snowflake://user:pass@account/database_name/default") + + uri = SnowflakeEngineSpec.adjust_engine_params(url, {})[0] + assert str(uri) == "snowflake://user:pass@account/database_name/default" + + uri = SnowflakeEngineSpec.adjust_engine_params( + url, + {}, + schema="new_schema", + )[0] + assert str(uri) == "snowflake://user:pass@account/database_name/new_schema" + + uri = SnowflakeEngineSpec.adjust_engine_params( + url, + {}, + catalog="new_catalog", + )[0] + assert str(uri) == "snowflake://user:pass@account/new_catalog/default" + + uri = SnowflakeEngineSpec.adjust_engine_params( + url, + {}, + catalog="new_catalog", + schema="new_schema", + )[0] + assert str(uri) == "snowflake://user:pass@account/new_catalog/new_schema" + + +def test_adjust_engine_params_catalog_only() -> None: + """ + Test the ``adjust_engine_params`` method when the URL has only the catalog. + """ + from superset.db_engine_specs.snowflake import SnowflakeEngineSpec + + url = make_url("snowflake://user:pass@account/database_name") + + uri = SnowflakeEngineSpec.adjust_engine_params(url, {})[0] + assert str(uri) == "snowflake://user:pass@account/database_name" + + uri = SnowflakeEngineSpec.adjust_engine_params( + url, + {}, + schema="new_schema", + )[0] + assert str(uri) == "snowflake://user:pass@account/database_name/new_schema" + + uri = SnowflakeEngineSpec.adjust_engine_params( + url, + {}, + catalog="new_catalog", + )[0] + assert str(uri) == "snowflake://user:pass@account/new_catalog" + + uri = SnowflakeEngineSpec.adjust_engine_params( + url, + {}, + catalog="new_catalog", + schema="new_schema", + )[0] + assert str(uri) == "snowflake://user:pass@account/new_catalog/new_schema" + + +def test_get_default_catalog() -> None: + """ + Test the ``get_default_catalog`` method. + """ + from superset.db_engine_specs.snowflake import SnowflakeEngineSpec + from superset.models.core import Database + + database = Database( + database_name="my_db", + sqlalchemy_uri="snowflake://user:pass@account/database_name", + ) + assert SnowflakeEngineSpec.get_default_catalog(database) == "database_name" + + database = Database( + database_name="my_db", + sqlalchemy_uri="snowflake://user:pass@account/database_name/default", + ) + assert SnowflakeEngineSpec.get_default_catalog(database) == "database_name" diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py index 5bd83828e..e35615f57 100644 --- a/tests/unit_tests/db_engine_specs/test_trino.py +++ b/tests/unit_tests/db_engine_specs/test_trino.py @@ -26,6 +26,7 @@ import pytest from pytest_mock import MockerFixture from requests.exceptions import ConnectionError as RequestsConnectionError from sqlalchemy import types +from sqlalchemy.engine.url import make_url from trino.exceptions import TrinoExternalError, TrinoInternalError, TrinoUserError from trino.sqlalchemy import datatype @@ -556,3 +557,91 @@ def test_get_dbapi_exception_mapping(): assert mapping.get(TrinoExternalError) == SupersetDBAPIOperationalError assert mapping.get(RequestsConnectionError) == SupersetDBAPIConnectionError assert mapping.get(Exception) is None + + +def test_adjust_engine_params_fully_qualified() -> None: + """ + Test the ``adjust_engine_params`` method when the URL has catalog and schema. + """ + from superset.db_engine_specs.trino import TrinoEngineSpec + + url = make_url("trino://user:pass@localhost:8080/system/default") + + uri = TrinoEngineSpec.adjust_engine_params(url, {})[0] + assert str(uri) == "trino://user:pass@localhost:8080/system/default" + + uri = TrinoEngineSpec.adjust_engine_params( + url, + {}, + schema="new_schema", + )[0] + assert str(uri) == "trino://user:pass@localhost:8080/system/new_schema" + + uri = TrinoEngineSpec.adjust_engine_params( + url, + {}, + catalog="new_catalog", + )[0] + assert str(uri) == "trino://user:pass@localhost:8080/new_catalog/default" + + uri = TrinoEngineSpec.adjust_engine_params( + url, + {}, + catalog="new_catalog", + schema="new_schema", + )[0] + assert str(uri) == "trino://user:pass@localhost:8080/new_catalog/new_schema" + + +def test_adjust_engine_params_catalog_only() -> None: + """ + Test the ``adjust_engine_params`` method when the URL has only the catalog. + """ + from superset.db_engine_specs.trino import TrinoEngineSpec + + url = make_url("trino://user:pass@localhost:8080/system") + + uri = TrinoEngineSpec.adjust_engine_params(url, {})[0] + assert str(uri) == "trino://user:pass@localhost:8080/system" + + uri = TrinoEngineSpec.adjust_engine_params( + url, + {}, + schema="new_schema", + )[0] + assert str(uri) == "trino://user:pass@localhost:8080/system/new_schema" + + uri = TrinoEngineSpec.adjust_engine_params( + url, + {}, + catalog="new_catalog", + )[0] + assert str(uri) == "trino://user:pass@localhost:8080/new_catalog" + + uri = TrinoEngineSpec.adjust_engine_params( + url, + {}, + catalog="new_catalog", + schema="new_schema", + )[0] + assert str(uri) == "trino://user:pass@localhost:8080/new_catalog/new_schema" + + +def test_get_default_catalog() -> None: + """ + Test the ``get_default_catalog`` method. + """ + from superset.db_engine_specs.trino import TrinoEngineSpec + from superset.models.core import Database + + database = Database( + database_name="my_db", + sqlalchemy_uri="trino://user:pass@localhost:8080/system", + ) + assert TrinoEngineSpec.get_default_catalog(database) == "system" + + database = Database( + database_name="my_db", + sqlalchemy_uri="trino://user:pass@localhost:8080/system/default", + ) + assert TrinoEngineSpec.get_default_catalog(database) == "system"