From b5d03c85c578bd12bfb99446a54543bd029361ad Mon Sep 17 00:00:00 2001 From: agrawaldevesh Date: Fri, 22 Feb 2019 20:32:46 -0800 Subject: [PATCH] Do label name mutation before anything else on the dataframe (#6831) This problem especially happens with pinot when you select two metrics with different aliases but same function. For example, effectively the sql like 'select type, count(*) as one, count(*) as two from bar group by type'. In such a case, pinot will return two columns, both named count_star. So when we try to do a df['count_star'], the result is a Dataframe and not a Series. This causes a KeyError in the get_df method. So we push the DB specific label mutation inside get_df before we do any other mutation. --- superset/connectors/sqla/models.py | 10 +++++++--- superset/models/core.py | 5 ++++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index ca0db02f6..e788ce7c2 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -836,9 +836,8 @@ class SqlaTable(Model, BaseDatasource): sql = query_str_ext.sql status = utils.QueryStatus.SUCCESS error_message = None - df = None - try: - df = self.database.get_df(sql, self.schema) + + def mutator(df): labels_expected = query_str_ext.labels_expected if df is not None and not df.empty: if len(df.columns) != len(labels_expected): @@ -846,7 +845,12 @@ class SqlaTable(Model, BaseDatasource): f' differs from {labels_expected}') else: df.columns = labels_expected + return df + + try: + df = self.database.get_df(sql, self.schema, mutator) except Exception as e: + df = None status = utils.QueryStatus.FAILED logging.exception(f'Query {sql} on schema {self.schema} failed') db_engine_spec = self.database.db_engine_spec diff --git a/superset/models/core.py b/superset/models/core.py index f68e94e0d..c8435c2c3 100644 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -825,7 +825,7 @@ class Database(Model, AuditMixinNullable, ImportMixin): def get_quoter(self): return self.get_dialect().identifier_preparer.quote - def get_df(self, sql, schema): + def get_df(self, sql, schema, mutator=None): sqls = [str(s).strip().strip(';') for s in sqlparse.parse(sql)] source_key = None if request and request.referrer: @@ -869,6 +869,9 @@ class Database(Model, AuditMixinNullable, ImportMixin): coerce_float=True, ) + if mutator: + df = mutator(df) + for k, v in df.dtypes.items(): if v.type == numpy.object_ and needs_conversion(df[k]): df[k] = df[k].apply(utils.json_dumps_w_dates)