Fixed finding postaggregations (#4017)

This commit is contained in:
Jeff Niu 2017-12-06 21:55:43 -08:00 committed by Maxime Beauchemin
parent 5bc581fd44
commit cb7c5aa70c
3 changed files with 397 additions and 148 deletions

View File

@ -786,73 +786,123 @@ class DruidDatasource(Model, BaseDatasource):
return granularity
@staticmethod
def _metrics_and_post_aggs(metrics, metrics_dict):
all_metrics = []
post_aggs = {}
def get_post_agg(mconf):
"""
For a metric specified as `postagg` returns the
kind of post aggregation for pydruid.
"""
if mconf.get('type') == 'javascript':
return JavascriptPostAggregator(
name=mconf.get('name', ''),
field_names=mconf.get('fieldNames', []),
function=mconf.get('function', ''))
elif mconf.get('type') == 'quantile':
return Quantile(
mconf.get('name', ''),
mconf.get('probability', ''),
)
elif mconf.get('type') == 'quantiles':
return Quantiles(
mconf.get('name', ''),
mconf.get('probabilities', ''),
)
elif mconf.get('type') == 'fieldAccess':
return Field(mconf.get('name'))
elif mconf.get('type') == 'constant':
return Const(
mconf.get('value'),
output_name=mconf.get('name', ''),
)
elif mconf.get('type') == 'hyperUniqueCardinality':
return HyperUniqueCardinality(
mconf.get('name'),
)
elif mconf.get('type') == 'arithmetic':
return Postaggregator(
mconf.get('fn', '/'),
mconf.get('fields', []),
mconf.get('name', ''))
else:
return CustomPostAggregator(
mconf.get('name', ''),
mconf)
def recursive_get_fields(_conf):
_type = _conf.get('type')
_field = _conf.get('field')
_fields = _conf.get('fields')
@staticmethod
def find_postaggs_for(postagg_names, metrics_dict):
"""Return a list of metrics that are post aggregations"""
postagg_metrics = [
metrics_dict[name] for name in postagg_names
if metrics_dict[name].metric_type == 'postagg'
]
# Remove post aggregations that were found
for postagg in postagg_metrics:
postagg_names.remove(postagg.metric_name)
return postagg_metrics
field_names = []
if _type in ['fieldAccess', 'hyperUniqueCardinality',
'quantile', 'quantiles']:
field_names.append(_conf.get('fieldName', ''))
@staticmethod
def recursive_get_fields(_conf):
_type = _conf.get('type')
_field = _conf.get('field')
_fields = _conf.get('fields')
field_names = []
if _type in ['fieldAccess', 'hyperUniqueCardinality',
'quantile', 'quantiles']:
field_names.append(_conf.get('fieldName', ''))
if _field:
field_names += DruidDatasource.recursive_get_fields(_field)
if _fields:
for _f in _fields:
field_names += DruidDatasource.recursive_get_fields(_f)
return list(set(field_names))
if _field:
field_names += recursive_get_fields(_field)
if _fields:
for _f in _fields:
field_names += recursive_get_fields(_f)
return list(set(field_names))
@staticmethod
def resolve_postagg(postagg, post_aggs, agg_names, visited_postaggs, metrics_dict):
mconf = postagg.json_obj
required_fields = set(
DruidDatasource.recursive_get_fields(mconf)
+ mconf.get('fieldNames', []))
# Check if the fields are already in aggs
# or is a previous postagg
required_fields = set([
field for field in required_fields
if field not in visited_postaggs and field not in agg_names
])
# First try to find postaggs that match
if len(required_fields) > 0:
missing_postaggs = DruidDatasource.find_postaggs_for(
required_fields, metrics_dict)
for missing_metric in required_fields:
agg_names.add(missing_metric)
for missing_postagg in missing_postaggs:
# Add to visited first to avoid infinite recursion
# if post aggregations are cyclicly dependent
visited_postaggs.add(missing_postagg.metric_name)
for missing_postagg in missing_postaggs:
DruidDatasource.resolve_postagg(
missing_postagg, post_aggs, agg_names, visited_postaggs, metrics_dict)
post_aggs[postagg.metric_name] = DruidDatasource.get_post_agg(postagg.json_obj)
@staticmethod
def metrics_and_post_aggs(metrics, metrics_dict):
# Separate metrics into those that are aggregations
# and those that are post aggregations
agg_names = set()
postagg_names = []
for metric_name in metrics:
metric = metrics_dict[metric_name]
if metric.metric_type != 'postagg':
all_metrics.append(metric_name)
if metrics_dict[metric_name].metric_type != 'postagg':
agg_names.add(metric_name)
else:
mconf = metric.json_obj
all_metrics += recursive_get_fields(mconf)
all_metrics += mconf.get('fieldNames', [])
if mconf.get('type') == 'javascript':
post_aggs[metric_name] = JavascriptPostAggregator(
name=mconf.get('name', ''),
field_names=mconf.get('fieldNames', []),
function=mconf.get('function', ''))
elif mconf.get('type') == 'quantile':
post_aggs[metric_name] = Quantile(
mconf.get('name', ''),
mconf.get('probability', ''),
)
elif mconf.get('type') == 'quantiles':
post_aggs[metric_name] = Quantiles(
mconf.get('name', ''),
mconf.get('probabilities', ''),
)
elif mconf.get('type') == 'fieldAccess':
post_aggs[metric_name] = Field(mconf.get('name'))
elif mconf.get('type') == 'constant':
post_aggs[metric_name] = Const(
mconf.get('value'),
output_name=mconf.get('name', ''),
)
elif mconf.get('type') == 'hyperUniqueCardinality':
post_aggs[metric_name] = HyperUniqueCardinality(
mconf.get('name'),
)
elif mconf.get('type') == 'arithmetic':
post_aggs[metric_name] = Postaggregator(
mconf.get('fn', '/'),
mconf.get('fields', []),
mconf.get('name', ''))
else:
post_aggs[metric_name] = CustomPostAggregator(
mconf.get('name', ''),
mconf)
return all_metrics, post_aggs
postagg_names.append(metric_name)
# Create the post aggregations, maintain order since postaggs
# may depend on previous ones
post_aggs = OrderedDict()
visited_postaggs = set()
for postagg_name in postagg_names:
postagg = metrics_dict[postagg_name]
visited_postaggs.add(postagg_name)
DruidDatasource.resolve_postagg(
postagg, post_aggs, agg_names, visited_postaggs, metrics_dict)
return list(agg_names), post_aggs
def values_for_column(self,
column_name,
@ -940,7 +990,7 @@ class DruidDatasource(Model, BaseDatasource):
columns_dict = {c.column_name: c for c in self.columns}
all_metrics, post_aggs = self._metrics_and_post_aggs(
all_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
metrics,
metrics_dict)

View File

@ -2,12 +2,26 @@ import json
import unittest
from mock import Mock
import pydruid.utils.postaggregator as postaggs
import superset.connectors.druid.models as models
from superset.connectors.druid.models import (
DruidColumn, DruidDatasource, DruidMetric,
)
def mock_metric(metric_name, is_postagg=False):
metric = Mock()
metric.metric_name = metric_name
metric.metric_type = 'postagg' if is_postagg else 'metric'
return metric
def emplace(metrics_dict, metric_name, is_postagg=False):
metrics_dict[metric_name] = mock_metric(metric_name, is_postagg)
# Unit tests that can be run without initializing base tests
class DruidFuncTestCase(unittest.TestCase):
def test_get_filters_ignores_invalid_filter_objects(self):
@ -271,3 +285,273 @@ class DruidFuncTestCase(unittest.TestCase):
called_args = client.groupby.call_args_list[0][1]
self.assertIn('dimensions', called_args)
self.assertEqual(['col1', 'col2'], called_args['dimensions'])
def test_get_post_agg_returns_correct_agg_type(self):
get_post_agg = DruidDatasource.get_post_agg
# javascript PostAggregators
function = 'function(field1, field2) { return field1 + field2; }'
conf = {
'type': 'javascript',
'name': 'postagg_name',
'fieldNames': ['field1', 'field2'],
'function': function,
}
postagg = get_post_agg(conf)
self.assertTrue(isinstance(postagg, models.JavascriptPostAggregator))
self.assertEqual(postagg.name, 'postagg_name')
self.assertEqual(postagg.post_aggregator['type'], 'javascript')
self.assertEqual(postagg.post_aggregator['fieldNames'], ['field1', 'field2'])
self.assertEqual(postagg.post_aggregator['name'], 'postagg_name')
self.assertEqual(postagg.post_aggregator['function'], function)
# Quantile
conf = {
'type': 'quantile',
'name': 'postagg_name',
'probability': '0.5',
}
postagg = get_post_agg(conf)
self.assertTrue(isinstance(postagg, postaggs.Quantile))
self.assertEqual(postagg.name, 'postagg_name')
self.assertEqual(postagg.post_aggregator['probability'], '0.5')
# Quantiles
conf = {
'type': 'quantiles',
'name': 'postagg_name',
'probabilities': '0.4,0.5,0.6',
}
postagg = get_post_agg(conf)
self.assertTrue(isinstance(postagg, postaggs.Quantiles))
self.assertEqual(postagg.name, 'postagg_name')
self.assertEqual(postagg.post_aggregator['probabilities'], '0.4,0.5,0.6')
# FieldAccess
conf = {
'type': 'fieldAccess',
'name': 'field_name',
}
postagg = get_post_agg(conf)
self.assertTrue(isinstance(postagg, postaggs.Field))
self.assertEqual(postagg.name, 'field_name')
# constant
conf = {
'type': 'constant',
'value': 1234,
'name': 'postagg_name',
}
postagg = get_post_agg(conf)
self.assertTrue(isinstance(postagg, postaggs.Const))
self.assertEqual(postagg.name, 'postagg_name')
self.assertEqual(postagg.post_aggregator['value'], 1234)
# hyperUniqueCardinality
conf = {
'type': 'hyperUniqueCardinality',
'name': 'unique_name',
}
postagg = get_post_agg(conf)
self.assertTrue(isinstance(postagg, postaggs.HyperUniqueCardinality))
self.assertEqual(postagg.name, 'unique_name')
# arithmetic
conf = {
'type': 'arithmetic',
'fn': '+',
'fields': ['field1', 'field2'],
'name': 'postagg_name',
}
postagg = get_post_agg(conf)
self.assertTrue(isinstance(postagg, postaggs.Postaggregator))
self.assertEqual(postagg.name, 'postagg_name')
self.assertEqual(postagg.post_aggregator['fn'], '+')
self.assertEqual(postagg.post_aggregator['fields'], ['field1', 'field2'])
# custom post aggregator
conf = {
'type': 'custom',
'name': 'custom_name',
'stuff': 'more_stuff',
}
postagg = get_post_agg(conf)
self.assertTrue(isinstance(postagg, models.CustomPostAggregator))
self.assertEqual(postagg.name, 'custom_name')
self.assertEqual(postagg.post_aggregator['stuff'], 'more_stuff')
def test_find_postaggs_for_returns_postaggs_and_removes(self):
find_postaggs_for = DruidDatasource.find_postaggs_for
postagg_names = set(['pa2', 'pa3', 'pa4', 'm1', 'm2', 'm3', 'm4'])
metrics = {}
for i in range(1, 6):
emplace(metrics, 'pa' + str(i), True)
emplace(metrics, 'm' + str(i), False)
postagg_list = find_postaggs_for(postagg_names, metrics)
self.assertEqual(3, len(postagg_list))
self.assertEqual(4, len(postagg_names))
expected_metrics = ['m1', 'm2', 'm3', 'm4']
expected_postaggs = set(['pa2', 'pa3', 'pa4'])
for postagg in postagg_list:
expected_postaggs.remove(postagg.metric_name)
for metric in expected_metrics:
postagg_names.remove(metric)
self.assertEqual(0, len(expected_postaggs))
self.assertEqual(0, len(postagg_names))
def test_recursive_get_fields(self):
conf = {
'type': 'quantile',
'fieldName': 'f1',
'field': {
'type': 'custom',
'fields': [{
'type': 'fieldAccess',
'fieldName': 'f2',
}, {
'type': 'fieldAccess',
'fieldName': 'f3',
}, {
'type': 'quantiles',
'fieldName': 'f4',
'field': {
'type': 'custom',
},
}, {
'type': 'custom',
'fields': [{
'type': 'fieldAccess',
'fieldName': 'f5',
}, {
'type': 'fieldAccess',
'fieldName': 'f2',
'fields': [{
'type': 'fieldAccess',
'fieldName': 'f3',
}, {
'type': 'fieldIgnoreMe',
'fieldName': 'f6',
}],
}],
}],
},
}
fields = DruidDatasource.recursive_get_fields(conf)
expected = set(['f1', 'f2', 'f3', 'f4', 'f5'])
self.assertEqual(5, len(fields))
for field in fields:
expected.remove(field)
self.assertEqual(0, len(expected))
def test_metrics_and_post_aggs_tree(self):
metrics = ['A', 'B', 'm1', 'm2']
metrics_dict = {}
for i in range(ord('A'), ord('K') + 1):
emplace(metrics_dict, chr(i), True)
for i in range(1, 10):
emplace(metrics_dict, 'm' + str(i), False)
def depends_on(index, fields):
dependents = fields if isinstance(fields, list) else [fields]
metrics_dict[index].json_obj = {'fieldNames': dependents}
depends_on('A', ['m1', 'D', 'C'])
depends_on('B', ['B', 'C', 'E', 'F', 'm3'])
depends_on('C', ['H', 'I'])
depends_on('D', ['m2', 'm5', 'G', 'C'])
depends_on('E', ['H', 'I', 'J'])
depends_on('F', ['J', 'm5'])
depends_on('G', ['m4', 'm7', 'm6', 'A'])
depends_on('H', ['A', 'm4', 'I'])
depends_on('I', ['H', 'K'])
depends_on('J', 'K')
depends_on('K', ['m8', 'm9'])
all_metrics, postaggs = DruidDatasource.metrics_and_post_aggs(
metrics, metrics_dict)
expected_metrics = set(all_metrics)
self.assertEqual(9, len(all_metrics))
for i in range(1, 10):
expected_metrics.remove('m' + str(i))
self.assertEqual(0, len(expected_metrics))
self.assertEqual(11, len(postaggs))
for i in range(ord('A'), ord('K') + 1):
del postaggs[chr(i)]
self.assertEqual(0, len(postaggs))
def test_metrics_and_post_aggs(self):
"""
Test generation of metrics and post-aggregations from an initial list
of superset metrics (which may include the results of either). This
primarily tests that specifying a post-aggregator metric will also
require the raw aggregation of the associated druid metric column.
"""
metrics_dict = {
'unused_count': DruidMetric(
metric_name='unused_count',
verbose_name='COUNT(*)',
metric_type='count',
json=json.dumps({'type': 'count', 'name': 'unused_count'}),
),
'some_sum': DruidMetric(
metric_name='some_sum',
verbose_name='SUM(*)',
metric_type='sum',
json=json.dumps({'type': 'sum', 'name': 'sum'}),
),
'a_histogram': DruidMetric(
metric_name='a_histogram',
verbose_name='APPROXIMATE_HISTOGRAM(*)',
metric_type='approxHistogramFold',
json=json.dumps(
{'type': 'approxHistogramFold', 'name': 'a_histogram'},
),
),
'aCustomMetric': DruidMetric(
metric_name='aCustomMetric',
verbose_name='MY_AWESOME_METRIC(*)',
metric_type='aCustomType',
json=json.dumps(
{'type': 'customMetric', 'name': 'aCustomMetric'},
),
),
'quantile_p95': DruidMetric(
metric_name='quantile_p95',
verbose_name='P95(*)',
metric_type='postagg',
json=json.dumps({
'type': 'quantile',
'probability': 0.95,
'name': 'p95',
'fieldName': 'a_histogram',
}),
),
'aCustomPostAgg': DruidMetric(
metric_name='aCustomPostAgg',
verbose_name='CUSTOM_POST_AGG(*)',
metric_type='postagg',
json=json.dumps({
'type': 'customPostAgg',
'name': 'aCustomPostAgg',
'field': {
'type': 'fieldAccess',
'fieldName': 'aCustomMetric',
},
}),
),
}
metrics = ['some_sum']
all_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
metrics, metrics_dict)
assert all_metrics == ['some_sum']
assert post_aggs == {}
metrics = ['quantile_p95']
all_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
metrics, metrics_dict)
result_postaggs = set(['quantile_p95'])
assert all_metrics == ['a_histogram']
assert set(post_aggs.keys()) == result_postaggs
metrics = ['aCustomPostAgg']
all_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
metrics, metrics_dict)
result_postaggs = set(['aCustomPostAgg'])
assert all_metrics == ['aCustomMetric']
assert set(post_aggs.keys()) == result_postaggs

