[fix] Enforce the query result to contain a data-frame (#8935)
This commit is contained in:
parent
2a94150097
commit
2d456e88eb
|
|
@ -988,11 +988,9 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
def get_query_str(self, query_obj, phase=1, client=None):
|
||||
return self.run_query(client=client, phase=phase, **query_obj)
|
||||
|
||||
def _add_filter_from_pre_query_data(
|
||||
self, df: Optional[pd.DataFrame], dimensions, dim_filter
|
||||
):
|
||||
def _add_filter_from_pre_query_data(self, df: pd.DataFrame, dimensions, dim_filter):
|
||||
ret = dim_filter
|
||||
if df is not None and not df.empty:
|
||||
if not df.empty:
|
||||
new_filters = []
|
||||
for unused, row in df.iterrows():
|
||||
fields = []
|
||||
|
|
@ -1379,7 +1377,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
|
||||
if df is None or df.size == 0:
|
||||
return QueryResult(
|
||||
df=pd.DataFrame([]),
|
||||
df=pd.DataFrame(),
|
||||
query=query_str,
|
||||
duration=datetime.now() - qry_start_dttm,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -85,7 +85,6 @@ class AnnotationDatasource(BaseDatasource):
|
|||
cache_timeout = 0
|
||||
|
||||
def query(self, query_obj: Dict[str, Any]) -> QueryResult:
|
||||
df = None
|
||||
error_message = None
|
||||
qry = db.session.query(Annotation)
|
||||
qry = qry.filter(Annotation.layer_id == query_obj["filter"][0]["val"])
|
||||
|
|
@ -97,6 +96,7 @@ class AnnotationDatasource(BaseDatasource):
|
|||
try:
|
||||
df = pd.read_sql_query(qry.statement, db.engine)
|
||||
except Exception as e:
|
||||
df = pd.DataFrame()
|
||||
status = utils.QueryStatus.FAILED
|
||||
logging.exception(e)
|
||||
error_message = utils.error_msg_from_exception(e)
|
||||
|
|
@ -995,7 +995,7 @@ class SqlaTable(Model, BaseDatasource):
|
|||
try:
|
||||
df = self.database.get_df(sql, self.schema, mutator)
|
||||
except Exception as e:
|
||||
df = None
|
||||
df = pd.DataFrame()
|
||||
status = utils.QueryStatus.FAILED
|
||||
logging.exception(f"Query {sql} on schema {self.schema} failed")
|
||||
db_engine_spec = self.database.db_engine_spec
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from typing import List, Optional
|
|||
# isort and pylint disagree, isort should win
|
||||
# pylint: disable=ungrouped-imports
|
||||
import humanize
|
||||
import pandas as pd
|
||||
import sqlalchemy as sa
|
||||
import yaml
|
||||
from flask import escape, g, Markup
|
||||
|
|
@ -368,11 +369,11 @@ class QueryResult: # pylint: disable=too-few-public-methods
|
|||
def __init__( # pylint: disable=too-many-arguments
|
||||
self, df, query, duration, status=QueryStatus.SUCCESS, error_message=None
|
||||
):
|
||||
self.df = df # pylint: disable=invalid-name
|
||||
self.query = query
|
||||
self.duration = duration
|
||||
self.status = status
|
||||
self.error_message = error_message
|
||||
self.df: pd.DataFrame = df # pylint: disable=invalid-name
|
||||
self.query: str = query
|
||||
self.duration: int = duration
|
||||
self.status: str = status
|
||||
self.error_message: Optional[str] = error_message
|
||||
|
||||
|
||||
class ExtraJSONMixin:
|
||||
|
|
|
|||
160
superset/viz.py
160
superset/viz.py
|
|
@ -32,7 +32,7 @@ from collections import defaultdict, OrderedDict
|
|||
from datetime import datetime, timedelta
|
||||
from functools import reduce
|
||||
from itertools import product
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import geohash
|
||||
import numpy as np
|
||||
|
|
@ -77,6 +77,8 @@ METRIC_KEYS = [
|
|||
"size",
|
||||
]
|
||||
|
||||
VizData = Optional[Union[List[Any], Dict[Any, Any]]]
|
||||
|
||||
|
||||
class BaseViz:
|
||||
|
||||
|
|
@ -109,10 +111,10 @@ class BaseViz:
|
|||
self.groupby = self.form_data.get("groupby") or []
|
||||
self.time_shift = timedelta()
|
||||
|
||||
self.status = None
|
||||
self.status: Optional[str] = None
|
||||
self.error_msg = ""
|
||||
self.results: Optional[QueryResult] = None
|
||||
self.error_message = None
|
||||
self.error_message: Optional[str] = None
|
||||
self.force = force
|
||||
|
||||
# Keeping track of whether some data came from cache
|
||||
|
|
@ -190,14 +192,12 @@ class BaseViz:
|
|||
df = self.get_df(query_obj)
|
||||
return df.to_dict(orient="records")
|
||||
|
||||
def get_df(
|
||||
self, query_obj: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[pd.DataFrame]:
|
||||
def get_df(self, query_obj: Optional[Dict[str, Any]] = None) -> pd.DataFrame:
|
||||
"""Returns a pandas dataframe based on the query object"""
|
||||
if not query_obj:
|
||||
query_obj = self.query_obj()
|
||||
if not query_obj:
|
||||
return None
|
||||
return pd.DataFrame()
|
||||
|
||||
self.error_msg = ""
|
||||
|
||||
|
|
@ -219,7 +219,7 @@ class BaseViz:
|
|||
# be considered as the default ISO date format
|
||||
# If the datetime format is unix, the parse will use the corresponding
|
||||
# parsing logic.
|
||||
if df is not None and not df.empty:
|
||||
if not df.empty:
|
||||
if DTTM_ALIAS in df.columns:
|
||||
if timestamp_format in ("epoch_s", "epoch_ms"):
|
||||
# Column has already been formatted as a timestamp.
|
||||
|
|
@ -439,11 +439,7 @@ class BaseViz:
|
|||
and self.status != utils.QueryStatus.FAILED
|
||||
):
|
||||
try:
|
||||
cache_value = dict(
|
||||
dttm=cached_dttm,
|
||||
df=df if df is not None else None,
|
||||
query=self.query,
|
||||
)
|
||||
cache_value = dict(dttm=cached_dttm, df=df, query=self.query)
|
||||
cache_value = pkl.dumps(cache_value, protocol=pkl.HIGHEST_PROTOCOL)
|
||||
|
||||
logging.info(
|
||||
|
|
@ -500,7 +496,7 @@ class BaseViz:
|
|||
include_index = not isinstance(df.index, pd.RangeIndex)
|
||||
return df.to_csv(index=include_index, **config["CSV_EXPORT"])
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
return df.to_dict(orient="records")
|
||||
|
||||
@property
|
||||
|
|
@ -563,7 +559,7 @@ class TableViz(BaseViz):
|
|||
d["is_timeseries"] = self.should_be_timeseries()
|
||||
return d
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
fd = self.form_data
|
||||
if not self.should_be_timeseries() and df is not None and DTTM_ALIAS in df:
|
||||
del df[DTTM_ALIAS]
|
||||
|
|
@ -631,7 +627,7 @@ class TimeTableViz(BaseViz):
|
|||
)
|
||||
return d
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
fd = self.form_data
|
||||
columns = None
|
||||
values = self.metric_labels
|
||||
|
|
@ -644,7 +640,7 @@ class TimeTableViz(BaseViz):
|
|||
return dict(
|
||||
records=pt.to_dict(orient="index"),
|
||||
columns=list(pt.columns),
|
||||
is_group_by=len(fd.get("groupby")) > 0,
|
||||
is_group_by=len(fd.get("groupby", [])) > 0,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -684,7 +680,7 @@ class PivotTableViz(BaseViz):
|
|||
raise Exception(_("Group By' and 'Columns' can't overlap"))
|
||||
return d
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
if self.form_data.get("granularity") == "all" and DTTM_ALIAS in df:
|
||||
del df[DTTM_ALIAS]
|
||||
|
||||
|
|
@ -701,7 +697,9 @@ class PivotTableViz(BaseViz):
|
|||
df = df.pivot_table(
|
||||
index=groupby,
|
||||
columns=columns,
|
||||
values=[utils.get_metric_name(m) for m in self.form_data.get("metrics")],
|
||||
values=[
|
||||
utils.get_metric_name(m) for m in self.form_data.get("metrics", [])
|
||||
],
|
||||
aggfunc=aggfunc,
|
||||
margins=self.form_data.get("pivot_margins"),
|
||||
)
|
||||
|
|
@ -731,12 +729,10 @@ class MarkupViz(BaseViz):
|
|||
def query_obj(self):
|
||||
return None
|
||||
|
||||
def get_df(
|
||||
self, query_obj: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[pd.DataFrame]:
|
||||
return None
|
||||
def get_df(self, query_obj: Optional[Dict[str, Any]] = None) -> pd.DataFrame:
|
||||
return pd.DataFrame()
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
markup_type = self.form_data.get("markup_type")
|
||||
code = self.form_data.get("code", "")
|
||||
if markup_type == "markdown":
|
||||
|
|
@ -790,7 +786,7 @@ class TreemapViz(BaseViz):
|
|||
]
|
||||
return result
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
df = df.set_index(self.form_data.get("groupby"))
|
||||
chart_data = [
|
||||
{"name": metric, "children": self._nest(metric, df)}
|
||||
|
|
@ -808,7 +804,7 @@ class CalHeatmapViz(BaseViz):
|
|||
credits = "<a href=https://github.com/wa0x6e/cal-heatmap>cal-heatmap</a>"
|
||||
is_timeseries = True
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
form_data = self.form_data
|
||||
|
||||
data = {}
|
||||
|
|
@ -842,9 +838,9 @@ class CalHeatmapViz(BaseViz):
|
|||
elif domain == "week":
|
||||
range_ = diff_delta.years * 53 + diff_delta.weeks + 1
|
||||
elif domain == "day":
|
||||
range_ = diff_secs // (24 * 60 * 60) + 1
|
||||
range_ = diff_secs // (24 * 60 * 60) + 1 # type: ignore
|
||||
else:
|
||||
range_ = diff_secs // (60 * 60) + 1
|
||||
range_ = diff_secs // (60 * 60) + 1 # type: ignore
|
||||
|
||||
return {
|
||||
"data": data,
|
||||
|
|
@ -900,7 +896,7 @@ class BoxPlotViz(NVD3Viz):
|
|||
chart_data.append({"label": chart_label, "values": box})
|
||||
return chart_data
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
form_data = self.form_data
|
||||
|
||||
# conform to NVD3 names
|
||||
|
|
@ -929,8 +925,10 @@ class BoxPlotViz(NVD3Viz):
|
|||
def whisker_low(series):
|
||||
return series.min()
|
||||
|
||||
elif " percentiles" in whisker_type:
|
||||
low, high = whisker_type.replace(" percentiles", "").split("/")
|
||||
elif " percentiles" in whisker_type: # type: ignore
|
||||
low, high = whisker_type.replace(" percentiles", "").split( # type: ignore
|
||||
"/"
|
||||
)
|
||||
|
||||
def whisker_high(series):
|
||||
return np.nanpercentile(series, int(high))
|
||||
|
|
@ -981,14 +979,14 @@ class BubbleViz(NVD3Viz):
|
|||
raise Exception(_("Pick a metric for x, y and size"))
|
||||
return d
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
df["x"] = df[[utils.get_metric_name(self.x_metric)]]
|
||||
df["y"] = df[[utils.get_metric_name(self.y_metric)]]
|
||||
df["size"] = df[[utils.get_metric_name(self.z_metric)]]
|
||||
df["shape"] = "circle"
|
||||
df["group"] = df[[self.series]]
|
||||
|
||||
series = defaultdict(list)
|
||||
series: Dict[Any, List[Any]] = defaultdict(list)
|
||||
for row in df.to_dict(orient="records"):
|
||||
series[row["group"]].append(row)
|
||||
chart_data = []
|
||||
|
|
@ -1029,7 +1027,7 @@ class BulletViz(NVD3Viz):
|
|||
raise Exception(_("Pick a metric to display"))
|
||||
return d
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
df["metric"] = df[[utils.get_metric_name(self.metric)]]
|
||||
values = df["metric"].values
|
||||
return {
|
||||
|
|
@ -1155,6 +1153,9 @@ class NVD3TimeSeriesViz(NVD3Viz):
|
|||
if fd.get("granularity") == "all":
|
||||
raise Exception(_("Pick a time granularity for your time series"))
|
||||
|
||||
if df.empty:
|
||||
return df
|
||||
|
||||
if aggregate:
|
||||
df = df.pivot_table(
|
||||
index=DTTM_ALIAS,
|
||||
|
|
@ -1236,7 +1237,7 @@ class NVD3TimeSeriesViz(NVD3Viz):
|
|||
df2 = self.process_data(df2)
|
||||
self._extra_chart_data.append((label, df2))
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
fd = self.form_data
|
||||
comparison_type = fd.get("comparison_type") or "values"
|
||||
df = self.process_data(df)
|
||||
|
|
@ -1298,7 +1299,7 @@ class MultiLineViz(NVD3Viz):
|
|||
def query_obj(self):
|
||||
return None
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
fd = self.form_data
|
||||
# Late imports to avoid circular import issues
|
||||
from superset.models.slice import Slice
|
||||
|
|
@ -1371,7 +1372,7 @@ class NVD3DualLineViz(NVD3Viz):
|
|||
chart_data.append(d)
|
||||
return chart_data
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
fd = self.form_data
|
||||
|
||||
if self.form_data.get("granularity") == "all":
|
||||
|
|
@ -1407,7 +1408,7 @@ class NVD3TimePivotViz(NVD3TimeSeriesViz):
|
|||
d["metrics"] = [self.form_data.get("metric")]
|
||||
return d
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
fd = self.form_data
|
||||
df = self.process_data(df)
|
||||
freq = to_offset(fd.get("freq"))
|
||||
|
|
@ -1465,7 +1466,7 @@ class DistributionPieViz(NVD3Viz):
|
|||
verbose_name = _("Distribution - NVD3 - Pie Chart")
|
||||
is_timeseries = False
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
metric = self.metric_labels[0]
|
||||
df = df.pivot_table(index=self.groupby, values=[metric])
|
||||
df.sort_values(by=metric, ascending=False, inplace=True)
|
||||
|
|
@ -1505,7 +1506,7 @@ class HistogramViz(BaseViz):
|
|||
labels = [column] + labels
|
||||
return "__".join(labels)
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
"""Returns the chart data"""
|
||||
chart_data = []
|
||||
if len(self.groupby) > 0:
|
||||
|
|
@ -1546,7 +1547,7 @@ class DistributionBarViz(DistributionPieViz):
|
|||
raise Exception(_("Pick at least one field for [Series]"))
|
||||
return d
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
fd = self.form_data
|
||||
metrics = self.metric_labels
|
||||
columns = fd.get("columns") or []
|
||||
|
|
@ -1597,13 +1598,13 @@ class SunburstViz(BaseViz):
|
|||
'@<a href="https://bl.ocks.org/kerryrodden/7090426">bl.ocks.org</a>'
|
||||
)
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
fd = self.form_data
|
||||
cols = fd.get("groupby")
|
||||
metric = utils.get_metric_name(fd.get("metric"))
|
||||
secondary_metric = utils.get_metric_name(fd.get("secondary_metric"))
|
||||
if metric == secondary_metric or secondary_metric is None:
|
||||
df.columns = cols + ["m1"]
|
||||
df.columns = cols + ["m1"] # type: ignore
|
||||
df["m2"] = df["m1"]
|
||||
return json.loads(df.to_json(orient="values"))
|
||||
|
||||
|
|
@ -1633,13 +1634,13 @@ class SankeyViz(BaseViz):
|
|||
qry["metrics"] = [self.form_data["metric"]]
|
||||
return qry
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
df.columns = ["source", "target", "value"]
|
||||
df["source"] = df["source"].astype(str)
|
||||
df["target"] = df["target"].astype(str)
|
||||
recs = df.to_dict(orient="records")
|
||||
|
||||
hierarchy = defaultdict(set)
|
||||
hierarchy: Dict[str, Set[str]] = defaultdict(set)
|
||||
for row in recs:
|
||||
hierarchy[row["source"]].add(row["target"])
|
||||
|
||||
|
|
@ -1686,7 +1687,7 @@ class DirectedForceViz(BaseViz):
|
|||
qry["metrics"] = [self.form_data["metric"]]
|
||||
return qry
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
df.columns = ["source", "target", "value"]
|
||||
return df.to_dict(orient="records")
|
||||
|
||||
|
|
@ -1707,7 +1708,7 @@ class ChordViz(BaseViz):
|
|||
qry["metrics"] = [utils.get_metric_name(fd.get("metric"))]
|
||||
return qry
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
df.columns = ["source", "target", "value"]
|
||||
|
||||
# Preparing a symetrical matrix like d3.chords calls for
|
||||
|
|
@ -1736,7 +1737,7 @@ class CountryMapViz(BaseViz):
|
|||
qry["groupby"] = [self.form_data["entity"]]
|
||||
return qry
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
fd = self.form_data
|
||||
cols = [fd.get("entity")]
|
||||
metric = self.metric_labels[0]
|
||||
|
|
@ -1762,7 +1763,7 @@ class WorldMapViz(BaseViz):
|
|||
qry["groupby"] = [self.form_data["entity"]]
|
||||
return qry
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
from superset.examples import countries
|
||||
|
||||
fd = self.form_data
|
||||
|
|
@ -1830,7 +1831,7 @@ class FilterBoxViz(BaseViz):
|
|||
df = self.get_df_payload(query_obj=qry).get("df")
|
||||
self.dataframes[col] = df
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
filters = self.form_data.get("filter_configs") or []
|
||||
d = {}
|
||||
for flt in filters:
|
||||
|
|
@ -1867,10 +1868,10 @@ class IFrameViz(BaseViz):
|
|||
def query_obj(self):
|
||||
return None
|
||||
|
||||
def get_df(self, query_obj: Dict[str, Any] = None) -> Optional[pd.DataFrame]:
|
||||
return None
|
||||
def get_df(self, query_obj: Dict[str, Any] = None) -> pd.DataFrame:
|
||||
return pd.DataFrame()
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
return {}
|
||||
|
||||
|
||||
|
|
@ -1896,7 +1897,7 @@ class ParallelCoordinatesViz(BaseViz):
|
|||
d["groupby"] = [fd.get("series")]
|
||||
return d
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
return df.to_dict(orient="records")
|
||||
|
||||
|
||||
|
|
@ -1919,7 +1920,7 @@ class HeatmapViz(BaseViz):
|
|||
d["groupby"] = [fd.get("all_columns_x"), fd.get("all_columns_y")]
|
||||
return d
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
fd = self.form_data
|
||||
x = fd.get("all_columns_x")
|
||||
y = fd.get("all_columns_y")
|
||||
|
|
@ -2026,20 +2027,20 @@ class MapboxViz(BaseViz):
|
|||
)
|
||||
return d
|
||||
|
||||
def get_data(self, df):
|
||||
if df is None:
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
if df.empty:
|
||||
return None
|
||||
fd = self.form_data
|
||||
label_col = fd.get("mapbox_label")
|
||||
has_custom_metric = label_col is not None and len(label_col) > 0
|
||||
metric_col = [None] * len(df.index)
|
||||
if has_custom_metric:
|
||||
if label_col[0] == fd.get("all_columns_x"):
|
||||
if label_col[0] == fd.get("all_columns_x"): # type: ignore
|
||||
metric_col = df[fd.get("all_columns_x")]
|
||||
elif label_col[0] == fd.get("all_columns_y"):
|
||||
elif label_col[0] == fd.get("all_columns_y"): # type: ignore
|
||||
metric_col = df[fd.get("all_columns_y")]
|
||||
else:
|
||||
metric_col = df[label_col[0]]
|
||||
metric_col = df[label_col[0]] # type: ignore
|
||||
point_radius_col = (
|
||||
[None] * len(df.index)
|
||||
if fd.get("point_radius") == "Auto"
|
||||
|
|
@ -2106,7 +2107,7 @@ class DeckGLMultiLayer(BaseViz):
|
|||
def query_obj(self):
|
||||
return None
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
fd = self.form_data
|
||||
# Late imports to avoid circular import issues
|
||||
from superset.models.slice import Slice
|
||||
|
|
@ -2248,8 +2249,8 @@ class BaseDeckGLViz(BaseViz):
|
|||
cols = self.form_data.get("js_columns") or []
|
||||
return {col: d.get(col) for col in cols}
|
||||
|
||||
def get_data(self, df):
|
||||
if df is None:
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
if df.empty:
|
||||
return None
|
||||
|
||||
# Processing spatial info
|
||||
|
|
@ -2310,13 +2311,13 @@ class DeckScatterViz(BaseDeckGLViz):
|
|||
DTTM_ALIAS: d.get(DTTM_ALIAS),
|
||||
}
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
fd = self.form_data
|
||||
self.metric_label = utils.get_metric_name(self.metric) if self.metric else None
|
||||
self.point_radius_fixed = fd.get("point_radius_fixed")
|
||||
self.fixed_value = None
|
||||
self.dim = self.form_data.get("dimension")
|
||||
if self.point_radius_fixed.get("type") != "metric":
|
||||
if self.point_radius_fixed and self.point_radius_fixed.get("type") != "metric":
|
||||
self.fixed_value = self.point_radius_fixed.get("value")
|
||||
return super().get_data(df)
|
||||
|
||||
|
|
@ -2342,7 +2343,7 @@ class DeckScreengrid(BaseDeckGLViz):
|
|||
"__timestamp": d.get(DTTM_ALIAS) or d.get("__time"),
|
||||
}
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
self.metric_label = utils.get_metric_name(self.metric)
|
||||
return super().get_data(df)
|
||||
|
||||
|
|
@ -2358,7 +2359,7 @@ class DeckGrid(BaseDeckGLViz):
|
|||
def get_properties(self, d):
|
||||
return {"position": d.get("spatial"), "weight": d.get(self.metric_label) or 1}
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
self.metric_label = utils.get_metric_name(self.metric)
|
||||
return super().get_data(df)
|
||||
|
||||
|
|
@ -2416,7 +2417,7 @@ class DeckPathViz(BaseDeckGLViz):
|
|||
d["__timestamp"] = d.get(DTTM_ALIAS) or d.get("__time")
|
||||
return d
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
self.metric_label = utils.get_metric_name(self.metric)
|
||||
return super().get_data(df)
|
||||
|
||||
|
|
@ -2462,7 +2463,7 @@ class DeckHex(BaseDeckGLViz):
|
|||
def get_properties(self, d):
|
||||
return {"position": d.get("spatial"), "weight": d.get(self.metric_label) or 1}
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
self.metric_label = utils.get_metric_name(self.metric)
|
||||
return super(DeckHex, self).get_data(df)
|
||||
|
||||
|
|
@ -2509,10 +2510,13 @@ class DeckArc(BaseDeckGLViz):
|
|||
DTTM_ALIAS: d.get(DTTM_ALIAS),
|
||||
}
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
d = super().get_data(df)
|
||||
|
||||
return {"features": d["features"], "mapboxApiKey": config["MAPBOX_API_KEY"]}
|
||||
return {
|
||||
"features": d["features"], # type: ignore
|
||||
"mapboxApiKey": config["MAPBOX_API_KEY"],
|
||||
}
|
||||
|
||||
|
||||
class EventFlowViz(BaseViz):
|
||||
|
|
@ -2543,7 +2547,7 @@ class EventFlowViz(BaseViz):
|
|||
|
||||
return query
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
return df.to_dict(orient="records")
|
||||
|
||||
|
||||
|
|
@ -2556,7 +2560,7 @@ class PairedTTestViz(BaseViz):
|
|||
sort_series = False
|
||||
is_timeseries = True
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
"""
|
||||
Transform received data frame into an object of the form:
|
||||
{
|
||||
|
|
@ -2582,7 +2586,7 @@ class PairedTTestViz(BaseViz):
|
|||
else:
|
||||
cols.append(col)
|
||||
df.columns = cols
|
||||
data = {}
|
||||
data: Dict = {}
|
||||
series = df.to_dict("series")
|
||||
for nameSet in df.columns:
|
||||
# If no groups are defined, nameSet will be the metric name
|
||||
|
|
@ -2607,10 +2611,10 @@ class RoseViz(NVD3TimeSeriesViz):
|
|||
sort_series = False
|
||||
is_timeseries = True
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
data = super().get_data(df)
|
||||
result = {}
|
||||
for datum in data:
|
||||
result: Dict = {}
|
||||
for datum in data: # type: ignore
|
||||
key = datum["key"]
|
||||
for val in datum["values"]:
|
||||
timestamp = val["x"].value
|
||||
|
|
@ -2761,7 +2765,7 @@ class PartitionViz(NVD3TimeSeriesViz):
|
|||
for i in procs[level][dims].columns
|
||||
]
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
fd = self.form_data
|
||||
groups = fd.get("groupby", [])
|
||||
time_op = fd.get("time_series_option", "not_time")
|
||||
|
|
|
|||
|
|
@ -251,7 +251,7 @@ class SqlaTableModelTestCase(SupersetTestCase):
|
|||
else:
|
||||
self.assertNotIn("JOIN", sql.upper())
|
||||
spec.allows_joins = old_inner_join
|
||||
self.assertIsNotNone(qr.df)
|
||||
self.assertFalse(qr.df.empty)
|
||||
return qr.df
|
||||
|
||||
def test_query_with_expr_groupby_timeseries(self):
|
||||
|
|
@ -262,8 +262,9 @@ class SqlaTableModelTestCase(SupersetTestCase):
|
|||
|
||||
df1 = self.query_with_expr_helper(is_timeseries=True, inner_join=True)
|
||||
df2 = self.query_with_expr_helper(is_timeseries=True, inner_join=False)
|
||||
self.assertIsNotNone(df2) # df1 can be none if the db does not support join
|
||||
if df1 is not None:
|
||||
self.assertFalse(df2.empty)
|
||||
# df1 can be empty if the db does not support join
|
||||
if not df1.empty:
|
||||
pandas.testing.assert_frame_equal(
|
||||
cannonicalize_df(df1), cannonicalize_df(df2)
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue