style(mypy): Spit-and-polish pass (#10001)
Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
parent
656cdfb867
commit
91517a56a3
|
|
@ -50,10 +50,12 @@ multi_line_output = 3
|
||||||
order_by_type = false
|
order_by_type = false
|
||||||
|
|
||||||
[mypy]
|
[mypy]
|
||||||
|
disallow_any_generics = true
|
||||||
ignore_missing_imports = true
|
ignore_missing_imports = true
|
||||||
no_implicit_optional = true
|
no_implicit_optional = true
|
||||||
|
warn_unused_ignores = true
|
||||||
|
|
||||||
[mypy-superset,superset.app,superset.bin.*,superset.charts.*,superset.cli,superset.commands.*,superset.common.*,superset.config,superset.connectors.*,superset.constants,superset.dataframe,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.errors,superset.examples.*,superset.exceptions,superset.extensions,superset.forms,superset.jinja_context,superset.legacy,superset.migrations.*,superset.models.*,superset.result_set,superset.queries.*,superset.security.*,superset.sql_lab,superset.sql_parse,superset.sql_validators.*,superset.stats_logger,superset.tasks.*,superset.translations.*,superset.typing,superset.utils.*,,superset.views.*,superset.viz,superset.viz_sip38]
|
[mypy-superset.*]
|
||||||
check_untyped_defs = true
|
check_untyped_defs = true
|
||||||
disallow_untyped_calls = true
|
disallow_untyped_calls = true
|
||||||
disallow_untyped_defs = true
|
disallow_untyped_defs = true
|
||||||
|
|
|
||||||
|
|
@ -80,7 +80,7 @@ class SupersetAppInitializer:
|
||||||
|
|
||||||
self.flask_app = app
|
self.flask_app = app
|
||||||
self.config = app.config
|
self.config = app.config
|
||||||
self.manifest: dict = {}
|
self.manifest: Dict[Any, Any] = {}
|
||||||
|
|
||||||
def pre_init(self) -> None:
|
def pre_init(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
@ -542,7 +542,7 @@ class SupersetAppInitializer:
|
||||||
self.app = app
|
self.app = app
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, environ: Dict[str, Any], start_response: Callable
|
self, environ: Dict[str, Any], start_response: Callable[..., Any]
|
||||||
) -> Any:
|
) -> Any:
|
||||||
# Setting wsgi.input_terminated tells werkzeug.wsgi to ignore
|
# Setting wsgi.input_terminated tells werkzeug.wsgi to ignore
|
||||||
# content-length and read the stream till the end.
|
# content-length and read the stream till the end.
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from flask_appbuilder.models.sqla import Model
|
from flask_appbuilder.models.sqla import Model
|
||||||
from flask_appbuilder.security.sqla.models import User
|
from flask_appbuilder.security.sqla.models import User
|
||||||
|
|
@ -36,7 +36,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CreateChartCommand(BaseCommand):
|
class CreateChartCommand(BaseCommand):
|
||||||
def __init__(self, user: User, data: Dict):
|
def __init__(self, user: User, data: Dict[str, Any]):
|
||||||
self._actor = user
|
self._actor = user
|
||||||
self._properties = data.copy()
|
self._properties = data.copy()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from flask_appbuilder.models.sqla import Model
|
from flask_appbuilder.models.sqla import Model
|
||||||
from flask_appbuilder.security.sqla.models import User
|
from flask_appbuilder.security.sqla.models import User
|
||||||
|
|
@ -42,7 +42,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class UpdateChartCommand(BaseCommand):
|
class UpdateChartCommand(BaseCommand):
|
||||||
def __init__(self, user: User, model_id: int, data: Dict):
|
def __init__(self, user: User, model_id: int, data: Dict[str, Any]):
|
||||||
self._actor = user
|
self._actor = user
|
||||||
self._model_id = model_id
|
self._model_id = model_id
|
||||||
self._properties = data.copy()
|
self._properties = data.copy()
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ from pandas import DataFrame
|
||||||
|
|
||||||
from superset import app, is_feature_enabled
|
from superset import app, is_feature_enabled
|
||||||
from superset.exceptions import QueryObjectValidationError
|
from superset.exceptions import QueryObjectValidationError
|
||||||
|
from superset.typing import Metric
|
||||||
from superset.utils import core as utils, pandas_postprocessing
|
from superset.utils import core as utils, pandas_postprocessing
|
||||||
from superset.views.utils import get_time_range_endpoints
|
from superset.views.utils import get_time_range_endpoints
|
||||||
|
|
||||||
|
|
@ -67,11 +68,11 @@ class QueryObject:
|
||||||
row_limit: int
|
row_limit: int
|
||||||
filter: List[Dict[str, Any]]
|
filter: List[Dict[str, Any]]
|
||||||
timeseries_limit: int
|
timeseries_limit: int
|
||||||
timeseries_limit_metric: Optional[Dict]
|
timeseries_limit_metric: Optional[Metric]
|
||||||
order_desc: bool
|
order_desc: bool
|
||||||
extras: Dict
|
extras: Dict[str, Any]
|
||||||
columns: List[str]
|
columns: List[str]
|
||||||
orderby: List[List]
|
orderby: List[List[str]]
|
||||||
post_processing: List[Dict[str, Any]]
|
post_processing: List[Dict[str, Any]]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -85,11 +86,11 @@ class QueryObject:
|
||||||
is_timeseries: bool = False,
|
is_timeseries: bool = False,
|
||||||
timeseries_limit: int = 0,
|
timeseries_limit: int = 0,
|
||||||
row_limit: int = app.config["ROW_LIMIT"],
|
row_limit: int = app.config["ROW_LIMIT"],
|
||||||
timeseries_limit_metric: Optional[Dict] = None,
|
timeseries_limit_metric: Optional[Metric] = None,
|
||||||
order_desc: bool = True,
|
order_desc: bool = True,
|
||||||
extras: Optional[Dict] = None,
|
extras: Optional[Dict[str, Any]] = None,
|
||||||
columns: Optional[List[str]] = None,
|
columns: Optional[List[str]] = None,
|
||||||
orderby: Optional[List[List]] = None,
|
orderby: Optional[List[List[str]]] = None,
|
||||||
post_processing: Optional[List[Dict[str, Any]]] = None,
|
post_processing: Optional[List[Dict[str, Any]]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,7 @@ from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING
|
||||||
from cachelib.base import BaseCache
|
from cachelib.base import BaseCache
|
||||||
from celery.schedules import crontab
|
from celery.schedules import crontab
|
||||||
from dateutil import tz
|
from dateutil import tz
|
||||||
|
from flask import Blueprint
|
||||||
from flask_appbuilder.security.manager import AUTH_DB
|
from flask_appbuilder.security.manager import AUTH_DB
|
||||||
|
|
||||||
from superset.jinja_context import ( # pylint: disable=unused-import
|
from superset.jinja_context import ( # pylint: disable=unused-import
|
||||||
|
|
@ -421,7 +422,7 @@ DEFAULT_MODULE_DS_MAP = OrderedDict(
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
ADDITIONAL_MODULE_DS_MAP: Dict[str, List[str]] = {}
|
ADDITIONAL_MODULE_DS_MAP: Dict[str, List[str]] = {}
|
||||||
ADDITIONAL_MIDDLEWARE: List[Callable] = []
|
ADDITIONAL_MIDDLEWARE: List[Callable[..., Any]] = []
|
||||||
|
|
||||||
# 1) https://docs.python-guide.org/writing/logging/
|
# 1) https://docs.python-guide.org/writing/logging/
|
||||||
# 2) https://docs.python.org/2/library/logging.config.html
|
# 2) https://docs.python.org/2/library/logging.config.html
|
||||||
|
|
@ -624,7 +625,7 @@ ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[
|
||||||
# SQL Lab. The existing context gets updated with this dictionary,
|
# SQL Lab. The existing context gets updated with this dictionary,
|
||||||
# meaning values for existing keys get overwritten by the content of this
|
# meaning values for existing keys get overwritten by the content of this
|
||||||
# dictionary.
|
# dictionary.
|
||||||
JINJA_CONTEXT_ADDONS: Dict[str, Callable] = {}
|
JINJA_CONTEXT_ADDONS: Dict[str, Callable[..., Any]] = {}
|
||||||
|
|
||||||
# A dictionary of macro template processors that gets merged into global
|
# A dictionary of macro template processors that gets merged into global
|
||||||
# template processors. The existing template processors get updated with this
|
# template processors. The existing template processors get updated with this
|
||||||
|
|
@ -684,7 +685,7 @@ PERMISSION_INSTRUCTIONS_LINK = ""
|
||||||
|
|
||||||
# Integrate external Blueprints to the app by passing them to your
|
# Integrate external Blueprints to the app by passing them to your
|
||||||
# configuration. These blueprints will get integrated in the app
|
# configuration. These blueprints will get integrated in the app
|
||||||
BLUEPRINTS: List[Callable] = []
|
BLUEPRINTS: List[Blueprint] = []
|
||||||
|
|
||||||
# Provide a callable that receives a tracking_url and returns another
|
# Provide a callable that receives a tracking_url and returns another
|
||||||
# URL. This is used to translate internal Hadoop job tracker URL
|
# URL. This is used to translate internal Hadoop job tracker URL
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import json
|
import json
|
||||||
from typing import Any, Dict, Hashable, List, Optional, Type
|
from typing import Any, Dict, Hashable, List, Optional, Type, Union
|
||||||
|
|
||||||
from flask_appbuilder.security.sqla.models import User
|
from flask_appbuilder.security.sqla.models import User
|
||||||
from sqlalchemy import and_, Boolean, Column, Integer, String, Text
|
from sqlalchemy import and_, Boolean, Column, Integer, String, Text
|
||||||
|
|
@ -64,12 +64,12 @@ class BaseDatasource(
|
||||||
baselink: Optional[str] = None # url portion pointing to ModelView endpoint
|
baselink: Optional[str] = None # url portion pointing to ModelView endpoint
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def column_class(self) -> Type:
|
def column_class(self) -> Type["BaseColumn"]:
|
||||||
# link to derivative of BaseColumn
|
# link to derivative of BaseColumn
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def metric_class(self) -> Type:
|
def metric_class(self) -> Type["BaseMetric"]:
|
||||||
# link to derivative of BaseMetric
|
# link to derivative of BaseMetric
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
@ -368,7 +368,7 @@ class BaseDatasource(
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def values_for_column(self, column_name: str, limit: int = 10000) -> List:
|
def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
|
||||||
"""Given a column, returns an iterable of distinct values
|
"""Given a column, returns an iterable of distinct values
|
||||||
|
|
||||||
This is used to populate the dropdown showing a list of
|
This is used to populate the dropdown showing a list of
|
||||||
|
|
@ -389,7 +389,10 @@ class BaseDatasource(
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_fk_many_from_list(
|
def get_fk_many_from_list(
|
||||||
object_list: List[Any], fkmany: List[Column], fkmany_class: Type, key_attr: str,
|
object_list: List[Any],
|
||||||
|
fkmany: List[Column],
|
||||||
|
fkmany_class: Type[Union["BaseColumn", "BaseMetric"]],
|
||||||
|
key_attr: str,
|
||||||
) -> List[Column]: # pylint: disable=too-many-locals
|
) -> List[Column]: # pylint: disable=too-many-locals
|
||||||
"""Update ORM one-to-many list from object list
|
"""Update ORM one-to-many list from object list
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,6 @@
|
||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
from collections import OrderedDict
|
|
||||||
from typing import Dict, List, Optional, Set, Type, TYPE_CHECKING
|
from typing import Dict, List, Optional, Set, Type, TYPE_CHECKING
|
||||||
|
|
||||||
from sqlalchemy import or_
|
from sqlalchemy import or_
|
||||||
|
|
@ -22,6 +21,8 @@ from sqlalchemy.orm import Session, subqueryload
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
# pylint: disable=unused-import
|
# pylint: disable=unused-import
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
from superset.models.core import Database
|
from superset.models.core import Database
|
||||||
from superset.connectors.base.models import BaseDatasource
|
from superset.connectors.base.models import BaseDatasource
|
||||||
|
|
||||||
|
|
@ -32,7 +33,7 @@ class ConnectorRegistry:
|
||||||
sources: Dict[str, Type["BaseDatasource"]] = {}
|
sources: Dict[str, Type["BaseDatasource"]] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_sources(cls, datasource_config: OrderedDict) -> None:
|
def register_sources(cls, datasource_config: "OrderedDict[str, List[str]]") -> None:
|
||||||
for module_name, class_names in datasource_config.items():
|
for module_name, class_names in datasource_config.items():
|
||||||
class_names = [str(s) for s in class_names]
|
class_names = [str(s) for s in class_names]
|
||||||
module_obj = __import__(module_name, fromlist=class_names)
|
module_obj = __import__(module_name, fromlist=class_names)
|
||||||
|
|
|
||||||
|
|
@ -24,18 +24,7 @@ from copy import deepcopy
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from distutils.version import LooseVersion
|
from distutils.version import LooseVersion
|
||||||
from multiprocessing.pool import ThreadPool
|
from multiprocessing.pool import ThreadPool
|
||||||
from typing import (
|
from typing import Any, cast, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
cast,
|
|
||||||
Dict,
|
|
||||||
Iterable,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Set,
|
|
||||||
Tuple,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
|
@ -173,7 +162,7 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
|
||||||
return self.__repr__()
|
return self.__repr__()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data(self) -> Dict:
|
def data(self) -> Dict[str, Any]:
|
||||||
return {"id": self.id, "name": self.cluster_name, "backend": "druid"}
|
return {"id": self.id, "name": self.cluster_name, "backend": "druid"}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -354,7 +343,7 @@ class DruidColumn(Model, BaseColumn):
|
||||||
return self.dimension_spec_json
|
return self.dimension_spec_json
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dimension_spec(self) -> Optional[Dict]:
|
def dimension_spec(self) -> Optional[Dict[str, Any]]:
|
||||||
if self.dimension_spec_json:
|
if self.dimension_spec_json:
|
||||||
return json.loads(self.dimension_spec_json)
|
return json.loads(self.dimension_spec_json)
|
||||||
return None
|
return None
|
||||||
|
|
@ -438,7 +427,7 @@ class DruidMetric(Model, BaseMetric):
|
||||||
return self.json
|
return self.json
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def json_obj(self) -> Dict:
|
def json_obj(self) -> Dict[str, Any]:
|
||||||
try:
|
try:
|
||||||
obj = json.loads(self.json)
|
obj = json.loads(self.json)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
@ -614,7 +603,7 @@ class DruidDatasource(Model, BaseDatasource):
|
||||||
name = escape(self.datasource_name)
|
name = escape(self.datasource_name)
|
||||||
return Markup(f'<a href="{url}">{name}</a>')
|
return Markup(f'<a href="{url}">{name}</a>')
|
||||||
|
|
||||||
def get_metric_obj(self, metric_name: str) -> Dict:
|
def get_metric_obj(self, metric_name: str) -> Dict[str, Any]:
|
||||||
return [m.json_obj for m in self.metrics if m.metric_name == metric_name][0]
|
return [m.json_obj for m in self.metrics if m.metric_name == metric_name][0]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -705,7 +694,11 @@ class DruidDatasource(Model, BaseDatasource):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sync_to_db_from_config(
|
def sync_to_db_from_config(
|
||||||
cls, druid_config: Dict, user: User, cluster: DruidCluster, refresh: bool = True
|
cls,
|
||||||
|
druid_config: Dict[str, Any],
|
||||||
|
user: User,
|
||||||
|
cluster: DruidCluster,
|
||||||
|
refresh: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Merges the ds config from druid_config into one stored in the db."""
|
"""Merges the ds config from druid_config into one stored in the db."""
|
||||||
session = db.session
|
session = db.session
|
||||||
|
|
@ -901,7 +894,7 @@ class DruidDatasource(Model, BaseDatasource):
|
||||||
return postagg_metrics
|
return postagg_metrics
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def recursive_get_fields(_conf: Dict) -> List[str]:
|
def recursive_get_fields(_conf: Dict[str, Any]) -> List[str]:
|
||||||
_type = _conf.get("type")
|
_type = _conf.get("type")
|
||||||
_field = _conf.get("field")
|
_field = _conf.get("field")
|
||||||
_fields = _conf.get("fields")
|
_fields = _conf.get("fields")
|
||||||
|
|
@ -957,8 +950,8 @@ class DruidDatasource(Model, BaseDatasource):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def metrics_and_post_aggs(
|
def metrics_and_post_aggs(
|
||||||
metrics: List[Union[Dict, str]], metrics_dict: Dict[str, DruidMetric],
|
metrics: List[Metric], metrics_dict: Dict[str, DruidMetric],
|
||||||
) -> Tuple[OrderedDict, OrderedDict]:
|
) -> Tuple["OrderedDict[str, Any]", "OrderedDict[str, Any]"]:
|
||||||
# Separate metrics into those that are aggregations
|
# Separate metrics into those that are aggregations
|
||||||
# and those that are post aggregations
|
# and those that are post aggregations
|
||||||
saved_agg_names = set()
|
saved_agg_names = set()
|
||||||
|
|
@ -987,7 +980,7 @@ class DruidDatasource(Model, BaseDatasource):
|
||||||
)
|
)
|
||||||
return aggs, post_aggs
|
return aggs, post_aggs
|
||||||
|
|
||||||
def values_for_column(self, column_name: str, limit: int = 10000) -> List:
|
def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
|
||||||
"""Retrieve some values for the given column"""
|
"""Retrieve some values for the given column"""
|
||||||
logger.info(
|
logger.info(
|
||||||
"Getting values for columns [{}] limited to [{}]".format(column_name, limit)
|
"Getting values for columns [{}] limited to [{}]".format(column_name, limit)
|
||||||
|
|
@ -1079,8 +1072,10 @@ class DruidDatasource(Model, BaseDatasource):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_aggregations(
|
def get_aggregations(
|
||||||
metrics_dict: Dict, saved_metrics: Set[str], adhoc_metrics: List[Dict] = []
|
metrics_dict: Dict[str, Any],
|
||||||
) -> OrderedDict:
|
saved_metrics: Set[str],
|
||||||
|
adhoc_metrics: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
) -> "OrderedDict[str, Any]":
|
||||||
"""
|
"""
|
||||||
Returns a dictionary of aggregation metric names to aggregation json objects
|
Returns a dictionary of aggregation metric names to aggregation json objects
|
||||||
|
|
||||||
|
|
@ -1089,7 +1084,9 @@ class DruidDatasource(Model, BaseDatasource):
|
||||||
:param adhoc_metrics: list of adhoc metric names
|
:param adhoc_metrics: list of adhoc metric names
|
||||||
:raise SupersetException: if one or more metric names are not aggregations
|
:raise SupersetException: if one or more metric names are not aggregations
|
||||||
"""
|
"""
|
||||||
aggregations: OrderedDict = OrderedDict()
|
if not adhoc_metrics:
|
||||||
|
adhoc_metrics = []
|
||||||
|
aggregations = OrderedDict()
|
||||||
invalid_metric_names = []
|
invalid_metric_names = []
|
||||||
for metric_name in saved_metrics:
|
for metric_name in saved_metrics:
|
||||||
if metric_name in metrics_dict:
|
if metric_name in metrics_dict:
|
||||||
|
|
@ -1115,7 +1112,7 @@ class DruidDatasource(Model, BaseDatasource):
|
||||||
|
|
||||||
def get_dimensions(
|
def get_dimensions(
|
||||||
self, columns: List[str], columns_dict: Dict[str, DruidColumn]
|
self, columns: List[str], columns_dict: Dict[str, DruidColumn]
|
||||||
) -> List[Union[str, Dict]]:
|
) -> List[Union[str, Dict[str, Any]]]:
|
||||||
dimensions = []
|
dimensions = []
|
||||||
columns = [col for col in columns if col in columns_dict]
|
columns = [col for col in columns if col in columns_dict]
|
||||||
for column_name in columns:
|
for column_name in columns:
|
||||||
|
|
@ -1433,7 +1430,7 @@ class DruidDatasource(Model, BaseDatasource):
|
||||||
df[columns] = df[columns].fillna(NULL_STRING).astype("unicode")
|
df[columns] = df[columns].fillna(NULL_STRING).astype("unicode")
|
||||||
return df
|
return df
|
||||||
|
|
||||||
def query(self, query_obj: Dict) -> QueryResult:
|
def query(self, query_obj: QueryObjectDict) -> QueryResult:
|
||||||
qry_start_dttm = datetime.now()
|
qry_start_dttm = datetime.now()
|
||||||
client = self.cluster.get_pydruid_client()
|
client = self.cluster.get_pydruid_client()
|
||||||
query_str = self.get_query_str(client=client, query_obj=query_obj, phase=2)
|
query_str = self.get_query_str(client=client, query_obj=query_obj, phase=2)
|
||||||
|
|
@ -1583,7 +1580,7 @@ class DruidDatasource(Model, BaseDatasource):
|
||||||
dimension=col, value=eq, extraction_function=extraction_fn
|
dimension=col, value=eq, extraction_function=extraction_fn
|
||||||
)
|
)
|
||||||
elif is_list_target:
|
elif is_list_target:
|
||||||
eq = cast(list, eq)
|
eq = cast(List[Any], eq)
|
||||||
fields = []
|
fields = []
|
||||||
# ignore the filter if it has no value
|
# ignore the filter if it has no value
|
||||||
if not len(eq):
|
if not len(eq):
|
||||||
|
|
|
||||||
|
|
@ -597,7 +597,7 @@ class SqlaTable(Model, BaseDatasource):
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data(self) -> Dict:
|
def data(self) -> Dict[str, Any]:
|
||||||
d = super().data
|
d = super().data
|
||||||
if self.type == "table":
|
if self.type == "table":
|
||||||
grains = self.database.grains() or []
|
grains = self.database.grains() or []
|
||||||
|
|
@ -684,7 +684,9 @@ class SqlaTable(Model, BaseDatasource):
|
||||||
return TextAsFrom(sa.text(from_sql), []).alias("expr_qry")
|
return TextAsFrom(sa.text(from_sql), []).alias("expr_qry")
|
||||||
return self.get_sqla_table()
|
return self.get_sqla_table()
|
||||||
|
|
||||||
def adhoc_metric_to_sqla(self, metric: Dict, cols: Dict) -> Optional[Column]:
|
def adhoc_metric_to_sqla(
|
||||||
|
self, metric: Dict[str, Any], cols: Dict[str, Any]
|
||||||
|
) -> Optional[Column]:
|
||||||
"""
|
"""
|
||||||
Turn an adhoc metric into a sqlalchemy column.
|
Turn an adhoc metric into a sqlalchemy column.
|
||||||
|
|
||||||
|
|
@ -804,7 +806,7 @@ class SqlaTable(Model, BaseDatasource):
|
||||||
main_metric_expr = self.make_sqla_column_compatible(main_metric_expr, label)
|
main_metric_expr = self.make_sqla_column_compatible(main_metric_expr, label)
|
||||||
|
|
||||||
select_exprs: List[Column] = []
|
select_exprs: List[Column] = []
|
||||||
groupby_exprs_sans_timestamp: OrderedDict = OrderedDict()
|
groupby_exprs_sans_timestamp = OrderedDict()
|
||||||
|
|
||||||
if (is_sip_38 and metrics and columns) or (not is_sip_38 and groupby):
|
if (is_sip_38 and metrics and columns) or (not is_sip_38 and groupby):
|
||||||
# dedup columns while preserving order
|
# dedup columns while preserving order
|
||||||
|
|
@ -874,7 +876,7 @@ class SqlaTable(Model, BaseDatasource):
|
||||||
qry = qry.group_by(*groupby_exprs_with_timestamp.values())
|
qry = qry.group_by(*groupby_exprs_with_timestamp.values())
|
||||||
|
|
||||||
where_clause_and = []
|
where_clause_and = []
|
||||||
having_clause_and: List = []
|
having_clause_and = []
|
||||||
|
|
||||||
for flt in filter: # type: ignore
|
for flt in filter: # type: ignore
|
||||||
if not all([flt.get(s) for s in ["col", "op"]]):
|
if not all([flt.get(s) for s in ["col", "op"]]):
|
||||||
|
|
@ -1082,7 +1084,10 @@ class SqlaTable(Model, BaseDatasource):
|
||||||
return ob
|
return ob
|
||||||
|
|
||||||
def _get_top_groups(
|
def _get_top_groups(
|
||||||
self, df: pd.DataFrame, dimensions: List, groupby_exprs: OrderedDict
|
self,
|
||||||
|
df: pd.DataFrame,
|
||||||
|
dimensions: List[str],
|
||||||
|
groupby_exprs: "OrderedDict[str, Any]",
|
||||||
) -> ColumnElement:
|
) -> ColumnElement:
|
||||||
groups = []
|
groups = []
|
||||||
for unused, row in df.iterrows():
|
for unused, row in df.iterrows():
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@
|
||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
from typing import Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from flask_appbuilder.models.filters import BaseFilter
|
from flask_appbuilder.models.filters import BaseFilter
|
||||||
from flask_appbuilder.models.sqla import Model
|
from flask_appbuilder.models.sqla import Model
|
||||||
|
|
@ -75,7 +75,7 @@ class BaseDAO:
|
||||||
return query.all()
|
return query.all()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, properties: Dict, commit: bool = True) -> Model:
|
def create(cls, properties: Dict[str, Any], commit: bool = True) -> Model:
|
||||||
"""
|
"""
|
||||||
Generic for creating models
|
Generic for creating models
|
||||||
:raises: DAOCreateFailedError
|
:raises: DAOCreateFailedError
|
||||||
|
|
@ -95,7 +95,9 @@ class BaseDAO:
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def update(cls, model: Model, properties: Dict, commit: bool = True) -> Model:
|
def update(
|
||||||
|
cls, model: Model, properties: Dict[str, Any], commit: bool = True
|
||||||
|
) -> Model:
|
||||||
"""
|
"""
|
||||||
Generic update a model
|
Generic update a model
|
||||||
:raises: DAOCreateFailedError
|
:raises: DAOCreateFailedError
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from flask_appbuilder.models.sqla import Model
|
from flask_appbuilder.models.sqla import Model
|
||||||
from flask_appbuilder.security.sqla.models import User
|
from flask_appbuilder.security.sqla.models import User
|
||||||
|
|
@ -35,7 +35,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CreateDashboardCommand(BaseCommand):
|
class CreateDashboardCommand(BaseCommand):
|
||||||
def __init__(self, user: User, data: Dict):
|
def __init__(self, user: User, data: Dict[str, Any]):
|
||||||
self._actor = user
|
self._actor = user
|
||||||
self._properties = data.copy()
|
self._properties = data.copy()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from flask_appbuilder.models.sqla import Model
|
from flask_appbuilder.models.sqla import Model
|
||||||
from flask_appbuilder.security.sqla.models import User
|
from flask_appbuilder.security.sqla.models import User
|
||||||
|
|
@ -40,7 +40,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class UpdateDashboardCommand(BaseCommand):
|
class UpdateDashboardCommand(BaseCommand):
|
||||||
def __init__(self, user: User, model_id: int, data: Dict):
|
def __init__(self, user: User, model_id: int, data: Dict[str, Any]):
|
||||||
self._actor = user
|
self._actor = user
|
||||||
self._model_id = model_id
|
self._model_id = model_id
|
||||||
self._properties = data.copy()
|
self._properties = data.copy()
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from flask_appbuilder.models.sqla import Model
|
from flask_appbuilder.models.sqla import Model
|
||||||
from flask_appbuilder.security.sqla.models import User
|
from flask_appbuilder.security.sqla.models import User
|
||||||
|
|
@ -39,7 +39,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CreateDatasetCommand(BaseCommand):
|
class CreateDatasetCommand(BaseCommand):
|
||||||
def __init__(self, user: User, data: Dict):
|
def __init__(self, user: User, data: Dict[str, Any]):
|
||||||
self._actor = user
|
self._actor = user
|
||||||
self._properties = data.copy()
|
self._properties = data.copy()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from typing import Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from flask_appbuilder.models.sqla import Model
|
from flask_appbuilder.models.sqla import Model
|
||||||
from flask_appbuilder.security.sqla.models import User
|
from flask_appbuilder.security.sqla.models import User
|
||||||
|
|
@ -48,7 +48,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class UpdateDatasetCommand(BaseCommand):
|
class UpdateDatasetCommand(BaseCommand):
|
||||||
def __init__(self, user: User, model_id: int, data: Dict):
|
def __init__(self, user: User, model_id: int, data: Dict[str, Any]):
|
||||||
self._actor = user
|
self._actor = user
|
||||||
self._model_id = model_id
|
self._model_id = model_id
|
||||||
self._properties = data.copy()
|
self._properties = data.copy()
|
||||||
|
|
@ -111,7 +111,7 @@ class UpdateDatasetCommand(BaseCommand):
|
||||||
raise exception
|
raise exception
|
||||||
|
|
||||||
def _validate_columns(
|
def _validate_columns(
|
||||||
self, columns: List[Dict], exceptions: List[ValidationError]
|
self, columns: List[Dict[str, Any]], exceptions: List[ValidationError]
|
||||||
) -> None:
|
) -> None:
|
||||||
# Validate duplicates on data
|
# Validate duplicates on data
|
||||||
if self._get_duplicates(columns, "column_name"):
|
if self._get_duplicates(columns, "column_name"):
|
||||||
|
|
@ -133,7 +133,7 @@ class UpdateDatasetCommand(BaseCommand):
|
||||||
exceptions.append(DatasetColumnsExistsValidationError())
|
exceptions.append(DatasetColumnsExistsValidationError())
|
||||||
|
|
||||||
def _validate_metrics(
|
def _validate_metrics(
|
||||||
self, metrics: List[Dict], exceptions: List[ValidationError]
|
self, metrics: List[Dict[str, Any]], exceptions: List[ValidationError]
|
||||||
) -> None:
|
) -> None:
|
||||||
if self._get_duplicates(metrics, "metric_name"):
|
if self._get_duplicates(metrics, "metric_name"):
|
||||||
exceptions.append(DatasetMetricsDuplicateValidationError())
|
exceptions.append(DatasetMetricsDuplicateValidationError())
|
||||||
|
|
@ -152,7 +152,7 @@ class UpdateDatasetCommand(BaseCommand):
|
||||||
exceptions.append(DatasetMetricsExistsValidationError())
|
exceptions.append(DatasetMetricsExistsValidationError())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_duplicates(data: List[Dict], key: str) -> List[str]:
|
def _get_duplicates(data: List[Dict[str, Any]], key: str) -> List[str]:
|
||||||
duplicates = [
|
duplicates = [
|
||||||
name
|
name
|
||||||
for name, count in Counter([item[key] for item in data]).items()
|
for name, count in Counter([item[key] for item in data]).items()
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
@ -116,7 +116,7 @@ class DatasetDAO(BaseDAO):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def update(
|
def update(
|
||||||
cls, model: SqlaTable, properties: Dict, commit: bool = True
|
cls, model: SqlaTable, properties: Dict[str, Any], commit: bool = True
|
||||||
) -> Optional[SqlaTable]:
|
) -> Optional[SqlaTable]:
|
||||||
"""
|
"""
|
||||||
Updates a Dataset model on the metadata DB
|
Updates a Dataset model on the metadata DB
|
||||||
|
|
@ -151,13 +151,13 @@ class DatasetDAO(BaseDAO):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def update_column(
|
def update_column(
|
||||||
cls, model: TableColumn, properties: Dict, commit: bool = True
|
cls, model: TableColumn, properties: Dict[str, Any], commit: bool = True
|
||||||
) -> Optional[TableColumn]:
|
) -> Optional[TableColumn]:
|
||||||
return DatasetColumnDAO.update(model, properties, commit=commit)
|
return DatasetColumnDAO.update(model, properties, commit=commit)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_column(
|
def create_column(
|
||||||
cls, properties: Dict, commit: bool = True
|
cls, properties: Dict[str, Any], commit: bool = True
|
||||||
) -> Optional[TableColumn]:
|
) -> Optional[TableColumn]:
|
||||||
"""
|
"""
|
||||||
Creates a Dataset model on the metadata DB
|
Creates a Dataset model on the metadata DB
|
||||||
|
|
@ -166,13 +166,13 @@ class DatasetDAO(BaseDAO):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def update_metric(
|
def update_metric(
|
||||||
cls, model: SqlMetric, properties: Dict, commit: bool = True
|
cls, model: SqlMetric, properties: Dict[str, Any], commit: bool = True
|
||||||
) -> Optional[SqlMetric]:
|
) -> Optional[SqlMetric]:
|
||||||
return DatasetMetricDAO.update(model, properties, commit=commit)
|
return DatasetMetricDAO.update(model, properties, commit=commit)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_metric(
|
def create_metric(
|
||||||
cls, properties: Dict, commit: bool = True
|
cls, properties: Dict[str, Any], commit: bool = True
|
||||||
) -> Optional[SqlMetric]:
|
) -> Optional[SqlMetric]:
|
||||||
"""
|
"""
|
||||||
Creates a Dataset model on the metadata DB
|
Creates a Dataset model on the metadata DB
|
||||||
|
|
|
||||||
|
|
@ -151,7 +151,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||||
try_remove_schema_from_table_name = True # pylint: disable=invalid-name
|
try_remove_schema_from_table_name = True # pylint: disable=invalid-name
|
||||||
|
|
||||||
# default matching patterns for identifying column types
|
# default matching patterns for identifying column types
|
||||||
db_column_types: Dict[utils.DbColumnType, Tuple[Pattern, ...]] = {
|
db_column_types: Dict[utils.DbColumnType, Tuple[Pattern[Any], ...]] = {
|
||||||
utils.DbColumnType.NUMERIC: (
|
utils.DbColumnType.NUMERIC: (
|
||||||
re.compile(r".*DOUBLE.*", re.IGNORECASE),
|
re.compile(r".*DOUBLE.*", re.IGNORECASE),
|
||||||
re.compile(r".*FLOAT.*", re.IGNORECASE),
|
re.compile(r".*FLOAT.*", re.IGNORECASE),
|
||||||
|
|
@ -296,7 +296,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||||
return select_exprs
|
return select_exprs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
|
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
:param cursor: Cursor instance
|
:param cursor: Cursor instance
|
||||||
|
|
@ -311,8 +311,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def expand_data(
|
def expand_data(
|
||||||
cls, columns: List[dict], data: List[dict]
|
cls, columns: List[Dict[Any, Any]], data: List[Dict[Any, Any]]
|
||||||
) -> Tuple[List[dict], List[dict], List[dict]]:
|
) -> Tuple[List[Dict[Any, Any]], List[Dict[Any, Any]], List[Dict[Any, Any]]]:
|
||||||
"""
|
"""
|
||||||
Some engines support expanding nested fields. See implementation in Presto
|
Some engines support expanding nested fields. See implementation in Presto
|
||||||
spec for details.
|
spec for details.
|
||||||
|
|
@ -645,7 +645,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||||
schema: Optional[str],
|
schema: Optional[str],
|
||||||
database: "Database",
|
database: "Database",
|
||||||
query: Select,
|
query: Select,
|
||||||
columns: Optional[List] = None,
|
columns: Optional[List[Dict[str, str]]] = None,
|
||||||
) -> Optional[Select]:
|
) -> Optional[Select]:
|
||||||
"""
|
"""
|
||||||
Add a where clause to a query to reference only the most recent partition
|
Add a where clause to a query to reference only the most recent partition
|
||||||
|
|
@ -925,7 +925,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def pyodbc_rows_to_tuples(data: List[Any]) -> List[Tuple]:
|
def pyodbc_rows_to_tuples(data: List[Any]) -> List[Tuple[Any, ...]]:
|
||||||
"""
|
"""
|
||||||
Convert pyodbc.Row objects from `fetch_data` to tuples.
|
Convert pyodbc.Row objects from `fetch_data` to tuples.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -83,7 +83,7 @@ class BigQueryEngineSpec(BaseEngineSpec):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
|
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]:
|
||||||
data = super().fetch_data(cursor, limit)
|
data = super().fetch_data(cursor, limit)
|
||||||
# Support type BigQuery Row, introduced here PR #4071
|
# Support type BigQuery Row, introduced here PR #4071
|
||||||
# google.cloud.bigquery.table.Row
|
# google.cloud.bigquery.table.Row
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ class ExasolEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
|
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]:
|
||||||
data = super().fetch_data(cursor, limit)
|
data = super().fetch_data(cursor, limit)
|
||||||
# Lists of `pyodbc.Row` need to be unpacked further
|
# Lists of `pyodbc.Row` need to be unpacked further
|
||||||
return cls.pyodbc_rows_to_tuples(data)
|
return cls.pyodbc_rows_to_tuples(data)
|
||||||
|
|
|
||||||
|
|
@ -93,7 +93,7 @@ class HiveEngineSpec(PrestoEngineSpec):
|
||||||
return BaseEngineSpec.get_all_datasource_names(database, datasource_type)
|
return BaseEngineSpec.get_all_datasource_names(database, datasource_type)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
|
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]:
|
||||||
import pyhive
|
import pyhive
|
||||||
from TCLIService import ttypes
|
from TCLIService import ttypes
|
||||||
|
|
||||||
|
|
@ -304,7 +304,7 @@ class HiveEngineSpec(PrestoEngineSpec):
|
||||||
schema: Optional[str],
|
schema: Optional[str],
|
||||||
database: "Database",
|
database: "Database",
|
||||||
query: Select,
|
query: Select,
|
||||||
columns: Optional[List] = None,
|
columns: Optional[List[Dict[str, str]]] = None,
|
||||||
) -> Optional[Select]:
|
) -> Optional[Select]:
|
||||||
try:
|
try:
|
||||||
col_names, values = cls.latest_partition(
|
col_names, values = cls.latest_partition(
|
||||||
|
|
@ -323,7 +323,7 @@ class HiveEngineSpec(PrestoEngineSpec):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_fields(cls, cols: List[dict]) -> List[ColumnClause]:
|
def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[ColumnClause]:
|
||||||
return BaseEngineSpec._get_fields(cols) # pylint: disable=protected-access
|
return BaseEngineSpec._get_fields(cols) # pylint: disable=protected-access
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
||||||
|
|
@ -66,7 +66,7 @@ class MssqlEngineSpec(BaseEngineSpec):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
|
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]:
|
||||||
data = super().fetch_data(cursor, limit)
|
data = super().fetch_data(cursor, limit)
|
||||||
# Lists of `pyodbc.Row` need to be unpacked further
|
# Lists of `pyodbc.Row` need to be unpacked further
|
||||||
return cls.pyodbc_rows_to_tuples(data)
|
return cls.pyodbc_rows_to_tuples(data)
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,7 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
|
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]:
|
||||||
cursor.tzinfo_factory = FixedOffsetTimezone
|
cursor.tzinfo_factory = FixedOffsetTimezone
|
||||||
if not cursor.description:
|
if not cursor.description:
|
||||||
return []
|
return []
|
||||||
|
|
|
||||||
|
|
@ -164,7 +164,7 @@ class PrestoEngineSpec(BaseEngineSpec):
|
||||||
return [row[0] for row in results]
|
return [row[0] for row in results]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _create_column_info(cls, name: str, data_type: str) -> dict:
|
def _create_column_info(cls, name: str, data_type: str) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Create column info object
|
Create column info object
|
||||||
:param name: column name
|
:param name: column name
|
||||||
|
|
@ -213,7 +213,10 @@ class PrestoEngineSpec(BaseEngineSpec):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _parse_structural_column( # pylint: disable=too-many-locals,too-many-branches
|
def _parse_structural_column( # pylint: disable=too-many-locals,too-many-branches
|
||||||
cls, parent_column_name: str, parent_data_type: str, result: List[dict]
|
cls,
|
||||||
|
parent_column_name: str,
|
||||||
|
parent_data_type: str,
|
||||||
|
result: List[Dict[str, Any]],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Parse a row or array column
|
Parse a row or array column
|
||||||
|
|
@ -322,7 +325,7 @@ class PrestoEngineSpec(BaseEngineSpec):
|
||||||
(i.e. column name and data type)
|
(i.e. column name and data type)
|
||||||
"""
|
"""
|
||||||
columns = cls._show_columns(inspector, table_name, schema)
|
columns = cls._show_columns(inspector, table_name, schema)
|
||||||
result: List[dict] = []
|
result: List[Dict[str, Any]] = []
|
||||||
for column in columns:
|
for column in columns:
|
||||||
try:
|
try:
|
||||||
# parse column if it is a row or array
|
# parse column if it is a row or array
|
||||||
|
|
@ -361,7 +364,7 @@ class PrestoEngineSpec(BaseEngineSpec):
|
||||||
return column_name.startswith('"') and column_name.endswith('"')
|
return column_name.startswith('"') and column_name.endswith('"')
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_fields(cls, cols: List[dict]) -> List[ColumnClause]:
|
def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[ColumnClause]:
|
||||||
"""
|
"""
|
||||||
Format column clauses where names are in quotes and labels are specified
|
Format column clauses where names are in quotes and labels are specified
|
||||||
:param cols: columns
|
:param cols: columns
|
||||||
|
|
@ -561,8 +564,8 @@ class PrestoEngineSpec(BaseEngineSpec):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def expand_data( # pylint: disable=too-many-locals
|
def expand_data( # pylint: disable=too-many-locals
|
||||||
cls, columns: List[dict], data: List[dict]
|
cls, columns: List[Dict[Any, Any]], data: List[Dict[Any, Any]]
|
||||||
) -> Tuple[List[dict], List[dict], List[dict]]:
|
) -> Tuple[List[Dict[Any, Any]], List[Dict[Any, Any]], List[Dict[Any, Any]]]:
|
||||||
"""
|
"""
|
||||||
We do not immediately display rows and arrays clearly in the data grid. This
|
We do not immediately display rows and arrays clearly in the data grid. This
|
||||||
method separates out nested fields and data values to help clearly display
|
method separates out nested fields and data values to help clearly display
|
||||||
|
|
@ -590,7 +593,7 @@ class PrestoEngineSpec(BaseEngineSpec):
|
||||||
# process each column, unnesting ARRAY types and
|
# process each column, unnesting ARRAY types and
|
||||||
# expanding ROW types into new columns
|
# expanding ROW types into new columns
|
||||||
to_process = deque((column, 0) for column in columns)
|
to_process = deque((column, 0) for column in columns)
|
||||||
all_columns: List[dict] = []
|
all_columns: List[Dict[str, Any]] = []
|
||||||
expanded_columns = []
|
expanded_columns = []
|
||||||
current_array_level = None
|
current_array_level = None
|
||||||
while to_process:
|
while to_process:
|
||||||
|
|
@ -843,7 +846,7 @@ class PrestoEngineSpec(BaseEngineSpec):
|
||||||
schema: Optional[str],
|
schema: Optional[str],
|
||||||
database: "Database",
|
database: "Database",
|
||||||
query: Select,
|
query: Select,
|
||||||
columns: Optional[List] = None,
|
columns: Optional[List[Dict[str, str]]] = None,
|
||||||
) -> Optional[Select]:
|
) -> Optional[Select]:
|
||||||
try:
|
try:
|
||||||
col_names, values = cls.latest_partition(
|
col_names, values = cls.latest_partition(
|
||||||
|
|
|
||||||
|
|
@ -95,7 +95,9 @@ class UIManifestProcessor:
|
||||||
self.parse_manifest_json()
|
self.parse_manifest_json()
|
||||||
|
|
||||||
@app.context_processor
|
@app.context_processor
|
||||||
def get_manifest() -> Dict[str, Callable]: # pylint: disable=unused-variable
|
def get_manifest() -> Dict[ # pylint: disable=unused-variable
|
||||||
|
str, Callable[[str], List[str]]
|
||||||
|
]:
|
||||||
loaded_chunks = set()
|
loaded_chunks = set()
|
||||||
|
|
||||||
def get_files(bundle: str, asset_type: str = "js") -> List[str]:
|
def get_files(bundle: str, asset_type: str = "js") -> List[str]:
|
||||||
|
|
@ -131,7 +133,7 @@ appbuilder = AppBuilder(update_perms=False)
|
||||||
cache_manager = CacheManager()
|
cache_manager = CacheManager()
|
||||||
celery_app = celery.Celery()
|
celery_app = celery.Celery()
|
||||||
db = SQLA()
|
db = SQLA()
|
||||||
_event_logger: dict = {}
|
_event_logger: Dict[str, Any] = {}
|
||||||
event_logger = LocalProxy(lambda: _event_logger.get("event_logger"))
|
event_logger = LocalProxy(lambda: _event_logger.get("event_logger"))
|
||||||
feature_flag_manager = FeatureFlagManager()
|
feature_flag_manager = FeatureFlagManager()
|
||||||
jinja_context_manager = JinjaContextManager()
|
jinja_context_manager = JinjaContextManager()
|
||||||
|
|
|
||||||
|
|
@ -341,11 +341,14 @@ class Database(
|
||||||
def get_reserved_words(self) -> Set[str]:
|
def get_reserved_words(self) -> Set[str]:
|
||||||
return self.get_dialect().preparer.reserved_words
|
return self.get_dialect().preparer.reserved_words
|
||||||
|
|
||||||
def get_quoter(self) -> Callable:
|
def get_quoter(self) -> Callable[[str, Any], str]:
|
||||||
return self.get_dialect().identifier_preparer.quote
|
return self.get_dialect().identifier_preparer.quote
|
||||||
|
|
||||||
def get_df( # pylint: disable=too-many-locals
|
def get_df( # pylint: disable=too-many-locals
|
||||||
self, sql: str, schema: Optional[str] = None, mutator: Optional[Callable] = None
|
self,
|
||||||
|
sql: str,
|
||||||
|
schema: Optional[str] = None,
|
||||||
|
mutator: Optional[Callable[[pd.DataFrame], None]] = None,
|
||||||
) -> pd.DataFrame:
|
) -> pd.DataFrame:
|
||||||
sqls = [str(s).strip(" ;") for s in sqlparse.parse(sql)]
|
sqls = [str(s).strip(" ;") for s in sqlparse.parse(sql)]
|
||||||
|
|
||||||
|
|
@ -450,7 +453,7 @@ class Database(
|
||||||
|
|
||||||
@cache_util.memoized_func(
|
@cache_util.memoized_func(
|
||||||
key=lambda *args, **kwargs: "db:{}:schema:None:view_list",
|
key=lambda *args, **kwargs: "db:{}:schema:None:view_list",
|
||||||
attribute_in_key="id", # type: ignore
|
attribute_in_key="id",
|
||||||
)
|
)
|
||||||
def get_all_view_names_in_database(
|
def get_all_view_names_in_database(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -240,7 +240,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
|
||||||
self.json_metadata = value
|
self.json_metadata = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def position(self) -> Dict:
|
def position(self) -> Dict[str, Any]:
|
||||||
if self.position_json:
|
if self.position_json:
|
||||||
return json.loads(self.position_json)
|
return json.loads(self.position_json)
|
||||||
return {}
|
return {}
|
||||||
|
|
@ -315,7 +315,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
|
||||||
old_to_new_slc_id_dict: Dict[int, int] = {}
|
old_to_new_slc_id_dict: Dict[int, int] = {}
|
||||||
new_timed_refresh_immune_slices = []
|
new_timed_refresh_immune_slices = []
|
||||||
new_expanded_slices = {}
|
new_expanded_slices = {}
|
||||||
new_filter_scopes: Dict[str, Dict] = {}
|
new_filter_scopes = {}
|
||||||
i_params_dict = dashboard_to_import.params_dict
|
i_params_dict = dashboard_to_import.params_dict
|
||||||
remote_id_slice_map = {
|
remote_id_slice_map = {
|
||||||
slc.params_dict["remote_id"]: slc
|
slc.params_dict["remote_id"]: slc
|
||||||
|
|
@ -351,7 +351,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
|
||||||
# are converted to filter_scopes
|
# are converted to filter_scopes
|
||||||
# but dashboard create from import may still have old dashboard filter metadata
|
# but dashboard create from import may still have old dashboard filter metadata
|
||||||
# here we convert them to new filter_scopes metadata first
|
# here we convert them to new filter_scopes metadata first
|
||||||
filter_scopes: Dict = {}
|
filter_scopes = {}
|
||||||
if (
|
if (
|
||||||
"filter_immune_slices" in i_params_dict
|
"filter_immune_slices" in i_params_dict
|
||||||
or "filter_immune_slice_fields" in i_params_dict
|
or "filter_immune_slice_fields" in i_params_dict
|
||||||
|
|
@ -415,7 +415,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def export_dashboards( # pylint: disable=too-many-locals
|
def export_dashboards( # pylint: disable=too-many-locals
|
||||||
cls, dashboard_ids: List
|
cls, dashboard_ids: List[int]
|
||||||
) -> str:
|
) -> str:
|
||||||
copied_dashboards = []
|
copied_dashboards = []
|
||||||
datasource_ids = set()
|
datasource_ids = set()
|
||||||
|
|
|
||||||
|
|
@ -81,7 +81,7 @@ class ImportMixin:
|
||||||
for u in cls.__table_args__ # type: ignore
|
for u in cls.__table_args__ # type: ignore
|
||||||
if isinstance(u, UniqueConstraint)
|
if isinstance(u, UniqueConstraint)
|
||||||
]
|
]
|
||||||
unique.extend( # type: ignore
|
unique.extend(
|
||||||
{c.name} for c in cls.__table__.columns if c.unique # type: ignore
|
{c.name} for c in cls.__table__.columns if c.unique # type: ignore
|
||||||
)
|
)
|
||||||
return unique
|
return unique
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ from superset.tasks.thumbnails import cache_chart_thumbnail
|
||||||
from superset.utils import core as utils
|
from superset.utils import core as utils
|
||||||
|
|
||||||
if is_feature_enabled("SIP_38_VIZ_REARCHITECTURE"):
|
if is_feature_enabled("SIP_38_VIZ_REARCHITECTURE"):
|
||||||
from superset.viz_sip38 import BaseViz, viz_types # type: ignore
|
from superset.viz_sip38 import BaseViz, viz_types
|
||||||
else:
|
else:
|
||||||
from superset.viz import BaseViz, viz_types # type: ignore
|
from superset.viz import BaseViz, viz_types # type: ignore
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@
|
||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
from typing import Any, Optional, Type
|
from typing import Any, Dict, List, Optional, Type
|
||||||
|
|
||||||
from sqlalchemy import types
|
from sqlalchemy import types
|
||||||
from sqlalchemy.sql.sqltypes import Integer
|
from sqlalchemy.sql.sqltypes import Integer
|
||||||
|
|
@ -29,7 +29,7 @@ class TinyInteger(Integer):
|
||||||
A type for tiny ``int`` integers.
|
A type for tiny ``int`` integers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def python_type(self) -> Type:
|
def python_type(self) -> Type[int]:
|
||||||
return int
|
return int
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -42,7 +42,7 @@ class Interval(TypeEngine):
|
||||||
A type for intervals.
|
A type for intervals.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def python_type(self) -> Optional[Type]:
|
def python_type(self) -> Optional[Type[Any]]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -55,7 +55,7 @@ class Array(TypeEngine):
|
||||||
A type for arrays.
|
A type for arrays.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def python_type(self) -> Optional[Type]:
|
def python_type(self) -> Optional[Type[List[Any]]]:
|
||||||
return list
|
return list
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -68,7 +68,7 @@ class Map(TypeEngine):
|
||||||
A type for maps.
|
A type for maps.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def python_type(self) -> Optional[Type]:
|
def python_type(self) -> Optional[Type[Dict[Any, Any]]]:
|
||||||
return dict
|
return dict
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -81,7 +81,7 @@ class Row(TypeEngine):
|
||||||
A type for rows.
|
A type for rows.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def python_type(self) -> Optional[Type]:
|
def python_type(self) -> Optional[Type[Any]]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@
|
||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
from typing import Callable
|
from typing import Any
|
||||||
|
|
||||||
from flask import g
|
from flask import g
|
||||||
from flask_sqlalchemy import BaseQuery
|
from flask_sqlalchemy import BaseQuery
|
||||||
|
|
@ -25,7 +25,7 @@ from superset.views.base import BaseFilter
|
||||||
|
|
||||||
|
|
||||||
class QueryFilter(BaseFilter): # pylint: disable=too-few-public-methods
|
class QueryFilter(BaseFilter): # pylint: disable=too-few-public-methods
|
||||||
def apply(self, query: BaseQuery, value: Callable) -> BaseQuery:
|
def apply(self, query: BaseQuery, value: Any) -> BaseQuery:
|
||||||
"""
|
"""
|
||||||
Filter queries to only those owned by current user. If
|
Filter queries to only those owned by current user. If
|
||||||
can_access_all_queries permission is set a user can list all queries
|
can_access_all_queries permission is set a user can list all queries
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
|
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
@ -64,7 +64,7 @@ def stringify(obj: Any) -> str:
|
||||||
|
|
||||||
|
|
||||||
def stringify_values(array: np.ndarray) -> np.ndarray:
|
def stringify_values(array: np.ndarray) -> np.ndarray:
|
||||||
vstringify: Callable = np.vectorize(stringify)
|
vstringify = np.vectorize(stringify)
|
||||||
return vstringify(array)
|
return vstringify(array)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -172,7 +172,7 @@ class SupersetResultSet:
|
||||||
return table.to_pandas(integer_object_nulls=True)
|
return table.to_pandas(integer_object_nulls=True)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def first_nonempty(items: List) -> Any:
|
def first_nonempty(items: List[Any]) -> Any:
|
||||||
return next((i for i in items if i), None)
|
return next((i for i in items if i), None)
|
||||||
|
|
||||||
def is_temporal(self, db_type_str: Optional[str]) -> bool:
|
def is_temporal(self, db_type_str: Optional[str]) -> bool:
|
||||||
|
|
|
||||||
|
|
@ -21,11 +21,11 @@ from typing import Any, Callable, List, Optional, Set, Tuple, TYPE_CHECKING, Uni
|
||||||
|
|
||||||
from flask import current_app, g
|
from flask import current_app, g
|
||||||
from flask_appbuilder import Model
|
from flask_appbuilder import Model
|
||||||
from flask_appbuilder.security.sqla import models as ab_models
|
|
||||||
from flask_appbuilder.security.sqla.manager import SecurityManager
|
from flask_appbuilder.security.sqla.manager import SecurityManager
|
||||||
from flask_appbuilder.security.sqla.models import (
|
from flask_appbuilder.security.sqla.models import (
|
||||||
assoc_permissionview_role,
|
assoc_permissionview_role,
|
||||||
assoc_user_role,
|
assoc_user_role,
|
||||||
|
PermissionView,
|
||||||
)
|
)
|
||||||
from flask_appbuilder.security.views import (
|
from flask_appbuilder.security.views import (
|
||||||
PermissionModelView,
|
PermissionModelView,
|
||||||
|
|
@ -602,11 +602,8 @@ class SupersetSecurityManager(SecurityManager):
|
||||||
|
|
||||||
logger.info("Cleaning faulty perms")
|
logger.info("Cleaning faulty perms")
|
||||||
sesh = self.get_session
|
sesh = self.get_session
|
||||||
pvms = sesh.query(ab_models.PermissionView).filter(
|
pvms = sesh.query(PermissionView).filter(
|
||||||
or_(
|
or_(PermissionView.permission == None, PermissionView.view_menu == None,)
|
||||||
ab_models.PermissionView.permission == None,
|
|
||||||
ab_models.PermissionView.view_menu == None,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
deleted_count = pvms.delete()
|
deleted_count = pvms.delete()
|
||||||
sesh.commit()
|
sesh.commit()
|
||||||
|
|
@ -640,7 +637,9 @@ class SupersetSecurityManager(SecurityManager):
|
||||||
self.get_session.commit()
|
self.get_session.commit()
|
||||||
self.clean_perms()
|
self.clean_perms()
|
||||||
|
|
||||||
def set_role(self, role_name: str, pvm_check: Callable) -> None:
|
def set_role(
|
||||||
|
self, role_name: str, pvm_check: Callable[[PermissionView], bool]
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Set the FAB permission/views for the role.
|
Set the FAB permission/views for the role.
|
||||||
|
|
||||||
|
|
@ -650,7 +649,7 @@ class SupersetSecurityManager(SecurityManager):
|
||||||
|
|
||||||
logger.info("Syncing {} perms".format(role_name))
|
logger.info("Syncing {} perms".format(role_name))
|
||||||
sesh = self.get_session
|
sesh = self.get_session
|
||||||
pvms = sesh.query(ab_models.PermissionView).all()
|
pvms = sesh.query(PermissionView).all()
|
||||||
pvms = [p for p in pvms if p.permission and p.view_menu]
|
pvms = [p for p in pvms if p.permission and p.view_menu]
|
||||||
role = self.add_role(role_name)
|
role = self.add_role(role_name)
|
||||||
role_pvms = [p for p in pvms if pvm_check(p)]
|
role_pvms = [p for p in pvms if pvm_check(p)]
|
||||||
|
|
|
||||||
|
|
@ -299,9 +299,10 @@ def _serialize_and_expand_data(
|
||||||
db_engine_spec: BaseEngineSpec,
|
db_engine_spec: BaseEngineSpec,
|
||||||
use_msgpack: Optional[bool] = False,
|
use_msgpack: Optional[bool] = False,
|
||||||
expand_data: bool = False,
|
expand_data: bool = False,
|
||||||
) -> Tuple[Union[bytes, str], list, list, list]:
|
) -> Tuple[Union[bytes, str], List[Any], List[Any], List[Any]]:
|
||||||
selected_columns: List[Dict] = result_set.columns
|
selected_columns = result_set.columns
|
||||||
expanded_columns: List[Dict]
|
all_columns: List[Any]
|
||||||
|
expanded_columns: List[Any]
|
||||||
|
|
||||||
if use_msgpack:
|
if use_msgpack:
|
||||||
with stats_timing(
|
with stats_timing(
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ from superset import create_app
|
||||||
from superset.extensions import celery_app
|
from superset.extensions import celery_app
|
||||||
|
|
||||||
# Init the Flask app / configure everything
|
# Init the Flask app / configure everything
|
||||||
create_app() # type: ignore
|
create_app()
|
||||||
|
|
||||||
# Need to import late, as the celery_app will have been setup by "create_app()"
|
# Need to import late, as the celery_app will have been setup by "create_app()"
|
||||||
# pylint: disable=wrong-import-position, unused-import
|
# pylint: disable=wrong-import-position, unused-import
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ import urllib.request
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from email.utils import make_msgid, parseaddr
|
from email.utils import make_msgid, parseaddr
|
||||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
from typing import Any, Dict, Iterator, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||||
from urllib.error import URLError # pylint: disable=ungrouped-imports
|
from urllib.error import URLError # pylint: disable=ungrouped-imports
|
||||||
|
|
||||||
import croniter
|
import croniter
|
||||||
|
|
@ -36,7 +36,6 @@ from flask_login import login_user
|
||||||
from retry.api import retry_call
|
from retry.api import retry_call
|
||||||
from selenium.common.exceptions import WebDriverException
|
from selenium.common.exceptions import WebDriverException
|
||||||
from selenium.webdriver import chrome, firefox
|
from selenium.webdriver import chrome, firefox
|
||||||
from werkzeug.datastructures import TypeConversionDict
|
|
||||||
from werkzeug.http import parse_cookie
|
from werkzeug.http import parse_cookie
|
||||||
|
|
||||||
# Superset framework imports
|
# Superset framework imports
|
||||||
|
|
@ -53,6 +52,11 @@ from superset.models.schedules import (
|
||||||
)
|
)
|
||||||
from superset.utils.core import get_email_address_list, send_email_smtp
|
from superset.utils.core import get_email_address_list, send_email_smtp
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
# pylint: disable=unused-import
|
||||||
|
from werkzeug.datastructures import TypeConversionDict
|
||||||
|
|
||||||
|
|
||||||
# Globals
|
# Globals
|
||||||
config = app.config
|
config = app.config
|
||||||
logger = logging.getLogger("tasks.email_reports")
|
logger = logging.getLogger("tasks.email_reports")
|
||||||
|
|
@ -131,7 +135,7 @@ def _generate_mail_content(
|
||||||
return EmailContent(body, data, images)
|
return EmailContent(body, data, images)
|
||||||
|
|
||||||
|
|
||||||
def _get_auth_cookies() -> List[TypeConversionDict]:
|
def _get_auth_cookies() -> List["TypeConversionDict[Any, Any]"]:
|
||||||
# Login with the user specified to get the reports
|
# Login with the user specified to get the reports
|
||||||
with app.test_request_context():
|
with app.test_request_context():
|
||||||
user = security_manager.find_user(config["EMAIL_REPORTS_USER"])
|
user = security_manager.find_user(config["EMAIL_REPORTS_USER"])
|
||||||
|
|
|
||||||
|
|
@ -27,8 +27,9 @@ def view_cache_key(*args: Any, **kwargs: Any) -> str: # pylint: disable=unused-
|
||||||
|
|
||||||
|
|
||||||
def memoized_func(
|
def memoized_func(
|
||||||
key: Callable = view_cache_key, attribute_in_key: Optional[str] = None
|
key: Callable[..., str] = view_cache_key, # pylint: disable=bad-whitespace
|
||||||
) -> Callable:
|
attribute_in_key: Optional[str] = None,
|
||||||
|
) -> Callable[..., Any]:
|
||||||
"""Use this decorator to cache functions that have predefined first arg.
|
"""Use this decorator to cache functions that have predefined first arg.
|
||||||
|
|
||||||
enable_cache is treated as True by default,
|
enable_cache is treated as True by default,
|
||||||
|
|
@ -45,7 +46,7 @@ def memoized_func(
|
||||||
returns the caching key.
|
returns the caching key.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrap(f: Callable) -> Callable:
|
def wrap(f: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
if cache_manager.tables_cache:
|
if cache_manager.tables_cache:
|
||||||
|
|
||||||
def wrapped_f(self: Any, *args: Any, **kwargs: Any) -> Any:
|
def wrapped_f(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
|
|
||||||
|
|
@ -85,7 +85,7 @@ from superset.exceptions import (
|
||||||
SupersetException,
|
SupersetException,
|
||||||
SupersetTimeoutException,
|
SupersetTimeoutException,
|
||||||
)
|
)
|
||||||
from superset.typing import FormData, Metric
|
from superset.typing import FlaskResponse, FormData, Metric
|
||||||
from superset.utils.dates import datetime_to_epoch, EPOCH
|
from superset.utils.dates import datetime_to_epoch, EPOCH
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -147,7 +147,9 @@ class _memoized:
|
||||||
should account for instance variable changes.
|
should account for instance variable changes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, func: Callable, watch: Optional[Tuple[str, ...]] = None) -> None:
|
def __init__(
|
||||||
|
self, func: Callable[..., Any], watch: Optional[Tuple[str, ...]] = None
|
||||||
|
) -> None:
|
||||||
self.func = func
|
self.func = func
|
||||||
self.cache: Dict[Any, Any] = {}
|
self.cache: Dict[Any, Any] = {}
|
||||||
self.is_method = False
|
self.is_method = False
|
||||||
|
|
@ -173,7 +175,7 @@ class _memoized:
|
||||||
"""Return the function's docstring."""
|
"""Return the function's docstring."""
|
||||||
return self.func.__doc__ or ""
|
return self.func.__doc__ or ""
|
||||||
|
|
||||||
def __get__(self, obj: Any, objtype: Type) -> functools.partial:
|
def __get__(self, obj: Any, objtype: Type[Any]) -> functools.partial: # type: ignore
|
||||||
if not self.is_method:
|
if not self.is_method:
|
||||||
self.is_method = True
|
self.is_method = True
|
||||||
"""Support instance methods."""
|
"""Support instance methods."""
|
||||||
|
|
@ -181,13 +183,13 @@ class _memoized:
|
||||||
|
|
||||||
|
|
||||||
def memoized(
|
def memoized(
|
||||||
func: Optional[Callable] = None, watch: Optional[Tuple[str, ...]] = None
|
func: Optional[Callable[..., Any]] = None, watch: Optional[Tuple[str, ...]] = None
|
||||||
) -> Callable:
|
) -> Callable[..., Any]:
|
||||||
if func:
|
if func:
|
||||||
return _memoized(func)
|
return _memoized(func)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def wrapper(f: Callable) -> Callable:
|
def wrapper(f: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
return _memoized(f, watch)
|
return _memoized(f, watch)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
@ -1241,7 +1243,9 @@ def create_ssl_cert_file(certificate: str) -> str:
|
||||||
return path
|
return path
|
||||||
|
|
||||||
|
|
||||||
def time_function(func: Callable, *args: Any, **kwargs: Any) -> Tuple[float, Any]:
|
def time_function(
|
||||||
|
func: Callable[..., FlaskResponse], *args: Any, **kwargs: Any
|
||||||
|
) -> Tuple[float, Any]:
|
||||||
"""
|
"""
|
||||||
Measures the amount of time a function takes to execute in ms
|
Measures the amount of time a function takes to execute in ms
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ def convert_filter_scopes(
|
||||||
) -> Dict[int, Dict[str, Dict[str, Any]]]:
|
) -> Dict[int, Dict[str, Dict[str, Any]]]:
|
||||||
filter_scopes = {}
|
filter_scopes = {}
|
||||||
immuned_by_id: List[int] = json_metadata.get("filter_immune_slices") or []
|
immuned_by_id: List[int] = json_metadata.get("filter_immune_slices") or []
|
||||||
immuned_by_column: Dict = defaultdict(list)
|
immuned_by_column: Dict[str, List[int]] = defaultdict(list)
|
||||||
for slice_id, columns in json_metadata.get(
|
for slice_id, columns in json_metadata.get(
|
||||||
"filter_immune_slice_fields", {}
|
"filter_immune_slice_fields", {}
|
||||||
).items():
|
).items():
|
||||||
|
|
@ -52,7 +52,7 @@ def convert_filter_scopes(
|
||||||
logging.info(f"slice [{filter_id}] has invalid field: {filter_field}")
|
logging.info(f"slice [{filter_id}] has invalid field: {filter_field}")
|
||||||
|
|
||||||
for filter_slice in filters:
|
for filter_slice in filters:
|
||||||
filter_fields: Dict = {}
|
filter_fields: Dict[str, Dict[str, Any]] = {}
|
||||||
filter_id = filter_slice.id
|
filter_id = filter_slice.id
|
||||||
slice_params = json.loads(filter_slice.params or "{}")
|
slice_params = json.loads(filter_slice.params or "{}")
|
||||||
configs = slice_params.get("filter_configs") or []
|
configs = slice_params.get("filter_configs") or []
|
||||||
|
|
@ -77,9 +77,10 @@ def convert_filter_scopes(
|
||||||
|
|
||||||
|
|
||||||
def copy_filter_scopes(
|
def copy_filter_scopes(
|
||||||
old_to_new_slc_id_dict: Dict[int, int], old_filter_scopes: Dict[str, Dict]
|
old_to_new_slc_id_dict: Dict[int, int],
|
||||||
) -> Dict:
|
old_filter_scopes: Dict[int, Dict[str, Dict[str, Any]]],
|
||||||
new_filter_scopes: Dict[str, Dict] = {}
|
) -> Dict[str, Dict[Any, Any]]:
|
||||||
|
new_filter_scopes: Dict[str, Dict[Any, Any]] = {}
|
||||||
for (filter_id, scopes) in old_filter_scopes.items():
|
for (filter_id, scopes) in old_filter_scopes.items():
|
||||||
new_filter_key = old_to_new_slc_id_dict.get(int(filter_id))
|
new_filter_key = old_to_new_slc_id_dict.get(int(filter_id))
|
||||||
if new_filter_key:
|
if new_filter_key:
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,7 @@ def stats_timing(stats_key: str, stats_logger: BaseStatsLogger) -> Iterator[floa
|
||||||
stats_logger.timing(stats_key, now_as_float() - start_ts)
|
stats_logger.timing(stats_key, now_as_float() - start_ts)
|
||||||
|
|
||||||
|
|
||||||
def etag_cache(max_age: int, check_perms: Callable) -> Callable:
|
def etag_cache(max_age: int, check_perms: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
"""
|
"""
|
||||||
A decorator for caching views and handling etag conditional requests.
|
A decorator for caching views and handling etag conditional requests.
|
||||||
|
|
||||||
|
|
@ -60,7 +60,7 @@ def etag_cache(max_age: int, check_perms: Callable) -> Callable:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(f: Callable) -> Callable:
|
def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
@wraps(f)
|
@wraps(f)
|
||||||
def wrapper(*args: Any, **kwargs: Any) -> ETagResponseMixin:
|
def wrapper(*args: Any, **kwargs: Any) -> ETagResponseMixin:
|
||||||
# check if the user can access the resource
|
# check if the user can access the resource
|
||||||
|
|
|
||||||
|
|
@ -27,8 +27,8 @@ logger = logging.getLogger(__name__)
|
||||||
def import_datasource(
|
def import_datasource(
|
||||||
session: Session,
|
session: Session,
|
||||||
i_datasource: Model,
|
i_datasource: Model,
|
||||||
lookup_database: Callable,
|
lookup_database: Callable[[Model], Model],
|
||||||
lookup_datasource: Callable,
|
lookup_datasource: Callable[[Model], Model],
|
||||||
import_time: Optional[int] = None,
|
import_time: Optional[int] = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Imports the datasource from the object to the database.
|
"""Imports the datasource from the object to the database.
|
||||||
|
|
@ -82,7 +82,9 @@ def import_datasource(
|
||||||
return datasource.id
|
return datasource.id
|
||||||
|
|
||||||
|
|
||||||
def import_simple_obj(session: Session, i_obj: Model, lookup_obj: Callable) -> Model:
|
def import_simple_obj(
|
||||||
|
session: Session, i_obj: Model, lookup_obj: Callable[[Model], Model]
|
||||||
|
) -> Model:
|
||||||
make_transient(i_obj)
|
make_transient(i_obj)
|
||||||
i_obj.id = None
|
i_obj.id = None
|
||||||
i_obj.table = None
|
i_obj.table = None
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,7 @@ class AbstractEventLogger(ABC):
|
||||||
) -> None:
|
) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def log_this(self, f: Callable) -> Callable:
|
def log_this(self, f: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
@functools.wraps(f)
|
@functools.wraps(f)
|
||||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||||
user_id = None
|
user_id = None
|
||||||
|
|
@ -124,7 +124,7 @@ def get_event_logger_from_cfg_value(cfg_value: Any) -> AbstractEventLogger:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
event_logger_type = cast(Type, cfg_value)
|
event_logger_type = cast(Type[Any], cfg_value)
|
||||||
result = event_logger_type()
|
result = event_logger_type()
|
||||||
|
|
||||||
# Verify that we have a valid logger impl
|
# Verify that we have a valid logger impl
|
||||||
|
|
|
||||||
|
|
@ -58,7 +58,7 @@ class DefaultLoggingConfigurator(LoggingConfigurator):
|
||||||
|
|
||||||
if app_config["ENABLE_TIME_ROTATE"]:
|
if app_config["ENABLE_TIME_ROTATE"]:
|
||||||
logging.getLogger().setLevel(app_config["TIME_ROTATE_LOG_LEVEL"])
|
logging.getLogger().setLevel(app_config["TIME_ROTATE_LOG_LEVEL"])
|
||||||
handler = TimedRotatingFileHandler( # type: ignore
|
handler = TimedRotatingFileHandler(
|
||||||
app_config["FILENAME"],
|
app_config["FILENAME"],
|
||||||
when=app_config["ROLLOVER"],
|
when=app_config["ROLLOVER"],
|
||||||
interval=app_config["INTERVAL"],
|
interval=app_config["INTERVAL"],
|
||||||
|
|
|
||||||
|
|
@ -72,8 +72,8 @@ WHITELIST_CUMULATIVE_FUNCTIONS = (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def validate_column_args(*argnames: str) -> Callable:
|
def validate_column_args(*argnames: str) -> Callable[..., Any]:
|
||||||
def wrapper(func: Callable) -> Callable:
|
def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
def wrapped(df: DataFrame, **options: Any) -> Any:
|
def wrapped(df: DataFrame, **options: Any) -> Any:
|
||||||
columns = df.columns.tolist()
|
columns = df.columns.tolist()
|
||||||
for name in argnames:
|
for name in argnames:
|
||||||
|
|
@ -471,7 +471,7 @@ def geodetic_parse(
|
||||||
Parse a string containing a geodetic point and return latitude, longitude
|
Parse a string containing a geodetic point and return latitude, longitude
|
||||||
and altitude
|
and altitude
|
||||||
"""
|
"""
|
||||||
point = Point(location) # type: ignore
|
point = Point(location)
|
||||||
return point[0], point[1], point[2]
|
return point[0], point[1], point[2]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,7 @@ SELENIUM_HEADSTART = 3
|
||||||
WindowSize = Tuple[int, int]
|
WindowSize = Tuple[int, int]
|
||||||
|
|
||||||
|
|
||||||
def get_auth_cookies(user: "User") -> List[Dict]:
|
def get_auth_cookies(user: "User") -> List[Dict[Any, Any]]:
|
||||||
# Login with the user specified to get the reports
|
# Login with the user specified to get the reports
|
||||||
with current_app.test_request_context("/login"):
|
with current_app.test_request_context("/login"):
|
||||||
login_user(user)
|
login_user(user)
|
||||||
|
|
@ -101,14 +101,14 @@ class AuthWebDriverProxy:
|
||||||
self,
|
self,
|
||||||
driver_type: str,
|
driver_type: str,
|
||||||
window: Optional[WindowSize] = None,
|
window: Optional[WindowSize] = None,
|
||||||
auth_func: Optional[Callable] = None,
|
auth_func: Optional[
|
||||||
|
Callable[..., Any]
|
||||||
|
] = None, # pylint: disable=bad-whitespace
|
||||||
):
|
):
|
||||||
self._driver_type = driver_type
|
self._driver_type = driver_type
|
||||||
self._window: WindowSize = window or (800, 600)
|
self._window: WindowSize = window or (800, 600)
|
||||||
config_auth_func: Callable = current_app.config.get(
|
config_auth_func = current_app.config.get("WEBDRIVER_AUTH_FUNC", auth_driver)
|
||||||
"WEBDRIVER_AUTH_FUNC", auth_driver
|
self._auth_func = auth_func or config_auth_func
|
||||||
)
|
|
||||||
self._auth_func: Callable = auth_func or config_auth_func
|
|
||||||
|
|
||||||
def create(self) -> WebDriver:
|
def create(self) -> WebDriver:
|
||||||
if self._driver_type == "firefox":
|
if self._driver_type == "firefox":
|
||||||
|
|
@ -123,7 +123,7 @@ class AuthWebDriverProxy:
|
||||||
raise Exception(f"Webdriver name ({self._driver_type}) not supported")
|
raise Exception(f"Webdriver name ({self._driver_type}) not supported")
|
||||||
# Prepare args for the webdriver init
|
# Prepare args for the webdriver init
|
||||||
options.add_argument("--headless")
|
options.add_argument("--headless")
|
||||||
kwargs: Dict = dict(options=options)
|
kwargs: Dict[Any, Any] = dict(options=options)
|
||||||
kwargs.update(current_app.config["WEBDRIVER_CONFIGURATION"])
|
kwargs.update(current_app.config["WEBDRIVER_CONFIGURATION"])
|
||||||
logger.info("Init selenium driver")
|
logger.info("Init selenium driver")
|
||||||
return driver_class(**kwargs)
|
return driver_class(**kwargs)
|
||||||
|
|
|
||||||
|
|
@ -143,7 +143,7 @@ def generate_download_headers(
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
|
|
||||||
def api(f: Callable) -> Callable:
|
def api(f: Callable[..., FlaskResponse]) -> Callable[..., FlaskResponse]:
|
||||||
"""
|
"""
|
||||||
A decorator to label an endpoint as an API. Catches uncaught exceptions and
|
A decorator to label an endpoint as an API. Catches uncaught exceptions and
|
||||||
return the response in the JSON format
|
return the response in the JSON format
|
||||||
|
|
@ -383,11 +383,11 @@ class DeleteMixin: # pylint: disable=too-few-public-methods
|
||||||
:param primary_key:
|
:param primary_key:
|
||||||
record primary key to delete
|
record primary key to delete
|
||||||
"""
|
"""
|
||||||
item = self.datamodel.get(primary_key, self._base_filters) # type: ignore
|
item = self.datamodel.get(primary_key, self._base_filters)
|
||||||
if not item:
|
if not item:
|
||||||
abort(404)
|
abort(404)
|
||||||
try:
|
try:
|
||||||
self.pre_delete(item) # type: ignore
|
self.pre_delete(item)
|
||||||
except Exception as ex: # pylint: disable=broad-except
|
except Exception as ex: # pylint: disable=broad-except
|
||||||
flash(str(ex), "danger")
|
flash(str(ex), "danger")
|
||||||
else:
|
else:
|
||||||
|
|
@ -400,8 +400,8 @@ class DeleteMixin: # pylint: disable=too-few-public-methods
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.datamodel.delete(item): # type: ignore
|
if self.datamodel.delete(item):
|
||||||
self.post_delete(item) # type: ignore
|
self.post_delete(item)
|
||||||
|
|
||||||
for pv in pvs:
|
for pv in pvs:
|
||||||
security_manager.get_session.delete(pv)
|
security_manager.get_session.delete(pv)
|
||||||
|
|
@ -411,8 +411,8 @@ class DeleteMixin: # pylint: disable=too-few-public-methods
|
||||||
|
|
||||||
security_manager.get_session.commit()
|
security_manager.get_session.commit()
|
||||||
|
|
||||||
flash(*self.datamodel.message) # type: ignore
|
flash(*self.datamodel.message)
|
||||||
self.update_redirect() # type: ignore
|
self.update_redirect()
|
||||||
|
|
||||||
@action(
|
@action(
|
||||||
"muldelete", __("Delete"), __("Delete all Really?"), "fa-trash", single=False
|
"muldelete", __("Delete"), __("Delete all Really?"), "fa-trash", single=False
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ get_related_schema = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def statsd_metrics(f: Callable) -> Callable:
|
def statsd_metrics(f: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
"""
|
"""
|
||||||
Handle sending all statsd metrics from the REST API
|
Handle sending all statsd metrics from the REST API
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -88,7 +88,9 @@ class BaseOwnedSchema(BaseSupersetSchema):
|
||||||
owners_field_name = "owners"
|
owners_field_name = "owners"
|
||||||
|
|
||||||
@post_load
|
@post_load
|
||||||
def make_object(self, data: Dict, discard: Optional[List[str]] = None) -> Model:
|
def make_object(
|
||||||
|
self, data: Dict[str, Any], discard: Optional[List[str]] = None
|
||||||
|
) -> Model:
|
||||||
discard = discard or []
|
discard = discard or []
|
||||||
discard.append(self.owners_field_name)
|
discard.append(self.owners_field_name)
|
||||||
instance = super().make_object(data, discard)
|
instance = super().make_object(data, discard)
|
||||||
|
|
|
||||||
|
|
@ -251,7 +251,7 @@ def check_slice_perms(self: "Superset", slice_id: int) -> None:
|
||||||
|
|
||||||
def _deserialize_results_payload(
|
def _deserialize_results_payload(
|
||||||
payload: Union[bytes, str], query: Query, use_msgpack: Optional[bool] = False
|
payload: Union[bytes, str], query: Query, use_msgpack: Optional[bool] = False
|
||||||
) -> Dict[Any, Any]:
|
) -> Dict[str, Any]:
|
||||||
logger.debug(f"Deserializing from msgpack: {use_msgpack}")
|
logger.debug(f"Deserializing from msgpack: {use_msgpack}")
|
||||||
if use_msgpack:
|
if use_msgpack:
|
||||||
with stats_timing(
|
with stats_timing(
|
||||||
|
|
@ -278,7 +278,7 @@ def _deserialize_results_payload(
|
||||||
with stats_timing(
|
with stats_timing(
|
||||||
"sqllab.query.results_backend_json_deserialize", stats_logger
|
"sqllab.query.results_backend_json_deserialize", stats_logger
|
||||||
):
|
):
|
||||||
return json.loads(payload) # type: ignore
|
return json.loads(payload)
|
||||||
|
|
||||||
|
|
||||||
def get_cta_schema_name(
|
def get_cta_schema_name(
|
||||||
|
|
@ -1343,7 +1343,7 @@ class Superset(BaseSupersetView):
|
||||||
|
|
||||||
if "timed_refresh_immune_slices" not in md:
|
if "timed_refresh_immune_slices" not in md:
|
||||||
md["timed_refresh_immune_slices"] = []
|
md["timed_refresh_immune_slices"] = []
|
||||||
new_filter_scopes: Dict[str, Dict] = {}
|
new_filter_scopes = {}
|
||||||
if "filter_scopes" in data:
|
if "filter_scopes" in data:
|
||||||
# replace filter_id and immune ids from old slice id to new slice id:
|
# replace filter_id and immune ids from old slice id to new slice id:
|
||||||
# and remove slice ids that are not in dash anymore
|
# and remove slice ids that are not in dash anymore
|
||||||
|
|
@ -2137,7 +2137,7 @@ class Superset(BaseSupersetView):
|
||||||
f"deprecated.{self.__class__.__name__}.select_star.database_not_found"
|
f"deprecated.{self.__class__.__name__}.select_star.database_not_found"
|
||||||
)
|
)
|
||||||
return json_error_response("Not found", 404)
|
return json_error_response("Not found", 404)
|
||||||
schema = utils.parse_js_uri_path_item(schema, eval_undefined=True) # type: ignore
|
schema = utils.parse_js_uri_path_item(schema, eval_undefined=True)
|
||||||
table_name = utils.parse_js_uri_path_item(table_name) # type: ignore
|
table_name = utils.parse_js_uri_path_item(table_name) # type: ignore
|
||||||
# Check that the user can access the datasource
|
# Check that the user can access the datasource
|
||||||
if not self.appbuilder.sm.can_access_datasource(
|
if not self.appbuilder.sm.can_access_datasource(
|
||||||
|
|
@ -2245,7 +2245,7 @@ class Superset(BaseSupersetView):
|
||||||
)
|
)
|
||||||
|
|
||||||
payload = utils.zlib_decompress(blob, decode=not results_backend_use_msgpack)
|
payload = utils.zlib_decompress(blob, decode=not results_backend_use_msgpack)
|
||||||
obj: dict = _deserialize_results_payload(
|
obj = _deserialize_results_payload(
|
||||||
payload, query, cast(bool, results_backend_use_msgpack)
|
payload, query, cast(bool, results_backend_use_msgpack)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -2474,9 +2474,7 @@ class Superset(BaseSupersetView):
|
||||||
schema: str = cast(str, query_params.get("schema"))
|
schema: str = cast(str, query_params.get("schema"))
|
||||||
sql: str = cast(str, query_params.get("sql"))
|
sql: str = cast(str, query_params.get("sql"))
|
||||||
try:
|
try:
|
||||||
template_params: dict = json.loads(
|
template_params = json.loads(query_params.get("templateParams") or "{}")
|
||||||
query_params.get("templateParams") or "{}"
|
|
||||||
)
|
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Invalid template parameter {query_params.get('templateParams')}"
|
f"Invalid template parameter {query_params.get('templateParams')}"
|
||||||
|
|
|
||||||
|
|
@ -61,7 +61,7 @@ def get_col_type(col: Dict[Any, Any]) -> str:
|
||||||
|
|
||||||
def get_table_metadata(
|
def get_table_metadata(
|
||||||
database: Database, table_name: str, schema_name: Optional[str]
|
database: Database, table_name: str, schema_name: Optional[str]
|
||||||
) -> Dict:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Get table metadata information, including type, pk, fks.
|
Get table metadata information, including type, pk, fks.
|
||||||
This function raises SQLAlchemyError when a schema is not found.
|
This function raises SQLAlchemyError when a schema is not found.
|
||||||
|
|
@ -72,7 +72,7 @@ def get_table_metadata(
|
||||||
:param schema_name: schema name
|
:param schema_name: schema name
|
||||||
:return: Dict table metadata ready for API response
|
:return: Dict table metadata ready for API response
|
||||||
"""
|
"""
|
||||||
keys: List = []
|
keys = []
|
||||||
columns = database.get_columns(table_name, schema_name)
|
columns = database.get_columns(table_name, schema_name)
|
||||||
primary_key = database.get_pk_constraint(table_name, schema_name)
|
primary_key = database.get_pk_constraint(table_name, schema_name)
|
||||||
if primary_key and primary_key.get("constrained_columns"):
|
if primary_key and primary_key.get("constrained_columns"):
|
||||||
|
|
@ -82,7 +82,7 @@ def get_table_metadata(
|
||||||
foreign_keys = get_foreign_keys_metadata(database, table_name, schema_name)
|
foreign_keys = get_foreign_keys_metadata(database, table_name, schema_name)
|
||||||
indexes = get_indexes_metadata(database, table_name, schema_name)
|
indexes = get_indexes_metadata(database, table_name, schema_name)
|
||||||
keys += foreign_keys + indexes
|
keys += foreign_keys + indexes
|
||||||
payload_columns: List[Dict] = []
|
payload_columns: List[Dict[str, Any]] = []
|
||||||
for col in columns:
|
for col in columns:
|
||||||
dtype = get_col_type(col)
|
dtype = get_col_type(col)
|
||||||
payload_columns.append(
|
payload_columns.append(
|
||||||
|
|
@ -90,7 +90,7 @@ def get_table_metadata(
|
||||||
"name": col["name"],
|
"name": col["name"],
|
||||||
"type": dtype.split("(")[0] if "(" in dtype else dtype,
|
"type": dtype.split("(")[0] if "(" in dtype else dtype,
|
||||||
"longType": dtype,
|
"longType": dtype,
|
||||||
"keys": [k for k in keys if col["name"] in k.get("column_names")],
|
"keys": [k for k in keys if col["name"] in k["column_names"]],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
|
|
@ -270,7 +270,7 @@ class DatabaseRestApi(DatabaseMixin, BaseSupersetModelRestApi):
|
||||||
"""
|
"""
|
||||||
self.incr_stats("init", self.table_metadata.__name__)
|
self.incr_stats("init", self.table_metadata.__name__)
|
||||||
try:
|
try:
|
||||||
table_info: Dict = get_table_metadata(database, table_name, schema_name)
|
table_info = get_table_metadata(database, table_name, schema_name)
|
||||||
except SQLAlchemyError as ex:
|
except SQLAlchemyError as ex:
|
||||||
self.incr_stats("error", self.table_metadata.__name__)
|
self.incr_stats("error", self.table_metadata.__name__)
|
||||||
return self.response_422(error_msg_from_exception(ex))
|
return self.response_422(error_msg_from_exception(ex))
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ from superset.views.base_api import BaseSupersetModelRestApi
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def check_datasource_access(f: Callable) -> Callable:
|
def check_datasource_access(f: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
"""
|
"""
|
||||||
A Decorator that checks if a user has datasource access
|
A Decorator that checks if a user has datasource access
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import enum
|
import enum
|
||||||
from typing import Type
|
from typing import Type, Union
|
||||||
|
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
from croniter import croniter
|
from croniter import croniter
|
||||||
|
|
@ -55,7 +55,7 @@ class EmailScheduleView(
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def schedule_type_model(self) -> Type:
|
def schedule_type_model(self) -> Type[Union[Dashboard, Slice]]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
page_size = 20
|
page_size = 20
|
||||||
|
|
@ -154,9 +154,7 @@ class EmailScheduleView(
|
||||||
info[col] = info[col].username
|
info[col] = info[col].username
|
||||||
|
|
||||||
info["user"] = schedule.user.username
|
info["user"] = schedule.user.username
|
||||||
info[self.schedule_type] = getattr( # type: ignore
|
info[self.schedule_type] = getattr(schedule, self.schedule_type).id
|
||||||
schedule, self.schedule_type
|
|
||||||
).id
|
|
||||||
schedules.append(info)
|
schedules.append(info)
|
||||||
|
|
||||||
return json_success(json.dumps(schedules, default=json_iso_dttm_ser))
|
return json_success(json.dumps(schedules, default=json_iso_dttm_ser))
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@
|
||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
from typing import Callable
|
from typing import Any
|
||||||
|
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
from flask import g, redirect, request, Response
|
from flask import g, redirect, request, Response
|
||||||
|
|
@ -40,7 +40,7 @@ from .base import (
|
||||||
|
|
||||||
|
|
||||||
class QueryFilter(BaseFilter): # pylint: disable=too-few-public-methods
|
class QueryFilter(BaseFilter): # pylint: disable=too-few-public-methods
|
||||||
def apply(self, query: BaseQuery, value: Callable) -> BaseQuery:
|
def apply(self, query: BaseQuery, value: Any) -> BaseQuery:
|
||||||
"""
|
"""
|
||||||
Filter queries to only those owned by current user. If
|
Filter queries to only those owned by current user. If
|
||||||
can_access_all_queries permission is set a user can list all queries
|
can_access_all_queries permission is set a user can list all queries
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,7 @@ from superset.utils.core import QueryStatus, TimeRangeEndpoint
|
||||||
from superset.viz import BaseViz
|
from superset.viz import BaseViz
|
||||||
|
|
||||||
if is_feature_enabled("SIP_38_VIZ_REARCHITECTURE"):
|
if is_feature_enabled("SIP_38_VIZ_REARCHITECTURE"):
|
||||||
from superset import viz_sip38 as viz # type: ignore
|
from superset import viz_sip38 as viz
|
||||||
else:
|
else:
|
||||||
from superset import viz # type: ignore
|
from superset import viz # type: ignore
|
||||||
|
|
||||||
|
|
@ -318,9 +318,9 @@ def get_dashboard_extra_filters(
|
||||||
|
|
||||||
|
|
||||||
def build_extra_filters(
|
def build_extra_filters(
|
||||||
layout: Dict,
|
layout: Dict[str, Dict[str, Any]],
|
||||||
filter_scopes: Dict,
|
filter_scopes: Dict[str, Dict[str, Any]],
|
||||||
default_filters: Dict[str, Dict[str, List]],
|
default_filters: Dict[str, Dict[str, List[Any]]],
|
||||||
slice_id: int,
|
slice_id: int,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
extra_filters = []
|
extra_filters = []
|
||||||
|
|
@ -343,7 +343,9 @@ def build_extra_filters(
|
||||||
return extra_filters
|
return extra_filters
|
||||||
|
|
||||||
|
|
||||||
def is_slice_in_container(layout: Dict, container_id: str, slice_id: int) -> bool:
|
def is_slice_in_container(
|
||||||
|
layout: Dict[str, Dict[str, Any]], container_id: str, slice_id: int
|
||||||
|
) -> bool:
|
||||||
if container_id == "ROOT_ID":
|
if container_id == "ROOT_ID":
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2720,7 +2720,7 @@ class PairedTTestViz(BaseViz):
|
||||||
else:
|
else:
|
||||||
cols.append(col)
|
cols.append(col)
|
||||||
df.columns = cols
|
df.columns = cols
|
||||||
data: Dict = {}
|
data: Dict[str, List[Dict[str, Any]]] = {}
|
||||||
series = df.to_dict("series")
|
series = df.to_dict("series")
|
||||||
for nameSet in df.columns:
|
for nameSet in df.columns:
|
||||||
# If no groups are defined, nameSet will be the metric name
|
# If no groups are defined, nameSet will be the metric name
|
||||||
|
|
@ -2750,7 +2750,7 @@ class RoseViz(NVD3TimeSeriesViz):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
data = super().get_data(df)
|
data = super().get_data(df)
|
||||||
result: Dict = {}
|
result: Dict[str, List[Dict[str, str]]] = {}
|
||||||
for datum in data: # type: ignore
|
for datum in data: # type: ignore
|
||||||
key = datum["key"]
|
key = datum["key"]
|
||||||
for val in datum["values"]:
|
for val in datum["values"]:
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@
|
||||||
"""Unit tests for Superset"""
|
"""Unit tests for Superset"""
|
||||||
import imp
|
import imp
|
||||||
import json
|
import json
|
||||||
from typing import Dict, Union, List
|
from typing import Any, Dict, Union, List
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
@ -397,7 +397,9 @@ class SupersetTestCase(TestCase):
|
||||||
mock_method.assert_called_once_with("error", func_name)
|
mock_method.assert_called_once_with("error", func_name)
|
||||||
return rv
|
return rv
|
||||||
|
|
||||||
def post_assert_metric(self, uri: str, data: Dict, func_name: str) -> Response:
|
def post_assert_metric(
|
||||||
|
self, uri: str, data: Dict[str, Any], func_name: str
|
||||||
|
) -> Response:
|
||||||
"""
|
"""
|
||||||
Simple client post with an extra assertion for statsd metrics
|
Simple client post with an extra assertion for statsd metrics
|
||||||
|
|
||||||
|
|
@ -417,7 +419,9 @@ class SupersetTestCase(TestCase):
|
||||||
mock_method.assert_called_once_with("error", func_name)
|
mock_method.assert_called_once_with("error", func_name)
|
||||||
return rv
|
return rv
|
||||||
|
|
||||||
def put_assert_metric(self, uri: str, data: Dict, func_name: str) -> Response:
|
def put_assert_metric(
|
||||||
|
self, uri: str, data: Dict[str, Any], func_name: str
|
||||||
|
) -> Response:
|
||||||
"""
|
"""
|
||||||
Simple client put with an extra assertion for statsd metrics
|
Simple client put with an extra assertion for statsd metrics
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ from copy import copy
|
||||||
from cachelib.redis import RedisCache
|
from cachelib.redis import RedisCache
|
||||||
from flask import Flask
|
from flask import Flask
|
||||||
|
|
||||||
from superset.config import * # type: ignore
|
from superset.config import *
|
||||||
|
|
||||||
AUTH_USER_REGISTRATION_ROLE = "alpha"
|
AUTH_USER_REGISTRATION_ROLE = "alpha"
|
||||||
SQLALCHEMY_DATABASE_URI = "sqlite:///" + os.path.join(DATA_DIR, "unittests.db")
|
SQLALCHEMY_DATABASE_URI = "sqlite:///" + os.path.join(DATA_DIR, "unittests.db")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue