From cb7c5aa70c3729d8f1fe0c310d30e620f9e9a581 Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Wed, 6 Dec 2017 21:55:43 -0800 Subject: [PATCH] Fixed finding postaggregations (#4017) --- superset/connectors/druid/models.py | 174 +++++++++++------ tests/druid_func_tests.py | 284 ++++++++++++++++++++++++++++ tests/druid_tests.py | 87 +-------- 3 files changed, 397 insertions(+), 148 deletions(-) diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index bf7e17643..acb1951c9 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -786,73 +786,123 @@ class DruidDatasource(Model, BaseDatasource): return granularity @staticmethod - def _metrics_and_post_aggs(metrics, metrics_dict): - all_metrics = [] - post_aggs = {} + def get_post_agg(mconf): + """ + For a metric specified as `postagg` returns the + kind of post aggregation for pydruid. + """ + if mconf.get('type') == 'javascript': + return JavascriptPostAggregator( + name=mconf.get('name', ''), + field_names=mconf.get('fieldNames', []), + function=mconf.get('function', '')) + elif mconf.get('type') == 'quantile': + return Quantile( + mconf.get('name', ''), + mconf.get('probability', ''), + ) + elif mconf.get('type') == 'quantiles': + return Quantiles( + mconf.get('name', ''), + mconf.get('probabilities', ''), + ) + elif mconf.get('type') == 'fieldAccess': + return Field(mconf.get('name')) + elif mconf.get('type') == 'constant': + return Const( + mconf.get('value'), + output_name=mconf.get('name', ''), + ) + elif mconf.get('type') == 'hyperUniqueCardinality': + return HyperUniqueCardinality( + mconf.get('name'), + ) + elif mconf.get('type') == 'arithmetic': + return Postaggregator( + mconf.get('fn', '/'), + mconf.get('fields', []), + mconf.get('name', '')) + else: + return CustomPostAggregator( + mconf.get('name', ''), + mconf) - def recursive_get_fields(_conf): - _type = _conf.get('type') - _field = _conf.get('field') - _fields = _conf.get('fields') + @staticmethod + def find_postaggs_for(postagg_names, metrics_dict): + """Return a list of metrics that are post aggregations""" + postagg_metrics = [ + metrics_dict[name] for name in postagg_names + if metrics_dict[name].metric_type == 'postagg' + ] + # Remove post aggregations that were found + for postagg in postagg_metrics: + postagg_names.remove(postagg.metric_name) + return postagg_metrics - field_names = [] - if _type in ['fieldAccess', 'hyperUniqueCardinality', - 'quantile', 'quantiles']: - field_names.append(_conf.get('fieldName', '')) + @staticmethod + def recursive_get_fields(_conf): + _type = _conf.get('type') + _field = _conf.get('field') + _fields = _conf.get('fields') + field_names = [] + if _type in ['fieldAccess', 'hyperUniqueCardinality', + 'quantile', 'quantiles']: + field_names.append(_conf.get('fieldName', '')) + if _field: + field_names += DruidDatasource.recursive_get_fields(_field) + if _fields: + for _f in _fields: + field_names += DruidDatasource.recursive_get_fields(_f) + return list(set(field_names)) - if _field: - field_names += recursive_get_fields(_field) - - if _fields: - for _f in _fields: - field_names += recursive_get_fields(_f) - - return list(set(field_names)) + @staticmethod + def resolve_postagg(postagg, post_aggs, agg_names, visited_postaggs, metrics_dict): + mconf = postagg.json_obj + required_fields = set( + DruidDatasource.recursive_get_fields(mconf) + + mconf.get('fieldNames', [])) + # Check if the fields are already in aggs + # or is a previous postagg + required_fields = set([ + field for field in required_fields + if field not in visited_postaggs and field not in agg_names + ]) + # First try to find postaggs that match + if len(required_fields) > 0: + missing_postaggs = DruidDatasource.find_postaggs_for( + required_fields, metrics_dict) + for missing_metric in required_fields: + agg_names.add(missing_metric) + for missing_postagg in missing_postaggs: + # Add to visited first to avoid infinite recursion + # if post aggregations are cyclicly dependent + visited_postaggs.add(missing_postagg.metric_name) + for missing_postagg in missing_postaggs: + DruidDatasource.resolve_postagg( + missing_postagg, post_aggs, agg_names, visited_postaggs, metrics_dict) + post_aggs[postagg.metric_name] = DruidDatasource.get_post_agg(postagg.json_obj) + @staticmethod + def metrics_and_post_aggs(metrics, metrics_dict): + # Separate metrics into those that are aggregations + # and those that are post aggregations + agg_names = set() + postagg_names = [] for metric_name in metrics: - metric = metrics_dict[metric_name] - if metric.metric_type != 'postagg': - all_metrics.append(metric_name) + if metrics_dict[metric_name].metric_type != 'postagg': + agg_names.add(metric_name) else: - mconf = metric.json_obj - all_metrics += recursive_get_fields(mconf) - all_metrics += mconf.get('fieldNames', []) - if mconf.get('type') == 'javascript': - post_aggs[metric_name] = JavascriptPostAggregator( - name=mconf.get('name', ''), - field_names=mconf.get('fieldNames', []), - function=mconf.get('function', '')) - elif mconf.get('type') == 'quantile': - post_aggs[metric_name] = Quantile( - mconf.get('name', ''), - mconf.get('probability', ''), - ) - elif mconf.get('type') == 'quantiles': - post_aggs[metric_name] = Quantiles( - mconf.get('name', ''), - mconf.get('probabilities', ''), - ) - elif mconf.get('type') == 'fieldAccess': - post_aggs[metric_name] = Field(mconf.get('name')) - elif mconf.get('type') == 'constant': - post_aggs[metric_name] = Const( - mconf.get('value'), - output_name=mconf.get('name', ''), - ) - elif mconf.get('type') == 'hyperUniqueCardinality': - post_aggs[metric_name] = HyperUniqueCardinality( - mconf.get('name'), - ) - elif mconf.get('type') == 'arithmetic': - post_aggs[metric_name] = Postaggregator( - mconf.get('fn', '/'), - mconf.get('fields', []), - mconf.get('name', '')) - else: - post_aggs[metric_name] = CustomPostAggregator( - mconf.get('name', ''), - mconf) - return all_metrics, post_aggs + postagg_names.append(metric_name) + # Create the post aggregations, maintain order since postaggs + # may depend on previous ones + post_aggs = OrderedDict() + visited_postaggs = set() + for postagg_name in postagg_names: + postagg = metrics_dict[postagg_name] + visited_postaggs.add(postagg_name) + DruidDatasource.resolve_postagg( + postagg, post_aggs, agg_names, visited_postaggs, metrics_dict) + return list(agg_names), post_aggs def values_for_column(self, column_name, @@ -940,7 +990,7 @@ class DruidDatasource(Model, BaseDatasource): columns_dict = {c.column_name: c for c in self.columns} - all_metrics, post_aggs = self._metrics_and_post_aggs( + all_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( metrics, metrics_dict) diff --git a/tests/druid_func_tests.py b/tests/druid_func_tests.py index ba1f49793..4c047dff1 100644 --- a/tests/druid_func_tests.py +++ b/tests/druid_func_tests.py @@ -2,12 +2,26 @@ import json import unittest from mock import Mock +import pydruid.utils.postaggregator as postaggs +import superset.connectors.druid.models as models from superset.connectors.druid.models import ( DruidColumn, DruidDatasource, DruidMetric, ) +def mock_metric(metric_name, is_postagg=False): + metric = Mock() + metric.metric_name = metric_name + metric.metric_type = 'postagg' if is_postagg else 'metric' + return metric + + +def emplace(metrics_dict, metric_name, is_postagg=False): + metrics_dict[metric_name] = mock_metric(metric_name, is_postagg) + + +# Unit tests that can be run without initializing base tests class DruidFuncTestCase(unittest.TestCase): def test_get_filters_ignores_invalid_filter_objects(self): @@ -271,3 +285,273 @@ class DruidFuncTestCase(unittest.TestCase): called_args = client.groupby.call_args_list[0][1] self.assertIn('dimensions', called_args) self.assertEqual(['col1', 'col2'], called_args['dimensions']) + + def test_get_post_agg_returns_correct_agg_type(self): + get_post_agg = DruidDatasource.get_post_agg + # javascript PostAggregators + function = 'function(field1, field2) { return field1 + field2; }' + conf = { + 'type': 'javascript', + 'name': 'postagg_name', + 'fieldNames': ['field1', 'field2'], + 'function': function, + } + postagg = get_post_agg(conf) + self.assertTrue(isinstance(postagg, models.JavascriptPostAggregator)) + self.assertEqual(postagg.name, 'postagg_name') + self.assertEqual(postagg.post_aggregator['type'], 'javascript') + self.assertEqual(postagg.post_aggregator['fieldNames'], ['field1', 'field2']) + self.assertEqual(postagg.post_aggregator['name'], 'postagg_name') + self.assertEqual(postagg.post_aggregator['function'], function) + # Quantile + conf = { + 'type': 'quantile', + 'name': 'postagg_name', + 'probability': '0.5', + } + postagg = get_post_agg(conf) + self.assertTrue(isinstance(postagg, postaggs.Quantile)) + self.assertEqual(postagg.name, 'postagg_name') + self.assertEqual(postagg.post_aggregator['probability'], '0.5') + # Quantiles + conf = { + 'type': 'quantiles', + 'name': 'postagg_name', + 'probabilities': '0.4,0.5,0.6', + } + postagg = get_post_agg(conf) + self.assertTrue(isinstance(postagg, postaggs.Quantiles)) + self.assertEqual(postagg.name, 'postagg_name') + self.assertEqual(postagg.post_aggregator['probabilities'], '0.4,0.5,0.6') + # FieldAccess + conf = { + 'type': 'fieldAccess', + 'name': 'field_name', + } + postagg = get_post_agg(conf) + self.assertTrue(isinstance(postagg, postaggs.Field)) + self.assertEqual(postagg.name, 'field_name') + # constant + conf = { + 'type': 'constant', + 'value': 1234, + 'name': 'postagg_name', + } + postagg = get_post_agg(conf) + self.assertTrue(isinstance(postagg, postaggs.Const)) + self.assertEqual(postagg.name, 'postagg_name') + self.assertEqual(postagg.post_aggregator['value'], 1234) + # hyperUniqueCardinality + conf = { + 'type': 'hyperUniqueCardinality', + 'name': 'unique_name', + } + postagg = get_post_agg(conf) + self.assertTrue(isinstance(postagg, postaggs.HyperUniqueCardinality)) + self.assertEqual(postagg.name, 'unique_name') + # arithmetic + conf = { + 'type': 'arithmetic', + 'fn': '+', + 'fields': ['field1', 'field2'], + 'name': 'postagg_name', + } + postagg = get_post_agg(conf) + self.assertTrue(isinstance(postagg, postaggs.Postaggregator)) + self.assertEqual(postagg.name, 'postagg_name') + self.assertEqual(postagg.post_aggregator['fn'], '+') + self.assertEqual(postagg.post_aggregator['fields'], ['field1', 'field2']) + # custom post aggregator + conf = { + 'type': 'custom', + 'name': 'custom_name', + 'stuff': 'more_stuff', + } + postagg = get_post_agg(conf) + self.assertTrue(isinstance(postagg, models.CustomPostAggregator)) + self.assertEqual(postagg.name, 'custom_name') + self.assertEqual(postagg.post_aggregator['stuff'], 'more_stuff') + + def test_find_postaggs_for_returns_postaggs_and_removes(self): + find_postaggs_for = DruidDatasource.find_postaggs_for + postagg_names = set(['pa2', 'pa3', 'pa4', 'm1', 'm2', 'm3', 'm4']) + + metrics = {} + for i in range(1, 6): + emplace(metrics, 'pa' + str(i), True) + emplace(metrics, 'm' + str(i), False) + postagg_list = find_postaggs_for(postagg_names, metrics) + self.assertEqual(3, len(postagg_list)) + self.assertEqual(4, len(postagg_names)) + expected_metrics = ['m1', 'm2', 'm3', 'm4'] + expected_postaggs = set(['pa2', 'pa3', 'pa4']) + for postagg in postagg_list: + expected_postaggs.remove(postagg.metric_name) + for metric in expected_metrics: + postagg_names.remove(metric) + self.assertEqual(0, len(expected_postaggs)) + self.assertEqual(0, len(postagg_names)) + + def test_recursive_get_fields(self): + conf = { + 'type': 'quantile', + 'fieldName': 'f1', + 'field': { + 'type': 'custom', + 'fields': [{ + 'type': 'fieldAccess', + 'fieldName': 'f2', + }, { + 'type': 'fieldAccess', + 'fieldName': 'f3', + }, { + 'type': 'quantiles', + 'fieldName': 'f4', + 'field': { + 'type': 'custom', + }, + }, { + 'type': 'custom', + 'fields': [{ + 'type': 'fieldAccess', + 'fieldName': 'f5', + }, { + 'type': 'fieldAccess', + 'fieldName': 'f2', + 'fields': [{ + 'type': 'fieldAccess', + 'fieldName': 'f3', + }, { + 'type': 'fieldIgnoreMe', + 'fieldName': 'f6', + }], + }], + }], + }, + } + fields = DruidDatasource.recursive_get_fields(conf) + expected = set(['f1', 'f2', 'f3', 'f4', 'f5']) + self.assertEqual(5, len(fields)) + for field in fields: + expected.remove(field) + self.assertEqual(0, len(expected)) + + def test_metrics_and_post_aggs_tree(self): + metrics = ['A', 'B', 'm1', 'm2'] + metrics_dict = {} + for i in range(ord('A'), ord('K') + 1): + emplace(metrics_dict, chr(i), True) + for i in range(1, 10): + emplace(metrics_dict, 'm' + str(i), False) + + def depends_on(index, fields): + dependents = fields if isinstance(fields, list) else [fields] + metrics_dict[index].json_obj = {'fieldNames': dependents} + + depends_on('A', ['m1', 'D', 'C']) + depends_on('B', ['B', 'C', 'E', 'F', 'm3']) + depends_on('C', ['H', 'I']) + depends_on('D', ['m2', 'm5', 'G', 'C']) + depends_on('E', ['H', 'I', 'J']) + depends_on('F', ['J', 'm5']) + depends_on('G', ['m4', 'm7', 'm6', 'A']) + depends_on('H', ['A', 'm4', 'I']) + depends_on('I', ['H', 'K']) + depends_on('J', 'K') + depends_on('K', ['m8', 'm9']) + all_metrics, postaggs = DruidDatasource.metrics_and_post_aggs( + metrics, metrics_dict) + expected_metrics = set(all_metrics) + self.assertEqual(9, len(all_metrics)) + for i in range(1, 10): + expected_metrics.remove('m' + str(i)) + self.assertEqual(0, len(expected_metrics)) + self.assertEqual(11, len(postaggs)) + for i in range(ord('A'), ord('K') + 1): + del postaggs[chr(i)] + self.assertEqual(0, len(postaggs)) + + def test_metrics_and_post_aggs(self): + """ + Test generation of metrics and post-aggregations from an initial list + of superset metrics (which may include the results of either). This + primarily tests that specifying a post-aggregator metric will also + require the raw aggregation of the associated druid metric column. + """ + metrics_dict = { + 'unused_count': DruidMetric( + metric_name='unused_count', + verbose_name='COUNT(*)', + metric_type='count', + json=json.dumps({'type': 'count', 'name': 'unused_count'}), + ), + 'some_sum': DruidMetric( + metric_name='some_sum', + verbose_name='SUM(*)', + metric_type='sum', + json=json.dumps({'type': 'sum', 'name': 'sum'}), + ), + 'a_histogram': DruidMetric( + metric_name='a_histogram', + verbose_name='APPROXIMATE_HISTOGRAM(*)', + metric_type='approxHistogramFold', + json=json.dumps( + {'type': 'approxHistogramFold', 'name': 'a_histogram'}, + ), + ), + 'aCustomMetric': DruidMetric( + metric_name='aCustomMetric', + verbose_name='MY_AWESOME_METRIC(*)', + metric_type='aCustomType', + json=json.dumps( + {'type': 'customMetric', 'name': 'aCustomMetric'}, + ), + ), + 'quantile_p95': DruidMetric( + metric_name='quantile_p95', + verbose_name='P95(*)', + metric_type='postagg', + json=json.dumps({ + 'type': 'quantile', + 'probability': 0.95, + 'name': 'p95', + 'fieldName': 'a_histogram', + }), + ), + 'aCustomPostAgg': DruidMetric( + metric_name='aCustomPostAgg', + verbose_name='CUSTOM_POST_AGG(*)', + metric_type='postagg', + json=json.dumps({ + 'type': 'customPostAgg', + 'name': 'aCustomPostAgg', + 'field': { + 'type': 'fieldAccess', + 'fieldName': 'aCustomMetric', + }, + }), + ), + } + + metrics = ['some_sum'] + all_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( + metrics, metrics_dict) + + assert all_metrics == ['some_sum'] + assert post_aggs == {} + + metrics = ['quantile_p95'] + all_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( + metrics, metrics_dict) + + result_postaggs = set(['quantile_p95']) + assert all_metrics == ['a_histogram'] + assert set(post_aggs.keys()) == result_postaggs + + metrics = ['aCustomPostAgg'] + all_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( + metrics, metrics_dict) + + result_postaggs = set(['aCustomPostAgg']) + assert all_metrics == ['aCustomMetric'] + assert set(post_aggs.keys()) == result_postaggs diff --git a/tests/druid_tests.py b/tests/druid_tests.py index c9dce339d..c280da790 100644 --- a/tests/druid_tests.py +++ b/tests/druid_tests.py @@ -12,7 +12,7 @@ from mock import Mock, patch from superset import db, security, sm from superset.connectors.druid.models import ( - DruidCluster, DruidDatasource, DruidMetric, + DruidCluster, DruidDatasource, ) from .base_tests import SupersetTestCase @@ -328,91 +328,6 @@ class DruidTests(SupersetTestCase): permission=permission, view_menu=view_menu).first() assert pv is not None - def test_metrics_and_post_aggs(self): - """ - Test generation of metrics and post-aggregations from an initial list - of superset metrics (which may include the results of either). This - primarily tests that specifying a post-aggregator metric will also - require the raw aggregation of the associated druid metric column. - """ - metrics_dict = { - 'unused_count': DruidMetric( - metric_name='unused_count', - verbose_name='COUNT(*)', - metric_type='count', - json=json.dumps({'type': 'count', 'name': 'unused_count'}), - ), - 'some_sum': DruidMetric( - metric_name='some_sum', - verbose_name='SUM(*)', - metric_type='sum', - json=json.dumps({'type': 'sum', 'name': 'sum'}), - ), - 'a_histogram': DruidMetric( - metric_name='a_histogram', - verbose_name='APPROXIMATE_HISTOGRAM(*)', - metric_type='approxHistogramFold', - json=json.dumps( - {'type': 'approxHistogramFold', 'name': 'a_histogram'}, - ), - ), - 'aCustomMetric': DruidMetric( - metric_name='aCustomMetric', - verbose_name='MY_AWESOME_METRIC(*)', - metric_type='aCustomType', - json=json.dumps( - {'type': 'customMetric', 'name': 'aCustomMetric'}, - ), - ), - 'quantile_p95': DruidMetric( - metric_name='quantile_p95', - verbose_name='P95(*)', - metric_type='postagg', - json=json.dumps({ - 'type': 'quantile', - 'probability': 0.95, - 'name': 'p95', - 'fieldName': 'a_histogram', - }), - ), - 'aCustomPostAgg': DruidMetric( - metric_name='aCustomPostAgg', - verbose_name='CUSTOM_POST_AGG(*)', - metric_type='postagg', - json=json.dumps({ - 'type': 'customPostAgg', - 'name': 'aCustomPostAgg', - 'field': { - 'type': 'fieldAccess', - 'fieldName': 'aCustomMetric', - }, - }), - ), - } - - metrics = ['some_sum'] - all_metrics, post_aggs = DruidDatasource._metrics_and_post_aggs( - metrics, metrics_dict) - - assert all_metrics == ['some_sum'] - assert post_aggs == {} - - metrics = ['quantile_p95'] - all_metrics, post_aggs = DruidDatasource._metrics_and_post_aggs( - metrics, metrics_dict) - - result_postaggs = set(['quantile_p95']) - assert all_metrics == ['a_histogram'] - assert set(post_aggs.keys()) == result_postaggs - - metrics = ['aCustomPostAgg'] - all_metrics, post_aggs = DruidDatasource._metrics_and_post_aggs( - metrics, metrics_dict) - - result_postaggs = set(['aCustomPostAgg']) - assert all_metrics == ['aCustomMetric'] - assert set(post_aggs.keys()) == result_postaggs - if __name__ == '__main__': unittest.main()