style: enforcing mypy typing for connectors (#9824)
Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
parent
9edfc8f68d
commit
7f6dbf838e
|
|
@ -45,7 +45,7 @@ combine_as_imports = true
|
|||
include_trailing_comma = true
|
||||
line_length = 88
|
||||
known_first_party = superset
|
||||
known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytz,retry,selenium,setuptools,simplejson,sphinx_rtd_theme,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml
|
||||
known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,croniter,cryptography,dataclasses,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytz,retry,selenium,setuptools,simplejson,sphinx_rtd_theme,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml
|
||||
multi_line_output = 3
|
||||
order_by_type = false
|
||||
|
||||
|
|
@ -53,7 +53,7 @@ order_by_type = false
|
|||
ignore_missing_imports = true
|
||||
no_implicit_optional = true
|
||||
|
||||
[mypy-superset.bin.*,superset.charts.*,superset.datasets.*,superset.dashboards.*,superset.commands.*,superset.common.*,superset.dao.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.models.*,superset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*,superset.translations.*]
|
||||
[mypy-superset.bin.*,superset.charts.*,superset.commands.*,superset.common.*,superset.connectors.*,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.models.*,uperset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*,superset.translations.*]
|
||||
check_untyped_defs = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_defs = true
|
||||
|
|
|
|||
|
|
@ -20,12 +20,12 @@ from typing import Any, Dict, Hashable, List, Optional, Type
|
|||
from flask_appbuilder.security.sqla.models import User
|
||||
from sqlalchemy import and_, Boolean, Column, Integer, String, Text
|
||||
from sqlalchemy.ext.declarative import declared_attr
|
||||
from sqlalchemy.orm import foreign, Query, relationship
|
||||
from sqlalchemy.orm import foreign, Query, relationship, RelationshipProperty
|
||||
|
||||
from superset.constants import NULL_STRING
|
||||
from superset.models.helpers import AuditMixinNullable, ImportMixin, QueryResult
|
||||
from superset.models.slice import Slice
|
||||
from superset.typing import FilterValue, FilterValues
|
||||
from superset.typing import FilterValue, FilterValues, QueryObjectDict
|
||||
from superset.utils import core as utils
|
||||
|
||||
METRIC_FORM_DATA_PARAMS = [
|
||||
|
|
@ -93,7 +93,7 @@ class BaseDatasource(
|
|||
update_from_object_fields: List[str]
|
||||
|
||||
@declared_attr
|
||||
def slices(self):
|
||||
def slices(self) -> RelationshipProperty:
|
||||
return relationship(
|
||||
"Slice",
|
||||
primaryjoin=lambda: and_(
|
||||
|
|
@ -117,7 +117,7 @@ class BaseDatasource(
|
|||
return sorted([c.column_name for c in self.columns], key=lambda x: x or "")
|
||||
|
||||
@property
|
||||
def columns_types(self) -> Dict:
|
||||
def columns_types(self) -> Dict[str, str]:
|
||||
return {c.column_name: c.type for c in self.columns}
|
||||
|
||||
@property
|
||||
|
|
@ -125,7 +125,7 @@ class BaseDatasource(
|
|||
return "timestamp"
|
||||
|
||||
@property
|
||||
def datasource_name(self):
|
||||
def datasource_name(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
|
|
@ -143,7 +143,7 @@ class BaseDatasource(
|
|||
return sorted([c.column_name for c in self.columns if c.filterable])
|
||||
|
||||
@property
|
||||
def dttm_cols(self) -> List:
|
||||
def dttm_cols(self) -> List[str]:
|
||||
return []
|
||||
|
||||
@property
|
||||
|
|
@ -182,7 +182,7 @@ class BaseDatasource(
|
|||
}
|
||||
|
||||
@property
|
||||
def select_star(self):
|
||||
def select_star(self) -> Optional[str]:
|
||||
pass
|
||||
|
||||
@property
|
||||
|
|
@ -336,18 +336,18 @@ class BaseDatasource(
|
|||
values = None
|
||||
return values
|
||||
|
||||
def external_metadata(self):
|
||||
def external_metadata(self) -> List[Dict[str, str]]:
|
||||
"""Returns column information from the external system"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_query_str(self, query_obj) -> str:
|
||||
def get_query_str(self, query_obj: QueryObjectDict) -> str:
|
||||
"""Returns a query as a string
|
||||
|
||||
This is used to be displayed to the user so that she/he can
|
||||
understand what is taking place behind the scene"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def query(self, query_obj) -> QueryResult:
|
||||
def query(self, query_obj: QueryObjectDict) -> QueryResult:
|
||||
"""Executes the query and returns a dataframe
|
||||
|
||||
query_obj is a dictionary representing Superset's query interface.
|
||||
|
|
@ -363,7 +363,7 @@ class BaseDatasource(
|
|||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def default_query(qry) -> Query:
|
||||
def default_query(qry: Query) -> Query:
|
||||
return qry
|
||||
|
||||
def get_column(self, column_name: Optional[str]) -> Optional["BaseColumn"]:
|
||||
|
|
@ -376,8 +376,8 @@ class BaseDatasource(
|
|||
|
||||
@staticmethod
|
||||
def get_fk_many_from_list(
|
||||
object_list, fkmany, fkmany_class, key_attr
|
||||
): # pylint: disable=too-many-locals
|
||||
object_list: List[Any], fkmany: List[Column], fkmany_class: Type, key_attr: str,
|
||||
) -> List[Column]: # pylint: disable=too-many-locals
|
||||
"""Update ORM one-to-many list from object list
|
||||
|
||||
Used for syncing metrics and columns using the same code"""
|
||||
|
|
@ -390,8 +390,9 @@ class BaseDatasource(
|
|||
# sync existing fks
|
||||
for fk in fkmany:
|
||||
obj = object_dict.get(getattr(fk, key_attr))
|
||||
for attr in fkmany_class.update_from_object_fields:
|
||||
setattr(fk, attr, obj.get(attr))
|
||||
if obj:
|
||||
for attr in fkmany_class.update_from_object_fields:
|
||||
setattr(fk, attr, obj.get(attr))
|
||||
|
||||
# create new fks
|
||||
new_fks = []
|
||||
|
|
@ -409,7 +410,7 @@ class BaseDatasource(
|
|||
fkmany += new_fks
|
||||
return fkmany
|
||||
|
||||
def update_from_object(self, obj) -> None:
|
||||
def update_from_object(self, obj: Dict[str, Any]) -> None:
|
||||
"""Update datasource from a data structure
|
||||
|
||||
The UI's table editor crafts a complex data structure that
|
||||
|
|
@ -426,18 +427,26 @@ class BaseDatasource(
|
|||
self.owners = obj.get("owners", [])
|
||||
|
||||
# Syncing metrics
|
||||
metrics = self.get_fk_many_from_list(
|
||||
obj.get("metrics"), self.metrics, self.metric_class, "metric_name"
|
||||
metrics = (
|
||||
self.get_fk_many_from_list(
|
||||
obj["metrics"], self.metrics, self.metric_class, "metric_name"
|
||||
)
|
||||
if self.metric_class and "metrics" in obj
|
||||
else []
|
||||
)
|
||||
self.metrics = metrics
|
||||
|
||||
# Syncing columns
|
||||
self.columns = self.get_fk_many_from_list(
|
||||
obj.get("columns"), self.columns, self.column_class, "column_name"
|
||||
self.columns = (
|
||||
self.get_fk_many_from_list(
|
||||
obj["columns"], self.columns, self.column_class, "column_name"
|
||||
)
|
||||
if self.column_class and "columns" in obj
|
||||
else []
|
||||
)
|
||||
|
||||
def get_extra_cache_keys( # pylint: disable=no-self-use
|
||||
self, query_obj: Dict[str, Any] # pylint: disable=unused-argument
|
||||
self, query_obj: QueryObjectDict # pylint: disable=unused-argument
|
||||
) -> List[Hashable]:
|
||||
""" If a datasource needs to provide additional keys for calculation of
|
||||
cache keys, those can be provided via this method
|
||||
|
|
@ -474,7 +483,7 @@ class BaseColumn(AuditMixinNullable, ImportMixin):
|
|||
# [optional] Set this to support import/export functionality
|
||||
export_fields: List[Any] = []
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return self.column_name
|
||||
|
||||
num_types = (
|
||||
|
|
@ -505,11 +514,11 @@ class BaseColumn(AuditMixinNullable, ImportMixin):
|
|||
return self.type and any(map(lambda t: t in self.type.upper(), self.str_types))
|
||||
|
||||
@property
|
||||
def expression(self):
|
||||
def expression(self) -> Column:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def python_date_format(self):
|
||||
def python_date_format(self) -> Column:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
|
|
@ -557,11 +566,11 @@ class BaseMetric(AuditMixinNullable, ImportMixin):
|
|||
"""
|
||||
|
||||
@property
|
||||
def perm(self):
|
||||
def perm(self) -> Optional[str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def expression(self):
|
||||
def expression(self) -> Column:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -16,12 +16,13 @@
|
|||
# under the License.
|
||||
from flask import Markup
|
||||
|
||||
from superset.connectors.base.models import BaseDatasource
|
||||
from superset.exceptions import SupersetException
|
||||
from superset.views.base import SupersetModelView
|
||||
|
||||
|
||||
class DatasourceModelView(SupersetModelView):
|
||||
def pre_delete(self, item):
|
||||
def pre_delete(self, item: BaseDatasource) -> None:
|
||||
if item.slices:
|
||||
raise SupersetException(
|
||||
Markup(
|
||||
|
|
|
|||
|
|
@ -24,7 +24,18 @@ from copy import deepcopy
|
|||
from datetime import datetime, timedelta
|
||||
from distutils.version import LooseVersion
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from typing import Any, cast, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import pandas as pd
|
||||
import sqlalchemy as sa
|
||||
|
|
@ -54,7 +65,7 @@ from superset.constants import NULL_STRING
|
|||
from superset.exceptions import SupersetException
|
||||
from superset.models.core import Database
|
||||
from superset.models.helpers import AuditMixinNullable, ImportMixin, QueryResult
|
||||
from superset.typing import FilterValues
|
||||
from superset.typing import FilterValues, Granularity, Metric, QueryObjectDict
|
||||
from superset.utils import core as utils, import_datasource
|
||||
|
||||
try:
|
||||
|
|
@ -99,7 +110,7 @@ logger = logging.getLogger(__name__)
|
|||
try:
|
||||
# Postaggregator might not have been imported.
|
||||
class JavascriptPostAggregator(Postaggregator):
|
||||
def __init__(self, name, field_names, function):
|
||||
def __init__(self, name: str, field_names: List[str], function: str) -> None:
|
||||
self.post_aggregator = {
|
||||
"type": "javascript",
|
||||
"fieldNames": field_names,
|
||||
|
|
@ -111,7 +122,7 @@ try:
|
|||
class CustomPostAggregator(Postaggregator):
|
||||
"""A way to allow users to specify completely custom PostAggregators"""
|
||||
|
||||
def __init__(self, name, post_aggregator):
|
||||
def __init__(self, name: str, post_aggregator: Dict[str, Any]) -> None:
|
||||
self.name = name
|
||||
self.post_aggregator = post_aggregator
|
||||
|
||||
|
|
@ -121,7 +132,7 @@ except NameError:
|
|||
|
||||
# Function wrapper because bound methods cannot
|
||||
# be passed to processes
|
||||
def _fetch_metadata_for(datasource):
|
||||
def _fetch_metadata_for(datasource: "DruidDatasource") -> Optional[Dict[str, Any]]:
|
||||
return datasource.latest_metadata()
|
||||
|
||||
|
||||
|
|
@ -155,10 +166,10 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
|
|||
update_from_object_fields = export_fields
|
||||
export_children = ["datasources"]
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return self.verbose_name if self.verbose_name else self.cluster_name
|
||||
|
||||
def __html__(self):
|
||||
def __html__(self) -> str:
|
||||
return self.__repr__()
|
||||
|
||||
@property
|
||||
|
|
@ -166,7 +177,7 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
|
|||
return {"id": self.id, "name": self.cluster_name, "backend": "druid"}
|
||||
|
||||
@staticmethod
|
||||
def get_base_url(host, port) -> str:
|
||||
def get_base_url(host: str, port: int) -> str:
|
||||
if not re.match("http(s)?://", host):
|
||||
host = "http://" + host
|
||||
|
||||
|
|
@ -335,7 +346,7 @@ class DruidColumn(Model, BaseColumn):
|
|||
update_from_object_fields = export_fields
|
||||
export_parent = "datasource"
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return self.column_name or str(self.id)
|
||||
|
||||
@property
|
||||
|
|
@ -380,7 +391,7 @@ class DruidColumn(Model, BaseColumn):
|
|||
|
||||
@classmethod
|
||||
def import_obj(cls, i_column: "DruidColumn") -> "DruidColumn":
|
||||
def lookup_obj(lookup_column: "DruidColumn") -> Optional["DruidColumn"]:
|
||||
def lookup_obj(lookup_column: DruidColumn) -> Optional[DruidColumn]:
|
||||
return (
|
||||
db.session.query(DruidColumn)
|
||||
.filter(
|
||||
|
|
@ -423,7 +434,7 @@ class DruidMetric(Model, BaseMetric):
|
|||
export_parent = "datasource"
|
||||
|
||||
@property
|
||||
def expression(self):
|
||||
def expression(self) -> Column:
|
||||
return self.json
|
||||
|
||||
@property
|
||||
|
|
@ -558,8 +569,8 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
obj=self
|
||||
)
|
||||
|
||||
def update_from_object(self, obj):
|
||||
return NotImplementedError()
|
||||
def update_from_object(self, obj: Dict[str, Any]) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def link(self) -> Markup:
|
||||
|
|
@ -594,7 +605,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
"time_grains": ["now"],
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return self.datasource_name
|
||||
|
||||
@renders("datasource_name")
|
||||
|
|
@ -634,7 +645,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
db.session, i_datasource, lookup_cluster, lookup_datasource, import_time
|
||||
)
|
||||
|
||||
def latest_metadata(self):
|
||||
def latest_metadata(self) -> Optional[Dict[str, Any]]:
|
||||
"""Returns segment metadata from the latest segment"""
|
||||
logger.info("Syncing datasource [{}]".format(self.datasource_name))
|
||||
client = self.cluster.get_pydruid_client()
|
||||
|
|
@ -686,6 +697,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
logger.exception(ex)
|
||||
if segment_metadata:
|
||||
return segment_metadata[-1]["columns"]
|
||||
return None
|
||||
|
||||
def refresh_metrics(self) -> None:
|
||||
for col in self.columns:
|
||||
|
|
@ -772,7 +784,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def time_offset(granularity: Union[str, Dict]) -> int:
|
||||
def time_offset(granularity: Granularity) -> int:
|
||||
if granularity == "week_ending_saturday":
|
||||
return 6 * 24 * 3600 * 1000 # 6 days
|
||||
return 0
|
||||
|
|
@ -795,7 +807,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
@staticmethod
|
||||
def granularity(
|
||||
period_name: str, timezone: Optional[str] = None, origin: Optional[str] = None
|
||||
) -> Union[str, Dict]:
|
||||
) -> Union[Dict[str, str], str]:
|
||||
if not period_name or period_name == "all":
|
||||
return "all"
|
||||
iso_8601_dict = {
|
||||
|
|
@ -817,7 +829,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
"year": "P1Y",
|
||||
}
|
||||
|
||||
granularity: Dict[str, Union[str, float]] = {"type": "period"}
|
||||
granularity = {"type": "period"}
|
||||
if timezone:
|
||||
granularity["timeZone"] = timezone
|
||||
|
||||
|
|
@ -840,12 +852,12 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
else:
|
||||
granularity["type"] = "duration"
|
||||
granularity["duration"] = (
|
||||
utils.parse_human_timedelta(period_name).total_seconds() * 1000
|
||||
utils.parse_human_timedelta(period_name).total_seconds() * 1000 # type: ignore
|
||||
)
|
||||
return granularity
|
||||
|
||||
@staticmethod
|
||||
def get_post_agg(mconf: Dict) -> "Postaggregator":
|
||||
def get_post_agg(mconf: Dict[str, Any]) -> "Postaggregator":
|
||||
"""
|
||||
For a metric specified as `postagg` returns the
|
||||
kind of post aggregation for pydruid.
|
||||
|
|
@ -904,7 +916,13 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
return list(set(field_names))
|
||||
|
||||
@staticmethod
|
||||
def resolve_postagg(postagg, post_aggs, agg_names, visited_postaggs, metrics_dict):
|
||||
def resolve_postagg(
|
||||
postagg: DruidMetric,
|
||||
post_aggs: Dict[str, Any],
|
||||
agg_names: Set[str],
|
||||
visited_postaggs: Set[str],
|
||||
metrics_dict: Dict[str, DruidMetric],
|
||||
) -> None:
|
||||
mconf = postagg.json_obj
|
||||
required_fields = set(
|
||||
DruidDatasource.recursive_get_fields(mconf) + mconf.get("fieldNames", [])
|
||||
|
|
@ -939,9 +957,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
|
||||
@staticmethod
|
||||
def metrics_and_post_aggs(
|
||||
metrics: List[Union[Dict, str]],
|
||||
metrics_dict: Dict[str, DruidMetric],
|
||||
druid_version=None,
|
||||
metrics: List[Union[Dict, str]], metrics_dict: Dict[str, DruidMetric],
|
||||
) -> Tuple[OrderedDict, OrderedDict]:
|
||||
# Separate metrics into those that are aggregations
|
||||
# and those that are post aggregations
|
||||
|
|
@ -998,10 +1014,17 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
df = client.export_pandas()
|
||||
return df[column_name].to_list()
|
||||
|
||||
def get_query_str(self, query_obj, phase=1, client=None):
|
||||
def get_query_str(
|
||||
self,
|
||||
query_obj: QueryObjectDict,
|
||||
phase: int = 1,
|
||||
client: Optional["PyDruid"] = None,
|
||||
) -> str:
|
||||
return self.run_query(client=client, phase=phase, **query_obj)
|
||||
|
||||
def _add_filter_from_pre_query_data(self, df: pd.DataFrame, dimensions, dim_filter):
|
||||
def _add_filter_from_pre_query_data(
|
||||
self, df: pd.DataFrame, dimensions: List[Any], dim_filter: "Filter"
|
||||
) -> "Filter":
|
||||
ret = dim_filter
|
||||
if not df.empty:
|
||||
new_filters = []
|
||||
|
|
@ -1043,7 +1066,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
return ret
|
||||
|
||||
@staticmethod
|
||||
def druid_type_from_adhoc_metric(adhoc_metric: Dict) -> str:
|
||||
def druid_type_from_adhoc_metric(adhoc_metric: Dict[str, Any]) -> str:
|
||||
column_type = adhoc_metric["column"]["type"].lower()
|
||||
aggregate = adhoc_metric["aggregate"].lower()
|
||||
|
||||
|
|
@ -1115,12 +1138,14 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def _dimensions_to_values(dimensions):
|
||||
def _dimensions_to_values(
|
||||
dimensions: List[Union[Dict[str, str], str]]
|
||||
) -> List[Union[Dict[str, str], str]]:
|
||||
"""
|
||||
Replace dimensions specs with their `dimension`
|
||||
values, and ignore those without
|
||||
"""
|
||||
values = []
|
||||
values: List[Union[Dict[str, str], str]] = []
|
||||
for dimension in dimensions:
|
||||
if isinstance(dimension, dict):
|
||||
if "extractionFn" in dimension:
|
||||
|
|
@ -1133,37 +1158,37 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
return values
|
||||
|
||||
@staticmethod
|
||||
def sanitize_metric_object(metric: Dict) -> None:
|
||||
def sanitize_metric_object(metric: Metric) -> None:
|
||||
"""
|
||||
Update a metric with the correct type if necessary.
|
||||
:param dict metric: The metric to sanitize
|
||||
"""
|
||||
if (
|
||||
utils.is_adhoc_metric(metric)
|
||||
and metric["column"]["type"].upper() == "FLOAT"
|
||||
and metric["column"]["type"].upper() == "FLOAT" # type: ignore
|
||||
):
|
||||
metric["column"]["type"] = "DOUBLE"
|
||||
metric["column"]["type"] = "DOUBLE" # type: ignore
|
||||
|
||||
def run_query( # druid
|
||||
self,
|
||||
metrics,
|
||||
granularity,
|
||||
from_dttm,
|
||||
to_dttm,
|
||||
columns=None,
|
||||
groupby=None,
|
||||
filter=None,
|
||||
is_timeseries=True,
|
||||
timeseries_limit=None,
|
||||
timeseries_limit_metric=None,
|
||||
row_limit=None,
|
||||
inner_from_dttm=None,
|
||||
inner_to_dttm=None,
|
||||
orderby=None,
|
||||
extras=None,
|
||||
phase=2,
|
||||
client=None,
|
||||
order_desc=True,
|
||||
metrics: List[Metric],
|
||||
granularity: str,
|
||||
from_dttm: datetime,
|
||||
to_dttm: datetime,
|
||||
columns: Optional[List[str]] = None,
|
||||
groupby: Optional[List[str]] = None,
|
||||
filter: Optional[List[Dict[str, Any]]] = None,
|
||||
is_timeseries: Optional[bool] = True,
|
||||
timeseries_limit: Optional[int] = None,
|
||||
timeseries_limit_metric: Optional[Metric] = None,
|
||||
row_limit: Optional[int] = None,
|
||||
inner_from_dttm: Optional[datetime] = None,
|
||||
inner_to_dttm: Optional[datetime] = None,
|
||||
orderby: Optional[Any] = None,
|
||||
extras: Optional[Dict[str, Any]] = None,
|
||||
phase: int = 2,
|
||||
client: Optional["PyDruid"] = None,
|
||||
order_desc: bool = True,
|
||||
) -> str:
|
||||
"""Runs a query against Druid and returns a dataframe.
|
||||
"""
|
||||
|
|
@ -1190,17 +1215,16 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
) < LooseVersion("0.11.0"):
|
||||
for metric in metrics:
|
||||
self.sanitize_metric_object(metric)
|
||||
self.sanitize_metric_object(timeseries_limit_metric)
|
||||
if timeseries_limit_metric:
|
||||
self.sanitize_metric_object(timeseries_limit_metric)
|
||||
|
||||
aggregations, post_aggs = DruidDatasource.metrics_and_post_aggs(
|
||||
metrics, metrics_dict
|
||||
)
|
||||
|
||||
# the dimensions list with dimensionSpecs expanded
|
||||
|
||||
dimensions = self.get_dimensions(
|
||||
columns if IS_SIP_38 else groupby, columns_dict
|
||||
)
|
||||
columns_ = columns if IS_SIP_38 else groupby
|
||||
dimensions = self.get_dimensions(columns_, columns_dict) if columns_ else []
|
||||
|
||||
extras = extras or {}
|
||||
qry = dict(
|
||||
|
|
@ -1217,17 +1241,24 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
if is_timeseries:
|
||||
qry["context"] = dict(skipEmptyBuckets=True)
|
||||
|
||||
filters = DruidDatasource.get_filters(filter, self.num_cols, columns_dict)
|
||||
filters = (
|
||||
DruidDatasource.get_filters(filter, self.num_cols, columns_dict)
|
||||
if filter
|
||||
else None
|
||||
)
|
||||
if filters:
|
||||
qry["filter"] = filters
|
||||
|
||||
having_filters = self.get_having_filters(extras.get("having_druid"))
|
||||
if having_filters:
|
||||
qry["having"] = having_filters
|
||||
if "having_druid" in extras:
|
||||
having_filters = self.get_having_filters(extras["having_druid"])
|
||||
if having_filters:
|
||||
qry["having"] = having_filters
|
||||
else:
|
||||
having_filters = None
|
||||
|
||||
order_direction = "descending" if order_desc else "ascending"
|
||||
|
||||
if (IS_SIP_38 and not metrics and "__time" not in columns) or (
|
||||
if (IS_SIP_38 and not metrics and columns and "__time" not in columns) or (
|
||||
not IS_SIP_38 and columns
|
||||
):
|
||||
columns.append("__time")
|
||||
|
|
@ -1240,7 +1271,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
qry["limit"] = row_limit
|
||||
client.scan(**qry)
|
||||
elif (IS_SIP_38 and columns) or (
|
||||
not IS_SIP_38 and len(groupby) == 0 and not having_filters
|
||||
not IS_SIP_38 and not groupby and not having_filters
|
||||
):
|
||||
logger.info("Running timeseries query for no groupby values")
|
||||
del qry["dimensions"]
|
||||
|
|
@ -1249,13 +1280,14 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
not having_filters
|
||||
and order_desc
|
||||
and (
|
||||
(IS_SIP_38 and len(columns) == 1)
|
||||
or (not IS_SIP_38 and len(groupby) == 1)
|
||||
(IS_SIP_38 and columns and len(columns) == 1)
|
||||
or (not IS_SIP_38 and groupby and len(groupby) == 1)
|
||||
)
|
||||
):
|
||||
dim = list(qry["dimensions"])[0]
|
||||
logger.info("Running two-phase topn query for dimension [{}]".format(dim))
|
||||
pre_qry = deepcopy(qry)
|
||||
order_by: Optional[str] = None
|
||||
if timeseries_limit_metric:
|
||||
order_by = utils.get_metric_name(timeseries_limit_metric)
|
||||
aggs_dict, post_aggs_dict = DruidDatasource.metrics_and_post_aggs(
|
||||
|
|
@ -1275,7 +1307,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
pre_qry["granularity"] = "all"
|
||||
pre_qry["threshold"] = min(row_limit, timeseries_limit or row_limit)
|
||||
pre_qry["metric"] = order_by
|
||||
pre_qry["dimension"] = self._dimensions_to_values(qry.get("dimensions"))[0]
|
||||
pre_qry["dimension"] = self._dimensions_to_values(qry["dimensions"])[0]
|
||||
del pre_qry["dimensions"]
|
||||
|
||||
client.topn(**pre_qry)
|
||||
|
|
@ -1303,10 +1335,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
qry["metric"] = list(qry["aggregations"].keys())[0]
|
||||
client.topn(**qry)
|
||||
logger.info("Phase 2 Complete")
|
||||
elif (
|
||||
having_filters
|
||||
or ((IS_SIP_38 and columns) or (not IS_SIP_38 and len(groupby))) > 0
|
||||
):
|
||||
elif having_filters or ((IS_SIP_38 and columns) or (not IS_SIP_38 and groupby)):
|
||||
# If grouping on multiple fields or using a having filter
|
||||
# we have to force a groupby query
|
||||
logger.info("Running groupby query for dimensions [{}]".format(dimensions))
|
||||
|
|
@ -1322,13 +1351,13 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
set([x for x in pre_qry_dims if not isinstance(x, dict)])
|
||||
)
|
||||
dict_dims = [x for x in pre_qry_dims if isinstance(x, dict)]
|
||||
pre_qry["dimensions"] = non_dict_dims + dict_dims
|
||||
pre_qry["dimensions"] = non_dict_dims + dict_dims # type: ignore
|
||||
|
||||
order_by = None
|
||||
if metrics:
|
||||
order_by = utils.get_metric_name(metrics[0])
|
||||
else:
|
||||
order_by = pre_qry_dims[0]
|
||||
order_by = pre_qry_dims[0] # type: ignore
|
||||
|
||||
if timeseries_limit_metric:
|
||||
order_by = utils.get_metric_name(timeseries_limit_metric)
|
||||
|
|
@ -1366,7 +1395,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
if df is None:
|
||||
df = pd.DataFrame()
|
||||
qry["filter"] = self._add_filter_from_pre_query_data(
|
||||
df, pre_qry["dimensions"], filters
|
||||
df, pre_qry["dimensions"], qry["filter"]
|
||||
)
|
||||
qry["limit_spec"] = None
|
||||
if row_limit:
|
||||
|
|
@ -1446,7 +1475,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
|
||||
time_offset = DruidDatasource.time_offset(query_obj["granularity"])
|
||||
|
||||
def increment_timestamp(ts):
|
||||
def increment_timestamp(ts: str) -> datetime:
|
||||
dt = utils.parse_human_datetime(ts).replace(tzinfo=DRUID_TZ)
|
||||
return dt + timedelta(milliseconds=time_offset)
|
||||
|
||||
|
|
@ -1458,7 +1487,17 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_extraction_fn(dim_spec):
|
||||
def _create_extraction_fn(
|
||||
dim_spec: Dict[str, Any]
|
||||
) -> Tuple[
|
||||
str,
|
||||
Union[
|
||||
"MapLookupExtraction",
|
||||
"RegexExtraction",
|
||||
"RegisteredLookupExtraction",
|
||||
"TimeFormatExtraction",
|
||||
],
|
||||
]:
|
||||
extraction_fn = None
|
||||
if dim_spec and "extractionFn" in dim_spec:
|
||||
col = dim_spec["dimension"]
|
||||
|
|
@ -1487,7 +1526,12 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
return (col, extraction_fn)
|
||||
|
||||
@classmethod
|
||||
def get_filters(cls, raw_filters, num_cols, columns_dict) -> "Filter":
|
||||
def get_filters(
|
||||
cls,
|
||||
raw_filters: List[Dict[str, Any]],
|
||||
num_cols: List[str],
|
||||
columns_dict: Dict[str, DruidColumn],
|
||||
) -> "Filter":
|
||||
"""Given Superset filter data structure, returns pydruid Filter(s)"""
|
||||
filters = None
|
||||
for flt in raw_filters:
|
||||
|
|
@ -1641,7 +1685,9 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
|
||||
return cond
|
||||
|
||||
def get_having_filters(self, raw_filters: List[Dict[str, Any]]) -> "Having":
|
||||
def get_having_filters(
|
||||
self, raw_filters: List[Dict[str, Any]]
|
||||
) -> Optional["Having"]:
|
||||
filters = None
|
||||
reversed_op_map = {
|
||||
FilterOperator.NOT_EQUALS.value: FilterOperator.EQUALS.value,
|
||||
|
|
@ -1673,16 +1719,18 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
|
||||
@classmethod
|
||||
def query_datasources_by_name(
|
||||
cls, session: Session, database: Database, datasource_name: str, schema=None
|
||||
cls,
|
||||
session: Session,
|
||||
database: Database,
|
||||
datasource_name: str,
|
||||
schema: Optional[str] = None,
|
||||
) -> List["DruidDatasource"]:
|
||||
return []
|
||||
|
||||
def external_metadata(self) -> List[Dict]:
|
||||
def external_metadata(self) -> List[Dict[str, Any]]:
|
||||
self.merge_flag = True
|
||||
return [
|
||||
{"name": k, "type": v.get("type")}
|
||||
for k, v in self.latest_metadata().items()
|
||||
]
|
||||
latest_metadata = self.latest_metadata() or {}
|
||||
return [{"name": k, "type": v.get("type")} for k, v in latest_metadata.items()]
|
||||
|
||||
|
||||
sa.event.listen(DruidDatasource, "after_insert", security_manager.set_perm)
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ from superset import app, appbuilder, db, security_manager
|
|||
from superset.connectors.base.views import DatasourceModelView
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.constants import RouteMethod
|
||||
from superset.typing import FlaskResponse
|
||||
from superset.utils import core as utils
|
||||
from superset.views.base import (
|
||||
BaseSupersetView,
|
||||
|
|
@ -106,7 +107,7 @@ class DruidColumnInlineView(CompactCRUDMixin, SupersetModelView):
|
|||
|
||||
edit_form_extra_fields = add_form_extra_fields
|
||||
|
||||
def pre_update(self, col):
|
||||
def pre_update(self, col: "DruidColumnInlineView") -> None:
|
||||
# If a dimension spec JSON is given, ensure that it is
|
||||
# valid JSON and that `outputName` is specified
|
||||
if col.dimension_spec_json:
|
||||
|
|
@ -128,10 +129,10 @@ class DruidColumnInlineView(CompactCRUDMixin, SupersetModelView):
|
|||
)
|
||||
)
|
||||
|
||||
def post_update(self, col):
|
||||
def post_update(self, col: "DruidColumnInlineView") -> None:
|
||||
col.refresh_metrics()
|
||||
|
||||
def post_add(self, col):
|
||||
def post_add(self, col: "DruidColumnInlineView") -> None:
|
||||
self.post_update(col)
|
||||
|
||||
|
||||
|
|
@ -240,13 +241,13 @@ class DruidClusterModelView(SupersetModelView, DeleteMixin, YamlExportMixin):
|
|||
|
||||
yaml_dict_key = "databases"
|
||||
|
||||
def pre_add(self, cluster):
|
||||
def pre_add(self, cluster: "DruidClusterModelView") -> None:
|
||||
security_manager.add_permission_view_menu("database_access", cluster.perm)
|
||||
|
||||
def pre_update(self, cluster):
|
||||
def pre_update(self, cluster: "DruidClusterModelView") -> None:
|
||||
self.pre_add(cluster)
|
||||
|
||||
def _delete(self, pk):
|
||||
def _delete(self, pk: int) -> None:
|
||||
DeleteMixin._delete(self, pk)
|
||||
|
||||
|
||||
|
|
@ -334,7 +335,7 @@ class DruidDatasourceModelView(DatasourceModelView, DeleteMixin, YamlExportMixin
|
|||
"modified": _("Modified"),
|
||||
}
|
||||
|
||||
def pre_add(self, datasource):
|
||||
def pre_add(self, datasource: "DruidDatasourceModelView") -> None:
|
||||
with db.session.no_autoflush:
|
||||
query = db.session.query(models.DruidDatasource).filter(
|
||||
models.DruidDatasource.datasource_name == datasource.datasource_name,
|
||||
|
|
@ -343,7 +344,7 @@ class DruidDatasourceModelView(DatasourceModelView, DeleteMixin, YamlExportMixin
|
|||
if db.session.query(query.exists()).scalar():
|
||||
raise Exception(get_datasource_exist_error_msg(datasource.full_name))
|
||||
|
||||
def post_add(self, datasource):
|
||||
def post_add(self, datasource: "DruidDatasourceModelView") -> None:
|
||||
datasource.refresh_metrics()
|
||||
security_manager.add_permission_view_menu(
|
||||
"datasource_access", datasource.get_perm()
|
||||
|
|
@ -353,10 +354,10 @@ class DruidDatasourceModelView(DatasourceModelView, DeleteMixin, YamlExportMixin
|
|||
"schema_access", datasource.schema_perm
|
||||
)
|
||||
|
||||
def post_update(self, datasource):
|
||||
def post_update(self, datasource: "DruidDatasourceModelView") -> None:
|
||||
self.post_add(datasource)
|
||||
|
||||
def _delete(self, pk):
|
||||
def _delete(self, pk: int) -> None:
|
||||
DeleteMixin._delete(self, pk)
|
||||
|
||||
|
||||
|
|
@ -365,7 +366,7 @@ class Druid(BaseSupersetView):
|
|||
|
||||
@has_access
|
||||
@expose("/refresh_datasources/")
|
||||
def refresh_datasources(self, refresh_all=True):
|
||||
def refresh_datasources(self, refresh_all: bool = True) -> FlaskResponse:
|
||||
"""endpoint that refreshes druid datasources metadata"""
|
||||
session = db.session()
|
||||
DruidCluster = ConnectorRegistry.sources["druid"].cluster_class
|
||||
|
|
@ -397,7 +398,7 @@ class Druid(BaseSupersetView):
|
|||
|
||||
@has_access
|
||||
@expose("/scan_new_datasources/")
|
||||
def scan_new_datasources(self):
|
||||
def scan_new_datasources(self) -> FlaskResponse:
|
||||
"""
|
||||
Calling this endpoint will cause a scan for new
|
||||
datasources only and add them.
|
||||
|
|
|
|||
|
|
@ -54,10 +54,15 @@ from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetr
|
|||
from superset.constants import NULL_STRING
|
||||
from superset.db_engine_specs.base import TimestampExpression
|
||||
from superset.exceptions import DatabaseNotFound
|
||||
from superset.jinja_context import ExtraCache, get_template_processor
|
||||
from superset.jinja_context import (
|
||||
BaseTemplateProcessor,
|
||||
ExtraCache,
|
||||
get_template_processor,
|
||||
)
|
||||
from superset.models.annotations import Annotation
|
||||
from superset.models.core import Database
|
||||
from superset.models.helpers import AuditMixinNullable, QueryResult
|
||||
from superset.typing import Metric, QueryObjectDict
|
||||
from superset.utils import core as utils, import_datasource
|
||||
|
||||
config = app.config
|
||||
|
|
@ -86,7 +91,7 @@ class AnnotationDatasource(BaseDatasource):
|
|||
cache_timeout = 0
|
||||
changed_on = None
|
||||
|
||||
def query(self, query_obj: Dict[str, Any]) -> QueryResult:
|
||||
def query(self, query_obj: QueryObjectDict) -> QueryResult:
|
||||
error_message = None
|
||||
qry = db.session.query(Annotation)
|
||||
qry = qry.filter(Annotation.layer_id == query_obj["filter"][0]["val"])
|
||||
|
|
@ -110,10 +115,10 @@ class AnnotationDatasource(BaseDatasource):
|
|||
error_message=error_message,
|
||||
)
|
||||
|
||||
def get_query_str(self, query_obj):
|
||||
def get_query_str(self, query_obj: QueryObjectDict) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def values_for_column(self, column_name, limit=10000):
|
||||
def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
|
|
@ -239,8 +244,8 @@ class TableColumn(Model, BaseColumn):
|
|||
return self.table.make_sqla_column_compatible(time_expr, label)
|
||||
|
||||
@classmethod
|
||||
def import_obj(cls, i_column):
|
||||
def lookup_obj(lookup_column):
|
||||
def import_obj(cls, i_column: "TableColumn") -> "TableColumn":
|
||||
def lookup_obj(lookup_column: TableColumn) -> TableColumn:
|
||||
return (
|
||||
db.session.query(TableColumn)
|
||||
.filter(
|
||||
|
|
@ -343,8 +348,8 @@ class SqlMetric(Model, BaseMetric):
|
|||
return self.perm
|
||||
|
||||
@classmethod
|
||||
def import_obj(cls, i_metric):
|
||||
def lookup_obj(lookup_metric):
|
||||
def import_obj(cls, i_metric: "SqlMetric") -> "SqlMetric":
|
||||
def lookup_obj(lookup_metric: SqlMetric) -> SqlMetric:
|
||||
return (
|
||||
db.session.query(SqlMetric)
|
||||
.filter(
|
||||
|
|
@ -442,7 +447,7 @@ class SqlaTable(Model, BaseDatasource):
|
|||
sqla_col._df_label_expected = label_expected
|
||||
return sqla_col
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return self.name
|
||||
|
||||
@property
|
||||
|
|
@ -521,14 +526,14 @@ class SqlaTable(Model, BaseDatasource):
|
|||
)
|
||||
|
||||
@property
|
||||
def dttm_cols(self) -> List:
|
||||
def dttm_cols(self) -> List[str]:
|
||||
l = [c.column_name for c in self.columns if c.is_dttm]
|
||||
if self.main_dttm_col and self.main_dttm_col not in l:
|
||||
l.append(self.main_dttm_col)
|
||||
return l
|
||||
|
||||
@property
|
||||
def num_cols(self) -> List:
|
||||
def num_cols(self) -> List[str]:
|
||||
return [c.column_name for c in self.columns if c.is_numeric]
|
||||
|
||||
@property
|
||||
|
|
@ -550,7 +555,7 @@ class SqlaTable(Model, BaseDatasource):
|
|||
def sql_url(self) -> str:
|
||||
return self.database.sql_url + "?table_name=" + str(self.table_name)
|
||||
|
||||
def external_metadata(self):
|
||||
def external_metadata(self) -> List[Dict[str, str]]:
|
||||
cols = self.database.get_columns(self.table_name, schema=self.schema)
|
||||
for col in cols:
|
||||
try:
|
||||
|
|
@ -567,7 +572,7 @@ class SqlaTable(Model, BaseDatasource):
|
|||
}
|
||||
|
||||
@property
|
||||
def select_star(self) -> str:
|
||||
def select_star(self) -> Optional[str]:
|
||||
# show_cols and latest_partition set to false to avoid
|
||||
# the expensive cost of inspecting the DB
|
||||
return self.database.select_star(
|
||||
|
|
@ -589,7 +594,7 @@ class SqlaTable(Model, BaseDatasource):
|
|||
d["is_sqllab_view"] = self.is_sqllab_view
|
||||
return d
|
||||
|
||||
def values_for_column(self, column_name: str, limit: int = 10000) -> List:
|
||||
def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
|
||||
"""Runs query against sqla to retrieve some
|
||||
sample values for the given column.
|
||||
"""
|
||||
|
|
@ -626,10 +631,10 @@ class SqlaTable(Model, BaseDatasource):
|
|||
sql = SQL_QUERY_MUTATOR(sql, username, security_manager, self.database)
|
||||
return sql
|
||||
|
||||
def get_template_processor(self, **kwargs):
|
||||
def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor:
|
||||
return get_template_processor(table=self, database=self.database, **kwargs)
|
||||
|
||||
def get_query_str_extended(self, query_obj: Dict[str, Any]) -> QueryStringExtended:
|
||||
def get_query_str_extended(self, query_obj: QueryObjectDict) -> QueryStringExtended:
|
||||
sqlaq = self.get_sqla_query(**query_obj)
|
||||
sql = self.database.compile_sqla_query(sqlaq.sqla_query)
|
||||
logger.info(sql)
|
||||
|
|
@ -639,18 +644,20 @@ class SqlaTable(Model, BaseDatasource):
|
|||
labels_expected=sqlaq.labels_expected, sql=sql, prequeries=sqlaq.prequeries
|
||||
)
|
||||
|
||||
def get_query_str(self, query_obj: Dict[str, Any]) -> str:
|
||||
def get_query_str(self, query_obj: QueryObjectDict) -> str:
|
||||
query_str_ext = self.get_query_str_extended(query_obj)
|
||||
all_queries = query_str_ext.prequeries + [query_str_ext.sql]
|
||||
return ";\n\n".join(all_queries) + ";"
|
||||
|
||||
def get_sqla_table(self):
|
||||
def get_sqla_table(self) -> table:
|
||||
tbl = table(self.table_name)
|
||||
if self.schema:
|
||||
tbl.schema = self.schema
|
||||
return tbl
|
||||
|
||||
def get_from_clause(self, template_processor=None):
|
||||
def get_from_clause(
|
||||
self, template_processor: Optional[BaseTemplateProcessor] = None
|
||||
) -> Union[table, TextAsFrom]:
|
||||
# Supporting arbitrary SQL statements in place of tables
|
||||
if self.sql:
|
||||
from_sql = self.sql
|
||||
|
|
@ -687,7 +694,9 @@ class SqlaTable(Model, BaseDatasource):
|
|||
|
||||
return self.make_sqla_column_compatible(sqla_metric, label)
|
||||
|
||||
def _get_sqla_row_level_filters(self, template_processor) -> List[str]:
|
||||
def _get_sqla_row_level_filters(
|
||||
self, template_processor: BaseTemplateProcessor
|
||||
) -> List[str]:
|
||||
"""
|
||||
Return the appropriate row level security filters for this table and the current user.
|
||||
|
||||
|
|
@ -702,22 +711,22 @@ class SqlaTable(Model, BaseDatasource):
|
|||
|
||||
def get_sqla_query( # sqla
|
||||
self,
|
||||
metrics,
|
||||
granularity,
|
||||
from_dttm,
|
||||
to_dttm,
|
||||
columns=None,
|
||||
groupby=None,
|
||||
filter=None,
|
||||
is_timeseries=True,
|
||||
timeseries_limit=15,
|
||||
timeseries_limit_metric=None,
|
||||
row_limit=None,
|
||||
inner_from_dttm=None,
|
||||
inner_to_dttm=None,
|
||||
orderby=None,
|
||||
extras=None,
|
||||
order_desc=True,
|
||||
metrics: List[Metric],
|
||||
granularity: str,
|
||||
from_dttm: datetime,
|
||||
to_dttm: datetime,
|
||||
columns: Optional[List[str]] = None,
|
||||
groupby: Optional[List[str]] = None,
|
||||
filter: Optional[List[Dict[str, Any]]] = None,
|
||||
is_timeseries: bool = True,
|
||||
timeseries_limit: int = 15,
|
||||
timeseries_limit_metric: Optional[Metric] = None,
|
||||
row_limit: Optional[int] = None,
|
||||
inner_from_dttm: Optional[datetime] = None,
|
||||
inner_to_dttm: Optional[datetime] = None,
|
||||
orderby: Optional[List[Tuple[ColumnElement, bool]]] = None,
|
||||
extras: Optional[Dict[str, Any]] = None,
|
||||
order_desc: bool = True,
|
||||
) -> SqlaQuery:
|
||||
"""Querying any sqla table from this common interface"""
|
||||
template_kwargs = {
|
||||
|
|
@ -765,8 +774,9 @@ class SqlaTable(Model, BaseDatasource):
|
|||
metrics_exprs: List[ColumnElement] = []
|
||||
for m in metrics:
|
||||
if utils.is_adhoc_metric(m):
|
||||
assert isinstance(m, dict)
|
||||
metrics_exprs.append(self.adhoc_metric_to_sqla(m, cols))
|
||||
elif m in metrics_dict:
|
||||
elif isinstance(m, str) and m in metrics_dict:
|
||||
metrics_exprs.append(metrics_dict[m].get_sqla_col())
|
||||
else:
|
||||
raise Exception(_("Metric '%(metric)s' does not exist", metric=m))
|
||||
|
|
@ -781,7 +791,9 @@ class SqlaTable(Model, BaseDatasource):
|
|||
|
||||
if (is_sip_38 and metrics and columns) or (not is_sip_38 and groupby):
|
||||
# dedup columns while preserving order
|
||||
groupby = list(dict.fromkeys(columns if is_sip_38 else groupby))
|
||||
columns_ = columns if is_sip_38 else groupby
|
||||
assert columns_
|
||||
groupby = list(dict.fromkeys(columns_))
|
||||
|
||||
select_exprs = []
|
||||
for s in groupby:
|
||||
|
|
@ -802,6 +814,7 @@ class SqlaTable(Model, BaseDatasource):
|
|||
)
|
||||
metrics_exprs = []
|
||||
|
||||
assert extras is not None
|
||||
time_range_endpoints = extras.get("time_range_endpoints")
|
||||
groupby_exprs_with_timestamp = OrderedDict(groupby_exprs_sans_timestamp.items())
|
||||
if granularity:
|
||||
|
|
@ -845,7 +858,8 @@ class SqlaTable(Model, BaseDatasource):
|
|||
|
||||
where_clause_and = []
|
||||
having_clause_and: List = []
|
||||
for flt in filter:
|
||||
|
||||
for flt in filter: # type: ignore
|
||||
if not all([flt.get(s) for s in ["col", "op"]]):
|
||||
continue
|
||||
col = flt["col"]
|
||||
|
|
@ -1029,12 +1043,20 @@ class SqlaTable(Model, BaseDatasource):
|
|||
prequeries=prequeries,
|
||||
)
|
||||
|
||||
def _get_timeseries_orderby(self, timeseries_limit_metric, metrics_dict, cols):
|
||||
def _get_timeseries_orderby(
|
||||
self,
|
||||
timeseries_limit_metric: Metric,
|
||||
metrics_dict: Dict[str, SqlMetric],
|
||||
cols: Dict[str, Column],
|
||||
) -> Optional[Column]:
|
||||
if utils.is_adhoc_metric(timeseries_limit_metric):
|
||||
assert isinstance(timeseries_limit_metric, dict)
|
||||
ob = self.adhoc_metric_to_sqla(timeseries_limit_metric, cols)
|
||||
elif timeseries_limit_metric in metrics_dict:
|
||||
timeseries_limit_metric = metrics_dict.get(timeseries_limit_metric)
|
||||
ob = timeseries_limit_metric.get_sqla_col()
|
||||
elif (
|
||||
isinstance(timeseries_limit_metric, str)
|
||||
and timeseries_limit_metric in metrics_dict
|
||||
):
|
||||
ob = metrics_dict[timeseries_limit_metric].get_sqla_col()
|
||||
else:
|
||||
raise Exception(
|
||||
_("Metric '%(metric)s' does not exist", metric=timeseries_limit_metric)
|
||||
|
|
@ -1054,7 +1076,7 @@ class SqlaTable(Model, BaseDatasource):
|
|||
|
||||
return or_(*groups)
|
||||
|
||||
def query(self, query_obj: Dict[str, Any]) -> QueryResult:
|
||||
def query(self, query_obj: QueryObjectDict) -> QueryResult:
|
||||
qry_start_dttm = datetime.now()
|
||||
query_str_ext = self.get_query_str_extended(query_obj)
|
||||
sql = query_str_ext.sql
|
||||
|
|
@ -1101,7 +1123,7 @@ class SqlaTable(Model, BaseDatasource):
|
|||
def get_sqla_table_object(self) -> Table:
|
||||
return self.database.get_table(self.table_name, schema=self.schema)
|
||||
|
||||
def fetch_metadata(self, commit=True) -> None:
|
||||
def fetch_metadata(self, commit: bool = True) -> None:
|
||||
"""Fetches the metadata for the table and merges it in"""
|
||||
try:
|
||||
table = self.get_sqla_table_object()
|
||||
|
|
@ -1166,7 +1188,9 @@ class SqlaTable(Model, BaseDatasource):
|
|||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def import_obj(cls, i_datasource, import_time=None) -> int:
|
||||
def import_obj(
|
||||
cls, i_datasource: "SqlaTable", import_time: Optional[int] = None
|
||||
) -> int:
|
||||
"""Imports the datasource from the object to the database.
|
||||
|
||||
Metrics and columns and datasource will be overrided if exists.
|
||||
|
|
@ -1174,7 +1198,7 @@ class SqlaTable(Model, BaseDatasource):
|
|||
superset instances. Audit metadata isn't copies over.
|
||||
"""
|
||||
|
||||
def lookup_sqlatable(table):
|
||||
def lookup_sqlatable(table: "SqlaTable") -> "SqlaTable":
|
||||
return (
|
||||
db.session.query(SqlaTable)
|
||||
.join(Database)
|
||||
|
|
@ -1186,7 +1210,7 @@ class SqlaTable(Model, BaseDatasource):
|
|||
.first()
|
||||
)
|
||||
|
||||
def lookup_database(table):
|
||||
def lookup_database(table: SqlaTable) -> Database:
|
||||
try:
|
||||
return (
|
||||
db.session.query(Database)
|
||||
|
|
@ -1207,7 +1231,11 @@ class SqlaTable(Model, BaseDatasource):
|
|||
|
||||
@classmethod
|
||||
def query_datasources_by_name(
|
||||
cls, session: Session, database: Database, datasource_name: str, schema=None
|
||||
cls,
|
||||
session: Session,
|
||||
database: Database,
|
||||
datasource_name: str,
|
||||
schema: Optional[str] = None,
|
||||
) -> List["SqlaTable"]:
|
||||
query = (
|
||||
session.query(cls)
|
||||
|
|
@ -1219,10 +1247,10 @@ class SqlaTable(Model, BaseDatasource):
|
|||
return query.all()
|
||||
|
||||
@staticmethod
|
||||
def default_query(qry) -> Query:
|
||||
def default_query(qry: Query) -> Query:
|
||||
return qry.filter_by(is_sqllab_view=False)
|
||||
|
||||
def has_extra_cache_key_calls(self, query_obj: Dict[str, Any]) -> bool:
|
||||
def has_extra_cache_key_calls(self, query_obj: QueryObjectDict) -> bool:
|
||||
"""
|
||||
Detects the presence of calls to `ExtraCache` methods in items in query_obj that
|
||||
can be templated. If any are present, the query must be evaluated to extract
|
||||
|
|
@ -1248,7 +1276,7 @@ class SqlaTable(Model, BaseDatasource):
|
|||
return True
|
||||
return False
|
||||
|
||||
def get_extra_cache_keys(self, query_obj: Dict[str, Any]) -> List[Hashable]:
|
||||
def get_extra_cache_keys(self, query_obj: QueryObjectDict) -> List[Hashable]:
|
||||
"""
|
||||
The cache key of a SqlaTable needs to consider any keys added by the parent
|
||||
class and any keys added via `ExtraCache`.
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@
|
|||
"""Views used by the SqlAlchemy connector"""
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Union
|
||||
|
||||
from flask import flash, Markup, redirect
|
||||
from flask_appbuilder import CompactCRUDMixin, expose
|
||||
|
|
@ -32,6 +33,7 @@ from wtforms.validators import Regexp
|
|||
from superset import app, db, security_manager
|
||||
from superset.connectors.base.views import DatasourceModelView
|
||||
from superset.constants import RouteMethod
|
||||
from superset.typing import FlaskResponse
|
||||
from superset.utils import core as utils
|
||||
from superset.views.base import (
|
||||
create_table_permissions,
|
||||
|
|
@ -375,10 +377,10 @@ class TableModelView(DatasourceModelView, DeleteMixin, YamlExportMixin):
|
|||
)
|
||||
}
|
||||
|
||||
def pre_add(self, table):
|
||||
def pre_add(self, table: "TableModelView") -> None:
|
||||
validate_sqlatable(table)
|
||||
|
||||
def post_add(self, table, flash_message=True):
|
||||
def post_add(self, table: "TableModelView", flash_message: bool = True) -> None:
|
||||
table.fetch_metadata()
|
||||
create_table_permissions(table)
|
||||
if flash_message:
|
||||
|
|
@ -392,15 +394,15 @@ class TableModelView(DatasourceModelView, DeleteMixin, YamlExportMixin):
|
|||
"info",
|
||||
)
|
||||
|
||||
def post_update(self, table):
|
||||
def post_update(self, table: "TableModelView") -> None:
|
||||
self.post_add(table, flash_message=False)
|
||||
|
||||
def _delete(self, pk):
|
||||
def _delete(self, pk: int) -> None:
|
||||
DeleteMixin._delete(self, pk)
|
||||
|
||||
@expose("/edit/<pk>", methods=["GET", "POST"])
|
||||
@has_access
|
||||
def edit(self, pk):
|
||||
def edit(self, pk: int) -> FlaskResponse:
|
||||
"""Simple hack to redirect to explore view after saving"""
|
||||
resp = super(TableModelView, self).edit(pk)
|
||||
if isinstance(resp, str):
|
||||
|
|
@ -410,7 +412,9 @@ class TableModelView(DatasourceModelView, DeleteMixin, YamlExportMixin):
|
|||
@action(
|
||||
"refresh", __("Refresh Metadata"), __("Refresh column metadata"), "fa-refresh"
|
||||
)
|
||||
def refresh(self, tables):
|
||||
def refresh(
|
||||
self, tables: Union["TableModelView", List["TableModelView"]]
|
||||
) -> FlaskResponse:
|
||||
if not isinstance(tables, list):
|
||||
tables = [tables]
|
||||
successes = []
|
||||
|
|
@ -439,7 +443,7 @@ class TableModelView(DatasourceModelView, DeleteMixin, YamlExportMixin):
|
|||
|
||||
@expose("/list/")
|
||||
@has_access
|
||||
def list(self):
|
||||
def list(self) -> FlaskResponse:
|
||||
if not app.config["ENABLE_REACT_CRUD_VIEWS"]:
|
||||
return super().list()
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=unused-argument
|
||||
import dataclasses
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
|
|
@ -34,6 +33,7 @@ from typing import (
|
|||
Union,
|
||||
)
|
||||
|
||||
import dataclasses
|
||||
import pandas as pd
|
||||
import sqlparse
|
||||
from flask import g
|
||||
|
|
|
|||
|
|
@ -15,10 +15,11 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=too-few-public-methods,invalid-name
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class SupersetErrorType(str, Enum):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@
|
|||
"""Defines the templating context for SQL Lab"""
|
||||
import inspect
|
||||
import re
|
||||
from typing import Any, List, Optional, Tuple
|
||||
from typing import Any, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from flask import g, request
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
|
|
@ -26,6 +26,13 @@ from superset import jinja_base_context
|
|||
from superset.extensions import jinja_context_manager
|
||||
from superset.utils.core import convert_legacy_filters_into_adhoc, merge_extra_filters
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.connectors.sqla.models import ( # pylint: disable=unused-import
|
||||
SqlaTable,
|
||||
)
|
||||
from superset.models.core import Database # pylint: disable=unused-import
|
||||
from superset.models.sql_lab import Query # pylint: disable=unused-import
|
||||
|
||||
|
||||
def filter_values(column: str, default: Optional[str] = None) -> List[str]:
|
||||
""" Gets a values for a particular filter as a list
|
||||
|
|
@ -200,12 +207,12 @@ class BaseTemplateProcessor: # pylint: disable=too-few-public-methods
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
database=None,
|
||||
query=None,
|
||||
table=None,
|
||||
database: Optional["Database"] = None,
|
||||
query: Optional["Query"] = None,
|
||||
table: Optional["SqlaTable"] = None,
|
||||
extra_cache_keys: Optional[List[Any]] = None,
|
||||
**kwargs
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.database = database
|
||||
self.query = query
|
||||
self.schema = None
|
||||
|
|
@ -230,7 +237,7 @@ class BaseTemplateProcessor: # pylint: disable=too-few-public-methods
|
|||
self.context[self.engine] = self
|
||||
self.env = SandboxedEnvironment()
|
||||
|
||||
def process_template(self, sql: str, **kwargs) -> str:
|
||||
def process_template(self, sql: str, **kwargs: Any) -> str:
|
||||
"""Processes a sql template
|
||||
|
||||
>>> sql = "SELECT '{{ datetime(2017, 1, 1).isoformat() }}'"
|
||||
|
|
@ -279,12 +286,14 @@ class PrestoTemplateProcessor(BaseTemplateProcessor):
|
|||
"""
|
||||
|
||||
table_name, schema = self._schema_table(table_name, self.schema)
|
||||
return self.database.db_engine_spec.latest_partition(
|
||||
assert self.database
|
||||
return self.database.db_engine_spec.latest_partition( # type: ignore
|
||||
table_name, schema, self.database
|
||||
)[1]
|
||||
|
||||
def latest_sub_partition(self, table_name, **kwargs):
|
||||
table_name, schema = self._schema_table(table_name, self.schema)
|
||||
assert self.database
|
||||
return self.database.db_engine_spec.latest_sub_partition(
|
||||
table_name=table_name, schema=schema, database=self.database, **kwargs
|
||||
)
|
||||
|
|
@ -305,7 +314,12 @@ for k in keys:
|
|||
template_processors[o.engine] = o
|
||||
|
||||
|
||||
def get_template_processor(database, table=None, query=None, **kwargs):
|
||||
def get_template_processor(
|
||||
database: "Database",
|
||||
table: Optional["SqlaTable"] = None,
|
||||
query: Optional["Query"] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseTemplateProcessor:
|
||||
template_processor = template_processors.get(
|
||||
database.backend, BaseTemplateProcessor
|
||||
)
|
||||
|
|
|
|||
|
|
@ -15,11 +15,11 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Set
|
||||
from urllib import parse
|
||||
|
||||
import sqlparse
|
||||
from dataclasses import dataclass
|
||||
from sqlparse.sql import Identifier, IdentifierList, remove_quotes, Token, TokenList
|
||||
from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace
|
||||
from sqlparse.utils import imt
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|||
|
||||
from flask import Flask
|
||||
from flask_caching import Cache
|
||||
from werkzeug.wrappers import Response
|
||||
|
||||
CacheConfig = Union[Callable[[Flask], Cache], Dict[str, Any]]
|
||||
DbapiDescriptionRow = Tuple[
|
||||
|
|
@ -27,4 +28,15 @@ DbapiDescription = Union[List[DbapiDescriptionRow], Tuple[DbapiDescriptionRow, .
|
|||
DbapiResult = List[Union[List[Any], Tuple[Any, ...]]]
|
||||
FilterValue = Union[float, int, str]
|
||||
FilterValues = Union[FilterValue, List[FilterValue], Tuple[FilterValue]]
|
||||
Granularity = Union[str, Dict[str, Union[str, float]]]
|
||||
Metric = Union[Dict[str, str], str]
|
||||
QueryObjectDict = Dict[str, Any]
|
||||
VizData = Optional[Union[List[Any], Dict[Any, Any]]]
|
||||
|
||||
# Flask response.
|
||||
Base = Union[bytes, str]
|
||||
Status = Union[int, str]
|
||||
Headers = Dict[str, Any]
|
||||
FlaskResponse = Union[
|
||||
Response, Base, Tuple[Base, Status], Tuple[Base, Status, Headers],
|
||||
]
|
||||
|
|
|
|||
|
|
@ -43,10 +43,12 @@ from typing import (
|
|||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
|
|
@ -79,6 +81,7 @@ from superset.exceptions import (
|
|||
SupersetException,
|
||||
SupersetTimeoutException,
|
||||
)
|
||||
from superset.typing import Metric
|
||||
from superset.utils.dates import datetime_to_epoch, EPOCH
|
||||
|
||||
try:
|
||||
|
|
@ -101,7 +104,7 @@ JS_MAX_INTEGER = 9007199254740991 # Largest int Java Script can handle 2^53-1
|
|||
try:
|
||||
# Having might not have been imported.
|
||||
class DimSelector(Having):
|
||||
def __init__(self, **args):
|
||||
def __init__(self, **args: Any) -> None:
|
||||
# Just a hack to prevent any exceptions
|
||||
Having.__init__(self, type="equalTo", aggregation=None, value=None)
|
||||
|
||||
|
|
@ -118,7 +121,7 @@ except NameError:
|
|||
pass
|
||||
|
||||
|
||||
def flasher(msg, severity=None):
|
||||
def flasher(msg: str, severity: str) -> None:
|
||||
"""Flask's flash if available, logging call if not"""
|
||||
try:
|
||||
flash(msg, severity)
|
||||
|
|
@ -235,7 +238,7 @@ def list_minus(l: List, minus: List) -> List:
|
|||
return [o for o in l if o not in minus]
|
||||
|
||||
|
||||
def parse_human_datetime(s: Optional[str]) -> Optional[datetime]:
|
||||
def parse_human_datetime(s: str) -> datetime:
|
||||
"""
|
||||
Returns ``datetime.datetime`` from human readable strings
|
||||
|
||||
|
|
@ -256,8 +259,6 @@ def parse_human_datetime(s: Optional[str]) -> Optional[datetime]:
|
|||
>>> year_ago_1 == year_ago_2
|
||||
True
|
||||
"""
|
||||
if not s:
|
||||
return None
|
||||
try:
|
||||
dttm = parse(s)
|
||||
except Exception:
|
||||
|
|
@ -564,7 +565,9 @@ def generic_find_uq_constraint_name(table, columns, insp):
|
|||
return uq["name"]
|
||||
|
||||
|
||||
def get_datasource_full_name(database_name, datasource_name, schema=None):
|
||||
def get_datasource_full_name(
|
||||
database_name: str, datasource_name: str, schema: Optional[str] = None
|
||||
) -> str:
|
||||
if not schema:
|
||||
return "[{}].[{}]".format(database_name, datasource_name)
|
||||
return "[{}].[{}].[{}]".format(database_name, schema, datasource_name)
|
||||
|
|
@ -792,7 +795,7 @@ def get_email_address_list(address_string: str) -> List[str]:
|
|||
return [x.strip() for x in address_string_list if x.strip()]
|
||||
|
||||
|
||||
def choicify(values):
|
||||
def choicify(values: Iterable[Any]) -> List[Tuple[Any, Any]]:
|
||||
"""Takes an iterable and makes an iterable of tuples with it"""
|
||||
return [(v, v) for v in values]
|
||||
|
||||
|
|
@ -967,7 +970,7 @@ def get_example_database() -> "Database":
|
|||
return get_or_create_db("examples", db_uri)
|
||||
|
||||
|
||||
def is_adhoc_metric(metric) -> bool:
|
||||
def is_adhoc_metric(metric: Metric) -> bool:
|
||||
return bool(
|
||||
isinstance(metric, dict)
|
||||
and (
|
||||
|
|
@ -985,11 +988,11 @@ def is_adhoc_metric(metric) -> bool:
|
|||
)
|
||||
|
||||
|
||||
def get_metric_name(metric):
|
||||
return metric["label"] if is_adhoc_metric(metric) else metric
|
||||
def get_metric_name(metric: Metric) -> str:
|
||||
return metric["label"] if is_adhoc_metric(metric) else metric # type: ignore
|
||||
|
||||
|
||||
def get_metric_names(metrics):
|
||||
def get_metric_names(metrics: Sequence[Metric]) -> List[str]:
|
||||
return [get_metric_name(metric) for metric in metrics]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -15,15 +15,22 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Callable, Optional
|
||||
|
||||
from flask_appbuilder import Model
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm.session import make_transient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def import_datasource(
|
||||
session, i_datasource, lookup_database, lookup_datasource, import_time
|
||||
):
|
||||
session: Session,
|
||||
i_datasource: Model,
|
||||
lookup_database: Callable,
|
||||
lookup_datasource: Callable,
|
||||
import_time: Optional[int] = None,
|
||||
) -> int:
|
||||
"""Imports the datasource from the object to the database.
|
||||
|
||||
Metrics and columns and datasource will be overrided if exists.
|
||||
|
|
@ -75,7 +82,7 @@ def import_datasource(
|
|||
return datasource.id
|
||||
|
||||
|
||||
def import_simple_obj(session, i_obj, lookup_obj):
|
||||
def import_simple_obj(session: Session, i_obj: Model, lookup_obj: Callable) -> Model:
|
||||
make_transient(i_obj)
|
||||
i_obj.id = None
|
||||
i_obj.table = None
|
||||
|
|
|
|||
|
|
@ -14,13 +14,13 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import dataclasses
|
||||
import functools
|
||||
import logging
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import dataclasses
|
||||
import simplejson as json
|
||||
import yaml
|
||||
from flask import abort, flash, g, get_flashed_messages, redirect, Response, session
|
||||
|
|
@ -48,10 +48,16 @@ from superset.connectors.sqla import models
|
|||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import SupersetException, SupersetSecurityException
|
||||
from superset.translations.utils import get_language_pack
|
||||
from superset.typing import FlaskResponse
|
||||
from superset.utils import core as utils
|
||||
|
||||
from .utils import bootstrap_user_data
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.connectors.druid.views import ( # pylint: disable=unused-import
|
||||
DruidClusterModelView,
|
||||
)
|
||||
|
||||
FRONTEND_CONF_KEYS = (
|
||||
"SUPERSET_WEBSERVER_TIMEOUT",
|
||||
"SUPERSET_DASHBOARD_POSITION_DATA_LIMIT",
|
||||
|
|
@ -305,7 +311,7 @@ class SupersetModelView(ModelView):
|
|||
page_size = 100
|
||||
list_widget = SupersetListWidget
|
||||
|
||||
def render_app_template(self):
|
||||
def render_app_template(self) -> FlaskResponse:
|
||||
payload = {
|
||||
"user": bootstrap_user_data(g.user),
|
||||
"common": common_bootstrap_payload(),
|
||||
|
|
@ -359,7 +365,9 @@ class YamlExportMixin: # pylint: disable=too-few-public-methods
|
|||
|
||||
|
||||
class DeleteMixin: # pylint: disable=too-few-public-methods
|
||||
def _delete(self, primary_key):
|
||||
def _delete(
|
||||
self: Union[BaseView, "DeleteMixin", "DruidClusterModelView"], primary_key: int,
|
||||
) -> None:
|
||||
"""
|
||||
Delete function logic, override to implement diferent logic
|
||||
deletes the record with primary_key = primary_key
|
||||
|
|
@ -367,11 +375,11 @@ class DeleteMixin: # pylint: disable=too-few-public-methods
|
|||
:param primary_key:
|
||||
record primary key to delete
|
||||
"""
|
||||
item = self.datamodel.get(primary_key, self._base_filters)
|
||||
item = self.datamodel.get(primary_key, self._base_filters) # type: ignore
|
||||
if not item:
|
||||
abort(404)
|
||||
try:
|
||||
self.pre_delete(item)
|
||||
self.pre_delete(item) # type: ignore
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
flash(str(ex), "danger")
|
||||
else:
|
||||
|
|
@ -384,8 +392,8 @@ class DeleteMixin: # pylint: disable=too-few-public-methods
|
|||
.all()
|
||||
)
|
||||
|
||||
if self.datamodel.delete(item):
|
||||
self.post_delete(item)
|
||||
if self.datamodel.delete(item): # type: ignore
|
||||
self.post_delete(item) # type: ignore
|
||||
|
||||
for pv in pvs:
|
||||
security_manager.get_session.delete(pv)
|
||||
|
|
@ -395,8 +403,8 @@ class DeleteMixin: # pylint: disable=too-few-public-methods
|
|||
|
||||
security_manager.get_session.commit()
|
||||
|
||||
flash(*self.datamodel.message)
|
||||
self.update_redirect()
|
||||
flash(*self.datamodel.message) # type: ignore
|
||||
self.update_redirect() # type: ignore
|
||||
|
||||
@action(
|
||||
"muldelete", __("Delete"), __("Delete all Really?"), "fa-trash", single=False
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ These objects represent the backend of all the visualizations that
|
|||
Superset can render.
|
||||
"""
|
||||
import copy
|
||||
import dataclasses
|
||||
import hashlib
|
||||
import inspect
|
||||
import logging
|
||||
|
|
@ -34,6 +33,7 @@ from datetime import datetime, timedelta
|
|||
from itertools import product
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, TYPE_CHECKING
|
||||
|
||||
import dataclasses
|
||||
import geohash
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
|
@ -1049,9 +1049,9 @@ class BubbleViz(NVD3Viz):
|
|||
# dedup groupby if it happens to be the same
|
||||
d["groupby"] = list(dict.fromkeys(d["groupby"]))
|
||||
|
||||
self.x_metric = form_data.get("x")
|
||||
self.y_metric = form_data.get("y")
|
||||
self.z_metric = form_data.get("size")
|
||||
self.x_metric = form_data["x"]
|
||||
self.y_metric = form_data["y"]
|
||||
self.z_metric = form_data["size"]
|
||||
self.entity = form_data.get("entity")
|
||||
self.series = form_data.get("series") or self.entity
|
||||
d["row_limit"] = form_data.get("limit")
|
||||
|
|
@ -1093,7 +1093,7 @@ class BulletViz(NVD3Viz):
|
|||
def query_obj(self):
|
||||
form_data = self.form_data
|
||||
d = super().query_obj()
|
||||
self.metric = form_data.get("metric")
|
||||
self.metric = form_data["metric"]
|
||||
|
||||
d["metrics"] = [self.metric]
|
||||
if not self.metric:
|
||||
|
|
@ -1451,8 +1451,8 @@ class NVD3DualLineViz(NVD3Viz):
|
|||
_("Pick a time granularity for your time series")
|
||||
)
|
||||
|
||||
metric = utils.get_metric_name(fd.get("metric"))
|
||||
metric_2 = utils.get_metric_name(fd.get("metric_2"))
|
||||
metric = utils.get_metric_name(fd["metric"])
|
||||
metric_2 = utils.get_metric_name(fd["metric_2"])
|
||||
df = df.pivot_table(index=DTTM_ALIAS, values=[metric, metric_2])
|
||||
|
||||
chart_data = self.to_series(df)
|
||||
|
|
@ -1507,7 +1507,7 @@ class NVD3TimePivotViz(NVD3TimeSeriesViz):
|
|||
df = df.pivot_table(
|
||||
index=DTTM_ALIAS,
|
||||
columns="series",
|
||||
values=utils.get_metric_name(fd.get("metric")),
|
||||
values=utils.get_metric_name(fd["metric"]),
|
||||
)
|
||||
chart_data = self.to_series(df)
|
||||
for serie in chart_data:
|
||||
|
|
@ -1690,8 +1690,12 @@ class SunburstViz(BaseViz):
|
|||
fd = self.form_data
|
||||
cols = fd.get("groupby") or []
|
||||
cols.extend(["m1", "m2"])
|
||||
metric = utils.get_metric_name(fd.get("metric"))
|
||||
secondary_metric = utils.get_metric_name(fd.get("secondary_metric"))
|
||||
metric = utils.get_metric_name(fd["metric"])
|
||||
secondary_metric = (
|
||||
utils.get_metric_name(fd["secondary_metric"])
|
||||
if "secondary_metric" in fd
|
||||
else None
|
||||
)
|
||||
if metric == secondary_metric or secondary_metric is None:
|
||||
df.rename(columns={df.columns[-1]: "m1"}, inplace=True)
|
||||
df["m2"] = df["m1"]
|
||||
|
|
@ -1872,8 +1876,12 @@ class WorldMapViz(BaseViz):
|
|||
|
||||
fd = self.form_data
|
||||
cols = [fd.get("entity")]
|
||||
metric = utils.get_metric_name(fd.get("metric"))
|
||||
secondary_metric = utils.get_metric_name(fd.get("secondary_metric"))
|
||||
metric = utils.get_metric_name(fd["metric"])
|
||||
secondary_metric = (
|
||||
utils.get_metric_name(fd["secondary_metric"])
|
||||
if "secondary_metric" in fd
|
||||
else None
|
||||
)
|
||||
columns = ["country", "m1", "m2"]
|
||||
if metric == secondary_metric:
|
||||
ndf = df[cols]
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ These objects represent the backend of all the visualizations that
|
|||
Superset can render.
|
||||
"""
|
||||
import copy
|
||||
import dataclasses
|
||||
import hashlib
|
||||
import inspect
|
||||
import logging
|
||||
|
|
@ -34,6 +33,7 @@ from datetime import datetime, timedelta
|
|||
from itertools import product
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, TYPE_CHECKING
|
||||
|
||||
import dataclasses
|
||||
import geohash
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
|
@ -1077,9 +1077,9 @@ class BubbleViz(NVD3Viz):
|
|||
form_data = self.form_data
|
||||
d = super().query_obj()
|
||||
|
||||
self.x_metric = form_data.get("x")
|
||||
self.y_metric = form_data.get("y")
|
||||
self.z_metric = form_data.get("size")
|
||||
self.x_metric = form_data["x"]
|
||||
self.y_metric = form_data["y"]
|
||||
self.z_metric = form_data["size"]
|
||||
self.entity = form_data.get("entity")
|
||||
self.series = form_data.get("series") or self.entity
|
||||
d["row_limit"] = form_data.get("limit")
|
||||
|
|
@ -1476,8 +1476,8 @@ class NVD3DualLineViz(NVD3Viz):
|
|||
_("Pick a time granularity for your time series")
|
||||
)
|
||||
|
||||
metric = utils.get_metric_name(fd.get("metric"))
|
||||
metric_2 = utils.get_metric_name(fd.get("metric_2"))
|
||||
metric = utils.get_metric_name(fd["metric"])
|
||||
metric_2 = utils.get_metric_name(fd["metric_2"])
|
||||
df = df.pivot_table(index=DTTM_ALIAS, values=[metric, metric_2])
|
||||
|
||||
chart_data = self.to_series(df)
|
||||
|
|
@ -1532,7 +1532,7 @@ class NVD3TimePivotViz(NVD3TimeSeriesViz):
|
|||
df = df.pivot_table(
|
||||
index=DTTM_ALIAS,
|
||||
columns="series",
|
||||
values=utils.get_metric_name(fd.get("metric")),
|
||||
values=utils.get_metric_name(fd["metric"]),
|
||||
)
|
||||
chart_data = self.to_series(df)
|
||||
for serie in chart_data:
|
||||
|
|
@ -1710,8 +1710,12 @@ class SunburstViz(BaseViz):
|
|||
fd = self.form_data
|
||||
cols = fd.get("groupby") or []
|
||||
cols.extend(["m1", "m2"])
|
||||
metric = utils.get_metric_name(fd.get("metric"))
|
||||
secondary_metric = utils.get_metric_name(fd.get("secondary_metric"))
|
||||
metric = utils.get_metric_name(fd["metric"])
|
||||
secondary_metric = (
|
||||
utils.get_metric_name(fd["secondary_metric"])
|
||||
if "secondary_metric" in fd
|
||||
else None
|
||||
)
|
||||
if metric == secondary_metric or secondary_metric is None:
|
||||
df.rename(columns={df.columns[-1]: "m1"}, inplace=True)
|
||||
df["m2"] = df["m1"]
|
||||
|
|
@ -1868,8 +1872,12 @@ class WorldMapViz(BaseViz):
|
|||
|
||||
fd = self.form_data
|
||||
cols = [fd.get("entity")]
|
||||
metric = utils.get_metric_name(fd.get("metric"))
|
||||
secondary_metric = utils.get_metric_name(fd.get("secondary_metric"))
|
||||
metric = utils.get_metric_name(fd["metric"])
|
||||
secondary_metric = (
|
||||
utils.get_metric_name(fd["secondary_metric"])
|
||||
if "secondary_metric" in fd
|
||||
else None
|
||||
)
|
||||
columns = ["country", "m1", "m2"]
|
||||
if metric == secondary_metric:
|
||||
ndf = df[cols]
|
||||
|
|
|
|||
Loading…
Reference in New Issue