View File

@ -12,7 +12,7 @@ from mock import Mock, patch
from superset import db, security, sm
from superset.connectors.druid.models import (
DruidCluster, DruidDatasource, DruidMetric,
DruidCluster, DruidDatasource,
)
from .base_tests import SupersetTestCase
@ -328,91 +328,6 @@ class DruidTests(SupersetTestCase):
permission=permission, view_menu=view_menu).first()
assert pv is not None
def test_metrics_and_post_aggs(self):
"""
Test generation of metrics and post-aggregations from an initial list
of superset metrics (which may include the results of either). This
primarily tests that specifying a post-aggregator metric will also
require the raw aggregation of the associated druid metric column.
"""
metrics_dict = {
'unused_count': DruidMetric(
metric_name='unused_count',
verbose_name='COUNT(*)',
metric_type='count',
json=json.dumps({'type': 'count', 'name': 'unused_count'}),
),
'some_sum': DruidMetric(
metric_name='some_sum',
verbose_name='SUM(*)',
metric_type='sum',
json=json.dumps({'type': 'sum', 'name': 'sum'}),
),
'a_histogram': DruidMetric(
metric_name='a_histogram',
verbose_name='APPROXIMATE_HISTOGRAM(*)',
metric_type='approxHistogramFold',
json=json.dumps(
{'type': 'approxHistogramFold', 'name': 'a_histogram'},
),
),
'aCustomMetric': DruidMetric(
metric_name='aCustomMetric',
verbose_name='MY_AWESOME_METRIC(*)',
metric_type='aCustomType',
json=json.dumps(
{'type': 'customMetric', 'name': 'aCustomMetric'},
),
),
'quantile_p95': DruidMetric(
metric_name='quantile_p95',
verbose_name='P95(*)',
metric_type='postagg',
json=json.dumps({
'type': 'quantile',
'probability': 0.95,
'name': 'p95',
'fieldName': 'a_histogram',
}),
),
'aCustomPostAgg': DruidMetric(
metric_name='aCustomPostAgg',
verbose_name='CUSTOM_POST_AGG(*)',
metric_type='postagg',
json=json.dumps({
'type': 'customPostAgg',
'name': 'aCustomPostAgg',
'field': {
'type': 'fieldAccess',
'fieldName': 'aCustomMetric',
},
}),
),
}
metrics = ['some_sum']
all_metrics, post_aggs = DruidDatasource._metrics_and_post_aggs(
metrics, metrics_dict)
assert all_metrics == ['some_sum']
assert post_aggs == {}
metrics = ['quantile_p95']
all_metrics, post_aggs = DruidDatasource._metrics_and_post_aggs(
metrics, metrics_dict)
result_postaggs = set(['quantile_p95'])
assert all_metrics == ['a_histogram']
assert set(post_aggs.keys()) == result_postaggs
metrics = ['aCustomPostAgg']
all_metrics, post_aggs = DruidDatasource._metrics_and_post_aggs(
metrics, metrics_dict)
result_postaggs = set(['aCustomPostAgg'])
assert all_metrics == ['aCustomMetric']
assert set(post_aggs.keys()) == result_postaggs
if __name__ == '__main__':
unittest.main()