[typing] superset/models/core.py (#8284)
This commit is contained in:
parent
4c35de1d1f
commit
9a29116d6b
|
|
@ -505,7 +505,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
return utils.error_msg_from_exception(e)
|
||||
|
||||
@classmethod
|
||||
def adjust_database_uri(cls, uri, selected_schema: str):
|
||||
def adjust_database_uri(cls, uri, selected_schema: Optional[str]):
|
||||
"""Based on a URI and selected schema, return a new URI
|
||||
|
||||
The URI here represents the URI as entered when saving the database,
|
||||
|
|
@ -718,19 +718,21 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
return costs
|
||||
|
||||
@classmethod
|
||||
def modify_url_for_impersonation(cls, url, impersonate_user: bool, username: str):
|
||||
def modify_url_for_impersonation(
|
||||
cls, url, impersonate_user: bool, username: Optional[str]
|
||||
):
|
||||
"""
|
||||
Modify the SQL Alchemy URL object with the user to impersonate if applicable.
|
||||
:param url: SQLAlchemy URL object
|
||||
:param impersonate_user: Flag indicating if impersonation is enabled
|
||||
:param username: Effective username
|
||||
"""
|
||||
if impersonate_user is not None and username is not None:
|
||||
if impersonate_user and username is not None:
|
||||
url.username = username
|
||||
|
||||
@classmethod
|
||||
def get_configuration_for_impersonation( # pylint: disable=invalid-name
|
||||
cls, uri: str, impersonate_user: bool, username: str
|
||||
cls, uri: str, impersonate_user: bool, username: Optional[str]
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Return a configuration dictionary that can be merged with other configs
|
||||
|
|
|
|||
|
|
@ -376,7 +376,9 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def modify_url_for_impersonation(cls, url, impersonate_user: bool, username: str):
|
||||
def modify_url_for_impersonation(
|
||||
cls, url, impersonate_user: bool, username: Optional[str]
|
||||
):
|
||||
"""
|
||||
Modify the SQL Alchemy URL object with the user to impersonate if applicable.
|
||||
:param url: SQLAlchemy URL object
|
||||
|
|
@ -389,7 +391,7 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||
|
||||
@classmethod
|
||||
def get_configuration_for_impersonation(
|
||||
cls, uri: str, impersonate_user: bool, username: str
|
||||
cls, uri: str, impersonate_user: bool, username: Optional[str]
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Return a configuration dictionary that can be merged with other configs
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ import textwrap
|
|||
from contextlib import closing
|
||||
from copy import copy, deepcopy
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TYPE_CHECKING
|
||||
from urllib import parse
|
||||
|
||||
import numpy
|
||||
|
|
@ -45,22 +45,29 @@ from sqlalchemy import (
|
|||
Table,
|
||||
Text,
|
||||
)
|
||||
from sqlalchemy.engine import url
|
||||
from sqlalchemy.engine.url import make_url
|
||||
from sqlalchemy.engine import Dialect, Engine, url
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.url import make_url, URL
|
||||
from sqlalchemy.orm import relationship, sessionmaker, subqueryload
|
||||
from sqlalchemy.orm.session import make_transient
|
||||
from sqlalchemy.pool import NullPool
|
||||
from sqlalchemy.schema import UniqueConstraint
|
||||
from sqlalchemy.sql import Select
|
||||
from sqlalchemy_utils import EncryptedType
|
||||
|
||||
from superset import app, db, db_engine_specs, is_feature_enabled, security_manager
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.db_engine_specs.base import TimeGrain
|
||||
from superset.legacy import update_time_range
|
||||
from superset.models.helpers import AuditMixinNullable, ImportMixin
|
||||
from superset.models.tags import ChartUpdater, DashboardUpdater, FavStarUpdater
|
||||
from superset.models.user_attributes import UserAttribute
|
||||
from superset.utils import cache as cache_util, core as utils
|
||||
from superset.viz import viz_types
|
||||
from superset.viz import BaseViz, viz_types
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.connectors.base.models import BaseDatasource
|
||||
|
||||
|
||||
config = app.config
|
||||
custom_password_store = config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"]
|
||||
|
|
@ -180,14 +187,14 @@ class Slice(Model, AuditMixinNullable, ImportMixin):
|
|||
return self.slice_name or str(self.id)
|
||||
|
||||
@property
|
||||
def cls_model(self):
|
||||
def cls_model(self) -> Type["BaseDatasource"]:
|
||||
return ConnectorRegistry.sources[self.datasource_type]
|
||||
|
||||
@property
|
||||
def datasource(self):
|
||||
def datasource(self) -> "BaseDatasource":
|
||||
return self.get_datasource
|
||||
|
||||
def clone(self):
|
||||
def clone(self) -> "Slice":
|
||||
return Slice(
|
||||
slice_name=self.slice_name,
|
||||
datasource_id=self.datasource_id,
|
||||
|
|
@ -201,46 +208,45 @@ class Slice(Model, AuditMixinNullable, ImportMixin):
|
|||
|
||||
@datasource.getter # type: ignore
|
||||
@utils.memoized
|
||||
def get_datasource(self):
|
||||
def get_datasource(self) -> Optional["BaseDatasource"]:
|
||||
return db.session.query(self.cls_model).filter_by(id=self.datasource_id).first()
|
||||
|
||||
@renders("datasource_name")
|
||||
def datasource_link(self):
|
||||
def datasource_link(self) -> Optional[Markup]:
|
||||
# pylint: disable=no-member
|
||||
datasource = self.datasource
|
||||
return datasource.link if datasource else None
|
||||
|
||||
def datasource_name_text(self):
|
||||
def datasource_name_text(self) -> Optional[str]:
|
||||
# pylint: disable=no-member
|
||||
datasource = self.datasource
|
||||
return datasource.name if datasource else None
|
||||
|
||||
@property
|
||||
def datasource_edit_url(self):
|
||||
def datasource_edit_url(self) -> Optional[str]:
|
||||
# pylint: disable=no-member
|
||||
datasource = self.datasource
|
||||
return datasource.url if datasource else None
|
||||
|
||||
@property # type: ignore
|
||||
@utils.memoized
|
||||
def viz(self):
|
||||
def viz(self) -> BaseViz:
|
||||
d = json.loads(self.params)
|
||||
viz_class = viz_types[self.viz_type]
|
||||
# pylint: disable=no-member
|
||||
return viz_class(datasource=self.datasource, form_data=d)
|
||||
|
||||
@property
|
||||
def description_markeddown(self):
|
||||
def description_markeddown(self) -> str:
|
||||
return utils.markdown(self.description)
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
def data(self) -> Dict[str, Any]:
|
||||
"""Data used to render slice in templates"""
|
||||
d = {}
|
||||
d: Dict[str, Any] = {}
|
||||
self.token = ""
|
||||
try:
|
||||
d = self.viz.data
|
||||
self.token = d.get("token")
|
||||
self.token = d.get("token") # type: ignore
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
d["error"] = str(e)
|
||||
|
|
@ -259,12 +265,12 @@ class Slice(Model, AuditMixinNullable, ImportMixin):
|
|||
}
|
||||
|
||||
@property
|
||||
def json_data(self):
|
||||
def json_data(self) -> str:
|
||||
return json.dumps(self.data)
|
||||
|
||||
@property
|
||||
def form_data(self):
|
||||
form_data = {}
|
||||
def form_data(self) -> Dict[str, Any]:
|
||||
form_data: Dict[str, Any] = {}
|
||||
try:
|
||||
form_data = json.loads(self.params)
|
||||
except Exception as e:
|
||||
|
|
@ -283,7 +289,11 @@ class Slice(Model, AuditMixinNullable, ImportMixin):
|
|||
update_time_range(form_data)
|
||||
return form_data
|
||||
|
||||
def get_explore_url(self, base_url="/superset/explore", overrides=None):
|
||||
def get_explore_url(
|
||||
self,
|
||||
base_url: str = "/superset/explore",
|
||||
overrides: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
overrides = overrides or {}
|
||||
form_data = {"slice_id": self.id}
|
||||
form_data.update(overrides)
|
||||
|
|
@ -291,30 +301,30 @@ class Slice(Model, AuditMixinNullable, ImportMixin):
|
|||
return f"{base_url}/?form_data={params}"
|
||||
|
||||
@property
|
||||
def slice_url(self):
|
||||
def slice_url(self) -> str:
|
||||
"""Defines the url to access the slice"""
|
||||
return self.get_explore_url()
|
||||
|
||||
@property
|
||||
def explore_json_url(self):
|
||||
def explore_json_url(self) -> str:
|
||||
"""Defines the url to access the slice"""
|
||||
return self.get_explore_url("/superset/explore_json")
|
||||
|
||||
@property
|
||||
def edit_url(self):
|
||||
return "/chart/edit/{}".format(self.id)
|
||||
def edit_url(self) -> str:
|
||||
return f"/chart/edit/{self.id}"
|
||||
|
||||
@property
|
||||
def chart(self):
|
||||
def chart(self) -> str:
|
||||
return self.slice_name or "<empty>"
|
||||
|
||||
@property
|
||||
def slice_link(self):
|
||||
def slice_link(self) -> Markup:
|
||||
url = self.slice_url
|
||||
name = escape(self.chart)
|
||||
return Markup(f'<a href="{url}">{name}</a>')
|
||||
|
||||
def get_viz(self, force=False):
|
||||
def get_viz(self, force: bool = False) -> BaseViz:
|
||||
"""Creates :py:class:viz.BaseViz object from the url_params_multidict.
|
||||
|
||||
:return: object of the 'viz_type' type that is taken from the
|
||||
|
|
@ -332,7 +342,7 @@ class Slice(Model, AuditMixinNullable, ImportMixin):
|
|||
)
|
||||
|
||||
@property
|
||||
def icons(self):
|
||||
def icons(self) -> str:
|
||||
return f"""
|
||||
<a
|
||||
href="{self.datasource_edit_url}"
|
||||
|
|
@ -343,7 +353,12 @@ class Slice(Model, AuditMixinNullable, ImportMixin):
|
|||
"""
|
||||
|
||||
@classmethod
|
||||
def import_obj(cls, slc_to_import, slc_to_override, import_time=None):
|
||||
def import_obj(
|
||||
cls,
|
||||
slc_to_import: "Slice",
|
||||
slc_to_override: Optional["Slice"],
|
||||
import_time: Optional[int] = None,
|
||||
) -> int:
|
||||
"""Inserts or overrides slc in the database.
|
||||
|
||||
remote_id and import_time fields in params_dict are set to track the
|
||||
|
|
@ -363,7 +378,7 @@ class Slice(Model, AuditMixinNullable, ImportMixin):
|
|||
slc_to_import = slc_to_import.copy()
|
||||
slc_to_import.reset_ownership()
|
||||
params = slc_to_import.params_dict
|
||||
slc_to_import.datasource_id = ConnectorRegistry.get_datasource_by_name(
|
||||
slc_to_import.datasource_id = ConnectorRegistry.get_datasource_by_name( # type: ignore
|
||||
session,
|
||||
slc_to_import.datasource_type,
|
||||
params["datasource_name"],
|
||||
|
|
@ -380,10 +395,8 @@ class Slice(Model, AuditMixinNullable, ImportMixin):
|
|||
return slc_to_import.id
|
||||
|
||||
@property
|
||||
def url(self):
|
||||
return "/superset/explore/?form_data=%7B%22slice_id%22%3A%20{0}%7D".format(
|
||||
self.id
|
||||
)
|
||||
def url(self) -> str:
|
||||
return f"/superset/explore/?form_data=%7B%22slice_id%22%3A%20{self.id}%7D"
|
||||
|
||||
|
||||
sqla.event.listen(Slice, "before_insert", set_related_perm)
|
||||
|
|
@ -437,12 +450,12 @@ class Dashboard(Model, AuditMixinNullable, ImportMixin):
|
|||
return self.dashboard_title or str(self.id)
|
||||
|
||||
@property
|
||||
def table_names(self):
|
||||
def table_names(self) -> str:
|
||||
# pylint: disable=no-member
|
||||
return ", ".join({"{}".format(s.datasource.full_name) for s in self.slices})
|
||||
return ", ".join(str(s.datasource.full_name) for s in self.slices)
|
||||
|
||||
@property
|
||||
def url(self):
|
||||
def url(self) -> str:
|
||||
if self.json_metadata:
|
||||
# add default_filters to the preselect_filters of dashboard
|
||||
json_metadata = json.loads(self.json_metadata)
|
||||
|
|
@ -457,28 +470,28 @@ class Dashboard(Model, AuditMixinNullable, ImportMixin):
|
|||
)
|
||||
except Exception:
|
||||
pass
|
||||
return "/superset/dashboard/{}/".format(self.slug or self.id)
|
||||
return f"/superset/dashboard/{self.slug or self.id}/"
|
||||
|
||||
@property
|
||||
def datasources(self):
|
||||
def datasources(self) -> Set[Optional["BaseDatasource"]]:
|
||||
return {slc.datasource for slc in self.slices}
|
||||
|
||||
@property
|
||||
def charts(self):
|
||||
def charts(self) -> List[Optional["BaseDatasource"]]:
|
||||
return [slc.chart for slc in self.slices]
|
||||
|
||||
@property
|
||||
def sqla_metadata(self):
|
||||
def sqla_metadata(self) -> None:
|
||||
# pylint: disable=no-member
|
||||
metadata = MetaData(bind=self.get_sqla_engine())
|
||||
return metadata.reflect()
|
||||
metadata.reflect()
|
||||
|
||||
def dashboard_link(self):
|
||||
def dashboard_link(self) -> Markup:
|
||||
title = escape(self.dashboard_title or "<empty>")
|
||||
return Markup(f'<a href="{self.url}">{title}</a>')
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
def data(self) -> Dict[str, Any]:
|
||||
positions = self.position_json
|
||||
if positions:
|
||||
positions = json.loads(positions)
|
||||
|
|
@ -494,21 +507,23 @@ class Dashboard(Model, AuditMixinNullable, ImportMixin):
|
|||
}
|
||||
|
||||
@property
|
||||
def params(self):
|
||||
def params(self) -> str:
|
||||
return self.json_metadata
|
||||
|
||||
@params.setter
|
||||
def params(self, value):
|
||||
def params(self, value: str) -> None:
|
||||
self.json_metadata = value
|
||||
|
||||
@property
|
||||
def position(self):
|
||||
def position(self) -> Dict:
|
||||
if self.position_json:
|
||||
return json.loads(self.position_json)
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def import_obj(cls, dashboard_to_import, import_time=None):
|
||||
def import_obj(
|
||||
cls, dashboard_to_import: "Dashboard", import_time: Optional[int] = None
|
||||
) -> int:
|
||||
"""Imports the dashboard from the object to the database.
|
||||
|
||||
Once dashboard is imported, json_metadata field is extended and stores
|
||||
|
|
@ -652,10 +667,10 @@ class Dashboard(Model, AuditMixinNullable, ImportMixin):
|
|||
dashboard_to_import.slices = new_slices
|
||||
session.add(dashboard_to_import)
|
||||
session.flush()
|
||||
return dashboard_to_import.id
|
||||
return dashboard_to_import.id # type: ignore
|
||||
|
||||
@classmethod
|
||||
def export_dashboards(cls, dashboard_ids):
|
||||
def export_dashboards(cls, dashboard_ids: List) -> str:
|
||||
copied_dashboards = []
|
||||
datasource_ids = set()
|
||||
for dashboard_id in dashboard_ids:
|
||||
|
|
@ -688,22 +703,22 @@ class Dashboard(Model, AuditMixinNullable, ImportMixin):
|
|||
copied_dashboard.alter_params(remote_id=dashboard_id)
|
||||
copied_dashboards.append(copied_dashboard)
|
||||
|
||||
eager_datasources = []
|
||||
for datasource_id, datasource_type in datasource_ids:
|
||||
eager_datasource = ConnectorRegistry.get_eager_datasource(
|
||||
db.session, datasource_type, datasource_id
|
||||
)
|
||||
copied_datasource = eager_datasource.copy()
|
||||
copied_datasource.alter_params(
|
||||
remote_id=eager_datasource.id,
|
||||
database_name=eager_datasource.database.name,
|
||||
)
|
||||
datasource_class = copied_datasource.__class__
|
||||
for field_name in datasource_class.export_children:
|
||||
field_val = getattr(eager_datasource, field_name).copy()
|
||||
# set children without creating ORM relations
|
||||
copied_datasource.__dict__[field_name] = field_val
|
||||
eager_datasources.append(copied_datasource)
|
||||
eager_datasources = []
|
||||
for datasource_id, datasource_type in datasource_ids:
|
||||
eager_datasource = ConnectorRegistry.get_eager_datasource(
|
||||
db.session, datasource_type, datasource_id
|
||||
)
|
||||
copied_datasource = eager_datasource.copy()
|
||||
copied_datasource.alter_params(
|
||||
remote_id=eager_datasource.id,
|
||||
database_name=eager_datasource.database.name,
|
||||
)
|
||||
datasource_class = copied_datasource.__class__
|
||||
for field_name in datasource_class.export_children:
|
||||
field_val = getattr(eager_datasource, field_name).copy()
|
||||
# set children without creating ORM relations
|
||||
copied_datasource.__dict__[field_name] = field_val
|
||||
eager_datasources.append(copied_datasource)
|
||||
|
||||
return json.dumps(
|
||||
{"dashboards": copied_dashboards, "datasources": eager_datasources},
|
||||
|
|
@ -767,25 +782,27 @@ class Database(Model, AuditMixinNullable, ImportMixin):
|
|||
return self.name
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return self.verbose_name if self.verbose_name else self.database_name
|
||||
|
||||
@property
|
||||
def allows_subquery(self):
|
||||
def allows_subquery(self) -> bool:
|
||||
return self.db_engine_spec.allows_subqueries
|
||||
|
||||
@property
|
||||
def allows_cost_estimate(self) -> bool:
|
||||
extra = self.get_extra()
|
||||
|
||||
database_version = extra.get("version")
|
||||
cost_estimate_enabled = extra.get("cost_estimate_enabled")
|
||||
cost_estimate_enabled: bool = extra.get("cost_estimate_enabled") # type: ignore
|
||||
|
||||
return (
|
||||
self.db_engine_spec.get_allow_cost_estimate(database_version)
|
||||
and cost_estimate_enabled
|
||||
)
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
def data(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.database_name,
|
||||
|
|
@ -796,55 +813,55 @@ class Database(Model, AuditMixinNullable, ImportMixin):
|
|||
}
|
||||
|
||||
@property
|
||||
def unique_name(self):
|
||||
def unique_name(self) -> str:
|
||||
return self.database_name
|
||||
|
||||
@property
|
||||
def url_object(self):
|
||||
def url_object(self) -> URL:
|
||||
return make_url(self.sqlalchemy_uri_decrypted)
|
||||
|
||||
@property
|
||||
def backend(self):
|
||||
def backend(self) -> str:
|
||||
url = make_url(self.sqlalchemy_uri_decrypted)
|
||||
return url.get_backend_name()
|
||||
|
||||
@property
|
||||
def metadata_cache_timeout(self):
|
||||
def metadata_cache_timeout(self) -> Dict[str, Any]:
|
||||
return self.get_extra().get("metadata_cache_timeout", {})
|
||||
|
||||
@property
|
||||
def schema_cache_enabled(self):
|
||||
def schema_cache_enabled(self) -> bool:
|
||||
return "schema_cache_timeout" in self.metadata_cache_timeout
|
||||
|
||||
@property
|
||||
def schema_cache_timeout(self):
|
||||
def schema_cache_timeout(self) -> Optional[int]:
|
||||
return self.metadata_cache_timeout.get("schema_cache_timeout")
|
||||
|
||||
@property
|
||||
def table_cache_enabled(self):
|
||||
def table_cache_enabled(self) -> bool:
|
||||
return "table_cache_timeout" in self.metadata_cache_timeout
|
||||
|
||||
@property
|
||||
def table_cache_timeout(self):
|
||||
def table_cache_timeout(self) -> Optional[int]:
|
||||
return self.metadata_cache_timeout.get("table_cache_timeout")
|
||||
|
||||
@property
|
||||
def default_schemas(self):
|
||||
def default_schemas(self) -> List[str]:
|
||||
return self.get_extra().get("default_schemas", [])
|
||||
|
||||
@classmethod
|
||||
def get_password_masked_url_from_uri(cls, uri):
|
||||
def get_password_masked_url_from_uri(cls, uri: str):
|
||||
url = make_url(uri)
|
||||
return cls.get_password_masked_url(url)
|
||||
|
||||
@classmethod
|
||||
def get_password_masked_url(cls, url):
|
||||
def get_password_masked_url(cls, url: URL) -> URL:
|
||||
url_copy = deepcopy(url)
|
||||
if url_copy.password is not None and url_copy.password != PASSWORD_MASK:
|
||||
if url_copy.password is not None:
|
||||
url_copy.password = PASSWORD_MASK
|
||||
return url_copy
|
||||
|
||||
def set_sqlalchemy_uri(self, uri):
|
||||
def set_sqlalchemy_uri(self, uri: str) -> None:
|
||||
conn = sqla.engine.url.make_url(uri.strip())
|
||||
if conn.password != PASSWORD_MASK and not custom_password_store:
|
||||
# do not over-write the password with the password mask
|
||||
|
|
@ -852,7 +869,9 @@ class Database(Model, AuditMixinNullable, ImportMixin):
|
|||
conn.password = PASSWORD_MASK if conn.password else None
|
||||
self.sqlalchemy_uri = str(conn) # hides the password
|
||||
|
||||
def get_effective_user(self, url, user_name=None):
|
||||
def get_effective_user(
|
||||
self, url: URL, user_name: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get the effective user, especially during impersonation.
|
||||
:param url: SQL Alchemy URL object
|
||||
|
|
@ -873,7 +892,13 @@ class Database(Model, AuditMixinNullable, ImportMixin):
|
|||
return effective_username
|
||||
|
||||
@utils.memoized(watch=("impersonate_user", "sqlalchemy_uri_decrypted", "extra"))
|
||||
def get_sqla_engine(self, schema=None, nullpool=True, user_name=None, source=None):
|
||||
def get_sqla_engine(
|
||||
self,
|
||||
schema: Optional[str] = None,
|
||||
nullpool: bool = True,
|
||||
user_name: Optional[str] = None,
|
||||
source: Optional[int] = None,
|
||||
) -> Engine:
|
||||
extra = self.get_extra()
|
||||
url = make_url(self.sqlalchemy_uri_decrypted)
|
||||
url = self.db_engine_spec.adjust_database_uri(url, schema)
|
||||
|
|
@ -893,7 +918,7 @@ class Database(Model, AuditMixinNullable, ImportMixin):
|
|||
params["poolclass"] = NullPool
|
||||
|
||||
# If using Hive, this will set hive.server2.proxy.user=$effective_username
|
||||
configuration = {}
|
||||
configuration: Dict[str, Any] = {}
|
||||
configuration.update(
|
||||
self.db_engine_spec.get_configuration_for_impersonation(
|
||||
str(url), self.impersonate_user, effective_username
|
||||
|
|
@ -913,14 +938,16 @@ class Database(Model, AuditMixinNullable, ImportMixin):
|
|||
)
|
||||
return create_engine(url, **params)
|
||||
|
||||
def get_reserved_words(self):
|
||||
def get_reserved_words(self) -> Set[str]:
|
||||
return self.get_dialect().preparer.reserved_words
|
||||
|
||||
def get_quoter(self):
|
||||
return self.get_dialect().identifier_preparer.quote
|
||||
|
||||
def get_df(self, sql, schema, mutator=None):
|
||||
sqls = [str(s).strip().strip(";") for s in sqlparse.parse(sql)]
|
||||
def get_df(
|
||||
self, sql: str, schema: str, mutator: Optional[Callable] = None
|
||||
) -> pd.DataFrame:
|
||||
sqls = [str(s).strip(" ;") for s in sqlparse.parse(sql)]
|
||||
source_key = None
|
||||
if request and request.referrer:
|
||||
if "/superset/dashboard/" in request.referrer:
|
||||
|
|
@ -928,18 +955,14 @@ class Database(Model, AuditMixinNullable, ImportMixin):
|
|||
elif "/superset/explore/" in request.referrer:
|
||||
source_key = "chart"
|
||||
engine = self.get_sqla_engine(
|
||||
schema=schema, source=utils.sources.get(source_key, None)
|
||||
schema=schema, source=utils.sources[source_key] if source_key else None
|
||||
)
|
||||
username = utils.get_username()
|
||||
|
||||
def needs_conversion(df_series):
|
||||
if df_series.empty:
|
||||
return False
|
||||
if isinstance(df_series[0], (list, dict)):
|
||||
return True
|
||||
return False
|
||||
def needs_conversion(df_series: pd.Series) -> bool:
|
||||
return not df_series.empty and isinstance(df_series[0], (list, dict))
|
||||
|
||||
def _log_query(sql):
|
||||
def _log_query(sql: str) -> None:
|
||||
if log_query:
|
||||
log_query(engine.url, sql, schema, username, __name__, security_manager)
|
||||
|
||||
|
|
@ -970,7 +993,7 @@ class Database(Model, AuditMixinNullable, ImportMixin):
|
|||
df[k] = df[k].apply(utils.json_dumps_w_dates)
|
||||
return df
|
||||
|
||||
def compile_sqla_query(self, qry, schema=None):
|
||||
def compile_sqla_query(self, qry: Select, schema: Optional[str] = None) -> str:
|
||||
engine = self.get_sqla_engine(schema=schema)
|
||||
|
||||
sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True}))
|
||||
|
|
@ -982,13 +1005,13 @@ class Database(Model, AuditMixinNullable, ImportMixin):
|
|||
|
||||
def select_star(
|
||||
self,
|
||||
table_name,
|
||||
schema=None,
|
||||
limit=100,
|
||||
show_cols=False,
|
||||
indent=True,
|
||||
latest_partition=False,
|
||||
cols=None,
|
||||
table_name: str,
|
||||
schema: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
show_cols: bool = False,
|
||||
indent: bool = True,
|
||||
latest_partition: bool = False,
|
||||
cols: Optional[List[Dict[str, Any]]] = None,
|
||||
):
|
||||
"""Generates a ``select *`` statement in the proper dialect"""
|
||||
eng = self.get_sqla_engine(
|
||||
|
|
@ -1006,14 +1029,14 @@ class Database(Model, AuditMixinNullable, ImportMixin):
|
|||
cols=cols,
|
||||
)
|
||||
|
||||
def apply_limit_to_sql(self, sql, limit=1000):
|
||||
def apply_limit_to_sql(self, sql: str, limit: int = 1000) -> str:
|
||||
return self.db_engine_spec.apply_limit_to_sql(sql, limit, self)
|
||||
|
||||
def safe_sqlalchemy_uri(self):
|
||||
def safe_sqlalchemy_uri(self) -> str:
|
||||
return self.sqlalchemy_uri
|
||||
|
||||
@property
|
||||
def inspector(self):
|
||||
def inspector(self) -> Inspector:
|
||||
engine = self.get_sqla_engine()
|
||||
return sqla.inspect(engine)
|
||||
|
||||
|
|
@ -1030,7 +1053,8 @@ class Database(Model, AuditMixinNullable, ImportMixin):
|
|||
return self.db_engine_spec.get_all_datasource_names(self, "table")
|
||||
|
||||
@cache_util.memoized_func(
|
||||
key=lambda *args, **kwargs: "db:{}:schema:None:view_list", attribute_in_key="id"
|
||||
key=lambda *args, **kwargs: "db:{}:schema:None:view_list",
|
||||
attribute_in_key="id", # type: ignore
|
||||
)
|
||||
def get_all_view_names_in_database(
|
||||
self, cache: bool = False, cache_timeout: bool = None, force: bool = False
|
||||
|
|
@ -1041,9 +1065,7 @@ class Database(Model, AuditMixinNullable, ImportMixin):
|
|||
return self.db_engine_spec.get_all_datasource_names(self, "view")
|
||||
|
||||
@cache_util.memoized_func(
|
||||
key=lambda *args, **kwargs: "db:{{}}:schema:{}:table_list".format(
|
||||
kwargs.get("schema")
|
||||
),
|
||||
key=lambda *args, **kwargs: f"db:{{}}:schema:{kwargs.get('schema')}:table_list", # type: ignore
|
||||
attribute_in_key="id",
|
||||
)
|
||||
def get_all_table_names_in_schema(
|
||||
|
|
@ -1052,7 +1074,7 @@ class Database(Model, AuditMixinNullable, ImportMixin):
|
|||
cache: bool = False,
|
||||
cache_timeout: int = None,
|
||||
force: bool = False,
|
||||
):
|
||||
) -> List[utils.DatasourceName]:
|
||||
"""Parameters need to be passed as keyword arguments.
|
||||
|
||||
For unused parameters, they are referenced in
|
||||
|
|
@ -1075,9 +1097,7 @@ class Database(Model, AuditMixinNullable, ImportMixin):
|
|||
logging.exception(e)
|
||||
|
||||
@cache_util.memoized_func(
|
||||
key=lambda *args, **kwargs: "db:{{}}:schema:{}:view_list".format(
|
||||
kwargs.get("schema")
|
||||
),
|
||||
key=lambda *args, **kwargs: f"db:{{}}:schema:{kwargs.get('schema')}:view_list", # type: ignore
|
||||
attribute_in_key="id",
|
||||
)
|
||||
def get_all_view_names_in_schema(
|
||||
|
|
@ -1086,7 +1106,7 @@ class Database(Model, AuditMixinNullable, ImportMixin):
|
|||
cache: bool = False,
|
||||
cache_timeout: int = None,
|
||||
force: bool = False,
|
||||
):
|
||||
) -> List[utils.DatasourceName]:
|
||||
"""Parameters need to be passed as keyword arguments.
|
||||
|
||||
For unused parameters, they are referenced in
|
||||
|
|
@ -1125,14 +1145,16 @@ class Database(Model, AuditMixinNullable, ImportMixin):
|
|||
return self.db_engine_spec.get_schema_names(self.inspector)
|
||||
|
||||
@property
|
||||
def db_engine_spec(self):
|
||||
def db_engine_spec(self) -> Type[db_engine_specs.BaseEngineSpec]:
|
||||
return db_engine_specs.engines.get(self.backend, db_engine_specs.BaseEngineSpec)
|
||||
|
||||
@classmethod
|
||||
def get_db_engine_spec_for_backend(cls, backend):
|
||||
def get_db_engine_spec_for_backend(
|
||||
cls, backend
|
||||
) -> Type[db_engine_specs.BaseEngineSpec]:
|
||||
return db_engine_specs.engines.get(backend, db_engine_specs.BaseEngineSpec)
|
||||
|
||||
def grains(self):
|
||||
def grains(self) -> Tuple[TimeGrain, ...]:
|
||||
"""Defines time granularity database-specific expressions.
|
||||
|
||||
The idea here is to make it easy for users to change the time grain
|
||||
|
|
@ -1143,8 +1165,8 @@ class Database(Model, AuditMixinNullable, ImportMixin):
|
|||
"""
|
||||
return self.db_engine_spec.get_time_grains()
|
||||
|
||||
def get_extra(self):
|
||||
extra = {}
|
||||
def get_extra(self) -> Dict[str, Any]:
|
||||
extra: Dict[str, Any] = {}
|
||||
if self.extra:
|
||||
try:
|
||||
extra = json.loads(self.extra)
|
||||
|
|
@ -1163,7 +1185,7 @@ class Database(Model, AuditMixinNullable, ImportMixin):
|
|||
raise e
|
||||
return encrypted_extra
|
||||
|
||||
def get_table(self, table_name, schema=None):
|
||||
def get_table(self, table_name: str, schema: Optional[str] = None) -> Table:
|
||||
extra = self.get_extra()
|
||||
meta = MetaData(**extra.get("metadata_params", {}))
|
||||
return Table(
|
||||
|
|
@ -1174,23 +1196,31 @@ class Database(Model, AuditMixinNullable, ImportMixin):
|
|||
autoload_with=self.get_sqla_engine(),
|
||||
)
|
||||
|
||||
def get_columns(self, table_name, schema=None):
|
||||
def get_columns(
|
||||
self, table_name: str, schema: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
return self.db_engine_spec.get_columns(self.inspector, table_name, schema)
|
||||
|
||||
def get_indexes(self, table_name, schema=None):
|
||||
def get_indexes(
|
||||
self, table_name: str, schema: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
return self.inspector.get_indexes(table_name, schema)
|
||||
|
||||
def get_pk_constraint(self, table_name, schema=None):
|
||||
def get_pk_constraint(
|
||||
self, table_name: str, schema: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
return self.inspector.get_pk_constraint(table_name, schema)
|
||||
|
||||
def get_foreign_keys(self, table_name, schema=None):
|
||||
def get_foreign_keys(
|
||||
self, table_name: str, schema: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
return self.inspector.get_foreign_keys(table_name, schema)
|
||||
|
||||
def get_schema_access_for_csv_upload(self):
|
||||
def get_schema_access_for_csv_upload(self) -> List[str]:
|
||||
return self.get_extra().get("schemas_allowed_for_csv_upload", [])
|
||||
|
||||
@property
|
||||
def sqlalchemy_uri_decrypted(self):
|
||||
def sqlalchemy_uri_decrypted(self) -> str:
|
||||
conn = sqla.engine.url.make_url(self.sqlalchemy_uri)
|
||||
if custom_password_store:
|
||||
conn.password = custom_password_store(conn)
|
||||
|
|
@ -1199,22 +1229,22 @@ class Database(Model, AuditMixinNullable, ImportMixin):
|
|||
return str(conn)
|
||||
|
||||
@property
|
||||
def sql_url(self):
|
||||
return "/superset/sql/{}/".format(self.id)
|
||||
def sql_url(self) -> str:
|
||||
return f"/superset/sql/{self.id}/"
|
||||
|
||||
def get_perm(self):
|
||||
return ("[{obj.database_name}].(id:{obj.id})").format(obj=self)
|
||||
def get_perm(self) -> str:
|
||||
return f"[{self.database_name}].(id:{self.id})"
|
||||
|
||||
def has_table(self, table):
|
||||
def has_table(self, table: Table) -> bool:
|
||||
engine = self.get_sqla_engine()
|
||||
return engine.has_table(table.table_name, table.schema or None)
|
||||
|
||||
def has_table_by_name(self, table_name, schema=None):
|
||||
def has_table_by_name(self, table_name: str, schema: Optional[str] = None) -> bool:
|
||||
engine = self.get_sqla_engine()
|
||||
return engine.has_table(table_name, schema)
|
||||
|
||||
@utils.memoized
|
||||
def get_dialect(self):
|
||||
def get_dialect(self) -> Dialect:
|
||||
sqla_url = url.make_url(self.sqlalchemy_uri_decrypted)
|
||||
return sqla_url.get_dialect()()
|
||||
|
||||
|
|
@ -1265,30 +1295,29 @@ class DatasourceAccessRequest(Model, AuditMixinNullable):
|
|||
ROLES_BLACKLIST = set(config["ROBOT_PERMISSION_ROLES"])
|
||||
|
||||
@property
|
||||
def cls_model(self):
|
||||
def cls_model(self) -> Type["BaseDatasource"]:
|
||||
return ConnectorRegistry.sources[self.datasource_type]
|
||||
|
||||
@property
|
||||
def username(self):
|
||||
def username(self) -> Markup:
|
||||
return self.creator()
|
||||
|
||||
@property
|
||||
def datasource(self):
|
||||
def datasource(self) -> "BaseDatasource":
|
||||
return self.get_datasource
|
||||
|
||||
@datasource.getter # type: ignore
|
||||
@utils.memoized
|
||||
def get_datasource(self):
|
||||
# pylint: disable=no-member
|
||||
def get_datasource(self) -> "BaseDatasource":
|
||||
ds = db.session.query(self.cls_model).filter_by(id=self.datasource_id).first()
|
||||
return ds
|
||||
|
||||
@property
|
||||
def datasource_link(self):
|
||||
def datasource_link(self) -> Optional[Markup]:
|
||||
return self.datasource.link # pylint: disable=no-member
|
||||
|
||||
@property
|
||||
def roles_with_datasource(self):
|
||||
def roles_with_datasource(self) -> str:
|
||||
action_list = ""
|
||||
perm = self.datasource.perm # pylint: disable=no-member
|
||||
pv = security_manager.find_permission_view_menu("datasource_access", perm)
|
||||
|
|
@ -1306,7 +1335,7 @@ class DatasourceAccessRequest(Model, AuditMixinNullable):
|
|||
return "<ul>" + action_list + "</ul>"
|
||||
|
||||
@property
|
||||
def user_roles(self):
|
||||
def user_roles(self) -> str:
|
||||
action_list = ""
|
||||
for r in self.created_by.roles: # pylint: disable=no-member
|
||||
# pylint: disable=no-member
|
||||
|
|
|
|||
Loading…
Reference in New Issue