Druid refresh metadata performance improvements (#3527)

* parallelized refresh druid metadata

* fixed code style errors

* fixed code for python3

* added option to only scan for new druid datasources

* Increased code coverage
This commit is contained in:
Jeff Niu 2017-09-25 18:00:46 -07:00 committed by Maxime Beauchemin
parent 3949d39478
commit cf0b670932
3 changed files with 220 additions and 131 deletions

View File

@ -5,12 +5,13 @@ import logging
from copy import deepcopy
from datetime import datetime, timedelta
from six import string_types
from multiprocessing import Pool
import requests
import sqlalchemy as sa
from sqlalchemy import (
Column, Integer, String, ForeignKey, Text, Boolean,
DateTime,
DateTime, or_, and_,
)
from sqlalchemy.orm import backref, relationship
from dateutil.parser import parse as dparse
@ -39,6 +40,12 @@ from superset.models.helpers import AuditMixinNullable, QueryResult, set_perm
DRUID_TZ = conf.get("DRUID_TZ")
# Function wrapper because bound methods cannot
# be passed to processes
def _fetch_metadata_for(datasource):
return datasource.latest_metadata()
class JavascriptPostAggregator(Postaggregator):
def __init__(self, name, field_names, function):
self.post_aggregator = {
@ -101,15 +108,99 @@ class DruidCluster(Model, AuditMixinNullable):
).format(obj=self)
return json.loads(requests.get(endpoint).text)['version']
def refresh_datasources(self, datasource_name=None, merge_flag=False):
def refresh_datasources(
self,
datasource_name=None,
merge_flag=True,
refreshAll=True):
"""Refresh metadata of all datasources in the cluster
If ``datasource_name`` is specified, only that datasource is updated
"""
self.druid_version = self.get_druid_version()
for datasource in self.get_datasources():
if datasource not in conf.get('DRUID_DATA_SOURCE_BLACKLIST', []):
if not datasource_name or datasource_name == datasource:
DruidDatasource.sync_to_db(datasource, self, merge_flag)
ds_list = self.get_datasources()
blacklist = conf.get('DRUID_DATA_SOURCE_BLACKLIST', [])
ds_refresh = []
if not datasource_name:
ds_refresh = list(filter(lambda ds: ds not in blacklist, ds_list))
elif datasource_name not in blacklist and datasource_name in ds_list:
ds_refresh.append(datasource_name)
else:
return
self.refresh_async(ds_refresh, merge_flag, refreshAll)
def refresh_async(self, datasource_names, merge_flag, refreshAll):
"""
Fetches metadata for the specified datasources andm
merges to the Superset database
"""
session = db.session
ds_list = (
session.query(DruidDatasource)
.filter(or_(DruidDatasource.datasource_name == name
for name in datasource_names))
)
ds_map = {ds.name: ds for ds in ds_list}
for ds_name in datasource_names:
datasource = ds_map.get(ds_name, None)
if not datasource:
datasource = DruidDatasource(datasource_name=ds_name)
with session.no_autoflush:
session.add(datasource)
flasher(
"Adding new datasource [{}]".format(ds_name), 'success')
ds_map[ds_name] = datasource
elif refreshAll:
flasher(
"Refreshing datasource [{}]".format(ds_name), 'info')
else:
del ds_map[ds_name]
continue
datasource.cluster = self
datasource.merge_flag = merge_flag
session.flush()
# Prepare multithreaded executation
pool = Pool()
ds_refresh = list(ds_map.values())
metadata = pool.map(_fetch_metadata_for, ds_refresh)
pool.close()
pool.join()
for i in range(0, len(ds_refresh)):
datasource = ds_refresh[i]
cols = metadata[i]
col_objs_list = (
session.query(DruidColumn)
.filter(DruidColumn.datasource_name == datasource.datasource_name)
.filter(or_(DruidColumn.column_name == col for col in cols))
)
col_objs = {col.column_name: col for col in col_objs_list}
for col in cols:
if col == '__time': # skip the time column
continue
col_obj = col_objs.get(col, None)
if not col_obj:
col_obj = DruidColumn(
datasource_name=datasource.datasource_name,
column_name=col)
with session.no_autoflush:
session.add(col_obj)
datatype = cols[col]['type']
if datatype == 'STRING':
col_obj.groupby = True
col_obj.filterable = True
if datatype == 'hyperUnique' or datatype == 'thetaSketch':
col_obj.count_distinct = True
# Allow sum/min/max for long or double
if datatype == 'LONG' or datatype == 'DOUBLE':
col_obj.sum = True
col_obj.min = True
col_obj.max = True
col_obj.type = datatype
col_obj.datasource = datasource
datasource.generate_metrics_for(col_objs_list)
session.commit()
@property
def perm(self):
@ -160,16 +251,14 @@ class DruidColumn(Model, BaseColumn):
if self.dimension_spec_json:
return json.loads(self.dimension_spec_json)
def generate_metrics(self):
"""Generate metrics based on the column metadata"""
M = DruidMetric # noqa
metrics = []
metrics.append(DruidMetric(
def get_metrics(self):
metrics = {}
metrics['count'] = DruidMetric(
metric_name='count',
verbose_name='COUNT(*)',
metric_type='count',
json=json.dumps({'type': 'count', 'name': 'count'})
))
)
# Somehow we need to reassign this for UDAFs
if self.type in ('DOUBLE', 'FLOAT'):
corrected_type = 'DOUBLE'
@ -179,49 +268,49 @@ class DruidColumn(Model, BaseColumn):
if self.sum and self.is_num:
mt = corrected_type.lower() + 'Sum'
name = 'sum__' + self.column_name
metrics.append(DruidMetric(
metrics[name] = DruidMetric(
metric_name=name,
metric_type='sum',
verbose_name='SUM({})'.format(self.column_name),
json=json.dumps({
'type': mt, 'name': name, 'fieldName': self.column_name})
))
)
if self.avg and self.is_num:
mt = corrected_type.lower() + 'Avg'
name = 'avg__' + self.column_name
metrics.append(DruidMetric(
metrics[name] = DruidMetric(
metric_name=name,
metric_type='avg',
verbose_name='AVG({})'.format(self.column_name),
json=json.dumps({
'type': mt, 'name': name, 'fieldName': self.column_name})
))
)
if self.min and self.is_num:
mt = corrected_type.lower() + 'Min'
name = 'min__' + self.column_name
metrics.append(DruidMetric(
metrics[name] = DruidMetric(
metric_name=name,
metric_type='min',
verbose_name='MIN({})'.format(self.column_name),
json=json.dumps({
'type': mt, 'name': name, 'fieldName': self.column_name})
))
)
if self.max and self.is_num:
mt = corrected_type.lower() + 'Max'
name = 'max__' + self.column_name
metrics.append(DruidMetric(
metrics[name] = DruidMetric(
metric_name=name,
metric_type='max',
verbose_name='MAX({})'.format(self.column_name),
json=json.dumps({
'type': mt, 'name': name, 'fieldName': self.column_name})
))
)
if self.count_distinct:
name = 'count_distinct__' + self.column_name
if self.type == 'hyperUnique' or self.type == 'thetaSketch':
metrics.append(DruidMetric(
metrics[name] = DruidMetric(
metric_name=name,
verbose_name='COUNT(DISTINCT {})'.format(self.column_name),
metric_type=self.type,
@ -230,10 +319,9 @@ class DruidColumn(Model, BaseColumn):
'name': name,
'fieldName': self.column_name
})
))
)
else:
mt = 'count_distinct'
metrics.append(DruidMetric(
metrics[name] = DruidMetric(
metric_name=name,
verbose_name='COUNT(DISTINCT {})'.format(self.column_name),
metric_type='count_distinct',
@ -241,22 +329,25 @@ class DruidColumn(Model, BaseColumn):
'type': 'cardinality',
'name': name,
'fieldNames': [self.column_name]})
))
session = get_session()
new_metrics = []
for metric in metrics:
m = (
session.query(M)
.filter(M.metric_name == metric.metric_name)
.filter(M.datasource_name == self.datasource_name)
.filter(DruidCluster.cluster_name == self.datasource.cluster_name)
.first()
)
)
return metrics
def generate_metrics(self):
"""Generate metrics based on the column metadata"""
metrics = self.get_metrics()
dbmetrics = (
db.session.query(DruidMetric)
.filter(DruidCluster.cluster_name == self.datasource.cluster_name)
.filter(DruidMetric.datasource_name == self.datasource_name)
.filter(or_(
DruidMetric.metric_name == m for m in metrics
))
)
dbmetrics = {metric.metric_name: metric for metric in dbmetrics}
for metric in metrics.values():
metric.datasource_name = self.datasource_name
if not m:
new_metrics.append(metric)
session.add(metric)
session.flush()
if not dbmetrics.get(metric.metric_name, None):
db.session.add(metric)
@classmethod
def import_obj(cls, i_column):
@ -474,6 +565,7 @@ class DruidDatasource(Model, BaseDatasource):
def latest_metadata(self):
"""Returns segment metadata from the latest segment"""
logging.info("Syncing datasource [{}]".format(self.datasource_name))
client = self.cluster.get_pydruid_client()
results = client.time_boundary(datasource=self.datasource_name)
if not results:
@ -485,31 +577,33 @@ class DruidDatasource(Model, BaseDatasource):
# realtime segments, which triggered a bug (fixed in druid 0.8.2).
# https://groups.google.com/forum/#!topic/druid-user/gVCqqspHqOQ
lbound = (max_time - timedelta(days=7)).isoformat()
rbound = max_time.isoformat()
if not self.version_higher(self.cluster.druid_version, '0.8.2'):
rbound = (max_time - timedelta(1)).isoformat()
else:
rbound = max_time.isoformat()
segment_metadata = None
try:
segment_metadata = client.segment_metadata(
datasource=self.datasource_name,
intervals=lbound + '/' + rbound,
merge=self.merge_flag,
analysisTypes=conf.get('DRUID_ANALYSIS_TYPES'))
analysisTypes=[])
except Exception as e:
logging.warning("Failed first attempt to get latest segment")
logging.exception(e)
if not segment_metadata:
# if no segments in the past 7 days, look at all segments
lbound = datetime(1901, 1, 1).isoformat()[:10]
rbound = datetime(2050, 1, 1).isoformat()[:10]
if not self.version_higher(self.cluster.druid_version, '0.8.2'):
rbound = datetime.now().isoformat()
else:
rbound = datetime(2050, 1, 1).isoformat()[:10]
try:
segment_metadata = client.segment_metadata(
datasource=self.datasource_name,
intervals=lbound + '/' + rbound,
merge=self.merge_flag,
analysisTypes=conf.get('DRUID_ANALYSIS_TYPES'))
analysisTypes=[])
except Exception as e:
logging.warning("Failed 2nd attempt to get latest segment")
logging.exception(e)
@ -517,17 +611,37 @@ class DruidDatasource(Model, BaseDatasource):
return segment_metadata[-1]['columns']
def generate_metrics(self):
for col in self.columns:
col.generate_metrics()
self.generate_metrics_for(self.columns)
def generate_metrics_for(self, columns):
metrics = {}
for col in columns:
metrics.update(col.get_metrics())
dbmetrics = (
db.session.query(DruidMetric)
.filter(DruidCluster.cluster_name == self.cluster_name)
.filter(DruidMetric.datasource_name == self.datasource_name)
.filter(or_(DruidMetric.metric_name == m for m in metrics))
)
dbmetrics = {metric.metric_name: metric for metric in dbmetrics}
for metric in metrics.values():
metric.datasource_name = self.datasource_name
if not dbmetrics.get(metric.metric_name, None):
with db.session.no_autoflush:
db.session.add(metric)
@classmethod
def sync_to_db_from_config(cls, druid_config, user, cluster):
def sync_to_db_from_config(
cls,
druid_config,
user,
cluster,
refresh=True):
"""Merges the ds config from druid_config into one stored in the db."""
session = db.session()
session = db.session
datasource = (
session.query(cls)
.filter_by(
datasource_name=druid_config['name'])
.filter_by(datasource_name=druid_config['name'])
.first()
)
# Create a new datasource.
@ -540,16 +654,18 @@ class DruidDatasource(Model, BaseDatasource):
created_by_fk=user.id,
)
session.add(datasource)
elif not refresh:
return
dimensions = druid_config['dimensions']
col_objs = (
session.query(DruidColumn)
.filter(DruidColumn.datasource_name == druid_config['name'])
.filter(or_(DruidColumn.column_name == dim for dim in dimensions))
)
col_objs = {col.column_name: col for col in col_objs}
for dim in dimensions:
col_obj = (
session.query(DruidColumn)
.filter_by(
datasource_name=druid_config['name'],
column_name=dim)
.first()
)
col_obj = col_objs.get(dim, None)
if not col_obj:
col_obj = DruidColumn(
datasource_name=druid_config['name'],
@ -562,6 +678,13 @@ class DruidDatasource(Model, BaseDatasource):
)
session.add(col_obj)
# Import Druid metrics
metric_objs = (
session.query(DruidMetric)
.filter(DruidMetric.datasource_name == druid_config['name'])
.filter(or_(DruidMetric.metric_name == spec['name']
for spec in druid_config["metrics_spec"]))
)
metric_objs = {metric.metric_name: metric for metric in metric_objs}
for metric_spec in druid_config["metrics_spec"]:
metric_name = metric_spec["name"]
metric_type = metric_spec["type"]
@ -575,12 +698,7 @@ class DruidDatasource(Model, BaseDatasource):
"fieldName": metric_name,
})
metric_obj = (
session.query(DruidMetric)
.filter_by(
datasource_name=druid_config['name'],
metric_name=metric_name)
).first()
metric_obj = metric_objs.get(metric_name, None)
if not metric_obj:
metric_obj = DruidMetric(
metric_name=metric_name,
@ -595,58 +713,6 @@ class DruidDatasource(Model, BaseDatasource):
session.add(metric_obj)
session.commit()
@classmethod
def sync_to_db(cls, name, cluster, merge):
"""Fetches metadata for that datasource and merges the Superset db"""
logging.info("Syncing Druid datasource [{}]".format(name))
session = get_session()
datasource = session.query(cls).filter_by(datasource_name=name).first()
if not datasource:
datasource = cls(datasource_name=name)
session.add(datasource)
flasher("Adding new datasource [{}]".format(name), "success")
else:
flasher("Refreshing datasource [{}]".format(name), "info")
session.flush()
datasource.cluster = cluster
datasource.merge_flag = merge
session.flush()
cols = datasource.latest_metadata()
if not cols:
logging.error("Failed at fetching the latest segment")
return
for col in cols:
# Skip the time column
if col == "__time":
continue
col_obj = (
session
.query(DruidColumn)
.filter_by(datasource_name=name, column_name=col)
.first()
)
datatype = cols[col]['type']
if not col_obj:
col_obj = DruidColumn(datasource_name=name, column_name=col)
session.add(col_obj)
if datatype == "STRING":
col_obj.groupby = True
col_obj.filterable = True
if datatype == "hyperUnique" or datatype == "thetaSketch":
col_obj.count_distinct = True
# If long or double, allow sum/min/max
if datatype == "LONG" or datatype == "DOUBLE":
col_obj.sum = True
col_obj.min = True
col_obj.max = True
if col_obj:
col_obj.type = cols[col]['type']
session.flush()
col_obj.datasource = datasource
col_obj.generate_metrics()
session.flush()
@staticmethod
def time_offset(granularity):
if granularity == 'week_ending_saturday':

View File

@ -235,17 +235,17 @@ class DruidDatasourceModelView(DatasourceModelView, DeleteMixin): # noqa
}
def pre_add(self, datasource):
number_of_existing_datasources = db.session.query(
sqla.func.count('*')).filter(
models.DruidDatasource.datasource_name ==
datasource.datasource_name,
models.DruidDatasource.cluster_name == datasource.cluster.id
).scalar()
# table object is already added to the session
if number_of_existing_datasources > 1:
raise Exception(get_datasource_exist_error_mgs(
datasource.full_name))
with db.session.no_autoflush:
query = (
db.session.query(models.DruidDatasource)
.filter(models.DruidDatasource.datasource_name ==
datasource.datasource_name,
models.DruidDatasource.cluster_name ==
datasource.cluster.id)
)
if db.session.query(query.exists()).scalar():
raise Exception(get_datasource_exist_error_mgs(
datasource.full_name))
def post_add(self, datasource):
datasource.generate_metrics()
@ -273,14 +273,14 @@ class Druid(BaseSupersetView):
@has_access
@expose("/refresh_datasources/")
def refresh_datasources(self):
def refresh_datasources(self, refreshAll=True):
"""endpoint that refreshes druid datasources metadata"""
session = db.session()
DruidCluster = ConnectorRegistry.sources['druid'].cluster_class
for cluster in session.query(DruidCluster).all():
cluster_name = cluster.cluster_name
try:
cluster.refresh_datasources()
cluster.refresh_datasources(refreshAll=refreshAll)
except Exception as e:
flash(
"Error while processing cluster '{}'\n{}".format(
@ -296,8 +296,25 @@ class Druid(BaseSupersetView):
session.commit()
return redirect("/druiddatasourcemodelview/list/")
@has_access
@expose("/scan_new_datasources/")
def scan_new_datasources(self):
"""
Calling this endpoint will cause a scan for new
datasources only and add them.
"""
return self.refresh_datasources(refreshAll=False)
appbuilder.add_view_no_menu(Druid)
appbuilder.add_link(
"Scan New Datasources",
label=__("Scan New Datasources"),
href='/druid/scan_new_datasources/',
category='Sources',
category_label=__("Sources"),
category_icon='fa-database',
icon="fa-refresh")
appbuilder.add_link(
"Refresh Druid Metadata",
label=__("Refresh Druid Metadata"),

View File

@ -16,6 +16,9 @@ from superset.connectors.druid.models import PyDruid, Quantile, Postaggregator
from .base_tests import SupersetTestCase
class PickableMock(Mock):
def __reduce__(self):
return (Mock, ())
SEGMENT_METADATA = [{
"id": "some_id",
@ -98,8 +101,8 @@ class DruidTests(SupersetTestCase):
metadata_last_refreshed=datetime.now())
db.session.add(cluster)
cluster.get_datasources = Mock(return_value=['test_datasource'])
cluster.get_druid_version = Mock(return_value='0.9.1')
cluster.get_datasources = PickableMock(return_value=['test_datasource'])
cluster.get_druid_version = PickableMock(return_value='0.9.1')
cluster.refresh_datasources()
cluster.refresh_datasources(merge_flag=True)
datasource_id = cluster.datasources[0].id
@ -303,11 +306,14 @@ class DruidTests(SupersetTestCase):
metadata_last_refreshed=datetime.now())
db.session.add(cluster)
cluster.get_datasources = Mock(return_value=['test_datasource'])
cluster.get_druid_version = Mock(return_value='0.9.1')
cluster.get_datasources = PickableMock(return_value=['test_datasource'])
cluster.get_druid_version = PickableMock(return_value='0.9.1')
cluster.refresh_datasources()
datasource_id = cluster.datasources[0].id
cluster.datasources[0].merge_flag = True
metadata = cluster.datasources[0].latest_metadata()
self.assertEqual(len(metadata), 4)
db.session.commit()
view_menu_name = cluster.datasources[0].get_perm()