Generalize switch between different datasources (#1078)

* Generalize switch between different datasources.

* Fix previous migration since slice model changed

* Fix warm up cache and other small stuff

* Adding modules and datasources through config

* Replace tabs w/ spaces

* Fix other style issues

* Change add method for SliceModelView to pick the first non-empty ds

* Remove tests on slice add redirect

* Change way of db migration

* Fix styling

* Fix create slice

* Small fixes

* Fix code climate check

* Adding notes on how to create new datasource in CONTRIBUTING.md

* Fix last merge

* A commit just to trigger travis build again

* Add migration to merge two heads

* Fix codeclimate

* Simplify source_registry

* Fix codeclimate

* Remove all getter methods
This commit is contained in:
ShengyaoQian 2016-09-21 09:52:05 -07:00 committed by Maxime Beauchemin
parent ed2feaf84b
commit 5a0e06e7a2
13 changed files with 223 additions and 110 deletions

View File

@ -251,3 +251,20 @@ You can then translate the strings gathered in files located under
to take effect, they need to be compiled using this command: to take effect, they need to be compiled using this command:
fabmanager babel-compile --target caravel/translations/ fabmanager babel-compile --target caravel/translations/
## Adding new datasources
1. Create Models and Views for the datasource, add them under caravel folder, like a new my_models.py
with models for cluster, datasources, columns and metrics and my_views.py with clustermodelview
and datasourcemodelview.
2. Create db migration files for the new models
3. Specify this variable to add the datasource model and from which module it is from in config.py:
For example:
`ADDITIONAL_MODULE_DS_MAP = {'caravel.my_models': ['MyDatasource', 'MyOtherDatasource']}`
This means it'll register MyDatasource and MyOtherDatasource in caravel.my_models module in the source registry.

View File

@ -14,6 +14,7 @@ from sqlalchemy import event, exc
from flask_appbuilder.baseviews import expose from flask_appbuilder.baseviews import expose
from flask_cache import Cache from flask_cache import Cache
from flask_migrate import Migrate from flask_migrate import Migrate
from caravel import source_registry
from werkzeug.contrib.fixers import ProxyFix from werkzeug.contrib.fixers import ProxyFix
@ -95,5 +96,7 @@ appbuilder = AppBuilder(
sm = appbuilder.sm sm = appbuilder.sm
src_registry = source_registry.SourceRegistry()
get_session = appbuilder.get_session get_session = appbuilder.get_session
from caravel import config, views # noqa from caravel import views, config # noqa

View File

@ -20,6 +20,14 @@ config = app.config
manager = Manager(app) manager = Manager(app)
manager.add_command('db', MigrateCommand) manager.add_command('db', MigrateCommand)
module_datasource_map = config.get("DEFAULT_MODULE_DS_MAP")
module_datasource_map.update(config.get("ADDITIONAL_MODULE_DS_MAP"))
datasources = {}
for module in module_datasource_map:
datasources[module] = __import__(module, fromlist=module_datasource_map[module])
utils.register_sources(datasources, module_datasource_map, caravel.src_registry)
@manager.option( @manager.option(

View File

@ -164,6 +164,13 @@ VIZ_TYPE_BLACKLIST = []
DRUID_DATA_SOURCE_BLACKLIST = [] DRUID_DATA_SOURCE_BLACKLIST = []
# --------------------------------------------------
# Modules and datasources to be registered
# --------------------------------------------------
DEFAULT_MODULE_DS_MAP = {'caravel.models': ['DruidDatasource', 'SqlaTable']}
ADDITIONAL_MODULE_DS_MAP = {}
""" """
1) http://docs.python-guide.org/en/latest/writing/logging/ 1) http://docs.python-guide.org/en/latest/writing/logging/
2) https://docs.python.org/2/library/logging.config.html 2) https://docs.python.org/2/library/logging.config.html

View File

@ -75,7 +75,7 @@ def load_energy():
slice_name="Energy Sankey", slice_name="Energy Sankey",
viz_type='sankey', viz_type='sankey',
datasource_type='table', datasource_type='table',
table=tbl, datasource_id=tbl.id,
params=textwrap.dedent("""\ params=textwrap.dedent("""\
{ {
"collapsed_fieldsets": "", "collapsed_fieldsets": "",
@ -105,7 +105,7 @@ def load_energy():
slice_name="Energy Force Layout", slice_name="Energy Force Layout",
viz_type='directed_force', viz_type='directed_force',
datasource_type='table', datasource_type='table',
table=tbl, datasource_id=tbl.id,
params=textwrap.dedent("""\ params=textwrap.dedent("""\
{ {
"charge": "-500", "charge": "-500",
@ -136,7 +136,7 @@ def load_energy():
slice_name="Heatmap", slice_name="Heatmap",
viz_type='heatmap', viz_type='heatmap',
datasource_type='table', datasource_type='table',
table=tbl, datasource_id=tbl.id,
params=textwrap.dedent("""\ params=textwrap.dedent("""\
{ {
"all_columns_x": "source", "all_columns_x": "source",
@ -224,7 +224,7 @@ def load_world_bank_health_n_pop():
slice_name="Region Filter", slice_name="Region Filter",
viz_type='filter_box', viz_type='filter_box',
datasource_type='table', datasource_type='table',
table=tbl, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
viz_type='filter_box', viz_type='filter_box',
@ -233,7 +233,7 @@ def load_world_bank_health_n_pop():
slice_name="World's Population", slice_name="World's Population",
viz_type='big_number', viz_type='big_number',
datasource_type='table', datasource_type='table',
table=tbl, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
since='2000', since='2000',
@ -245,7 +245,7 @@ def load_world_bank_health_n_pop():
slice_name="Most Populated Countries", slice_name="Most Populated Countries",
viz_type='table', viz_type='table',
datasource_type='table', datasource_type='table',
table=tbl, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
viz_type='table', viz_type='table',
@ -255,7 +255,7 @@ def load_world_bank_health_n_pop():
slice_name="Growth Rate", slice_name="Growth Rate",
viz_type='line', viz_type='line',
datasource_type='table', datasource_type='table',
table=tbl, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
viz_type='line', viz_type='line',
@ -267,7 +267,7 @@ def load_world_bank_health_n_pop():
slice_name="% Rural", slice_name="% Rural",
viz_type='world_map', viz_type='world_map',
datasource_type='table', datasource_type='table',
table=tbl, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
viz_type='world_map', viz_type='world_map',
@ -277,7 +277,7 @@ def load_world_bank_health_n_pop():
slice_name="Life Expectancy VS Rural %", slice_name="Life Expectancy VS Rural %",
viz_type='bubble', viz_type='bubble',
datasource_type='table', datasource_type='table',
table=tbl, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
viz_type='bubble', viz_type='bubble',
@ -298,7 +298,7 @@ def load_world_bank_health_n_pop():
slice_name="Rural Breakdown", slice_name="Rural Breakdown",
viz_type='sunburst', viz_type='sunburst',
datasource_type='table', datasource_type='table',
table=tbl, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
viz_type='sunburst', viz_type='sunburst',
@ -310,7 +310,7 @@ def load_world_bank_health_n_pop():
slice_name="World's Pop Growth", slice_name="World's Pop Growth",
viz_type='area', viz_type='area',
datasource_type='table', datasource_type='table',
table=tbl, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
since="1960-01-01", since="1960-01-01",
@ -321,7 +321,7 @@ def load_world_bank_health_n_pop():
slice_name="Box plot", slice_name="Box plot",
viz_type='box_plot', viz_type='box_plot',
datasource_type='table', datasource_type='table',
table=tbl, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
since="1960-01-01", since="1960-01-01",
@ -333,7 +333,7 @@ def load_world_bank_health_n_pop():
slice_name="Treemap", slice_name="Treemap",
viz_type='treemap', viz_type='treemap',
datasource_type='table', datasource_type='table',
table=tbl, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
since="1960-01-01", since="1960-01-01",
@ -345,7 +345,7 @@ def load_world_bank_health_n_pop():
slice_name="Parallel Coordinates", slice_name="Parallel Coordinates",
viz_type='para', viz_type='para',
datasource_type='table', datasource_type='table',
table=tbl, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
since="2011-01-01", since="2011-01-01",
@ -615,7 +615,7 @@ def load_birth_names():
slice_name="Girls", slice_name="Girls",
viz_type='table', viz_type='table',
datasource_type='table', datasource_type='table',
table=tbl, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
groupby=['name'], groupby=['name'],
@ -625,7 +625,7 @@ def load_birth_names():
slice_name="Boys", slice_name="Boys",
viz_type='table', viz_type='table',
datasource_type='table', datasource_type='table',
table=tbl, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
groupby=['name'], groupby=['name'],
@ -636,7 +636,7 @@ def load_birth_names():
slice_name="Participants", slice_name="Participants",
viz_type='big_number', viz_type='big_number',
datasource_type='table', datasource_type='table',
table=tbl, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
viz_type="big_number", granularity="ds", viz_type="big_number", granularity="ds",
@ -645,7 +645,7 @@ def load_birth_names():
slice_name="Genders", slice_name="Genders",
viz_type='pie', viz_type='pie',
datasource_type='table', datasource_type='table',
table=tbl, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
viz_type="pie", groupby=['gender'])), viz_type="pie", groupby=['gender'])),
@ -653,7 +653,7 @@ def load_birth_names():
slice_name="Genders by State", slice_name="Genders by State",
viz_type='dist_bar', viz_type='dist_bar',
datasource_type='table', datasource_type='table',
table=tbl, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
flt_eq_1="other", viz_type="dist_bar", flt_eq_1="other", viz_type="dist_bar",
@ -663,7 +663,7 @@ def load_birth_names():
slice_name="Trends", slice_name="Trends",
viz_type='line', viz_type='line',
datasource_type='table', datasource_type='table',
table=tbl, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
viz_type="line", groupby=['name'], viz_type="line", groupby=['name'],
@ -672,7 +672,7 @@ def load_birth_names():
slice_name="Title", slice_name="Title",
viz_type='markup', viz_type='markup',
datasource_type='table', datasource_type='table',
table=tbl, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
viz_type="markup", markup_type="html", viz_type="markup", markup_type="html",
@ -690,7 +690,7 @@ def load_birth_names():
slice_name="Name Cloud", slice_name="Name Cloud",
viz_type='word_cloud', viz_type='word_cloud',
datasource_type='table', datasource_type='table',
table=tbl, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
viz_type="word_cloud", size_from="10", viz_type="word_cloud", size_from="10",
@ -700,7 +700,7 @@ def load_birth_names():
slice_name="Pivot Table", slice_name="Pivot Table",
viz_type='pivot_table', viz_type='pivot_table',
datasource_type='table', datasource_type='table',
table=tbl, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
viz_type="pivot_table", metrics=['sum__num'], viz_type="pivot_table", metrics=['sum__num'],
@ -709,7 +709,7 @@ def load_birth_names():
slice_name="Number of Girls", slice_name="Number of Girls",
viz_type='big_number_total', viz_type='big_number_total',
datasource_type='table', datasource_type='table',
table=tbl, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
viz_type="big_number_total", granularity="ds", viz_type="big_number_total", granularity="ds",
@ -862,7 +862,7 @@ def load_unicode_test_data():
slice_name="Unicode Cloud", slice_name="Unicode Cloud",
viz_type='word_cloud', viz_type='word_cloud',
datasource_type='table', datasource_type='table',
table=tbl, datasource_id=tbl.id,
params=get_slice_json(slice_data), params=get_slice_json(slice_data),
) )
merge_slice(slc) merge_slice(slc)
@ -935,7 +935,7 @@ def load_random_time_series_data():
slice_name="Calendar Heatmap", slice_name="Calendar Heatmap",
viz_type='cal_heatmap', viz_type='cal_heatmap',
datasource_type='table', datasource_type='table',
table=tbl, datasource_id=tbl.id,
params=get_slice_json(slice_data), params=get_slice_json(slice_data),
) )
merge_slice(slc) merge_slice(slc)
@ -1005,7 +1005,7 @@ def load_long_lat_data():
slice_name="Mapbox Long/Lat", slice_name="Mapbox Long/Lat",
viz_type='mapbox', viz_type='mapbox',
datasource_type='table', datasource_type='table',
table=tbl, datasource_id=tbl.id,
params=get_slice_json(slice_data), params=get_slice_json(slice_data),
) )
merge_slice(slc) merge_slice(slc)
@ -1084,7 +1084,7 @@ def load_multiformat_time_series_data():
slice_name="Calendar Heatmap multiformat" + str(i), slice_name="Calendar Heatmap multiformat" + str(i),
viz_type='cal_heatmap', viz_type='cal_heatmap',
datasource_type='table', datasource_type='table',
table=tbl, datasource_id=tbl.id,
params=get_slice_json(slice_data), params=get_slice_json(slice_data),
) )
merge_slice(slc) merge_slice(slc)

View File

@ -11,15 +11,34 @@ revision = '27ae655e4247'
down_revision = 'd8bc074f7aad' down_revision = 'd8bc074f7aad'
from alembic import op from alembic import op
from caravel import db, models from caravel import db
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship
from flask_appbuilder import Model
from sqlalchemy import (
Column, Integer, ForeignKey, Table)
Base = declarative_base()
class Slice(Base):
"""Declarative class to do query in upgrade"""
__tablename__ = 'slices'
id = Column(Integer, primary_key=True)
class Dashboard(Base):
"""Declarative class to do query in upgrade"""
__tablename__ = 'dashboards'
id = Column(Integer, primary_key=True)
def upgrade(): def upgrade():
bind = op.get_bind() bind = op.get_bind()
session = db.Session(bind=bind) session = db.Session(bind=bind)
objects = session.query(models.Slice).all() objects = session.query(Slice).all()
objects += session.query(models.Dashboard).all() objects += session.query(Dashboard).all()
for obj in objects: for obj in objects:
if obj.created_by and obj.created_by not in obj.owners: if obj.created_by and obj.created_by not in obj.owners:
obj.owners.append(obj.created_by) obj.owners.append(obj.created_by)

View File

@ -0,0 +1,59 @@
from alembic import op
import sqlalchemy as sa
from caravel import db
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import (
Column, Integer, String)
"""update slice model
Revision ID: 33d996bcc382
Revises: 41f6a59a61f2
Create Date: 2016-09-07 23:50:59.366779
"""
# revision identifiers, used by Alembic.
revision = '33d996bcc382'
down_revision = '41f6a59a61f2'
Base = declarative_base()
class Slice(Base):
"""Declarative class to do query in upgrade"""
__tablename__ = 'slices'
id = Column(Integer, primary_key=True)
datasource_id = Column(Integer)
druid_datasource_id = Column(Integer)
table_id = Column(Integer)
datasource_type = Column(String(200))
def upgrade():
bind = op.get_bind()
op.add_column('slices', sa.Column('datasource_id', sa.Integer()))
session = db.Session(bind=bind)
for slc in session.query(Slice).all():
if slc.druid_datasource_id:
slc.datasource_id = slc.druid_datasource_id
if slc.table_id:
slc.datasource_id = slc.table_id
session.merge(slc)
session.commit()
session.close()
def downgrade():
bind = op.get_bind()
session = db.Session(bind=bind)
for slc in session.query(Slice).all():
if slc.datasource_type == 'druid':
slc.druid_datasource_id = slc.datasource_id
if slc.datasource_type == 'table':
slc.table_id = slc.datasource_id
session.merge(slc)
session.commit()
session.close()
op.drop_column('slices', 'datasource_id')

