Import / export of the dashboards. (#1197)

* Implement import / export dashboard functionality.

* Address comments from discussion.

* Add function descriptions.

* Minor fixes

* Fix tests for python 3.

* Export datasources.

* Implement tables import.

* Json.loads does not support trailing commas.

* Improve alter_dict func

* Resolve comments.

* Refactor tests

* Move params_dict and alter_params to the ImportMixin

* Fix flask menues.
This commit is contained in:
Bogdan 2016-10-11 17:54:40 -07:00 committed by GitHub
parent cd2ab42abc
commit 73cd2ea3b1
10 changed files with 827 additions and 26 deletions

View File

@ -82,6 +82,12 @@ if app.config.get('ENABLE_CORS'):
if app.config.get('ENABLE_PROXY_FIX'):
app.wsgi_app = ProxyFix(app.wsgi_app)
if app.config.get('UPLOAD_FOLDER'):
try:
os.makedirs(app.config.get('UPLOAD_FOLDER'))
except OSError:
pass
class MyIndexView(IndexView):
@expose('/')

View File

@ -1058,7 +1058,7 @@ def load_multiformat_time_series_data():
'string0': ['%Y-%m-%d %H:%M:%S.%f', None],
'string3': ['%Y/%m/%d%H:%M:%S.%f', None],
}
for col in obj.table_columns:
for col in obj.columns:
dttm_and_expr = dttm_and_expr_dict[col.column_name]
col.python_date_format = dttm_and_expr[0]
col.dbatabase_expr = dttm_and_expr[1]
@ -1069,7 +1069,7 @@ def load_multiformat_time_series_data():
tbl = obj
print("Creating some slices")
for i, col in enumerate(tbl.table_columns):
for i, col in enumerate(tbl.columns):
slice_data = {
"granularity_sqla": col.column_name,
"datasource_id": "8",

View File

@ -0,0 +1,28 @@
"""Add json_metadata to the tables table.
Revision ID: b46fa1b0b39e
Revises: ef8843b41dac
Create Date: 2016-10-05 11:30:31.748238
"""
# revision identifiers, used by Alembic.
revision = 'b46fa1b0b39e'
down_revision = 'ef8843b41dac'
from alembic import op
import logging
import sqlalchemy as sa
def upgrade():
op.add_column('tables',
sa.Column('params', sa.Text(), nullable=True))
def downgrade():
try:
op.drop_column('tables', 'params')
except Exception as e:
logging.warning(str(e))

View File

@ -7,6 +7,7 @@ from __future__ import unicode_literals
import functools
import json
import logging
import pickle
import re
import textwrap
from collections import namedtuple
@ -18,6 +19,8 @@ import pandas as pd
import requests
import sqlalchemy as sqla
from sqlalchemy.engine.url import make_url
from sqlalchemy.orm import subqueryload
import sqlparse
from dateutil.parser import parse
@ -41,6 +44,7 @@ from sqlalchemy import (
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import backref, relationship
from sqlalchemy.orm.session import make_transient
from sqlalchemy.sql import table, literal_column, text, column
from sqlalchemy.sql.expression import ColumnClause, TextAsFrom
from sqlalchemy_utils import EncryptedType
@ -70,6 +74,31 @@ class JavascriptPostAggregator(Postaggregator):
self.name = name
class ImportMixin(object):
def override(self, obj):
"""Overrides the plain fields of the dashboard."""
for field in obj.__class__.export_fields:
setattr(self, field, getattr(obj, field))
def copy(self):
"""Creates a copy of the dashboard without relationships."""
new_obj = self.__class__()
new_obj.override(self)
return new_obj
def alter_params(self, **kwargs):
d = self.params_dict
d.update(kwargs)
self.params = json.dumps(d)
@property
def params_dict(self):
if self.params:
return json.loads(self.params)
else:
return {}
class AuditMixinNullable(AuditMixin):
"""Altering the AuditMixin to use nullable fields
@ -149,7 +178,7 @@ slice_user = Table('slice_user', Model.metadata,
)
class Slice(Model, AuditMixinNullable):
class Slice(Model, AuditMixinNullable, ImportMixin):
"""A slice is essentially a report or a view on data"""
@ -166,6 +195,9 @@ class Slice(Model, AuditMixinNullable):
perm = Column(String(2000))
owners = relationship("User", secondary=slice_user)
export_fields = ('slice_name', 'datasource_type', 'datasource_name',
'viz_type', 'params', 'cache_timeout')
def __repr__(self):
return self.slice_name
@ -283,6 +315,42 @@ class Slice(Model, AuditMixinNullable):
slice_=self
)
@classmethod
def import_obj(cls, slc_to_import, import_time=None):
"""Inserts or overrides slc in the database.
remote_id and import_time fields in params_dict are set to track the
slice origin and ensure correct overrides for multiple imports.
Slice.perm is used to find the datasources and connect them.
"""
session = db.session
make_transient(slc_to_import)
slc_to_import.dashboards = []
slc_to_import.alter_params(
remote_id=slc_to_import.id, import_time=import_time)
# find if the slice was already imported
slc_to_override = None
for slc in session.query(Slice).all():
if ('remote_id' in slc.params_dict and
slc.params_dict['remote_id'] == slc_to_import.id):
slc_to_override = slc
slc_to_import.id = None
params = slc_to_import.params_dict
slc_to_import.datasource_id = SourceRegistry.get_datasource_by_name(
session, slc_to_import.datasource_type, params['datasource_name'],
params['schema'], params['database_name']).id
if slc_to_override:
slc_to_override.override(slc_to_import)
session.flush()
return slc_to_override.id
else:
session.add(slc_to_import)
logging.info('Final slice: {}'.format(slc_to_import.to_json()))
session.flush()
return slc_to_import.id
def set_perm(mapper, connection, target): # noqa
src_class = target.cls_model
@ -309,7 +377,7 @@ dashboard_user = Table(
)
class Dashboard(Model, AuditMixinNullable):
class Dashboard(Model, AuditMixinNullable, ImportMixin):
"""The dashboard object!"""
@ -325,6 +393,9 @@ class Dashboard(Model, AuditMixinNullable):
'Slice', secondary=dashboard_slices, backref='dashboards')
owners = relationship("User", secondary=dashboard_user)
export_fields = ('dashboard_title', 'position_json', 'json_metadata',
'description', 'css', 'slug', 'slices')
def __repr__(self):
return self.dashboard_title
@ -340,13 +411,6 @@ class Dashboard(Model, AuditMixinNullable):
def datasources(self):
return {slc.datasource for slc in self.slices}
@property
def metadata_dejson(self):
if self.json_metadata:
return json.loads(self.json_metadata)
else:
return {}
@property
def sqla_metadata(self):
metadata = MetaData(bind=self.get_sqla_engine())
@ -361,7 +425,7 @@ class Dashboard(Model, AuditMixinNullable):
def json_data(self):
d = {
'id': self.id,
'metadata': self.metadata_dejson,
'metadata': self.params_dict,
'dashboard_title': self.dashboard_title,
'slug': self.slug,
'slices': [slc.data for slc in self.slices],
@ -369,6 +433,107 @@ class Dashboard(Model, AuditMixinNullable):
}
return json.dumps(d)
@property
def params(self):
return self.json_metadata
@params.setter
def params(self, value):
self.json_metadata = value
@classmethod
def import_obj(cls, dashboard_to_import, import_time=None):
"""Imports the dashboard from the object to the database.
Once dashboard is imported, json_metadata field is extended and stores
remote_id and import_time. It helps to decide if the dashboard has to
be overridden or just copies over. Slices that belong to this
dashboard will be wired to existing tables. This function can be used
to import/export dashboards between multiple caravel instances.
Audit metadata isn't copies over.
"""
logging.info('Started import of the dashboard: {}'
.format(dashboard_to_import.to_json()))
session = db.session
logging.info('Dashboard has {} slices'
.format(len(dashboard_to_import.slices)))
# copy slices object as Slice.import_slice will mutate the slice
# and will remove the existing dashboard - slice association
slices = copy(dashboard_to_import.slices)
slice_ids = set()
for slc in slices:
logging.info('Importing slice {} from the dashboard: {}'.format(
slc.to_json(), dashboard_to_import.dashboard_title))
slice_ids.add(Slice.import_obj(slc, import_time=import_time))
# override the dashboard
existing_dashboard = None
for dash in session.query(Dashboard).all():
if ('remote_id' in dash.params_dict and
dash.params_dict['remote_id'] ==
dashboard_to_import.id):
existing_dashboard = dash
dashboard_to_import.id = None
dashboard_to_import.alter_params(import_time=import_time)
new_slices = session.query(Slice).filter(Slice.id.in_(slice_ids)).all()
if existing_dashboard:
existing_dashboard.override(dashboard_to_import)
existing_dashboard.slices = new_slices
session.flush()
return existing_dashboard.id
else:
# session.add(dashboard_to_import) causes sqlachemy failures
# related to the attached users / slices. Creating new object
# allows to avoid conflicts in the sql alchemy state.
copied_dash = dashboard_to_import.copy()
copied_dash.slices = new_slices
session.add(copied_dash)
session.flush()
return copied_dash.id
@classmethod
def export_dashboards(cls, dashboard_ids):
copied_dashboards = []
datasource_ids = set()
for dashboard_id in dashboard_ids:
# make sure that dashboard_id is an integer
dashboard_id = int(dashboard_id)
copied_dashboard = (
db.session.query(Dashboard)
.options(subqueryload(Dashboard.slices))
.filter_by(id=dashboard_id).first()
)
make_transient(copied_dashboard)
for slc in copied_dashboard.slices:
datasource_ids.add((slc.datasource_id, slc.datasource_type))
# add extra params for the import
slc.alter_params(
remote_id=slc.id,
datasource_name=slc.datasource.name,
schema=slc.datasource.name,
database_name=slc.datasource.database.database_name,
)
copied_dashboard.alter_params(remote_id=dashboard_id)
copied_dashboards.append(copied_dashboard)
eager_datasources = []
for dashboard_id, dashboard_type in datasource_ids:
eager_datasource = SourceRegistry.get_eager_datasource(
db.session, dashboard_type, dashboard_id)
eager_datasource.alter_params(
remote_id=eager_datasource.id,
database_name=eager_datasource.database.database_name,
)
make_transient(eager_datasource)
eager_datasources.append(eager_datasource)
return pickle.dumps({
'dashboards': copied_dashboards,
'datasources': eager_datasources,
})
class Queryable(object):
@ -433,6 +598,10 @@ class Database(Model, AuditMixinNullable):
def __repr__(self):
return self.database_name
@property
def name(self):
return self.database_name
@property
def backend(self):
url = make_url(self.sqlalchemy_uri_decrypted)
@ -665,7 +834,7 @@ class Database(Model, AuditMixinNullable):
"[{obj.database_name}].(id:{obj.id})").format(obj=self)
class SqlaTable(Model, Queryable, AuditMixinNullable):
class SqlaTable(Model, Queryable, AuditMixinNullable, ImportMixin):
"""An ORM object for SqlAlchemy table references"""
@ -689,9 +858,13 @@ class SqlaTable(Model, Queryable, AuditMixinNullable):
cache_timeout = Column(Integer)
schema = Column(String(255))
sql = Column(Text)
table_columns = relationship("TableColumn", back_populates="table")
params = Column(Text)
baselink = "tablemodelview"
export_fields = (
'table_name', 'main_dttm_col', 'description', 'default_endpoint',
'database_id', 'is_featured', 'offset', 'cache_timeout', 'schema',
'sql', 'params')
__table_args__ = (
sqla.UniqueConstraint(
@ -773,7 +946,7 @@ class SqlaTable(Model, Queryable, AuditMixinNullable):
}
def get_col(self, col_name):
columns = self.table_columns
columns = self.columns
for col in columns:
if col_name == col.column_name:
return col
@ -1062,8 +1235,67 @@ class SqlaTable(Model, Queryable, AuditMixinNullable):
if not self.main_dttm_col:
self.main_dttm_col = any_date_col
@classmethod
def import_obj(cls, datasource_to_import, import_time=None):
"""Imports the datasource from the object to the database.
class SqlMetric(Model, AuditMixinNullable):
Metrics and columns and datasource will be overrided if exists.
This function can be used to import/export dashboards between multiple
caravel instances. Audit metadata isn't copies over.
"""
session = db.session
make_transient(datasource_to_import)
logging.info('Started import of the datasource: {}'
.format(datasource_to_import.to_json()))
datasource_to_import.id = None
database_name = datasource_to_import.params_dict['database_name']
datasource_to_import.database_id = session.query(Database).filter_by(
database_name=database_name).one().id
datasource_to_import.alter_params(import_time=import_time)
# override the datasource
datasource = (
session.query(SqlaTable).join(Database)
.filter(
SqlaTable.table_name == datasource_to_import.table_name,
SqlaTable.schema == datasource_to_import.schema,
Database.id == datasource_to_import.database_id,
)
.first()
)
if datasource:
datasource.override(datasource_to_import)
session.flush()
else:
datasource = datasource_to_import.copy()
session.add(datasource)
session.flush()
for m in datasource_to_import.metrics:
new_m = m.copy()
new_m.table_id = datasource.id
logging.info('Importing metric {} from the datasource: {}'.format(
new_m.to_json(), datasource_to_import.full_name))
imported_m = SqlMetric.import_obj(new_m)
if imported_m not in datasource.metrics:
datasource.metrics.append(imported_m)
for c in datasource_to_import.columns:
new_c = c.copy()
new_c.table_id = datasource.id
logging.info('Importing column {} from the datasource: {}'.format(
new_c.to_json(), datasource_to_import.full_name))
imported_c = TableColumn.import_obj(new_c)
if imported_c not in datasource.columns:
datasource.columns.append(imported_c)
db.session.flush()
return datasource.id
class SqlMetric(Model, AuditMixinNullable, ImportMixin):
"""ORM object for metrics, each table can have multiple metrics"""
@ -1082,6 +1314,10 @@ class SqlMetric(Model, AuditMixinNullable):
is_restricted = Column(Boolean, default=False, nullable=True)
d3format = Column(String(128))
export_fields = (
'metric_name', 'verbose_name', 'metric_type', 'table_id', 'expression',
'description', 'is_restricted', 'd3format')
@property
def sqla_col(self):
name = self.metric_name
@ -1094,8 +1330,28 @@ class SqlMetric(Model, AuditMixinNullable):
).format(obj=self,
parent_name=self.table.full_name) if self.table else None
@classmethod
def import_obj(cls, metric_to_import):
session = db.session
make_transient(metric_to_import)
metric_to_import.id = None
class TableColumn(Model, AuditMixinNullable):
# find if the column was already imported
existing_metric = session.query(SqlMetric).filter(
SqlMetric.table_id == metric_to_import.table_id,
SqlMetric.metric_name == metric_to_import.metric_name).first()
metric_to_import.table = None
if existing_metric:
existing_metric.override(metric_to_import)
session.flush()
return existing_metric
session.add(metric_to_import)
session.flush()
return metric_to_import
class TableColumn(Model, AuditMixinNullable, ImportMixin):
"""ORM object for table columns, each table can have multiple columns"""
@ -1125,6 +1381,12 @@ class TableColumn(Model, AuditMixinNullable):
num_types = ('DOUBLE', 'FLOAT', 'INT', 'BIGINT', 'LONG')
date_types = ('DATE', 'TIME')
str_types = ('VARCHAR', 'STRING', 'CHAR')
export_fields = (
'table_id', 'column_name', 'verbose_name', 'is_dttm', 'is_active',
'type', 'groupby', 'count_distinct', 'sum', 'max', 'min',
'filterable', 'expression', 'description', 'python_date_format',
'database_expression'
)
def __repr__(self):
return self.column_name
@ -1150,6 +1412,27 @@ class TableColumn(Model, AuditMixinNullable):
col = literal_column(self.expression).label(name)
return col
@classmethod
def import_obj(cls, column_to_import):
session = db.session
make_transient(column_to_import)
column_to_import.id = None
column_to_import.table = None
# find if the column was already imported
existing_column = session.query(TableColumn).filter(
TableColumn.table_id == column_to_import.table_id,
TableColumn.column_name == column_to_import.column_name).first()
column_to_import.table = None
if existing_column:
existing_column.override(column_to_import)
session.flush()
return existing_column
session.add(column_to_import)
session.flush()
return column_to_import
def dttm_sql_literal(self, dttm):
"""Convert datetime object to string
@ -1234,6 +1517,10 @@ class DruidCluster(Model, AuditMixinNullable):
def perm(self):
return "[{obj.cluster_name}].(id:{obj.id})".format(obj=self)
@property
def name(self):
return self.cluster_name
class DruidDatasource(Model, AuditMixinNullable, Queryable):
@ -1262,6 +1549,10 @@ class DruidDatasource(Model, AuditMixinNullable, Queryable):
offset = Column(Integer, default=0)
cache_timeout = Column(Integer)
@property
def database(self):
return self.cluster
@property
def metrics_combo(self):
return sorted(

View File

@ -1,3 +1,4 @@
from sqlalchemy.orm import subqueryload
class SourceRegistry(object):
@ -20,3 +21,30 @@ class SourceRegistry(object):
.filter_by(id=datasource_id)
.one()
)
@classmethod
def get_datasource_by_name(cls, session, datasource_type, datasource_name,
schema, database_name):
datasource_class = SourceRegistry.sources[datasource_type]
datasources = session.query(datasource_class).all()
db_ds = [d for d in datasources if d.database.name == database_name and
d.name == datasource_name and schema == schema]
return db_ds[0]
@classmethod
def get_eager_datasource(cls, session, datasource_type, datasource_id):
"""Returns datasource with columns and metrics."""
datasource_class = SourceRegistry.sources[datasource_type]
if datasource_type == 'table':
return (
session.query(datasource_class)
.options(
subqueryload(datasource_class.columns),
subqueryload(datasource_class.metrics)
)
.filter_by(id=datasource_id)
.one()
)
# TODO: support druid datasources.
return session.query(datasource_class).filter_by(
id=datasource_id).first()

View File

@ -0,0 +1,6 @@
<script>
window.onload = function() {
window.open(window.location += '&action=go');
window.location = '{{ dashboards_url }}';
};
</script>

View File

@ -0,0 +1,23 @@
{% extends "caravel/basic.html" %}
# TODO: move the libs required by flask into the common.js from welcome.js.
{% block head_js %}
{{ super() }}
{% with filename="welcome" %}
{% include "caravel/partials/_script_tag.html" %}
{% endwith %}
{% endblock %}
{% block title %}{{ _("Import") }}{% endblock %}
{% block body %}
{% include "caravel/flash_wrapper.html" %}
<div class="container">
<title>Import the dashboards.</title>
<h1>Import the dashboards.</h1>
<form method=post enctype=multipart/form-data>
<p><input type=file name=file>
<input type=submit value=Upload>
</p>
</form>
</div>
{% endblock %}

View File

@ -5,6 +5,8 @@ from __future__ import unicode_literals
import json
import logging
import os
import pickle
import re
import sys
import time
@ -15,7 +17,8 @@ import functools
import sqlalchemy as sqla
from flask import (
g, request, redirect, flash, Response, render_template, Markup)
g, request, make_response, redirect, flash, Response, render_template,
Markup, url_for)
from flask_appbuilder import ModelView, CompactCRUDMixin, BaseView, expose
from flask_appbuilder.actions import action
from flask_appbuilder.models.sqla.interface import SQLAInterface
@ -26,8 +29,9 @@ from flask_babel import lazy_gettext as _
from flask_appbuilder.models.sqla.filters import BaseFilter
from sqlalchemy import create_engine
from werkzeug.routing import BaseConverter
from werkzeug import secure_filename
from werkzeug.datastructures import ImmutableMultiDict
from werkzeug.routing import BaseConverter
from wtforms.validators import ValidationError
import caravel
@ -533,6 +537,16 @@ class DatabaseView(CaravelModelView, DeleteMixin): # noqa
self.pre_add(db)
appbuilder.add_link(
'Import Dashboards',
label=__("Import Dashboards"),
href='/caravel/import_dashboards',
icon="fa-cloud-upload",
category='Manage',
category_label=__("Manage"),
category_icon='fa-wrench',)
appbuilder.add_view(
DatabaseView,
"Databases",
@ -658,7 +672,6 @@ appbuilder.add_view(
category_label=__("Security"),
icon='fa-table',)
appbuilder.add_separator("Sources")
@ -867,13 +880,32 @@ class DashboardModelView(CaravelModelView, DeleteMixin): # noqa
def pre_delete(self, obj):
check_ownership(obj)
@action("mulexport", "Export", "Export dashboards?", "fa-database")
def mulexport(self, items):
ids = ''.join('&id={}'.format(d.id) for d in items)
return redirect(
'/dashboardmodelview/export_dashboards_form?{}'.format(ids[1:]))
@expose("/export_dashboards_form")
def download_dashboards(self):
if request.args.get('action') == 'go':
ids = request.args.getlist('id')
return Response(
models.Dashboard.export_dashboards(ids),
headers=generate_download_headers("pickle"),
mimetype="application/text")
return self.render_template(
'caravel/export_dashboards.html',
dashboards_url='/dashboardmodelview/list'
)
appbuilder.add_view(
DashboardModelView,
"Dashboards",
label=__("Dashboards"),
icon="fa-dashboard",
category="",
category='',
category_icon='',)
@ -1053,9 +1085,8 @@ class Caravel(BaseCaravelView):
role_to_extend = request.args.get('role_to_extend')
session = db.session
datasource_class = SourceRegistry.sources[datasource_type]
datasource = session.query(datasource_class).filter_by(
id=datasource_id).first()
datasource = SourceRegistry.get_datasource(
datasource_type, datasource_id, session)
if not datasource:
flash(DATASOURCE_MISSING_ERR, "alert")
@ -1149,6 +1180,27 @@ class Caravel(BaseCaravelView):
status=200,
mimetype="application/json")
@expose("/import_dashboards", methods=['GET', 'POST'])
@log_this
def import_dashboards(self):
"""Overrides the dashboards using pickled instances from the file."""
f = request.files.get('file')
if request.method == 'POST' and f:
filename = secure_filename(f.filename)
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
f.save(filepath)
current_tt = int(time.time())
data = pickle.load(open(filepath, 'rb'))
for table in data['datasources']:
models.SqlaTable.import_obj(table, import_time=current_tt)
for dashboard in data['dashboards']:
models.Dashboard.import_obj(
dashboard, import_time=current_tt)
os.remove(filepath)
db.session.commit()
return redirect('/dashboardmodelview/list/')
return self.render_template('caravel/import_dashboards.html')
@log_this
@has_access
@expose("/explore/<datasource_type>/<datasource_id>/")
@ -1478,7 +1530,7 @@ class Caravel(BaseCaravelView):
dash.slices = [o for o in dash.slices if o.id in slice_ids]
positions = sorted(data['positions'], key=lambda x: int(x['slice_id']))
dash.position_json = json.dumps(positions, indent=4, sort_keys=True)
md = dash.metadata_dejson
md = dash.params_dict
if 'filter_immune_slices' not in md:
md['filter_immune_slices'] = []
if 'filter_immune_slice_fields' not in md:

View File

@ -452,6 +452,5 @@ class CoreTests(CaravelTestCase):
db.session.commit()
self.test_save_dash('alpha')
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,368 @@
"""Unit tests for Caravel"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from sqlalchemy.orm.session import make_transient
import json
import pickle
import unittest
from caravel import db, models
from .base_tests import CaravelTestCase
class ImportExportTests(CaravelTestCase):
"""Testing export import functionality for dashboards"""
def __init__(self, *args, **kwargs):
super(ImportExportTests, self).__init__(*args, **kwargs)
@classmethod
def delete_imports(cls):
# Imported data clean up
session = db.session
for slc in session.query(models.Slice):
if 'remote_id' in slc.params_dict:
session.delete(slc)
for dash in session.query(models.Dashboard):
if 'remote_id' in dash.params_dict:
session.delete(dash)
for table in session.query(models.SqlaTable):
if 'remote_id' in table.params_dict:
session.delete(table)
session.commit()
@classmethod
def setUpClass(cls):
cls.delete_imports()
@classmethod
def tearDownClass(cls):
cls.delete_imports()
def create_slice(self, name, ds_id=None, id=None, db_name='main',
table_name='wb_health_population'):
params = {
'num_period_compare': '10',
'remote_id': id,
'datasource_name': table_name,
'database_name': db_name,
'schema': '',
}
if table_name and not ds_id:
table = self.get_table_by_name(table_name)
if table:
ds_id = table.id
return models.Slice(
slice_name=name,
datasource_type='table',
viz_type='bubble',
params=json.dumps(params),
datasource_id=ds_id,
id=id
)
def create_dashboard(self, title, id=0, slcs=[]):
json_metadata = {'remote_id': id}
return models.Dashboard(
id=id,
dashboard_title=title,
slices=slcs,
position_json='{"size_y": 2, "size_x": 2}',
slug='{}_imported'.format(title.lower()),
json_metadata=json.dumps(json_metadata)
)
def create_table(self, name, schema='', id=0, cols_names=[], metric_names=[]):
params = {'remote_id': id, 'database_name': 'main'}
table = models.SqlaTable(
id=id,
schema=schema,
table_name=name,
params=json.dumps(params)
)
for col_name in cols_names:
table.columns.append(
models.TableColumn(column_name=col_name))
for metric_name in metric_names:
table.metrics.append(models.SqlMetric(metric_name=metric_name))
return table
def get_slice(self, slc_id):
return db.session.query(models.Slice).filter_by(id=slc_id).first()
def get_dash(self, dash_id):
return db.session.query(models.Dashboard).filter_by(
id=dash_id).first()
def get_dash_by_slug(self, dash_slug):
return db.session.query(models.Dashboard).filter_by(
slug=dash_slug).first()
def get_table(self, table_id):
return db.session.query(models.SqlaTable).filter_by(
id=table_id).first()
def get_table_by_name(self, name):
return db.session.query(models.SqlaTable).filter_by(
table_name=name).first()
def assert_dash_equals(self, expected_dash, actual_dash):
self.assertEquals(expected_dash.slug, actual_dash.slug)
self.assertEquals(
expected_dash.dashboard_title, actual_dash.dashboard_title)
self.assertEquals(
expected_dash.position_json, actual_dash.position_json)
self.assertEquals(
len(expected_dash.slices), len(actual_dash.slices))
expected_slices = sorted(
expected_dash.slices, key=lambda s: s.slice_name)
actual_slices = sorted(
actual_dash.slices, key=lambda s: s.slice_name)
for e_slc, a_slc in zip(expected_slices, actual_slices):
self.assert_slice_equals(e_slc, a_slc)
def assert_table_equals(self, expected_ds, actual_ds):
self.assertEquals(expected_ds.table_name, actual_ds.table_name)
self.assertEquals(expected_ds.main_dttm_col, actual_ds.main_dttm_col)
self.assertEquals(expected_ds.schema, actual_ds.schema)
self.assertEquals(len(expected_ds.metrics), len(actual_ds.metrics))
self.assertEquals(len(expected_ds.columns), len(actual_ds.columns))
self.assertEquals(
set([c.column_name for c in expected_ds.columns]),
set([c.column_name for c in actual_ds.columns]))
self.assertEquals(
set([m.metric_name for m in expected_ds.metrics]),
set([m.metric_name for m in actual_ds.metrics]))
def assert_slice_equals(self, expected_slc, actual_slc):
self.assertEquals(actual_slc.datasource.perm, actual_slc.perm)
self.assertEquals(expected_slc.slice_name, actual_slc.slice_name)
self.assertEquals(
expected_slc.datasource_type, actual_slc.datasource_type)
self.assertEquals(expected_slc.viz_type, actual_slc.viz_type)
self.assertEquals(
json.loads(expected_slc.params), json.loads(actual_slc.params))
def test_export_1_dashboard(self):
birth_dash = self.get_dash_by_slug('births')
export_dash_url = (
'/dashboardmodelview/export_dashboards_form?id={}&action=go'
.format(birth_dash.id)
)
resp = self.client.get(export_dash_url)
exported_dashboards = pickle.loads(resp.data)['dashboards']
self.assert_dash_equals(birth_dash, exported_dashboards[0])
self.assertEquals(
birth_dash.id,
json.loads(exported_dashboards[0].json_metadata)['remote_id'])
exported_tables = pickle.loads(resp.data)['datasources']
self.assertEquals(1, len(exported_tables))
self.assert_table_equals(
self.get_table_by_name('birth_names'), exported_tables[0])
def test_export_2_dashboards(self):
birth_dash = self.get_dash_by_slug('births')
world_health_dash = self.get_dash_by_slug('world_health')
export_dash_url = (
'/dashboardmodelview/export_dashboards_form?id={}&id={}&action=go'
.format(birth_dash.id, world_health_dash.id))
resp = self.client.get(export_dash_url)
exported_dashboards = sorted(pickle.loads(resp.data)['dashboards'],
key=lambda d: d.dashboard_title)
self.assertEquals(2, len(exported_dashboards))
self.assert_dash_equals(birth_dash, exported_dashboards[0])
self.assertEquals(
birth_dash.id,
json.loads(exported_dashboards[0].json_metadata)['remote_id']
)
self.assert_dash_equals(world_health_dash, exported_dashboards[1])
self.assertEquals(
world_health_dash.id,
json.loads(exported_dashboards[1].json_metadata)['remote_id']
)
exported_tables = sorted(
pickle.loads(resp.data)['datasources'], key=lambda t: t.table_name)
self.assertEquals(2, len(exported_tables))
self.assert_table_equals(
self.get_table_by_name('birth_names'), exported_tables[0])
self.assert_table_equals(
self.get_table_by_name('wb_health_population'), exported_tables[1])
def test_import_1_slice(self):
expected_slice = self.create_slice('Import Me', id=10001);
slc_id = models.Slice.import_obj(expected_slice, import_time=1989)
self.assert_slice_equals(expected_slice, self.get_slice(slc_id))
table_id = self.get_table_by_name('wb_health_population').id
self.assertEquals(table_id, self.get_slice(slc_id).datasource_id)
def test_import_2_slices_for_same_table(self):
table_id = self.get_table_by_name('wb_health_population').id
# table_id != 666, import func will have to find the table
slc_1 = self.create_slice('Import Me 1', ds_id=666, id=10002)
slc_id_1 = models.Slice.import_obj(slc_1)
slc_2 = self.create_slice('Import Me 2', ds_id=666, id=10003)
slc_id_2 = models.Slice.import_obj(slc_2)
imported_slc_1 = self.get_slice(slc_id_1)
imported_slc_2 = self.get_slice(slc_id_2)
self.assertEquals(table_id, imported_slc_1.datasource_id)
self.assert_slice_equals(slc_1, imported_slc_1)
self.assertEquals(table_id, imported_slc_2.datasource_id)
self.assert_slice_equals(slc_2, imported_slc_2)
def test_import_slices_for_non_existent_table(self):
with self.assertRaises(IndexError):
models.Slice.import_obj(self.create_slice(
'Import Me 3', id=10004, table_name='non_existent'))
def test_import_slices_override(self):
slc = self.create_slice('Import Me New', id=10005)
slc_1_id = models.Slice.import_obj(slc, import_time=1990)
slc.slice_name = 'Import Me New'
slc_2_id = models.Slice.import_obj(
self.create_slice('Import Me New', id=10005), import_time=1990)
self.assertEquals(slc_1_id, slc_2_id)
imported_slc = self.get_slice(slc_2_id)
self.assert_slice_equals(slc, imported_slc)
def test_import_empty_dashboard(self):
empty_dash = self.create_dashboard('empty_dashboard', id=10001)
imported_dash_id = models.Dashboard.import_obj(
empty_dash, import_time=1989)
imported_dash = self.get_dash(imported_dash_id)
self.assert_dash_equals(empty_dash, imported_dash)
def test_import_dashboard_1_slice(self):
slc = self.create_slice('health_slc', id=10006)
dash_with_1_slice = self.create_dashboard(
'dash_with_1_slice', slcs=[slc], id=10002)
imported_dash_id = models.Dashboard.import_obj(
dash_with_1_slice, import_time=1990)
imported_dash = self.get_dash(imported_dash_id)
expected_dash = self.create_dashboard(
'dash_with_1_slice', slcs=[slc], id=10002)
make_transient(expected_dash)
self.assert_dash_equals(expected_dash, imported_dash)
self.assertEquals({"remote_id": 10002, "import_time": 1990},
json.loads(imported_dash.json_metadata))
def test_import_dashboard_2_slices(self):
e_slc = self.create_slice('e_slc', id=10007, table_name='energy_usage')
b_slc = self.create_slice('b_slc', id=10008, table_name='birth_names')
dash_with_2_slices = self.create_dashboard(
'dash_with_2_slices', slcs=[e_slc, b_slc], id=10003)
imported_dash_id = models.Dashboard.import_obj(
dash_with_2_slices, import_time=1991)
imported_dash = self.get_dash(imported_dash_id)
expected_dash = self.create_dashboard(
'dash_with_2_slices', slcs=[e_slc, b_slc], id=10003)
make_transient(expected_dash)
self.assert_dash_equals(imported_dash, expected_dash)
self.assertEquals({"remote_id": 10003, "import_time": 1991},
json.loads(imported_dash.json_metadata))
def test_import_override_dashboard_2_slices(self):
e_slc = self.create_slice('e_slc', id=10009, table_name='energy_usage')
b_slc = self.create_slice('b_slc', id=10010, table_name='birth_names')
dash_to_import = self.create_dashboard(
'override_dashboard', slcs=[e_slc, b_slc], id=10004)
imported_dash_id_1 = models.Dashboard.import_obj(
dash_to_import, import_time=1992)
# create new instances of the slices
e_slc = self.create_slice(
'e_slc', id=10009, table_name='energy_usage')
b_slc = self.create_slice(
'b_slc', id=10010, table_name='birth_names')
c_slc = self.create_slice('c_slc', id=10011, table_name='birth_names')
dash_to_import_override = self.create_dashboard(
'override_dashboard_new', slcs=[e_slc, b_slc, c_slc], id=10004)
imported_dash_id_2 = models.Dashboard.import_obj(
dash_to_import_override, import_time=1992)
# override doesn't change the id
self.assertEquals(imported_dash_id_1, imported_dash_id_2)
expected_dash = self.create_dashboard(
'override_dashboard_new', slcs=[e_slc, b_slc, c_slc], id=10004)
make_transient(expected_dash)
imported_dash = self.get_dash(imported_dash_id_2)
self.assert_dash_equals(expected_dash, imported_dash)
self.assertEquals({"remote_id": 10004, "import_time": 1992},
json.loads(imported_dash.json_metadata))
def test_import_table_no_metadata(self):
table = self.create_table('pure_table', id=10001)
imported_t_id = models.SqlaTable.import_obj(table, import_time=1989)
imported_table = self.get_table(imported_t_id)
self.assert_table_equals(table, imported_table)
def test_import_table_1_col_1_met(self):
table = self.create_table(
'table_1_col_1_met', id=10002,
cols_names=["col1"], metric_names=["metric1"])
imported_t_id = models.SqlaTable.import_obj(table, import_time=1990)
imported_table = self.get_table(imported_t_id)
self.assert_table_equals(table, imported_table)
self.assertEquals(
{'remote_id': 10002, 'import_time': 1990, 'database_name': 'main'},
json.loads(imported_table.params))
def test_import_table_2_col_2_met(self):
table = self.create_table(
'table_2_col_2_met', id=10003, cols_names=['c1', 'c2'],
metric_names=['m1', 'm2'])
imported_t_id = models.SqlaTable.import_obj(table, import_time=1991)
imported_table = self.get_table(imported_t_id)
self.assert_table_equals(table, imported_table)
def test_import_table_override(self):
table = self.create_table(
'table_override', id=10003, cols_names=['col1'],
metric_names=['m1'])
imported_t_id = models.SqlaTable.import_obj(table, import_time=1991)
table_over = self.create_table(
'table_override', id=10003, cols_names=['new_col1', 'col2', 'col3'],
metric_names=['new_metric1'])
imported_table_over_id = models.SqlaTable.import_obj(
table_over, import_time=1992)
imported_table_over = self.get_table(imported_table_over_id)
self.assertEquals(imported_t_id, imported_table_over.id)
expected_table = self.create_table(
'table_override', id=10003, metric_names=['new_metric1', 'm1'],
cols_names=['col1', 'new_col1', 'col2', 'col3'])
self.assert_table_equals(expected_table, imported_table_over)
def test_import_table_override_idential(self):
table = self.create_table(
'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'],
metric_names=['new_metric1'])
imported_t_id = models.SqlaTable.import_obj(table, import_time=1993)
copy_table = self.create_table(
'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'],
metric_names=['new_metric1'])
imported_t_id_copy = models.SqlaTable.import_obj(
copy_table, import_time=1994)
self.assertEquals(imported_t_id, imported_t_id_copy)
self.assert_table_equals(copy_table, self.get_table(imported_t_id))
if __name__ == '__main__':
unittest.main()