From fcf90dffa804bb0c78d2ef05d1423d60f996cb88 Mon Sep 17 00:00:00 2001 From: Guen Prawiroatmodjo Date: Thu, 28 Mar 2024 18:05:28 -0700 Subject: [PATCH] feat(db_engine): Add custom_user_agent when connecting to MotherDuck (#27665) --- superset/db_engine_specs/duckdb.py | 26 ++++++++++++-- .../unit_tests/db_engine_specs/test_duckdb.py | 34 +++++++++++++++++++ 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/superset/db_engine_specs/duckdb.py b/superset/db_engine_specs/duckdb.py index 291b5521e..fc8efdaa3 100644 --- a/superset/db_engine_specs/duckdb.py +++ b/superset/db_engine_specs/duckdb.py @@ -25,7 +25,8 @@ from flask_babel import gettext as __ from sqlalchemy import types from sqlalchemy.engine.reflection import Inspector -from superset.constants import TimeGrain +from superset.config import VERSION_STRING +from superset.constants import TimeGrain, USER_AGENT from superset.db_engine_specs.base import BaseEngineSpec from superset.errors import SupersetErrorType @@ -41,6 +42,8 @@ class DuckDBEngineSpec(BaseEngineSpec): engine = "duckdb" engine_name = "DuckDB" + sqlalchemy_uri_placeholder = "duckdb:////path/to/duck.db" + _time_grain_expressions = { None: "{col}", TimeGrain.SECOND: "DATE_TRUNC('second', {col})", @@ -81,9 +84,28 @@ class DuckDBEngineSpec(BaseEngineSpec): ) -> set[str]: return set(inspector.get_table_names(schema)) + @staticmethod + def get_extra_params(database: Database) -> dict[str, Any]: + """ + Add a user agent to be used in the requests. + """ + extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database) + engine_params: dict[str, Any] = extra.setdefault("engine_params", {}) + connect_args: dict[str, Any] = engine_params.setdefault("connect_args", {}) + config: dict[str, Any] = connect_args.setdefault("config", {}) + custom_user_agent = config.pop("custom_user_agent", "") + delim = " " if custom_user_agent else "" + user_agent = USER_AGENT.replace(" ", "-").lower() + user_agent = f"{user_agent}/{VERSION_STRING}{delim}{custom_user_agent}" + config.setdefault("custom_user_agent", user_agent) + + return extra + class MotherDuckEngineSpec(DuckDBEngineSpec): engine = "duckdb" engine_name = "MotherDuck" - sqlalchemy_uri_placeholder = "duckdb:///md:{SERVICE_TOKEN}@{database_name}" + sqlalchemy_uri_placeholder = ( + "duckdb:///md:{database_name}?motherduck_token={SERVICE_TOKEN}" + ) diff --git a/tests/unit_tests/db_engine_specs/test_duckdb.py b/tests/unit_tests/db_engine_specs/test_duckdb.py index 72d018f4f..39c70470f 100644 --- a/tests/unit_tests/db_engine_specs/test_duckdb.py +++ b/tests/unit_tests/db_engine_specs/test_duckdb.py @@ -15,11 +15,14 @@ # specific language governing permissions and limitations # under the License. +import json from datetime import datetime from typing import Optional import pytest +from pytest_mock import MockerFixture +from superset.config import VERSION_STRING from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm from tests.unit_tests.fixtures.common import dttm @@ -38,3 +41,34 @@ def test_convert_dttm( from superset.db_engine_specs.duckdb import DuckDBEngineSpec as spec assert_convert_dttm(spec, target_type, expected_result, dttm) + + +def test_get_extra_params(mocker: MockerFixture) -> None: + """ + Test the ``get_extra_params`` method. + """ + from superset.db_engine_specs.duckdb import DuckDBEngineSpec + + database = mocker.MagicMock() + + database.extra = {} + assert DuckDBEngineSpec.get_extra_params(database) == { + "engine_params": { + "connect_args": { + "config": {"custom_user_agent": f"apache-superset/{VERSION_STRING}"} + } + } + } + + database.extra = json.dumps( + {"engine_params": {"connect_args": {"config": {"custom_user_agent": "my-app"}}}} + ) + assert DuckDBEngineSpec.get_extra_params(database) == { + "engine_params": { + "connect_args": { + "config": { + "custom_user_agent": f"apache-superset/{VERSION_STRING} my-app" + } + } + } + }