fix: allow db driver distinction on enforced URI params (#23769)
This commit is contained in:
parent
adde66785c
commit
6ae5388dcf
|
|
@ -355,10 +355,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
# This set will give the keywords for data limit statements
|
||||
# to consider for the engines with TOP SQL parsing
|
||||
top_keywords: Set[str] = {"TOP"}
|
||||
# A set of disallowed connection query parameters
|
||||
disallow_uri_query_params: Set[str] = set()
|
||||
# A set of disallowed connection query parameters by driver name
|
||||
disallow_uri_query_params: Dict[str, Set[str]] = {}
|
||||
# A Dict of query parameters that will always be used on every connection
|
||||
enforce_uri_query_params: Dict[str, Any] = {}
|
||||
# by driver name
|
||||
enforce_uri_query_params: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
force_column_alias_quotes = False
|
||||
arraysize = 0
|
||||
|
|
@ -1099,7 +1100,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
This is important because DB engine specs can be installed from 3rd party
|
||||
packages.
|
||||
"""
|
||||
return uri, {**connect_args, **cls.enforce_uri_query_params}
|
||||
return uri, {
|
||||
**connect_args,
|
||||
**cls.enforce_uri_query_params.get(uri.get_driver_name(), {}),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def patch(cls) -> None:
|
||||
|
|
@ -1853,9 +1857,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
|
||||
:param sqlalchemy_uri:
|
||||
"""
|
||||
if existing_disallowed := cls.disallow_uri_query_params.intersection(
|
||||
sqlalchemy_uri.query
|
||||
):
|
||||
if existing_disallowed := cls.disallow_uri_query_params.get(
|
||||
sqlalchemy_uri.get_driver_name(), set()
|
||||
).intersection(sqlalchemy_uri.query):
|
||||
raise ValueError(f"Forbidden query parameter(s): {existing_disallowed}")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -175,8 +175,14 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
|
|||
{},
|
||||
),
|
||||
}
|
||||
disallow_uri_query_params = {"local_infile"}
|
||||
enforce_uri_query_params = {"local_infile": 0}
|
||||
disallow_uri_query_params = {
|
||||
"mysqldb": {"local_infile"},
|
||||
"mysqlconnector": {"allow_local_infile"},
|
||||
}
|
||||
enforce_uri_query_params = {
|
||||
"mysqldb": {"local_infile": 0},
|
||||
"mysqlconnector": {"allow_local_infile": 0},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def convert_dttm(
|
||||
|
|
|
|||
|
|
@ -194,7 +194,7 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
@mock.patch("superset.models.core.create_engine")
|
||||
def test_adjust_engine_params_mysql(self, mocked_create_engine):
|
||||
model = Database(
|
||||
database_name="test_database",
|
||||
database_name="test_database1",
|
||||
sqlalchemy_uri="mysql://user:password@localhost",
|
||||
)
|
||||
model._get_sqla_engine()
|
||||
|
|
@ -203,6 +203,16 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
assert str(call_args[0][0]) == "mysql://user:password@localhost"
|
||||
assert call_args[1]["connect_args"]["local_infile"] == 0
|
||||
|
||||
model = Database(
|
||||
database_name="test_database2",
|
||||
sqlalchemy_uri="mysql+mysqlconnector://user:password@localhost",
|
||||
)
|
||||
model._get_sqla_engine()
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
assert str(call_args[0][0]) == "mysql+mysqlconnector://user:password@localhost"
|
||||
assert call_args[1]["connect_args"]["allow_local_infile"] == 0
|
||||
|
||||
@mock.patch("superset.models.core.create_engine")
|
||||
def test_impersonate_user_trino(self, mocked_create_engine):
|
||||
principal_user = security_manager.find_user(username="gamma")
|
||||
|
|
|
|||
|
|
@ -104,8 +104,11 @@ def test_convert_dttm(
|
|||
"sqlalchemy_uri,error",
|
||||
[
|
||||
("mysql://user:password@host/db1?local_infile=1", True),
|
||||
("mysql+mysqlconnector://user:password@host/db1?allow_local_infile=1", True),
|
||||
("mysql://user:password@host/db1?local_infile=0", True),
|
||||
("mysql+mysqlconnector://user:password@host/db1?allow_local_infile=0", True),
|
||||
("mysql://user:password@host/db1", False),
|
||||
("mysql+mysqlconnector://user:password@host/db1", False),
|
||||
],
|
||||
)
|
||||
def test_validate_database_uri(sqlalchemy_uri: str, error: bool) -> None:
|
||||
|
|
@ -123,18 +126,43 @@ def test_validate_database_uri(sqlalchemy_uri: str, error: bool) -> None:
|
|||
"sqlalchemy_uri,connect_args,returns",
|
||||
[
|
||||
("mysql://user:password@host/db1", {"local_infile": 1}, {"local_infile": 0}),
|
||||
(
|
||||
"mysql+mysqlconnector://user:password@host/db1",
|
||||
{"allow_local_infile": 1},
|
||||
{"allow_local_infile": 0},
|
||||
),
|
||||
("mysql://user:password@host/db1", {"local_infile": -1}, {"local_infile": 0}),
|
||||
(
|
||||
"mysql+mysqlconnector://user:password@host/db1",
|
||||
{"allow_local_infile": -1},
|
||||
{"allow_local_infile": 0},
|
||||
),
|
||||
("mysql://user:password@host/db1", {"local_infile": 0}, {"local_infile": 0}),
|
||||
(
|
||||
"mysql+mysqlconnector://user:password@host/db1",
|
||||
{"allow_local_infile": 0},
|
||||
{"allow_local_infile": 0},
|
||||
),
|
||||
(
|
||||
"mysql://user:password@host/db1",
|
||||
{"param1": "some_value"},
|
||||
{"local_infile": 0, "param1": "some_value"},
|
||||
),
|
||||
(
|
||||
"mysql+mysqlconnector://user:password@host/db1",
|
||||
{"param1": "some_value"},
|
||||
{"allow_local_infile": 0, "param1": "some_value"},
|
||||
),
|
||||
(
|
||||
"mysql://user:password@host/db1",
|
||||
{"local_infile": 1, "param1": "some_value"},
|
||||
{"local_infile": 0, "param1": "some_value"},
|
||||
),
|
||||
(
|
||||
"mysql+mysqlconnector://user:password@host/db1",
|
||||
{"allow_local_infile": 1, "param1": "some_value"},
|
||||
{"allow_local_infile": 0, "param1": "some_value"},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_adjust_engine_params(
|
||||
|
|
|
|||
Loading…
Reference in New Issue