feat(postgresql): dynamic schema (#23401)
This commit is contained in:
parent
f4035e096f
commit
2c6f581fa6
|
|
@ -371,7 +371,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
supports_file_upload = True
|
||||
|
||||
# Is the DB engine spec able to change the default schema? This requires implementing
|
||||
# a custom `adjust_database_uri` method.
|
||||
# a custom `adjust_engine_params` method.
|
||||
supports_dynamic_schema = False
|
||||
|
||||
@classmethod
|
||||
|
|
@ -472,7 +472,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
Determining the correct schema is crucial for managing access to data, so please
|
||||
make sure you understand this logic when working on a new DB engine spec.
|
||||
"""
|
||||
# default schema varies on a per-query basis
|
||||
# dynamic schema varies on a per-query basis
|
||||
if cls.supports_dynamic_schema:
|
||||
return query.schema
|
||||
|
||||
|
|
@ -1057,30 +1057,33 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
]
|
||||
|
||||
@classmethod
|
||||
def adjust_database_uri( # pylint: disable=unused-argument
|
||||
def adjust_engine_params( # pylint: disable=unused-argument
|
||||
cls,
|
||||
uri: URL,
|
||||
selected_schema: Optional[str],
|
||||
) -> URL:
|
||||
connect_args: Dict[str, Any],
|
||||
catalog: Optional[str] = None,
|
||||
schema: Optional[str] = None,
|
||||
) -> Tuple[URL, Dict[str, Any]]:
|
||||
"""
|
||||
Return a modified URL with a new database component.
|
||||
Return a new URL and ``connect_args`` for a specific catalog/schema.
|
||||
|
||||
The URI here represents the URI as entered when saving the database,
|
||||
``selected_schema`` is the schema currently active presumably in
|
||||
the SQL Lab dropdown. Based on that, for some database engine,
|
||||
we can return a new altered URI that connects straight to the
|
||||
active schema, meaning the users won't have to prefix the object
|
||||
names by the schema name.
|
||||
This is used in SQL Lab, allowing users to select a schema from the list of
|
||||
schemas available in a given database, and have the query run with that schema as
|
||||
the default one.
|
||||
|
||||
Some databases engines have 2 level of namespacing: database and
|
||||
schema (postgres, oracle, mssql, ...)
|
||||
For those it's probably better to not alter the database
|
||||
component of the URI with the schema name, it won't work.
|
||||
For some databases (like MySQL, Presto, Snowflake) this requires modifying the
|
||||
SQLAlchemy URI before creating the connection. For others (like Postgres), it
|
||||
requires additional parameters in ``connect_args``.
|
||||
|
||||
Some database drivers like Presto accept '{catalog}/{schema}' in
|
||||
the database component of the URL, that can be handled here.
|
||||
When a DB engine spec implements this method it should also have the attribute
|
||||
``supports_dynamic_schema`` set to true, so that Superset knows in which schema a
|
||||
given query is running in order to enforce permissions (see #23385 and #23401).
|
||||
|
||||
Currently, changing the catalog is not supported. The method acceps a catalog so
|
||||
that when catalog support is added to Superse the interface remains the same. This
|
||||
is important because DB engine specs can be installed from 3rd party packages.
|
||||
"""
|
||||
return uri
|
||||
return uri, connect_args
|
||||
|
||||
@classmethod
|
||||
def patch(cls) -> None:
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from urllib import parse
|
||||
|
||||
from sqlalchemy import types
|
||||
|
|
@ -71,13 +71,17 @@ class DrillEngineSpec(BaseEngineSpec):
|
|||
return None
|
||||
|
||||
@classmethod
|
||||
def adjust_database_uri(cls, uri: URL, selected_schema: Optional[str]) -> URL:
|
||||
if selected_schema:
|
||||
uri = uri.set(
|
||||
database=parse.quote(selected_schema.replace(".", "/"), safe="")
|
||||
)
|
||||
def adjust_engine_params(
|
||||
cls,
|
||||
uri: URL,
|
||||
connect_args: Dict[str, Any],
|
||||
catalog: Optional[str] = None,
|
||||
schema: Optional[str] = None,
|
||||
) -> Tuple[URL, Dict[str, Any]]:
|
||||
if schema:
|
||||
uri = uri.set(database=parse.quote(schema.replace(".", "/"), safe=""))
|
||||
|
||||
return uri
|
||||
return uri, connect_args
|
||||
|
||||
@classmethod
|
||||
def get_schema_from_engine_params(
|
||||
|
|
|
|||
|
|
@ -260,13 +260,17 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||
return None
|
||||
|
||||
@classmethod
|
||||
def adjust_database_uri(
|
||||
cls, uri: URL, selected_schema: Optional[str] = None
|
||||
) -> URL:
|
||||
if selected_schema:
|
||||
uri = uri.set(database=parse.quote(selected_schema, safe=""))
|
||||
def adjust_engine_params(
|
||||
cls,
|
||||
uri: URL,
|
||||
connect_args: Dict[str, Any],
|
||||
catalog: Optional[str] = None,
|
||||
schema: Optional[str] = None,
|
||||
) -> Tuple[URL, Dict[str, Any]]:
|
||||
if schema:
|
||||
uri = uri.set(database=parse.quote(schema, safe=""))
|
||||
|
||||
return uri
|
||||
return uri, connect_args
|
||||
|
||||
@classmethod
|
||||
def get_schema_from_engine_params(
|
||||
|
|
|
|||
|
|
@ -191,15 +191,17 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
|
|||
return None
|
||||
|
||||
@classmethod
|
||||
def adjust_database_uri(
|
||||
def adjust_engine_params(
|
||||
cls,
|
||||
uri: URL,
|
||||
selected_schema: Optional[str] = None,
|
||||
) -> URL:
|
||||
if selected_schema:
|
||||
uri = uri.set(database=parse.quote(selected_schema, safe=""))
|
||||
connect_args: Dict[str, Any],
|
||||
catalog: Optional[str] = None,
|
||||
schema: Optional[str] = None,
|
||||
) -> Tuple[URL, Dict[str, Any]]:
|
||||
if schema:
|
||||
uri = uri.set(database=parse.quote(schema, safe=""))
|
||||
|
||||
return uri
|
||||
return uri, connect_args
|
||||
|
||||
@classmethod
|
||||
def get_schema_from_engine_params(
|
||||
|
|
|
|||
|
|
@ -72,12 +72,30 @@ COLUMN_DOES_NOT_EXIST_REGEX = re.compile(
|
|||
SYNTAX_ERROR_REGEX = re.compile('syntax error at or near "(?P<syntax_error>.*?)"')
|
||||
|
||||
|
||||
def parse_options(connect_args: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""
|
||||
Parse ``options`` from ``connect_args`` into a dictionary.
|
||||
"""
|
||||
if not isinstance(connect_args.get("options"), str):
|
||||
return {}
|
||||
|
||||
tokens = (
|
||||
tuple(token.strip() for token in option.strip().split("=", 1))
|
||||
for option in re.split(r"-c\s?", connect_args["options"])
|
||||
if "=" in option
|
||||
)
|
||||
|
||||
return {token[0]: token[1] for token in tokens}
|
||||
|
||||
|
||||
class PostgresBaseEngineSpec(BaseEngineSpec):
|
||||
"""Abstract class for Postgres 'like' databases"""
|
||||
|
||||
engine = ""
|
||||
engine_name = "PostgreSQL"
|
||||
|
||||
supports_dynamic_schema = True
|
||||
|
||||
_time_grain_expressions = {
|
||||
None: "{col}",
|
||||
"PT1S": "DATE_TRUNC('second', {col})",
|
||||
|
|
@ -147,6 +165,25 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
|
|||
),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def adjust_engine_params(
|
||||
cls,
|
||||
uri: URL,
|
||||
connect_args: Dict[str, Any],
|
||||
catalog: Optional[str] = None,
|
||||
schema: Optional[str] = None,
|
||||
) -> Tuple[URL, Dict[str, Any]]:
|
||||
if not schema:
|
||||
return uri, connect_args
|
||||
|
||||
options = parse_options(connect_args)
|
||||
options["search_path"] = schema
|
||||
connect_args["options"] = " ".join(
|
||||
f"-c{key}={value}" for key, value in options.items()
|
||||
)
|
||||
|
||||
return uri, connect_args
|
||||
|
||||
@classmethod
|
||||
def get_schema_from_engine_params(
|
||||
cls,
|
||||
|
|
@ -166,19 +203,16 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
|
|||
to determine the schema for a non-qualified table in a query. In cases like
|
||||
that we raise an exception.
|
||||
"""
|
||||
options = re.split(r"-c\s?", connect_args.get("options", ""))
|
||||
for option in options:
|
||||
if "=" not in option:
|
||||
continue
|
||||
key, value = option.strip().split("=", 1)
|
||||
if key.strip() == "search_path":
|
||||
if "," in value:
|
||||
raise Exception(
|
||||
"Multiple schemas are configured in the search path, which means "
|
||||
"Superset is unable to determine the schema of unqualified table "
|
||||
"names and enforce permissions."
|
||||
)
|
||||
return value.strip()
|
||||
options = parse_options(connect_args)
|
||||
if search_path := options.get("search_path"):
|
||||
schemas = search_path.split(",")
|
||||
if len(schemas) > 1:
|
||||
raise Exception(
|
||||
"Multiple schemas are configured in the search path, which means "
|
||||
"Superset is unable to determine the schema of unqualified table "
|
||||
"names and enforce permissions."
|
||||
)
|
||||
return schemas[0]
|
||||
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -301,19 +301,23 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
|
|||
return "from_unixtime({col})"
|
||||
|
||||
@classmethod
|
||||
def adjust_database_uri(
|
||||
cls, uri: URL, selected_schema: Optional[str] = None
|
||||
) -> URL:
|
||||
def adjust_engine_params(
|
||||
cls,
|
||||
uri: URL,
|
||||
connect_args: Dict[str, Any],
|
||||
catalog: Optional[str] = None,
|
||||
schema: Optional[str] = None,
|
||||
) -> Tuple[URL, Dict[str, Any]]:
|
||||
database = uri.database
|
||||
if selected_schema and database:
|
||||
selected_schema = parse.quote(selected_schema, safe="")
|
||||
if schema and database:
|
||||
schema = parse.quote(schema, safe="")
|
||||
if "/" in database:
|
||||
database = database.split("/")[0] + "/" + selected_schema
|
||||
database = database.split("/")[0] + "/" + schema
|
||||
else:
|
||||
database += "/" + selected_schema
|
||||
database += "/" + schema
|
||||
uri = uri.set(database=database)
|
||||
|
||||
return uri
|
||||
return uri, connect_args
|
||||
|
||||
@classmethod
|
||||
def get_schema_from_engine_params(
|
||||
|
|
|
|||
|
|
@ -135,17 +135,21 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
|||
return extra
|
||||
|
||||
@classmethod
|
||||
def adjust_database_uri(
|
||||
cls, uri: URL, selected_schema: Optional[str] = None
|
||||
) -> URL:
|
||||
def adjust_engine_params(
|
||||
cls,
|
||||
uri: URL,
|
||||
connect_args: Dict[str, Any],
|
||||
catalog: Optional[str] = None,
|
||||
schema: Optional[str] = None,
|
||||
) -> Tuple[URL, Dict[str, Any]]:
|
||||
database = uri.database
|
||||
if "/" in uri.database:
|
||||
database = uri.database.split("/")[0]
|
||||
if selected_schema:
|
||||
selected_schema = parse.quote(selected_schema, safe="")
|
||||
uri = uri.set(database=f"{database}/{selected_schema}")
|
||||
if "/" in database:
|
||||
database = database.split("/")[0]
|
||||
if schema:
|
||||
schema = parse.quote(schema, safe="")
|
||||
uri = uri.set(database=f"{database}/{schema}")
|
||||
|
||||
return uri
|
||||
return uri, connect_args
|
||||
|
||||
@classmethod
|
||||
def get_schema_from_engine_params(
|
||||
|
|
|
|||
|
|
@ -421,32 +421,58 @@ class Database(
|
|||
source: Optional[utils.QuerySource] = None,
|
||||
sqlalchemy_uri: Optional[str] = None,
|
||||
) -> Engine:
|
||||
extra = self.get_extra()
|
||||
sqlalchemy_url = make_url_safe(
|
||||
sqlalchemy_uri if sqlalchemy_uri else self.sqlalchemy_uri_decrypted
|
||||
)
|
||||
self.db_engine_spec.validate_database_uri(sqlalchemy_url)
|
||||
|
||||
sqlalchemy_url = self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema)
|
||||
extra = self.get_extra()
|
||||
params = extra.get("engine_params", {})
|
||||
if nullpool:
|
||||
params["poolclass"] = NullPool
|
||||
connect_args = params.get("connect_args", {})
|
||||
|
||||
# The ``adjust_database_uri`` method was renamed to ``adjust_engine_params`` and
|
||||
# had its signature changed in order to support more DB engine specs. Since DB
|
||||
# engine specs can be released as 3rd party modules we want to make sure the old
|
||||
# method is still supported so we don't introduce a breaking change.
|
||||
if hasattr(self.db_engine_spec, "adjust_database_uri"):
|
||||
sqlalchemy_url = self.db_engine_spec.adjust_database_uri(
|
||||
sqlalchemy_url,
|
||||
schema,
|
||||
)
|
||||
logger.warning(
|
||||
"DB engine spec %s implements the method `adjust_database_uri`, which is "
|
||||
"deprecated and will be removed in version 3.0. Please update it to "
|
||||
"implement `adjust_engine_params` instead.",
|
||||
self.db_engine_spec,
|
||||
)
|
||||
|
||||
sqlalchemy_url, connect_args = self.db_engine_spec.adjust_engine_params(
|
||||
uri=sqlalchemy_url,
|
||||
connect_args=connect_args,
|
||||
catalog=None,
|
||||
schema=schema,
|
||||
)
|
||||
|
||||
effective_username = self.get_effective_user(sqlalchemy_url)
|
||||
# If using MySQL or Presto for example, will set url.username
|
||||
# If using Hive, will not do anything yet since that relies on a
|
||||
# configuration parameter instead.
|
||||
sqlalchemy_url = self.db_engine_spec.get_url_for_impersonation(
|
||||
sqlalchemy_url, self.impersonate_user, effective_username
|
||||
sqlalchemy_url,
|
||||
self.impersonate_user,
|
||||
effective_username,
|
||||
)
|
||||
|
||||
masked_url = self.get_password_masked_url(sqlalchemy_url)
|
||||
logger.debug("Database._get_sqla_engine(). Masked URL: %s", str(masked_url))
|
||||
|
||||
params = extra.get("engine_params", {})
|
||||
if nullpool:
|
||||
params["poolclass"] = NullPool
|
||||
|
||||
connect_args = params.get("connect_args", {})
|
||||
if self.impersonate_user:
|
||||
self.db_engine_spec.update_impersonation_config(
|
||||
connect_args, str(sqlalchemy_url), effective_username
|
||||
connect_args,
|
||||
str(sqlalchemy_url),
|
||||
effective_username,
|
||||
)
|
||||
|
||||
if connect_args:
|
||||
|
|
@ -464,7 +490,11 @@ class Database(
|
|||
source = utils.QuerySource.SQL_LAB
|
||||
|
||||
sqlalchemy_url, params = DB_CONNECTION_MUTATOR(
|
||||
sqlalchemy_url, params, effective_username, security_manager, source
|
||||
sqlalchemy_url,
|
||||
params,
|
||||
effective_username,
|
||||
security_manager,
|
||||
source,
|
||||
)
|
||||
try:
|
||||
return create_engine(sqlalchemy_url, **params)
|
||||
|
|
|
|||
|
|
@ -131,3 +131,27 @@ def test_get_schema_from_engine_params() -> None:
|
|||
"Superset is unable to determine the schema of unqualified table "
|
||||
"names and enforce permissions."
|
||||
)
|
||||
|
||||
|
||||
def test_adjust_engine_params() -> None:
|
||||
"""
|
||||
Test the ``adjust_engine_params`` method.
|
||||
"""
|
||||
from superset.db_engine_specs.postgres import PostgresEngineSpec
|
||||
|
||||
uri = make_url("postgres://user:password@host/catalog")
|
||||
|
||||
assert PostgresEngineSpec.adjust_engine_params(uri, {}, None, "secret") == (
|
||||
uri,
|
||||
{"options": "-csearch_path=secret"},
|
||||
)
|
||||
|
||||
assert PostgresEngineSpec.adjust_engine_params(
|
||||
uri,
|
||||
{"foo": "bar", "options": "-csearch_path=default -c debug=1"},
|
||||
None,
|
||||
"secret",
|
||||
) == (
|
||||
uri,
|
||||
{"foo": "bar", "options": "-csearch_path=secret -cdebug=1"},
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue