From 2d456e88ebf1ef6fe4da7bdca1e25dcf098f6533 Mon Sep 17 00:00:00 2001
From: John Bodley <4567245+john-bodley@users.noreply.github.com>
Date: Wed, 8 Jan 2020 11:50:26 -0800
Subject: [PATCH] [fix] Enforce the query result to contain a data-frame
(#8935)
---
superset/connectors/druid/models.py | 8 +-
superset/connectors/sqla/models.py | 4 +-
superset/models/helpers.py | 11 +-
superset/viz.py | 160 ++++++++++++++--------------
tests/model_tests.py | 7 +-
5 files changed, 97 insertions(+), 93 deletions(-)
diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py
index 57d9d041c..eacf5dd8e 100644
--- a/superset/connectors/druid/models.py
+++ b/superset/connectors/druid/models.py
@@ -988,11 +988,9 @@ class DruidDatasource(Model, BaseDatasource):
def get_query_str(self, query_obj, phase=1, client=None):
return self.run_query(client=client, phase=phase, **query_obj)
- def _add_filter_from_pre_query_data(
- self, df: Optional[pd.DataFrame], dimensions, dim_filter
- ):
+ def _add_filter_from_pre_query_data(self, df: pd.DataFrame, dimensions, dim_filter):
ret = dim_filter
- if df is not None and not df.empty:
+ if not df.empty:
new_filters = []
for unused, row in df.iterrows():
fields = []
@@ -1379,7 +1377,7 @@ class DruidDatasource(Model, BaseDatasource):
if df is None or df.size == 0:
return QueryResult(
- df=pd.DataFrame([]),
+ df=pd.DataFrame(),
query=query_str,
duration=datetime.now() - qry_start_dttm,
)
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index deee43fc4..238a4a030 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -85,7 +85,6 @@ class AnnotationDatasource(BaseDatasource):
cache_timeout = 0
def query(self, query_obj: Dict[str, Any]) -> QueryResult:
- df = None
error_message = None
qry = db.session.query(Annotation)
qry = qry.filter(Annotation.layer_id == query_obj["filter"][0]["val"])
@@ -97,6 +96,7 @@ class AnnotationDatasource(BaseDatasource):
try:
df = pd.read_sql_query(qry.statement, db.engine)
except Exception as e:
+ df = pd.DataFrame()
status = utils.QueryStatus.FAILED
logging.exception(e)
error_message = utils.error_msg_from_exception(e)
@@ -995,7 +995,7 @@ class SqlaTable(Model, BaseDatasource):
try:
df = self.database.get_df(sql, self.schema, mutator)
except Exception as e:
- df = None
+ df = pd.DataFrame()
status = utils.QueryStatus.FAILED
logging.exception(f"Query {sql} on schema {self.schema} failed")
db_engine_spec = self.database.db_engine_spec
diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index b11fb3861..086c4b500 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -24,6 +24,7 @@ from typing import List, Optional
# isort and pylint disagree, isort should win
# pylint: disable=ungrouped-imports
import humanize
+import pandas as pd
import sqlalchemy as sa
import yaml
from flask import escape, g, Markup
@@ -368,11 +369,11 @@ class QueryResult: # pylint: disable=too-few-public-methods
def __init__( # pylint: disable=too-many-arguments
self, df, query, duration, status=QueryStatus.SUCCESS, error_message=None
):
- self.df = df # pylint: disable=invalid-name
- self.query = query
- self.duration = duration
- self.status = status
- self.error_message = error_message
+ self.df: pd.DataFrame = df # pylint: disable=invalid-name
+ self.query: str = query
+ self.duration: int = duration
+ self.status: str = status
+ self.error_message: Optional[str] = error_message
class ExtraJSONMixin:
diff --git a/superset/viz.py b/superset/viz.py
index beff59d78..617ccdc89 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -32,7 +32,7 @@ from collections import defaultdict, OrderedDict
from datetime import datetime, timedelta
from functools import reduce
from itertools import product
-from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
+from typing import Any, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
import geohash
import numpy as np
@@ -77,6 +77,8 @@ METRIC_KEYS = [
"size",
]
+VizData = Optional[Union[List[Any], Dict[Any, Any]]]
+
class BaseViz:
@@ -109,10 +111,10 @@ class BaseViz:
self.groupby = self.form_data.get("groupby") or []
self.time_shift = timedelta()
- self.status = None
+ self.status: Optional[str] = None
self.error_msg = ""
self.results: Optional[QueryResult] = None
- self.error_message = None
+ self.error_message: Optional[str] = None
self.force = force
# Keeping track of whether some data came from cache
@@ -190,14 +192,12 @@ class BaseViz:
df = self.get_df(query_obj)
return df.to_dict(orient="records")
- def get_df(
- self, query_obj: Optional[Dict[str, Any]] = None
- ) -> Optional[pd.DataFrame]:
+ def get_df(self, query_obj: Optional[Dict[str, Any]] = None) -> pd.DataFrame:
"""Returns a pandas dataframe based on the query object"""
if not query_obj:
query_obj = self.query_obj()
if not query_obj:
- return None
+ return pd.DataFrame()
self.error_msg = ""
@@ -219,7 +219,7 @@ class BaseViz:
# be considered as the default ISO date format
# If the datetime format is unix, the parse will use the corresponding
# parsing logic.
- if df is not None and not df.empty:
+ if not df.empty:
if DTTM_ALIAS in df.columns:
if timestamp_format in ("epoch_s", "epoch_ms"):
# Column has already been formatted as a timestamp.
@@ -439,11 +439,7 @@ class BaseViz:
and self.status != utils.QueryStatus.FAILED
):
try:
- cache_value = dict(
- dttm=cached_dttm,
- df=df if df is not None else None,
- query=self.query,
- )
+ cache_value = dict(dttm=cached_dttm, df=df, query=self.query)
cache_value = pkl.dumps(cache_value, protocol=pkl.HIGHEST_PROTOCOL)
logging.info(
@@ -500,7 +496,7 @@ class BaseViz:
include_index = not isinstance(df.index, pd.RangeIndex)
return df.to_csv(index=include_index, **config["CSV_EXPORT"])
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
return df.to_dict(orient="records")
@property
@@ -563,7 +559,7 @@ class TableViz(BaseViz):
d["is_timeseries"] = self.should_be_timeseries()
return d
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
fd = self.form_data
if not self.should_be_timeseries() and df is not None and DTTM_ALIAS in df:
del df[DTTM_ALIAS]
@@ -631,7 +627,7 @@ class TimeTableViz(BaseViz):
)
return d
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
fd = self.form_data
columns = None
values = self.metric_labels
@@ -644,7 +640,7 @@ class TimeTableViz(BaseViz):
return dict(
records=pt.to_dict(orient="index"),
columns=list(pt.columns),
- is_group_by=len(fd.get("groupby")) > 0,
+ is_group_by=len(fd.get("groupby", [])) > 0,
)
@@ -684,7 +680,7 @@ class PivotTableViz(BaseViz):
raise Exception(_("Group By' and 'Columns' can't overlap"))
return d
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
if self.form_data.get("granularity") == "all" and DTTM_ALIAS in df:
del df[DTTM_ALIAS]
@@ -701,7 +697,9 @@ class PivotTableViz(BaseViz):
df = df.pivot_table(
index=groupby,
columns=columns,
- values=[utils.get_metric_name(m) for m in self.form_data.get("metrics")],
+ values=[
+ utils.get_metric_name(m) for m in self.form_data.get("metrics", [])
+ ],
aggfunc=aggfunc,
margins=self.form_data.get("pivot_margins"),
)
@@ -731,12 +729,10 @@ class MarkupViz(BaseViz):
def query_obj(self):
return None
- def get_df(
- self, query_obj: Optional[Dict[str, Any]] = None
- ) -> Optional[pd.DataFrame]:
- return None
+ def get_df(self, query_obj: Optional[Dict[str, Any]] = None) -> pd.DataFrame:
+ return pd.DataFrame()
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
markup_type = self.form_data.get("markup_type")
code = self.form_data.get("code", "")
if markup_type == "markdown":
@@ -790,7 +786,7 @@ class TreemapViz(BaseViz):
]
return result
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
df = df.set_index(self.form_data.get("groupby"))
chart_data = [
{"name": metric, "children": self._nest(metric, df)}
@@ -808,7 +804,7 @@ class CalHeatmapViz(BaseViz):
credits = "cal-heatmap"
is_timeseries = True
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
form_data = self.form_data
data = {}
@@ -842,9 +838,9 @@ class CalHeatmapViz(BaseViz):
elif domain == "week":
range_ = diff_delta.years * 53 + diff_delta.weeks + 1
elif domain == "day":
- range_ = diff_secs // (24 * 60 * 60) + 1
+ range_ = diff_secs // (24 * 60 * 60) + 1 # type: ignore
else:
- range_ = diff_secs // (60 * 60) + 1
+ range_ = diff_secs // (60 * 60) + 1 # type: ignore
return {
"data": data,
@@ -900,7 +896,7 @@ class BoxPlotViz(NVD3Viz):
chart_data.append({"label": chart_label, "values": box})
return chart_data
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
form_data = self.form_data
# conform to NVD3 names
@@ -929,8 +925,10 @@ class BoxPlotViz(NVD3Viz):
def whisker_low(series):
return series.min()
- elif " percentiles" in whisker_type:
- low, high = whisker_type.replace(" percentiles", "").split("/")
+ elif " percentiles" in whisker_type: # type: ignore
+ low, high = whisker_type.replace(" percentiles", "").split( # type: ignore
+ "/"
+ )
def whisker_high(series):
return np.nanpercentile(series, int(high))
@@ -981,14 +979,14 @@ class BubbleViz(NVD3Viz):
raise Exception(_("Pick a metric for x, y and size"))
return d
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
df["x"] = df[[utils.get_metric_name(self.x_metric)]]
df["y"] = df[[utils.get_metric_name(self.y_metric)]]
df["size"] = df[[utils.get_metric_name(self.z_metric)]]
df["shape"] = "circle"
df["group"] = df[[self.series]]
- series = defaultdict(list)
+ series: Dict[Any, List[Any]] = defaultdict(list)
for row in df.to_dict(orient="records"):
series[row["group"]].append(row)
chart_data = []
@@ -1029,7 +1027,7 @@ class BulletViz(NVD3Viz):
raise Exception(_("Pick a metric to display"))
return d
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
df["metric"] = df[[utils.get_metric_name(self.metric)]]
values = df["metric"].values
return {
@@ -1155,6 +1153,9 @@ class NVD3TimeSeriesViz(NVD3Viz):
if fd.get("granularity") == "all":
raise Exception(_("Pick a time granularity for your time series"))
+ if df.empty:
+ return df
+
if aggregate:
df = df.pivot_table(
index=DTTM_ALIAS,
@@ -1236,7 +1237,7 @@ class NVD3TimeSeriesViz(NVD3Viz):
df2 = self.process_data(df2)
self._extra_chart_data.append((label, df2))
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
fd = self.form_data
comparison_type = fd.get("comparison_type") or "values"
df = self.process_data(df)
@@ -1298,7 +1299,7 @@ class MultiLineViz(NVD3Viz):
def query_obj(self):
return None
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
fd = self.form_data
# Late imports to avoid circular import issues
from superset.models.slice import Slice
@@ -1371,7 +1372,7 @@ class NVD3DualLineViz(NVD3Viz):
chart_data.append(d)
return chart_data
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
fd = self.form_data
if self.form_data.get("granularity") == "all":
@@ -1407,7 +1408,7 @@ class NVD3TimePivotViz(NVD3TimeSeriesViz):
d["metrics"] = [self.form_data.get("metric")]
return d
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
fd = self.form_data
df = self.process_data(df)
freq = to_offset(fd.get("freq"))
@@ -1465,7 +1466,7 @@ class DistributionPieViz(NVD3Viz):
verbose_name = _("Distribution - NVD3 - Pie Chart")
is_timeseries = False
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
metric = self.metric_labels[0]
df = df.pivot_table(index=self.groupby, values=[metric])
df.sort_values(by=metric, ascending=False, inplace=True)
@@ -1505,7 +1506,7 @@ class HistogramViz(BaseViz):
labels = [column] + labels
return "__".join(labels)
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
"""Returns the chart data"""
chart_data = []
if len(self.groupby) > 0:
@@ -1546,7 +1547,7 @@ class DistributionBarViz(DistributionPieViz):
raise Exception(_("Pick at least one field for [Series]"))
return d
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
fd = self.form_data
metrics = self.metric_labels
columns = fd.get("columns") or []
@@ -1597,13 +1598,13 @@ class SunburstViz(BaseViz):
'@bl.ocks.org'
)
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
fd = self.form_data
cols = fd.get("groupby")
metric = utils.get_metric_name(fd.get("metric"))
secondary_metric = utils.get_metric_name(fd.get("secondary_metric"))
if metric == secondary_metric or secondary_metric is None:
- df.columns = cols + ["m1"]
+ df.columns = cols + ["m1"] # type: ignore
df["m2"] = df["m1"]
return json.loads(df.to_json(orient="values"))
@@ -1633,13 +1634,13 @@ class SankeyViz(BaseViz):
qry["metrics"] = [self.form_data["metric"]]
return qry
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
df.columns = ["source", "target", "value"]
df["source"] = df["source"].astype(str)
df["target"] = df["target"].astype(str)
recs = df.to_dict(orient="records")
- hierarchy = defaultdict(set)
+ hierarchy: Dict[str, Set[str]] = defaultdict(set)
for row in recs:
hierarchy[row["source"]].add(row["target"])
@@ -1686,7 +1687,7 @@ class DirectedForceViz(BaseViz):
qry["metrics"] = [self.form_data["metric"]]
return qry
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
df.columns = ["source", "target", "value"]
return df.to_dict(orient="records")
@@ -1707,7 +1708,7 @@ class ChordViz(BaseViz):
qry["metrics"] = [utils.get_metric_name(fd.get("metric"))]
return qry
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
df.columns = ["source", "target", "value"]
# Preparing a symetrical matrix like d3.chords calls for
@@ -1736,7 +1737,7 @@ class CountryMapViz(BaseViz):
qry["groupby"] = [self.form_data["entity"]]
return qry
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
fd = self.form_data
cols = [fd.get("entity")]
metric = self.metric_labels[0]
@@ -1762,7 +1763,7 @@ class WorldMapViz(BaseViz):
qry["groupby"] = [self.form_data["entity"]]
return qry
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
from superset.examples import countries
fd = self.form_data
@@ -1830,7 +1831,7 @@ class FilterBoxViz(BaseViz):
df = self.get_df_payload(query_obj=qry).get("df")
self.dataframes[col] = df
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
filters = self.form_data.get("filter_configs") or []
d = {}
for flt in filters:
@@ -1867,10 +1868,10 @@ class IFrameViz(BaseViz):
def query_obj(self):
return None
- def get_df(self, query_obj: Dict[str, Any] = None) -> Optional[pd.DataFrame]:
- return None
+ def get_df(self, query_obj: Dict[str, Any] = None) -> pd.DataFrame:
+ return pd.DataFrame()
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
return {}
@@ -1896,7 +1897,7 @@ class ParallelCoordinatesViz(BaseViz):
d["groupby"] = [fd.get("series")]
return d
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
return df.to_dict(orient="records")
@@ -1919,7 +1920,7 @@ class HeatmapViz(BaseViz):
d["groupby"] = [fd.get("all_columns_x"), fd.get("all_columns_y")]
return d
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
fd = self.form_data
x = fd.get("all_columns_x")
y = fd.get("all_columns_y")
@@ -2026,20 +2027,20 @@ class MapboxViz(BaseViz):
)
return d
- def get_data(self, df):
- if df is None:
+ def get_data(self, df: pd.DataFrame) -> VizData:
+ if df.empty:
return None
fd = self.form_data
label_col = fd.get("mapbox_label")
has_custom_metric = label_col is not None and len(label_col) > 0
metric_col = [None] * len(df.index)
if has_custom_metric:
- if label_col[0] == fd.get("all_columns_x"):
+ if label_col[0] == fd.get("all_columns_x"): # type: ignore
metric_col = df[fd.get("all_columns_x")]
- elif label_col[0] == fd.get("all_columns_y"):
+ elif label_col[0] == fd.get("all_columns_y"): # type: ignore
metric_col = df[fd.get("all_columns_y")]
else:
- metric_col = df[label_col[0]]
+ metric_col = df[label_col[0]] # type: ignore
point_radius_col = (
[None] * len(df.index)
if fd.get("point_radius") == "Auto"
@@ -2106,7 +2107,7 @@ class DeckGLMultiLayer(BaseViz):
def query_obj(self):
return None
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
fd = self.form_data
# Late imports to avoid circular import issues
from superset.models.slice import Slice
@@ -2248,8 +2249,8 @@ class BaseDeckGLViz(BaseViz):
cols = self.form_data.get("js_columns") or []
return {col: d.get(col) for col in cols}
- def get_data(self, df):
- if df is None:
+ def get_data(self, df: pd.DataFrame) -> VizData:
+ if df.empty:
return None
# Processing spatial info
@@ -2310,13 +2311,13 @@ class DeckScatterViz(BaseDeckGLViz):
DTTM_ALIAS: d.get(DTTM_ALIAS),
}
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
fd = self.form_data
self.metric_label = utils.get_metric_name(self.metric) if self.metric else None
self.point_radius_fixed = fd.get("point_radius_fixed")
self.fixed_value = None
self.dim = self.form_data.get("dimension")
- if self.point_radius_fixed.get("type") != "metric":
+ if self.point_radius_fixed and self.point_radius_fixed.get("type") != "metric":
self.fixed_value = self.point_radius_fixed.get("value")
return super().get_data(df)
@@ -2342,7 +2343,7 @@ class DeckScreengrid(BaseDeckGLViz):
"__timestamp": d.get(DTTM_ALIAS) or d.get("__time"),
}
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
self.metric_label = utils.get_metric_name(self.metric)
return super().get_data(df)
@@ -2358,7 +2359,7 @@ class DeckGrid(BaseDeckGLViz):
def get_properties(self, d):
return {"position": d.get("spatial"), "weight": d.get(self.metric_label) or 1}
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
self.metric_label = utils.get_metric_name(self.metric)
return super().get_data(df)
@@ -2416,7 +2417,7 @@ class DeckPathViz(BaseDeckGLViz):
d["__timestamp"] = d.get(DTTM_ALIAS) or d.get("__time")
return d
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
self.metric_label = utils.get_metric_name(self.metric)
return super().get_data(df)
@@ -2462,7 +2463,7 @@ class DeckHex(BaseDeckGLViz):
def get_properties(self, d):
return {"position": d.get("spatial"), "weight": d.get(self.metric_label) or 1}
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
self.metric_label = utils.get_metric_name(self.metric)
return super(DeckHex, self).get_data(df)
@@ -2509,10 +2510,13 @@ class DeckArc(BaseDeckGLViz):
DTTM_ALIAS: d.get(DTTM_ALIAS),
}
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
d = super().get_data(df)
- return {"features": d["features"], "mapboxApiKey": config["MAPBOX_API_KEY"]}
+ return {
+ "features": d["features"], # type: ignore
+ "mapboxApiKey": config["MAPBOX_API_KEY"],
+ }
class EventFlowViz(BaseViz):
@@ -2543,7 +2547,7 @@ class EventFlowViz(BaseViz):
return query
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
return df.to_dict(orient="records")
@@ -2556,7 +2560,7 @@ class PairedTTestViz(BaseViz):
sort_series = False
is_timeseries = True
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
"""
Transform received data frame into an object of the form:
{
@@ -2582,7 +2586,7 @@ class PairedTTestViz(BaseViz):
else:
cols.append(col)
df.columns = cols
- data = {}
+ data: Dict = {}
series = df.to_dict("series")
for nameSet in df.columns:
# If no groups are defined, nameSet will be the metric name
@@ -2607,10 +2611,10 @@ class RoseViz(NVD3TimeSeriesViz):
sort_series = False
is_timeseries = True
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
data = super().get_data(df)
- result = {}
- for datum in data:
+ result: Dict = {}
+ for datum in data: # type: ignore
key = datum["key"]
for val in datum["values"]:
timestamp = val["x"].value
@@ -2761,7 +2765,7 @@ class PartitionViz(NVD3TimeSeriesViz):
for i in procs[level][dims].columns
]
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> VizData:
fd = self.form_data
groups = fd.get("groupby", [])
time_op = fd.get("time_series_option", "not_time")
diff --git a/tests/model_tests.py b/tests/model_tests.py
index 4081c111c..24a87b6e5 100644
--- a/tests/model_tests.py
+++ b/tests/model_tests.py
@@ -251,7 +251,7 @@ class SqlaTableModelTestCase(SupersetTestCase):
else:
self.assertNotIn("JOIN", sql.upper())
spec.allows_joins = old_inner_join
- self.assertIsNotNone(qr.df)
+ self.assertFalse(qr.df.empty)
return qr.df
def test_query_with_expr_groupby_timeseries(self):
@@ -262,8 +262,9 @@ class SqlaTableModelTestCase(SupersetTestCase):
df1 = self.query_with_expr_helper(is_timeseries=True, inner_join=True)
df2 = self.query_with_expr_helper(is_timeseries=True, inner_join=False)
- self.assertIsNotNone(df2) # df1 can be none if the db does not support join
- if df1 is not None:
+ self.assertFalse(df2.empty)
+ # df1 can be empty if the db does not support join
+ if not df1.empty:
pandas.testing.assert_frame_equal(
cannonicalize_df(df1), cannonicalize_df(df2)
)