Allow MetricsControl to aggregate on a column with an expression (#5021)

* Allow MetricsControl to aggregate on a column with an expression

* Adding test case for metrics based on columns
This commit is contained in:
michellethomas 2018-05-22 09:58:38 -07:00 committed by John Bodley
parent b312cdad2f
commit b8aeb1a825
2 changed files with 40 additions and 3 deletions

View File

@ -450,10 +450,24 @@ class SqlaTable(Model, BaseDatasource):
return TextAsFrom(sa.text(from_sql), []).alias('expr_qry')
return self.get_sqla_table()
def adhoc_metric_to_sa(self, metric):
def adhoc_metric_to_sa(self, metric, cols):
"""
Turn an adhoc metric into a sqlalchemy column.
:param dict metric: Adhoc metric definition
:param dict cols: Columns for the current table
:returns: The metric defined as a sqlalchemy column
:rtype: sqlalchemy.sql.column
"""
expressionType = metric.get('expressionType')
if expressionType == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']:
sa_column = column(metric.get('column').get('column_name'))
column_name = metric.get('column').get('column_name')
sa_column = column(column_name)
table_column = cols.get(column_name)
if table_column:
sa_column = table_column.sqla_col
sa_metric = self.sqla_aggregations[metric.get('aggregate')](sa_column)
sa_metric = sa_metric.label(metric.get('label'))
return sa_metric
@ -518,7 +532,7 @@ class SqlaTable(Model, BaseDatasource):
metrics_exprs = []
for m in metrics:
if utils.is_adhoc_metric(m):
metrics_exprs.append(self.adhoc_metric_to_sa(m))
metrics_exprs.append(self.adhoc_metric_to_sa(m, cols))
elif m in metrics_dict:
metrics_exprs.append(metrics_dict.get(m).sqla_col)
else:

View File

@ -19,6 +19,7 @@ import polyline
from superset import app, db, security_manager, utils
from superset.connectors.connector_registry import ConnectorRegistry
from superset.connectors.sqla.models import TableColumn
from superset.models import core as models
# Shortcuts
@ -585,6 +586,10 @@ def load_birth_names():
obj.main_dttm_col = 'ds'
obj.database = utils.get_or_create_main_db()
obj.filter_select_enabled = True
obj.columns.append(TableColumn(
column_name='num_california',
expression="CASE WHEN state = 'CA' THEN num ELSE 0 END"
))
db.session.merge(obj)
db.session.commit()
obj.fetch_metadata()
@ -737,6 +742,24 @@ def load_birth_names():
'val': ['girl'],
}],
subheader='total female participants')),
Slice(
slice_name="Number of California Births",
viz_type='big_number_total',
datasource_type='table',
datasource_id=tbl.id,
params=get_slice_json(
defaults,
metric={
"expressionType": "SIMPLE",
"column": {
"column_name": "num_california",
"expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END",
},
"aggregate": "SUM",
"label": "SUM(num_california)",
},
viz_type="big_number_total",
granularity_sqla="ds")),
]
for slc in slices:
merge_slice(slc)