feat(postgresql): dynamic schema (#23401)

This commit is contained in:
Beto Dealmeida 2023-03-17 17:53:42 -07:00 committed by GitHub
parent f4035e096f
commit 2c6f581fa6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 187 additions and 78 deletions

View File

@ -371,7 +371,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
supports_file_upload = True
# Is the DB engine spec able to change the default schema? This requires implementing
# a custom `adjust_database_uri` method.
# a custom `adjust_engine_params` method.
supports_dynamic_schema = False
@classmethod
@ -472,7 +472,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
Determining the correct schema is crucial for managing access to data, so please
make sure you understand this logic when working on a new DB engine spec.
"""
# default schema varies on a per-query basis
# dynamic schema varies on a per-query basis
if cls.supports_dynamic_schema:
return query.schema
@ -1057,30 +1057,33 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
]
@classmethod
def adjust_database_uri( # pylint: disable=unused-argument
def adjust_engine_params( # pylint: disable=unused-argument
cls,
uri: URL,
selected_schema: Optional[str],
) -> URL:
connect_args: Dict[str, Any],
catalog: Optional[str] = None,
schema: Optional[str] = None,
) -> Tuple[URL, Dict[str, Any]]:
"""
Return a modified URL with a new database component.
Return a new URL and ``connect_args`` for a specific catalog/schema.
The URI here represents the URI as entered when saving the database,
``selected_schema`` is the schema currently active presumably in
the SQL Lab dropdown. Based on that, for some database engine,
we can return a new altered URI that connects straight to the
active schema, meaning the users won't have to prefix the object
names by the schema name.
This is used in SQL Lab, allowing users to select a schema from the list of
schemas available in a given database, and have the query run with that schema as
the default one.
Some databases engines have 2 level of namespacing: database and
schema (postgres, oracle, mssql, ...)
For those it's probably better to not alter the database
component of the URI with the schema name, it won't work.
For some databases (like MySQL, Presto, Snowflake) this requires modifying the
SQLAlchemy URI before creating the connection. For others (like Postgres), it
requires additional parameters in ``connect_args``.
Some database drivers like Presto accept '{catalog}/{schema}' in
the database component of the URL, that can be handled here.
When a DB engine spec implements this method it should also have the attribute
``supports_dynamic_schema`` set to true, so that Superset knows in which schema a
given query is running in order to enforce permissions (see #23385 and #23401).
Currently, changing the catalog is not supported. The method acceps a catalog so
that when catalog support is added to Superse the interface remains the same. This
is important because DB engine specs can be installed from 3rd party packages.
"""
return uri
return uri, connect_args
@classmethod
def patch(cls) -> None:

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Tuple
from urllib import parse
from sqlalchemy import types
@ -71,13 +71,17 @@ class DrillEngineSpec(BaseEngineSpec):
return None
@classmethod
def adjust_database_uri(cls, uri: URL, selected_schema: Optional[str]) -> URL:
if selected_schema:
uri = uri.set(
database=parse.quote(selected_schema.replace(".", "/"), safe="")
)
def adjust_engine_params(
cls,
uri: URL,
connect_args: Dict[str, Any],
catalog: Optional[str] = None,
schema: Optional[str] = None,
) -> Tuple[URL, Dict[str, Any]]:
if schema:
uri = uri.set(database=parse.quote(schema.replace(".", "/"), safe=""))
return uri
return uri, connect_args
@classmethod
def get_schema_from_engine_params(

View File

@ -260,13 +260,17 @@ class HiveEngineSpec(PrestoEngineSpec):
return None
@classmethod
def adjust_database_uri(
cls, uri: URL, selected_schema: Optional[str] = None
) -> URL:
if selected_schema:
uri = uri.set(database=parse.quote(selected_schema, safe=""))
def adjust_engine_params(
cls,
uri: URL,
connect_args: Dict[str, Any],
catalog: Optional[str] = None,
schema: Optional[str] = None,
) -> Tuple[URL, Dict[str, Any]]:
if schema:
uri = uri.set(database=parse.quote(schema, safe=""))
return uri
return uri, connect_args
@classmethod
def get_schema_from_engine_params(

View File

@ -191,15 +191,17 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
return None
@classmethod
def adjust_database_uri(
def adjust_engine_params(
cls,
uri: URL,
selected_schema: Optional[str] = None,
) -> URL:
if selected_schema:
uri = uri.set(database=parse.quote(selected_schema, safe=""))
connect_args: Dict[str, Any],
catalog: Optional[str] = None,
schema: Optional[str] = None,
) -> Tuple[URL, Dict[str, Any]]:
if schema:
uri = uri.set(database=parse.quote(schema, safe=""))
return uri
return uri, connect_args
@classmethod
def get_schema_from_engine_params(

View File

@ -72,12 +72,30 @@ COLUMN_DOES_NOT_EXIST_REGEX = re.compile(
SYNTAX_ERROR_REGEX = re.compile('syntax error at or near "(?P<syntax_error>.*?)"')
def parse_options(connect_args: Dict[str, Any]) -> Dict[str, str]:
"""
Parse ``options`` from ``connect_args`` into a dictionary.
"""
if not isinstance(connect_args.get("options"), str):
return {}
tokens = (
tuple(token.strip() for token in option.strip().split("=", 1))
for option in re.split(r"-c\s?", connect_args["options"])
if "=" in option
)
return {token[0]: token[1] for token in tokens}
class PostgresBaseEngineSpec(BaseEngineSpec):
"""Abstract class for Postgres 'like' databases"""
engine = ""
engine_name = "PostgreSQL"
supports_dynamic_schema = True
_time_grain_expressions = {
None: "{col}",
"PT1S": "DATE_TRUNC('second', {col})",
@ -147,6 +165,25 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
),
}
@classmethod
def adjust_engine_params(
cls,
uri: URL,
connect_args: Dict[str, Any],
catalog: Optional[str] = None,
schema: Optional[str] = None,
) -> Tuple[URL, Dict[str, Any]]:
if not schema:
return uri, connect_args
options = parse_options(connect_args)
options["search_path"] = schema
connect_args["options"] = " ".join(
f"-c{key}={value}" for key, value in options.items()
)
return uri, connect_args
@classmethod
def get_schema_from_engine_params(
cls,
@ -166,19 +203,16 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
to determine the schema for a non-qualified table in a query. In cases like
that we raise an exception.
"""
options = re.split(r"-c\s?", connect_args.get("options", ""))
for option in options:
if "=" not in option:
continue
key, value = option.strip().split("=", 1)
if key.strip() == "search_path":
if "," in value:
raise Exception(
"Multiple schemas are configured in the search path, which means "
"Superset is unable to determine the schema of unqualified table "
"names and enforce permissions."
)
return value.strip()
options = parse_options(connect_args)
if search_path := options.get("search_path"):
schemas = search_path.split(",")
if len(schemas) > 1:
raise Exception(
"Multiple schemas are configured in the search path, which means "
"Superset is unable to determine the schema of unqualified table "
"names and enforce permissions."
)
return schemas[0]
return None

View File

@ -301,19 +301,23 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
return "from_unixtime({col})"
@classmethod
def adjust_database_uri(
cls, uri: URL, selected_schema: Optional[str] = None
) -> URL:
def adjust_engine_params(
cls,
uri: URL,
connect_args: Dict[str, Any],
catalog: Optional[str] = None,
schema: Optional[str] = None,
) -> Tuple[URL, Dict[str, Any]]:
database = uri.database
if selected_schema and database:
selected_schema = parse.quote(selected_schema, safe="")
if schema and database:
schema = parse.quote(schema, safe="")
if "/" in database:
database = database.split("/")[0] + "/" + selected_schema
database = database.split("/")[0] + "/" + schema
else:
database += "/" + selected_schema
database += "/" + schema
uri = uri.set(database=database)
return uri
return uri, connect_args
@classmethod
def get_schema_from_engine_params(

View File

@ -135,17 +135,21 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
return extra
@classmethod
def adjust_database_uri(
cls, uri: URL, selected_schema: Optional[str] = None
) -> URL:
def adjust_engine_params(
cls,
uri: URL,
connect_args: Dict[str, Any],
catalog: Optional[str] = None,
schema: Optional[str] = None,
) -> Tuple[URL, Dict[str, Any]]:
database = uri.database
if "/" in uri.database:
database = uri.database.split("/")[0]
if selected_schema:
selected_schema = parse.quote(selected_schema, safe="")
uri = uri.set(database=f"{database}/{selected_schema}")
if "/" in database:
database = database.split("/")[0]
if schema:
schema = parse.quote(schema, safe="")
uri = uri.set(database=f"{database}/{schema}")
return uri
return uri, connect_args
@classmethod
def get_schema_from_engine_params(

View File

@ -421,32 +421,58 @@ class Database(
source: Optional[utils.QuerySource] = None,
sqlalchemy_uri: Optional[str] = None,
) -> Engine:
extra = self.get_extra()
sqlalchemy_url = make_url_safe(
sqlalchemy_uri if sqlalchemy_uri else self.sqlalchemy_uri_decrypted
)
self.db_engine_spec.validate_database_uri(sqlalchemy_url)
sqlalchemy_url = self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema)
extra = self.get_extra()
params = extra.get("engine_params", {})
if nullpool:
params["poolclass"] = NullPool
connect_args = params.get("connect_args", {})
# The ``adjust_database_uri`` method was renamed to ``adjust_engine_params`` and
# had its signature changed in order to support more DB engine specs. Since DB
# engine specs can be released as 3rd party modules we want to make sure the old
# method is still supported so we don't introduce a breaking change.
if hasattr(self.db_engine_spec, "adjust_database_uri"):
sqlalchemy_url = self.db_engine_spec.adjust_database_uri(
sqlalchemy_url,
schema,
)
logger.warning(
"DB engine spec %s implements the method `adjust_database_uri`, which is "
"deprecated and will be removed in version 3.0. Please update it to "
"implement `adjust_engine_params` instead.",
self.db_engine_spec,
)
sqlalchemy_url, connect_args = self.db_engine_spec.adjust_engine_params(
uri=sqlalchemy_url,
connect_args=connect_args,
catalog=None,
schema=schema,
)
effective_username = self.get_effective_user(sqlalchemy_url)
# If using MySQL or Presto for example, will set url.username
# If using Hive, will not do anything yet since that relies on a
# configuration parameter instead.
sqlalchemy_url = self.db_engine_spec.get_url_for_impersonation(
sqlalchemy_url, self.impersonate_user, effective_username
sqlalchemy_url,
self.impersonate_user,
effective_username,
)
masked_url = self.get_password_masked_url(sqlalchemy_url)
logger.debug("Database._get_sqla_engine(). Masked URL: %s", str(masked_url))
params = extra.get("engine_params", {})
if nullpool:
params["poolclass"] = NullPool
connect_args = params.get("connect_args", {})
if self.impersonate_user:
self.db_engine_spec.update_impersonation_config(
connect_args, str(sqlalchemy_url), effective_username
connect_args,
str(sqlalchemy_url),
effective_username,
)
if connect_args:
@ -464,7 +490,11 @@ class Database(
source = utils.QuerySource.SQL_LAB
sqlalchemy_url, params = DB_CONNECTION_MUTATOR(
sqlalchemy_url, params, effective_username, security_manager, source
sqlalchemy_url,
params,
effective_username,
security_manager,
source,
)
try:
return create_engine(sqlalchemy_url, **params)

View File

@ -131,3 +131,27 @@ def test_get_schema_from_engine_params() -> None:
"Superset is unable to determine the schema of unqualified table "
"names and enforce permissions."
)
def test_adjust_engine_params() -> None:
"""
Test the ``adjust_engine_params`` method.
"""
from superset.db_engine_specs.postgres import PostgresEngineSpec
uri = make_url("postgres://user:password@host/catalog")
assert PostgresEngineSpec.adjust_engine_params(uri, {}, None, "secret") == (
uri,
{"options": "-csearch_path=secret"},
)
assert PostgresEngineSpec.adjust_engine_params(
uri,
{"foo": "bar", "options": "-csearch_path=default -c debug=1"},
None,
"secret",
) == (
uri,
{"foo": "bar", "options": "-csearch_path=secret -cdebug=1"},
)