Import / export of the dashboards. (#1197)
* Implement import / export dashboard functionality. * Address comments from discussion. * Add function descriptions. * Minor fixes * Fix tests for python 3. * Export datasources. * Implement tables import. * Json.loads does not support trailing commas. * Improve alter_dict func * Resolve comments. * Refactor tests * Move params_dict and alter_params to the ImportMixin * Fix flask menues.
This commit is contained in:
parent
cd2ab42abc
commit
73cd2ea3b1
|
|
@ -82,6 +82,12 @@ if app.config.get('ENABLE_CORS'):
|
|||
if app.config.get('ENABLE_PROXY_FIX'):
|
||||
app.wsgi_app = ProxyFix(app.wsgi_app)
|
||||
|
||||
if app.config.get('UPLOAD_FOLDER'):
|
||||
try:
|
||||
os.makedirs(app.config.get('UPLOAD_FOLDER'))
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
class MyIndexView(IndexView):
|
||||
@expose('/')
|
||||
|
|
|
|||
|
|
@ -1058,7 +1058,7 @@ def load_multiformat_time_series_data():
|
|||
'string0': ['%Y-%m-%d %H:%M:%S.%f', None],
|
||||
'string3': ['%Y/%m/%d%H:%M:%S.%f', None],
|
||||
}
|
||||
for col in obj.table_columns:
|
||||
for col in obj.columns:
|
||||
dttm_and_expr = dttm_and_expr_dict[col.column_name]
|
||||
col.python_date_format = dttm_and_expr[0]
|
||||
col.dbatabase_expr = dttm_and_expr[1]
|
||||
|
|
@ -1069,7 +1069,7 @@ def load_multiformat_time_series_data():
|
|||
tbl = obj
|
||||
|
||||
print("Creating some slices")
|
||||
for i, col in enumerate(tbl.table_columns):
|
||||
for i, col in enumerate(tbl.columns):
|
||||
slice_data = {
|
||||
"granularity_sqla": col.column_name,
|
||||
"datasource_id": "8",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,28 @@
|
|||
"""Add json_metadata to the tables table.
|
||||
|
||||
Revision ID: b46fa1b0b39e
|
||||
Revises: ef8843b41dac
|
||||
Create Date: 2016-10-05 11:30:31.748238
|
||||
|
||||
"""
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'b46fa1b0b39e'
|
||||
down_revision = 'ef8843b41dac'
|
||||
|
||||
from alembic import op
|
||||
import logging
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.add_column('tables',
|
||||
sa.Column('params', sa.Text(), nullable=True))
|
||||
|
||||
|
||||
def downgrade():
|
||||
try:
|
||||
op.drop_column('tables', 'params')
|
||||
except Exception as e:
|
||||
logging.warning(str(e))
|
||||
|
||||
|
|
@ -7,6 +7,7 @@ from __future__ import unicode_literals
|
|||
import functools
|
||||
import json
|
||||
import logging
|
||||
import pickle
|
||||
import re
|
||||
import textwrap
|
||||
from collections import namedtuple
|
||||
|
|
@ -18,6 +19,8 @@ import pandas as pd
|
|||
import requests
|
||||
import sqlalchemy as sqla
|
||||
from sqlalchemy.engine.url import make_url
|
||||
from sqlalchemy.orm import subqueryload
|
||||
|
||||
import sqlparse
|
||||
from dateutil.parser import parse
|
||||
|
||||
|
|
@ -41,6 +44,7 @@ from sqlalchemy import (
|
|||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.ext.declarative import declared_attr
|
||||
from sqlalchemy.orm import backref, relationship
|
||||
from sqlalchemy.orm.session import make_transient
|
||||
from sqlalchemy.sql import table, literal_column, text, column
|
||||
from sqlalchemy.sql.expression import ColumnClause, TextAsFrom
|
||||
from sqlalchemy_utils import EncryptedType
|
||||
|
|
@ -70,6 +74,31 @@ class JavascriptPostAggregator(Postaggregator):
|
|||
self.name = name
|
||||
|
||||
|
||||
class ImportMixin(object):
|
||||
def override(self, obj):
|
||||
"""Overrides the plain fields of the dashboard."""
|
||||
for field in obj.__class__.export_fields:
|
||||
setattr(self, field, getattr(obj, field))
|
||||
|
||||
def copy(self):
|
||||
"""Creates a copy of the dashboard without relationships."""
|
||||
new_obj = self.__class__()
|
||||
new_obj.override(self)
|
||||
return new_obj
|
||||
|
||||
def alter_params(self, **kwargs):
|
||||
d = self.params_dict
|
||||
d.update(kwargs)
|
||||
self.params = json.dumps(d)
|
||||
|
||||
@property
|
||||
def params_dict(self):
|
||||
if self.params:
|
||||
return json.loads(self.params)
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
class AuditMixinNullable(AuditMixin):
|
||||
|
||||
"""Altering the AuditMixin to use nullable fields
|
||||
|
|
@ -149,7 +178,7 @@ slice_user = Table('slice_user', Model.metadata,
|
|||
)
|
||||
|
||||
|
||||
class Slice(Model, AuditMixinNullable):
|
||||
class Slice(Model, AuditMixinNullable, ImportMixin):
|
||||
|
||||
"""A slice is essentially a report or a view on data"""
|
||||
|
||||
|
|
@ -166,6 +195,9 @@ class Slice(Model, AuditMixinNullable):
|
|||
perm = Column(String(2000))
|
||||
owners = relationship("User", secondary=slice_user)
|
||||
|
||||
export_fields = ('slice_name', 'datasource_type', 'datasource_name',
|
||||
'viz_type', 'params', 'cache_timeout')
|
||||
|
||||
def __repr__(self):
|
||||
return self.slice_name
|
||||
|
||||
|
|
@ -283,6 +315,42 @@ class Slice(Model, AuditMixinNullable):
|
|||
slice_=self
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def import_obj(cls, slc_to_import, import_time=None):
|
||||
"""Inserts or overrides slc in the database.
|
||||
|
||||
remote_id and import_time fields in params_dict are set to track the
|
||||
slice origin and ensure correct overrides for multiple imports.
|
||||
Slice.perm is used to find the datasources and connect them.
|
||||
"""
|
||||
session = db.session
|
||||
make_transient(slc_to_import)
|
||||
slc_to_import.dashboards = []
|
||||
slc_to_import.alter_params(
|
||||
remote_id=slc_to_import.id, import_time=import_time)
|
||||
|
||||
# find if the slice was already imported
|
||||
slc_to_override = None
|
||||
for slc in session.query(Slice).all():
|
||||
if ('remote_id' in slc.params_dict and
|
||||
slc.params_dict['remote_id'] == slc_to_import.id):
|
||||
slc_to_override = slc
|
||||
|
||||
slc_to_import.id = None
|
||||
params = slc_to_import.params_dict
|
||||
slc_to_import.datasource_id = SourceRegistry.get_datasource_by_name(
|
||||
session, slc_to_import.datasource_type, params['datasource_name'],
|
||||
params['schema'], params['database_name']).id
|
||||
if slc_to_override:
|
||||
slc_to_override.override(slc_to_import)
|
||||
session.flush()
|
||||
return slc_to_override.id
|
||||
else:
|
||||
session.add(slc_to_import)
|
||||
logging.info('Final slice: {}'.format(slc_to_import.to_json()))
|
||||
session.flush()
|
||||
return slc_to_import.id
|
||||
|
||||
|
||||
def set_perm(mapper, connection, target): # noqa
|
||||
src_class = target.cls_model
|
||||
|
|
@ -309,7 +377,7 @@ dashboard_user = Table(
|
|||
)
|
||||
|
||||
|
||||
class Dashboard(Model, AuditMixinNullable):
|
||||
class Dashboard(Model, AuditMixinNullable, ImportMixin):
|
||||
|
||||
"""The dashboard object!"""
|
||||
|
||||
|
|
@ -325,6 +393,9 @@ class Dashboard(Model, AuditMixinNullable):
|
|||
'Slice', secondary=dashboard_slices, backref='dashboards')
|
||||
owners = relationship("User", secondary=dashboard_user)
|
||||
|
||||
export_fields = ('dashboard_title', 'position_json', 'json_metadata',
|
||||
'description', 'css', 'slug', 'slices')
|
||||
|
||||
def __repr__(self):
|
||||
return self.dashboard_title
|
||||
|
||||
|
|
@ -340,13 +411,6 @@ class Dashboard(Model, AuditMixinNullable):
|
|||
def datasources(self):
|
||||
return {slc.datasource for slc in self.slices}
|
||||
|
||||
@property
|
||||
def metadata_dejson(self):
|
||||
if self.json_metadata:
|
||||
return json.loads(self.json_metadata)
|
||||
else:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def sqla_metadata(self):
|
||||
metadata = MetaData(bind=self.get_sqla_engine())
|
||||
|
|
@ -361,7 +425,7 @@ class Dashboard(Model, AuditMixinNullable):
|
|||
def json_data(self):
|
||||
d = {
|
||||
'id': self.id,
|
||||
'metadata': self.metadata_dejson,
|
||||
'metadata': self.params_dict,
|
||||
'dashboard_title': self.dashboard_title,
|
||||
'slug': self.slug,
|
||||
'slices': [slc.data for slc in self.slices],
|
||||
|
|
@ -369,6 +433,107 @@ class Dashboard(Model, AuditMixinNullable):
|
|||
}
|
||||
return json.dumps(d)
|
||||
|
||||
@property
|
||||
def params(self):
|
||||
return self.json_metadata
|
||||
|
||||
@params.setter
|
||||
def params(self, value):
|
||||
self.json_metadata = value
|
||||
|
||||
@classmethod
|
||||
def import_obj(cls, dashboard_to_import, import_time=None):
|
||||
"""Imports the dashboard from the object to the database.
|
||||
|
||||
Once dashboard is imported, json_metadata field is extended and stores
|
||||
remote_id and import_time. It helps to decide if the dashboard has to
|
||||
be overridden or just copies over. Slices that belong to this
|
||||
dashboard will be wired to existing tables. This function can be used
|
||||
to import/export dashboards between multiple caravel instances.
|
||||
Audit metadata isn't copies over.
|
||||
"""
|
||||
logging.info('Started import of the dashboard: {}'
|
||||
.format(dashboard_to_import.to_json()))
|
||||
session = db.session
|
||||
logging.info('Dashboard has {} slices'
|
||||
.format(len(dashboard_to_import.slices)))
|
||||
# copy slices object as Slice.import_slice will mutate the slice
|
||||
# and will remove the existing dashboard - slice association
|
||||
slices = copy(dashboard_to_import.slices)
|
||||
slice_ids = set()
|
||||
for slc in slices:
|
||||
logging.info('Importing slice {} from the dashboard: {}'.format(
|
||||
slc.to_json(), dashboard_to_import.dashboard_title))
|
||||
slice_ids.add(Slice.import_obj(slc, import_time=import_time))
|
||||
|
||||
# override the dashboard
|
||||
existing_dashboard = None
|
||||
for dash in session.query(Dashboard).all():
|
||||
if ('remote_id' in dash.params_dict and
|
||||
dash.params_dict['remote_id'] ==
|
||||
dashboard_to_import.id):
|
||||
existing_dashboard = dash
|
||||
|
||||
dashboard_to_import.id = None
|
||||
dashboard_to_import.alter_params(import_time=import_time)
|
||||
new_slices = session.query(Slice).filter(Slice.id.in_(slice_ids)).all()
|
||||
|
||||
if existing_dashboard:
|
||||
existing_dashboard.override(dashboard_to_import)
|
||||
existing_dashboard.slices = new_slices
|
||||
session.flush()
|
||||
return existing_dashboard.id
|
||||
else:
|
||||
# session.add(dashboard_to_import) causes sqlachemy failures
|
||||
# related to the attached users / slices. Creating new object
|
||||
# allows to avoid conflicts in the sql alchemy state.
|
||||
copied_dash = dashboard_to_import.copy()
|
||||
copied_dash.slices = new_slices
|
||||
session.add(copied_dash)
|
||||
session.flush()
|
||||
return copied_dash.id
|
||||
|
||||
@classmethod
|
||||
def export_dashboards(cls, dashboard_ids):
|
||||
copied_dashboards = []
|
||||
datasource_ids = set()
|
||||
for dashboard_id in dashboard_ids:
|
||||
# make sure that dashboard_id is an integer
|
||||
dashboard_id = int(dashboard_id)
|
||||
copied_dashboard = (
|
||||
db.session.query(Dashboard)
|
||||
.options(subqueryload(Dashboard.slices))
|
||||
.filter_by(id=dashboard_id).first()
|
||||
)
|
||||
make_transient(copied_dashboard)
|
||||
for slc in copied_dashboard.slices:
|
||||
datasource_ids.add((slc.datasource_id, slc.datasource_type))
|
||||
# add extra params for the import
|
||||
slc.alter_params(
|
||||
remote_id=slc.id,
|
||||
datasource_name=slc.datasource.name,
|
||||
schema=slc.datasource.name,
|
||||
database_name=slc.datasource.database.database_name,
|
||||
)
|
||||
copied_dashboard.alter_params(remote_id=dashboard_id)
|
||||
copied_dashboards.append(copied_dashboard)
|
||||
|
||||
eager_datasources = []
|
||||
for dashboard_id, dashboard_type in datasource_ids:
|
||||
eager_datasource = SourceRegistry.get_eager_datasource(
|
||||
db.session, dashboard_type, dashboard_id)
|
||||
eager_datasource.alter_params(
|
||||
remote_id=eager_datasource.id,
|
||||
database_name=eager_datasource.database.database_name,
|
||||
)
|
||||
make_transient(eager_datasource)
|
||||
eager_datasources.append(eager_datasource)
|
||||
|
||||
return pickle.dumps({
|
||||
'dashboards': copied_dashboards,
|
||||
'datasources': eager_datasources,
|
||||
})
|
||||
|
||||
|
||||
class Queryable(object):
|
||||
|
||||
|
|
@ -433,6 +598,10 @@ class Database(Model, AuditMixinNullable):
|
|||
def __repr__(self):
|
||||
return self.database_name
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.database_name
|
||||
|
||||
@property
|
||||
def backend(self):
|
||||
url = make_url(self.sqlalchemy_uri_decrypted)
|
||||
|
|
@ -665,7 +834,7 @@ class Database(Model, AuditMixinNullable):
|
|||
"[{obj.database_name}].(id:{obj.id})").format(obj=self)
|
||||
|
||||
|
||||
class SqlaTable(Model, Queryable, AuditMixinNullable):
|
||||
class SqlaTable(Model, Queryable, AuditMixinNullable, ImportMixin):
|
||||
|
||||
"""An ORM object for SqlAlchemy table references"""
|
||||
|
||||
|
|
@ -689,9 +858,13 @@ class SqlaTable(Model, Queryable, AuditMixinNullable):
|
|||
cache_timeout = Column(Integer)
|
||||
schema = Column(String(255))
|
||||
sql = Column(Text)
|
||||
table_columns = relationship("TableColumn", back_populates="table")
|
||||
params = Column(Text)
|
||||
|
||||
baselink = "tablemodelview"
|
||||
export_fields = (
|
||||
'table_name', 'main_dttm_col', 'description', 'default_endpoint',
|
||||
'database_id', 'is_featured', 'offset', 'cache_timeout', 'schema',
|
||||
'sql', 'params')
|
||||
|
||||
__table_args__ = (
|
||||
sqla.UniqueConstraint(
|
||||
|
|
@ -773,7 +946,7 @@ class SqlaTable(Model, Queryable, AuditMixinNullable):
|
|||
}
|
||||
|
||||
def get_col(self, col_name):
|
||||
columns = self.table_columns
|
||||
columns = self.columns
|
||||
for col in columns:
|
||||
if col_name == col.column_name:
|
||||
return col
|
||||
|
|
@ -1062,8 +1235,67 @@ class SqlaTable(Model, Queryable, AuditMixinNullable):
|
|||
if not self.main_dttm_col:
|
||||
self.main_dttm_col = any_date_col
|
||||
|
||||
@classmethod
|
||||
def import_obj(cls, datasource_to_import, import_time=None):
|
||||
"""Imports the datasource from the object to the database.
|
||||
|
||||
class SqlMetric(Model, AuditMixinNullable):
|
||||
Metrics and columns and datasource will be overrided if exists.
|
||||
This function can be used to import/export dashboards between multiple
|
||||
caravel instances. Audit metadata isn't copies over.
|
||||
"""
|
||||
session = db.session
|
||||
make_transient(datasource_to_import)
|
||||
logging.info('Started import of the datasource: {}'
|
||||
.format(datasource_to_import.to_json()))
|
||||
|
||||
datasource_to_import.id = None
|
||||
database_name = datasource_to_import.params_dict['database_name']
|
||||
datasource_to_import.database_id = session.query(Database).filter_by(
|
||||
database_name=database_name).one().id
|
||||
datasource_to_import.alter_params(import_time=import_time)
|
||||
|
||||
# override the datasource
|
||||
datasource = (
|
||||
session.query(SqlaTable).join(Database)
|
||||
.filter(
|
||||
SqlaTable.table_name == datasource_to_import.table_name,
|
||||
SqlaTable.schema == datasource_to_import.schema,
|
||||
Database.id == datasource_to_import.database_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if datasource:
|
||||
datasource.override(datasource_to_import)
|
||||
session.flush()
|
||||
else:
|
||||
datasource = datasource_to_import.copy()
|
||||
session.add(datasource)
|
||||
session.flush()
|
||||
|
||||
for m in datasource_to_import.metrics:
|
||||
new_m = m.copy()
|
||||
new_m.table_id = datasource.id
|
||||
logging.info('Importing metric {} from the datasource: {}'.format(
|
||||
new_m.to_json(), datasource_to_import.full_name))
|
||||
imported_m = SqlMetric.import_obj(new_m)
|
||||
if imported_m not in datasource.metrics:
|
||||
datasource.metrics.append(imported_m)
|
||||
|
||||
for c in datasource_to_import.columns:
|
||||
new_c = c.copy()
|
||||
new_c.table_id = datasource.id
|
||||
logging.info('Importing column {} from the datasource: {}'.format(
|
||||
new_c.to_json(), datasource_to_import.full_name))
|
||||
imported_c = TableColumn.import_obj(new_c)
|
||||
if imported_c not in datasource.columns:
|
||||
datasource.columns.append(imported_c)
|
||||
db.session.flush()
|
||||
|
||||
return datasource.id
|
||||
|
||||
|
||||
class SqlMetric(Model, AuditMixinNullable, ImportMixin):
|
||||
|
||||
"""ORM object for metrics, each table can have multiple metrics"""
|
||||
|
||||
|
|
@ -1082,6 +1314,10 @@ class SqlMetric(Model, AuditMixinNullable):
|
|||
is_restricted = Column(Boolean, default=False, nullable=True)
|
||||
d3format = Column(String(128))
|
||||
|
||||
export_fields = (
|
||||
'metric_name', 'verbose_name', 'metric_type', 'table_id', 'expression',
|
||||
'description', 'is_restricted', 'd3format')
|
||||
|
||||
@property
|
||||
def sqla_col(self):
|
||||
name = self.metric_name
|
||||
|
|
@ -1094,8 +1330,28 @@ class SqlMetric(Model, AuditMixinNullable):
|
|||
).format(obj=self,
|
||||
parent_name=self.table.full_name) if self.table else None
|
||||
|
||||
@classmethod
|
||||
def import_obj(cls, metric_to_import):
|
||||
session = db.session
|
||||
make_transient(metric_to_import)
|
||||
metric_to_import.id = None
|
||||
|
||||
class TableColumn(Model, AuditMixinNullable):
|
||||
# find if the column was already imported
|
||||
existing_metric = session.query(SqlMetric).filter(
|
||||
SqlMetric.table_id == metric_to_import.table_id,
|
||||
SqlMetric.metric_name == metric_to_import.metric_name).first()
|
||||
metric_to_import.table = None
|
||||
if existing_metric:
|
||||
existing_metric.override(metric_to_import)
|
||||
session.flush()
|
||||
return existing_metric
|
||||
|
||||
session.add(metric_to_import)
|
||||
session.flush()
|
||||
return metric_to_import
|
||||
|
||||
|
||||
class TableColumn(Model, AuditMixinNullable, ImportMixin):
|
||||
|
||||
"""ORM object for table columns, each table can have multiple columns"""
|
||||
|
||||
|
|
@ -1125,6 +1381,12 @@ class TableColumn(Model, AuditMixinNullable):
|
|||
num_types = ('DOUBLE', 'FLOAT', 'INT', 'BIGINT', 'LONG')
|
||||
date_types = ('DATE', 'TIME')
|
||||
str_types = ('VARCHAR', 'STRING', 'CHAR')
|
||||
export_fields = (
|
||||
'table_id', 'column_name', 'verbose_name', 'is_dttm', 'is_active',
|
||||
'type', 'groupby', 'count_distinct', 'sum', 'max', 'min',
|
||||
'filterable', 'expression', 'description', 'python_date_format',
|
||||
'database_expression'
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return self.column_name
|
||||
|
|
@ -1150,6 +1412,27 @@ class TableColumn(Model, AuditMixinNullable):
|
|||
col = literal_column(self.expression).label(name)
|
||||
return col
|
||||
|
||||
@classmethod
|
||||
def import_obj(cls, column_to_import):
|
||||
session = db.session
|
||||
make_transient(column_to_import)
|
||||
column_to_import.id = None
|
||||
column_to_import.table = None
|
||||
|
||||
# find if the column was already imported
|
||||
existing_column = session.query(TableColumn).filter(
|
||||
TableColumn.table_id == column_to_import.table_id,
|
||||
TableColumn.column_name == column_to_import.column_name).first()
|
||||
column_to_import.table = None
|
||||
if existing_column:
|
||||
existing_column.override(column_to_import)
|
||||
session.flush()
|
||||
return existing_column
|
||||
|
||||
session.add(column_to_import)
|
||||
session.flush()
|
||||
return column_to_import
|
||||
|
||||
def dttm_sql_literal(self, dttm):
|
||||
"""Convert datetime object to string
|
||||
|
||||
|
|
@ -1234,6 +1517,10 @@ class DruidCluster(Model, AuditMixinNullable):
|
|||
def perm(self):
|
||||
return "[{obj.cluster_name}].(id:{obj.id})".format(obj=self)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.cluster_name
|
||||
|
||||
|
||||
class DruidDatasource(Model, AuditMixinNullable, Queryable):
|
||||
|
||||
|
|
@ -1262,6 +1549,10 @@ class DruidDatasource(Model, AuditMixinNullable, Queryable):
|
|||
offset = Column(Integer, default=0)
|
||||
cache_timeout = Column(Integer)
|
||||
|
||||
@property
|
||||
def database(self):
|
||||
return self.cluster
|
||||
|
||||
@property
|
||||
def metrics_combo(self):
|
||||
return sorted(
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from sqlalchemy.orm import subqueryload
|
||||
|
||||
|
||||
class SourceRegistry(object):
|
||||
|
|
@ -20,3 +21,30 @@ class SourceRegistry(object):
|
|||
.filter_by(id=datasource_id)
|
||||
.one()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_datasource_by_name(cls, session, datasource_type, datasource_name,
|
||||
schema, database_name):
|
||||
datasource_class = SourceRegistry.sources[datasource_type]
|
||||
datasources = session.query(datasource_class).all()
|
||||
db_ds = [d for d in datasources if d.database.name == database_name and
|
||||
d.name == datasource_name and schema == schema]
|
||||
return db_ds[0]
|
||||
|
||||
@classmethod
|
||||
def get_eager_datasource(cls, session, datasource_type, datasource_id):
|
||||
"""Returns datasource with columns and metrics."""
|
||||
datasource_class = SourceRegistry.sources[datasource_type]
|
||||
if datasource_type == 'table':
|
||||
return (
|
||||
session.query(datasource_class)
|
||||
.options(
|
||||
subqueryload(datasource_class.columns),
|
||||
subqueryload(datasource_class.metrics)
|
||||
)
|
||||
.filter_by(id=datasource_id)
|
||||
.one()
|
||||
)
|
||||
# TODO: support druid datasources.
|
||||
return session.query(datasource_class).filter_by(
|
||||
id=datasource_id).first()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,6 @@
|
|||
<script>
|
||||
window.onload = function() {
|
||||
window.open(window.location += '&action=go');
|
||||
window.location = '{{ dashboards_url }}';
|
||||
};
|
||||
</script>
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
{% extends "caravel/basic.html" %}
|
||||
|
||||
# TODO: move the libs required by flask into the common.js from welcome.js.
|
||||
{% block head_js %}
|
||||
{{ super() }}
|
||||
{% with filename="welcome" %}
|
||||
{% include "caravel/partials/_script_tag.html" %}
|
||||
{% endwith %}
|
||||
{% endblock %}
|
||||
|
||||
{% block title %}{{ _("Import") }}{% endblock %}
|
||||
{% block body %}
|
||||
{% include "caravel/flash_wrapper.html" %}
|
||||
<div class="container">
|
||||
<title>Import the dashboards.</title>
|
||||
<h1>Import the dashboards.</h1>
|
||||
<form method=post enctype=multipart/form-data>
|
||||
<p><input type=file name=file>
|
||||
<input type=submit value=Upload>
|
||||
</p>
|
||||
</form>
|
||||
</div>
|
||||
{% endblock %}
|
||||
|
|
@ -5,6 +5,8 @@ from __future__ import unicode_literals
|
|||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
|
|
@ -15,7 +17,8 @@ import functools
|
|||
import sqlalchemy as sqla
|
||||
|
||||
from flask import (
|
||||
g, request, redirect, flash, Response, render_template, Markup)
|
||||
g, request, make_response, redirect, flash, Response, render_template,
|
||||
Markup, url_for)
|
||||
from flask_appbuilder import ModelView, CompactCRUDMixin, BaseView, expose
|
||||
from flask_appbuilder.actions import action
|
||||
from flask_appbuilder.models.sqla.interface import SQLAInterface
|
||||
|
|
@ -26,8 +29,9 @@ from flask_babel import lazy_gettext as _
|
|||
from flask_appbuilder.models.sqla.filters import BaseFilter
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from werkzeug.routing import BaseConverter
|
||||
from werkzeug import secure_filename
|
||||
from werkzeug.datastructures import ImmutableMultiDict
|
||||
from werkzeug.routing import BaseConverter
|
||||
from wtforms.validators import ValidationError
|
||||
|
||||
import caravel
|
||||
|
|
@ -533,6 +537,16 @@ class DatabaseView(CaravelModelView, DeleteMixin): # noqa
|
|||
self.pre_add(db)
|
||||
|
||||
|
||||
appbuilder.add_link(
|
||||
'Import Dashboards',
|
||||
label=__("Import Dashboards"),
|
||||
href='/caravel/import_dashboards',
|
||||
icon="fa-cloud-upload",
|
||||
category='Manage',
|
||||
category_label=__("Manage"),
|
||||
category_icon='fa-wrench',)
|
||||
|
||||
|
||||
appbuilder.add_view(
|
||||
DatabaseView,
|
||||
"Databases",
|
||||
|
|
@ -658,7 +672,6 @@ appbuilder.add_view(
|
|||
category_label=__("Security"),
|
||||
icon='fa-table',)
|
||||
|
||||
|
||||
appbuilder.add_separator("Sources")
|
||||
|
||||
|
||||
|
|
@ -867,13 +880,32 @@ class DashboardModelView(CaravelModelView, DeleteMixin): # noqa
|
|||
def pre_delete(self, obj):
|
||||
check_ownership(obj)
|
||||
|
||||
@action("mulexport", "Export", "Export dashboards?", "fa-database")
|
||||
def mulexport(self, items):
|
||||
ids = ''.join('&id={}'.format(d.id) for d in items)
|
||||
return redirect(
|
||||
'/dashboardmodelview/export_dashboards_form?{}'.format(ids[1:]))
|
||||
|
||||
@expose("/export_dashboards_form")
|
||||
def download_dashboards(self):
|
||||
if request.args.get('action') == 'go':
|
||||
ids = request.args.getlist('id')
|
||||
return Response(
|
||||
models.Dashboard.export_dashboards(ids),
|
||||
headers=generate_download_headers("pickle"),
|
||||
mimetype="application/text")
|
||||
return self.render_template(
|
||||
'caravel/export_dashboards.html',
|
||||
dashboards_url='/dashboardmodelview/list'
|
||||
)
|
||||
|
||||
|
||||
appbuilder.add_view(
|
||||
DashboardModelView,
|
||||
"Dashboards",
|
||||
label=__("Dashboards"),
|
||||
icon="fa-dashboard",
|
||||
category="",
|
||||
category='',
|
||||
category_icon='',)
|
||||
|
||||
|
||||
|
|
@ -1053,9 +1085,8 @@ class Caravel(BaseCaravelView):
|
|||
role_to_extend = request.args.get('role_to_extend')
|
||||
|
||||
session = db.session
|
||||
datasource_class = SourceRegistry.sources[datasource_type]
|
||||
datasource = session.query(datasource_class).filter_by(
|
||||
id=datasource_id).first()
|
||||
datasource = SourceRegistry.get_datasource(
|
||||
datasource_type, datasource_id, session)
|
||||
|
||||
if not datasource:
|
||||
flash(DATASOURCE_MISSING_ERR, "alert")
|
||||
|
|
@ -1149,6 +1180,27 @@ class Caravel(BaseCaravelView):
|
|||
status=200,
|
||||
mimetype="application/json")
|
||||
|
||||
@expose("/import_dashboards", methods=['GET', 'POST'])
|
||||
@log_this
|
||||
def import_dashboards(self):
|
||||
"""Overrides the dashboards using pickled instances from the file."""
|
||||
f = request.files.get('file')
|
||||
if request.method == 'POST' and f:
|
||||
filename = secure_filename(f.filename)
|
||||
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
||||
f.save(filepath)
|
||||
current_tt = int(time.time())
|
||||
data = pickle.load(open(filepath, 'rb'))
|
||||
for table in data['datasources']:
|
||||
models.SqlaTable.import_obj(table, import_time=current_tt)
|
||||
for dashboard in data['dashboards']:
|
||||
models.Dashboard.import_obj(
|
||||
dashboard, import_time=current_tt)
|
||||
os.remove(filepath)
|
||||
db.session.commit()
|
||||
return redirect('/dashboardmodelview/list/')
|
||||
return self.render_template('caravel/import_dashboards.html')
|
||||
|
||||
@log_this
|
||||
@has_access
|
||||
@expose("/explore/<datasource_type>/<datasource_id>/")
|
||||
|
|
@ -1478,7 +1530,7 @@ class Caravel(BaseCaravelView):
|
|||
dash.slices = [o for o in dash.slices if o.id in slice_ids]
|
||||
positions = sorted(data['positions'], key=lambda x: int(x['slice_id']))
|
||||
dash.position_json = json.dumps(positions, indent=4, sort_keys=True)
|
||||
md = dash.metadata_dejson
|
||||
md = dash.params_dict
|
||||
if 'filter_immune_slices' not in md:
|
||||
md['filter_immune_slices'] = []
|
||||
if 'filter_immune_slice_fields' not in md:
|
||||
|
|
|
|||
|
|
@ -452,6 +452,5 @@ class CoreTests(CaravelTestCase):
|
|||
db.session.commit()
|
||||
self.test_save_dash('alpha')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,368 @@
|
|||
"""Unit tests for Caravel"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from sqlalchemy.orm.session import make_transient
|
||||
|
||||
import json
|
||||
import pickle
|
||||
import unittest
|
||||
|
||||
from caravel import db, models
|
||||
|
||||
from .base_tests import CaravelTestCase
|
||||
|
||||
|
||||
class ImportExportTests(CaravelTestCase):
|
||||
"""Testing export import functionality for dashboards"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ImportExportTests, self).__init__(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def delete_imports(cls):
|
||||
# Imported data clean up
|
||||
session = db.session
|
||||
for slc in session.query(models.Slice):
|
||||
if 'remote_id' in slc.params_dict:
|
||||
session.delete(slc)
|
||||
for dash in session.query(models.Dashboard):
|
||||
if 'remote_id' in dash.params_dict:
|
||||
session.delete(dash)
|
||||
for table in session.query(models.SqlaTable):
|
||||
if 'remote_id' in table.params_dict:
|
||||
session.delete(table)
|
||||
session.commit()
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.delete_imports()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.delete_imports()
|
||||
|
||||
def create_slice(self, name, ds_id=None, id=None, db_name='main',
|
||||
table_name='wb_health_population'):
|
||||
params = {
|
||||
'num_period_compare': '10',
|
||||
'remote_id': id,
|
||||
'datasource_name': table_name,
|
||||
'database_name': db_name,
|
||||
'schema': '',
|
||||
}
|
||||
|
||||
if table_name and not ds_id:
|
||||
table = self.get_table_by_name(table_name)
|
||||
if table:
|
||||
ds_id = table.id
|
||||
|
||||
return models.Slice(
|
||||
slice_name=name,
|
||||
datasource_type='table',
|
||||
viz_type='bubble',
|
||||
params=json.dumps(params),
|
||||
datasource_id=ds_id,
|
||||
id=id
|
||||
)
|
||||
|
||||
def create_dashboard(self, title, id=0, slcs=[]):
|
||||
json_metadata = {'remote_id': id}
|
||||
return models.Dashboard(
|
||||
id=id,
|
||||
dashboard_title=title,
|
||||
slices=slcs,
|
||||
position_json='{"size_y": 2, "size_x": 2}',
|
||||
slug='{}_imported'.format(title.lower()),
|
||||
json_metadata=json.dumps(json_metadata)
|
||||
)
|
||||
|
||||
def create_table(self, name, schema='', id=0, cols_names=[], metric_names=[]):
|
||||
params = {'remote_id': id, 'database_name': 'main'}
|
||||
table = models.SqlaTable(
|
||||
id=id,
|
||||
schema=schema,
|
||||
table_name=name,
|
||||
params=json.dumps(params)
|
||||
)
|
||||
for col_name in cols_names:
|
||||
table.columns.append(
|
||||
models.TableColumn(column_name=col_name))
|
||||
for metric_name in metric_names:
|
||||
table.metrics.append(models.SqlMetric(metric_name=metric_name))
|
||||
return table
|
||||
|
||||
def get_slice(self, slc_id):
|
||||
return db.session.query(models.Slice).filter_by(id=slc_id).first()
|
||||
|
||||
def get_dash(self, dash_id):
|
||||
return db.session.query(models.Dashboard).filter_by(
|
||||
id=dash_id).first()
|
||||
|
||||
def get_dash_by_slug(self, dash_slug):
|
||||
return db.session.query(models.Dashboard).filter_by(
|
||||
slug=dash_slug).first()
|
||||
|
||||
def get_table(self, table_id):
|
||||
return db.session.query(models.SqlaTable).filter_by(
|
||||
id=table_id).first()
|
||||
|
||||
def get_table_by_name(self, name):
|
||||
return db.session.query(models.SqlaTable).filter_by(
|
||||
table_name=name).first()
|
||||
|
||||
def assert_dash_equals(self, expected_dash, actual_dash):
|
||||
self.assertEquals(expected_dash.slug, actual_dash.slug)
|
||||
self.assertEquals(
|
||||
expected_dash.dashboard_title, actual_dash.dashboard_title)
|
||||
self.assertEquals(
|
||||
expected_dash.position_json, actual_dash.position_json)
|
||||
self.assertEquals(
|
||||
len(expected_dash.slices), len(actual_dash.slices))
|
||||
expected_slices = sorted(
|
||||
expected_dash.slices, key=lambda s: s.slice_name)
|
||||
actual_slices = sorted(
|
||||
actual_dash.slices, key=lambda s: s.slice_name)
|
||||
for e_slc, a_slc in zip(expected_slices, actual_slices):
|
||||
self.assert_slice_equals(e_slc, a_slc)
|
||||
|
||||
def assert_table_equals(self, expected_ds, actual_ds):
|
||||
self.assertEquals(expected_ds.table_name, actual_ds.table_name)
|
||||
self.assertEquals(expected_ds.main_dttm_col, actual_ds.main_dttm_col)
|
||||
self.assertEquals(expected_ds.schema, actual_ds.schema)
|
||||
self.assertEquals(len(expected_ds.metrics), len(actual_ds.metrics))
|
||||
self.assertEquals(len(expected_ds.columns), len(actual_ds.columns))
|
||||
self.assertEquals(
|
||||
set([c.column_name for c in expected_ds.columns]),
|
||||
set([c.column_name for c in actual_ds.columns]))
|
||||
self.assertEquals(
|
||||
set([m.metric_name for m in expected_ds.metrics]),
|
||||
set([m.metric_name for m in actual_ds.metrics]))
|
||||
|
||||
def assert_slice_equals(self, expected_slc, actual_slc):
|
||||
self.assertEquals(actual_slc.datasource.perm, actual_slc.perm)
|
||||
self.assertEquals(expected_slc.slice_name, actual_slc.slice_name)
|
||||
self.assertEquals(
|
||||
expected_slc.datasource_type, actual_slc.datasource_type)
|
||||
self.assertEquals(expected_slc.viz_type, actual_slc.viz_type)
|
||||
self.assertEquals(
|
||||
json.loads(expected_slc.params), json.loads(actual_slc.params))
|
||||
|
||||
def test_export_1_dashboard(self):
|
||||
birth_dash = self.get_dash_by_slug('births')
|
||||
export_dash_url = (
|
||||
'/dashboardmodelview/export_dashboards_form?id={}&action=go'
|
||||
.format(birth_dash.id)
|
||||
)
|
||||
resp = self.client.get(export_dash_url)
|
||||
exported_dashboards = pickle.loads(resp.data)['dashboards']
|
||||
self.assert_dash_equals(birth_dash, exported_dashboards[0])
|
||||
self.assertEquals(
|
||||
birth_dash.id,
|
||||
json.loads(exported_dashboards[0].json_metadata)['remote_id'])
|
||||
|
||||
exported_tables = pickle.loads(resp.data)['datasources']
|
||||
self.assertEquals(1, len(exported_tables))
|
||||
self.assert_table_equals(
|
||||
self.get_table_by_name('birth_names'), exported_tables[0])
|
||||
|
||||
def test_export_2_dashboards(self):
|
||||
birth_dash = self.get_dash_by_slug('births')
|
||||
world_health_dash = self.get_dash_by_slug('world_health')
|
||||
export_dash_url = (
|
||||
'/dashboardmodelview/export_dashboards_form?id={}&id={}&action=go'
|
||||
.format(birth_dash.id, world_health_dash.id))
|
||||
resp = self.client.get(export_dash_url)
|
||||
exported_dashboards = sorted(pickle.loads(resp.data)['dashboards'],
|
||||
key=lambda d: d.dashboard_title)
|
||||
self.assertEquals(2, len(exported_dashboards))
|
||||
self.assert_dash_equals(birth_dash, exported_dashboards[0])
|
||||
self.assertEquals(
|
||||
birth_dash.id,
|
||||
json.loads(exported_dashboards[0].json_metadata)['remote_id']
|
||||
)
|
||||
|
||||
self.assert_dash_equals(world_health_dash, exported_dashboards[1])
|
||||
self.assertEquals(
|
||||
world_health_dash.id,
|
||||
json.loads(exported_dashboards[1].json_metadata)['remote_id']
|
||||
)
|
||||
|
||||
exported_tables = sorted(
|
||||
pickle.loads(resp.data)['datasources'], key=lambda t: t.table_name)
|
||||
self.assertEquals(2, len(exported_tables))
|
||||
self.assert_table_equals(
|
||||
self.get_table_by_name('birth_names'), exported_tables[0])
|
||||
self.assert_table_equals(
|
||||
self.get_table_by_name('wb_health_population'), exported_tables[1])
|
||||
|
||||
def test_import_1_slice(self):
|
||||
expected_slice = self.create_slice('Import Me', id=10001);
|
||||
slc_id = models.Slice.import_obj(expected_slice, import_time=1989)
|
||||
self.assert_slice_equals(expected_slice, self.get_slice(slc_id))
|
||||
|
||||
table_id = self.get_table_by_name('wb_health_population').id
|
||||
self.assertEquals(table_id, self.get_slice(slc_id).datasource_id)
|
||||
|
||||
def test_import_2_slices_for_same_table(self):
|
||||
table_id = self.get_table_by_name('wb_health_population').id
|
||||
# table_id != 666, import func will have to find the table
|
||||
slc_1 = self.create_slice('Import Me 1', ds_id=666, id=10002)
|
||||
slc_id_1 = models.Slice.import_obj(slc_1)
|
||||
slc_2 = self.create_slice('Import Me 2', ds_id=666, id=10003)
|
||||
slc_id_2 = models.Slice.import_obj(slc_2)
|
||||
|
||||
imported_slc_1 = self.get_slice(slc_id_1)
|
||||
imported_slc_2 = self.get_slice(slc_id_2)
|
||||
self.assertEquals(table_id, imported_slc_1.datasource_id)
|
||||
self.assert_slice_equals(slc_1, imported_slc_1)
|
||||
|
||||
self.assertEquals(table_id, imported_slc_2.datasource_id)
|
||||
self.assert_slice_equals(slc_2, imported_slc_2)
|
||||
|
||||
def test_import_slices_for_non_existent_table(self):
|
||||
with self.assertRaises(IndexError):
|
||||
models.Slice.import_obj(self.create_slice(
|
||||
'Import Me 3', id=10004, table_name='non_existent'))
|
||||
|
||||
def test_import_slices_override(self):
|
||||
slc = self.create_slice('Import Me New', id=10005)
|
||||
slc_1_id = models.Slice.import_obj(slc, import_time=1990)
|
||||
slc.slice_name = 'Import Me New'
|
||||
slc_2_id = models.Slice.import_obj(
|
||||
self.create_slice('Import Me New', id=10005), import_time=1990)
|
||||
self.assertEquals(slc_1_id, slc_2_id)
|
||||
imported_slc = self.get_slice(slc_2_id)
|
||||
self.assert_slice_equals(slc, imported_slc)
|
||||
|
||||
def test_import_empty_dashboard(self):
|
||||
empty_dash = self.create_dashboard('empty_dashboard', id=10001)
|
||||
imported_dash_id = models.Dashboard.import_obj(
|
||||
empty_dash, import_time=1989)
|
||||
imported_dash = self.get_dash(imported_dash_id)
|
||||
self.assert_dash_equals(empty_dash, imported_dash)
|
||||
|
||||
def test_import_dashboard_1_slice(self):
|
||||
slc = self.create_slice('health_slc', id=10006)
|
||||
dash_with_1_slice = self.create_dashboard(
|
||||
'dash_with_1_slice', slcs=[slc], id=10002)
|
||||
imported_dash_id = models.Dashboard.import_obj(
|
||||
dash_with_1_slice, import_time=1990)
|
||||
imported_dash = self.get_dash(imported_dash_id)
|
||||
|
||||
expected_dash = self.create_dashboard(
|
||||
'dash_with_1_slice', slcs=[slc], id=10002)
|
||||
make_transient(expected_dash)
|
||||
self.assert_dash_equals(expected_dash, imported_dash)
|
||||
self.assertEquals({"remote_id": 10002, "import_time": 1990},
|
||||
json.loads(imported_dash.json_metadata))
|
||||
|
||||
def test_import_dashboard_2_slices(self):
|
||||
e_slc = self.create_slice('e_slc', id=10007, table_name='energy_usage')
|
||||
b_slc = self.create_slice('b_slc', id=10008, table_name='birth_names')
|
||||
dash_with_2_slices = self.create_dashboard(
|
||||
'dash_with_2_slices', slcs=[e_slc, b_slc], id=10003)
|
||||
imported_dash_id = models.Dashboard.import_obj(
|
||||
dash_with_2_slices, import_time=1991)
|
||||
imported_dash = self.get_dash(imported_dash_id)
|
||||
|
||||
expected_dash = self.create_dashboard(
|
||||
'dash_with_2_slices', slcs=[e_slc, b_slc], id=10003)
|
||||
make_transient(expected_dash)
|
||||
self.assert_dash_equals(imported_dash, expected_dash)
|
||||
self.assertEquals({"remote_id": 10003, "import_time": 1991},
|
||||
json.loads(imported_dash.json_metadata))
|
||||
|
||||
def test_import_override_dashboard_2_slices(self):
|
||||
e_slc = self.create_slice('e_slc', id=10009, table_name='energy_usage')
|
||||
b_slc = self.create_slice('b_slc', id=10010, table_name='birth_names')
|
||||
dash_to_import = self.create_dashboard(
|
||||
'override_dashboard', slcs=[e_slc, b_slc], id=10004)
|
||||
imported_dash_id_1 = models.Dashboard.import_obj(
|
||||
dash_to_import, import_time=1992)
|
||||
|
||||
# create new instances of the slices
|
||||
e_slc = self.create_slice(
|
||||
'e_slc', id=10009, table_name='energy_usage')
|
||||
b_slc = self.create_slice(
|
||||
'b_slc', id=10010, table_name='birth_names')
|
||||
c_slc = self.create_slice('c_slc', id=10011, table_name='birth_names')
|
||||
dash_to_import_override = self.create_dashboard(
|
||||
'override_dashboard_new', slcs=[e_slc, b_slc, c_slc], id=10004)
|
||||
imported_dash_id_2 = models.Dashboard.import_obj(
|
||||
dash_to_import_override, import_time=1992)
|
||||
|
||||
# override doesn't change the id
|
||||
self.assertEquals(imported_dash_id_1, imported_dash_id_2)
|
||||
expected_dash = self.create_dashboard(
|
||||
'override_dashboard_new', slcs=[e_slc, b_slc, c_slc], id=10004)
|
||||
make_transient(expected_dash)
|
||||
imported_dash = self.get_dash(imported_dash_id_2)
|
||||
self.assert_dash_equals(expected_dash, imported_dash)
|
||||
self.assertEquals({"remote_id": 10004, "import_time": 1992},
|
||||
json.loads(imported_dash.json_metadata))
|
||||
|
||||
def test_import_table_no_metadata(self):
|
||||
table = self.create_table('pure_table', id=10001)
|
||||
imported_t_id = models.SqlaTable.import_obj(table, import_time=1989)
|
||||
imported_table = self.get_table(imported_t_id)
|
||||
self.assert_table_equals(table, imported_table)
|
||||
|
||||
def test_import_table_1_col_1_met(self):
|
||||
table = self.create_table(
|
||||
'table_1_col_1_met', id=10002,
|
||||
cols_names=["col1"], metric_names=["metric1"])
|
||||
imported_t_id = models.SqlaTable.import_obj(table, import_time=1990)
|
||||
imported_table = self.get_table(imported_t_id)
|
||||
self.assert_table_equals(table, imported_table)
|
||||
self.assertEquals(
|
||||
{'remote_id': 10002, 'import_time': 1990, 'database_name': 'main'},
|
||||
json.loads(imported_table.params))
|
||||
|
||||
def test_import_table_2_col_2_met(self):
|
||||
table = self.create_table(
|
||||
'table_2_col_2_met', id=10003, cols_names=['c1', 'c2'],
|
||||
metric_names=['m1', 'm2'])
|
||||
imported_t_id = models.SqlaTable.import_obj(table, import_time=1991)
|
||||
|
||||
imported_table = self.get_table(imported_t_id)
|
||||
self.assert_table_equals(table, imported_table)
|
||||
|
||||
def test_import_table_override(self):
|
||||
table = self.create_table(
|
||||
'table_override', id=10003, cols_names=['col1'],
|
||||
metric_names=['m1'])
|
||||
imported_t_id = models.SqlaTable.import_obj(table, import_time=1991)
|
||||
|
||||
table_over = self.create_table(
|
||||
'table_override', id=10003, cols_names=['new_col1', 'col2', 'col3'],
|
||||
metric_names=['new_metric1'])
|
||||
imported_table_over_id = models.SqlaTable.import_obj(
|
||||
table_over, import_time=1992)
|
||||
|
||||
imported_table_over = self.get_table(imported_table_over_id)
|
||||
self.assertEquals(imported_t_id, imported_table_over.id)
|
||||
expected_table = self.create_table(
|
||||
'table_override', id=10003, metric_names=['new_metric1', 'm1'],
|
||||
cols_names=['col1', 'new_col1', 'col2', 'col3'])
|
||||
self.assert_table_equals(expected_table, imported_table_over)
|
||||
|
||||
def test_import_table_override_idential(self):
|
||||
table = self.create_table(
|
||||
'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'],
|
||||
metric_names=['new_metric1'])
|
||||
imported_t_id = models.SqlaTable.import_obj(table, import_time=1993)
|
||||
|
||||
copy_table = self.create_table(
|
||||
'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'],
|
||||
metric_names=['new_metric1'])
|
||||
imported_t_id_copy = models.SqlaTable.import_obj(
|
||||
copy_table, import_time=1994)
|
||||
|
||||
self.assertEquals(imported_t_id, imported_t_id_copy)
|
||||
self.assert_table_equals(copy_table, self.get_table(imported_t_id))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue