From bdeedaaf80deb5785d82b786e713c8a3cb579ee3 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Thu, 15 Dec 2022 17:08:34 -0800 Subject: [PATCH] chore: set Snowflake user agent (#22432) --- superset/db_engine_specs/databricks.py | 12 ++--- superset/db_engine_specs/snowflake.py | 16 ++++++- .../db_engine_specs/databricks_tests.py | 21 ++++----- .../db_engine_specs/test_databricks.py | 44 ++++++++++++++++++- .../db_engine_specs/test_snowflake.py | 31 +++++++++++++ 5 files changed, 107 insertions(+), 17 deletions(-) diff --git a/superset/db_engine_specs/databricks.py b/superset/db_engine_specs/databricks.py index 7ebe6ab1a..131679359 100644 --- a/superset/db_engine_specs/databricks.py +++ b/superset/db_engine_specs/databricks.py @@ -163,11 +163,13 @@ class DatabricksNativeEngineSpec(DatabricksODBCEngineSpec, BasicParametersMixin) """ Add a user agent to be used in the requests. """ - extra = { - "http_headers": [("User-Agent", USER_AGENT)], - "_user_agent_entry": USER_AGENT, - } - extra.update(BaseEngineSpec.get_extra_params(database)) + extra: Dict[str, Any] = BaseEngineSpec.get_extra_params(database) + engine_params: Dict[str, Any] = extra.setdefault("engine_params", {}) + connect_args: Dict[str, Any] = engine_params.setdefault("connect_args", {}) + + connect_args.setdefault("http_headers", [("User-Agent", USER_AGENT)]) + connect_args.setdefault("_user_agent_entry", USER_AGENT) + return extra @classmethod diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py index 0704712d6..578ded965 100644 --- a/superset/db_engine_specs/snowflake.py +++ b/superset/db_engine_specs/snowflake.py @@ -31,8 +31,9 @@ from marshmallow import fields, Schema from sqlalchemy.engine.url import URL from typing_extensions import TypedDict +from superset.constants import USER_AGENT from superset.databases.utils import make_url_safe -from superset.db_engine_specs.base import BasicPropertiesType +from superset.db_engine_specs.base import BaseEngineSpec, BasicPropertiesType from superset.db_engine_specs.postgres import PostgresBaseEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.models.sql_lab import Query @@ -118,6 +119,19 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): ), } + @staticmethod + def get_extra_params(database: "Database") -> Dict[str, Any]: + """ + Add a user agent to be used in the requests. + """ + extra: Dict[str, Any] = BaseEngineSpec.get_extra_params(database) + engine_params: Dict[str, Any] = extra.setdefault("engine_params", {}) + connect_args: Dict[str, Any] = engine_params.setdefault("connect_args", {}) + + connect_args.setdefault("application", USER_AGENT) + + return extra + @classmethod def adjust_database_uri( cls, uri: URL, selected_schema: Optional[str] = None diff --git a/tests/integration_tests/db_engine_specs/databricks_tests.py b/tests/integration_tests/db_engine_specs/databricks_tests.py index b399e41fd..c2d57831a 100644 --- a/tests/integration_tests/db_engine_specs/databricks_tests.py +++ b/tests/integration_tests/db_engine_specs/databricks_tests.py @@ -44,16 +44,17 @@ class TestDatabricksDbEngineSpec(TestDbEngineSpec): db.extra = default_db_extra db.server_cert = None extras = DatabricksNativeEngineSpec.get_extra_params(db) - assert "connect_args" not in extras["engine_params"] - - def test_extras_with_user_agent(self): - db = mock.Mock() - db.extra = default_db_extra - extras = DatabricksNativeEngineSpec.get_extra_params(db) - _, user_agent = extras["http_headers"][0] - user_agent_entry = extras["_user_agent_entry"] - assert user_agent == USER_AGENT - assert user_agent_entry == USER_AGENT + assert extras == { + "engine_params": { + "connect_args": { + "_user_agent_entry": "Apache Superset", + "http_headers": [("User-Agent", "Apache Superset")], + }, + }, + "metadata_cache_timeout": {}, + "metadata_params": {}, + "schemas_allowed_for_file_upload": [], + } def test_extras_with_ssl_custom(self): db = mock.Mock() diff --git a/tests/unit_tests/db_engine_specs/test_databricks.py b/tests/unit_tests/db_engine_specs/test_databricks.py index 0cc0907f4..50c7fd47a 100644 --- a/tests/unit_tests/db_engine_specs/test_databricks.py +++ b/tests/unit_tests/db_engine_specs/test_databricks.py @@ -18,8 +18,9 @@ import json +from pytest_mock import MockerFixture + from superset.utils.core import GenericDataType -from tests.integration_tests.db_engine_specs.base_tests import assert_generic_types def test_get_parameters_from_uri() -> None: @@ -110,6 +111,7 @@ def test_generic_type() -> None: assert that generic types match """ from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec + from tests.integration_tests.db_engine_specs.base_tests import assert_generic_types type_expectations = ( # Numeric @@ -133,3 +135,43 @@ def test_generic_type() -> None: ("BOOLEAN", GenericDataType.BOOLEAN), ) assert_generic_types(DatabricksNativeEngineSpec, type_expectations) + + +def test_get_extra_params(mocker: MockerFixture) -> None: + """ + Test the ``get_extra_params`` method. + """ + from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec + + database = mocker.MagicMock() + + database.extra = {} + assert DatabricksNativeEngineSpec.get_extra_params(database) == { + "engine_params": { + "connect_args": { + "http_headers": [("User-Agent", "Apache Superset")], + "_user_agent_entry": "Apache Superset", + } + } + } + + database.extra = json.dumps( + { + "engine_params": { + "connect_args": { + "http_headers": [("User-Agent", "Custom user agent")], + "_user_agent_entry": "Custom user agent", + "foo": "bar", + } + } + } + ) + assert DatabricksNativeEngineSpec.get_extra_params(database) == { + "engine_params": { + "connect_args": { + "http_headers": [["User-Agent", "Custom user agent"]], + "_user_agent_entry": "Custom user agent", + "foo": "bar", + } + } + } diff --git a/tests/unit_tests/db_engine_specs/test_snowflake.py b/tests/unit_tests/db_engine_specs/test_snowflake.py index 2479e071f..3611c7214 100644 --- a/tests/unit_tests/db_engine_specs/test_snowflake.py +++ b/tests/unit_tests/db_engine_specs/test_snowflake.py @@ -14,11 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +# pylint: disable=import-outside-toplevel + import json from datetime import datetime from unittest import mock import pytest +from pytest_mock import MockerFixture from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from tests.unit_tests.fixtures.common import dttm @@ -122,3 +126,30 @@ def test_cancel_query_failed(engine_mock: mock.Mock) -> None: query = Query() cursor_mock = engine_mock.raiseError.side_effect = Exception() assert SnowflakeEngineSpec.cancel_query(cursor_mock, query, "123") is False + + +def test_get_extra_params(mocker: MockerFixture) -> None: + """ + Test the ``get_extra_params`` method. + """ + from superset.db_engine_specs.snowflake import SnowflakeEngineSpec + + database = mocker.MagicMock() + + database.extra = {} + assert SnowflakeEngineSpec.get_extra_params(database) == { + "engine_params": {"connect_args": {"application": "Apache Superset"}} + } + + database.extra = json.dumps( + { + "engine_params": { + "connect_args": {"application": "Custom user agent", "foo": "bar"} + } + } + ) + assert SnowflakeEngineSpec.get_extra_params(database) == { + "engine_params": { + "connect_args": {"application": "Custom user agent", "foo": "bar"} + } + }