[fix] Enforce the query result to contain a data-frame (#8935)

This commit is contained in:
John Bodley 2020-01-08 11:50:26 -08:00 committed by GitHub
parent 2a94150097
commit 2d456e88eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 97 additions and 93 deletions

View File

@ -988,11 +988,9 @@ class DruidDatasource(Model, BaseDatasource):
def get_query_str(self, query_obj, phase=1, client=None):
return self.run_query(client=client, phase=phase, **query_obj)
def _add_filter_from_pre_query_data(
self, df: Optional[pd.DataFrame], dimensions, dim_filter
):
def _add_filter_from_pre_query_data(self, df: pd.DataFrame, dimensions, dim_filter):
ret = dim_filter
if df is not None and not df.empty:
if not df.empty:
new_filters = []
for unused, row in df.iterrows():
fields = []
@ -1379,7 +1377,7 @@ class DruidDatasource(Model, BaseDatasource):
if df is None or df.size == 0:
return QueryResult(
df=pd.DataFrame([]),
df=pd.DataFrame(),
query=query_str,
duration=datetime.now() - qry_start_dttm,
)

View File

@ -85,7 +85,6 @@ class AnnotationDatasource(BaseDatasource):
cache_timeout = 0
def query(self, query_obj: Dict[str, Any]) -> QueryResult:
df = None
error_message = None
qry = db.session.query(Annotation)
qry = qry.filter(Annotation.layer_id == query_obj["filter"][0]["val"])
@ -97,6 +96,7 @@ class AnnotationDatasource(BaseDatasource):
try:
df = pd.read_sql_query(qry.statement, db.engine)
except Exception as e:
df = pd.DataFrame()
status = utils.QueryStatus.FAILED
logging.exception(e)
error_message = utils.error_msg_from_exception(e)
@ -995,7 +995,7 @@ class SqlaTable(Model, BaseDatasource):
try:
df = self.database.get_df(sql, self.schema, mutator)
except Exception as e:
df = None
df = pd.DataFrame()
status = utils.QueryStatus.FAILED
logging.exception(f"Query {sql} on schema {self.schema} failed")
db_engine_spec = self.database.db_engine_spec

View File

@ -24,6 +24,7 @@ from typing import List, Optional
# isort and pylint disagree, isort should win
# pylint: disable=ungrouped-imports
import humanize
import pandas as pd
import sqlalchemy as sa
import yaml
from flask import escape, g, Markup
@ -368,11 +369,11 @@ class QueryResult: # pylint: disable=too-few-public-methods
def __init__( # pylint: disable=too-many-arguments
self, df, query, duration, status=QueryStatus.SUCCESS, error_message=None
):
self.df = df # pylint: disable=invalid-name
self.query = query
self.duration = duration
self.status = status
self.error_message = error_message
self.df: pd.DataFrame = df # pylint: disable=invalid-name
self.query: str = query
self.duration: int = duration
self.status: str = status
self.error_message: Optional[str] = error_message
class ExtraJSONMixin:

View File

@ -32,7 +32,7 @@ from collections import defaultdict, OrderedDict
from datetime import datetime, timedelta
from functools import reduce
from itertools import product
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from typing import Any, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
import geohash
import numpy as np
@ -77,6 +77,8 @@ METRIC_KEYS = [
"size",
]
VizData = Optional[Union[List[Any], Dict[Any, Any]]]
class BaseViz:
@ -109,10 +111,10 @@ class BaseViz:
self.groupby = self.form_data.get("groupby") or []
self.time_shift = timedelta()
self.status = None
self.status: Optional[str] = None
self.error_msg = ""
self.results: Optional[QueryResult] = None
self.error_message = None
self.error_message: Optional[str] = None
self.force = force
# Keeping track of whether some data came from cache
@ -190,14 +192,12 @@ class BaseViz:
df = self.get_df(query_obj)
return df.to_dict(orient="records")
def get_df(
self, query_obj: Optional[Dict[str, Any]] = None
) -> Optional[pd.DataFrame]:
def get_df(self, query_obj: Optional[Dict[str, Any]] = None) -> pd.DataFrame:
"""Returns a pandas dataframe based on the query object"""
if not query_obj:
query_obj = self.query_obj()
if not query_obj:
return None
return pd.DataFrame()
self.error_msg = ""
@ -219,7 +219,7 @@ class BaseViz:
# be considered as the default ISO date format
# If the datetime format is unix, the parse will use the corresponding
# parsing logic.
if df is not None and not df.empty:
if not df.empty:
if DTTM_ALIAS in df.columns:
if timestamp_format in ("epoch_s", "epoch_ms"):
# Column has already been formatted as a timestamp.
@ -439,11 +439,7 @@ class BaseViz:
and self.status != utils.QueryStatus.FAILED
):
try:
cache_value = dict(
dttm=cached_dttm,
df=df if df is not None else None,
query=self.query,
)
cache_value = dict(dttm=cached_dttm, df=df, query=self.query)
cache_value = pkl.dumps(cache_value, protocol=pkl.HIGHEST_PROTOCOL)
logging.info(
@ -500,7 +496,7 @@ class BaseViz:
include_index = not isinstance(df.index, pd.RangeIndex)
return df.to_csv(index=include_index, **config["CSV_EXPORT"])
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
return df.to_dict(orient="records")
@property
@ -563,7 +559,7 @@ class TableViz(BaseViz):
d["is_timeseries"] = self.should_be_timeseries()
return d
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
fd = self.form_data
if not self.should_be_timeseries() and df is not None and DTTM_ALIAS in df:
del df[DTTM_ALIAS]
@ -631,7 +627,7 @@ class TimeTableViz(BaseViz):
)
return d
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
fd = self.form_data
columns = None
values = self.metric_labels
@ -644,7 +640,7 @@ class TimeTableViz(BaseViz):
return dict(
records=pt.to_dict(orient="index"),
columns=list(pt.columns),
is_group_by=len(fd.get("groupby")) > 0,
is_group_by=len(fd.get("groupby", [])) > 0,
)
@ -684,7 +680,7 @@ class PivotTableViz(BaseViz):
raise Exception(_("Group By' and 'Columns' can't overlap"))
return d
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
if self.form_data.get("granularity") == "all" and DTTM_ALIAS in df:
del df[DTTM_ALIAS]
@ -701,7 +697,9 @@ class PivotTableViz(BaseViz):
df = df.pivot_table(
index=groupby,
columns=columns,
values=[utils.get_metric_name(m) for m in self.form_data.get("metrics")],
values=[
utils.get_metric_name(m) for m in self.form_data.get("metrics", [])
],
aggfunc=aggfunc,
margins=self.form_data.get("pivot_margins"),
)
@ -731,12 +729,10 @@ class MarkupViz(BaseViz):
def query_obj(self):
return None
def get_df(
self, query_obj: Optional[Dict[str, Any]] = None
) -> Optional[pd.DataFrame]:
return None
def get_df(self, query_obj: Optional[Dict[str, Any]] = None) -> pd.DataFrame:
return pd.DataFrame()
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
markup_type = self.form_data.get("markup_type")
code = self.form_data.get("code", "")
if markup_type == "markdown":
@ -790,7 +786,7 @@ class TreemapViz(BaseViz):
]
return result
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
df = df.set_index(self.form_data.get("groupby"))
chart_data = [
{"name": metric, "children": self._nest(metric, df)}
@ -808,7 +804,7 @@ class CalHeatmapViz(BaseViz):
credits = "<a href=https://github.com/wa0x6e/cal-heatmap>cal-heatmap</a>"
is_timeseries = True
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
form_data = self.form_data
data = {}
@ -842,9 +838,9 @@ class CalHeatmapViz(BaseViz):
elif domain == "week":
range_ = diff_delta.years * 53 + diff_delta.weeks + 1
elif domain == "day":
range_ = diff_secs // (24 * 60 * 60) + 1
range_ = diff_secs // (24 * 60 * 60) + 1 # type: ignore
else:
range_ = diff_secs // (60 * 60) + 1
range_ = diff_secs // (60 * 60) + 1 # type: ignore
return {
"data": data,
@ -900,7 +896,7 @@ class BoxPlotViz(NVD3Viz):
chart_data.append({"label": chart_label, "values": box})
return chart_data
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
form_data = self.form_data
# conform to NVD3 names
@ -929,8 +925,10 @@ class BoxPlotViz(NVD3Viz):
def whisker_low(series):
return series.min()
elif " percentiles" in whisker_type:
low, high = whisker_type.replace(" percentiles", "").split("/")
elif " percentiles" in whisker_type: # type: ignore
low, high = whisker_type.replace(" percentiles", "").split( # type: ignore
"/"
)
def whisker_high(series):
return np.nanpercentile(series, int(high))
@ -981,14 +979,14 @@ class BubbleViz(NVD3Viz):
raise Exception(_("Pick a metric for x, y and size"))
return d
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
df["x"] = df[[utils.get_metric_name(self.x_metric)]]
df["y"] = df[[utils.get_metric_name(self.y_metric)]]
df["size"] = df[[utils.get_metric_name(self.z_metric)]]
df["shape"] = "circle"
df["group"] = df[[self.series]]
series = defaultdict(list)
series: Dict[Any, List[Any]] = defaultdict(list)
for row in df.to_dict(orient="records"):
series[row["group"]].append(row)
chart_data = []
@ -1029,7 +1027,7 @@ class BulletViz(NVD3Viz):
raise Exception(_("Pick a metric to display"))
return d
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
df["metric"] = df[[utils.get_metric_name(self.metric)]]
values = df["metric"].values
return {
@ -1155,6 +1153,9 @@ class NVD3TimeSeriesViz(NVD3Viz):
if fd.get("granularity") == "all":
raise Exception(_("Pick a time granularity for your time series"))
if df.empty:
return df
if aggregate:
df = df.pivot_table(
index=DTTM_ALIAS,
@ -1236,7 +1237,7 @@ class NVD3TimeSeriesViz(NVD3Viz):
df2 = self.process_data(df2)
self._extra_chart_data.append((label, df2))
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
fd = self.form_data
comparison_type = fd.get("comparison_type") or "values"
df = self.process_data(df)
@ -1298,7 +1299,7 @@ class MultiLineViz(NVD3Viz):
def query_obj(self):
return None
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
fd = self.form_data
# Late imports to avoid circular import issues
from superset.models.slice import Slice
@ -1371,7 +1372,7 @@ class NVD3DualLineViz(NVD3Viz):
chart_data.append(d)
return chart_data
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
fd = self.form_data
if self.form_data.get("granularity") == "all":
@ -1407,7 +1408,7 @@ class NVD3TimePivotViz(NVD3TimeSeriesViz):
d["metrics"] = [self.form_data.get("metric")]
return d
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
fd = self.form_data
df = self.process_data(df)
freq = to_offset(fd.get("freq"))
@ -1465,7 +1466,7 @@ class DistributionPieViz(NVD3Viz):
verbose_name = _("Distribution - NVD3 - Pie Chart")
is_timeseries = False
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
metric = self.metric_labels[0]
df = df.pivot_table(index=self.groupby, values=[metric])
df.sort_values(by=metric, ascending=False, inplace=True)
@ -1505,7 +1506,7 @@ class HistogramViz(BaseViz):
labels = [column] + labels
return "__".join(labels)
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
"""Returns the chart data"""
chart_data = []
if len(self.groupby) > 0:
@ -1546,7 +1547,7 @@ class DistributionBarViz(DistributionPieViz):
raise Exception(_("Pick at least one field for [Series]"))
return d
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
fd = self.form_data
metrics = self.metric_labels
columns = fd.get("columns") or []
@ -1597,13 +1598,13 @@ class SunburstViz(BaseViz):
'@<a href="https://bl.ocks.org/kerryrodden/7090426">bl.ocks.org</a>'
)
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
fd = self.form_data
cols = fd.get("groupby")
metric = utils.get_metric_name(fd.get("metric"))
secondary_metric = utils.get_metric_name(fd.get("secondary_metric"))
if metric == secondary_metric or secondary_metric is None:
df.columns = cols + ["m1"]
df.columns = cols + ["m1"] # type: ignore
df["m2"] = df["m1"]
return json.loads(df.to_json(orient="values"))
@ -1633,13 +1634,13 @@ class SankeyViz(BaseViz):
qry["metrics"] = [self.form_data["metric"]]
return qry
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
df.columns = ["source", "target", "value"]
df["source"] = df["source"].astype(str)
df["target"] = df["target"].astype(str)
recs = df.to_dict(orient="records")
hierarchy = defaultdict(set)
hierarchy: Dict[str, Set[str]] = defaultdict(set)
for row in recs:
hierarchy[row["source"]].add(row["target"])
@ -1686,7 +1687,7 @@ class DirectedForceViz(BaseViz):
qry["metrics"] = [self.form_data["metric"]]
return qry
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
df.columns = ["source", "target", "value"]
return df.to_dict(orient="records")
@ -1707,7 +1708,7 @@ class ChordViz(BaseViz):
qry["metrics"] = [utils.get_metric_name(fd.get("metric"))]
return qry
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
df.columns = ["source", "target", "value"]
# Preparing a symetrical matrix like d3.chords calls for
@ -1736,7 +1737,7 @@ class CountryMapViz(BaseViz):
qry["groupby"] = [self.form_data["entity"]]
return qry
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
fd = self.form_data
cols = [fd.get("entity")]
metric = self.metric_labels[0]
@ -1762,7 +1763,7 @@ class WorldMapViz(BaseViz):
qry["groupby"] = [self.form_data["entity"]]
return qry
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
from superset.examples import countries
fd = self.form_data
@ -1830,7 +1831,7 @@ class FilterBoxViz(BaseViz):
df = self.get_df_payload(query_obj=qry).get("df")
self.dataframes[col] = df
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
filters = self.form_data.get("filter_configs") or []
d = {}
for flt in filters:
@ -1867,10 +1868,10 @@ class IFrameViz(BaseViz):
def query_obj(self):
return None
def get_df(self, query_obj: Dict[str, Any] = None) -> Optional[pd.DataFrame]:
return None
def get_df(self, query_obj: Dict[str, Any] = None) -> pd.DataFrame:
return pd.DataFrame()
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
return {}
@ -1896,7 +1897,7 @@ class ParallelCoordinatesViz(BaseViz):
d["groupby"] = [fd.get("series")]
return d
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
return df.to_dict(orient="records")
@ -1919,7 +1920,7 @@ class HeatmapViz(BaseViz):
d["groupby"] = [fd.get("all_columns_x"), fd.get("all_columns_y")]
return d
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
fd = self.form_data
x = fd.get("all_columns_x")
y = fd.get("all_columns_y")
@ -2026,20 +2027,20 @@ class MapboxViz(BaseViz):
)
return d
def get_data(self, df):
if df is None:
def get_data(self, df: pd.DataFrame) -> VizData:
if df.empty:
return None
fd = self.form_data
label_col = fd.get("mapbox_label")
has_custom_metric = label_col is not None and len(label_col) > 0
metric_col = [None] * len(df.index)
if has_custom_metric:
if label_col[0] == fd.get("all_columns_x"):
if label_col[0] == fd.get("all_columns_x"): # type: ignore
metric_col = df[fd.get("all_columns_x")]
elif label_col[0] == fd.get("all_columns_y"):
elif label_col[0] == fd.get("all_columns_y"): # type: ignore
metric_col = df[fd.get("all_columns_y")]
else:
metric_col = df[label_col[0]]
metric_col = df[label_col[0]] # type: ignore
point_radius_col = (
[None] * len(df.index)
if fd.get("point_radius") == "Auto"
@ -2106,7 +2107,7 @@ class DeckGLMultiLayer(BaseViz):
def query_obj(self):
return None
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
fd = self.form_data
# Late imports to avoid circular import issues
from superset.models.slice import Slice
@ -2248,8 +2249,8 @@ class BaseDeckGLViz(BaseViz):
cols = self.form_data.get("js_columns") or []
return {col: d.get(col) for col in cols}
def get_data(self, df):
if df is None:
def get_data(self, df: pd.DataFrame) -> VizData:
if df.empty:
return None
# Processing spatial info
@ -2310,13 +2311,13 @@ class DeckScatterViz(BaseDeckGLViz):
DTTM_ALIAS: d.get(DTTM_ALIAS),
}
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
fd = self.form_data
self.metric_label = utils.get_metric_name(self.metric) if self.metric else None
self.point_radius_fixed = fd.get("point_radius_fixed")
self.fixed_value = None
self.dim = self.form_data.get("dimension")
if self.point_radius_fixed.get("type") != "metric":
if self.point_radius_fixed and self.point_radius_fixed.get("type") != "metric":
self.fixed_value = self.point_radius_fixed.get("value")
return super().get_data(df)
@ -2342,7 +2343,7 @@ class DeckScreengrid(BaseDeckGLViz):
"__timestamp": d.get(DTTM_ALIAS) or d.get("__time"),
}
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
self.metric_label = utils.get_metric_name(self.metric)
return super().get_data(df)
@ -2358,7 +2359,7 @@ class DeckGrid(BaseDeckGLViz):
def get_properties(self, d):
return {"position": d.get("spatial"), "weight": d.get(self.metric_label) or 1}
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
self.metric_label = utils.get_metric_name(self.metric)
return super().get_data(df)
@ -2416,7 +2417,7 @@ class DeckPathViz(BaseDeckGLViz):
d["__timestamp"] = d.get(DTTM_ALIAS) or d.get("__time")
return d
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
self.metric_label = utils.get_metric_name(self.metric)
return super().get_data(df)
@ -2462,7 +2463,7 @@ class DeckHex(BaseDeckGLViz):
def get_properties(self, d):
return {"position": d.get("spatial"), "weight": d.get(self.metric_label) or 1}
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
self.metric_label = utils.get_metric_name(self.metric)
return super(DeckHex, self).get_data(df)
@ -2509,10 +2510,13 @@ class DeckArc(BaseDeckGLViz):
DTTM_ALIAS: d.get(DTTM_ALIAS),
}
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
d = super().get_data(df)
return {"features": d["features"], "mapboxApiKey": config["MAPBOX_API_KEY"]}
return {
"features": d["features"], # type: ignore
"mapboxApiKey": config["MAPBOX_API_KEY"],
}
class EventFlowViz(BaseViz):
@ -2543,7 +2547,7 @@ class EventFlowViz(BaseViz):
return query
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
return df.to_dict(orient="records")
@ -2556,7 +2560,7 @@ class PairedTTestViz(BaseViz):
sort_series = False
is_timeseries = True
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
"""
Transform received data frame into an object of the form:
{
@ -2582,7 +2586,7 @@ class PairedTTestViz(BaseViz):
else:
cols.append(col)
df.columns = cols
data = {}
data: Dict = {}
series = df.to_dict("series")
for nameSet in df.columns:
# If no groups are defined, nameSet will be the metric name
@ -2607,10 +2611,10 @@ class RoseViz(NVD3TimeSeriesViz):
sort_series = False
is_timeseries = True
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
data = super().get_data(df)
result = {}
for datum in data:
result: Dict = {}
for datum in data: # type: ignore
key = datum["key"]
for val in datum["values"]:
timestamp = val["x"].value
@ -2761,7 +2765,7 @@ class PartitionViz(NVD3TimeSeriesViz):
for i in procs[level][dims].columns
]
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> VizData:
fd = self.form_data
groups = fd.get("groupby", [])
time_op = fd.get("time_series_option", "not_time")

View File

@ -251,7 +251,7 @@ class SqlaTableModelTestCase(SupersetTestCase):
else:
self.assertNotIn("JOIN", sql.upper())
spec.allows_joins = old_inner_join
self.assertIsNotNone(qr.df)
self.assertFalse(qr.df.empty)
return qr.df
def test_query_with_expr_groupby_timeseries(self):
@ -262,8 +262,9 @@ class SqlaTableModelTestCase(SupersetTestCase):
df1 = self.query_with_expr_helper(is_timeseries=True, inner_join=True)
df2 = self.query_with_expr_helper(is_timeseries=True, inner_join=False)
self.assertIsNotNone(df2) # df1 can be none if the db does not support join
if df1 is not None:
self.assertFalse(df2.empty)
# df1 can be empty if the db does not support join
if not df1.empty:
pandas.testing.assert_frame_equal(
cannonicalize_df(df1), cannonicalize_df(df2)
)