diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index ca0db02f6..e788ce7c2 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -836,9 +836,8 @@ class SqlaTable(Model, BaseDatasource): sql = query_str_ext.sql status = utils.QueryStatus.SUCCESS error_message = None - df = None - try: - df = self.database.get_df(sql, self.schema) + + def mutator(df): labels_expected = query_str_ext.labels_expected if df is not None and not df.empty: if len(df.columns) != len(labels_expected): @@ -846,7 +845,12 @@ class SqlaTable(Model, BaseDatasource): f' differs from {labels_expected}') else: df.columns = labels_expected + return df + + try: + df = self.database.get_df(sql, self.schema, mutator) except Exception as e: + df = None status = utils.QueryStatus.FAILED logging.exception(f'Query {sql} on schema {self.schema} failed') db_engine_spec = self.database.db_engine_spec diff --git a/superset/models/core.py b/superset/models/core.py index f68e94e0d..c8435c2c3 100644 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -825,7 +825,7 @@ class Database(Model, AuditMixinNullable, ImportMixin): def get_quoter(self): return self.get_dialect().identifier_preparer.quote - def get_df(self, sql, schema): + def get_df(self, sql, schema, mutator=None): sqls = [str(s).strip().strip(';') for s in sqlparse.parse(sql)] source_key = None if request and request.referrer: @@ -869,6 +869,9 @@ class Database(Model, AuditMixinNullable, ImportMixin): coerce_float=True, ) + if mutator: + df = mutator(df) + for k, v in df.dtypes.items(): if v.type == numpy.object_ and needs_conversion(df[k]): df[k] = df[k].apply(utils.json_dumps_w_dates)