diff --git a/superset/cli.py b/superset/cli.py index 48db7394b..95f6df7b9 100755 --- a/superset/cli.py +++ b/superset/cli.py @@ -166,10 +166,9 @@ def load_examples(load_test_data): ) @manager.option( '-m', '--merge', - help=( - "Specify using 'merge' property during operation. " - 'Default value is False ' - ), + action='store_true', + help="Specify using 'merge' property during operation.", + default=False, ) def refresh_druid(datasource, merge): """Refresh druid datasources""" diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index 74eb575d0..f4f137ef2 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -29,7 +29,7 @@ import requests from six import string_types import sqlalchemy as sa from sqlalchemy import ( - Boolean, Column, DateTime, ForeignKey, Integer, or_, String, Text, UniqueConstraint, + Boolean, Column, DateTime, ForeignKey, Integer, String, Text, UniqueConstraint, ) from sqlalchemy.orm import backref, relationship @@ -200,33 +200,31 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin): col_objs_list = ( session.query(DruidColumn) .filter(DruidColumn.datasource_id == datasource.id) - .filter(or_(DruidColumn.column_name == col for col in cols)) + .filter(DruidColumn.column_name.in_(cols.keys())) ) col_objs = {col.column_name: col for col in col_objs_list} for col in cols: if col == '__time': # skip the time column continue - col_obj = col_objs.get(col, None) + col_obj = col_objs.get(col) if not col_obj: col_obj = DruidColumn( datasource_id=datasource.id, column_name=col) with session.no_autoflush: session.add(col_obj) - datatype = cols[col]['type'] - if datatype == 'STRING': + col_obj.type = cols[col]['type'] + col_obj.datasource = datasource + if col_obj.type == 'STRING': col_obj.groupby = True col_obj.filterable = True - if datatype == 'hyperUnique' or datatype == 'thetaSketch': + if col_obj.type == 'hyperUnique' or col_obj.type == 'thetaSketch': col_obj.count_distinct = True - # Allow sum/min/max for long or double - if datatype == 'LONG' or datatype == 'DOUBLE': + if col_obj.is_num: col_obj.sum = True col_obj.min = True col_obj.max = True - col_obj.type = datatype - col_obj.datasource = datasource - datasource.generate_metrics_for(col_objs_list) + datasource.refresh_metrics() session.commit() @property @@ -361,21 +359,24 @@ class DruidColumn(Model, BaseColumn): ) return metrics - def generate_metrics(self): - """Generate metrics based on the column metadata""" + def refresh_metrics(self): + """Refresh metrics based on the column metadata""" metrics = self.get_metrics() dbmetrics = ( db.session.query(DruidMetric) .filter(DruidMetric.datasource_id == self.datasource_id) - .filter(or_( - DruidMetric.metric_name == m for m in metrics - )) + .filter(DruidMetric.metric_name.in_(metrics.keys())) ) dbmetrics = {metric.metric_name: metric for metric in dbmetrics} for metric in metrics.values(): - metric.datasource_id = self.datasource_id - if not dbmetrics.get(metric.metric_name, None): - db.session.add(metric) + dbmetric = dbmetrics.get(metric.metric_name) + if dbmetric: + for attr in ['json', 'metric_type', 'verbose_name']: + setattr(dbmetric, attr, getattr(metric, attr)) + else: + with db.session.no_autoflush: + metric.datasource_id = self.datasource_id + db.session.add(metric) @classmethod def import_obj(cls, i_column): @@ -653,24 +654,9 @@ class DruidDatasource(Model, BaseDatasource): if segment_metadata: return segment_metadata[-1]['columns'] - def generate_metrics(self): - self.generate_metrics_for(self.columns) - - def generate_metrics_for(self, columns): - metrics = {} - for col in columns: - metrics.update(col.get_metrics()) - dbmetrics = ( - db.session.query(DruidMetric) - .filter(DruidMetric.datasource_id == self.id) - .filter(or_(DruidMetric.metric_name == m for m in metrics)) - ) - dbmetrics = {metric.metric_name: metric for metric in dbmetrics} - for metric in metrics.values(): - metric.datasource_id = self.id - if not dbmetrics.get(metric.metric_name, None): - with db.session.no_autoflush: - db.session.add(metric) + def refresh_metrics(self): + for col in self.columns: + col.refresh_metrics() @classmethod def sync_to_db_from_config( @@ -703,7 +689,7 @@ class DruidDatasource(Model, BaseDatasource): col_objs = ( session.query(DruidColumn) .filter(DruidColumn.datasource_id == datasource.id) - .filter(or_(DruidColumn.column_name == dim for dim in dimensions)) + .filter(DruidColumn.column_name.in_(dimensions)) ) col_objs = {col.column_name: col for col in col_objs} for dim in dimensions: @@ -723,8 +709,9 @@ class DruidDatasource(Model, BaseDatasource): metric_objs = ( session.query(DruidMetric) .filter(DruidMetric.datasource_id == datasource.id) - .filter(or_(DruidMetric.metric_name == spec['name'] - for spec in druid_config['metrics_spec'])) + .filter(DruidMetric.metric_name.in_( + spec['name'] for spec in druid_config['metrics_spec'] + )) ) metric_objs = {metric.metric_name: metric for metric in metric_objs} for metric_spec in druid_config['metrics_spec']: diff --git a/superset/connectors/druid/views.py b/superset/connectors/druid/views.py index 53b3670de..f4a20891c 100644 --- a/superset/connectors/druid/views.py +++ b/superset/connectors/druid/views.py @@ -91,7 +91,7 @@ class DruidColumnInlineView(CompactCRUDMixin, SupersetModelView): # noqa .format(dimension_spec['outputName'], col.column_name)) def post_update(self, col): - col.generate_metrics() + col.refresh_metrics() def post_add(self, col): self.post_update(col) @@ -277,7 +277,7 @@ class DruidDatasourceModelView(DatasourceModelView, DeleteMixin, YamlExportMixin datasource.full_name)) def post_add(self, datasource): - datasource.generate_metrics() + datasource.refresh_metrics() security.merge_perm(sm, 'datasource_access', datasource.get_perm()) if datasource.schema: security.merge_perm(sm, 'schema_access', datasource.schema_perm) diff --git a/superset/migrations/versions/f231d82b9b26_.py b/superset/migrations/versions/f231d82b9b26_.py new file mode 100644 index 000000000..8a2aa1a75 --- /dev/null +++ b/superset/migrations/versions/f231d82b9b26_.py @@ -0,0 +1,72 @@ +"""empty message + +Revision ID: f231d82b9b26 +Revises: e68c4473c581 +Create Date: 2018-03-20 19:47:54.991259 + +""" + +# revision identifiers, used by Alembic. +revision = 'f231d82b9b26' +down_revision = 'e68c4473c581' + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.exc import OperationalError + +from superset.utils import generic_find_uq_constraint_name + +conv = { + 'uq': 'uq_%(table_name)s_%(column_0_name)s', +} + +names = { + 'columns': 'column_name', + 'metrics': 'metric_name', +} + +bind = op.get_bind() +insp = sa.engine.reflection.Inspector.from_engine(bind) + + +def upgrade(): + + # Reduce the size of the metric_name column for constraint viability. + with op.batch_alter_table('metrics', naming_convention=conv) as batch_op: + batch_op.alter_column( + 'metric_name', + existing_type=sa.String(length=512), + type_=sa.String(length=255), + existing_nullable=True, + ) + + # Add the missing uniqueness constraints. + for table, column in names.items(): + with op.batch_alter_table(table, naming_convention=conv) as batch_op: + batch_op.create_unique_constraint( + 'uq_{}_{}'.format(table, column), + [column, 'datasource_id'], + ) + +def downgrade(): + + # Restore the size of the metric_name column. + with op.batch_alter_table('metrics', naming_convention=conv) as batch_op: + batch_op.alter_column( + 'metric_name', + existing_type=sa.String(length=255), + type_=sa.String(length=512), + existing_nullable=True, + ) + + # Remove the previous missing uniqueness constraints. + for table, column in names.items(): + with op.batch_alter_table(table, naming_convention=conv) as batch_op: + batch_op.drop_constraint( + generic_find_uq_constraint_name( + table, + {column, 'datasource_id'}, + insp, + ) or 'uq_{}_{}'.format(table, column), + type_='unique', + ) diff --git a/tests/druid_tests.py b/tests/druid_tests.py index fc360b665..e72eb25ee 100644 --- a/tests/druid_tests.py +++ b/tests/druid_tests.py @@ -13,7 +13,7 @@ from mock import Mock, patch from superset import db, security, sm from superset.connectors.druid.models import ( - DruidCluster, DruidDatasource, + DruidCluster, DruidColumn, DruidDatasource, DruidMetric, ) from .base_tests import SupersetTestCase @@ -29,22 +29,27 @@ SEGMENT_METADATA = [{ 'columns': { '__time': { 'type': 'LONG', 'hasMultipleValues': False, - 'size': 407240380, 'cardinality': None, 'errorMessage': None}, + 'size': 407240380, 'cardinality': None, 'errorMessage': None, + }, 'dim1': { 'type': 'STRING', 'hasMultipleValues': False, - 'size': 100000, 'cardinality': 1944, 'errorMessage': None}, + 'size': 100000, 'cardinality': 1944, 'errorMessage': None, + }, 'dim2': { 'type': 'STRING', 'hasMultipleValues': True, - 'size': 100000, 'cardinality': 1504, 'errorMessage': None}, + 'size': 100000, 'cardinality': 1504, 'errorMessage': None, + }, 'metric1': { 'type': 'FLOAT', 'hasMultipleValues': False, - 'size': 100000, 'cardinality': None, 'errorMessage': None}, + 'size': 100000, 'cardinality': None, 'errorMessage': None, + }, }, 'aggregators': { 'metric1': { 'type': 'longSum', 'name': 'metric1', - 'fieldName': 'metric1'}, + 'fieldName': 'metric1', + }, }, 'size': 300000, 'numRows': 5000000, @@ -87,9 +92,7 @@ class DruidTests(SupersetTestCase): broker_port=7980, metadata_last_refreshed=datetime.now()) - @patch('superset.connectors.druid.models.PyDruid') - def test_client(self, PyDruid): - self.login(username='admin') + def get_cluster(self, PyDruid): instance = PyDruid.return_value instance.time_boundary.return_value = [ {'result': {'maxTime': '2016-01-01'}}] @@ -110,6 +113,13 @@ class DruidTests(SupersetTestCase): db.session.add(cluster) cluster.get_datasources = PickableMock(return_value=['test_datasource']) cluster.get_druid_version = PickableMock(return_value='0.9.1') + + return cluster + + @patch('superset.connectors.druid.models.PyDruid') + def test_client(self, PyDruid): + self.login(username='admin') + cluster = self.get_cluster(PyDruid) cluster.refresh_datasources() cluster.refresh_datasources(merge_flag=True) datasource_id = cluster.datasources[0].id @@ -121,6 +131,7 @@ class DruidTests(SupersetTestCase): nres = [dict(v) for v in nres] import pandas as pd df = pd.DataFrame(nres) + instance = PyDruid.return_value instance.export_pandas.return_value = df instance.query_dict = {} instance.query_builder.last_query.query_dict = {} @@ -327,6 +338,64 @@ class DruidTests(SupersetTestCase): permission=permission, view_menu=view_menu).first() assert pv is not None + @patch('superset.connectors.druid.models.PyDruid') + def test_refresh_metadata(self, PyDruid): + self.login(username='admin') + cluster = self.get_cluster(PyDruid) + cluster.refresh_datasources() + + for i, datasource in enumerate(cluster.datasources): + cols = ( + db.session.query(DruidColumn) + .filter(DruidColumn.datasource_id == datasource.id) + ) + + for col in cols: + self.assertIn( + col.column_name, + SEGMENT_METADATA[i]['columns'].keys(), + ) + + metrics = ( + db.session.query(DruidMetric) + .filter(DruidMetric.datasource_id == datasource.id) + .filter(DruidMetric.metric_name.like('%__metric1')) + ) + + self.assertEqual( + {metric.metric_name for metric in metrics}, + {'max__metric1', 'min__metric1', 'sum__metric1'}, + ) + + for metric in metrics: + agg, _ = metric.metric_name.split('__') + + self.assertEqual( + json.loads(metric.json)['type'], + 'double{}'.format(agg.capitalize()), + ) + + # Augment a metric. + metadata = SEGMENT_METADATA[:] + metadata[0]['columns']['metric1']['type'] = 'LONG' + instance = PyDruid.return_value + instance.segment_metadata.return_value = metadata + cluster.refresh_datasources() + + metrics = ( + db.session.query(DruidMetric) + .filter(DruidMetric.datasource_id == datasource.id) + .filter(DruidMetric.metric_name.like('%__metric1')) + ) + + for metric in metrics: + agg, _ = metric.metric_name.split('__') + + self.assertEqual( + metric.json_obj['type'], + 'long{}'.format(agg.capitalize()), + ) + def test_urls(self): cluster = self.get_test_cluster_obj() self.assertEquals(