Allow MetricsControl to aggregate on a column with an expression (#5021)
* Allow MetricsControl to aggregate on a column with an expression * Adding test case for metrics based on columns
This commit is contained in:
parent
b312cdad2f
commit
b8aeb1a825
|
|
@ -450,10 +450,24 @@ class SqlaTable(Model, BaseDatasource):
|
|||
return TextAsFrom(sa.text(from_sql), []).alias('expr_qry')
|
||||
return self.get_sqla_table()
|
||||
|
||||
def adhoc_metric_to_sa(self, metric):
|
||||
def adhoc_metric_to_sa(self, metric, cols):
|
||||
"""
|
||||
Turn an adhoc metric into a sqlalchemy column.
|
||||
|
||||
:param dict metric: Adhoc metric definition
|
||||
:param dict cols: Columns for the current table
|
||||
:returns: The metric defined as a sqlalchemy column
|
||||
:rtype: sqlalchemy.sql.column
|
||||
"""
|
||||
expressionType = metric.get('expressionType')
|
||||
if expressionType == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']:
|
||||
sa_column = column(metric.get('column').get('column_name'))
|
||||
column_name = metric.get('column').get('column_name')
|
||||
sa_column = column(column_name)
|
||||
table_column = cols.get(column_name)
|
||||
|
||||
if table_column:
|
||||
sa_column = table_column.sqla_col
|
||||
|
||||
sa_metric = self.sqla_aggregations[metric.get('aggregate')](sa_column)
|
||||
sa_metric = sa_metric.label(metric.get('label'))
|
||||
return sa_metric
|
||||
|
|
@ -518,7 +532,7 @@ class SqlaTable(Model, BaseDatasource):
|
|||
metrics_exprs = []
|
||||
for m in metrics:
|
||||
if utils.is_adhoc_metric(m):
|
||||
metrics_exprs.append(self.adhoc_metric_to_sa(m))
|
||||
metrics_exprs.append(self.adhoc_metric_to_sa(m, cols))
|
||||
elif m in metrics_dict:
|
||||
metrics_exprs.append(metrics_dict.get(m).sqla_col)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ import polyline
|
|||
|
||||
from superset import app, db, security_manager, utils
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.connectors.sqla.models import TableColumn
|
||||
from superset.models import core as models
|
||||
|
||||
# Shortcuts
|
||||
|
|
@ -585,6 +586,10 @@ def load_birth_names():
|
|||
obj.main_dttm_col = 'ds'
|
||||
obj.database = utils.get_or_create_main_db()
|
||||
obj.filter_select_enabled = True
|
||||
obj.columns.append(TableColumn(
|
||||
column_name='num_california',
|
||||
expression="CASE WHEN state = 'CA' THEN num ELSE 0 END"
|
||||
))
|
||||
db.session.merge(obj)
|
||||
db.session.commit()
|
||||
obj.fetch_metadata()
|
||||
|
|
@ -737,6 +742,24 @@ def load_birth_names():
|
|||
'val': ['girl'],
|
||||
}],
|
||||
subheader='total female participants')),
|
||||
Slice(
|
||||
slice_name="Number of California Births",
|
||||
viz_type='big_number_total',
|
||||
datasource_type='table',
|
||||
datasource_id=tbl.id,
|
||||
params=get_slice_json(
|
||||
defaults,
|
||||
metric={
|
||||
"expressionType": "SIMPLE",
|
||||
"column": {
|
||||
"column_name": "num_california",
|
||||
"expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END",
|
||||
},
|
||||
"aggregate": "SUM",
|
||||
"label": "SUM(num_california)",
|
||||
},
|
||||
viz_type="big_number_total",
|
||||
granularity_sqla="ds")),
|
||||
]
|
||||
for slc in slices:
|
||||
merge_slice(slc)
|
||||
|
|
|
|||
Loading…
Reference in New Issue