1013 lines
36 KiB
Python
Executable File
1013 lines
36 KiB
Python
Executable File
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
# pylint: disable=line-too-long,too-many-lines
|
|
"""A collection of ORM sqlalchemy models for Superset"""
|
|
import builtins
|
|
import enum
|
|
import json
|
|
import logging
|
|
import textwrap
|
|
from ast import literal_eval
|
|
from contextlib import closing, contextmanager, nullcontext
|
|
from copy import deepcopy
|
|
from datetime import datetime
|
|
from functools import lru_cache
|
|
from typing import Any, Callable, Optional, TYPE_CHECKING
|
|
|
|
import numpy
|
|
import pandas as pd
|
|
import sqlalchemy as sqla
|
|
import sshtunnel
|
|
from flask import g, request
|
|
from flask_appbuilder import Model
|
|
from sqlalchemy import (
|
|
Boolean,
|
|
Column,
|
|
create_engine,
|
|
DateTime,
|
|
ForeignKey,
|
|
Integer,
|
|
MetaData,
|
|
String,
|
|
Table,
|
|
Text,
|
|
)
|
|
from sqlalchemy.engine import Connection, Dialect, Engine
|
|
from sqlalchemy.engine.reflection import Inspector
|
|
from sqlalchemy.engine.url import URL
|
|
from sqlalchemy.exc import NoSuchModuleError
|
|
from sqlalchemy.ext.hybrid import hybrid_property
|
|
from sqlalchemy.orm import relationship
|
|
from sqlalchemy.pool import NullPool
|
|
from sqlalchemy.schema import UniqueConstraint
|
|
from sqlalchemy.sql import ColumnElement, expression, Select
|
|
|
|
from superset import app, db_engine_specs
|
|
from superset.constants import LRU_CACHE_MAX_SIZE, PASSWORD_MASK
|
|
from superset.databases.commands.exceptions import DatabaseInvalidError
|
|
from superset.databases.utils import make_url_safe
|
|
from superset.db_engine_specs.base import MetricType, TimeGrain
|
|
from superset.extensions import (
|
|
cache_manager,
|
|
encrypted_field_factory,
|
|
security_manager,
|
|
ssh_manager_factory,
|
|
)
|
|
from superset.models.helpers import AuditMixinNullable, ImportExportMixin
|
|
from superset.result_set import SupersetResultSet
|
|
from superset.utils import cache as cache_util, core as utils
|
|
from superset.utils.core import get_username
|
|
|
|
config = app.config
|
|
custom_password_store = config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"]
|
|
stats_logger = config["STATS_LOGGER"]
|
|
log_query = config["QUERY_LOGGER"]
|
|
metadata = Model.metadata # pylint: disable=no-member
|
|
logger = logging.getLogger(__name__)
|
|
|
|
if TYPE_CHECKING:
|
|
from superset.databases.ssh_tunnel.models import SSHTunnel
|
|
from superset.models.sql_lab import Query
|
|
|
|
DB_CONNECTION_MUTATOR = config["DB_CONNECTION_MUTATOR"]
|
|
|
|
|
|
class Url(Model, AuditMixinNullable):
|
|
"""Used for the short url feature"""
|
|
|
|
__tablename__ = "url"
|
|
id = Column(Integer, primary_key=True)
|
|
url = Column(Text)
|
|
|
|
|
|
class KeyValue(Model): # pylint: disable=too-few-public-methods
|
|
|
|
"""Used for any type of key-value store"""
|
|
|
|
__tablename__ = "keyvalue"
|
|
id = Column(Integer, primary_key=True)
|
|
value = Column(Text, nullable=False)
|
|
|
|
|
|
class CssTemplate(Model, AuditMixinNullable):
|
|
"""CSS templates for dashboards"""
|
|
|
|
__tablename__ = "css_templates"
|
|
id = Column(Integer, primary_key=True)
|
|
template_name = Column(String(250))
|
|
css = Column(Text, default="")
|
|
|
|
|
|
class ConfigurationMethod(str, enum.Enum):
|
|
SQLALCHEMY_FORM = "sqlalchemy_form"
|
|
DYNAMIC_FORM = "dynamic_form"
|
|
|
|
|
|
class Database(
|
|
Model, AuditMixinNullable, ImportExportMixin
|
|
): # pylint: disable=too-many-public-methods
|
|
|
|
"""An ORM object that stores Database related information"""
|
|
|
|
__tablename__ = "dbs"
|
|
type = "table"
|
|
__table_args__ = (UniqueConstraint("database_name"),)
|
|
|
|
id = Column(Integer, primary_key=True)
|
|
verbose_name = Column(String(250), unique=True)
|
|
# short unique name, used in permissions
|
|
database_name = Column(String(250), unique=True, nullable=False)
|
|
sqlalchemy_uri = Column(String(1024), nullable=False)
|
|
password = Column(encrypted_field_factory.create(String(1024)))
|
|
cache_timeout = Column(Integer)
|
|
select_as_create_table_as = Column(Boolean, default=False)
|
|
expose_in_sqllab = Column(Boolean, default=True)
|
|
configuration_method = Column(
|
|
String(255), server_default=ConfigurationMethod.SQLALCHEMY_FORM.value
|
|
)
|
|
allow_run_async = Column(Boolean, default=False)
|
|
allow_file_upload = Column(Boolean, default=False)
|
|
allow_ctas = Column(Boolean, default=False)
|
|
allow_cvas = Column(Boolean, default=False)
|
|
allow_dml = Column(Boolean, default=False)
|
|
force_ctas_schema = Column(String(250))
|
|
extra = Column(
|
|
Text,
|
|
default=textwrap.dedent(
|
|
"""\
|
|
{
|
|
"metadata_params": {},
|
|
"engine_params": {},
|
|
"metadata_cache_timeout": {},
|
|
"schemas_allowed_for_file_upload": []
|
|
}
|
|
"""
|
|
),
|
|
)
|
|
encrypted_extra = Column(encrypted_field_factory.create(Text), nullable=True)
|
|
impersonate_user = Column(Boolean, default=False)
|
|
server_cert = Column(encrypted_field_factory.create(Text), nullable=True)
|
|
is_managed_externally = Column(Boolean, nullable=False, default=False)
|
|
external_url = Column(Text, nullable=True)
|
|
|
|
export_fields = [
|
|
"database_name",
|
|
"sqlalchemy_uri",
|
|
"cache_timeout",
|
|
"expose_in_sqllab",
|
|
"allow_run_async",
|
|
"allow_ctas",
|
|
"allow_cvas",
|
|
"allow_dml",
|
|
"allow_file_upload",
|
|
"extra",
|
|
]
|
|
extra_import_fields = [
|
|
"password",
|
|
"is_managed_externally",
|
|
"external_url",
|
|
"encrypted_extra",
|
|
]
|
|
export_children = ["tables"]
|
|
|
|
def __repr__(self) -> str:
|
|
return self.name
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return self.verbose_name if self.verbose_name else self.database_name
|
|
|
|
@property
|
|
def allows_subquery(self) -> bool:
|
|
return self.db_engine_spec.allows_subqueries
|
|
|
|
@property
|
|
def function_names(self) -> list[str]:
|
|
try:
|
|
return self.db_engine_spec.get_function_names(self)
|
|
except Exception as ex: # pylint: disable=broad-except
|
|
# function_names property is used in bulk APIs and should not hard crash
|
|
# more info in: https://github.com/apache/superset/issues/9678
|
|
logger.error(
|
|
"Failed to fetch database function names with error: %s",
|
|
str(ex),
|
|
exc_info=True,
|
|
)
|
|
return []
|
|
|
|
@property
|
|
def allows_cost_estimate(self) -> bool:
|
|
extra = self.get_extra() or {}
|
|
cost_estimate_enabled: bool = extra.get("cost_estimate_enabled") # type: ignore
|
|
|
|
return (
|
|
self.db_engine_spec.get_allow_cost_estimate(extra) and cost_estimate_enabled
|
|
)
|
|
|
|
@property
|
|
def allows_virtual_table_explore(self) -> bool:
|
|
extra = self.get_extra()
|
|
|
|
return bool(extra.get("allows_virtual_table_explore", True))
|
|
|
|
@property
|
|
def explore_database_id(self) -> int:
|
|
return self.get_extra().get("explore_database_id", self.id)
|
|
|
|
@property
|
|
def disable_data_preview(self) -> bool:
|
|
# this will prevent any 'trash value' strings from going through
|
|
if self.get_extra().get("disable_data_preview", False) is not True:
|
|
return False
|
|
return True
|
|
|
|
@property
|
|
def data(self) -> dict[str, Any]:
|
|
return {
|
|
"id": self.id,
|
|
"name": self.database_name,
|
|
"backend": self.backend,
|
|
"configuration_method": self.configuration_method,
|
|
"allows_subquery": self.allows_subquery,
|
|
"allows_cost_estimate": self.allows_cost_estimate,
|
|
"allows_virtual_table_explore": self.allows_virtual_table_explore,
|
|
"explore_database_id": self.explore_database_id,
|
|
"parameters": self.parameters,
|
|
"disable_data_preview": self.disable_data_preview,
|
|
"parameters_schema": self.parameters_schema,
|
|
"engine_information": self.engine_information,
|
|
}
|
|
|
|
@property
|
|
def unique_name(self) -> str:
|
|
return self.database_name
|
|
|
|
@property
|
|
def url_object(self) -> URL:
|
|
return make_url_safe(self.sqlalchemy_uri_decrypted)
|
|
|
|
@property
|
|
def backend(self) -> str:
|
|
return self.url_object.get_backend_name()
|
|
|
|
@property
|
|
def driver(self) -> str:
|
|
return self.url_object.get_driver_name()
|
|
|
|
@property
|
|
def masked_encrypted_extra(self) -> Optional[str]:
|
|
return self.db_engine_spec.mask_encrypted_extra(self.encrypted_extra)
|
|
|
|
@property
|
|
def parameters(self) -> dict[str, Any]:
|
|
# Database parameters are a dictionary of values that are used to make up
|
|
# the sqlalchemy_uri
|
|
# When returning the parameters we should use the masked SQLAlchemy URI and the
|
|
# masked ``encrypted_extra`` to prevent exposing sensitive credentials.
|
|
masked_uri = make_url_safe(self.sqlalchemy_uri)
|
|
encrypted_config = {}
|
|
if (masked_encrypted_extra := self.masked_encrypted_extra) is not None:
|
|
try:
|
|
encrypted_config = json.loads(masked_encrypted_extra)
|
|
except (TypeError, json.JSONDecodeError):
|
|
pass
|
|
|
|
try:
|
|
# pylint: disable=useless-suppression
|
|
parameters = self.db_engine_spec.get_parameters_from_uri( # type: ignore
|
|
masked_uri,
|
|
encrypted_extra=encrypted_config,
|
|
)
|
|
except Exception: # pylint: disable=broad-except
|
|
parameters = {}
|
|
|
|
return parameters
|
|
|
|
@property
|
|
def parameters_schema(self) -> dict[str, Any]:
|
|
try:
|
|
parameters_schema = self.db_engine_spec.parameters_json_schema() # type: ignore
|
|
except Exception: # pylint: disable=broad-except
|
|
parameters_schema = {}
|
|
return parameters_schema
|
|
|
|
@property
|
|
def metadata_cache_timeout(self) -> dict[str, Any]:
|
|
return self.get_extra().get("metadata_cache_timeout", {})
|
|
|
|
@property
|
|
def schema_cache_enabled(self) -> bool:
|
|
return "schema_cache_timeout" in self.metadata_cache_timeout
|
|
|
|
@property
|
|
def schema_cache_timeout(self) -> Optional[int]:
|
|
return self.metadata_cache_timeout.get("schema_cache_timeout")
|
|
|
|
@property
|
|
def table_cache_enabled(self) -> bool:
|
|
return "table_cache_timeout" in self.metadata_cache_timeout
|
|
|
|
@property
|
|
def table_cache_timeout(self) -> Optional[int]:
|
|
return self.metadata_cache_timeout.get("table_cache_timeout")
|
|
|
|
@property
|
|
def default_schemas(self) -> list[str]:
|
|
return self.get_extra().get("default_schemas", [])
|
|
|
|
@property
|
|
def connect_args(self) -> dict[str, Any]:
|
|
return self.get_extra().get("engine_params", {}).get("connect_args", {})
|
|
|
|
@property
|
|
def engine_information(self) -> dict[str, Any]:
|
|
try:
|
|
engine_information = self.db_engine_spec.get_public_information()
|
|
except Exception: # pylint: disable=broad-except
|
|
engine_information = {}
|
|
return engine_information
|
|
|
|
@classmethod
|
|
def get_password_masked_url_from_uri( # pylint: disable=invalid-name
|
|
cls, uri: str
|
|
) -> URL:
|
|
sqlalchemy_url = make_url_safe(uri)
|
|
return cls.get_password_masked_url(sqlalchemy_url)
|
|
|
|
@classmethod
|
|
def get_password_masked_url(cls, masked_url: URL) -> URL:
|
|
url_copy = deepcopy(masked_url)
|
|
if url_copy.password is not None:
|
|
url_copy = url_copy.set(password=PASSWORD_MASK)
|
|
return url_copy
|
|
|
|
def set_sqlalchemy_uri(self, uri: str) -> None:
|
|
conn = make_url_safe(uri.strip())
|
|
if conn.password != PASSWORD_MASK and not custom_password_store:
|
|
# do not over-write the password with the password mask
|
|
self.password = conn.password
|
|
conn = conn.set(password=PASSWORD_MASK if conn.password else None)
|
|
self.sqlalchemy_uri = str(conn) # hides the password
|
|
|
|
def get_effective_user(self, object_url: URL) -> Optional[str]:
|
|
"""
|
|
Get the effective user, especially during impersonation.
|
|
|
|
:param object_url: SQL Alchemy URL object
|
|
:return: The effective username
|
|
"""
|
|
|
|
return ( # pylint: disable=used-before-assignment
|
|
username
|
|
if (username := get_username())
|
|
else object_url.username
|
|
if self.impersonate_user
|
|
else None
|
|
)
|
|
|
|
@contextmanager
|
|
def get_sqla_engine_with_context(
|
|
self,
|
|
schema: Optional[str] = None,
|
|
nullpool: bool = True,
|
|
source: Optional[utils.QuerySource] = None,
|
|
override_ssh_tunnel: Optional["SSHTunnel"] = None,
|
|
) -> Engine:
|
|
from superset.databases.dao import ( # pylint: disable=import-outside-toplevel
|
|
DatabaseDAO,
|
|
)
|
|
|
|
sqlalchemy_uri = self.sqlalchemy_uri_decrypted
|
|
engine_context = nullcontext()
|
|
ssh_tunnel = override_ssh_tunnel or DatabaseDAO.get_ssh_tunnel(
|
|
database_id=self.id
|
|
)
|
|
|
|
if ssh_tunnel:
|
|
# if ssh_tunnel is available build engine with information
|
|
engine_context = ssh_manager_factory.instance.create_tunnel(
|
|
ssh_tunnel=ssh_tunnel,
|
|
sqlalchemy_database_uri=self.sqlalchemy_uri_decrypted,
|
|
)
|
|
|
|
with engine_context as server_context:
|
|
if ssh_tunnel and server_context:
|
|
logger.info(
|
|
"[SSH] Successfully created tunnel w/ %s tunnel_timeout + %s ssh_timeout at %s",
|
|
sshtunnel.TUNNEL_TIMEOUT,
|
|
sshtunnel.SSH_TIMEOUT,
|
|
server_context.local_bind_address,
|
|
)
|
|
sqlalchemy_uri = ssh_manager_factory.instance.build_sqla_url(
|
|
sqlalchemy_uri, server_context
|
|
)
|
|
yield self._get_sqla_engine(
|
|
schema=schema,
|
|
nullpool=nullpool,
|
|
source=source,
|
|
sqlalchemy_uri=sqlalchemy_uri,
|
|
)
|
|
|
|
def _get_sqla_engine(
|
|
self,
|
|
schema: Optional[str] = None,
|
|
nullpool: bool = True,
|
|
source: Optional[utils.QuerySource] = None,
|
|
sqlalchemy_uri: Optional[str] = None,
|
|
) -> Engine:
|
|
sqlalchemy_url = make_url_safe(
|
|
sqlalchemy_uri if sqlalchemy_uri else self.sqlalchemy_uri_decrypted
|
|
)
|
|
self.db_engine_spec.validate_database_uri(sqlalchemy_url)
|
|
|
|
extra = self.get_extra()
|
|
params = extra.get("engine_params", {})
|
|
if nullpool:
|
|
params["poolclass"] = NullPool
|
|
connect_args = params.get("connect_args", {})
|
|
|
|
# The ``adjust_database_uri`` method was renamed to ``adjust_engine_params`` and
|
|
# had its signature changed in order to support more DB engine specs. Since DB
|
|
# engine specs can be released as 3rd party modules we want to make sure the old
|
|
# method is still supported so we don't introduce a breaking change.
|
|
if hasattr(self.db_engine_spec, "adjust_database_uri"):
|
|
sqlalchemy_url = self.db_engine_spec.adjust_database_uri(
|
|
sqlalchemy_url,
|
|
schema,
|
|
)
|
|
logger.warning(
|
|
"DB engine spec %s implements the method `adjust_database_uri`, which is "
|
|
"deprecated and will be removed in version 3.0. Please update it to "
|
|
"implement `adjust_engine_params` instead.",
|
|
self.db_engine_spec,
|
|
)
|
|
|
|
sqlalchemy_url, connect_args = self.db_engine_spec.adjust_engine_params(
|
|
uri=sqlalchemy_url,
|
|
connect_args=connect_args,
|
|
catalog=None,
|
|
schema=schema,
|
|
)
|
|
|
|
effective_username = self.get_effective_user(sqlalchemy_url)
|
|
# If using MySQL or Presto for example, will set url.username
|
|
# If using Hive, will not do anything yet since that relies on a
|
|
# configuration parameter instead.
|
|
sqlalchemy_url = self.db_engine_spec.get_url_for_impersonation(
|
|
sqlalchemy_url,
|
|
self.impersonate_user,
|
|
effective_username,
|
|
)
|
|
|
|
masked_url = self.get_password_masked_url(sqlalchemy_url)
|
|
logger.debug("Database._get_sqla_engine(). Masked URL: %s", str(masked_url))
|
|
|
|
if self.impersonate_user:
|
|
self.db_engine_spec.update_impersonation_config(
|
|
connect_args,
|
|
str(sqlalchemy_url),
|
|
effective_username,
|
|
)
|
|
|
|
if connect_args:
|
|
params["connect_args"] = connect_args
|
|
|
|
self.update_params_from_encrypted_extra(params)
|
|
|
|
if DB_CONNECTION_MUTATOR:
|
|
if not source and request and request.referrer:
|
|
if "/superset/dashboard/" in request.referrer:
|
|
source = utils.QuerySource.DASHBOARD
|
|
elif "/explore/" in request.referrer:
|
|
source = utils.QuerySource.CHART
|
|
elif "/superset/sqllab" in request.referrer:
|
|
source = utils.QuerySource.SQL_LAB
|
|
|
|
sqlalchemy_url, params = DB_CONNECTION_MUTATOR(
|
|
sqlalchemy_url,
|
|
params,
|
|
effective_username,
|
|
security_manager,
|
|
source,
|
|
)
|
|
try:
|
|
return create_engine(sqlalchemy_url, **params)
|
|
except Exception as ex:
|
|
raise self.db_engine_spec.get_dbapi_mapped_exception(ex)
|
|
|
|
@contextmanager
|
|
def get_raw_connection(
|
|
self,
|
|
schema: Optional[str] = None,
|
|
nullpool: bool = True,
|
|
source: Optional[utils.QuerySource] = None,
|
|
) -> Connection:
|
|
with self.get_sqla_engine_with_context(
|
|
schema=schema, nullpool=nullpool, source=source
|
|
) as engine:
|
|
with closing(engine.raw_connection()) as conn:
|
|
yield conn
|
|
|
|
def get_default_schema_for_query(self, query: "Query") -> Optional[str]:
|
|
"""
|
|
Return the default schema for a given query.
|
|
|
|
This is used to determine if the user has access to a query that reads from table
|
|
names without a specific schema, eg:
|
|
|
|
SELECT * FROM `foo`
|
|
|
|
The schema of the `foo` table depends on the DB engine spec. Some DB engine specs
|
|
can change the default schema on a per-query basis; in other DB engine specs the
|
|
default schema is defined in the SQLAlchemy URI; and in others the default schema
|
|
might be determined by the database itself (like `public` for Postgres).
|
|
"""
|
|
return self.db_engine_spec.get_default_schema_for_query(self, query)
|
|
|
|
@property
|
|
def quote_identifier(self) -> Callable[[str], str]:
|
|
"""Add quotes to potential identifiter expressions if needed"""
|
|
return self.get_dialect().identifier_preparer.quote
|
|
|
|
def get_reserved_words(self) -> set[str]:
|
|
return self.get_dialect().preparer.reserved_words
|
|
|
|
def get_df( # pylint: disable=too-many-locals
|
|
self,
|
|
sql: str,
|
|
schema: Optional[str] = None,
|
|
mutator: Optional[Callable[[pd.DataFrame], None]] = None,
|
|
) -> pd.DataFrame:
|
|
sqls = self.db_engine_spec.parse_sql(sql)
|
|
engine = self._get_sqla_engine(schema)
|
|
mutate_after_split = config["MUTATE_AFTER_SPLIT"]
|
|
sql_query_mutator = config["SQL_QUERY_MUTATOR"]
|
|
|
|
def needs_conversion(df_series: pd.Series) -> bool:
|
|
return (
|
|
not df_series.empty
|
|
and isinstance(df_series, pd.Series)
|
|
and isinstance(df_series[0], (list, dict))
|
|
)
|
|
|
|
def _log_query(sql: str) -> None:
|
|
if log_query:
|
|
log_query(
|
|
engine.url,
|
|
sql,
|
|
schema,
|
|
__name__,
|
|
security_manager,
|
|
)
|
|
|
|
with self.get_raw_connection(schema=schema) as conn:
|
|
cursor = conn.cursor()
|
|
for sql_ in sqls[:-1]:
|
|
if mutate_after_split:
|
|
sql_ = sql_query_mutator(
|
|
sql_,
|
|
security_manager=security_manager,
|
|
database=None,
|
|
)
|
|
_log_query(sql_)
|
|
self.db_engine_spec.execute(cursor, sql_)
|
|
cursor.fetchall()
|
|
|
|
if mutate_after_split:
|
|
last_sql = sql_query_mutator(
|
|
sqls[-1],
|
|
security_manager=security_manager,
|
|
database=None,
|
|
)
|
|
_log_query(last_sql)
|
|
self.db_engine_spec.execute(cursor, last_sql)
|
|
else:
|
|
_log_query(sqls[-1])
|
|
self.db_engine_spec.execute(cursor, sqls[-1])
|
|
|
|
data = self.db_engine_spec.fetch_data(cursor)
|
|
result_set = SupersetResultSet(
|
|
data, cursor.description, self.db_engine_spec
|
|
)
|
|
df = result_set.to_pandas_df()
|
|
if mutator:
|
|
df = mutator(df)
|
|
|
|
for col, coltype in df.dtypes.to_dict().items():
|
|
if coltype == numpy.object_ and needs_conversion(df[col]):
|
|
df[col] = df[col].apply(utils.json_dumps_w_dates)
|
|
|
|
return df
|
|
|
|
def compile_sqla_query(self, qry: Select, schema: Optional[str] = None) -> str:
|
|
engine = self._get_sqla_engine(schema=schema)
|
|
|
|
sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True}))
|
|
|
|
# pylint: disable=protected-access
|
|
if engine.dialect.identifier_preparer._double_percents: # noqa
|
|
sql = sql.replace("%%", "%")
|
|
|
|
return sql
|
|
|
|
def select_star( # pylint: disable=too-many-arguments
|
|
self,
|
|
table_name: str,
|
|
schema: Optional[str] = None,
|
|
limit: int = 100,
|
|
show_cols: bool = False,
|
|
indent: bool = True,
|
|
latest_partition: bool = False,
|
|
cols: Optional[list[dict[str, Any]]] = None,
|
|
) -> str:
|
|
"""Generates a ``select *`` statement in the proper dialect"""
|
|
eng = self._get_sqla_engine(schema=schema, source=utils.QuerySource.SQL_LAB)
|
|
return self.db_engine_spec.select_star(
|
|
self,
|
|
table_name,
|
|
schema=schema,
|
|
engine=eng,
|
|
limit=limit,
|
|
show_cols=show_cols,
|
|
indent=indent,
|
|
latest_partition=latest_partition,
|
|
cols=cols,
|
|
)
|
|
|
|
def apply_limit_to_sql(
|
|
self, sql: str, limit: int = 1000, force: bool = False
|
|
) -> str:
|
|
if self.db_engine_spec.allow_limit_clause:
|
|
return self.db_engine_spec.apply_limit_to_sql(sql, limit, self, force=force)
|
|
return self.db_engine_spec.apply_top_to_sql(sql, limit)
|
|
|
|
def safe_sqlalchemy_uri(self) -> str:
|
|
return self.sqlalchemy_uri
|
|
|
|
@property
|
|
def inspector(self) -> Inspector:
|
|
engine = self._get_sqla_engine()
|
|
return sqla.inspect(engine)
|
|
|
|
@cache_util.memoized_func(
|
|
key="db:{self.id}:schema:{schema}:table_list",
|
|
cache=cache_manager.cache,
|
|
)
|
|
def get_all_table_names_in_schema( # pylint: disable=unused-argument
|
|
self,
|
|
schema: str,
|
|
cache: bool = False,
|
|
cache_timeout: Optional[int] = None,
|
|
force: bool = False,
|
|
) -> set[tuple[str, str]]:
|
|
"""Parameters need to be passed as keyword arguments.
|
|
|
|
For unused parameters, they are referenced in
|
|
cache_util.memoized_func decorator.
|
|
|
|
:param schema: schema name
|
|
:param cache: whether cache is enabled for the function
|
|
:param cache_timeout: timeout in seconds for the cache
|
|
:param force: whether to force refresh the cache
|
|
:return: The table/schema pairs
|
|
"""
|
|
try:
|
|
with self.get_inspector_with_context() as inspector:
|
|
tables = {
|
|
(table, schema)
|
|
for table in self.db_engine_spec.get_table_names(
|
|
database=self,
|
|
inspector=inspector,
|
|
schema=schema,
|
|
)
|
|
}
|
|
return tables
|
|
except Exception as ex:
|
|
raise self.db_engine_spec.get_dbapi_mapped_exception(ex)
|
|
|
|
@cache_util.memoized_func(
|
|
key="db:{self.id}:schema:{schema}:view_list",
|
|
cache=cache_manager.cache,
|
|
)
|
|
def get_all_view_names_in_schema( # pylint: disable=unused-argument
|
|
self,
|
|
schema: str,
|
|
cache: bool = False,
|
|
cache_timeout: Optional[int] = None,
|
|
force: bool = False,
|
|
) -> set[tuple[str, str]]:
|
|
"""Parameters need to be passed as keyword arguments.
|
|
|
|
For unused parameters, they are referenced in
|
|
cache_util.memoized_func decorator.
|
|
|
|
:param schema: schema name
|
|
:param cache: whether cache is enabled for the function
|
|
:param cache_timeout: timeout in seconds for the cache
|
|
:param force: whether to force refresh the cache
|
|
:return: set of views
|
|
"""
|
|
try:
|
|
with self.get_inspector_with_context() as inspector:
|
|
return {
|
|
(view, schema)
|
|
for view in self.db_engine_spec.get_view_names(
|
|
database=self,
|
|
inspector=inspector,
|
|
schema=schema,
|
|
)
|
|
}
|
|
except Exception as ex:
|
|
raise self.db_engine_spec.get_dbapi_mapped_exception(ex)
|
|
|
|
@contextmanager
|
|
def get_inspector_with_context(
|
|
self, ssh_tunnel: Optional["SSHTunnel"] = None
|
|
) -> Inspector:
|
|
with self.get_sqla_engine_with_context(
|
|
override_ssh_tunnel=ssh_tunnel
|
|
) as engine:
|
|
yield sqla.inspect(engine)
|
|
|
|
@cache_util.memoized_func(
|
|
key="db:{self.id}:schema_list",
|
|
cache=cache_manager.cache,
|
|
)
|
|
def get_all_schema_names( # pylint: disable=unused-argument
|
|
self,
|
|
cache: bool = False,
|
|
cache_timeout: Optional[int] = None,
|
|
force: bool = False,
|
|
ssh_tunnel: Optional["SSHTunnel"] = None,
|
|
) -> list[str]:
|
|
"""Parameters need to be passed as keyword arguments.
|
|
|
|
For unused parameters, they are referenced in
|
|
cache_util.memoized_func decorator.
|
|
|
|
:param cache: whether cache is enabled for the function
|
|
:param cache_timeout: timeout in seconds for the cache
|
|
:param force: whether to force refresh the cache
|
|
:return: schema list
|
|
"""
|
|
try:
|
|
with self.get_inspector_with_context(ssh_tunnel=ssh_tunnel) as inspector:
|
|
return self.db_engine_spec.get_schema_names(inspector)
|
|
except Exception as ex:
|
|
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
|
|
|
|
@property
|
|
def db_engine_spec(self) -> builtins.type[db_engine_specs.BaseEngineSpec]:
|
|
url = make_url_safe(self.sqlalchemy_uri_decrypted)
|
|
return self.get_db_engine_spec(url)
|
|
|
|
@classmethod
|
|
@lru_cache(maxsize=LRU_CACHE_MAX_SIZE)
|
|
def get_db_engine_spec(
|
|
cls, url: URL
|
|
) -> builtins.type[db_engine_specs.BaseEngineSpec]:
|
|
backend = url.get_backend_name()
|
|
try:
|
|
driver = url.get_driver_name()
|
|
except NoSuchModuleError:
|
|
# can't load the driver, fallback for backwards compatibility
|
|
driver = None
|
|
|
|
return db_engine_specs.get_engine_spec(backend, driver)
|
|
|
|
def grains(self) -> tuple[TimeGrain, ...]:
|
|
"""Defines time granularity database-specific expressions.
|
|
|
|
The idea here is to make it easy for users to change the time grain
|
|
from a datetime (maybe the source grain is arbitrary timestamps, daily
|
|
or 5 minutes increments) to another, "truncated" datetime. Since
|
|
each database has slightly different but similar datetime functions,
|
|
this allows a mapping between database engines and actual functions.
|
|
"""
|
|
return self.db_engine_spec.get_time_grains()
|
|
|
|
def get_extra(self) -> dict[str, Any]:
|
|
return self.db_engine_spec.get_extra_params(self)
|
|
|
|
def get_encrypted_extra(self) -> dict[str, Any]:
|
|
encrypted_extra = {}
|
|
if self.encrypted_extra:
|
|
try:
|
|
encrypted_extra = json.loads(self.encrypted_extra)
|
|
except json.JSONDecodeError as ex:
|
|
logger.error(ex, exc_info=True)
|
|
raise ex
|
|
return encrypted_extra
|
|
|
|
# pylint: disable=invalid-name
|
|
def update_params_from_encrypted_extra(self, params: dict[str, Any]) -> None:
|
|
self.db_engine_spec.update_params_from_encrypted_extra(self, params)
|
|
|
|
def get_table(self, table_name: str, schema: Optional[str] = None) -> Table:
|
|
extra = self.get_extra()
|
|
meta = MetaData(**extra.get("metadata_params", {}))
|
|
with self.get_sqla_engine_with_context() as engine:
|
|
return Table(
|
|
table_name,
|
|
meta,
|
|
schema=schema or None,
|
|
autoload=True,
|
|
autoload_with=engine,
|
|
)
|
|
|
|
def get_table_comment(
|
|
self, table_name: str, schema: Optional[str] = None
|
|
) -> Optional[str]:
|
|
with self.get_inspector_with_context() as inspector:
|
|
return self.db_engine_spec.get_table_comment(inspector, table_name, schema)
|
|
|
|
def get_columns(
|
|
self, table_name: str, schema: Optional[str] = None
|
|
) -> list[dict[str, Any]]:
|
|
with self.get_inspector_with_context() as inspector:
|
|
return self.db_engine_spec.get_columns(inspector, table_name, schema)
|
|
|
|
def get_metrics(
|
|
self,
|
|
table_name: str,
|
|
schema: Optional[str] = None,
|
|
) -> list[MetricType]:
|
|
with self.get_inspector_with_context() as inspector:
|
|
return self.db_engine_spec.get_metrics(self, inspector, table_name, schema)
|
|
|
|
def get_indexes(
|
|
self, table_name: str, schema: Optional[str] = None
|
|
) -> list[dict[str, Any]]:
|
|
with self.get_inspector_with_context() as inspector:
|
|
return self.db_engine_spec.get_indexes(self, inspector, table_name, schema)
|
|
|
|
def get_pk_constraint(
|
|
self, table_name: str, schema: Optional[str] = None
|
|
) -> dict[str, Any]:
|
|
with self.get_inspector_with_context() as inspector:
|
|
pk_constraint = inspector.get_pk_constraint(table_name, schema) or {}
|
|
|
|
def _convert(value: Any) -> Any:
|
|
try:
|
|
return utils.base_json_conv(value)
|
|
except TypeError:
|
|
return None
|
|
|
|
return {key: _convert(value) for key, value in pk_constraint.items()}
|
|
|
|
def get_foreign_keys(
|
|
self, table_name: str, schema: Optional[str] = None
|
|
) -> list[dict[str, Any]]:
|
|
with self.get_inspector_with_context() as inspector:
|
|
return inspector.get_foreign_keys(table_name, schema)
|
|
|
|
def get_schema_access_for_file_upload( # pylint: disable=invalid-name
|
|
self,
|
|
) -> list[str]:
|
|
allowed_databases = self.get_extra().get("schemas_allowed_for_file_upload", [])
|
|
|
|
if isinstance(allowed_databases, str):
|
|
allowed_databases = literal_eval(allowed_databases)
|
|
|
|
if hasattr(g, "user"):
|
|
extra_allowed_databases = config["ALLOWED_USER_CSV_SCHEMA_FUNC"](
|
|
self, g.user
|
|
)
|
|
allowed_databases += extra_allowed_databases
|
|
return sorted(set(allowed_databases))
|
|
|
|
@property
|
|
def sqlalchemy_uri_decrypted(self) -> str:
|
|
try:
|
|
conn = make_url_safe(self.sqlalchemy_uri)
|
|
except DatabaseInvalidError:
|
|
# if the URI is invalid, ignore and return a placeholder url
|
|
# (so users see 500 less often)
|
|
return "dialect://invalid_uri"
|
|
if custom_password_store:
|
|
conn = conn.set(password=custom_password_store(conn))
|
|
else:
|
|
conn = conn.set(password=self.password)
|
|
return str(conn)
|
|
|
|
@property
|
|
def sql_url(self) -> str:
|
|
return f"/superset/sql/{self.id}/"
|
|
|
|
@hybrid_property
|
|
def perm(self) -> str:
|
|
return f"[{self.database_name}].(id:{self.id})"
|
|
|
|
@perm.expression # type: ignore
|
|
def perm(cls) -> str: # pylint: disable=no-self-argument
|
|
return (
|
|
"[" + cls.database_name + "].(id:" + expression.cast(cls.id, String) + ")"
|
|
)
|
|
|
|
def get_perm(self) -> str:
|
|
return self.perm # type: ignore
|
|
|
|
def has_table(self, table: Table) -> bool:
|
|
with self.get_sqla_engine_with_context() as engine:
|
|
return engine.has_table(table.table_name, table.schema or None)
|
|
|
|
def has_table_by_name(self, table_name: str, schema: Optional[str] = None) -> bool:
|
|
with self.get_sqla_engine_with_context() as engine:
|
|
return engine.has_table(table_name, schema)
|
|
|
|
@classmethod
|
|
def _has_view(
|
|
cls,
|
|
conn: Connection,
|
|
dialect: Dialect,
|
|
view_name: str,
|
|
schema: Optional[str] = None,
|
|
) -> bool:
|
|
view_names: list[str] = []
|
|
try:
|
|
view_names = dialect.get_view_names(connection=conn, schema=schema)
|
|
except Exception: # pylint: disable=broad-except
|
|
logger.warning("Has view failed", exc_info=True)
|
|
return view_name in view_names
|
|
|
|
def has_view(self, view_name: str, schema: Optional[str] = None) -> bool:
|
|
engine = self._get_sqla_engine()
|
|
return engine.run_callable(self._has_view, engine.dialect, view_name, schema)
|
|
|
|
def has_view_by_name(self, view_name: str, schema: Optional[str] = None) -> bool:
|
|
return self.has_view(view_name=view_name, schema=schema)
|
|
|
|
def get_dialect(self) -> Dialect:
|
|
sqla_url = make_url_safe(self.sqlalchemy_uri_decrypted)
|
|
return sqla_url.get_dialect()()
|
|
|
|
def make_sqla_column_compatible(
|
|
self, sqla_col: ColumnElement, label: Optional[str] = None
|
|
) -> ColumnElement:
|
|
"""Takes a sqlalchemy column object and adds label info if supported by engine.
|
|
:param sqla_col: sqlalchemy column instance
|
|
:param label: alias/label that column is expected to have
|
|
:return: either a sql alchemy column or label instance if supported by engine
|
|
"""
|
|
label_expected = label or sqla_col.name
|
|
# add quotes to tables
|
|
if self.db_engine_spec.allows_alias_in_select:
|
|
label = self.db_engine_spec.make_label_compatible(label_expected)
|
|
sqla_col = sqla_col.label(label)
|
|
sqla_col.key = label_expected
|
|
return sqla_col
|
|
|
|
|
|
sqla.event.listen(Database, "after_insert", security_manager.database_after_insert)
|
|
sqla.event.listen(Database, "after_update", security_manager.database_after_update)
|
|
sqla.event.listen(Database, "after_delete", security_manager.database_after_delete)
|
|
|
|
|
|
class Log(Model): # pylint: disable=too-few-public-methods
|
|
|
|
"""ORM object used to log Superset actions to the database"""
|
|
|
|
__tablename__ = "logs"
|
|
|
|
id = Column(Integer, primary_key=True)
|
|
action = Column(String(512))
|
|
user_id = Column(Integer, ForeignKey("ab_user.id"))
|
|
dashboard_id = Column(Integer)
|
|
slice_id = Column(Integer)
|
|
json = Column(Text)
|
|
user = relationship(
|
|
security_manager.user_model, backref="logs", foreign_keys=[user_id]
|
|
)
|
|
dttm = Column(DateTime, default=datetime.utcnow)
|
|
duration_ms = Column(Integer)
|
|
referrer = Column(String(1024))
|
|
|
|
|
|
class FavStarClassName(str, enum.Enum):
|
|
CHART = "slice"
|
|
DASHBOARD = "Dashboard"
|
|
|
|
|
|
class FavStar(Model): # pylint: disable=too-few-public-methods
|
|
__tablename__ = "favstar"
|
|
|
|
id = Column(Integer, primary_key=True)
|
|
user_id = Column(Integer, ForeignKey("ab_user.id"))
|
|
class_name = Column(String(50))
|
|
obj_id = Column(Integer)
|
|
dttm = Column(DateTime, default=datetime.utcnow)
|