diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index ed58e8cb8..93df7c721 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -357,6 +357,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods top_keywords: Set[str] = {"TOP"} # A set of disallowed connection query parameters disallow_uri_query_params: Set[str] = set() + # A Dict of query parameters that will always be used on every connection + enforce_uri_query_params: Dict[str, Any] = {} force_column_alias_quotes = False arraysize = 0 @@ -1089,11 +1091,12 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods ``supports_dynamic_schema`` set to true, so that Superset knows in which schema a given query is running in order to enforce permissions (see #23385 and #23401). - Currently, changing the catalog is not supported. The method acceps a catalog so - that when catalog support is added to Superse the interface remains the same. This - is important because DB engine specs can be installed from 3rd party packages. + Currently, changing the catalog is not supported. The method accepts a catalog so + that when catalog support is added to Superset the interface remains the same. + This is important because DB engine specs can be installed from 3rd party + packages. """ - return uri, connect_args + return uri, {**connect_args, **cls.enforce_uri_query_params} @classmethod def patch(cls) -> None: diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py index e5ff964f8..07d2aea36 100644 --- a/superset/db_engine_specs/mysql.py +++ b/superset/db_engine_specs/mysql.py @@ -176,6 +176,7 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin): ), } disallow_uri_query_params = {"local_infile"} + enforce_uri_query_params = {"local_infile": 0} @classmethod def convert_dttm( @@ -198,10 +199,13 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin): catalog: Optional[str] = None, schema: Optional[str] = None, ) -> Tuple[URL, Dict[str, Any]]: + uri, new_connect_args = super( + MySQLEngineSpec, MySQLEngineSpec + ).adjust_engine_params(uri, connect_args, catalog, schema) if schema: uri = uri.set(database=parse.quote(schema, safe="")) - return uri, connect_args + return uri, new_connect_args @classmethod def get_schema_from_engine_params( diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py index da6c5e6a3..35dbcc0a6 100644 --- a/tests/integration_tests/model_tests.py +++ b/tests/integration_tests/model_tests.py @@ -188,6 +188,21 @@ class TestDatabaseModel(SupersetTestCase): "password": "original_user_password", } + @unittest.skipUnless( + SupersetTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed" + ) + @mock.patch("superset.models.core.create_engine") + def test_adjust_engine_params_mysql(self, mocked_create_engine): + model = Database( + database_name="test_database", + sqlalchemy_uri="mysql://user:password@localhost", + ) + model._get_sqla_engine() + call_args = mocked_create_engine.call_args + + assert str(call_args[0][0]) == "mysql://user:password@localhost" + assert call_args[1]["connect_args"]["local_infile"] == 0 + @mock.patch("superset.models.core.create_engine") def test_impersonate_user_trino(self, mocked_create_engine): principal_user = security_manager.find_user(username="gamma") diff --git a/tests/unit_tests/db_engine_specs/test_mysql.py b/tests/unit_tests/db_engine_specs/test_mysql.py index 091cdb3b4..31e01ace5 100644 --- a/tests/unit_tests/db_engine_specs/test_mysql.py +++ b/tests/unit_tests/db_engine_specs/test_mysql.py @@ -16,7 +16,7 @@ # under the License. from datetime import datetime -from typing import Any, Dict, Optional, Type +from typing import Any, Dict, Optional, Tuple, Type from unittest.mock import Mock, patch import pytest @@ -33,7 +33,7 @@ from sqlalchemy.dialects.mysql import ( TINYINT, TINYTEXT, ) -from sqlalchemy.engine.url import make_url +from sqlalchemy.engine.url import make_url, URL from superset.utils.core import GenericDataType from tests.unit_tests.db_engine_specs.utils import ( @@ -119,6 +119,36 @@ def test_validate_database_uri(sqlalchemy_uri: str, error: bool) -> None: MySQLEngineSpec.validate_database_uri(url) +@pytest.mark.parametrize( + "sqlalchemy_uri,connect_args,returns", + [ + ("mysql://user:password@host/db1", {"local_infile": 1}, {"local_infile": 0}), + ("mysql://user:password@host/db1", {"local_infile": -1}, {"local_infile": 0}), + ("mysql://user:password@host/db1", {"local_infile": 0}, {"local_infile": 0}), + ( + "mysql://user:password@host/db1", + {"param1": "some_value"}, + {"local_infile": 0, "param1": "some_value"}, + ), + ( + "mysql://user:password@host/db1", + {"local_infile": 1, "param1": "some_value"}, + {"local_infile": 0, "param1": "some_value"}, + ), + ], +) +def test_adjust_engine_params( + sqlalchemy_uri: str, connect_args: Dict[str, Any], returns: Dict[str, Any] +) -> None: + from superset.db_engine_specs.mysql import MySQLEngineSpec + + url = make_url(sqlalchemy_uri) + returned_url, returned_connect_args = MySQLEngineSpec.adjust_engine_params( + url, connect_args + ) + assert returned_connect_args == returns + + @patch("sqlalchemy.engine.Engine.connect") def test_get_cancel_query_id(engine_mock: Mock) -> None: from superset.db_engine_specs.mysql import MySQLEngineSpec