diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 2ccb24555..72ec4d767 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -505,7 +505,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return utils.error_msg_from_exception(e) @classmethod - def adjust_database_uri(cls, uri, selected_schema: str): + def adjust_database_uri(cls, uri, selected_schema: Optional[str]): """Based on a URI and selected schema, return a new URI The URI here represents the URI as entered when saving the database, @@ -718,19 +718,21 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return costs @classmethod - def modify_url_for_impersonation(cls, url, impersonate_user: bool, username: str): + def modify_url_for_impersonation( + cls, url, impersonate_user: bool, username: Optional[str] + ): """ Modify the SQL Alchemy URL object with the user to impersonate if applicable. :param url: SQLAlchemy URL object :param impersonate_user: Flag indicating if impersonation is enabled :param username: Effective username """ - if impersonate_user is not None and username is not None: + if impersonate_user and username is not None: url.username = username @classmethod def get_configuration_for_impersonation( # pylint: disable=invalid-name - cls, uri: str, impersonate_user: bool, username: str + cls, uri: str, impersonate_user: bool, username: Optional[str] ) -> Dict[str, str]: """ Return a configuration dictionary that can be merged with other configs diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index 712af14da..94413fa14 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -376,7 +376,9 @@ class HiveEngineSpec(PrestoEngineSpec): ) @classmethod - def modify_url_for_impersonation(cls, url, impersonate_user: bool, username: str): + def modify_url_for_impersonation( + cls, url, impersonate_user: bool, username: Optional[str] + ): """ Modify the SQL Alchemy URL object with the user to impersonate if applicable. :param url: SQLAlchemy URL object @@ -389,7 +391,7 @@ class HiveEngineSpec(PrestoEngineSpec): @classmethod def get_configuration_for_impersonation( - cls, uri: str, impersonate_user: bool, username: str + cls, uri: str, impersonate_user: bool, username: Optional[str] ) -> Dict[str, str]: """ Return a configuration dictionary that can be merged with other configs diff --git a/superset/models/core.py b/superset/models/core.py index a177d26c0..ccc909af0 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -22,7 +22,7 @@ import textwrap from contextlib import closing from copy import copy, deepcopy from datetime import datetime -from typing import List +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TYPE_CHECKING from urllib import parse import numpy @@ -45,22 +45,29 @@ from sqlalchemy import ( Table, Text, ) -from sqlalchemy.engine import url -from sqlalchemy.engine.url import make_url +from sqlalchemy.engine import Dialect, Engine, url +from sqlalchemy.engine.reflection import Inspector +from sqlalchemy.engine.url import make_url, URL from sqlalchemy.orm import relationship, sessionmaker, subqueryload from sqlalchemy.orm.session import make_transient from sqlalchemy.pool import NullPool from sqlalchemy.schema import UniqueConstraint +from sqlalchemy.sql import Select from sqlalchemy_utils import EncryptedType from superset import app, db, db_engine_specs, is_feature_enabled, security_manager from superset.connectors.connector_registry import ConnectorRegistry +from superset.db_engine_specs.base import TimeGrain from superset.legacy import update_time_range from superset.models.helpers import AuditMixinNullable, ImportMixin from superset.models.tags import ChartUpdater, DashboardUpdater, FavStarUpdater from superset.models.user_attributes import UserAttribute from superset.utils import cache as cache_util, core as utils -from superset.viz import viz_types +from superset.viz import BaseViz, viz_types + +if TYPE_CHECKING: + from superset.connectors.base.models import BaseDatasource + config = app.config custom_password_store = config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"] @@ -180,14 +187,14 @@ class Slice(Model, AuditMixinNullable, ImportMixin): return self.slice_name or str(self.id) @property - def cls_model(self): + def cls_model(self) -> Type["BaseDatasource"]: return ConnectorRegistry.sources[self.datasource_type] @property - def datasource(self): + def datasource(self) -> "BaseDatasource": return self.get_datasource - def clone(self): + def clone(self) -> "Slice": return Slice( slice_name=self.slice_name, datasource_id=self.datasource_id, @@ -201,46 +208,45 @@ class Slice(Model, AuditMixinNullable, ImportMixin): @datasource.getter # type: ignore @utils.memoized - def get_datasource(self): + def get_datasource(self) -> Optional["BaseDatasource"]: return db.session.query(self.cls_model).filter_by(id=self.datasource_id).first() @renders("datasource_name") - def datasource_link(self): + def datasource_link(self) -> Optional[Markup]: # pylint: disable=no-member datasource = self.datasource return datasource.link if datasource else None - def datasource_name_text(self): + def datasource_name_text(self) -> Optional[str]: # pylint: disable=no-member datasource = self.datasource return datasource.name if datasource else None @property - def datasource_edit_url(self): + def datasource_edit_url(self) -> Optional[str]: # pylint: disable=no-member datasource = self.datasource return datasource.url if datasource else None @property # type: ignore @utils.memoized - def viz(self): + def viz(self) -> BaseViz: d = json.loads(self.params) viz_class = viz_types[self.viz_type] - # pylint: disable=no-member return viz_class(datasource=self.datasource, form_data=d) @property - def description_markeddown(self): + def description_markeddown(self) -> str: return utils.markdown(self.description) @property - def data(self): + def data(self) -> Dict[str, Any]: """Data used to render slice in templates""" - d = {} + d: Dict[str, Any] = {} self.token = "" try: d = self.viz.data - self.token = d.get("token") + self.token = d.get("token") # type: ignore except Exception as e: logging.exception(e) d["error"] = str(e) @@ -259,12 +265,12 @@ class Slice(Model, AuditMixinNullable, ImportMixin): } @property - def json_data(self): + def json_data(self) -> str: return json.dumps(self.data) @property - def form_data(self): - form_data = {} + def form_data(self) -> Dict[str, Any]: + form_data: Dict[str, Any] = {} try: form_data = json.loads(self.params) except Exception as e: @@ -283,7 +289,11 @@ class Slice(Model, AuditMixinNullable, ImportMixin): update_time_range(form_data) return form_data - def get_explore_url(self, base_url="/superset/explore", overrides=None): + def get_explore_url( + self, + base_url: str = "/superset/explore", + overrides: Optional[Dict[str, Any]] = None, + ) -> str: overrides = overrides or {} form_data = {"slice_id": self.id} form_data.update(overrides) @@ -291,30 +301,30 @@ class Slice(Model, AuditMixinNullable, ImportMixin): return f"{base_url}/?form_data={params}" @property - def slice_url(self): + def slice_url(self) -> str: """Defines the url to access the slice""" return self.get_explore_url() @property - def explore_json_url(self): + def explore_json_url(self) -> str: """Defines the url to access the slice""" return self.get_explore_url("/superset/explore_json") @property - def edit_url(self): - return "/chart/edit/{}".format(self.id) + def edit_url(self) -> str: + return f"/chart/edit/{self.id}" @property - def chart(self): + def chart(self) -> str: return self.slice_name or "" @property - def slice_link(self): + def slice_link(self) -> Markup: url = self.slice_url name = escape(self.chart) return Markup(f'{name}') - def get_viz(self, force=False): + def get_viz(self, force: bool = False) -> BaseViz: """Creates :py:class:viz.BaseViz object from the url_params_multidict. :return: object of the 'viz_type' type that is taken from the @@ -332,7 +342,7 @@ class Slice(Model, AuditMixinNullable, ImportMixin): ) @property - def icons(self): + def icons(self) -> str: return f""" int: """Inserts or overrides slc in the database. remote_id and import_time fields in params_dict are set to track the @@ -363,7 +378,7 @@ class Slice(Model, AuditMixinNullable, ImportMixin): slc_to_import = slc_to_import.copy() slc_to_import.reset_ownership() params = slc_to_import.params_dict - slc_to_import.datasource_id = ConnectorRegistry.get_datasource_by_name( + slc_to_import.datasource_id = ConnectorRegistry.get_datasource_by_name( # type: ignore session, slc_to_import.datasource_type, params["datasource_name"], @@ -380,10 +395,8 @@ class Slice(Model, AuditMixinNullable, ImportMixin): return slc_to_import.id @property - def url(self): - return "/superset/explore/?form_data=%7B%22slice_id%22%3A%20{0}%7D".format( - self.id - ) + def url(self) -> str: + return f"/superset/explore/?form_data=%7B%22slice_id%22%3A%20{self.id}%7D" sqla.event.listen(Slice, "before_insert", set_related_perm) @@ -437,12 +450,12 @@ class Dashboard(Model, AuditMixinNullable, ImportMixin): return self.dashboard_title or str(self.id) @property - def table_names(self): + def table_names(self) -> str: # pylint: disable=no-member - return ", ".join({"{}".format(s.datasource.full_name) for s in self.slices}) + return ", ".join(str(s.datasource.full_name) for s in self.slices) @property - def url(self): + def url(self) -> str: if self.json_metadata: # add default_filters to the preselect_filters of dashboard json_metadata = json.loads(self.json_metadata) @@ -457,28 +470,28 @@ class Dashboard(Model, AuditMixinNullable, ImportMixin): ) except Exception: pass - return "/superset/dashboard/{}/".format(self.slug or self.id) + return f"/superset/dashboard/{self.slug or self.id}/" @property - def datasources(self): + def datasources(self) -> Set[Optional["BaseDatasource"]]: return {slc.datasource for slc in self.slices} @property - def charts(self): + def charts(self) -> List[Optional["BaseDatasource"]]: return [slc.chart for slc in self.slices] @property - def sqla_metadata(self): + def sqla_metadata(self) -> None: # pylint: disable=no-member metadata = MetaData(bind=self.get_sqla_engine()) - return metadata.reflect() + metadata.reflect() - def dashboard_link(self): + def dashboard_link(self) -> Markup: title = escape(self.dashboard_title or "") return Markup(f'{title}') @property - def data(self): + def data(self) -> Dict[str, Any]: positions = self.position_json if positions: positions = json.loads(positions) @@ -494,21 +507,23 @@ class Dashboard(Model, AuditMixinNullable, ImportMixin): } @property - def params(self): + def params(self) -> str: return self.json_metadata @params.setter - def params(self, value): + def params(self, value: str) -> None: self.json_metadata = value @property - def position(self): + def position(self) -> Dict: if self.position_json: return json.loads(self.position_json) return {} @classmethod - def import_obj(cls, dashboard_to_import, import_time=None): + def import_obj( + cls, dashboard_to_import: "Dashboard", import_time: Optional[int] = None + ) -> int: """Imports the dashboard from the object to the database. Once dashboard is imported, json_metadata field is extended and stores @@ -652,10 +667,10 @@ class Dashboard(Model, AuditMixinNullable, ImportMixin): dashboard_to_import.slices = new_slices session.add(dashboard_to_import) session.flush() - return dashboard_to_import.id + return dashboard_to_import.id # type: ignore @classmethod - def export_dashboards(cls, dashboard_ids): + def export_dashboards(cls, dashboard_ids: List) -> str: copied_dashboards = [] datasource_ids = set() for dashboard_id in dashboard_ids: @@ -688,22 +703,22 @@ class Dashboard(Model, AuditMixinNullable, ImportMixin): copied_dashboard.alter_params(remote_id=dashboard_id) copied_dashboards.append(copied_dashboard) - eager_datasources = [] - for datasource_id, datasource_type in datasource_ids: - eager_datasource = ConnectorRegistry.get_eager_datasource( - db.session, datasource_type, datasource_id - ) - copied_datasource = eager_datasource.copy() - copied_datasource.alter_params( - remote_id=eager_datasource.id, - database_name=eager_datasource.database.name, - ) - datasource_class = copied_datasource.__class__ - for field_name in datasource_class.export_children: - field_val = getattr(eager_datasource, field_name).copy() - # set children without creating ORM relations - copied_datasource.__dict__[field_name] = field_val - eager_datasources.append(copied_datasource) + eager_datasources = [] + for datasource_id, datasource_type in datasource_ids: + eager_datasource = ConnectorRegistry.get_eager_datasource( + db.session, datasource_type, datasource_id + ) + copied_datasource = eager_datasource.copy() + copied_datasource.alter_params( + remote_id=eager_datasource.id, + database_name=eager_datasource.database.name, + ) + datasource_class = copied_datasource.__class__ + for field_name in datasource_class.export_children: + field_val = getattr(eager_datasource, field_name).copy() + # set children without creating ORM relations + copied_datasource.__dict__[field_name] = field_val + eager_datasources.append(copied_datasource) return json.dumps( {"dashboards": copied_dashboards, "datasources": eager_datasources}, @@ -767,25 +782,27 @@ class Database(Model, AuditMixinNullable, ImportMixin): return self.name @property - def name(self): + def name(self) -> str: return self.verbose_name if self.verbose_name else self.database_name @property - def allows_subquery(self): + def allows_subquery(self) -> bool: return self.db_engine_spec.allows_subqueries @property def allows_cost_estimate(self) -> bool: extra = self.get_extra() + database_version = extra.get("version") - cost_estimate_enabled = extra.get("cost_estimate_enabled") + cost_estimate_enabled: bool = extra.get("cost_estimate_enabled") # type: ignore + return ( self.db_engine_spec.get_allow_cost_estimate(database_version) and cost_estimate_enabled ) @property - def data(self): + def data(self) -> Dict[str, Any]: return { "id": self.id, "name": self.database_name, @@ -796,55 +813,55 @@ class Database(Model, AuditMixinNullable, ImportMixin): } @property - def unique_name(self): + def unique_name(self) -> str: return self.database_name @property - def url_object(self): + def url_object(self) -> URL: return make_url(self.sqlalchemy_uri_decrypted) @property - def backend(self): + def backend(self) -> str: url = make_url(self.sqlalchemy_uri_decrypted) return url.get_backend_name() @property - def metadata_cache_timeout(self): + def metadata_cache_timeout(self) -> Dict[str, Any]: return self.get_extra().get("metadata_cache_timeout", {}) @property - def schema_cache_enabled(self): + def schema_cache_enabled(self) -> bool: return "schema_cache_timeout" in self.metadata_cache_timeout @property - def schema_cache_timeout(self): + def schema_cache_timeout(self) -> Optional[int]: return self.metadata_cache_timeout.get("schema_cache_timeout") @property - def table_cache_enabled(self): + def table_cache_enabled(self) -> bool: return "table_cache_timeout" in self.metadata_cache_timeout @property - def table_cache_timeout(self): + def table_cache_timeout(self) -> Optional[int]: return self.metadata_cache_timeout.get("table_cache_timeout") @property - def default_schemas(self): + def default_schemas(self) -> List[str]: return self.get_extra().get("default_schemas", []) @classmethod - def get_password_masked_url_from_uri(cls, uri): + def get_password_masked_url_from_uri(cls, uri: str): url = make_url(uri) return cls.get_password_masked_url(url) @classmethod - def get_password_masked_url(cls, url): + def get_password_masked_url(cls, url: URL) -> URL: url_copy = deepcopy(url) - if url_copy.password is not None and url_copy.password != PASSWORD_MASK: + if url_copy.password is not None: url_copy.password = PASSWORD_MASK return url_copy - def set_sqlalchemy_uri(self, uri): + def set_sqlalchemy_uri(self, uri: str) -> None: conn = sqla.engine.url.make_url(uri.strip()) if conn.password != PASSWORD_MASK and not custom_password_store: # do not over-write the password with the password mask @@ -852,7 +869,9 @@ class Database(Model, AuditMixinNullable, ImportMixin): conn.password = PASSWORD_MASK if conn.password else None self.sqlalchemy_uri = str(conn) # hides the password - def get_effective_user(self, url, user_name=None): + def get_effective_user( + self, url: URL, user_name: Optional[str] = None + ) -> Optional[str]: """ Get the effective user, especially during impersonation. :param url: SQL Alchemy URL object @@ -873,7 +892,13 @@ class Database(Model, AuditMixinNullable, ImportMixin): return effective_username @utils.memoized(watch=("impersonate_user", "sqlalchemy_uri_decrypted", "extra")) - def get_sqla_engine(self, schema=None, nullpool=True, user_name=None, source=None): + def get_sqla_engine( + self, + schema: Optional[str] = None, + nullpool: bool = True, + user_name: Optional[str] = None, + source: Optional[int] = None, + ) -> Engine: extra = self.get_extra() url = make_url(self.sqlalchemy_uri_decrypted) url = self.db_engine_spec.adjust_database_uri(url, schema) @@ -893,7 +918,7 @@ class Database(Model, AuditMixinNullable, ImportMixin): params["poolclass"] = NullPool # If using Hive, this will set hive.server2.proxy.user=$effective_username - configuration = {} + configuration: Dict[str, Any] = {} configuration.update( self.db_engine_spec.get_configuration_for_impersonation( str(url), self.impersonate_user, effective_username @@ -913,14 +938,16 @@ class Database(Model, AuditMixinNullable, ImportMixin): ) return create_engine(url, **params) - def get_reserved_words(self): + def get_reserved_words(self) -> Set[str]: return self.get_dialect().preparer.reserved_words def get_quoter(self): return self.get_dialect().identifier_preparer.quote - def get_df(self, sql, schema, mutator=None): - sqls = [str(s).strip().strip(";") for s in sqlparse.parse(sql)] + def get_df( + self, sql: str, schema: str, mutator: Optional[Callable] = None + ) -> pd.DataFrame: + sqls = [str(s).strip(" ;") for s in sqlparse.parse(sql)] source_key = None if request and request.referrer: if "/superset/dashboard/" in request.referrer: @@ -928,18 +955,14 @@ class Database(Model, AuditMixinNullable, ImportMixin): elif "/superset/explore/" in request.referrer: source_key = "chart" engine = self.get_sqla_engine( - schema=schema, source=utils.sources.get(source_key, None) + schema=schema, source=utils.sources[source_key] if source_key else None ) username = utils.get_username() - def needs_conversion(df_series): - if df_series.empty: - return False - if isinstance(df_series[0], (list, dict)): - return True - return False + def needs_conversion(df_series: pd.Series) -> bool: + return not df_series.empty and isinstance(df_series[0], (list, dict)) - def _log_query(sql): + def _log_query(sql: str) -> None: if log_query: log_query(engine.url, sql, schema, username, __name__, security_manager) @@ -970,7 +993,7 @@ class Database(Model, AuditMixinNullable, ImportMixin): df[k] = df[k].apply(utils.json_dumps_w_dates) return df - def compile_sqla_query(self, qry, schema=None): + def compile_sqla_query(self, qry: Select, schema: Optional[str] = None) -> str: engine = self.get_sqla_engine(schema=schema) sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True})) @@ -982,13 +1005,13 @@ class Database(Model, AuditMixinNullable, ImportMixin): def select_star( self, - table_name, - schema=None, - limit=100, - show_cols=False, - indent=True, - latest_partition=False, - cols=None, + table_name: str, + schema: Optional[str] = None, + limit: int = 100, + show_cols: bool = False, + indent: bool = True, + latest_partition: bool = False, + cols: Optional[List[Dict[str, Any]]] = None, ): """Generates a ``select *`` statement in the proper dialect""" eng = self.get_sqla_engine( @@ -1006,14 +1029,14 @@ class Database(Model, AuditMixinNullable, ImportMixin): cols=cols, ) - def apply_limit_to_sql(self, sql, limit=1000): + def apply_limit_to_sql(self, sql: str, limit: int = 1000) -> str: return self.db_engine_spec.apply_limit_to_sql(sql, limit, self) - def safe_sqlalchemy_uri(self): + def safe_sqlalchemy_uri(self) -> str: return self.sqlalchemy_uri @property - def inspector(self): + def inspector(self) -> Inspector: engine = self.get_sqla_engine() return sqla.inspect(engine) @@ -1030,7 +1053,8 @@ class Database(Model, AuditMixinNullable, ImportMixin): return self.db_engine_spec.get_all_datasource_names(self, "table") @cache_util.memoized_func( - key=lambda *args, **kwargs: "db:{}:schema:None:view_list", attribute_in_key="id" + key=lambda *args, **kwargs: "db:{}:schema:None:view_list", + attribute_in_key="id", # type: ignore ) def get_all_view_names_in_database( self, cache: bool = False, cache_timeout: bool = None, force: bool = False @@ -1041,9 +1065,7 @@ class Database(Model, AuditMixinNullable, ImportMixin): return self.db_engine_spec.get_all_datasource_names(self, "view") @cache_util.memoized_func( - key=lambda *args, **kwargs: "db:{{}}:schema:{}:table_list".format( - kwargs.get("schema") - ), + key=lambda *args, **kwargs: f"db:{{}}:schema:{kwargs.get('schema')}:table_list", # type: ignore attribute_in_key="id", ) def get_all_table_names_in_schema( @@ -1052,7 +1074,7 @@ class Database(Model, AuditMixinNullable, ImportMixin): cache: bool = False, cache_timeout: int = None, force: bool = False, - ): + ) -> List[utils.DatasourceName]: """Parameters need to be passed as keyword arguments. For unused parameters, they are referenced in @@ -1075,9 +1097,7 @@ class Database(Model, AuditMixinNullable, ImportMixin): logging.exception(e) @cache_util.memoized_func( - key=lambda *args, **kwargs: "db:{{}}:schema:{}:view_list".format( - kwargs.get("schema") - ), + key=lambda *args, **kwargs: f"db:{{}}:schema:{kwargs.get('schema')}:view_list", # type: ignore attribute_in_key="id", ) def get_all_view_names_in_schema( @@ -1086,7 +1106,7 @@ class Database(Model, AuditMixinNullable, ImportMixin): cache: bool = False, cache_timeout: int = None, force: bool = False, - ): + ) -> List[utils.DatasourceName]: """Parameters need to be passed as keyword arguments. For unused parameters, they are referenced in @@ -1125,14 +1145,16 @@ class Database(Model, AuditMixinNullable, ImportMixin): return self.db_engine_spec.get_schema_names(self.inspector) @property - def db_engine_spec(self): + def db_engine_spec(self) -> Type[db_engine_specs.BaseEngineSpec]: return db_engine_specs.engines.get(self.backend, db_engine_specs.BaseEngineSpec) @classmethod - def get_db_engine_spec_for_backend(cls, backend): + def get_db_engine_spec_for_backend( + cls, backend + ) -> Type[db_engine_specs.BaseEngineSpec]: return db_engine_specs.engines.get(backend, db_engine_specs.BaseEngineSpec) - def grains(self): + def grains(self) -> Tuple[TimeGrain, ...]: """Defines time granularity database-specific expressions. The idea here is to make it easy for users to change the time grain @@ -1143,8 +1165,8 @@ class Database(Model, AuditMixinNullable, ImportMixin): """ return self.db_engine_spec.get_time_grains() - def get_extra(self): - extra = {} + def get_extra(self) -> Dict[str, Any]: + extra: Dict[str, Any] = {} if self.extra: try: extra = json.loads(self.extra) @@ -1163,7 +1185,7 @@ class Database(Model, AuditMixinNullable, ImportMixin): raise e return encrypted_extra - def get_table(self, table_name, schema=None): + def get_table(self, table_name: str, schema: Optional[str] = None) -> Table: extra = self.get_extra() meta = MetaData(**extra.get("metadata_params", {})) return Table( @@ -1174,23 +1196,31 @@ class Database(Model, AuditMixinNullable, ImportMixin): autoload_with=self.get_sqla_engine(), ) - def get_columns(self, table_name, schema=None): + def get_columns( + self, table_name: str, schema: Optional[str] = None + ) -> List[Dict[str, Any]]: return self.db_engine_spec.get_columns(self.inspector, table_name, schema) - def get_indexes(self, table_name, schema=None): + def get_indexes( + self, table_name: str, schema: Optional[str] = None + ) -> List[Dict[str, Any]]: return self.inspector.get_indexes(table_name, schema) - def get_pk_constraint(self, table_name, schema=None): + def get_pk_constraint( + self, table_name: str, schema: Optional[str] = None + ) -> Dict[str, Any]: return self.inspector.get_pk_constraint(table_name, schema) - def get_foreign_keys(self, table_name, schema=None): + def get_foreign_keys( + self, table_name: str, schema: Optional[str] = None + ) -> List[Dict[str, Any]]: return self.inspector.get_foreign_keys(table_name, schema) - def get_schema_access_for_csv_upload(self): + def get_schema_access_for_csv_upload(self) -> List[str]: return self.get_extra().get("schemas_allowed_for_csv_upload", []) @property - def sqlalchemy_uri_decrypted(self): + def sqlalchemy_uri_decrypted(self) -> str: conn = sqla.engine.url.make_url(self.sqlalchemy_uri) if custom_password_store: conn.password = custom_password_store(conn) @@ -1199,22 +1229,22 @@ class Database(Model, AuditMixinNullable, ImportMixin): return str(conn) @property - def sql_url(self): - return "/superset/sql/{}/".format(self.id) + def sql_url(self) -> str: + return f"/superset/sql/{self.id}/" - def get_perm(self): - return ("[{obj.database_name}].(id:{obj.id})").format(obj=self) + def get_perm(self) -> str: + return f"[{self.database_name}].(id:{self.id})" - def has_table(self, table): + def has_table(self, table: Table) -> bool: engine = self.get_sqla_engine() return engine.has_table(table.table_name, table.schema or None) - def has_table_by_name(self, table_name, schema=None): + def has_table_by_name(self, table_name: str, schema: Optional[str] = None) -> bool: engine = self.get_sqla_engine() return engine.has_table(table_name, schema) @utils.memoized - def get_dialect(self): + def get_dialect(self) -> Dialect: sqla_url = url.make_url(self.sqlalchemy_uri_decrypted) return sqla_url.get_dialect()() @@ -1265,30 +1295,29 @@ class DatasourceAccessRequest(Model, AuditMixinNullable): ROLES_BLACKLIST = set(config["ROBOT_PERMISSION_ROLES"]) @property - def cls_model(self): + def cls_model(self) -> Type["BaseDatasource"]: return ConnectorRegistry.sources[self.datasource_type] @property - def username(self): + def username(self) -> Markup: return self.creator() @property - def datasource(self): + def datasource(self) -> "BaseDatasource": return self.get_datasource @datasource.getter # type: ignore @utils.memoized - def get_datasource(self): - # pylint: disable=no-member + def get_datasource(self) -> "BaseDatasource": ds = db.session.query(self.cls_model).filter_by(id=self.datasource_id).first() return ds @property - def datasource_link(self): + def datasource_link(self) -> Optional[Markup]: return self.datasource.link # pylint: disable=no-member @property - def roles_with_datasource(self): + def roles_with_datasource(self) -> str: action_list = "" perm = self.datasource.perm # pylint: disable=no-member pv = security_manager.find_permission_view_menu("datasource_access", perm) @@ -1306,7 +1335,7 @@ class DatasourceAccessRequest(Model, AuditMixinNullable): return "" @property - def user_roles(self): + def user_roles(self) -> str: action_list = "" for r in self.created_by.roles: # pylint: disable=no-member # pylint: disable=no-member