View File

@ -0,0 +1,19 @@
"""empty message
Revision ID: b347b202819b
Revises: ('33d996bcc382', '65903709c321')
Create Date: 2016-09-19 17:22:40.138601
"""
# revision identifiers, used by Alembic.
revision = 'b347b202819b'
down_revision = ('33d996bcc382', '65903709c321')
def upgrade():
pass
def downgrade():
pass

View File

@ -49,7 +49,7 @@ from sqlalchemy_utils import EncryptedType
from werkzeug.datastructures import ImmutableMultiDict from werkzeug.datastructures import ImmutableMultiDict
import caravel import caravel
from caravel import app, db, get_session, utils, sm from caravel import app, db, get_session, utils, sm, src_registry
from caravel.viz import viz_types from caravel.viz import viz_types
from caravel.utils import flasher, MetricPermException, DimSelector from caravel.utils import flasher, MetricPermException, DimSelector
@ -156,8 +156,7 @@ class Slice(Model, AuditMixinNullable):
__tablename__ = 'slices' __tablename__ = 'slices'
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
slice_name = Column(String(250)) slice_name = Column(String(250))
druid_datasource_id = Column(Integer, ForeignKey('datasources.id')) datasource_id = Column(Integer)
table_id = Column(Integer, ForeignKey('tables.id'))
datasource_type = Column(String(200)) datasource_type = Column(String(200))
datasource_name = Column(String(2000)) datasource_name = Column(String(2000))
viz_type = Column(String(250)) viz_type = Column(String(250))
@ -165,33 +164,34 @@ class Slice(Model, AuditMixinNullable):
description = Column(Text) description = Column(Text)
cache_timeout = Column(Integer) cache_timeout = Column(Integer)
perm = Column(String(2000)) perm = Column(String(2000))
table = relationship(
'SqlaTable', foreign_keys=[table_id], backref='slices')
druid_datasource = relationship(
'DruidDatasource', foreign_keys=[druid_datasource_id], backref='slices')
owners = relationship("User", secondary=slice_user) owners = relationship("User", secondary=slice_user)
def __repr__(self): def __repr__(self):
return self.slice_name return self.slice_name
@property
def cls_model(self):
return src_registry.sources[self.datasource_type]
@property @property
def datasource(self): def datasource(self):
return self.table or self.druid_datasource return self.get_datasource
@datasource.getter
@utils.memoized
def get_datasource(self):
ds = db.session.query(
self.cls_model).filter_by(
id=self.datasource_id).first()
return ds
@renders('datasource_name') @renders('datasource_name')
def datasource_link(self): def datasource_link(self):
if self.table: return self.datasource.link
return self.table.link
elif self.druid_datasource:
return self.druid_datasource.link
@property @property
def datasource_edit_url(self): def datasource_edit_url(self):
if self.table: self.datasource.url
return self.table.url
elif self.druid_datasource:
return self.druid_datasource.url
@property @property
@utils.memoized @utils.memoized
@ -204,10 +204,6 @@ class Slice(Model, AuditMixinNullable):
def description_markeddown(self): def description_markeddown(self):
return utils.markdown(self.description) return utils.markdown(self.description)
@property
def datasource_id(self):
return self.table_id or self.druid_datasource_id
@property @property
def data(self): def data(self):
"""Data used to render slice in templates""" """Data used to render slice in templates"""
@ -283,12 +279,8 @@ class Slice(Model, AuditMixinNullable):
def set_perm(mapper, connection, target): # noqa def set_perm(mapper, connection, target): # noqa
if target.table_id: src_class = target.cls_model
src_class = SqlaTable id_ = target.datasource_id
id_ = target.table_id
elif target.druid_datasource_id:
src_class = DruidDatasource
id_ = target.druid_datasource_id
ds = db.session.query(src_class).filter_by(id=int(id_)).first() ds = db.session.query(src_class).filter_by(id=int(id_)).first()
target.perm = ds.perm target.perm = ds.perm

