style: enforcing mypy typing for connectors (#9824)

Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
John Bodley 2020-05-25 12:32:49 -07:00 committed by GitHub
parent 9edfc8f68d
commit 7f6dbf838e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 392 additions and 240 deletions

View File

@ -45,7 +45,7 @@ combine_as_imports = true
include_trailing_comma = true
line_length = 88
known_first_party = superset
known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytz,retry,selenium,setuptools,simplejson,sphinx_rtd_theme,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml
known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,croniter,cryptography,dataclasses,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytz,retry,selenium,setuptools,simplejson,sphinx_rtd_theme,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml
multi_line_output = 3
order_by_type = false
@ -53,7 +53,7 @@ order_by_type = false
ignore_missing_imports = true
no_implicit_optional = true
[mypy-superset.bin.*,superset.charts.*,superset.datasets.*,superset.dashboards.*,superset.commands.*,superset.common.*,superset.dao.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.models.*,superset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*,superset.translations.*]
[mypy-superset.bin.*,superset.charts.*,superset.commands.*,superset.common.*,superset.connectors.*,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.models.*,uperset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*,superset.translations.*]
check_untyped_defs = true
disallow_untyped_calls = true
disallow_untyped_defs = true

View File

@ -20,12 +20,12 @@ from typing import Any, Dict, Hashable, List, Optional, Type
from flask_appbuilder.security.sqla.models import User
from sqlalchemy import and_, Boolean, Column, Integer, String, Text
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import foreign, Query, relationship
from sqlalchemy.orm import foreign, Query, relationship, RelationshipProperty
from superset.constants import NULL_STRING
from superset.models.helpers import AuditMixinNullable, ImportMixin, QueryResult
from superset.models.slice import Slice
from superset.typing import FilterValue, FilterValues
from superset.typing import FilterValue, FilterValues, QueryObjectDict
from superset.utils import core as utils
METRIC_FORM_DATA_PARAMS = [
@ -93,7 +93,7 @@ class BaseDatasource(
update_from_object_fields: List[str]
@declared_attr
def slices(self):
def slices(self) -> RelationshipProperty:
return relationship(
"Slice",
primaryjoin=lambda: and_(
@ -117,7 +117,7 @@ class BaseDatasource(
return sorted([c.column_name for c in self.columns], key=lambda x: x or "")
@property
def columns_types(self) -> Dict:
def columns_types(self) -> Dict[str, str]:
return {c.column_name: c.type for c in self.columns}
@property
@ -125,7 +125,7 @@ class BaseDatasource(
return "timestamp"
@property
def datasource_name(self):
def datasource_name(self) -> str:
raise NotImplementedError()
@property
@ -143,7 +143,7 @@ class BaseDatasource(
return sorted([c.column_name for c in self.columns if c.filterable])
@property
def dttm_cols(self) -> List:
def dttm_cols(self) -> List[str]:
return []
@property
@ -182,7 +182,7 @@ class BaseDatasource(
}
@property
def select_star(self):
def select_star(self) -> Optional[str]:
pass
@property
@ -336,18 +336,18 @@ class BaseDatasource(
values = None
return values
def external_metadata(self):
def external_metadata(self) -> List[Dict[str, str]]:
"""Returns column information from the external system"""
raise NotImplementedError()
def get_query_str(self, query_obj) -> str:
def get_query_str(self, query_obj: QueryObjectDict) -> str:
"""Returns a query as a string
This is used to be displayed to the user so that she/he can
understand what is taking place behind the scene"""
raise NotImplementedError()
def query(self, query_obj) -> QueryResult:
def query(self, query_obj: QueryObjectDict) -> QueryResult:
"""Executes the query and returns a dataframe
query_obj is a dictionary representing Superset's query interface.
@ -363,7 +363,7 @@ class BaseDatasource(
raise NotImplementedError()
@staticmethod
def default_query(qry) -> Query:
def default_query(qry: Query) -> Query:
return qry
def get_column(self, column_name: Optional[str]) -> Optional["BaseColumn"]:
@ -376,8 +376,8 @@ class BaseDatasource(
@staticmethod
def get_fk_many_from_list(
object_list, fkmany, fkmany_class, key_attr
): # pylint: disable=too-many-locals
object_list: List[Any], fkmany: List[Column], fkmany_class: Type, key_attr: str,
) -> List[Column]: # pylint: disable=too-many-locals
"""Update ORM one-to-many list from object list
Used for syncing metrics and columns using the same code"""
@ -390,8 +390,9 @@ class BaseDatasource(
# sync existing fks
for fk in fkmany:
obj = object_dict.get(getattr(fk, key_attr))
for attr in fkmany_class.update_from_object_fields:
setattr(fk, attr, obj.get(attr))
if obj:
for attr in fkmany_class.update_from_object_fields:
setattr(fk, attr, obj.get(attr))
# create new fks
new_fks = []
@ -409,7 +410,7 @@ class BaseDatasource(
fkmany += new_fks
return fkmany
def update_from_object(self, obj) -> None:
def update_from_object(self, obj: Dict[str, Any]) -> None:
"""Update datasource from a data structure
The UI's table editor crafts a complex data structure that
@ -426,18 +427,26 @@ class BaseDatasource(
self.owners = obj.get("owners", [])
# Syncing metrics
metrics = self.get_fk_many_from_list(
obj.get("metrics"), self.metrics, self.metric_class, "metric_name"
metrics = (
self.get_fk_many_from_list(
obj["metrics"], self.metrics, self.metric_class, "metric_name"
)
if self.metric_class and "metrics" in obj
else []
)
self.metrics = metrics
# Syncing columns
self.columns = self.get_fk_many_from_list(
obj.get("columns"), self.columns, self.column_class, "column_name"
self.columns = (
self.get_fk_many_from_list(
obj["columns"], self.columns, self.column_class, "column_name"
)
if self.column_class and "columns" in obj
else []
)
def get_extra_cache_keys( # pylint: disable=no-self-use
self, query_obj: Dict[str, Any] # pylint: disable=unused-argument
self, query_obj: QueryObjectDict # pylint: disable=unused-argument
) -> List[Hashable]:
""" If a datasource needs to provide additional keys for calculation of
cache keys, those can be provided via this method
@ -474,7 +483,7 @@ class BaseColumn(AuditMixinNullable, ImportMixin):
# [optional] Set this to support import/export functionality
export_fields: List[Any] = []
def __repr__(self):
def __repr__(self) -> str:
return self.column_name
num_types = (
@ -505,11 +514,11 @@ class BaseColumn(AuditMixinNullable, ImportMixin):
return self.type and any(map(lambda t: t in self.type.upper(), self.str_types))
@property
def expression(self):
def expression(self) -> Column:
raise NotImplementedError()
@property
def python_date_format(self):
def python_date_format(self) -> Column:
raise NotImplementedError()
@property
@ -557,11 +566,11 @@ class BaseMetric(AuditMixinNullable, ImportMixin):
"""
@property
def perm(self):
def perm(self) -> Optional[str]:
raise NotImplementedError()
@property
def expression(self):
def expression(self) -> Column:
raise NotImplementedError()
@property

View File

@ -16,12 +16,13 @@
# under the License.
from flask import Markup
from superset.connectors.base.models import BaseDatasource
from superset.exceptions import SupersetException
from superset.views.base import SupersetModelView
class DatasourceModelView(SupersetModelView):
def pre_delete(self, item):
def pre_delete(self, item: BaseDatasource) -> None:
if item.slices:
raise SupersetException(
Markup(

View File

@ -24,7 +24,18 @@ from copy import deepcopy
from datetime import datetime, timedelta
from distutils.version import LooseVersion
from multiprocessing.pool import ThreadPool
from typing import Any, cast, Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import (
Any,
Callable,
cast,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
Union,
)
import pandas as pd
import sqlalchemy as sa
@ -54,7 +65,7 @@ from superset.constants import NULL_STRING
from superset.exceptions import SupersetException
from superset.models.core import Database
from superset.models.helpers import AuditMixinNullable, ImportMixin, QueryResult
from superset.typing import FilterValues
from superset.typing import FilterValues, Granularity, Metric, QueryObjectDict
from superset.utils import core as utils, import_datasource
try:
@ -99,7 +110,7 @@ logger = logging.getLogger(__name__)
try:
# Postaggregator might not have been imported.
class JavascriptPostAggregator(Postaggregator):
def __init__(self, name, field_names, function):
def __init__(self, name: str, field_names: List[str], function: str) -> None:
self.post_aggregator = {
"type": "javascript",
"fieldNames": field_names,
@ -111,7 +122,7 @@ try:
class CustomPostAggregator(Postaggregator):
"""A way to allow users to specify completely custom PostAggregators"""
def __init__(self, name, post_aggregator):
def __init__(self, name: str, post_aggregator: Dict[str, Any]) -> None:
self.name = name
self.post_aggregator = post_aggregator
@ -121,7 +132,7 @@ except NameError:
# Function wrapper because bound methods cannot
# be passed to processes
def _fetch_metadata_for(datasource):
def _fetch_metadata_for(datasource: "DruidDatasource") -> Optional[Dict[str, Any]]:
return datasource.latest_metadata()
@ -155,10 +166,10 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
update_from_object_fields = export_fields
export_children = ["datasources"]
def __repr__(self):
def __repr__(self) -> str:
return self.verbose_name if self.verbose_name else self.cluster_name
def __html__(self):
def __html__(self) -> str:
return self.__repr__()
@property
@ -166,7 +177,7 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
return {"id": self.id, "name": self.cluster_name, "backend": "druid"}
@staticmethod
def get_base_url(host, port) -> str:
def get_base_url(host: str, port: int) -> str:
if not re.match("http(s)?://", host):
host = "http://" + host
@ -335,7 +346,7 @@ class DruidColumn(Model, BaseColumn):
update_from_object_fields = export_fields
export_parent = "datasource"
def __repr__(self):
def __repr__(self) -> str:
return self.column_name or str(self.id)
@property
@ -380,7 +391,7 @@ class DruidColumn(Model, BaseColumn):
@classmethod
def import_obj(cls, i_column: "DruidColumn") -> "DruidColumn":
def lookup_obj(lookup_column: "DruidColumn") -> Optional["DruidColumn"]:
def lookup_obj(lookup_column: DruidColumn) -> Optional[DruidColumn]:
return (
db.session.query(DruidColumn)
.filter(
@ -423,7 +434,7 @@ class DruidMetric(Model, BaseMetric):
export_parent = "datasource"
@property
def expression(self):
def expression(self) -> Column:
return self.json
@property
@ -558,8 +569,8 @@ class DruidDatasource(Model, BaseDatasource):
obj=self
)
def update_from_object(self, obj):
return NotImplementedError()
def update_from_object(self, obj: Dict[str, Any]) -> None:
raise NotImplementedError()
@property
def link(self) -> Markup:
@ -594,7 +605,7 @@ class DruidDatasource(Model, BaseDatasource):
"time_grains": ["now"],
}
def __repr__(self):
def __repr__(self) -> str:
return self.datasource_name
@renders("datasource_name")
@ -634,7 +645,7 @@ class DruidDatasource(Model, BaseDatasource):
db.session, i_datasource, lookup_cluster, lookup_datasource, import_time
)
def latest_metadata(self):
def latest_metadata(self) -> Optional[Dict[str, Any]]:
"""Returns segment metadata from the latest segment"""
logger.info("Syncing datasource [{}]".format(self.datasource_name))
client = self.cluster.get_pydruid_client()
@ -686,6 +697,7 @@ class DruidDatasource(Model, BaseDatasource):
logger.exception(ex)
if segment_metadata:
return segment_metadata[-1]["columns"]
return None
def refresh_metrics(self) -> None:
for col in self.columns:
@ -772,7 +784,7 @@ class DruidDatasource(Model, BaseDatasource):
session.commit()
@staticmethod
def time_offset(granularity: Union[str, Dict]) -> int:
def time_offset(granularity: Granularity) -> int:
if granularity == "week_ending_saturday":
return 6 * 24 * 3600 * 1000 # 6 days
return 0
@ -795,7 +807,7 @@ class DruidDatasource(Model, BaseDatasource):
@staticmethod
def granularity(
period_name: str, timezone: Optional[str] = None, origin: Optional[str] = None
) -> Union[str, Dict]:
) -> Union[Dict[str, str], str]:
if not period_name or period_name == "all":
return "all"
iso_8601_dict = {
@ -817,7 +829,7 @@ class DruidDatasource(Model, BaseDatasource):
"year": "P1Y",
}
granularity: Dict[str, Union[str, float]] = {"type": "period"}
granularity = {"type": "period"}
if timezone:
granularity["timeZone"] = timezone
@ -840,12 +852,12 @@ class DruidDatasource(Model, BaseDatasource):
else:
granularity["type"] = "duration"
granularity["duration"] = (
utils.parse_human_timedelta(period_name).total_seconds() * 1000
utils.parse_human_timedelta(period_name).total_seconds() * 1000 # type: ignore
)
return granularity
@staticmethod
def get_post_agg(mconf: Dict) -> "Postaggregator":
def get_post_agg(mconf: Dict[str, Any]) -> "Postaggregator":
"""
For a metric specified as `postagg` returns the
kind of post aggregation for pydruid.
@ -904,7 +916,13 @@ class DruidDatasource(Model, BaseDatasource):
return list(set(field_names))
@staticmethod
def resolve_postagg(postagg, post_aggs, agg_names, visited_postaggs, metrics_dict):
def resolve_postagg(
postagg: DruidMetric,
post_aggs: Dict[str, Any],
agg_names: Set[str],
visited_postaggs: Set[str],
metrics_dict: Dict[str, DruidMetric],
) -> None:
mconf = postagg.json_obj
required_fields = set(
DruidDatasource.recursive_get_fields(mconf) + mconf.get("fieldNames", [])
@ -939,9 +957,7 @@ class DruidDatasource(Model, BaseDatasource):
@staticmethod
def metrics_and_post_aggs(
metrics: List[Union[Dict, str]],
metrics_dict: Dict[str, DruidMetric],
druid_version=None,
metrics: List[Union[Dict, str]], metrics_dict: Dict[str, DruidMetric],
) -> Tuple[OrderedDict, OrderedDict]:
# Separate metrics into those that are aggregations
# and those that are post aggregations
@ -998,10 +1014,17 @@ class DruidDatasource(Model, BaseDatasource):
df = client.export_pandas()
return df[column_name].to_list()
def get_query_str(self, query_obj, phase=1, client=None):
def get_query_str(
self,
query_obj: QueryObjectDict,
phase: int = 1,
client: Optional["PyDruid"] = None,
) -> str:
return self.run_query(client=client, phase=phase, **query_obj)
def _add_filter_from_pre_query_data(self, df: pd.DataFrame, dimensions, dim_filter):
def _add_filter_from_pre_query_data(
self, df: pd.DataFrame, dimensions: List[Any], dim_filter: "Filter"
) -> "Filter":
ret = dim_filter
if not df.empty:
new_filters = []
@ -1043,7 +1066,7 @@ class DruidDatasource(Model, BaseDatasource):
return ret
@staticmethod
def druid_type_from_adhoc_metric(adhoc_metric: Dict) -> str:
def druid_type_from_adhoc_metric(adhoc_metric: Dict[str, Any]) -> str:
column_type = adhoc_metric["column"]["type"].lower()
aggregate = adhoc_metric["aggregate"].lower()
@ -1115,12 +1138,14 @@ class DruidDatasource(Model, BaseDatasource):
)
@staticmethod
def _dimensions_to_values(dimensions):
def _dimensions_to_values(
dimensions: List[Union[Dict[str, str], str]]
) -> List[Union[Dict[str, str], str]]:
"""
Replace dimensions specs with their `dimension`
values, and ignore those without
"""
values = []
values: List[Union[Dict[str, str], str]] = []
for dimension in dimensions:
if isinstance(dimension, dict):
if "extractionFn" in dimension:
@ -1133,37 +1158,37 @@ class DruidDatasource(Model, BaseDatasource):
return values
@staticmethod
def sanitize_metric_object(metric: Dict) -> None:
def sanitize_metric_object(metric: Metric) -> None:
"""
Update a metric with the correct type if necessary.
:param dict metric: The metric to sanitize
"""
if (
utils.is_adhoc_metric(metric)
and metric["column"]["type"].upper() == "FLOAT"
and metric["column"]["type"].upper() == "FLOAT" # type: ignore
):
metric["column"]["type"] = "DOUBLE"
metric["column"]["type"] = "DOUBLE" # type: ignore
def run_query( # druid
self,
metrics,
granularity,
from_dttm,
to_dttm,
columns=None,
groupby=None,
filter=None,
is_timeseries=True,
timeseries_limit=None,
timeseries_limit_metric=None,
row_limit=None,
inner_from_dttm=None,
inner_to_dttm=None,
orderby=None,
extras=None,
phase=2,
client=None,
order_desc=True,
metrics: List[Metric],
granularity: str,
from_dttm: datetime,
to_dttm: datetime,
columns: Optional[List[str]] = None,
groupby: Optional[List[str]] = None,
filter: Optional[List[Dict[str, Any]]] = None,
is_timeseries: Optional[bool] = True,
timeseries_limit: Optional[int] = None,
timeseries_limit_metric: Optional[Metric] = None,
row_limit: Optional[int] = None,
inner_from_dttm: Optional[datetime] = None,
inner_to_dttm: Optional[datetime] = None,
orderby: Optional[Any] = None,
extras: Optional[Dict[str, Any]] = None,
phase: int = 2,
client: Optional["PyDruid"] = None,
order_desc: bool = True,
) -> str:
"""Runs a query against Druid and returns a dataframe.
"""
@ -1190,17 +1215,16 @@ class DruidDatasource(Model, BaseDatasource):
) < LooseVersion("0.11.0"):
for metric in metrics:
self.sanitize_metric_object(metric)
self.sanitize_metric_object(timeseries_limit_metric)
if timeseries_limit_metric:
self.sanitize_metric_object(timeseries_limit_metric)
aggregations, post_aggs = DruidDatasource.metrics_and_post_aggs(
metrics, metrics_dict
)
# the dimensions list with dimensionSpecs expanded
dimensions = self.get_dimensions(
columns if IS_SIP_38 else groupby, columns_dict
)
columns_ = columns if IS_SIP_38 else groupby
dimensions = self.get_dimensions(columns_, columns_dict) if columns_ else []
extras = extras or {}
qry = dict(
@ -1217,17 +1241,24 @@ class DruidDatasource(Model, BaseDatasource):
if is_timeseries:
qry["context"] = dict(skipEmptyBuckets=True)
filters = DruidDatasource.get_filters(filter, self.num_cols, columns_dict)
filters = (
DruidDatasource.get_filters(filter, self.num_cols, columns_dict)
if filter
else None
)
if filters:
qry["filter"] = filters
having_filters = self.get_having_filters(extras.get("having_druid"))
if having_filters:
qry["having"] = having_filters
if "having_druid" in extras:
having_filters = self.get_having_filters(extras["having_druid"])
if having_filters:
qry["having"] = having_filters
else:
having_filters = None
order_direction = "descending" if order_desc else "ascending"
if (IS_SIP_38 and not metrics and "__time" not in columns) or (
if (IS_SIP_38 and not metrics and columns and "__time" not in columns) or (
not IS_SIP_38 and columns
):
columns.append("__time")
@ -1240,7 +1271,7 @@ class DruidDatasource(Model, BaseDatasource):
qry["limit"] = row_limit
client.scan(**qry)
elif (IS_SIP_38 and columns) or (
not IS_SIP_38 and len(groupby) == 0 and not having_filters
not IS_SIP_38 and not groupby and not having_filters
):
logger.info("Running timeseries query for no groupby values")
del qry["dimensions"]
@ -1249,13 +1280,14 @@ class DruidDatasource(Model, BaseDatasource):
not having_filters
and order_desc
and (
(IS_SIP_38 and len(columns) == 1)
or (not IS_SIP_38 and len(groupby) == 1)
(IS_SIP_38 and columns and len(columns) == 1)
or (not IS_SIP_38 and groupby and len(groupby) == 1)
)
):
dim = list(qry["dimensions"])[0]
logger.info("Running two-phase topn query for dimension [{}]".format(dim))
pre_qry = deepcopy(qry)
order_by: Optional[str] = None
if timeseries_limit_metric:
order_by = utils.get_metric_name(timeseries_limit_metric)
aggs_dict, post_aggs_dict = DruidDatasource.metrics_and_post_aggs(
@ -1275,7 +1307,7 @@ class DruidDatasource(Model, BaseDatasource):
pre_qry["granularity"] = "all"
pre_qry["threshold"] = min(row_limit, timeseries_limit or row_limit)
pre_qry["metric"] = order_by
pre_qry["dimension"] = self._dimensions_to_values(qry.get("dimensions"))[0]
pre_qry["dimension"] = self._dimensions_to_values(qry["dimensions"])[0]
del pre_qry["dimensions"]
client.topn(**pre_qry)
@ -1303,10 +1335,7 @@ class DruidDatasource(Model, BaseDatasource):
qry["metric"] = list(qry["aggregations"].keys())[0]
client.topn(**qry)
logger.info("Phase 2 Complete")
elif (
having_filters
or ((IS_SIP_38 and columns) or (not IS_SIP_38 and len(groupby))) > 0
):
elif having_filters or ((IS_SIP_38 and columns) or (not IS_SIP_38 and groupby)):
# If grouping on multiple fields or using a having filter
# we have to force a groupby query
logger.info("Running groupby query for dimensions [{}]".format(dimensions))
@ -1322,13 +1351,13 @@ class DruidDatasource(Model, BaseDatasource):
set([x for x in pre_qry_dims if not isinstance(x, dict)])
)
dict_dims = [x for x in pre_qry_dims if isinstance(x, dict)]
pre_qry["dimensions"] = non_dict_dims + dict_dims
pre_qry["dimensions"] = non_dict_dims + dict_dims # type: ignore
order_by = None
if metrics:
order_by = utils.get_metric_name(metrics[0])
else:
order_by = pre_qry_dims[0]
order_by = pre_qry_dims[0] # type: ignore
if timeseries_limit_metric:
order_by = utils.get_metric_name(timeseries_limit_metric)
@ -1366,7 +1395,7 @@ class DruidDatasource(Model, BaseDatasource):
if df is None:
df = pd.DataFrame()
qry["filter"] = self._add_filter_from_pre_query_data(
df, pre_qry["dimensions"], filters
df, pre_qry["dimensions"], qry["filter"]
)
qry["limit_spec"] = None
if row_limit:
@ -1446,7 +1475,7 @@ class DruidDatasource(Model, BaseDatasource):
time_offset = DruidDatasource.time_offset(query_obj["granularity"])
def increment_timestamp(ts):
def increment_timestamp(ts: str) -> datetime:
dt = utils.parse_human_datetime(ts).replace(tzinfo=DRUID_TZ)
return dt + timedelta(milliseconds=time_offset)
@ -1458,7 +1487,17 @@ class DruidDatasource(Model, BaseDatasource):
)
@staticmethod
def _create_extraction_fn(dim_spec):
def _create_extraction_fn(
dim_spec: Dict[str, Any]
) -> Tuple[
str,
Union[
"MapLookupExtraction",
"RegexExtraction",
"RegisteredLookupExtraction",
"TimeFormatExtraction",
],
]:
extraction_fn = None
if dim_spec and "extractionFn" in dim_spec:
col = dim_spec["dimension"]
@ -1487,7 +1526,12 @@ class DruidDatasource(Model, BaseDatasource):
return (col, extraction_fn)
@classmethod
def get_filters(cls, raw_filters, num_cols, columns_dict) -> "Filter":
def get_filters(
cls,
raw_filters: List[Dict[str, Any]],
num_cols: List[str],
columns_dict: Dict[str, DruidColumn],
) -> "Filter":
"""Given Superset filter data structure, returns pydruid Filter(s)"""
filters = None
for flt in raw_filters:
@ -1641,7 +1685,9 @@ class DruidDatasource(Model, BaseDatasource):
return cond
def get_having_filters(self, raw_filters: List[Dict[str, Any]]) -> "Having":
def get_having_filters(
self, raw_filters: List[Dict[str, Any]]
) -> Optional["Having"]:
filters = None
reversed_op_map = {
FilterOperator.NOT_EQUALS.value: FilterOperator.EQUALS.value,
@ -1673,16 +1719,18 @@ class DruidDatasource(Model, BaseDatasource):
@classmethod
def query_datasources_by_name(
cls, session: Session, database: Database, datasource_name: str, schema=None
cls,
session: Session,
database: Database,
datasource_name: str,
schema: Optional[str] = None,
) -> List["DruidDatasource"]:
return []
def external_metadata(self) -> List[Dict]:
def external_metadata(self) -> List[Dict[str, Any]]:
self.merge_flag = True
return [
{"name": k, "type": v.get("type")}
for k, v in self.latest_metadata().items()
]
latest_metadata = self.latest_metadata() or {}
return [{"name": k, "type": v.get("type")} for k, v in latest_metadata.items()]
sa.event.listen(DruidDatasource, "after_insert", security_manager.set_perm)

View File

@ -31,6 +31,7 @@ from superset import app, appbuilder, db, security_manager
from superset.connectors.base.views import DatasourceModelView
from superset.connectors.connector_registry import ConnectorRegistry
from superset.constants import RouteMethod
from superset.typing import FlaskResponse
from superset.utils import core as utils
from superset.views.base import (
BaseSupersetView,
@ -106,7 +107,7 @@ class DruidColumnInlineView(CompactCRUDMixin, SupersetModelView):
edit_form_extra_fields = add_form_extra_fields
def pre_update(self, col):
def pre_update(self, col: "DruidColumnInlineView") -> None:
# If a dimension spec JSON is given, ensure that it is
# valid JSON and that `outputName` is specified
if col.dimension_spec_json:
@ -128,10 +129,10 @@ class DruidColumnInlineView(CompactCRUDMixin, SupersetModelView):
)
)
def post_update(self, col):
def post_update(self, col: "DruidColumnInlineView") -> None:
col.refresh_metrics()
def post_add(self, col):
def post_add(self, col: "DruidColumnInlineView") -> None:
self.post_update(col)
@ -240,13 +241,13 @@ class DruidClusterModelView(SupersetModelView, DeleteMixin, YamlExportMixin):
yaml_dict_key = "databases"
def pre_add(self, cluster):
def pre_add(self, cluster: "DruidClusterModelView") -> None:
security_manager.add_permission_view_menu("database_access", cluster.perm)
def pre_update(self, cluster):
def pre_update(self, cluster: "DruidClusterModelView") -> None:
self.pre_add(cluster)
def _delete(self, pk):
def _delete(self, pk: int) -> None:
DeleteMixin._delete(self, pk)
@ -334,7 +335,7 @@ class DruidDatasourceModelView(DatasourceModelView, DeleteMixin, YamlExportMixin
"modified": _("Modified"),
}
def pre_add(self, datasource):
def pre_add(self, datasource: "DruidDatasourceModelView") -> None:
with db.session.no_autoflush:
query = db.session.query(models.DruidDatasource).filter(
models.DruidDatasource.datasource_name == datasource.datasource_name,
@ -343,7 +344,7 @@ class DruidDatasourceModelView(DatasourceModelView, DeleteMixin, YamlExportMixin
if db.session.query(query.exists()).scalar():
raise Exception(get_datasource_exist_error_msg(datasource.full_name))
def post_add(self, datasource):
def post_add(self, datasource: "DruidDatasourceModelView") -> None:
datasource.refresh_metrics()
security_manager.add_permission_view_menu(
"datasource_access", datasource.get_perm()
@ -353,10 +354,10 @@ class DruidDatasourceModelView(DatasourceModelView, DeleteMixin, YamlExportMixin
"schema_access", datasource.schema_perm
)
def post_update(self, datasource):
def post_update(self, datasource: "DruidDatasourceModelView") -> None:
self.post_add(datasource)
def _delete(self, pk):
def _delete(self, pk: int) -> None:
DeleteMixin._delete(self, pk)
@ -365,7 +366,7 @@ class Druid(BaseSupersetView):
@has_access
@expose("/refresh_datasources/")
def refresh_datasources(self, refresh_all=True):
def refresh_datasources(self, refresh_all: bool = True) -> FlaskResponse:
"""endpoint that refreshes druid datasources metadata"""
session = db.session()
DruidCluster = ConnectorRegistry.sources["druid"].cluster_class
@ -397,7 +398,7 @@ class Druid(BaseSupersetView):
@has_access
@expose("/scan_new_datasources/")
def scan_new_datasources(self):
def scan_new_datasources(self) -> FlaskResponse:
"""
Calling this endpoint will cause a scan for new
datasources only and add them.

View File

@ -54,10 +54,15 @@ from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetr
from superset.constants import NULL_STRING
from superset.db_engine_specs.base import TimestampExpression
from superset.exceptions import DatabaseNotFound
from superset.jinja_context import ExtraCache, get_template_processor
from superset.jinja_context import (
BaseTemplateProcessor,
ExtraCache,
get_template_processor,
)
from superset.models.annotations import Annotation
from superset.models.core import Database
from superset.models.helpers import AuditMixinNullable, QueryResult
from superset.typing import Metric, QueryObjectDict
from superset.utils import core as utils, import_datasource
config = app.config
@ -86,7 +91,7 @@ class AnnotationDatasource(BaseDatasource):
cache_timeout = 0
changed_on = None
def query(self, query_obj: Dict[str, Any]) -> QueryResult:
def query(self, query_obj: QueryObjectDict) -> QueryResult:
error_message = None
qry = db.session.query(Annotation)
qry = qry.filter(Annotation.layer_id == query_obj["filter"][0]["val"])
@ -110,10 +115,10 @@ class AnnotationDatasource(BaseDatasource):
error_message=error_message,
)
def get_query_str(self, query_obj):
def get_query_str(self, query_obj: QueryObjectDict) -> str:
raise NotImplementedError()
def values_for_column(self, column_name, limit=10000):
def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
raise NotImplementedError()
@ -239,8 +244,8 @@ class TableColumn(Model, BaseColumn):
return self.table.make_sqla_column_compatible(time_expr, label)
@classmethod
def import_obj(cls, i_column):
def lookup_obj(lookup_column):
def import_obj(cls, i_column: "TableColumn") -> "TableColumn":
def lookup_obj(lookup_column: TableColumn) -> TableColumn:
return (
db.session.query(TableColumn)
.filter(
@ -343,8 +348,8 @@ class SqlMetric(Model, BaseMetric):
return self.perm
@classmethod
def import_obj(cls, i_metric):
def lookup_obj(lookup_metric):
def import_obj(cls, i_metric: "SqlMetric") -> "SqlMetric":
def lookup_obj(lookup_metric: SqlMetric) -> SqlMetric:
return (
db.session.query(SqlMetric)
.filter(
@ -442,7 +447,7 @@ class SqlaTable(Model, BaseDatasource):
sqla_col._df_label_expected = label_expected
return sqla_col
def __repr__(self):
def __repr__(self) -> str:
return self.name
@property
@ -521,14 +526,14 @@ class SqlaTable(Model, BaseDatasource):
)
@property
def dttm_cols(self) -> List:
def dttm_cols(self) -> List[str]:
l = [c.column_name for c in self.columns if c.is_dttm]
if self.main_dttm_col and self.main_dttm_col not in l:
l.append(self.main_dttm_col)
return l
@property
def num_cols(self) -> List:
def num_cols(self) -> List[str]:
return [c.column_name for c in self.columns if c.is_numeric]
@property
@ -550,7 +555,7 @@ class SqlaTable(Model, BaseDatasource):
def sql_url(self) -> str:
return self.database.sql_url + "?table_name=" + str(self.table_name)
def external_metadata(self):
def external_metadata(self) -> List[Dict[str, str]]:
cols = self.database.get_columns(self.table_name, schema=self.schema)
for col in cols:
try:
@ -567,7 +572,7 @@ class SqlaTable(Model, BaseDatasource):
}
@property
def select_star(self) -> str:
def select_star(self) -> Optional[str]:
# show_cols and latest_partition set to false to avoid
# the expensive cost of inspecting the DB
return self.database.select_star(
@ -589,7 +594,7 @@ class SqlaTable(Model, BaseDatasource):
d["is_sqllab_view"] = self.is_sqllab_view
return d
def values_for_column(self, column_name: str, limit: int = 10000) -> List:
def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
"""Runs query against sqla to retrieve some
sample values for the given column.
"""
@ -626,10 +631,10 @@ class SqlaTable(Model, BaseDatasource):
sql = SQL_QUERY_MUTATOR(sql, username, security_manager, self.database)
return sql
def get_template_processor(self, **kwargs):
def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor:
return get_template_processor(table=self, database=self.database, **kwargs)
def get_query_str_extended(self, query_obj: Dict[str, Any]) -> QueryStringExtended:
def get_query_str_extended(self, query_obj: QueryObjectDict) -> QueryStringExtended:
sqlaq = self.get_sqla_query(**query_obj)
sql = self.database.compile_sqla_query(sqlaq.sqla_query)
logger.info(sql)
@ -639,18 +644,20 @@ class SqlaTable(Model, BaseDatasource):
labels_expected=sqlaq.labels_expected, sql=sql, prequeries=sqlaq.prequeries
)
def get_query_str(self, query_obj: Dict[str, Any]) -> str:
def get_query_str(self, query_obj: QueryObjectDict) -> str:
query_str_ext = self.get_query_str_extended(query_obj)
all_queries = query_str_ext.prequeries + [query_str_ext.sql]
return ";\n\n".join(all_queries) + ";"
def get_sqla_table(self):
def get_sqla_table(self) -> table:
tbl = table(self.table_name)
if self.schema:
tbl.schema = self.schema
return tbl
def get_from_clause(self, template_processor=None):
def get_from_clause(
self, template_processor: Optional[BaseTemplateProcessor] = None
) -> Union[table, TextAsFrom]:
# Supporting arbitrary SQL statements in place of tables
if self.sql:
from_sql = self.sql
@ -687,7 +694,9 @@ class SqlaTable(Model, BaseDatasource):
return self.make_sqla_column_compatible(sqla_metric, label)
def _get_sqla_row_level_filters(self, template_processor) -> List[str]:
def _get_sqla_row_level_filters(
self, template_processor: BaseTemplateProcessor
) -> List[str]:
"""
Return the appropriate row level security filters for this table and the current user.
@ -702,22 +711,22 @@ class SqlaTable(Model, BaseDatasource):
def get_sqla_query( # sqla
self,
metrics,
granularity,
from_dttm,
to_dttm,
columns=None,
groupby=None,
filter=None,
is_timeseries=True,
timeseries_limit=15,
timeseries_limit_metric=None,
row_limit=None,
inner_from_dttm=None,
inner_to_dttm=None,
orderby=None,
extras=None,
order_desc=True,
metrics: List[Metric],
granularity: str,
from_dttm: datetime,
to_dttm: datetime,
columns: Optional[List[str]] = None,
groupby: Optional[List[str]] = None,
filter: Optional[List[Dict[str, Any]]] = None,
is_timeseries: bool = True,
timeseries_limit: int = 15,
timeseries_limit_metric: Optional[Metric] = None,
row_limit: Optional[int] = None,
inner_from_dttm: Optional[datetime] = None,
inner_to_dttm: Optional[datetime] = None,
orderby: Optional[List[Tuple[ColumnElement, bool]]] = None,
extras: Optional[Dict[str, Any]] = None,
order_desc: bool = True,
) -> SqlaQuery:
"""Querying any sqla table from this common interface"""
template_kwargs = {
@ -765,8 +774,9 @@ class SqlaTable(Model, BaseDatasource):
metrics_exprs: List[ColumnElement] = []
for m in metrics:
if utils.is_adhoc_metric(m):
assert isinstance(m, dict)
metrics_exprs.append(self.adhoc_metric_to_sqla(m, cols))
elif m in metrics_dict:
elif isinstance(m, str) and m in metrics_dict:
metrics_exprs.append(metrics_dict[m].get_sqla_col())
else:
raise Exception(_("Metric '%(metric)s' does not exist", metric=m))
@ -781,7 +791,9 @@ class SqlaTable(Model, BaseDatasource):
if (is_sip_38 and metrics and columns) or (not is_sip_38 and groupby):
# dedup columns while preserving order
groupby = list(dict.fromkeys(columns if is_sip_38 else groupby))
columns_ = columns if is_sip_38 else groupby
assert columns_
groupby = list(dict.fromkeys(columns_))
select_exprs = []
for s in groupby:
@ -802,6 +814,7 @@ class SqlaTable(Model, BaseDatasource):
)
metrics_exprs = []
assert extras is not None
time_range_endpoints = extras.get("time_range_endpoints")
groupby_exprs_with_timestamp = OrderedDict(groupby_exprs_sans_timestamp.items())
if granularity:
@ -845,7 +858,8 @@ class SqlaTable(Model, BaseDatasource):
where_clause_and = []
having_clause_and: List = []
for flt in filter:
for flt in filter: # type: ignore
if not all([flt.get(s) for s in ["col", "op"]]):
continue
col = flt["col"]
@ -1029,12 +1043,20 @@ class SqlaTable(Model, BaseDatasource):
prequeries=prequeries,
)
def _get_timeseries_orderby(self, timeseries_limit_metric, metrics_dict, cols):
def _get_timeseries_orderby(
self,
timeseries_limit_metric: Metric,
metrics_dict: Dict[str, SqlMetric],
cols: Dict[str, Column],
) -> Optional[Column]:
if utils.is_adhoc_metric(timeseries_limit_metric):
assert isinstance(timeseries_limit_metric, dict)
ob = self.adhoc_metric_to_sqla(timeseries_limit_metric, cols)
elif timeseries_limit_metric in metrics_dict:
timeseries_limit_metric = metrics_dict.get(timeseries_limit_metric)
ob = timeseries_limit_metric.get_sqla_col()
elif (
isinstance(timeseries_limit_metric, str)
and timeseries_limit_metric in metrics_dict
):
ob = metrics_dict[timeseries_limit_metric].get_sqla_col()
else:
raise Exception(
_("Metric '%(metric)s' does not exist", metric=timeseries_limit_metric)
@ -1054,7 +1076,7 @@ class SqlaTable(Model, BaseDatasource):
return or_(*groups)
def query(self, query_obj: Dict[str, Any]) -> QueryResult:
def query(self, query_obj: QueryObjectDict) -> QueryResult:
qry_start_dttm = datetime.now()
query_str_ext = self.get_query_str_extended(query_obj)
sql = query_str_ext.sql
@ -1101,7 +1123,7 @@ class SqlaTable(Model, BaseDatasource):
def get_sqla_table_object(self) -> Table:
return self.database.get_table(self.table_name, schema=self.schema)
def fetch_metadata(self, commit=True) -> None:
def fetch_metadata(self, commit: bool = True) -> None:
"""Fetches the metadata for the table and merges it in"""
try:
table = self.get_sqla_table_object()
@ -1166,7 +1188,9 @@ class SqlaTable(Model, BaseDatasource):
db.session.commit()
@classmethod
def import_obj(cls, i_datasource, import_time=None) -> int:
def import_obj(
cls, i_datasource: "SqlaTable", import_time: Optional[int] = None
) -> int:
"""Imports the datasource from the object to the database.
Metrics and columns and datasource will be overrided if exists.
@ -1174,7 +1198,7 @@ class SqlaTable(Model, BaseDatasource):
superset instances. Audit metadata isn't copies over.
"""
def lookup_sqlatable(table):
def lookup_sqlatable(table: "SqlaTable") -> "SqlaTable":
return (
db.session.query(SqlaTable)
.join(Database)
@ -1186,7 +1210,7 @@ class SqlaTable(Model, BaseDatasource):
.first()
)
def lookup_database(table):
def lookup_database(table: SqlaTable) -> Database:
try:
return (
db.session.query(Database)
@ -1207,7 +1231,11 @@ class SqlaTable(Model, BaseDatasource):
@classmethod
def query_datasources_by_name(
cls, session: Session, database: Database, datasource_name: str, schema=None
cls,
session: Session,
database: Database,
datasource_name: str,
schema: Optional[str] = None,
) -> List["SqlaTable"]:
query = (
session.query(cls)
@ -1219,10 +1247,10 @@ class SqlaTable(Model, BaseDatasource):
return query.all()
@staticmethod
def default_query(qry) -> Query:
def default_query(qry: Query) -> Query:
return qry.filter_by(is_sqllab_view=False)
def has_extra_cache_key_calls(self, query_obj: Dict[str, Any]) -> bool:
def has_extra_cache_key_calls(self, query_obj: QueryObjectDict) -> bool:
"""
Detects the presence of calls to `ExtraCache` methods in items in query_obj that
can be templated. If any are present, the query must be evaluated to extract
@ -1248,7 +1276,7 @@ class SqlaTable(Model, BaseDatasource):
return True
return False
def get_extra_cache_keys(self, query_obj: Dict[str, Any]) -> List[Hashable]:
def get_extra_cache_keys(self, query_obj: QueryObjectDict) -> List[Hashable]:
"""
The cache key of a SqlaTable needs to consider any keys added by the parent
class and any keys added via `ExtraCache`.

View File

@ -18,6 +18,7 @@
"""Views used by the SqlAlchemy connector"""
import logging
import re
from typing import List, Union
from flask import flash, Markup, redirect
from flask_appbuilder import CompactCRUDMixin, expose
@ -32,6 +33,7 @@ from wtforms.validators import Regexp
from superset import app, db, security_manager
from superset.connectors.base.views import DatasourceModelView
from superset.constants import RouteMethod
from superset.typing import FlaskResponse
from superset.utils import core as utils
from superset.views.base import (
create_table_permissions,
@ -375,10 +377,10 @@ class TableModelView(DatasourceModelView, DeleteMixin, YamlExportMixin):
)
}
def pre_add(self, table):
def pre_add(self, table: "TableModelView") -> None:
validate_sqlatable(table)
def post_add(self, table, flash_message=True):
def post_add(self, table: "TableModelView", flash_message: bool = True) -> None:
table.fetch_metadata()
create_table_permissions(table)
if flash_message:
@ -392,15 +394,15 @@ class TableModelView(DatasourceModelView, DeleteMixin, YamlExportMixin):
"info",
)
def post_update(self, table):
def post_update(self, table: "TableModelView") -> None:
self.post_add(table, flash_message=False)
def _delete(self, pk):
def _delete(self, pk: int) -> None:
DeleteMixin._delete(self, pk)
@expose("/edit/<pk>", methods=["GET", "POST"])
@has_access
def edit(self, pk):
def edit(self, pk: int) -> FlaskResponse:
"""Simple hack to redirect to explore view after saving"""
resp = super(TableModelView, self).edit(pk)
if isinstance(resp, str):
@ -410,7 +412,9 @@ class TableModelView(DatasourceModelView, DeleteMixin, YamlExportMixin):
@action(
"refresh", __("Refresh Metadata"), __("Refresh column metadata"), "fa-refresh"
)
def refresh(self, tables):
def refresh(
self, tables: Union["TableModelView", List["TableModelView"]]
) -> FlaskResponse:
if not isinstance(tables, list):
tables = [tables]
successes = []
@ -439,7 +443,7 @@ class TableModelView(DatasourceModelView, DeleteMixin, YamlExportMixin):
@expose("/list/")
@has_access
def list(self):
def list(self) -> FlaskResponse:
if not app.config["ENABLE_REACT_CRUD_VIEWS"]:
return super().list()

View File

@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument
import dataclasses
import hashlib
import json
import logging
@ -34,6 +33,7 @@ from typing import (
Union,
)
import dataclasses
import pandas as pd
import sqlparse
from flask import g

View File

@ -15,10 +15,11 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-few-public-methods,invalid-name
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, Optional
from dataclasses import dataclass
class SupersetErrorType(str, Enum):
"""

View File

@ -17,7 +17,7 @@
"""Defines the templating context for SQL Lab"""
import inspect
import re
from typing import Any, List, Optional, Tuple
from typing import Any, List, Optional, Tuple, TYPE_CHECKING
from flask import g, request
from jinja2.sandbox import SandboxedEnvironment
@ -26,6 +26,13 @@ from superset import jinja_base_context
from superset.extensions import jinja_context_manager
from superset.utils.core import convert_legacy_filters_into_adhoc, merge_extra_filters
if TYPE_CHECKING:
from superset.connectors.sqla.models import ( # pylint: disable=unused-import
SqlaTable,
)
from superset.models.core import Database # pylint: disable=unused-import
from superset.models.sql_lab import Query # pylint: disable=unused-import
def filter_values(column: str, default: Optional[str] = None) -> List[str]:
""" Gets a values for a particular filter as a list
@ -200,12 +207,12 @@ class BaseTemplateProcessor: # pylint: disable=too-few-public-methods
def __init__(
self,
database=None,
query=None,
table=None,
database: Optional["Database"] = None,
query: Optional["Query"] = None,
table: Optional["SqlaTable"] = None,
extra_cache_keys: Optional[List[Any]] = None,
**kwargs
):
**kwargs: Any,
) -> None:
self.database = database
self.query = query
self.schema = None
@ -230,7 +237,7 @@ class BaseTemplateProcessor: # pylint: disable=too-few-public-methods
self.context[self.engine] = self
self.env = SandboxedEnvironment()
def process_template(self, sql: str, **kwargs) -> str:
def process_template(self, sql: str, **kwargs: Any) -> str:
"""Processes a sql template
>>> sql = "SELECT '{{ datetime(2017, 1, 1).isoformat() }}'"
@ -279,12 +286,14 @@ class PrestoTemplateProcessor(BaseTemplateProcessor):
"""
table_name, schema = self._schema_table(table_name, self.schema)
return self.database.db_engine_spec.latest_partition(
assert self.database
return self.database.db_engine_spec.latest_partition( # type: ignore
table_name, schema, self.database
)[1]
def latest_sub_partition(self, table_name, **kwargs):
table_name, schema = self._schema_table(table_name, self.schema)
assert self.database
return self.database.db_engine_spec.latest_sub_partition(
table_name=table_name, schema=schema, database=self.database, **kwargs
)
@ -305,7 +314,12 @@ for k in keys:
template_processors[o.engine] = o
def get_template_processor(database, table=None, query=None, **kwargs):
def get_template_processor(
database: "Database",
table: Optional["SqlaTable"] = None,
query: Optional["Query"] = None,
**kwargs: Any,
) -> BaseTemplateProcessor:
template_processor = template_processors.get(
database.backend, BaseTemplateProcessor
)

View File

@ -15,11 +15,11 @@
# specific language governing permissions and limitations
# under the License.
import logging
from dataclasses import dataclass
from typing import List, Optional, Set
from urllib import parse
import sqlparse
from dataclasses import dataclass
from sqlparse.sql import Identifier, IdentifierList, remove_quotes, Token, TokenList
from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace
from sqlparse.utils import imt

View File

@ -18,6 +18,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from flask import Flask
from flask_caching import Cache
from werkzeug.wrappers import Response
CacheConfig = Union[Callable[[Flask], Cache], Dict[str, Any]]
DbapiDescriptionRow = Tuple[
@ -27,4 +28,15 @@ DbapiDescription = Union[List[DbapiDescriptionRow], Tuple[DbapiDescriptionRow, .
DbapiResult = List[Union[List[Any], Tuple[Any, ...]]]
FilterValue = Union[float, int, str]
FilterValues = Union[FilterValue, List[FilterValue], Tuple[FilterValue]]
Granularity = Union[str, Dict[str, Union[str, float]]]
Metric = Union[Dict[str, str], str]
QueryObjectDict = Dict[str, Any]
VizData = Optional[Union[List[Any], Dict[Any, Any]]]
# Flask response.
Base = Union[bytes, str]
Status = Union[int, str]
Headers = Dict[str, Any]
FlaskResponse = Union[
Response, Base, Tuple[Base, Status], Tuple[Base, Status, Headers],
]

View File

@ -43,10 +43,12 @@ from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
NamedTuple,
Optional,
Sequence,
Set,
Tuple,
TYPE_CHECKING,
@ -79,6 +81,7 @@ from superset.exceptions import (
SupersetException,
SupersetTimeoutException,
)
from superset.typing import Metric
from superset.utils.dates import datetime_to_epoch, EPOCH
try:
@ -101,7 +104,7 @@ JS_MAX_INTEGER = 9007199254740991 # Largest int Java Script can handle 2^53-1
try:
# Having might not have been imported.
class DimSelector(Having):
def __init__(self, **args):
def __init__(self, **args: Any) -> None:
# Just a hack to prevent any exceptions
Having.__init__(self, type="equalTo", aggregation=None, value=None)
@ -118,7 +121,7 @@ except NameError:
pass
def flasher(msg, severity=None):
def flasher(msg: str, severity: str) -> None:
"""Flask's flash if available, logging call if not"""
try:
flash(msg, severity)
@ -235,7 +238,7 @@ def list_minus(l: List, minus: List) -> List:
return [o for o in l if o not in minus]
def parse_human_datetime(s: Optional[str]) -> Optional[datetime]:
def parse_human_datetime(s: str) -> datetime:
"""
Returns ``datetime.datetime`` from human readable strings
@ -256,8 +259,6 @@ def parse_human_datetime(s: Optional[str]) -> Optional[datetime]:
>>> year_ago_1 == year_ago_2
True
"""
if not s:
return None
try:
dttm = parse(s)
except Exception:
@ -564,7 +565,9 @@ def generic_find_uq_constraint_name(table, columns, insp):
return uq["name"]
def get_datasource_full_name(database_name, datasource_name, schema=None):
def get_datasource_full_name(
database_name: str, datasource_name: str, schema: Optional[str] = None
) -> str:
if not schema:
return "[{}].[{}]".format(database_name, datasource_name)
return "[{}].[{}].[{}]".format(database_name, schema, datasource_name)
@ -792,7 +795,7 @@ def get_email_address_list(address_string: str) -> List[str]:
return [x.strip() for x in address_string_list if x.strip()]
def choicify(values):
def choicify(values: Iterable[Any]) -> List[Tuple[Any, Any]]:
"""Takes an iterable and makes an iterable of tuples with it"""
return [(v, v) for v in values]
@ -967,7 +970,7 @@ def get_example_database() -> "Database":
return get_or_create_db("examples", db_uri)
def is_adhoc_metric(metric) -> bool:
def is_adhoc_metric(metric: Metric) -> bool:
return bool(
isinstance(metric, dict)
and (
@ -985,11 +988,11 @@ def is_adhoc_metric(metric) -> bool:
)
def get_metric_name(metric):
return metric["label"] if is_adhoc_metric(metric) else metric
def get_metric_name(metric: Metric) -> str:
return metric["label"] if is_adhoc_metric(metric) else metric # type: ignore
def get_metric_names(metrics):
def get_metric_names(metrics: Sequence[Metric]) -> List[str]:
return [get_metric_name(metric) for metric in metrics]

View File

@ -15,15 +15,22 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Callable, Optional
from flask_appbuilder import Model
from sqlalchemy.orm import Session
from sqlalchemy.orm.session import make_transient
logger = logging.getLogger(__name__)
def import_datasource(
session, i_datasource, lookup_database, lookup_datasource, import_time
):
session: Session,
i_datasource: Model,
lookup_database: Callable,
lookup_datasource: Callable,
import_time: Optional[int] = None,
) -> int:
"""Imports the datasource from the object to the database.
Metrics and columns and datasource will be overrided if exists.
@ -75,7 +82,7 @@ def import_datasource(
return datasource.id
def import_simple_obj(session, i_obj, lookup_obj):
def import_simple_obj(session: Session, i_obj: Model, lookup_obj: Callable) -> Model:
make_transient(i_obj)
i_obj.id = None
i_obj.table = None

View File

@ -14,13 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import dataclasses
import functools
import logging
import traceback
from datetime import datetime
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union
import dataclasses
import simplejson as json
import yaml
from flask import abort, flash, g, get_flashed_messages, redirect, Response, session
@ -48,10 +48,16 @@ from superset.connectors.sqla import models
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetException, SupersetSecurityException
from superset.translations.utils import get_language_pack
from superset.typing import FlaskResponse
from superset.utils import core as utils
from .utils import bootstrap_user_data
if TYPE_CHECKING:
from superset.connectors.druid.views import ( # pylint: disable=unused-import
DruidClusterModelView,
)
FRONTEND_CONF_KEYS = (
"SUPERSET_WEBSERVER_TIMEOUT",
"SUPERSET_DASHBOARD_POSITION_DATA_LIMIT",
@ -305,7 +311,7 @@ class SupersetModelView(ModelView):
page_size = 100
list_widget = SupersetListWidget
def render_app_template(self):
def render_app_template(self) -> FlaskResponse:
payload = {
"user": bootstrap_user_data(g.user),
"common": common_bootstrap_payload(),
@ -359,7 +365,9 @@ class YamlExportMixin: # pylint: disable=too-few-public-methods
class DeleteMixin: # pylint: disable=too-few-public-methods
def _delete(self, primary_key):
def _delete(
self: Union[BaseView, "DeleteMixin", "DruidClusterModelView"], primary_key: int,
) -> None:
"""
Delete function logic, override to implement diferent logic
deletes the record with primary_key = primary_key
@ -367,11 +375,11 @@ class DeleteMixin: # pylint: disable=too-few-public-methods
:param primary_key:
record primary key to delete
"""
item = self.datamodel.get(primary_key, self._base_filters)
item = self.datamodel.get(primary_key, self._base_filters) # type: ignore
if not item:
abort(404)
try:
self.pre_delete(item)
self.pre_delete(item) # type: ignore
except Exception as ex: # pylint: disable=broad-except
flash(str(ex), "danger")
else:
@ -384,8 +392,8 @@ class DeleteMixin: # pylint: disable=too-few-public-methods
.all()
)
if self.datamodel.delete(item):
self.post_delete(item)
if self.datamodel.delete(item): # type: ignore
self.post_delete(item) # type: ignore
for pv in pvs:
security_manager.get_session.delete(pv)
@ -395,8 +403,8 @@ class DeleteMixin: # pylint: disable=too-few-public-methods
security_manager.get_session.commit()
flash(*self.datamodel.message)
self.update_redirect()
flash(*self.datamodel.message) # type: ignore
self.update_redirect() # type: ignore
@action(
"muldelete", __("Delete"), __("Delete all Really?"), "fa-trash", single=False

View File

@ -21,7 +21,6 @@ These objects represent the backend of all the visualizations that
Superset can render.
"""
import copy
import dataclasses
import hashlib
import inspect
import logging
@ -34,6 +33,7 @@ from datetime import datetime, timedelta
from itertools import product
from typing import Any, Dict, List, Optional, Set, Tuple, TYPE_CHECKING
import dataclasses
import geohash
import numpy as np
import pandas as pd
@ -1049,9 +1049,9 @@ class BubbleViz(NVD3Viz):
# dedup groupby if it happens to be the same
d["groupby"] = list(dict.fromkeys(d["groupby"]))
self.x_metric = form_data.get("x")
self.y_metric = form_data.get("y")
self.z_metric = form_data.get("size")
self.x_metric = form_data["x"]
self.y_metric = form_data["y"]
self.z_metric = form_data["size"]
self.entity = form_data.get("entity")
self.series = form_data.get("series") or self.entity
d["row_limit"] = form_data.get("limit")
@ -1093,7 +1093,7 @@ class BulletViz(NVD3Viz):
def query_obj(self):
form_data = self.form_data
d = super().query_obj()
self.metric = form_data.get("metric")
self.metric = form_data["metric"]
d["metrics"] = [self.metric]
if not self.metric:
@ -1451,8 +1451,8 @@ class NVD3DualLineViz(NVD3Viz):
_("Pick a time granularity for your time series")
)
metric = utils.get_metric_name(fd.get("metric"))
metric_2 = utils.get_metric_name(fd.get("metric_2"))
metric = utils.get_metric_name(fd["metric"])
metric_2 = utils.get_metric_name(fd["metric_2"])
df = df.pivot_table(index=DTTM_ALIAS, values=[metric, metric_2])
chart_data = self.to_series(df)
@ -1507,7 +1507,7 @@ class NVD3TimePivotViz(NVD3TimeSeriesViz):
df = df.pivot_table(
index=DTTM_ALIAS,
columns="series",
values=utils.get_metric_name(fd.get("metric")),
values=utils.get_metric_name(fd["metric"]),
)
chart_data = self.to_series(df)
for serie in chart_data:
@ -1690,8 +1690,12 @@ class SunburstViz(BaseViz):
fd = self.form_data
cols = fd.get("groupby") or []
cols.extend(["m1", "m2"])
metric = utils.get_metric_name(fd.get("metric"))
secondary_metric = utils.get_metric_name(fd.get("secondary_metric"))
metric = utils.get_metric_name(fd["metric"])
secondary_metric = (
utils.get_metric_name(fd["secondary_metric"])
if "secondary_metric" in fd
else None
)
if metric == secondary_metric or secondary_metric is None:
df.rename(columns={df.columns[-1]: "m1"}, inplace=True)
df["m2"] = df["m1"]
@ -1872,8 +1876,12 @@ class WorldMapViz(BaseViz):
fd = self.form_data
cols = [fd.get("entity")]
metric = utils.get_metric_name(fd.get("metric"))
secondary_metric = utils.get_metric_name(fd.get("secondary_metric"))
metric = utils.get_metric_name(fd["metric"])
secondary_metric = (
utils.get_metric_name(fd["secondary_metric"])
if "secondary_metric" in fd
else None
)
columns = ["country", "m1", "m2"]
if metric == secondary_metric:
ndf = df[cols]

View File

@ -21,7 +21,6 @@ These objects represent the backend of all the visualizations that
Superset can render.
"""
import copy
import dataclasses
import hashlib
import inspect
import logging
@ -34,6 +33,7 @@ from datetime import datetime, timedelta
from itertools import product
from typing import Any, Dict, List, Optional, Set, Tuple, TYPE_CHECKING
import dataclasses
import geohash
import numpy as np
import pandas as pd
@ -1077,9 +1077,9 @@ class BubbleViz(NVD3Viz):
form_data = self.form_data
d = super().query_obj()
self.x_metric = form_data.get("x")
self.y_metric = form_data.get("y")
self.z_metric = form_data.get("size")
self.x_metric = form_data["x"]
self.y_metric = form_data["y"]
self.z_metric = form_data["size"]
self.entity = form_data.get("entity")
self.series = form_data.get("series") or self.entity
d["row_limit"] = form_data.get("limit")
@ -1476,8 +1476,8 @@ class NVD3DualLineViz(NVD3Viz):
_("Pick a time granularity for your time series")
)
metric = utils.get_metric_name(fd.get("metric"))
metric_2 = utils.get_metric_name(fd.get("metric_2"))
metric = utils.get_metric_name(fd["metric"])
metric_2 = utils.get_metric_name(fd["metric_2"])
df = df.pivot_table(index=DTTM_ALIAS, values=[metric, metric_2])
chart_data = self.to_series(df)
@ -1532,7 +1532,7 @@ class NVD3TimePivotViz(NVD3TimeSeriesViz):
df = df.pivot_table(
index=DTTM_ALIAS,
columns="series",
values=utils.get_metric_name(fd.get("metric")),
values=utils.get_metric_name(fd["metric"]),
)
chart_data = self.to_series(df)
for serie in chart_data:
@ -1710,8 +1710,12 @@ class SunburstViz(BaseViz):
fd = self.form_data
cols = fd.get("groupby") or []
cols.extend(["m1", "m2"])
metric = utils.get_metric_name(fd.get("metric"))
secondary_metric = utils.get_metric_name(fd.get("secondary_metric"))
metric = utils.get_metric_name(fd["metric"])
secondary_metric = (
utils.get_metric_name(fd["secondary_metric"])
if "secondary_metric" in fd
else None
)
if metric == secondary_metric or secondary_metric is None:
df.rename(columns={df.columns[-1]: "m1"}, inplace=True)
df["m2"] = df["m1"]
@ -1868,8 +1872,12 @@ class WorldMapViz(BaseViz):
fd = self.form_data
cols = [fd.get("entity")]
metric = utils.get_metric_name(fd.get("metric"))
secondary_metric = utils.get_metric_name(fd.get("secondary_metric"))
metric = utils.get_metric_name(fd["metric"])
secondary_metric = (
utils.get_metric_name(fd["secondary_metric"])
if "secondary_metric" in fd
else None
)
columns = ["country", "m1", "m2"]
if metric == secondary_metric:
ndf = df[cols]