feat: add enforce URI query params with a specific for MySQL (#23723)

This commit is contained in:
Daniel Vaz Gaspar 2023-04-18 17:07:37 +01:00 committed by GitHub
parent e9b4022787
commit 0ad6c879b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 59 additions and 7 deletions

View File

@ -357,6 +357,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
top_keywords: Set[str] = {"TOP"}
# A set of disallowed connection query parameters
disallow_uri_query_params: Set[str] = set()
# A Dict of query parameters that will always be used on every connection
enforce_uri_query_params: Dict[str, Any] = {}
force_column_alias_quotes = False
arraysize = 0
@ -1089,11 +1091,12 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
``supports_dynamic_schema`` set to true, so that Superset knows in which schema a
given query is running in order to enforce permissions (see #23385 and #23401).
Currently, changing the catalog is not supported. The method acceps a catalog so
that when catalog support is added to Superse the interface remains the same. This
is important because DB engine specs can be installed from 3rd party packages.
Currently, changing the catalog is not supported. The method accepts a catalog so
that when catalog support is added to Superset the interface remains the same.
This is important because DB engine specs can be installed from 3rd party
packages.
"""
return uri, connect_args
return uri, {**connect_args, **cls.enforce_uri_query_params}
@classmethod
def patch(cls) -> None:

View File

@ -176,6 +176,7 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
),
}
disallow_uri_query_params = {"local_infile"}
enforce_uri_query_params = {"local_infile": 0}
@classmethod
def convert_dttm(
@ -198,10 +199,13 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
catalog: Optional[str] = None,
schema: Optional[str] = None,
) -> Tuple[URL, Dict[str, Any]]:
uri, new_connect_args = super(
MySQLEngineSpec, MySQLEngineSpec
).adjust_engine_params(uri, connect_args, catalog, schema)
if schema:
uri = uri.set(database=parse.quote(schema, safe=""))
return uri, connect_args
return uri, new_connect_args
@classmethod
def get_schema_from_engine_params(

View File

@ -188,6 +188,21 @@ class TestDatabaseModel(SupersetTestCase):
"password": "original_user_password",
}
@unittest.skipUnless(
SupersetTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed"
)
@mock.patch("superset.models.core.create_engine")
def test_adjust_engine_params_mysql(self, mocked_create_engine):
model = Database(
database_name="test_database",
sqlalchemy_uri="mysql://user:password@localhost",
)
model._get_sqla_engine()
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "mysql://user:password@localhost"
assert call_args[1]["connect_args"]["local_infile"] == 0
@mock.patch("superset.models.core.create_engine")
def test_impersonate_user_trino(self, mocked_create_engine):
principal_user = security_manager.find_user(username="gamma")

View File

@ -16,7 +16,7 @@
# under the License.
from datetime import datetime
from typing import Any, Dict, Optional, Type
from typing import Any, Dict, Optional, Tuple, Type
from unittest.mock import Mock, patch
import pytest
@ -33,7 +33,7 @@ from sqlalchemy.dialects.mysql import (
TINYINT,
TINYTEXT,
)
from sqlalchemy.engine.url import make_url
from sqlalchemy.engine.url import make_url, URL
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
@ -119,6 +119,36 @@ def test_validate_database_uri(sqlalchemy_uri: str, error: bool) -> None:
MySQLEngineSpec.validate_database_uri(url)
@pytest.mark.parametrize(
"sqlalchemy_uri,connect_args,returns",
[
("mysql://user:password@host/db1", {"local_infile": 1}, {"local_infile": 0}),
("mysql://user:password@host/db1", {"local_infile": -1}, {"local_infile": 0}),
("mysql://user:password@host/db1", {"local_infile": 0}, {"local_infile": 0}),
(
"mysql://user:password@host/db1",
{"param1": "some_value"},
{"local_infile": 0, "param1": "some_value"},
),
(
"mysql://user:password@host/db1",
{"local_infile": 1, "param1": "some_value"},
{"local_infile": 0, "param1": "some_value"},
),
],
)
def test_adjust_engine_params(
sqlalchemy_uri: str, connect_args: Dict[str, Any], returns: Dict[str, Any]
) -> None:
from superset.db_engine_specs.mysql import MySQLEngineSpec
url = make_url(sqlalchemy_uri)
returned_url, returned_connect_args = MySQLEngineSpec.adjust_engine_params(
url, connect_args
)
assert returned_connect_args == returns
@patch("sqlalchemy.engine.Engine.connect")
def test_get_cancel_query_id(engine_mock: Mock) -> None:
from superset.db_engine_specs.mysql import MySQLEngineSpec