View File

@ -0,0 +1,15 @@
from flask import flash
class SourceRegistry(object):
""" Central Registry for all available datasource engines"""
sources = {}
def add_source(self, ds_type, cls_model):
if ds_type not in self.sources:
self.sources[ds_type] = cls_model
if self.sources[ds_type] is not cls_model:
raise Exception(
'source type: {} is already associated with Model: {}'.format(
ds_type, self.sources[ds_type]))

View File

@ -410,6 +410,14 @@ def readfile(filepath):
return content return content
def register_sources(datasources, module_datasource_map, registry):
for m in datasources:
datasource_list = module_datasource_map[m]
for ds in datasource_list:
ds_class = getattr(datasources[m], ds)
registry.add_source(ds_class.type, ds_class)
def generic_find_constraint_name(table, columns, referenced, db): def generic_find_constraint_name(table, columns, referenced, db):
"""Utility to find a constraint name in alembic migrations""" """Utility to find a constraint name in alembic migrations"""
t = sa.Table(table, db.metadata, autoload=True, autoload_with=db.engine) t = sa.Table(table, db.metadata, autoload=True, autoload_with=db.engine)

View File

@ -33,7 +33,8 @@ from wtforms.validators import ValidationError
import caravel import caravel
from caravel import ( from caravel import (
appbuilder, cache, db, models, viz, utils, app, sm, ascii_art, sql_lab appbuilder, cache, db, models, viz, utils, app,
sm, ascii_art, sql_lab, src_registry
) )
config = app.config config = app.config
@ -675,8 +676,7 @@ class SliceModelView(CaravelModelView, DeleteMixin): # noqa
list_columns = [ list_columns = [
'slice_link', 'viz_type', 'datasource_link', 'creator', 'modified'] 'slice_link', 'viz_type', 'datasource_link', 'creator', 'modified']
edit_columns = [ edit_columns = [
'slice_name', 'description', 'viz_type', 'druid_datasource', 'slice_name', 'description', 'viz_type', 'owners', 'dashboards', 'params', 'cache_timeout']
'table', 'owners', 'dashboards', 'params', 'cache_timeout']
base_order = ('changed_on', 'desc') base_order = ('changed_on', 'desc')
description_columns = { description_columns = {
'description': Markup( 'description': Markup(
@ -722,18 +722,13 @@ class SliceModelView(CaravelModelView, DeleteMixin): # noqa
if not widget: if not widget:
return redirect(self.get_redirect()) return redirect(self.get_redirect())
a_druid_datasource = db.session.query(models.DruidDatasource).first() sources = src_registry.sources
if a_druid_datasource is not None: for source in sources:
url = "/druiddatasourcemodelview/list/" ds = db.session.query(src_registry.sources[source]).first()
msg = _( if ds is not None:
"Click on a datasource link to create a Slice, " url = "/{}/list/".format(ds.baselink)
"or click on a table link " msg = _("Click on a {} link to create a Slice".format(source))
"<a href='/tablemodelview/list/'>here</a> " break
"to create a Slice for a table"
)
else:
url = "/tablemodelview/list/"
msg = _("Click on a table link to create a Slice")
redirect_url = "/r/msg/?url={}&msg={}".format(url, msg) redirect_url = "/r/msg/?url={}&msg={}".format(url, msg)
return redirect(redirect_url) return redirect(redirect_url)
@ -978,8 +973,8 @@ class Caravel(BaseCaravelView):
@log_this @log_this
def explore(self, datasource_type, datasource_id, slice_id=None): def explore(self, datasource_type, datasource_id, slice_id=None):
error_redirect = '/slicemodelview/list/' error_redirect = '/slicemodelview/list/'
datasource_class = models.SqlaTable \ datasource_class = src_registry.sources[datasource_type]
if datasource_type == "table" else models.DruidDatasource
datasources = ( datasources = (
db.session db.session
.query(datasource_class) .query(datasource_class)
@ -1093,12 +1088,8 @@ class Caravel(BaseCaravelView):
if k not in as_list and isinstance(v, list): if k not in as_list and isinstance(v, list):
d[k] = v[0] d[k] = v[0]
table_id = druid_datasource_id = None
datasource_type = args.get('datasource_type') datasource_type = args.get('datasource_type')
if datasource_type in ('datasource', 'druid'): datasource_id = args.get('datasource_id')
druid_datasource_id = args.get('datasource_id')
elif datasource_type == 'table':
table_id = args.get('datasource_id')
if action in ('saveas'): if action in ('saveas'):
d.pop('slice_id') # don't save old slice_id d.pop('slice_id') # don't save old slice_id
@ -1107,9 +1098,8 @@ class Caravel(BaseCaravelView):
slc.params = json.dumps(d, indent=4, sort_keys=True) slc.params = json.dumps(d, indent=4, sort_keys=True)
slc.datasource_name = args.get('datasource_name') slc.datasource_name = args.get('datasource_name')
slc.viz_type = args.get('viz_type') slc.viz_type = args.get('viz_type')
slc.druid_datasource_id = druid_datasource_id
slc.table_id = table_id
slc.datasource_type = datasource_type slc.datasource_type = datasource_type
slc.datasource_id = datasource_id
slc.slice_name = slice_name slc.slice_name = slice_name
if action in ('saveas') and slice_add_perm: if action in ('saveas') and slice_add_perm:
@ -1330,7 +1320,9 @@ class Caravel(BaseCaravelView):
json_error_response(__( json_error_response(__(
"Table %(t)s wasn't found in the database %(d)s", "Table %(t)s wasn't found in the database %(d)s",
t=table_name, s=db_name), status=404) t=table_name, s=db_name), status=404)
slices = table.slices slices = session.query(models.Slice).filter_by(
datasource_id=table.id,
datasource_type=table.type).all()
for slice in slices: for slice in slices:
try: try:

View File

@ -210,32 +210,6 @@ class CoreTests(CaravelTestCase):
assert new_slice in dash.slices assert new_slice in dash.slices
assert len(set(dash.slices)) == len(dash.slices) assert len(set(dash.slices)) == len(dash.slices)
def test_add_slice_redirect_to_sqla(self, username='admin'):
self.login(username=username)
url = '/slicemodelview/add'
resp = self.client.get(url, follow_redirects=True)
assert (
"Click on a table link to create a Slice" in
resp.data.decode('utf-8')
)
def test_add_slice_redirect_to_druid(self, username='admin'):
datasource = DruidDatasource(
datasource_name="datasource_name",
)
db.session.add(datasource)
db.session.commit()
self.login(username=username)
url = '/slicemodelview/add'
resp = self.client.get(url, follow_redirects=True)
assert (
"Click on a datasource link to create a Slice"
in resp.data.decode('utf-8')
)
db.session.delete(datasource)
db.session.commit()
def test_druid_sync_from_config(self): def test_druid_sync_from_config(self):
cluster = models.DruidCluster(cluster_name="new_druid") cluster = models.DruidCluster(cluster_name="new_druid")