feat: add enforce URI query params with a specific for MySQL (#23723)
This commit is contained in:
parent
e9b4022787
commit
0ad6c879b3
|
|
@ -357,6 +357,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
top_keywords: Set[str] = {"TOP"}
|
||||
# A set of disallowed connection query parameters
|
||||
disallow_uri_query_params: Set[str] = set()
|
||||
# A Dict of query parameters that will always be used on every connection
|
||||
enforce_uri_query_params: Dict[str, Any] = {}
|
||||
|
||||
force_column_alias_quotes = False
|
||||
arraysize = 0
|
||||
|
|
@ -1089,11 +1091,12 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
``supports_dynamic_schema`` set to true, so that Superset knows in which schema a
|
||||
given query is running in order to enforce permissions (see #23385 and #23401).
|
||||
|
||||
Currently, changing the catalog is not supported. The method acceps a catalog so
|
||||
that when catalog support is added to Superse the interface remains the same. This
|
||||
is important because DB engine specs can be installed from 3rd party packages.
|
||||
Currently, changing the catalog is not supported. The method accepts a catalog so
|
||||
that when catalog support is added to Superset the interface remains the same.
|
||||
This is important because DB engine specs can be installed from 3rd party
|
||||
packages.
|
||||
"""
|
||||
return uri, connect_args
|
||||
return uri, {**connect_args, **cls.enforce_uri_query_params}
|
||||
|
||||
@classmethod
|
||||
def patch(cls) -> None:
|
||||
|
|
|
|||
|
|
@ -176,6 +176,7 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
|
|||
),
|
||||
}
|
||||
disallow_uri_query_params = {"local_infile"}
|
||||
enforce_uri_query_params = {"local_infile": 0}
|
||||
|
||||
@classmethod
|
||||
def convert_dttm(
|
||||
|
|
@ -198,10 +199,13 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
|
|||
catalog: Optional[str] = None,
|
||||
schema: Optional[str] = None,
|
||||
) -> Tuple[URL, Dict[str, Any]]:
|
||||
uri, new_connect_args = super(
|
||||
MySQLEngineSpec, MySQLEngineSpec
|
||||
).adjust_engine_params(uri, connect_args, catalog, schema)
|
||||
if schema:
|
||||
uri = uri.set(database=parse.quote(schema, safe=""))
|
||||
|
||||
return uri, connect_args
|
||||
return uri, new_connect_args
|
||||
|
||||
@classmethod
|
||||
def get_schema_from_engine_params(
|
||||
|
|
|
|||
|
|
@ -188,6 +188,21 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
"password": "original_user_password",
|
||||
}
|
||||
|
||||
@unittest.skipUnless(
|
||||
SupersetTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed"
|
||||
)
|
||||
@mock.patch("superset.models.core.create_engine")
|
||||
def test_adjust_engine_params_mysql(self, mocked_create_engine):
|
||||
model = Database(
|
||||
database_name="test_database",
|
||||
sqlalchemy_uri="mysql://user:password@localhost",
|
||||
)
|
||||
model._get_sqla_engine()
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
assert str(call_args[0][0]) == "mysql://user:password@localhost"
|
||||
assert call_args[1]["connect_args"]["local_infile"] == 0
|
||||
|
||||
@mock.patch("superset.models.core.create_engine")
|
||||
def test_impersonate_user_trino(self, mocked_create_engine):
|
||||
principal_user = security_manager.find_user(username="gamma")
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@
|
|||
# under the License.
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional, Type
|
||||
from typing import Any, Dict, Optional, Tuple, Type
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
|
@ -33,7 +33,7 @@ from sqlalchemy.dialects.mysql import (
|
|||
TINYINT,
|
||||
TINYTEXT,
|
||||
)
|
||||
from sqlalchemy.engine.url import make_url
|
||||
from sqlalchemy.engine.url import make_url, URL
|
||||
|
||||
from superset.utils.core import GenericDataType
|
||||
from tests.unit_tests.db_engine_specs.utils import (
|
||||
|
|
@ -119,6 +119,36 @@ def test_validate_database_uri(sqlalchemy_uri: str, error: bool) -> None:
|
|||
MySQLEngineSpec.validate_database_uri(url)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sqlalchemy_uri,connect_args,returns",
|
||||
[
|
||||
("mysql://user:password@host/db1", {"local_infile": 1}, {"local_infile": 0}),
|
||||
("mysql://user:password@host/db1", {"local_infile": -1}, {"local_infile": 0}),
|
||||
("mysql://user:password@host/db1", {"local_infile": 0}, {"local_infile": 0}),
|
||||
(
|
||||
"mysql://user:password@host/db1",
|
||||
{"param1": "some_value"},
|
||||
{"local_infile": 0, "param1": "some_value"},
|
||||
),
|
||||
(
|
||||
"mysql://user:password@host/db1",
|
||||
{"local_infile": 1, "param1": "some_value"},
|
||||
{"local_infile": 0, "param1": "some_value"},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_adjust_engine_params(
|
||||
sqlalchemy_uri: str, connect_args: Dict[str, Any], returns: Dict[str, Any]
|
||||
) -> None:
|
||||
from superset.db_engine_specs.mysql import MySQLEngineSpec
|
||||
|
||||
url = make_url(sqlalchemy_uri)
|
||||
returned_url, returned_connect_args = MySQLEngineSpec.adjust_engine_params(
|
||||
url, connect_args
|
||||
)
|
||||
assert returned_connect_args == returns
|
||||
|
||||
|
||||
@patch("sqlalchemy.engine.Engine.connect")
|
||||
def test_get_cancel_query_id(engine_mock: Mock) -> None:
|
||||
from superset.db_engine_specs.mysql import MySQLEngineSpec
|
||||
|
|
|
|||
Loading…
Reference in New Issue