diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 5a375896c..f0664564f 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -296,10 +296,13 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta): return "from_unixtime({col})" @classmethod - def get_default_catalog(cls, database: "Database") -> str | None: + def get_default_catalog(cls, database: Database) -> str | None: """ Return the default catalog. """ + if database.url_object.database is None: + return None + return database.url_object.database.split("/")[0] @classmethod diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py index f7183ba7d..5a32cd050 100644 --- a/tests/unit_tests/db_engine_specs/test_trino.py +++ b/tests/unit_tests/db_engine_specs/test_trino.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=unused-argument, import-outside-toplevel, protected-access +from __future__ import annotations + import copy from collections import namedtuple from datetime import datetime @@ -719,7 +721,15 @@ def test_adjust_engine_params_catalog_only() -> None: assert str(uri) == "trino://user:pass@localhost:8080/new_catalog/new_schema" -def test_get_default_catalog() -> None: +@pytest.mark.parametrize( + "sqlalchemy_uri,result", + [ + ("trino://user:pass@localhost:8080/system", "system"), + ("trino://user:pass@localhost:8080/system/default", "system"), + ("trino://trino@localhost:8081", None), + ], +) +def test_get_default_catalog(sqlalchemy_uri: str, result: str | None) -> None: """ Test the ``get_default_catalog`` method. """ @@ -728,15 +738,9 @@ def test_get_default_catalog() -> None: database = Database( database_name="my_db", - sqlalchemy_uri="trino://user:pass@localhost:8080/system", + sqlalchemy_uri=sqlalchemy_uri, ) - assert TrinoEngineSpec.get_default_catalog(database) == "system" - - database = Database( - database_name="my_db", - sqlalchemy_uri="trino://user:pass@localhost:8080/system/default", - ) - assert TrinoEngineSpec.get_default_catalog(database) == "system" + assert TrinoEngineSpec.get_default_catalog(database) == result @patch("superset.db_engine_specs.trino.TrinoEngineSpec.latest_partition")