From 2d456e88ebf1ef6fe4da7bdca1e25dcf098f6533 Mon Sep 17 00:00:00 2001 From: John Bodley <4567245+john-bodley@users.noreply.github.com> Date: Wed, 8 Jan 2020 11:50:26 -0800 Subject: [PATCH] [fix] Enforce the query result to contain a data-frame (#8935) --- superset/connectors/druid/models.py | 8 +- superset/connectors/sqla/models.py | 4 +- superset/models/helpers.py | 11 +- superset/viz.py | 160 ++++++++++++++-------------- tests/model_tests.py | 7 +- 5 files changed, 97 insertions(+), 93 deletions(-) diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index 57d9d041c..eacf5dd8e 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -988,11 +988,9 @@ class DruidDatasource(Model, BaseDatasource): def get_query_str(self, query_obj, phase=1, client=None): return self.run_query(client=client, phase=phase, **query_obj) - def _add_filter_from_pre_query_data( - self, df: Optional[pd.DataFrame], dimensions, dim_filter - ): + def _add_filter_from_pre_query_data(self, df: pd.DataFrame, dimensions, dim_filter): ret = dim_filter - if df is not None and not df.empty: + if not df.empty: new_filters = [] for unused, row in df.iterrows(): fields = [] @@ -1379,7 +1377,7 @@ class DruidDatasource(Model, BaseDatasource): if df is None or df.size == 0: return QueryResult( - df=pd.DataFrame([]), + df=pd.DataFrame(), query=query_str, duration=datetime.now() - qry_start_dttm, ) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index deee43fc4..238a4a030 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -85,7 +85,6 @@ class AnnotationDatasource(BaseDatasource): cache_timeout = 0 def query(self, query_obj: Dict[str, Any]) -> QueryResult: - df = None error_message = None qry = db.session.query(Annotation) qry = qry.filter(Annotation.layer_id == query_obj["filter"][0]["val"]) @@ -97,6 +96,7 @@ class AnnotationDatasource(BaseDatasource): try: df = pd.read_sql_query(qry.statement, db.engine) except Exception as e: + df = pd.DataFrame() status = utils.QueryStatus.FAILED logging.exception(e) error_message = utils.error_msg_from_exception(e) @@ -995,7 +995,7 @@ class SqlaTable(Model, BaseDatasource): try: df = self.database.get_df(sql, self.schema, mutator) except Exception as e: - df = None + df = pd.DataFrame() status = utils.QueryStatus.FAILED logging.exception(f"Query {sql} on schema {self.schema} failed") db_engine_spec = self.database.db_engine_spec diff --git a/superset/models/helpers.py b/superset/models/helpers.py index b11fb3861..086c4b500 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -24,6 +24,7 @@ from typing import List, Optional # isort and pylint disagree, isort should win # pylint: disable=ungrouped-imports import humanize +import pandas as pd import sqlalchemy as sa import yaml from flask import escape, g, Markup @@ -368,11 +369,11 @@ class QueryResult: # pylint: disable=too-few-public-methods def __init__( # pylint: disable=too-many-arguments self, df, query, duration, status=QueryStatus.SUCCESS, error_message=None ): - self.df = df # pylint: disable=invalid-name - self.query = query - self.duration = duration - self.status = status - self.error_message = error_message + self.df: pd.DataFrame = df # pylint: disable=invalid-name + self.query: str = query + self.duration: int = duration + self.status: str = status + self.error_message: Optional[str] = error_message class ExtraJSONMixin: diff --git a/superset/viz.py b/superset/viz.py index beff59d78..617ccdc89 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -32,7 +32,7 @@ from collections import defaultdict, OrderedDict from datetime import datetime, timedelta from functools import reduce from itertools import product -from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING +from typing import Any, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union import geohash import numpy as np @@ -77,6 +77,8 @@ METRIC_KEYS = [ "size", ] +VizData = Optional[Union[List[Any], Dict[Any, Any]]] + class BaseViz: @@ -109,10 +111,10 @@ class BaseViz: self.groupby = self.form_data.get("groupby") or [] self.time_shift = timedelta() - self.status = None + self.status: Optional[str] = None self.error_msg = "" self.results: Optional[QueryResult] = None - self.error_message = None + self.error_message: Optional[str] = None self.force = force # Keeping track of whether some data came from cache @@ -190,14 +192,12 @@ class BaseViz: df = self.get_df(query_obj) return df.to_dict(orient="records") - def get_df( - self, query_obj: Optional[Dict[str, Any]] = None - ) -> Optional[pd.DataFrame]: + def get_df(self, query_obj: Optional[Dict[str, Any]] = None) -> pd.DataFrame: """Returns a pandas dataframe based on the query object""" if not query_obj: query_obj = self.query_obj() if not query_obj: - return None + return pd.DataFrame() self.error_msg = "" @@ -219,7 +219,7 @@ class BaseViz: # be considered as the default ISO date format # If the datetime format is unix, the parse will use the corresponding # parsing logic. - if df is not None and not df.empty: + if not df.empty: if DTTM_ALIAS in df.columns: if timestamp_format in ("epoch_s", "epoch_ms"): # Column has already been formatted as a timestamp. @@ -439,11 +439,7 @@ class BaseViz: and self.status != utils.QueryStatus.FAILED ): try: - cache_value = dict( - dttm=cached_dttm, - df=df if df is not None else None, - query=self.query, - ) + cache_value = dict(dttm=cached_dttm, df=df, query=self.query) cache_value = pkl.dumps(cache_value, protocol=pkl.HIGHEST_PROTOCOL) logging.info( @@ -500,7 +496,7 @@ class BaseViz: include_index = not isinstance(df.index, pd.RangeIndex) return df.to_csv(index=include_index, **config["CSV_EXPORT"]) - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: return df.to_dict(orient="records") @property @@ -563,7 +559,7 @@ class TableViz(BaseViz): d["is_timeseries"] = self.should_be_timeseries() return d - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: fd = self.form_data if not self.should_be_timeseries() and df is not None and DTTM_ALIAS in df: del df[DTTM_ALIAS] @@ -631,7 +627,7 @@ class TimeTableViz(BaseViz): ) return d - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: fd = self.form_data columns = None values = self.metric_labels @@ -644,7 +640,7 @@ class TimeTableViz(BaseViz): return dict( records=pt.to_dict(orient="index"), columns=list(pt.columns), - is_group_by=len(fd.get("groupby")) > 0, + is_group_by=len(fd.get("groupby", [])) > 0, ) @@ -684,7 +680,7 @@ class PivotTableViz(BaseViz): raise Exception(_("Group By' and 'Columns' can't overlap")) return d - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: if self.form_data.get("granularity") == "all" and DTTM_ALIAS in df: del df[DTTM_ALIAS] @@ -701,7 +697,9 @@ class PivotTableViz(BaseViz): df = df.pivot_table( index=groupby, columns=columns, - values=[utils.get_metric_name(m) for m in self.form_data.get("metrics")], + values=[ + utils.get_metric_name(m) for m in self.form_data.get("metrics", []) + ], aggfunc=aggfunc, margins=self.form_data.get("pivot_margins"), ) @@ -731,12 +729,10 @@ class MarkupViz(BaseViz): def query_obj(self): return None - def get_df( - self, query_obj: Optional[Dict[str, Any]] = None - ) -> Optional[pd.DataFrame]: - return None + def get_df(self, query_obj: Optional[Dict[str, Any]] = None) -> pd.DataFrame: + return pd.DataFrame() - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: markup_type = self.form_data.get("markup_type") code = self.form_data.get("code", "") if markup_type == "markdown": @@ -790,7 +786,7 @@ class TreemapViz(BaseViz): ] return result - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: df = df.set_index(self.form_data.get("groupby")) chart_data = [ {"name": metric, "children": self._nest(metric, df)} @@ -808,7 +804,7 @@ class CalHeatmapViz(BaseViz): credits = "cal-heatmap" is_timeseries = True - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: form_data = self.form_data data = {} @@ -842,9 +838,9 @@ class CalHeatmapViz(BaseViz): elif domain == "week": range_ = diff_delta.years * 53 + diff_delta.weeks + 1 elif domain == "day": - range_ = diff_secs // (24 * 60 * 60) + 1 + range_ = diff_secs // (24 * 60 * 60) + 1 # type: ignore else: - range_ = diff_secs // (60 * 60) + 1 + range_ = diff_secs // (60 * 60) + 1 # type: ignore return { "data": data, @@ -900,7 +896,7 @@ class BoxPlotViz(NVD3Viz): chart_data.append({"label": chart_label, "values": box}) return chart_data - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: form_data = self.form_data # conform to NVD3 names @@ -929,8 +925,10 @@ class BoxPlotViz(NVD3Viz): def whisker_low(series): return series.min() - elif " percentiles" in whisker_type: - low, high = whisker_type.replace(" percentiles", "").split("/") + elif " percentiles" in whisker_type: # type: ignore + low, high = whisker_type.replace(" percentiles", "").split( # type: ignore + "/" + ) def whisker_high(series): return np.nanpercentile(series, int(high)) @@ -981,14 +979,14 @@ class BubbleViz(NVD3Viz): raise Exception(_("Pick a metric for x, y and size")) return d - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: df["x"] = df[[utils.get_metric_name(self.x_metric)]] df["y"] = df[[utils.get_metric_name(self.y_metric)]] df["size"] = df[[utils.get_metric_name(self.z_metric)]] df["shape"] = "circle" df["group"] = df[[self.series]] - series = defaultdict(list) + series: Dict[Any, List[Any]] = defaultdict(list) for row in df.to_dict(orient="records"): series[row["group"]].append(row) chart_data = [] @@ -1029,7 +1027,7 @@ class BulletViz(NVD3Viz): raise Exception(_("Pick a metric to display")) return d - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: df["metric"] = df[[utils.get_metric_name(self.metric)]] values = df["metric"].values return { @@ -1155,6 +1153,9 @@ class NVD3TimeSeriesViz(NVD3Viz): if fd.get("granularity") == "all": raise Exception(_("Pick a time granularity for your time series")) + if df.empty: + return df + if aggregate: df = df.pivot_table( index=DTTM_ALIAS, @@ -1236,7 +1237,7 @@ class NVD3TimeSeriesViz(NVD3Viz): df2 = self.process_data(df2) self._extra_chart_data.append((label, df2)) - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: fd = self.form_data comparison_type = fd.get("comparison_type") or "values" df = self.process_data(df) @@ -1298,7 +1299,7 @@ class MultiLineViz(NVD3Viz): def query_obj(self): return None - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: fd = self.form_data # Late imports to avoid circular import issues from superset.models.slice import Slice @@ -1371,7 +1372,7 @@ class NVD3DualLineViz(NVD3Viz): chart_data.append(d) return chart_data - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: fd = self.form_data if self.form_data.get("granularity") == "all": @@ -1407,7 +1408,7 @@ class NVD3TimePivotViz(NVD3TimeSeriesViz): d["metrics"] = [self.form_data.get("metric")] return d - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: fd = self.form_data df = self.process_data(df) freq = to_offset(fd.get("freq")) @@ -1465,7 +1466,7 @@ class DistributionPieViz(NVD3Viz): verbose_name = _("Distribution - NVD3 - Pie Chart") is_timeseries = False - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: metric = self.metric_labels[0] df = df.pivot_table(index=self.groupby, values=[metric]) df.sort_values(by=metric, ascending=False, inplace=True) @@ -1505,7 +1506,7 @@ class HistogramViz(BaseViz): labels = [column] + labels return "__".join(labels) - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: """Returns the chart data""" chart_data = [] if len(self.groupby) > 0: @@ -1546,7 +1547,7 @@ class DistributionBarViz(DistributionPieViz): raise Exception(_("Pick at least one field for [Series]")) return d - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: fd = self.form_data metrics = self.metric_labels columns = fd.get("columns") or [] @@ -1597,13 +1598,13 @@ class SunburstViz(BaseViz): '@bl.ocks.org' ) - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: fd = self.form_data cols = fd.get("groupby") metric = utils.get_metric_name(fd.get("metric")) secondary_metric = utils.get_metric_name(fd.get("secondary_metric")) if metric == secondary_metric or secondary_metric is None: - df.columns = cols + ["m1"] + df.columns = cols + ["m1"] # type: ignore df["m2"] = df["m1"] return json.loads(df.to_json(orient="values")) @@ -1633,13 +1634,13 @@ class SankeyViz(BaseViz): qry["metrics"] = [self.form_data["metric"]] return qry - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: df.columns = ["source", "target", "value"] df["source"] = df["source"].astype(str) df["target"] = df["target"].astype(str) recs = df.to_dict(orient="records") - hierarchy = defaultdict(set) + hierarchy: Dict[str, Set[str]] = defaultdict(set) for row in recs: hierarchy[row["source"]].add(row["target"]) @@ -1686,7 +1687,7 @@ class DirectedForceViz(BaseViz): qry["metrics"] = [self.form_data["metric"]] return qry - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: df.columns = ["source", "target", "value"] return df.to_dict(orient="records") @@ -1707,7 +1708,7 @@ class ChordViz(BaseViz): qry["metrics"] = [utils.get_metric_name(fd.get("metric"))] return qry - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: df.columns = ["source", "target", "value"] # Preparing a symetrical matrix like d3.chords calls for @@ -1736,7 +1737,7 @@ class CountryMapViz(BaseViz): qry["groupby"] = [self.form_data["entity"]] return qry - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: fd = self.form_data cols = [fd.get("entity")] metric = self.metric_labels[0] @@ -1762,7 +1763,7 @@ class WorldMapViz(BaseViz): qry["groupby"] = [self.form_data["entity"]] return qry - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: from superset.examples import countries fd = self.form_data @@ -1830,7 +1831,7 @@ class FilterBoxViz(BaseViz): df = self.get_df_payload(query_obj=qry).get("df") self.dataframes[col] = df - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: filters = self.form_data.get("filter_configs") or [] d = {} for flt in filters: @@ -1867,10 +1868,10 @@ class IFrameViz(BaseViz): def query_obj(self): return None - def get_df(self, query_obj: Dict[str, Any] = None) -> Optional[pd.DataFrame]: - return None + def get_df(self, query_obj: Dict[str, Any] = None) -> pd.DataFrame: + return pd.DataFrame() - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: return {} @@ -1896,7 +1897,7 @@ class ParallelCoordinatesViz(BaseViz): d["groupby"] = [fd.get("series")] return d - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: return df.to_dict(orient="records") @@ -1919,7 +1920,7 @@ class HeatmapViz(BaseViz): d["groupby"] = [fd.get("all_columns_x"), fd.get("all_columns_y")] return d - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: fd = self.form_data x = fd.get("all_columns_x") y = fd.get("all_columns_y") @@ -2026,20 +2027,20 @@ class MapboxViz(BaseViz): ) return d - def get_data(self, df): - if df is None: + def get_data(self, df: pd.DataFrame) -> VizData: + if df.empty: return None fd = self.form_data label_col = fd.get("mapbox_label") has_custom_metric = label_col is not None and len(label_col) > 0 metric_col = [None] * len(df.index) if has_custom_metric: - if label_col[0] == fd.get("all_columns_x"): + if label_col[0] == fd.get("all_columns_x"): # type: ignore metric_col = df[fd.get("all_columns_x")] - elif label_col[0] == fd.get("all_columns_y"): + elif label_col[0] == fd.get("all_columns_y"): # type: ignore metric_col = df[fd.get("all_columns_y")] else: - metric_col = df[label_col[0]] + metric_col = df[label_col[0]] # type: ignore point_radius_col = ( [None] * len(df.index) if fd.get("point_radius") == "Auto" @@ -2106,7 +2107,7 @@ class DeckGLMultiLayer(BaseViz): def query_obj(self): return None - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: fd = self.form_data # Late imports to avoid circular import issues from superset.models.slice import Slice @@ -2248,8 +2249,8 @@ class BaseDeckGLViz(BaseViz): cols = self.form_data.get("js_columns") or [] return {col: d.get(col) for col in cols} - def get_data(self, df): - if df is None: + def get_data(self, df: pd.DataFrame) -> VizData: + if df.empty: return None # Processing spatial info @@ -2310,13 +2311,13 @@ class DeckScatterViz(BaseDeckGLViz): DTTM_ALIAS: d.get(DTTM_ALIAS), } - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: fd = self.form_data self.metric_label = utils.get_metric_name(self.metric) if self.metric else None self.point_radius_fixed = fd.get("point_radius_fixed") self.fixed_value = None self.dim = self.form_data.get("dimension") - if self.point_radius_fixed.get("type") != "metric": + if self.point_radius_fixed and self.point_radius_fixed.get("type") != "metric": self.fixed_value = self.point_radius_fixed.get("value") return super().get_data(df) @@ -2342,7 +2343,7 @@ class DeckScreengrid(BaseDeckGLViz): "__timestamp": d.get(DTTM_ALIAS) or d.get("__time"), } - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: self.metric_label = utils.get_metric_name(self.metric) return super().get_data(df) @@ -2358,7 +2359,7 @@ class DeckGrid(BaseDeckGLViz): def get_properties(self, d): return {"position": d.get("spatial"), "weight": d.get(self.metric_label) or 1} - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: self.metric_label = utils.get_metric_name(self.metric) return super().get_data(df) @@ -2416,7 +2417,7 @@ class DeckPathViz(BaseDeckGLViz): d["__timestamp"] = d.get(DTTM_ALIAS) or d.get("__time") return d - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: self.metric_label = utils.get_metric_name(self.metric) return super().get_data(df) @@ -2462,7 +2463,7 @@ class DeckHex(BaseDeckGLViz): def get_properties(self, d): return {"position": d.get("spatial"), "weight": d.get(self.metric_label) or 1} - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: self.metric_label = utils.get_metric_name(self.metric) return super(DeckHex, self).get_data(df) @@ -2509,10 +2510,13 @@ class DeckArc(BaseDeckGLViz): DTTM_ALIAS: d.get(DTTM_ALIAS), } - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: d = super().get_data(df) - return {"features": d["features"], "mapboxApiKey": config["MAPBOX_API_KEY"]} + return { + "features": d["features"], # type: ignore + "mapboxApiKey": config["MAPBOX_API_KEY"], + } class EventFlowViz(BaseViz): @@ -2543,7 +2547,7 @@ class EventFlowViz(BaseViz): return query - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: return df.to_dict(orient="records") @@ -2556,7 +2560,7 @@ class PairedTTestViz(BaseViz): sort_series = False is_timeseries = True - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: """ Transform received data frame into an object of the form: { @@ -2582,7 +2586,7 @@ class PairedTTestViz(BaseViz): else: cols.append(col) df.columns = cols - data = {} + data: Dict = {} series = df.to_dict("series") for nameSet in df.columns: # If no groups are defined, nameSet will be the metric name @@ -2607,10 +2611,10 @@ class RoseViz(NVD3TimeSeriesViz): sort_series = False is_timeseries = True - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: data = super().get_data(df) - result = {} - for datum in data: + result: Dict = {} + for datum in data: # type: ignore key = datum["key"] for val in datum["values"]: timestamp = val["x"].value @@ -2761,7 +2765,7 @@ class PartitionViz(NVD3TimeSeriesViz): for i in procs[level][dims].columns ] - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> VizData: fd = self.form_data groups = fd.get("groupby", []) time_op = fd.get("time_series_option", "not_time") diff --git a/tests/model_tests.py b/tests/model_tests.py index 4081c111c..24a87b6e5 100644 --- a/tests/model_tests.py +++ b/tests/model_tests.py @@ -251,7 +251,7 @@ class SqlaTableModelTestCase(SupersetTestCase): else: self.assertNotIn("JOIN", sql.upper()) spec.allows_joins = old_inner_join - self.assertIsNotNone(qr.df) + self.assertFalse(qr.df.empty) return qr.df def test_query_with_expr_groupby_timeseries(self): @@ -262,8 +262,9 @@ class SqlaTableModelTestCase(SupersetTestCase): df1 = self.query_with_expr_helper(is_timeseries=True, inner_join=True) df2 = self.query_with_expr_helper(is_timeseries=True, inner_join=False) - self.assertIsNotNone(df2) # df1 can be none if the db does not support join - if df1 is not None: + self.assertFalse(df2.empty) + # df1 can be empty if the db does not support join + if not df1.empty: pandas.testing.assert_frame_equal( cannonicalize_df(df1), cannonicalize_df(df2) )