From e789a3555843d9791b9230a61454a3abb8cb07e0 Mon Sep 17 00:00:00 2001 From: John Bodley <4567245+john-bodley@users.noreply.github.com> Date: Fri, 22 May 2020 20:31:21 -0700 Subject: [PATCH] [mypy] Enforcing typing for superset.models (#9883) Co-authored-by: John Bodley --- setup.cfg | 2 +- superset/connectors/sqla/models.py | 8 +- superset/legacy.py | 3 +- superset/models/annotations.py | 6 +- superset/models/core.py | 21 ++- superset/models/dashboard.py | 26 ++-- superset/models/helpers.py | 140 ++++++++++-------- superset/models/schedules.py | 6 +- superset/models/slice.py | 12 +- superset/models/sql_lab.py | 22 +-- superset/models/sql_types/presto_sql_types.py | 22 +-- superset/models/tags.py | 59 ++++++-- superset/utils/cache.py | 6 +- superset/utils/core.py | 4 +- 14 files changed, 207 insertions(+), 130 deletions(-) diff --git a/setup.cfg b/setup.cfg index dc3e7013e..bfef8affc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -53,7 +53,7 @@ order_by_type = false ignore_missing_imports = true no_implicit_optional = true -[mypy-superset.bin.*,superset.charts.*,superset.datasets.*,superset.dashboards.*,superset.commands.*,superset.common.*,superset.dao.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*,superset.translations.*] +[mypy-superset.bin.*,superset.charts.*,superset.datasets.*,superset.dashboards.*,superset.commands.*,superset.common.*,superset.dao.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.models.*,superset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*,superset.translations.*] check_untyped_defs = true disallow_untyped_calls = true disallow_untyped_defs = true diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index ad45474c3..3e50280d1 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -18,7 +18,7 @@ import logging import re from collections import OrderedDict -from datetime import datetime +from datetime import datetime, timedelta from typing import Any, Dict, Hashable, List, NamedTuple, Optional, Tuple, Union import pandas as pd @@ -103,7 +103,11 @@ class AnnotationDatasource(BaseDatasource): logger.exception(ex) error_message = utils.error_msg_from_exception(ex) return QueryResult( - status=status, df=df, duration=0, query="", error_message=error_message + status=status, + df=df, + duration=timedelta(0), + query="", + error_message=error_message, ) def get_query_str(self, query_obj): diff --git a/superset/legacy.py b/superset/legacy.py index b867edcc4..168b9c0b6 100644 --- a/superset/legacy.py +++ b/superset/legacy.py @@ -15,9 +15,10 @@ # specific language governing permissions and limitations # under the License. """Code related with dealing with legacy / change management""" +from typing import Any, Dict -def update_time_range(form_data): +def update_time_range(form_data: Dict[str, Any]) -> None: """Move since and until to time_range.""" if "since" in form_data or "until" in form_data: form_data["time_range"] = "{} : {}".format( diff --git a/superset/models/annotations.py b/superset/models/annotations.py index 07e235174..ec8d3c045 100644 --- a/superset/models/annotations.py +++ b/superset/models/annotations.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """a collection of Annotation-related models""" +from typing import Any, Dict + from flask_appbuilder import Model from sqlalchemy import Column, DateTime, ForeignKey, Index, Integer, String, Text from sqlalchemy.orm import relationship @@ -31,7 +33,7 @@ class AnnotationLayer(Model, AuditMixinNullable): name = Column(String(250)) descr = Column(Text) - def __repr__(self): + def __repr__(self) -> str: return self.name @@ -52,7 +54,7 @@ class Annotation(Model, AuditMixinNullable): __table_args__ = (Index("ti_dag_state", layer_id, start_dttm, end_dttm),) @property - def data(self): + def data(self) -> Dict[str, Any]: return { "layer_id": self.layer_id, "start_dttm": self.start_dttm, diff --git a/superset/models/core.py b/superset/models/core.py index abcb210d7..69d306a4d 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -152,7 +152,7 @@ class Database( ] export_children = ["tables"] - def __repr__(self): + def __repr__(self) -> str: return self.name @property @@ -234,7 +234,9 @@ class Database( return self.get_extra().get("default_schemas", []) @classmethod - def get_password_masked_url_from_uri(cls, uri: str): # pylint: disable=invalid-name + def get_password_masked_url_from_uri( # pylint: disable=invalid-name + cls, uri: str + ) -> URL: sqlalchemy_url = make_url(uri) return cls.get_password_masked_url(sqlalchemy_url) @@ -279,7 +281,7 @@ class Database( effective_username = g.user.username return effective_username - @utils.memoized(watch=("impersonate_user", "sqlalchemy_uri_decrypted", "extra")) + @utils.memoized(watch=["impersonate_user", "sqlalchemy_uri_decrypted", "extra"]) def get_sqla_engine( self, schema: Optional[str] = None, @@ -339,7 +341,7 @@ class Database( def get_reserved_words(self) -> Set[str]: return self.get_dialect().preparer.reserved_words - def get_quoter(self): + def get_quoter(self) -> Callable: return self.get_dialect().identifier_preparer.quote def get_df( # pylint: disable=too-many-locals @@ -405,7 +407,7 @@ class Database( indent: bool = True, latest_partition: bool = False, cols: Optional[List[Dict[str, Any]]] = None, - ): + ) -> str: """Generates a ``select *`` statement in the proper dialect""" eng = self.get_sqla_engine(schema=schema, source=utils.QuerySource.SQL_LAB) return self.db_engine_spec.select_star( @@ -436,7 +438,10 @@ class Database( attribute_in_key="id", ) def get_all_table_names_in_database( - self, cache: bool = False, cache_timeout: Optional[bool] = None, force=False + self, + cache: bool = False, + cache_timeout: Optional[bool] = None, + force: bool = False, ) -> List[utils.DatasourceName]: """Parameters need to be passed as keyword arguments.""" if not self.allow_multi_schema_metadata_fetch: @@ -547,7 +552,7 @@ class Database( @classmethod def get_db_engine_spec_for_backend( - cls, backend + cls, backend: str ) -> Type[db_engine_specs.BaseEngineSpec]: return db_engine_specs.engines.get(backend, db_engine_specs.BaseEngineSpec) @@ -565,7 +570,7 @@ class Database( def get_extra(self) -> Dict[str, Any]: return self.db_engine_spec.get_extra_params(self) - def get_encrypted_extra(self): + def get_encrypted_extra(self) -> Dict[str, Any]: encrypted_extra = {} if self.encrypted_extra: try: diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py index c86f8ff83..de4228563 100644 --- a/superset/models/dashboard.py +++ b/superset/models/dashboard.py @@ -36,7 +36,9 @@ from sqlalchemy import ( Text, UniqueConstraint, ) +from sqlalchemy.engine.base import Connection from sqlalchemy.orm import relationship, sessionmaker, subqueryload +from sqlalchemy.orm.mapper import Mapper from superset import app, ConnectorRegistry, db, is_feature_enabled, security_manager from superset.models.helpers import AuditMixinNullable, ImportMixin @@ -59,7 +61,7 @@ config = app.config logger = logging.getLogger(__name__) -def copy_dashboard(mapper, connection, target): +def copy_dashboard(mapper: Mapper, connection: Connection, target: "Dashboard") -> None: # pylint: disable=unused-argument dashboard_id = config["DASHBOARD_TEMPLATE_ID"] if dashboard_id is None: @@ -140,7 +142,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes "slug", ] - def __repr__(self): + def __repr__(self) -> str: return self.dashboard_title or str(self.id) @property @@ -202,13 +204,13 @@ class Dashboard( # pylint: disable=too-many-instance-attributes return f"/api/v1/dashboard/{self.id}/thumbnail/{self.digest}/" @property - def changed_by_name(self): + def changed_by_name(self) -> str: if not self.changed_by: return "" return str(self.changed_by) @property - def changed_by_url(self): + def changed_by_url(self) -> str: if not self.changed_by: return "" return f"/superset/profile/{self.changed_by.username}" @@ -229,8 +231,8 @@ class Dashboard( # pylint: disable=too-many-instance-attributes "position_json": positions, } - @property - def params(self) -> str: + @property # type: ignore + def params(self) -> str: # type: ignore return self.json_metadata @params.setter @@ -257,7 +259,9 @@ class Dashboard( # pylint: disable=too-many-instance-attributes Audit metadata isn't copied over. """ - def alter_positions(dashboard, old_to_new_slc_id_dict): + def alter_positions( + dashboard: Dashboard, old_to_new_slc_id_dict: Dict[int, int] + ) -> None: """ Updates slice_ids in the position json. Sample position_json data: @@ -291,9 +295,9 @@ class Dashboard( # pylint: disable=too-many-instance-attributes if ( isinstance(value, dict) and value.get("meta") - and value.get("meta").get("chartId") + and value.get("meta", {}).get("chartId") ): - old_slice_id = value.get("meta").get("chartId") + old_slice_id = value["meta"]["chartId"] if old_slice_id in old_to_new_slc_id_dict: value["meta"]["chartId"] = old_to_new_slc_id_dict[old_slice_id] @@ -470,8 +474,8 @@ class Dashboard( # pylint: disable=too-many-instance-attributes def event_after_dashboard_changed( # pylint: disable=unused-argument - mapper, connection, target -): + mapper: Mapper, connection: Connection, target: Dashboard +) -> None: cache_dashboard_thumbnail.delay(target.id, force=True) diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 93b3b7f3e..42169e61a 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -18,8 +18,8 @@ import json import logging import re -from datetime import datetime -from typing import Any, Dict, List, Optional +from datetime import datetime, timedelta +from typing import Any, Dict, List, Optional, Set, Union # isort and pylint disagree, isort should win # pylint: disable=ungrouped-imports @@ -30,8 +30,10 @@ import yaml from flask import escape, g, Markup from flask_appbuilder.models.decorators import renders from flask_appbuilder.models.mixins import AuditMixin +from flask_appbuilder.security.sqla.models import User from sqlalchemy import and_, or_, UniqueConstraint from sqlalchemy.ext.declarative import declared_attr +from sqlalchemy.orm import Session from sqlalchemy.orm.exc import MultipleResultsFound from superset.utils.core import QueryStatus @@ -39,7 +41,7 @@ from superset.utils.core import QueryStatus logger = logging.getLogger(__name__) -def json_to_dict(json_str): +def json_to_dict(json_str: str) -> Dict[Any, Any]: if json_str: val = re.sub(",[ \t\r\n]+}", "}", json_str) val = re.sub( @@ -64,48 +66,56 @@ class ImportMixin: # that are available for import and export @classmethod - def _parent_foreign_key_mappings(cls): + def _parent_foreign_key_mappings(cls) -> Dict[str, str]: """Get a mapping of foreign name to the local name of foreign keys""" - parent_rel = cls.__mapper__.relationships.get(cls.export_parent) + parent_rel = cls.__mapper__.relationships.get(cls.export_parent) # type: ignore if parent_rel: return {l.name: r.name for (l, r) in parent_rel.local_remote_pairs} return {} @classmethod - def _unique_constrains(cls): + def _unique_constrains(cls) -> List[Set[str]]: """Get all (single column and multi column) unique constraints""" unique = [ {c.name for c in u.columns} - for u in cls.__table_args__ + for u in cls.__table_args__ # type: ignore if isinstance(u, UniqueConstraint) ] - unique.extend({c.name} for c in cls.__table__.columns if c.unique) + unique.extend( # type: ignore + {c.name} for c in cls.__table__.columns if c.unique # type: ignore + ) return unique @classmethod - def export_schema(cls, recursive=True, include_parent_ref=False): + def export_schema( + cls, recursive: bool = True, include_parent_ref: bool = False + ) -> Dict[str, Any]: """Export schema as a dictionary""" - parent_excludes = {} + parent_excludes = set() if not include_parent_ref: - parent_ref = cls.__mapper__.relationships.get(cls.export_parent) + parent_ref = cls.__mapper__.relationships.get( # type: ignore + cls.export_parent + ) if parent_ref: parent_excludes = {column.name for column in parent_ref.local_columns} - def formatter(column): + def formatter(column: sa.Column) -> str: return ( "{0} Default ({1})".format(str(column.type), column.default.arg) if column.default else str(column.type) ) - schema = { + schema: Dict[str, Any] = { column.name: formatter(column) - for column in cls.__table__.columns + for column in cls.__table__.columns # type: ignore if (column.name in cls.export_fields and column.name not in parent_excludes) } if recursive: for column in cls.export_children: - child_class = cls.__mapper__.relationships[column].argument.class_ + child_class = cls.__mapper__.relationships[ # type: ignore + column + ].argument.class_ schema[column] = [ child_class.export_schema( recursive=recursive, include_parent_ref=include_parent_ref @@ -114,17 +124,20 @@ class ImportMixin: return schema @classmethod - def import_from_dict( - cls, session, dict_rep, parent=None, recursive=True, sync=None - ): # pylint: disable=too-many-arguments,too-many-locals,too-many-branches + def import_from_dict( # pylint: disable=too-many-arguments,too-many-branches,too-many-locals + cls, + session: Session, + dict_rep: Dict[Any, Any], + parent: Optional[Any] = None, + recursive: bool = True, + sync: Optional[List[str]] = None, + ) -> Any: # pylint: disable=too-many-arguments,too-many-locals,too-many-branches """Import obj from a dictionary""" if sync is None: sync = [] parent_refs = cls._parent_foreign_key_mappings() export_fields = set(cls.export_fields) | set(parent_refs.keys()) - new_children = { - c: dict_rep.get(c) for c in cls.export_children if c in dict_rep - } + new_children = {c: dict_rep[c] for c in cls.export_children if c in dict_rep} unique_constrains = cls._unique_constrains() filters = [] # Using these filters to check if obj already exists @@ -178,7 +191,7 @@ class ImportMixin: if not obj: is_new_obj = True # Create new DB object - obj = cls(**dict_rep) + obj = cls(**dict_rep) # type: ignore logger.info("Importing new %s %s", obj.__tablename__, str(obj)) if cls.export_parent and parent: setattr(obj, cls.export_parent, parent) @@ -193,7 +206,9 @@ class ImportMixin: # Recursively create children if recursive: for child in cls.export_children: - child_class = cls.__mapper__.relationships[child].argument.class_ + child_class = cls.__mapper__.relationships[ # type: ignore + child + ].argument.class_ added = [] for c_obj in new_children.get(child, []): added.append( @@ -221,18 +236,23 @@ class ImportMixin: return obj def export_to_dict( - self, recursive=True, include_parent_ref=False, include_defaults=False - ): + self, + recursive: bool = True, + include_parent_ref: bool = False, + include_defaults: bool = False, + ) -> Dict[Any, Any]: """Export obj to dictionary""" cls = self.__class__ - parent_excludes = {} + parent_excludes = set() if recursive and not include_parent_ref: - parent_ref = cls.__mapper__.relationships.get(cls.export_parent) + parent_ref = cls.__mapper__.relationships.get( # type: ignore + cls.export_parent + ) if parent_ref: parent_excludes = {c.name for c in parent_ref.local_columns} dict_rep = { c.name: getattr(self, c.name) - for c in cls.__table__.columns + for c in cls.__table__.columns # type: ignore if ( c.name in self.export_fields and c.name not in parent_excludes @@ -262,18 +282,18 @@ class ImportMixin: return dict_rep - def override(self, obj): + def override(self, obj: Any) -> None: """Overrides the plain fields of the dashboard.""" for field in obj.__class__.export_fields: setattr(self, field, getattr(obj, field)) - def copy(self): + def copy(self) -> Any: """Creates a copy of the dashboard without relationships.""" new_obj = self.__class__() new_obj.override(self) return new_obj - def alter_params(self, **kwargs): + def alter_params(self, **kwargs: Any) -> None: params = self.params_dict params.update(kwargs) self.params = json.dumps(params) @@ -283,7 +303,7 @@ class ImportMixin: params.pop(param_to_remove, None) self.params = json.dumps(params) - def reset_ownership(self): + def reset_ownership(self) -> None: """ object will belong to the user the current user """ # make sure the object doesn't have relations to a user # it will be filled by appbuilder on save @@ -297,15 +317,15 @@ class ImportMixin: self.owners = [] @property - def params_dict(self): + def params_dict(self) -> Dict[Any, Any]: return json_to_dict(self.params) @property - def template_params_dict(self): - return json_to_dict(self.template_params) + def template_params_dict(self) -> Dict[Any, Any]: + return json_to_dict(self.template_params) # type: ignore -def _user_link(user): # pylint: disable=no-self-use +def _user_link(user: User) -> Union[Markup, str]: # pylint: disable=no-self-use if not user: return "" url = "/superset/profile/{}/".format(user.username) @@ -325,7 +345,7 @@ class AuditMixinNullable(AuditMixin): ) @declared_attr - def created_by_fk(self): + def created_by_fk(self) -> sa.Column: return sa.Column( sa.Integer, sa.ForeignKey("ab_user.id"), @@ -334,7 +354,7 @@ class AuditMixinNullable(AuditMixin): ) @declared_attr - def changed_by_fk(self): + def changed_by_fk(self) -> sa.Column: return sa.Column( sa.Integer, sa.ForeignKey("ab_user.id"), @@ -343,29 +363,29 @@ class AuditMixinNullable(AuditMixin): nullable=True, ) - def changed_by_name(self): + def changed_by_name(self) -> str: if self.created_by: return escape("{}".format(self.created_by)) return "" @renders("created_by") - def creator(self): + def creator(self) -> Union[Markup, str]: return _user_link(self.created_by) @property - def changed_by_(self): + def changed_by_(self) -> Union[Markup, str]: return _user_link(self.changed_by) @renders("changed_on") - def changed_on_(self): + def changed_on_(self) -> Markup: return Markup(f'{self.changed_on}') @property - def changed_on_humanized(self): + def changed_on_humanized(self) -> str: return humanize.naturaltime(datetime.now() - self.changed_on) @renders("changed_on") - def modified(self): + def modified(self) -> Markup: return Markup(f'{self.changed_on_humanized}') @@ -375,19 +395,19 @@ class QueryResult: # pylint: disable=too-few-public-methods def __init__( # pylint: disable=too-many-arguments self, - df, - query, - duration, - status=QueryStatus.SUCCESS, - error_message=None, - errors=None, - ): - self.df: pd.DataFrame = df - self.query: str = query - self.duration: int = duration - self.status: str = status - self.error_message: Optional[str] = error_message - self.errors: List[Dict[str, Any]] = errors or [] + df: pd.DataFrame, + query: str, + duration: timedelta, + status: str = QueryStatus.SUCCESS, + error_message: Optional[str] = None, + errors: Optional[List[Dict[str, Any]]] = None, + ) -> None: + self.df = df + self.query = query + self.duration = duration + self.status = status + self.error_message = error_message + self.errors = errors or [] class ExtraJSONMixin: @@ -396,16 +416,16 @@ class ExtraJSONMixin: extra_json = sa.Column(sa.Text, default="{}") @property - def extra(self): + def extra(self) -> Dict[str, Any]: try: return json.loads(self.extra_json) except Exception: # pylint: disable=broad-except return {} - def set_extra_json(self, extras): + def set_extra_json(self, extras: Dict[str, Any]) -> None: self.extra_json = json.dumps(extras) - def set_extra_json_key(self, key, value): + def set_extra_json_key(self, key: str, value: Any) -> None: extra = self.extra extra[key] = value self.extra_json = json.dumps(extra) diff --git a/superset/models/schedules.py b/superset/models/schedules.py index f70b076d6..0eb31a5d7 100644 --- a/superset/models/schedules.py +++ b/superset/models/schedules.py @@ -21,7 +21,7 @@ from typing import Optional, Type from flask_appbuilder import Model from sqlalchemy import Boolean, Column, Enum, ForeignKey, Integer, String, Text from sqlalchemy.ext.declarative import declared_attr -from sqlalchemy.orm import relationship +from sqlalchemy.orm import relationship, RelationshipProperty from superset import security_manager from superset.models.helpers import AuditMixinNullable, ImportMixin @@ -55,11 +55,11 @@ class EmailSchedule: crontab = Column(String(50)) @declared_attr - def user_id(self): + def user_id(self) -> int: return Column(Integer, ForeignKey("ab_user.id")) @declared_attr - def user(self): + def user(self) -> RelationshipProperty: return relationship( security_manager.user_model, backref=self.__tablename__, diff --git a/superset/models/slice.py b/superset/models/slice.py index a570bcf85..76eb457a7 100644 --- a/superset/models/slice.py +++ b/superset/models/slice.py @@ -24,7 +24,9 @@ from flask_appbuilder import Model from flask_appbuilder.models.decorators import renders from markupsafe import escape, Markup from sqlalchemy import Column, ForeignKey, Integer, String, Table, Text +from sqlalchemy.engine.base import Connection from sqlalchemy.orm import make_transient, relationship +from sqlalchemy.orm.mapper import Mapper from superset import ConnectorRegistry, db, is_feature_enabled, security_manager from superset.legacy import update_time_range @@ -92,7 +94,7 @@ class Slice( "cache_timeout", ] - def __repr__(self): + def __repr__(self) -> str: return self.slice_name or str(self.id) @property @@ -263,7 +265,7 @@ class Slice( @property def changed_by_url(self) -> str: - return f"/superset/profile/{self.created_by.username}" + return f"/superset/profile/{self.created_by.username}" # type: ignore @property def icons(self) -> str: @@ -324,7 +326,7 @@ class Slice( return f"/superset/explore/?form_data=%7B%22slice_id%22%3A%20{self.id}%7D" -def set_related_perm(mapper, connection, target): +def set_related_perm(mapper: Mapper, connection: Connection, target: Slice) -> None: # pylint: disable=unused-argument src_class = target.cls_model id_ = target.datasource_id @@ -336,8 +338,8 @@ def set_related_perm(mapper, connection, target): def event_after_chart_changed( # pylint: disable=unused-argument - mapper, connection, target -): + mapper: Mapper, connection: Connection, target: Slice +) -> None: cache_chart_thumbnail.delay(target.id, force=True) diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index 654bc2fb4..9c3c239b3 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -17,6 +17,7 @@ """A collection of ORM sqlalchemy models for SQL Lab""" import re from datetime import datetime +from typing import Any, Dict # pylint: disable=ungrouped-imports import simplejson as json @@ -33,6 +34,7 @@ from sqlalchemy import ( String, Text, ) +from sqlalchemy.engine.url import URL from sqlalchemy.orm import backref, relationship from superset import security_manager @@ -99,7 +101,7 @@ class Query(Model, ExtraJSONMixin): __table_args__ = (sqla.Index("ti_user_id_changed_on", user_id, changed_on),) - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return { "changedOn": self.changed_on, "changed_on": self.changed_on.isoformat(), @@ -130,7 +132,7 @@ class Query(Model, ExtraJSONMixin): } @property - def name(self): + def name(self) -> str: """Name property""" ts = datetime.now().isoformat() ts = ts.replace("-", "").replace(":", "").split(".")[0] @@ -139,11 +141,11 @@ class Query(Model, ExtraJSONMixin): return f"sqllab_{tab}_{ts}" @property - def database_name(self): + def database_name(self) -> str: return self.database.name @property - def username(self): + def username(self) -> str: return self.user.username @@ -170,7 +172,7 @@ class SavedQuery(Model, AuditMixinNullable, ExtraJSONMixin): ) @property - def pop_tab_link(self): + def pop_tab_link(self) -> Markup: return Markup( f""" @@ -180,14 +182,14 @@ class SavedQuery(Model, AuditMixinNullable, ExtraJSONMixin): ) @property - def user_email(self): + def user_email(self) -> str: return self.user.email @property - def sqlalchemy_uri(self): + def sqlalchemy_uri(self) -> URL: return self.database.sqlalchemy_uri - def url(self): + def url(self) -> str: return "/superset/sqllab?savedQueryId={0}".format(self.id) @@ -226,7 +228,7 @@ class TabState(Model, AuditMixinNullable, ExtraJSONMixin): autorun = Column(Boolean, default=False) template_params = Column(Text) - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return { "id": self.id, "user_id": self.user_id, @@ -260,7 +262,7 @@ class TableSchema(Model, AuditMixinNullable, ExtraJSONMixin): expanded = Column(Boolean, default=False) - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: try: description = json.loads(self.description) except json.JSONDecodeError: diff --git a/superset/models/sql_types/presto_sql_types.py b/superset/models/sql_types/presto_sql_types.py index f0b46fac0..a50b4c2f7 100644 --- a/superset/models/sql_types/presto_sql_types.py +++ b/superset/models/sql_types/presto_sql_types.py @@ -14,10 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Any, Optional, Type from sqlalchemy import types from sqlalchemy.sql.sqltypes import Integer from sqlalchemy.sql.type_api import TypeEngine +from sqlalchemy.sql.visitors import Visitable # _compiler_dispatch is defined to help with type compilation @@ -27,11 +29,11 @@ class TinyInteger(Integer): A type for tiny ``int`` integers. """ - def python_type(self): + def python_type(self) -> Type: return int @classmethod - def _compiler_dispatch(cls, _visitor, **_kw): + def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str: return "TINYINT" @@ -40,11 +42,11 @@ class Interval(TypeEngine): A type for intervals. """ - def python_type(self): + def python_type(self) -> Optional[Type]: return None @classmethod - def _compiler_dispatch(cls, _visitor, **_kw): + def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str: return "INTERVAL" @@ -53,11 +55,11 @@ class Array(TypeEngine): A type for arrays. """ - def python_type(self): + def python_type(self) -> Optional[Type]: return list @classmethod - def _compiler_dispatch(cls, _visitor, **_kw): + def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str: return "ARRAY" @@ -66,11 +68,11 @@ class Map(TypeEngine): A type for maps. """ - def python_type(self): + def python_type(self) -> Optional[Type]: return dict @classmethod - def _compiler_dispatch(cls, _visitor, **_kw): + def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str: return "MAP" @@ -79,11 +81,11 @@ class Row(TypeEngine): A type for rows. """ - def python_type(self): + def python_type(self) -> Optional[Type]: return None @classmethod - def _compiler_dispatch(cls, _visitor, **_kw): + def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str: return "ROW" diff --git a/superset/models/tags.py b/superset/models/tags.py index 0cb00cc4d..c09bb1685 100644 --- a/superset/models/tags.py +++ b/superset/models/tags.py @@ -17,15 +17,23 @@ from __future__ import absolute_import, division, print_function, unicode_literals import enum -from typing import Optional +from typing import List, Optional, TYPE_CHECKING, Union from flask_appbuilder import Model from sqlalchemy import Column, Enum, ForeignKey, Integer, String -from sqlalchemy.orm import relationship, sessionmaker +from sqlalchemy.engine.base import Connection +from sqlalchemy.orm import relationship, Session, sessionmaker from sqlalchemy.orm.exc import NoResultFound +from sqlalchemy.orm.mapper import Mapper from superset.models.helpers import AuditMixinNullable +if TYPE_CHECKING: + from superset.models.core import FavStar # pylint: disable=unused-import + from superset.models.dashboard import Dashboard # pylint: disable=unused-import + from superset.models.slice import Slice # pylint: disable=unused-import + from superset.models.sql_lab import Query # pylint: disable=unused-import + Session = sessionmaker(autoflush=False) @@ -80,7 +88,7 @@ class TaggedObject(Model, AuditMixinNullable): tag = relationship("Tag", backref="objects") -def get_tag(name, session, type_): +def get_tag(name: str, session: Session, type_: TagTypes) -> Tag: try: tag = session.query(Tag).filter_by(name=name, type=type_).one() except NoResultFound: @@ -91,7 +99,7 @@ def get_tag(name, session, type_): return tag -def get_object_type(class_name): +def get_object_type(class_name: str) -> ObjectTypes: mapping = { "slice": ObjectTypes.chart, "dashboard": ObjectTypes.dashboard, @@ -108,11 +116,15 @@ class ObjectUpdater: object_type: Optional[str] = None @classmethod - def get_owners_ids(cls, target): + def get_owners_ids( + cls, target: Union["Dashboard", "FavStar", "Slice"] + ) -> List[int]: raise NotImplementedError("Subclass should implement `get_owners_ids`") @classmethod - def _add_owners(cls, session, target): + def _add_owners( + cls, session: Session, target: Union["Dashboard", "FavStar", "Slice"] + ) -> None: for owner_id in cls.get_owners_ids(target): name = "owner:{0}".format(owner_id) tag = get_tag(name, session, TagTypes.owner) @@ -122,7 +134,12 @@ class ObjectUpdater: session.add(tagged_object) @classmethod - def after_insert(cls, mapper, connection, target): + def after_insert( + cls, + mapper: Mapper, + connection: Connection, + target: Union["Dashboard", "FavStar", "Slice"], + ) -> None: # pylint: disable=unused-argument session = Session(bind=connection) @@ -139,7 +156,12 @@ class ObjectUpdater: session.commit() @classmethod - def after_update(cls, mapper, connection, target): + def after_update( + cls, + mapper: Mapper, + connection: Connection, + target: Union["Dashboard", "FavStar", "Slice"], + ) -> None: # pylint: disable=unused-argument session = Session(bind=connection) @@ -164,7 +186,12 @@ class ObjectUpdater: session.commit() @classmethod - def after_delete(cls, mapper, connection, target): + def after_delete( + cls, + mapper: Mapper, + connection: Connection, + target: Union["Dashboard", "FavStar", "Slice"], + ) -> None: # pylint: disable=unused-argument session = Session(bind=connection) @@ -182,7 +209,7 @@ class ChartUpdater(ObjectUpdater): object_type = "chart" @classmethod - def get_owners_ids(cls, target): + def get_owners_ids(cls, target: "Slice") -> List[int]: return [owner.id for owner in target.owners] @@ -191,7 +218,7 @@ class DashboardUpdater(ObjectUpdater): object_type = "dashboard" @classmethod - def get_owners_ids(cls, target): + def get_owners_ids(cls, target: "Dashboard") -> List[int]: return [owner.id for owner in target.owners] @@ -200,13 +227,15 @@ class QueryUpdater(ObjectUpdater): object_type = "query" @classmethod - def get_owners_ids(cls, target): + def get_owners_ids(cls, target: "Query") -> List[int]: return [target.user_id] class FavStarUpdater: @classmethod - def after_insert(cls, mapper, connection, target): + def after_insert( + cls, mapper: Mapper, connection: Connection, target: "FavStar" + ) -> None: # pylint: disable=unused-argument session = Session(bind=connection) name = "favorited_by:{0}".format(target.user_id) @@ -221,7 +250,9 @@ class FavStarUpdater: session.commit() @classmethod - def after_delete(cls, mapper, connection, target): + def after_delete( + cls, mapper: Mapper, connection: Connection, target: "FavStar" + ) -> None: # pylint: disable=unused-argument session = Session(bind=connection) name = "favorited_by:{0}".format(target.user_id) diff --git a/superset/utils/cache.py b/superset/utils/cache.py index a7a5dc6b8..b55500509 100644 --- a/superset/utils/cache.py +++ b/superset/utils/cache.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Callable, Optional + from flask import request from superset.extensions import cache_manager @@ -24,7 +26,9 @@ def view_cache_key(*_, **__) -> str: return "view/{}/{}".format(request.path, args_hash) -def memoized_func(key=view_cache_key, attribute_in_key=None): +def memoized_func( + key: Callable = view_cache_key, attribute_in_key: Optional[str] = None +) -> Callable: """Use this decorator to cache functions that have predefined first arg. enable_cache is treated as True by default, diff --git a/superset/utils/core.py b/superset/utils/core.py index 19e39ab08..e093d3d05 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -143,7 +143,7 @@ class _memoized: self.func = func self.cache = {} self.is_method = False - self.watch = watch + self.watch = watch or [] def __call__(self, *args, **kwargs): key = [args, frozenset(kwargs.items())] @@ -172,7 +172,7 @@ class _memoized: return functools.partial(self.__call__, obj) -def memoized(func=None, watch=None): +def memoized(func: Optional[Callable] = None, watch: Optional[List[str]] = None): if func: return _memoized(func) else: