fix(trino): handle missing db in migration (#29997)

This commit is contained in:
Ville Brofeldt 2024-08-22 15:52:56 -07:00 committed by GitHub
parent 5906890b78
commit 17eecb1981
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 10 deletions

View File

@ -296,10 +296,13 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
return "from_unixtime({col})"
@classmethod
def get_default_catalog(cls, database: "Database") -> str | None:
def get_default_catalog(cls, database: Database) -> str | None:
"""
Return the default catalog.
"""
if database.url_object.database is None:
return None
return database.url_object.database.split("/")[0]
@classmethod

View File

@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
from __future__ import annotations
import copy
from collections import namedtuple
from datetime import datetime
@ -719,7 +721,15 @@ def test_adjust_engine_params_catalog_only() -> None:
assert str(uri) == "trino://user:pass@localhost:8080/new_catalog/new_schema"
def test_get_default_catalog() -> None:
@pytest.mark.parametrize(
"sqlalchemy_uri,result",
[
("trino://user:pass@localhost:8080/system", "system"),
("trino://user:pass@localhost:8080/system/default", "system"),
("trino://trino@localhost:8081", None),
],
)
def test_get_default_catalog(sqlalchemy_uri: str, result: str | None) -> None:
"""
Test the ``get_default_catalog`` method.
"""
@ -728,15 +738,9 @@ def test_get_default_catalog() -> None:
database = Database(
database_name="my_db",
sqlalchemy_uri="trino://user:pass@localhost:8080/system",
sqlalchemy_uri=sqlalchemy_uri,
)
assert TrinoEngineSpec.get_default_catalog(database) == "system"
database = Database(
database_name="my_db",
sqlalchemy_uri="trino://user:pass@localhost:8080/system/default",
)
assert TrinoEngineSpec.get_default_catalog(database) == "system"
assert TrinoEngineSpec.get_default_catalog(database) == result
@patch("superset.db_engine_specs.trino.TrinoEngineSpec.latest_partition")