Refactor Dashboard and Slice models (#8820)

* refactor dashboard and slice models

* appease various linters

* remove shortcuts & import indirection

* appease mypy

* fix bad imports

* lint

* address various issues

* ignore type issue

* remove unused imports

* lint
This commit is contained in:
David Aaron Suddjian 2019-12-18 11:40:45 -08:00 committed by Maxime Beauchemin
parent cbf860074b
commit 016f202423
34 changed files with 929 additions and 838 deletions

View File

@ -45,7 +45,7 @@ combine_as_imports = true
include_trailing_comma = true
line_length = 88
known_first_party = superset
known_third_party =alembic,backoff,bleach,celery,click,colorama,contextlib2,croniter,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,marshmallow,msgpack,numpy,pandas,parsedatetime,pathlib2,polyline,prison,psycopg2,pyarrow,pyhive,pytz,retry,selenium,setuptools,simplejson,sphinx_rtd_theme,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml
known_third_party =alembic,backoff,bleach,celery,click,colorama,contextlib2,croniter,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parsedatetime,pathlib2,polyline,prison,psycopg2,pyarrow,pyhive,pytz,retry,selenium,setuptools,simplejson,sphinx_rtd_theme,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml
multi_line_output = 3
order_by_type = false

View File

@ -23,8 +23,8 @@ from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import foreign, Query, relationship
from superset.constants import NULL_STRING
from superset.models.core import Slice
from superset.models.helpers import AuditMixinNullable, ImportMixin, QueryResult
from superset.models.slice import Slice
from superset.utils import core as utils

View File

@ -23,16 +23,16 @@ from sqlalchemy.sql import column
from superset import db, security_manager
from superset.connectors.sqla.models import SqlMetric, TableColumn
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.utils.core import get_example_database
from .helpers import (
config,
Dash,
get_example_data,
get_slice_json,
merge_slice,
misc_dash_slices,
Slice,
TBL,
update_slice_ids,
)
@ -441,10 +441,10 @@ def load_birth_names(only_metadata=False, force=False):
misc_dash_slices.add(slc.slice_name)
print("Creating a dashboard")
dash = db.session.query(Dash).filter_by(slug="births").first()
dash = db.session.query(Dashboard).filter_by(slug="births").first()
if not dash:
dash = Dash()
dash = Dashboard()
db.session.add(dash)
dash.published = True
dash.json_metadata = textwrap.dedent(

View File

@ -22,6 +22,7 @@ from sqlalchemy.sql import column
from superset import db
from superset.connectors.sqla.models import SqlMetric
from superset.models.slice import Slice
from superset.utils import core as utils
from .helpers import (
@ -29,7 +30,6 @@ from .helpers import (
get_slice_json,
merge_slice,
misc_dash_slices,
Slice,
TBL,
)

View File

@ -18,8 +18,10 @@
import json
from superset import db
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from .helpers import Dash, get_slice_json, merge_slice, Slice, TBL, update_slice_ids
from .helpers import get_slice_json, merge_slice, TBL, update_slice_ids
COLOR_RED = {"r": 205, "g": 0, "b": 3, "a": 0.82}
POSITION_JSON = """\
@ -509,10 +511,10 @@ def load_deck_dash():
print("Creating a dashboard")
title = "deck.gl Demo"
dash = db.session.query(Dash).filter_by(slug=slug).first()
dash = db.session.query(Dashboard).filter_by(slug=slug).first()
if not dash:
dash = Dash()
dash = Dashboard()
dash.published = True
js = POSITION_JSON
pos = json.loads(js)

View File

@ -23,9 +23,10 @@ from sqlalchemy.sql import column
from superset import db
from superset.connectors.sqla.models import SqlMetric
from superset.models.slice import Slice
from superset.utils import core as utils
from .helpers import get_example_data, merge_slice, misc_dash_slices, Slice, TBL
from .helpers import get_example_data, merge_slice, misc_dash_slices, TBL
def load_energy(only_metadata=False, force=False):

View File

@ -25,13 +25,12 @@ from urllib import request
from superset import app, db
from superset.connectors.connector_registry import ConnectorRegistry
from superset.models import core as models
from superset.models.slice import Slice
BASE_URL = "https://github.com/apache-superset/examples-data/blob/master/"
# Shortcuts
DB = models.Database
Slice = models.Slice
Dash = models.Dashboard
TBL = ConnectorRegistry.sources["table"]

View File

@ -22,6 +22,7 @@ import pandas as pd
from sqlalchemy import DateTime, Float, String
from superset import db
from superset.models.slice import Slice
from superset.utils import core as utils
from .helpers import (
@ -29,7 +30,6 @@ from .helpers import (
get_slice_json,
merge_slice,
misc_dash_slices,
Slice,
TBL,
)

View File

@ -18,8 +18,10 @@ import json
import textwrap
from superset import db
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from .helpers import Dash, misc_dash_slices, Slice, update_slice_ids
from .helpers import misc_dash_slices, update_slice_ids
DASH_SLUG = "misc_charts"
@ -29,10 +31,10 @@ def load_misc_dashboard():
print("Creating the dashboard")
db.session.expunge_all()
dash = db.session.query(Dash).filter_by(slug=DASH_SLUG).first()
dash = db.session.query(Dashboard).filter_by(slug=DASH_SLUG).first()
if not dash:
dash = Dash()
dash = Dashboard()
js = textwrap.dedent(
"""\
{

View File

@ -17,9 +17,10 @@
import json
from superset import db
from superset.models.slice import Slice
from .birth_names import load_birth_names
from .helpers import merge_slice, misc_dash_slices, Slice
from .helpers import merge_slice, misc_dash_slices
from .world_bank import load_world_bank_health_n_pop

View File

@ -19,6 +19,7 @@ import pandas as pd
from sqlalchemy import BigInteger, Date, DateTime, String
from superset import db
from superset.models.slice import Slice
from superset.utils.core import get_example_database
from .helpers import (
@ -27,7 +28,6 @@ from .helpers import (
get_slice_json,
merge_slice,
misc_dash_slices,
Slice,
TBL,
)

View File

@ -19,9 +19,10 @@ import pandas as pd
from sqlalchemy import DateTime
from superset import db
from superset.models.slice import Slice
from superset.utils import core as utils
from .helpers import config, get_example_data, get_slice_json, merge_slice, Slice, TBL
from .helpers import config, get_example_data, get_slice_json, merge_slice, TBL
def load_random_time_series_data(only_metadata=False, force=False):

View File

@ -19,8 +19,10 @@ import json
import textwrap
from superset import db
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from .helpers import Dash, Slice, update_slice_ids
from .helpers import update_slice_ids
def load_tabbed_dashboard(_=False):
@ -28,10 +30,10 @@ def load_tabbed_dashboard(_=False):
print("Creating a dashboard with nested tabs")
slug = "tabbed_dash"
dash = db.session.query(Dash).filter_by(slug=slug).first()
dash = db.session.query(Dashboard).filter_by(slug=slug).first()
if not dash:
dash = Dash()
dash = Dashboard()
# reuse charts in "World's Bank Data and create
# new dashboard with nested tabs

View File

@ -22,15 +22,15 @@ import pandas as pd
from sqlalchemy import Date, Float, String
from superset import db
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.utils import core as utils
from .helpers import (
config,
Dash,
get_example_data,
get_slice_json,
merge_slice,
Slice,
TBL,
update_slice_ids,
)
@ -109,10 +109,10 @@ def load_unicode_test_data(only_metadata=False, force=False):
merge_slice(slc)
print("Creating a dashboard")
dash = db.session.query(Dash).filter_by(slug="unicode-test").first()
dash = db.session.query(Dashboard).filter_by(slug="unicode-test").first()
if not dash:
dash = Dash()
dash = Dashboard()
js = """\
{
"CHART-Hkx6154FEm": {

View File

@ -25,17 +25,17 @@ from sqlalchemy.sql import column
from superset import db
from superset.connectors.sqla.models import SqlMetric
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.utils import core as utils
from .helpers import (
config,
Dash,
EXAMPLES_FOLDER,
get_example_data,
get_slice_json,
merge_slice,
misc_dash_slices,
Slice,
TBL,
update_slice_ids,
)
@ -332,10 +332,10 @@ def load_world_bank_health_n_pop(
print("Creating a World's Health Bank dashboard")
dash_name = "World Bank's Data"
slug = "world_health"
dash = db.session.query(Dash).filter_by(slug=slug).first()
dash = db.session.query(Dashboard).filter_by(slug=slug).first()
if not dash:
dash = Dash()
dash = Dashboard()
dash.published = True
js = textwrap.dedent(
"""\

View File

@ -20,19 +20,16 @@ import json
import logging
import textwrap
from contextlib import closing
from copy import copy, deepcopy
from copy import deepcopy
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TYPE_CHECKING
from urllib import parse
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type
import numpy
import pandas as pd
import sqlalchemy as sqla
import sqlparse
from flask import escape, g, Markup, request
from flask import g, request
from flask_appbuilder import Model
from flask_appbuilder.models.decorators import renders
from flask_appbuilder.security.sqla.models import User
from sqlalchemy import (
Boolean,
Column,
@ -48,27 +45,18 @@ from sqlalchemy import (
from sqlalchemy.engine import Dialect, Engine, url
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import make_url, URL
from sqlalchemy.orm import relationship, sessionmaker, subqueryload
from sqlalchemy.orm.session import make_transient
from sqlalchemy.orm import relationship
from sqlalchemy.pool import NullPool
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import Select
from sqlalchemy_utils import EncryptedType
from superset import app, db, db_engine_specs, is_feature_enabled, security_manager
from superset.connectors.connector_registry import ConnectorRegistry
from superset import app, db_engine_specs, is_feature_enabled, security_manager
from superset.db_engine_specs.base import TimeGrain
from superset.legacy import update_time_range
from superset.models.dashboard import Dashboard
from superset.models.helpers import AuditMixinNullable, ImportMixin
from superset.models.tags import ChartUpdater, DashboardUpdater, FavStarUpdater
from superset.models.user_attributes import UserAttribute
from superset.models.tags import DashboardUpdater, FavStarUpdater
from superset.utils import cache as cache_util, core as utils
from superset.viz import BaseViz, viz_types
if TYPE_CHECKING:
from superset.connectors.base.models import ( # pylint: disable=unused-import
BaseDatasource,
)
config = app.config
custom_password_store = config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"]
@ -80,50 +68,6 @@ PASSWORD_MASK = "X" * 10
DB_CONNECTION_MUTATOR = config["DB_CONNECTION_MUTATOR"]
def set_related_perm(mapper, connection, target):
src_class = target.cls_model
id_ = target.datasource_id
if id_:
ds = db.session.query(src_class).filter_by(id=int(id_)).first()
if ds:
target.perm = ds.perm
target.schema_perm = ds.schema_perm
def copy_dashboard(mapper, connection, target):
dashboard_id = config["DASHBOARD_TEMPLATE_ID"]
if dashboard_id is None:
return
session_class = sessionmaker(autoflush=False)
session = session_class(bind=connection)
new_user = session.query(User).filter_by(id=target.id).first()
# copy template dashboard to user
template = session.query(Dashboard).filter_by(id=int(dashboard_id)).first()
dashboard = Dashboard(
dashboard_title=template.dashboard_title,
position_json=template.position_json,
description=template.description,
css=template.css,
json_metadata=template.json_metadata,
slices=template.slices,
owners=[new_user],
)
session.add(dashboard)
session.commit()
# set dashboard as the welcome dashboard
extra_attributes = UserAttribute(
user_id=target.id, welcome_dashboard_id=dashboard.id
)
session.add(extra_attributes)
session.commit()
sqla.event.listen(User, "after_insert", copy_dashboard)
class Url(Model, AuditMixinNullable):
"""Used for the short url feature"""
@ -151,608 +95,6 @@ class CssTemplate(Model, AuditMixinNullable):
css = Column(Text, default="")
slice_user = Table(
"slice_user",
metadata,
Column("id", Integer, primary_key=True),
Column("user_id", Integer, ForeignKey("ab_user.id")),
Column("slice_id", Integer, ForeignKey("slices.id")),
)
class Slice(
Model, AuditMixinNullable, ImportMixin
): # pylint: disable=too-many-public-methods
"""A slice is essentially a report or a view on data"""
__tablename__ = "slices"
id = Column(Integer, primary_key=True) # pylint: disable=invalid-name
slice_name = Column(String(250))
datasource_id = Column(Integer)
datasource_type = Column(String(200))
datasource_name = Column(String(2000))
viz_type = Column(String(250))
params = Column(Text)
description = Column(Text)
cache_timeout = Column(Integer)
perm = Column(String(1000))
schema_perm = Column(String(1000))
owners = relationship(security_manager.user_model, secondary=slice_user)
token = ""
export_fields = [
"slice_name",
"datasource_type",
"datasource_name",
"viz_type",
"params",
"cache_timeout",
]
def __repr__(self):
return self.slice_name or str(self.id)
@property
def cls_model(self) -> Type["BaseDatasource"]:
return ConnectorRegistry.sources[self.datasource_type]
@property
def datasource(self) -> "BaseDatasource":
return self.get_datasource
def clone(self) -> "Slice":
return Slice(
slice_name=self.slice_name,
datasource_id=self.datasource_id,
datasource_type=self.datasource_type,
datasource_name=self.datasource_name,
viz_type=self.viz_type,
params=self.params,
description=self.description,
cache_timeout=self.cache_timeout,
)
# pylint: disable=using-constant-test
@datasource.getter # type: ignore
@utils.memoized
def get_datasource(self) -> Optional["BaseDatasource"]:
return db.session.query(self.cls_model).filter_by(id=self.datasource_id).first()
@renders("datasource_name")
def datasource_link(self) -> Optional[Markup]:
# pylint: disable=no-member
datasource = self.datasource
return datasource.link if datasource else None
def datasource_name_text(self) -> Optional[str]:
# pylint: disable=no-member
datasource = self.datasource
return datasource.name if datasource else None
@property
def datasource_edit_url(self) -> Optional[str]:
# pylint: disable=no-member
datasource = self.datasource
return datasource.url if datasource else None
# pylint: enable=using-constant-test
@property # type: ignore
@utils.memoized
def viz(self) -> BaseViz:
d = json.loads(self.params)
viz_class = viz_types[self.viz_type]
return viz_class(datasource=self.datasource, form_data=d)
@property
def description_markeddown(self) -> str:
return utils.markdown(self.description)
@property
def data(self) -> Dict[str, Any]:
"""Data used to render slice in templates"""
d: Dict[str, Any] = {}
self.token = ""
try:
d = self.viz.data
self.token = d.get("token") # type: ignore
except Exception as e: # pylint: disable=broad-except
logging.exception(e)
d["error"] = str(e)
return {
"datasource": self.datasource_name,
"description": self.description,
"description_markeddown": self.description_markeddown,
"edit_url": self.edit_url,
"form_data": self.form_data,
"slice_id": self.id,
"slice_name": self.slice_name,
"slice_url": self.slice_url,
"modified": self.modified(),
"changed_on_humanized": self.changed_on_humanized,
"changed_on": self.changed_on.isoformat(),
}
@property
def json_data(self) -> str:
return json.dumps(self.data)
@property
def form_data(self) -> Dict[str, Any]:
form_data: Dict[str, Any] = {}
try:
form_data = json.loads(self.params)
except Exception as e: # pylint: disable=broad-except
logging.error("Malformed json in slice's params")
logging.exception(e)
form_data.update(
{
"slice_id": self.id,
"viz_type": self.viz_type,
"datasource": "{}__{}".format(self.datasource_id, self.datasource_type),
}
)
if self.cache_timeout:
form_data["cache_timeout"] = self.cache_timeout
update_time_range(form_data)
return form_data
def get_explore_url(
self,
base_url: str = "/superset/explore",
overrides: Optional[Dict[str, Any]] = None,
) -> str:
overrides = overrides or {}
form_data = {"slice_id": self.id}
form_data.update(overrides)
params = parse.quote(json.dumps(form_data))
return f"{base_url}/?form_data={params}"
@property
def slice_url(self) -> str:
"""Defines the url to access the slice"""
return self.get_explore_url()
@property
def explore_json_url(self) -> str:
"""Defines the url to access the slice"""
return self.get_explore_url("/superset/explore_json")
@property
def edit_url(self) -> str:
return f"/chart/edit/{self.id}"
@property
def chart(self) -> str:
return self.slice_name or "<empty>"
@property
def slice_link(self) -> Markup:
name = escape(self.chart)
return Markup(f'<a href="{self.url}">{name}</a>')
def get_viz(self, force: bool = False) -> BaseViz:
"""Creates :py:class:viz.BaseViz object from the url_params_multidict.
:return: object of the 'viz_type' type that is taken from the
url_params_multidict or self.params.
:rtype: :py:class:viz.BaseViz
"""
slice_params = json.loads(self.params)
slice_params["slice_id"] = self.id
slice_params["json"] = "false"
slice_params["slice_name"] = self.slice_name
slice_params["viz_type"] = self.viz_type if self.viz_type else "table"
return viz_types[slice_params.get("viz_type")](
self.datasource, form_data=slice_params, force=force
)
@property
def icons(self) -> str:
return f"""
<a
href="{self.datasource_edit_url}"
data-toggle="tooltip"
title="{self.datasource}">
<i class="fa fa-database"></i>
</a>
"""
@classmethod
def import_obj(
cls,
slc_to_import: "Slice",
slc_to_override: Optional["Slice"],
import_time: Optional[int] = None,
) -> int:
"""Inserts or overrides slc in the database.
remote_id and import_time fields in params_dict are set to track the
slice origin and ensure correct overrides for multiple imports.
Slice.perm is used to find the datasources and connect them.
:param Slice slc_to_import: Slice object to import
:param Slice slc_to_override: Slice to replace, id matches remote_id
:returns: The resulting id for the imported slice
:rtype: int
"""
session = db.session
make_transient(slc_to_import)
slc_to_import.dashboards = []
slc_to_import.alter_params(remote_id=slc_to_import.id, import_time=import_time)
slc_to_import = slc_to_import.copy()
slc_to_import.reset_ownership()
params = slc_to_import.params_dict
slc_to_import.datasource_id = ConnectorRegistry.get_datasource_by_name( # type: ignore
session,
slc_to_import.datasource_type,
params["datasource_name"],
params["schema"],
params["database_name"],
).id
if slc_to_override:
slc_to_override.override(slc_to_import)
session.flush()
return slc_to_override.id
session.add(slc_to_import)
logging.info("Final slice: %s", str(slc_to_import.to_json()))
session.flush()
return slc_to_import.id
@property
def url(self) -> str:
return f"/superset/explore/?form_data=%7B%22slice_id%22%3A%20{self.id}%7D"
sqla.event.listen(Slice, "before_insert", set_related_perm)
sqla.event.listen(Slice, "before_update", set_related_perm)
dashboard_slices = Table(
"dashboard_slices",
metadata,
Column("id", Integer, primary_key=True),
Column("dashboard_id", Integer, ForeignKey("dashboards.id")),
Column("slice_id", Integer, ForeignKey("slices.id")),
UniqueConstraint("dashboard_id", "slice_id"),
)
dashboard_user = Table(
"dashboard_user",
metadata,
Column("id", Integer, primary_key=True),
Column("user_id", Integer, ForeignKey("ab_user.id")),
Column("dashboard_id", Integer, ForeignKey("dashboards.id")),
)
class Dashboard( # pylint: disable=too-many-instance-attributes
Model, AuditMixinNullable, ImportMixin
):
"""The dashboard object!"""
__tablename__ = "dashboards"
id = Column(Integer, primary_key=True) # pylint: disable=invalid-name
dashboard_title = Column(String(500))
position_json = Column(utils.MediumText())
description = Column(Text)
css = Column(Text)
json_metadata = Column(Text)
slug = Column(String(255), unique=True)
slices = relationship("Slice", secondary=dashboard_slices, backref="dashboards")
owners = relationship(security_manager.user_model, secondary=dashboard_user)
published = Column(Boolean, default=False)
export_fields = [
"dashboard_title",
"position_json",
"json_metadata",
"description",
"css",
"slug",
]
def __repr__(self):
return self.dashboard_title or str(self.id)
@property
def table_names(self) -> str:
# pylint: disable=no-member
return ", ".join(str(s.datasource.full_name) for s in self.slices)
@property
def url(self) -> str:
if self.json_metadata:
# add default_filters to the preselect_filters of dashboard
json_metadata = json.loads(self.json_metadata)
default_filters = json_metadata.get("default_filters")
# make sure default_filters is not empty and is valid
if default_filters and default_filters != "{}":
try:
if json.loads(default_filters):
filters = parse.quote(default_filters.encode("utf8"))
return "/superset/dashboard/{}/?preselect_filters={}".format(
self.slug or self.id, filters
)
except Exception: # pylint: disable=broad-except
pass
return f"/superset/dashboard/{self.slug or self.id}/"
@property
def datasources(self) -> Set[Optional["BaseDatasource"]]:
return {slc.datasource for slc in self.slices}
@property
def charts(self) -> List[Optional["BaseDatasource"]]:
return [slc.chart for slc in self.slices]
@property
def sqla_metadata(self) -> None:
# pylint: disable=no-member
meta = MetaData(bind=self.get_sqla_engine())
meta.reflect()
@renders("dashboard_title")
def dashboard_link(self) -> Markup:
title = escape(self.dashboard_title or "<empty>")
return Markup(f'<a href="{self.url}">{title}</a>')
@property
def data(self) -> Dict[str, Any]:
positions = self.position_json
if positions:
positions = json.loads(positions)
return {
"id": self.id,
"metadata": self.params_dict,
"css": self.css,
"dashboard_title": self.dashboard_title,
"published": self.published,
"slug": self.slug,
"slices": [slc.data for slc in self.slices],
"position_json": positions,
}
@property
def params(self) -> str:
return self.json_metadata
@params.setter
def params(self, value: str) -> None:
self.json_metadata = value
@property
def position(self) -> Dict:
if self.position_json:
return json.loads(self.position_json)
return {}
@classmethod
def import_obj( # pylint: disable=too-many-locals,too-many-branches,too-many-statements
cls, dashboard_to_import: "Dashboard", import_time: Optional[int] = None
) -> int:
"""Imports the dashboard from the object to the database.
Once dashboard is imported, json_metadata field is extended and stores
remote_id and import_time. It helps to decide if the dashboard has to
be overridden or just copies over. Slices that belong to this
dashboard will be wired to existing tables. This function can be used
to import/export dashboards between multiple superset instances.
Audit metadata isn't copied over.
"""
def alter_positions(dashboard, old_to_new_slc_id_dict):
""" Updates slice_ids in the position json.
Sample position_json data:
{
"DASHBOARD_VERSION_KEY": "v2",
"DASHBOARD_ROOT_ID": {
"type": "DASHBOARD_ROOT_TYPE",
"id": "DASHBOARD_ROOT_ID",
"children": ["DASHBOARD_GRID_ID"]
},
"DASHBOARD_GRID_ID": {
"type": "DASHBOARD_GRID_TYPE",
"id": "DASHBOARD_GRID_ID",
"children": ["DASHBOARD_CHART_TYPE-2"]
},
"DASHBOARD_CHART_TYPE-2": {
"type": "DASHBOARD_CHART_TYPE",
"id": "DASHBOARD_CHART_TYPE-2",
"children": [],
"meta": {
"width": 4,
"height": 50,
"chartId": 118
}
},
}
"""
position_data = json.loads(dashboard.position_json)
position_json = position_data.values()
for value in position_json:
if (
isinstance(value, dict)
and value.get("meta")
and value.get("meta").get("chartId")
):
old_slice_id = value.get("meta").get("chartId")
if old_slice_id in old_to_new_slc_id_dict:
value["meta"]["chartId"] = old_to_new_slc_id_dict[old_slice_id]
dashboard.position_json = json.dumps(position_data)
logging.info(
"Started import of the dashboard: %s", dashboard_to_import.to_json()
)
session = db.session
logging.info("Dashboard has %d slices", len(dashboard_to_import.slices))
# copy slices object as Slice.import_slice will mutate the slice
# and will remove the existing dashboard - slice association
slices = copy(dashboard_to_import.slices)
old_to_new_slc_id_dict = {}
new_filter_immune_slices = []
new_filter_immune_slice_fields = {}
new_timed_refresh_immune_slices = []
new_expanded_slices = {}
i_params_dict = dashboard_to_import.params_dict
remote_id_slice_map = {
slc.params_dict["remote_id"]: slc
for slc in session.query(Slice).all()
if "remote_id" in slc.params_dict
}
for slc in slices:
logging.info(
"Importing slice %s from the dashboard: %s",
slc.to_json(),
dashboard_to_import.dashboard_title,
)
remote_slc = remote_id_slice_map.get(slc.id)
new_slc_id = Slice.import_obj(slc, remote_slc, import_time=import_time)
old_to_new_slc_id_dict[slc.id] = new_slc_id
# update json metadata that deals with slice ids
new_slc_id_str = "{}".format(new_slc_id)
old_slc_id_str = "{}".format(slc.id)
if (
"filter_immune_slices" in i_params_dict
and old_slc_id_str in i_params_dict["filter_immune_slices"]
):
new_filter_immune_slices.append(new_slc_id_str)
if (
"filter_immune_slice_fields" in i_params_dict
and old_slc_id_str in i_params_dict["filter_immune_slice_fields"]
):
new_filter_immune_slice_fields[new_slc_id_str] = i_params_dict[
"filter_immune_slice_fields"
][old_slc_id_str]
if (
"timed_refresh_immune_slices" in i_params_dict
and old_slc_id_str in i_params_dict["timed_refresh_immune_slices"]
):
new_timed_refresh_immune_slices.append(new_slc_id_str)
if (
"expanded_slices" in i_params_dict
and old_slc_id_str in i_params_dict["expanded_slices"]
):
new_expanded_slices[new_slc_id_str] = i_params_dict["expanded_slices"][
old_slc_id_str
]
# override the dashboard
existing_dashboard = None
for dash in session.query(Dashboard).all():
if (
"remote_id" in dash.params_dict
and dash.params_dict["remote_id"] == dashboard_to_import.id
):
existing_dashboard = dash
dashboard_to_import = dashboard_to_import.copy()
dashboard_to_import.id = None
dashboard_to_import.reset_ownership()
# position_json can be empty for dashboards
# with charts added from chart-edit page and without re-arranging
if dashboard_to_import.position_json:
alter_positions(dashboard_to_import, old_to_new_slc_id_dict)
dashboard_to_import.alter_params(import_time=import_time)
if new_expanded_slices:
dashboard_to_import.alter_params(expanded_slices=new_expanded_slices)
if new_filter_immune_slices:
dashboard_to_import.alter_params(
filter_immune_slices=new_filter_immune_slices
)
if new_filter_immune_slice_fields:
dashboard_to_import.alter_params(
filter_immune_slice_fields=new_filter_immune_slice_fields
)
if new_timed_refresh_immune_slices:
dashboard_to_import.alter_params(
timed_refresh_immune_slices=new_timed_refresh_immune_slices
)
new_slices = (
session.query(Slice)
.filter(Slice.id.in_(old_to_new_slc_id_dict.values()))
.all()
)
if existing_dashboard:
existing_dashboard.override(dashboard_to_import)
existing_dashboard.slices = new_slices
session.flush()
return existing_dashboard.id
dashboard_to_import.slices = new_slices
session.add(dashboard_to_import)
session.flush()
return dashboard_to_import.id # type: ignore
@classmethod
def export_dashboards( # pylint: disable=too-many-locals
cls, dashboard_ids: List
) -> str:
copied_dashboards = []
datasource_ids = set()
for dashboard_id in dashboard_ids:
# make sure that dashboard_id is an integer
dashboard_id = int(dashboard_id)
dashboard = (
db.session.query(Dashboard)
.options(subqueryload(Dashboard.slices))
.filter_by(id=dashboard_id)
.first()
)
# remove ids and relations (like owners, created by, slices, ...)
copied_dashboard = dashboard.copy()
for slc in dashboard.slices:
datasource_ids.add((slc.datasource_id, slc.datasource_type))
copied_slc = slc.copy()
# save original id into json
# we need it to update dashboard's json metadata on import
copied_slc.id = slc.id
# add extra params for the import
copied_slc.alter_params(
remote_id=slc.id,
datasource_name=slc.datasource.datasource_name,
schema=slc.datasource.schema,
database_name=slc.datasource.database.name,
)
# set slices without creating ORM relations
slices = copied_dashboard.__dict__.setdefault("slices", [])
slices.append(copied_slc)
copied_dashboard.alter_params(remote_id=dashboard_id)
copied_dashboards.append(copied_dashboard)
eager_datasources = []
for datasource_id, datasource_type in datasource_ids:
eager_datasource = ConnectorRegistry.get_eager_datasource(
db.session, datasource_type, datasource_id
)
copied_datasource = eager_datasource.copy()
copied_datasource.alter_params(
remote_id=eager_datasource.id,
database_name=eager_datasource.database.name,
)
datasource_class = copied_datasource.__class__
for field_name in datasource_class.export_children:
field_val = getattr(eager_datasource, field_name).copy()
# set children without creating ORM relations
copied_datasource.__dict__[field_name] = field_val
eager_datasources.append(copied_datasource)
return json.dumps(
{"dashboards": copied_dashboards, "datasources": eager_datasources},
cls=utils.DashboardEncoder,
indent=4,
)
class Database(
Model, AuditMixinNullable, ImportMixin
): # pylint: disable=too-many-public-methods
@ -1324,9 +666,6 @@ class FavStar(Model): # pylint: disable=too-few-public-methods
# events for updating tags
if is_feature_enabled("TAGGING_SYSTEM"):
sqla.event.listen(Slice, "after_insert", ChartUpdater.after_insert)
sqla.event.listen(Slice, "after_update", ChartUpdater.after_update)
sqla.event.listen(Slice, "after_delete", ChartUpdater.after_delete)
sqla.event.listen(Dashboard, "after_insert", DashboardUpdater.after_insert)
sqla.event.listen(Dashboard, "after_update", DashboardUpdater.after_update)
sqla.event.listen(Dashboard, "after_delete", DashboardUpdater.after_delete)

View File

@ -0,0 +1,437 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import logging
from copy import copy
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING
from urllib import parse
import sqlalchemy as sqla
from flask_appbuilder import Model
from flask_appbuilder.models.decorators import renders
from flask_appbuilder.security.sqla.models import User
from markupsafe import escape, Markup
from sqlalchemy import (
Boolean,
Column,
ForeignKey,
Integer,
MetaData,
String,
Table,
Text,
UniqueConstraint,
)
from sqlalchemy.orm import relationship, sessionmaker, subqueryload
from superset import app, ConnectorRegistry, db, is_feature_enabled, security_manager
from superset.models.helpers import AuditMixinNullable, ImportMixin
from superset.models.slice import Slice as Slice
from superset.models.tags import DashboardUpdater
from superset.models.user_attributes import UserAttribute
from superset.utils import core as utils
if TYPE_CHECKING:
# pylint: disable=unused-import
from superset.connectors.base.models import BaseDatasource
metadata = Model.metadata # pylint: disable=no-member
config = app.config
def copy_dashboard(mapper, connection, target):
# pylint: disable=unused-argument
dashboard_id = config["DASHBOARD_TEMPLATE_ID"]
if dashboard_id is None:
return
session_class = sessionmaker(autoflush=False)
session = session_class(bind=connection)
new_user = session.query(User).filter_by(id=target.id).first()
# copy template dashboard to user
template = session.query(Dashboard).filter_by(id=int(dashboard_id)).first()
dashboard = Dashboard(
dashboard_title=template.dashboard_title,
position_json=template.position_json,
description=template.description,
css=template.css,
json_metadata=template.json_metadata,
slices=template.slices,
owners=[new_user],
)
session.add(dashboard)
session.commit()
# set dashboard as the welcome dashboard
extra_attributes = UserAttribute(
user_id=target.id, welcome_dashboard_id=dashboard.id
)
session.add(extra_attributes)
session.commit()
sqla.event.listen(User, "after_insert", copy_dashboard)
dashboard_slices = Table(
"dashboard_slices",
metadata,
Column("id", Integer, primary_key=True),
Column("dashboard_id", Integer, ForeignKey("dashboards.id")),
Column("slice_id", Integer, ForeignKey("slices.id")),
UniqueConstraint("dashboard_id", "slice_id"),
)
dashboard_user = Table(
"dashboard_user",
metadata,
Column("id", Integer, primary_key=True),
Column("user_id", Integer, ForeignKey("ab_user.id")),
Column("dashboard_id", Integer, ForeignKey("dashboards.id")),
)
class Dashboard( # pylint: disable=too-many-instance-attributes
Model, AuditMixinNullable, ImportMixin
):
"""The dashboard object!"""
__tablename__ = "dashboards"
id = Column(Integer, primary_key=True) # pylint: disable=invalid-name
dashboard_title = Column(String(500))
position_json = Column(utils.MediumText())
description = Column(Text)
css = Column(Text)
json_metadata = Column(Text)
slug = Column(String(255), unique=True)
slices = relationship("Slice", secondary=dashboard_slices, backref="dashboards")
owners = relationship(security_manager.user_model, secondary=dashboard_user)
published = Column(Boolean, default=False)
export_fields = [
"dashboard_title",
"position_json",
"json_metadata",
"description",
"css",
"slug",
]
def __repr__(self):
return self.dashboard_title or str(self.id)
@property
def table_names(self) -> str:
# pylint: disable=no-member
return ", ".join(str(s.datasource.full_name) for s in self.slices)
@property
def url(self) -> str:
if self.json_metadata:
# add default_filters to the preselect_filters of dashboard
json_metadata = json.loads(self.json_metadata)
default_filters = json_metadata.get("default_filters")
# make sure default_filters is not empty and is valid
if default_filters and default_filters != "{}":
try:
if json.loads(default_filters):
filters = parse.quote(default_filters.encode("utf8"))
return "/superset/dashboard/{}/?preselect_filters={}".format(
self.slug or self.id, filters
)
except Exception: # pylint: disable=broad-except
pass
return f"/superset/dashboard/{self.slug or self.id}/"
@property
def datasources(self) -> Set[Optional["BaseDatasource"]]:
return {slc.datasource for slc in self.slices}
@property
def charts(self) -> List[Optional["BaseDatasource"]]:
return [slc.chart for slc in self.slices]
@property
def sqla_metadata(self) -> None:
# pylint: disable=no-member
meta = MetaData(bind=self.get_sqla_engine())
meta.reflect()
@renders("dashboard_title")
def dashboard_link(self) -> Markup:
title = escape(self.dashboard_title or "<empty>")
return Markup(f'<a href="{self.url}">{title}</a>')
@property
def data(self) -> Dict[str, Any]:
positions = self.position_json
if positions:
positions = json.loads(positions)
return {
"id": self.id,
"metadata": self.params_dict,
"css": self.css,
"dashboard_title": self.dashboard_title,
"published": self.published,
"slug": self.slug,
"slices": [slc.data for slc in self.slices],
"position_json": positions,
}
@property
def params(self) -> str:
return self.json_metadata
@params.setter
def params(self, value: str) -> None:
self.json_metadata = value
@property
def position(self) -> Dict:
if self.position_json:
return json.loads(self.position_json)
return {}
@classmethod
def import_obj( # pylint: disable=too-many-locals,too-many-branches,too-many-statements
cls, dashboard_to_import: "Dashboard", import_time: Optional[int] = None
) -> int:
"""Imports the dashboard from the object to the database.
Once dashboard is imported, json_metadata field is extended and stores
remote_id and import_time. It helps to decide if the dashboard has to
be overridden or just copies over. Slices that belong to this
dashboard will be wired to existing tables. This function can be used
to import/export dashboards between multiple superset instances.
Audit metadata isn't copied over.
"""
def alter_positions(dashboard, old_to_new_slc_id_dict):
""" Updates slice_ids in the position json.
Sample position_json data:
{
"DASHBOARD_VERSION_KEY": "v2",
"DASHBOARD_ROOT_ID": {
"type": "DASHBOARD_ROOT_TYPE",
"id": "DASHBOARD_ROOT_ID",
"children": ["DASHBOARD_GRID_ID"]
},
"DASHBOARD_GRID_ID": {
"type": "DASHBOARD_GRID_TYPE",
"id": "DASHBOARD_GRID_ID",
"children": ["DASHBOARD_CHART_TYPE-2"]
},
"DASHBOARD_CHART_TYPE-2": {
"type": "DASHBOARD_CHART_TYPE",
"id": "DASHBOARD_CHART_TYPE-2",
"children": [],
"meta": {
"width": 4,
"height": 50,
"chartId": 118
}
},
}
"""
position_data = json.loads(dashboard.position_json)
position_json = position_data.values()
for value in position_json:
if (
isinstance(value, dict)
and value.get("meta")
and value.get("meta").get("chartId")
):
old_slice_id = value.get("meta").get("chartId")
if old_slice_id in old_to_new_slc_id_dict:
value["meta"]["chartId"] = old_to_new_slc_id_dict[old_slice_id]
dashboard.position_json = json.dumps(position_data)
logging.info(
"Started import of the dashboard: %s", dashboard_to_import.to_json()
)
session = db.session
logging.info("Dashboard has %d slices", len(dashboard_to_import.slices))
# copy slices object as Slice.import_slice will mutate the slice
# and will remove the existing dashboard - slice association
slices = copy(dashboard_to_import.slices)
old_to_new_slc_id_dict = {}
new_filter_immune_slices = []
new_filter_immune_slice_fields = {}
new_timed_refresh_immune_slices = []
new_expanded_slices = {}
i_params_dict = dashboard_to_import.params_dict
remote_id_slice_map = {
slc.params_dict["remote_id"]: slc
for slc in session.query(Slice).all()
if "remote_id" in slc.params_dict
}
for slc in slices:
logging.info(
"Importing slice %s from the dashboard: %s",
slc.to_json(),
dashboard_to_import.dashboard_title,
)
remote_slc = remote_id_slice_map.get(slc.id)
new_slc_id = Slice.import_obj(slc, remote_slc, import_time=import_time)
old_to_new_slc_id_dict[slc.id] = new_slc_id
# update json metadata that deals with slice ids
new_slc_id_str = "{}".format(new_slc_id)
old_slc_id_str = "{}".format(slc.id)
if (
"filter_immune_slices" in i_params_dict
and old_slc_id_str in i_params_dict["filter_immune_slices"]
):
new_filter_immune_slices.append(new_slc_id_str)
if (
"filter_immune_slice_fields" in i_params_dict
and old_slc_id_str in i_params_dict["filter_immune_slice_fields"]
):
new_filter_immune_slice_fields[new_slc_id_str] = i_params_dict[
"filter_immune_slice_fields"
][old_slc_id_str]
if (
"timed_refresh_immune_slices" in i_params_dict
and old_slc_id_str in i_params_dict["timed_refresh_immune_slices"]
):
new_timed_refresh_immune_slices.append(new_slc_id_str)
if (
"expanded_slices" in i_params_dict
and old_slc_id_str in i_params_dict["expanded_slices"]
):
new_expanded_slices[new_slc_id_str] = i_params_dict["expanded_slices"][
old_slc_id_str
]
# override the dashboard
existing_dashboard = None
for dash in session.query(Dashboard).all():
if (
"remote_id" in dash.params_dict
and dash.params_dict["remote_id"] == dashboard_to_import.id
):
existing_dashboard = dash
dashboard_to_import = dashboard_to_import.copy()
dashboard_to_import.id = None
dashboard_to_import.reset_ownership()
# position_json can be empty for dashboards
# with charts added from chart-edit page and without re-arranging
if dashboard_to_import.position_json:
alter_positions(dashboard_to_import, old_to_new_slc_id_dict)
dashboard_to_import.alter_params(import_time=import_time)
if new_expanded_slices:
dashboard_to_import.alter_params(expanded_slices=new_expanded_slices)
if new_filter_immune_slices:
dashboard_to_import.alter_params(
filter_immune_slices=new_filter_immune_slices
)
if new_filter_immune_slice_fields:
dashboard_to_import.alter_params(
filter_immune_slice_fields=new_filter_immune_slice_fields
)
if new_timed_refresh_immune_slices:
dashboard_to_import.alter_params(
timed_refresh_immune_slices=new_timed_refresh_immune_slices
)
new_slices = (
session.query(Slice)
.filter(Slice.id.in_(old_to_new_slc_id_dict.values()))
.all()
)
if existing_dashboard:
existing_dashboard.override(dashboard_to_import)
existing_dashboard.slices = new_slices
session.flush()
return existing_dashboard.id
dashboard_to_import.slices = new_slices
session.add(dashboard_to_import)
session.flush()
return dashboard_to_import.id # type: ignore
@classmethod
def export_dashboards( # pylint: disable=too-many-locals
cls, dashboard_ids: List
) -> str:
copied_dashboards = []
datasource_ids = set()
for dashboard_id in dashboard_ids:
# make sure that dashboard_id is an integer
dashboard_id = int(dashboard_id)
dashboard = (
db.session.query(Dashboard)
.options(subqueryload(Dashboard.slices))
.filter_by(id=dashboard_id)
.first()
)
# remove ids and relations (like owners, created by, slices, ...)
copied_dashboard = dashboard.copy()
for slc in dashboard.slices:
datasource_ids.add((slc.datasource_id, slc.datasource_type))
copied_slc = slc.copy()
# save original id into json
# we need it to update dashboard's json metadata on import
copied_slc.id = slc.id
# add extra params for the import
copied_slc.alter_params(
remote_id=slc.id,
datasource_name=slc.datasource.datasource_name,
schema=slc.datasource.schema,
database_name=slc.datasource.database.name,
)
# set slices without creating ORM relations
slices = copied_dashboard.__dict__.setdefault("slices", [])
slices.append(copied_slc)
copied_dashboard.alter_params(remote_id=dashboard_id)
copied_dashboards.append(copied_dashboard)
eager_datasources = []
for datasource_id, datasource_type in datasource_ids:
eager_datasource = ConnectorRegistry.get_eager_datasource(
db.session, datasource_type, datasource_id
)
copied_datasource = eager_datasource.copy()
copied_datasource.alter_params(
remote_id=eager_datasource.id,
database_name=eager_datasource.database.name,
)
datasource_class = copied_datasource.__class__
for field_name in datasource_class.export_children:
field_val = getattr(eager_datasource, field_name).copy()
# set children without creating ORM relations
copied_datasource.__dict__[field_name] = field_val
eager_datasources.append(copied_datasource)
return json.dumps(
{"dashboards": copied_dashboards, "datasources": eager_datasources},
cls=utils.DashboardEncoder,
indent=4,
)
# events for updating tags
if is_feature_enabled("TAGGING_SYSTEM"):
sqla.event.listen(Dashboard, "after_insert", DashboardUpdater.after_insert)
sqla.event.listen(Dashboard, "after_update", DashboardUpdater.after_update)
sqla.event.listen(Dashboard, "after_delete", DashboardUpdater.after_delete)

317
superset/models/slice.py Normal file
View File

@ -0,0 +1,317 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import logging
from typing import Any, Dict, Optional, Type, TYPE_CHECKING
from urllib import parse
import sqlalchemy as sqla
from flask_appbuilder import Model
from flask_appbuilder.models.decorators import renders
from markupsafe import escape, Markup
from sqlalchemy import Column, ForeignKey, Integer, String, Table, Text
from sqlalchemy.orm import make_transient, relationship
from superset import ConnectorRegistry, db, is_feature_enabled, security_manager
from superset.legacy import update_time_range
from superset.models.helpers import AuditMixinNullable, ImportMixin
from superset.models.tags import ChartUpdater
from superset.utils import core as utils
from superset.viz import BaseViz, viz_types
if TYPE_CHECKING:
# pylint: disable=unused-import
from superset.connectors.base.models import BaseDatasource
metadata = Model.metadata # pylint: disable=no-member
slice_user = Table(
"slice_user",
metadata,
Column("id", Integer, primary_key=True),
Column("user_id", Integer, ForeignKey("ab_user.id")),
Column("slice_id", Integer, ForeignKey("slices.id")),
)
class Slice(
Model, AuditMixinNullable, ImportMixin
): # pylint: disable=too-many-public-methods
"""A slice is essentially a report or a view on data"""
__tablename__ = "slices"
id = Column(Integer, primary_key=True) # pylint: disable=invalid-name
slice_name = Column(String(250))
datasource_id = Column(Integer)
datasource_type = Column(String(200))
datasource_name = Column(String(2000))
viz_type = Column(String(250))
params = Column(Text)
description = Column(Text)
cache_timeout = Column(Integer)
perm = Column(String(1000))
schema_perm = Column(String(1000))
owners = relationship(security_manager.user_model, secondary=slice_user)
token = ""
export_fields = [
"slice_name",
"datasource_type",
"datasource_name",
"viz_type",
"params",
"cache_timeout",
]
def __repr__(self):
return self.slice_name or str(self.id)
@property
def cls_model(self) -> Type["BaseDatasource"]:
return ConnectorRegistry.sources[self.datasource_type]
@property
def datasource(self) -> "BaseDatasource":
return self.get_datasource
def clone(self) -> "Slice":
return Slice(
slice_name=self.slice_name,
datasource_id=self.datasource_id,
datasource_type=self.datasource_type,
datasource_name=self.datasource_name,
viz_type=self.viz_type,
params=self.params,
description=self.description,
cache_timeout=self.cache_timeout,
)
# pylint: disable=using-constant-test
@datasource.getter # type: ignore
@utils.memoized
def get_datasource(self) -> Optional["BaseDatasource"]:
return db.session.query(self.cls_model).filter_by(id=self.datasource_id).first()
@renders("datasource_name")
def datasource_link(self) -> Optional[Markup]:
# pylint: disable=no-member
datasource = self.datasource
return datasource.link if datasource else None
def datasource_name_text(self) -> Optional[str]:
# pylint: disable=no-member
datasource = self.datasource
return datasource.name if datasource else None
@property
def datasource_edit_url(self) -> Optional[str]:
# pylint: disable=no-member
datasource = self.datasource
return datasource.url if datasource else None
# pylint: enable=using-constant-test
@property # type: ignore
@utils.memoized
def viz(self) -> BaseViz:
d = json.loads(self.params)
viz_class = viz_types[self.viz_type]
return viz_class(datasource=self.datasource, form_data=d)
@property
def description_markeddown(self) -> str:
return utils.markdown(self.description)
@property
def data(self) -> Dict[str, Any]:
"""Data used to render slice in templates"""
d: Dict[str, Any] = {}
self.token = ""
try:
d = self.viz.data
self.token = d.get("token") # type: ignore
except Exception as e: # pylint: disable=broad-except
logging.exception(e)
d["error"] = str(e)
return {
"datasource": self.datasource_name,
"description": self.description,
"description_markeddown": self.description_markeddown,
"edit_url": self.edit_url,
"form_data": self.form_data,
"slice_id": self.id,
"slice_name": self.slice_name,
"slice_url": self.slice_url,
"modified": self.modified(),
"changed_on_humanized": self.changed_on_humanized,
"changed_on": self.changed_on.isoformat(),
}
@property
def json_data(self) -> str:
return json.dumps(self.data)
@property
def form_data(self) -> Dict[str, Any]:
form_data: Dict[str, Any] = {}
try:
form_data = json.loads(self.params)
except Exception as e: # pylint: disable=broad-except
logging.error("Malformed json in slice's params")
logging.exception(e)
form_data.update(
{
"slice_id": self.id,
"viz_type": self.viz_type,
"datasource": "{}__{}".format(self.datasource_id, self.datasource_type),
}
)
if self.cache_timeout:
form_data["cache_timeout"] = self.cache_timeout
update_time_range(form_data)
return form_data
def get_explore_url(
self,
base_url: str = "/superset/explore",
overrides: Optional[Dict[str, Any]] = None,
) -> str:
overrides = overrides or {}
form_data = {"slice_id": self.id}
form_data.update(overrides)
params = parse.quote(json.dumps(form_data))
return f"{base_url}/?form_data={params}"
@property
def slice_url(self) -> str:
"""Defines the url to access the slice"""
return self.get_explore_url()
@property
def explore_json_url(self) -> str:
"""Defines the url to access the slice"""
return self.get_explore_url("/superset/explore_json")
@property
def edit_url(self) -> str:
return f"/chart/edit/{self.id}"
@property
def chart(self) -> str:
return self.slice_name or "<empty>"
@property
def slice_link(self) -> Markup:
name = escape(self.chart)
return Markup(f'<a href="{self.url}">{name}</a>')
def get_viz(self, force: bool = False) -> BaseViz:
"""Creates :py:class:viz.BaseViz object from the url_params_multidict.
:return: object of the 'viz_type' type that is taken from the
url_params_multidict or self.params.
:rtype: :py:class:viz.BaseViz
"""
slice_params = json.loads(self.params)
slice_params["slice_id"] = self.id
slice_params["json"] = "false"
slice_params["slice_name"] = self.slice_name
slice_params["viz_type"] = self.viz_type if self.viz_type else "table"
return viz_types[slice_params.get("viz_type")](
self.datasource, form_data=slice_params, force=force
)
@property
def icons(self) -> str:
return f"""
<a
href="{self.datasource_edit_url}"
data-toggle="tooltip"
title="{self.datasource}">
<i class="fa fa-database"></i>
</a>
"""
@classmethod
def import_obj(
cls,
slc_to_import: "Slice",
slc_to_override: Optional["Slice"],
import_time: Optional[int] = None,
) -> int:
"""Inserts or overrides slc in the database.
remote_id and import_time fields in params_dict are set to track the
slice origin and ensure correct overrides for multiple imports.
Slice.perm is used to find the datasources and connect them.
:param Slice slc_to_import: Slice object to import
:param Slice slc_to_override: Slice to replace, id matches remote_id
:returns: The resulting id for the imported slice
:rtype: int
"""
session = db.session
make_transient(slc_to_import)
slc_to_import.dashboards = []
slc_to_import.alter_params(remote_id=slc_to_import.id, import_time=import_time)
slc_to_import = slc_to_import.copy()
slc_to_import.reset_ownership()
params = slc_to_import.params_dict
datasource = ConnectorRegistry.get_datasource_by_name(
session,
slc_to_import.datasource_type,
params["datasource_name"],
params["schema"],
params["database_name"],
)
slc_to_import.datasource_id = datasource.id # type: ignore
if slc_to_override:
slc_to_override.override(slc_to_import)
session.flush()
return slc_to_override.id
session.add(slc_to_import)
logging.info("Final slice: %s", str(slc_to_import.to_json()))
session.flush()
return slc_to_import.id
@property
def url(self) -> str:
return f"/superset/explore/?form_data=%7B%22slice_id%22%3A%20{self.id}%7D"
def set_related_perm(mapper, connection, target):
# pylint: disable=unused-argument
src_class = target.cls_model
id_ = target.datasource_id
if id_:
ds = db.session.query(src_class).filter_by(id=int(id_)).first()
if ds:
target.perm = ds.perm
target.schema_perm = ds.schema_perm
sqla.event.listen(Slice, "before_insert", set_related_perm)
sqla.event.listen(Slice, "before_update", set_related_perm)
# events for updating tags
if is_feature_enabled("TAGGING_SYSTEM"):
sqla.event.listen(Slice, "after_insert", ChartUpdater.after_insert)
sqla.event.listen(Slice, "after_update", ChartUpdater.after_update)
sqla.event.listen(Slice, "after_delete", ChartUpdater.after_delete)

View File

@ -26,7 +26,9 @@ from sqlalchemy import and_, func
from superset import app, db
from superset.extensions import celery_app
from superset.models.core import Dashboard, Log, Slice
from superset.models.core import Log
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.models.tags import Tag, TaggedObject
from superset.utils.core import parse_human_datetime

View File

@ -21,7 +21,8 @@ import time
from datetime import datetime
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.models.core import Dashboard, Slice
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
def decode_dashboards(o):

View File

@ -20,10 +20,10 @@ from flask import request
from flask_appbuilder import expose
from flask_appbuilder.security.decorators import has_access_api
import superset.models.core as models
from superset import appbuilder, db, event_logger, security_manager
from superset.common.query_context import QueryContext
from superset.legacy import update_time_range
from superset.models.slice import Slice
from superset.utils import core as utils
from .base import api, BaseSupersetView, handle_api_exception
@ -63,7 +63,7 @@ class Api(BaseSupersetView):
form_data = {}
slice_id = request.args.get("slice_id")
if slice_id:
slc = db.session.query(models.Slice).filter_by(id=slice_id).one_or_none()
slc = db.session.query(Slice).filter_by(id=slice_id).one_or_none()
if slc:
form_data = slc.form_data.copy()

View File

@ -77,7 +77,9 @@ from superset.exceptions import (
SupersetTimeoutException,
)
from superset.jinja_context import get_template_processor
from superset.models.dashboard import Dashboard
from superset.models.datasource_access_request import DatasourceAccessRequest
from superset.models.slice import Slice
from superset.models.sql_lab import Query, TabState
from superset.models.user_attributes import UserAttribute
from superset.sql_parse import ParsedQuery
@ -291,7 +293,7 @@ if config["ENABLE_ACCESS_REQUEST"]:
class SliceModelView(SupersetModelView, DeleteMixin):
route_base = "/chart"
datamodel = SQLAInterface(models.Slice)
datamodel = SQLAInterface(Slice)
list_title = _("Charts")
show_title = _("Show Chart")
@ -605,9 +607,7 @@ class Superset(BaseSupersetView):
datasources = set()
dashboard_id = request.args.get("dashboard_id")
if dashboard_id:
dash = (
db.session.query(models.Dashboard).filter_by(id=int(dashboard_id)).one()
)
dash = db.session.query(Dashboard).filter_by(id=int(dashboard_id)).one()
datasources |= dash.datasources
datasource_id = request.args.get("datasource_id")
datasource_type = request.args.get("datasource_type")
@ -755,7 +755,7 @@ class Superset(BaseSupersetView):
force=False,
):
if slice_id:
slc = db.session.query(models.Slice).filter_by(id=slice_id).one()
slc = db.session.query(Slice).filter_by(id=slice_id).one()
return slc.get_viz()
else:
viz_type = form_data.get("viz_type", "table")
@ -1148,7 +1148,7 @@ class Superset(BaseSupersetView):
if action in ("saveas"):
if "slice_id" in form_data:
form_data.pop("slice_id") # don't save old slice_id
slc = models.Slice(owners=[g.user] if g.user else [])
slc = Slice(owners=[g.user] if g.user else [])
slc.params = json.dumps(form_data, indent=2, sort_keys=True)
slc.datasource_name = datasource_name
@ -1166,7 +1166,7 @@ class Superset(BaseSupersetView):
dash = None
if request.args.get("add_to_dash") == "existing":
dash = (
db.session.query(models.Dashboard)
db.session.query(Dashboard)
.filter_by(id=int(request.args.get("save_to_dashboard_id")))
.one()
)
@ -1197,7 +1197,7 @@ class Superset(BaseSupersetView):
status=400,
)
dash = models.Dashboard(
dash = Dashboard(
dashboard_title=request.args.get("new_dashboard_name"),
owners=[g.user] if g.user else [],
)
@ -1381,7 +1381,7 @@ class Superset(BaseSupersetView):
session = db.session()
data = json.loads(request.form.get("data"))
dash = models.Dashboard()
original_dash = session.query(models.Dashboard).get(dashboard_id)
original_dash = session.query(Dashboard).get(dashboard_id)
dash.owners = [g.user] if g.user else []
dash.dashboard_title = data["dashboard_title"]
@ -1426,7 +1426,7 @@ class Superset(BaseSupersetView):
def save_dash(self, dashboard_id):
"""Save a dashboard's metadata"""
session = db.session()
dash = session.query(models.Dashboard).get(dashboard_id)
dash = session.query(Dashboard).get(dashboard_id)
check_ownership(dash, raise_if_false=True)
data = json.loads(request.form.get("data"))
self._set_dash_metadata(dash, data)
@ -1451,7 +1451,6 @@ class Superset(BaseSupersetView):
pass
session = db.session()
Slice = models.Slice
current_slices = session.query(Slice).filter(Slice.id.in_(slice_ids)).all()
dashboard.slices = current_slices
@ -1510,8 +1509,7 @@ class Superset(BaseSupersetView):
"""Add and save slices to a dashboard"""
data = json.loads(request.form.get("data"))
session = db.session()
Slice = models.Slice # noqa
dash = session.query(models.Dashboard).get(dashboard_id)
dash = session.query(Dashboard).get(dashboard_id)
check_ownership(dash, raise_if_false=True)
new_slices = session.query(Slice).filter(Slice.id.in_(data["slice_ids"]))
dash.slices += new_slices
@ -1576,9 +1574,9 @@ class Superset(BaseSupersetView):
limit = 1000
qry = (
db.session.query(M.Log, M.Dashboard, M.Slice)
db.session.query(M.Log, M.Dashboard, Slice)
.outerjoin(M.Dashboard, M.Dashboard.id == M.Log.dashboard_id)
.outerjoin(M.Slice, M.Slice.id == M.Log.slice_id)
.outerjoin(Slice, Slice.id == M.Log.slice_id)
.filter(
and_(
~M.Log.action.in_(("queries", "shortner", "sql_json")),
@ -1643,13 +1641,13 @@ class Superset(BaseSupersetView):
@expose("/fave_dashboards/<user_id>/", methods=["GET"])
def fave_dashboards(self, user_id):
qry = (
db.session.query(models.Dashboard, models.FavStar.dttm)
db.session.query(Dashboard, models.FavStar.dttm)
.join(
models.FavStar,
and_(
models.FavStar.user_id == int(user_id),
models.FavStar.class_name == "Dashboard",
models.Dashboard.id == models.FavStar.obj_id,
Dashboard.id == models.FavStar.obj_id,
),
)
.order_by(models.FavStar.dttm.desc())
@ -1674,7 +1672,7 @@ class Superset(BaseSupersetView):
@has_access_api
@expose("/created_dashboards/<user_id>/", methods=["GET"])
def created_dashboards(self, user_id):
Dash = models.Dashboard
Dash = Dashboard
qry = (
db.session.query(Dash)
.filter(or_(Dash.created_by_fk == user_id, Dash.changed_by_fk == user_id))
@ -1700,7 +1698,6 @@ class Superset(BaseSupersetView):
"""List of slices a user created, or faved"""
if not user_id:
user_id = g.user.id
Slice = models.Slice
FavStar = models.FavStar
qry = (
db.session.query(Slice, FavStar.dttm)
@ -1709,7 +1706,7 @@ class Superset(BaseSupersetView):
and_(
models.FavStar.user_id == int(user_id),
models.FavStar.class_name == "slice",
models.Slice.id == models.FavStar.obj_id,
Slice.id == models.FavStar.obj_id,
),
isouter=True,
)
@ -1743,7 +1740,6 @@ class Superset(BaseSupersetView):
"""List of slices created by this user"""
if not user_id:
user_id = g.user.id
Slice = models.Slice
qry = (
db.session.query(Slice)
.filter(or_(Slice.created_by_fk == user_id, Slice.changed_by_fk == user_id))
@ -1770,13 +1766,13 @@ class Superset(BaseSupersetView):
if not user_id:
user_id = g.user.id
qry = (
db.session.query(models.Slice, models.FavStar.dttm)
db.session.query(Slice, models.FavStar.dttm)
.join(
models.FavStar,
and_(
models.FavStar.user_id == int(user_id),
models.FavStar.class_name == "slice",
models.Slice.id == models.FavStar.obj_id,
Slice.id == models.FavStar.obj_id,
),
)
.order_by(models.FavStar.dttm.desc())
@ -1820,7 +1816,7 @@ class Superset(BaseSupersetView):
status=400,
)
if slice_id:
slices = session.query(models.Slice).filter_by(id=slice_id).all()
slices = session.query(Slice).filter_by(id=slice_id).all()
if not slices:
return json_error_response(
__("Chart %(id)s not found", id=slice_id), status=404
@ -1845,7 +1841,7 @@ class Superset(BaseSupersetView):
status=404,
)
slices = (
session.query(models.Slice)
session.query(Slice)
.filter_by(datasource_id=table.id, datasource_type=table.type)
.all()
)
@ -1906,7 +1902,7 @@ class Superset(BaseSupersetView):
def publish(self, dashboard_id):
"""Gets and toggles published status on dashboards"""
session = db.session()
Dashboard = models.Dashboard
Dashboard = Dashboard
Role = ab_models.Role
dash = (
session.query(Dashboard).filter(Dashboard.id == dashboard_id).one_or_none()
@ -1940,7 +1936,7 @@ class Superset(BaseSupersetView):
def dashboard(self, dashboard_id):
"""Server side rendering for a dashboard"""
session = db.session()
qry = session.query(models.Dashboard)
qry = session.query(Dashboard)
if dashboard_id.isdigit():
qry = qry.filter_by(id=int(dashboard_id))
else:

View File

@ -17,7 +17,9 @@
from sqlalchemy import and_, or_
from superset import db, security_manager
from superset.models.core import Dashboard, FavStar, Slice
from superset.models.core import FavStar
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.views.base import BaseFilter
from ..base import get_user_roles

View File

@ -28,12 +28,13 @@ from wtforms import BooleanField, StringField
from superset import app, appbuilder, db, security_manager
from superset.exceptions import SupersetException
from superset.models.core import Dashboard, Slice
from superset.models.dashboard import Dashboard
from superset.models.schedules import (
DashboardEmailSchedule,
ScheduleType,
SliceEmailSchedule,
)
from superset.models.slice import Slice
from superset.tasks.schedules import schedule_email_report
from superset.utils.core import get_email_address_list, json_iso_dttm_ser
from superset.views.core import json_success

View File

@ -26,7 +26,8 @@ from werkzeug.routing import BaseConverter
from superset import app, appbuilder, db, utils
from superset.jinja_context import current_user_id, current_username
from superset.models.core import Dashboard, Slice
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.models.sql_lab import SavedQuery
from superset.models.tags import ObjectTypes, Tag, TaggedObject, TagTypes

View File

@ -27,6 +27,7 @@ from superset import app, db, viz
from superset.connectors.connector_registry import ConnectorRegistry
from superset.exceptions import SupersetException
from superset.legacy import update_time_range
from superset.models.slice import Slice
from superset.utils.core import QueryStatus, TimeRangeEndpoint
FORM_DATA_KEY_BLACKLIST: List[str] = []
@ -79,7 +80,7 @@ def get_viz(
slice_id=None, form_data=None, datasource_type=None, datasource_id=None, force=False
):
if slice_id:
slc = db.session.query(models.Slice).filter_by(id=slice_id).one()
slc = db.session.query(Slice).filter_by(id=slice_id).one()
return slc.get_viz()
viz_type = form_data.get("viz_type", "table")
@ -127,7 +128,7 @@ def get_form_data(slice_id=None, use_slice_data=False):
# Include the slice_form_data if request from explore or slice calls
# or if form_data only contains slice_id and additional filters
if slice_id and (use_slice_data or valid_slice_id):
slc = db.session.query(models.Slice).filter_by(id=slice_id).one_or_none()
slc = db.session.query(Slice).filter_by(id=slice_id).one_or_none()
if slc:
slice_form_data = slc.form_data.copy()
slice_form_data.update(form_data)
@ -209,7 +210,7 @@ def apply_display_max_row_limit(
def get_time_range_endpoints(
form_data: Dict[str, Any],
slc: Optional[models.Slice] = None,
slc: Optional[Slice] = None,
slice_id: Optional[int] = None,
) -> Optional[Tuple[TimeRangeEndpoint, TimeRangeEndpoint]]:
"""
@ -244,9 +245,7 @@ def get_time_range_endpoints(
if datasource_type == "table":
if not slc:
slc = (
db.session.query(models.Slice).filter_by(id=slice_id).one_or_none()
)
slc = db.session.query(Slice).filter_by(id=slice_id).one_or_none()
if slc:
endpoints = slc.datasource.database.get_extra().get(

View File

@ -1296,7 +1296,7 @@ class MultiLineViz(NVD3Viz):
def get_data(self, df):
fd = self.form_data
# Late imports to avoid circular import issues
from superset.models.core import Slice
from superset.models.slice import Slice
from superset import db
slice_ids1 = fd.get("line_charts")
@ -2104,7 +2104,7 @@ class DeckGLMultiLayer(BaseViz):
def get_data(self, df):
fd = self.form_data
# Late imports to avoid circular import issues
from superset.models.core import Slice
from superset.models.slice import Slice
from superset import db
slice_ids = fd.get("deck_slices")

View File

@ -30,7 +30,9 @@ from superset import db, security_manager
from superset.connectors.druid.models import DruidCluster, DruidDatasource
from superset.connectors.sqla.models import SqlaTable
from superset.models import core as models
from superset.models.slice import Slice
from superset.models.core import Database
from superset.models.dashboard import Dashboard
from superset.models.datasource_access_request import DatasourceAccessRequest
from superset.utils.core import get_example_database
@ -107,7 +109,7 @@ class SupersetTestCase(TestCase):
self.assertNotIn("User confirmation needed", resp)
def get_slice(self, slice_name, session):
slc = session.query(models.Slice).filter_by(slice_name=slice_name).one()
slc = session.query(Slice).filter_by(slice_name=slice_name).one()
session.expunge_all()
return slc
@ -281,4 +283,4 @@ class SupersetTestCase(TestCase):
def get_dash_by_slug(self, dash_slug):
sesh = db.session()
return sesh.query(models.Dashboard).filter_by(slug=dash_slug).first()
return sesh.query(Dashboard).filter_by(slug=dash_slug).first()

View File

@ -40,7 +40,9 @@ from superset.connectors.sqla.models import SqlaTable
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.mssql import MssqlEngineSpec
from superset.models import core as models
from superset.models.dashboard import Dashboard
from superset.models.datasource_access_request import DatasourceAccessRequest
from superset.models.slice import Slice
from superset.models.sql_lab import Query
from superset.utils import core as utils
from superset.views import core as views
@ -221,7 +223,7 @@ class CoreTests(SupersetTestCase):
)
db.session.expunge_all()
new_slice_id = resp.json["form_data"]["slice_id"]
slc = db.session.query(models.Slice).filter_by(id=new_slice_id).one()
slc = db.session.query(Slice).filter_by(id=new_slice_id).one()
self.assertEqual(slc.slice_name, copy_name)
form_data.pop("slice_id") # We don't save the slice id when saving as
@ -241,7 +243,7 @@ class CoreTests(SupersetTestCase):
data={"form_data": json.dumps(form_data)},
)
db.session.expunge_all()
slc = db.session.query(models.Slice).filter_by(id=new_slice_id).one()
slc = db.session.query(Slice).filter_by(id=new_slice_id).one()
self.assertEqual(slc.slice_name, new_slice_name)
self.assertEqual(slc.viz.form_data, form_data)
@ -280,7 +282,7 @@ class CoreTests(SupersetTestCase):
def test_slices(self):
# Testing by hitting the two supported end points for all slices
self.login(username="admin")
Slc = models.Slice
Slc = Slice
urls = []
for slc in db.session.query(Slc).all():
urls += [
@ -333,7 +335,7 @@ class CoreTests(SupersetTestCase):
)
self.login(username="explore_beta", password="general")
Slc = models.Slice
Slc = Slice
urls = []
for slc in db.session.query(Slc).all():
urls += [(slc.slice_name, "slice_url", slc.slice_url)]
@ -554,7 +556,7 @@ class CoreTests(SupersetTestCase):
resp = self.get_json_resp(url)
self.assertEqual(resp["count"], 1)
dash = db.session.query(models.Dashboard).filter_by(slug="births").first()
dash = db.session.query(Dashboard).filter_by(slug="births").first()
url = "/superset/favstar/Dashboard/{}/select/".format(dash.id)
resp = self.get_json_resp(url)
self.assertEqual(resp["count"], 1)
@ -579,7 +581,7 @@ class CoreTests(SupersetTestCase):
def test_slice_id_is_always_logged_correctly_on_web_request(self):
# superset/explore case
slc = db.session.query(models.Slice).filter_by(slice_name="Girls").one()
slc = db.session.query(Slice).filter_by(slice_name="Girls").one()
qry = db.session.query(models.Log).filter_by(slice_id=slc.id)
self.get_resp(slc.slice_url, {"form_data": json.dumps(slc.form_data)})
self.assertEqual(1, qry.count())
@ -587,7 +589,7 @@ class CoreTests(SupersetTestCase):
def test_slice_id_is_always_logged_correctly_on_ajax_request(self):
# superset/explore_json case
self.login(username="admin")
slc = db.session.query(models.Slice).filter_by(slice_name="Girls").one()
slc = db.session.query(Slice).filter_by(slice_name="Girls").one()
qry = db.session.query(models.Log).filter_by(slice_id=slc.id)
slc_url = slc.slice_url.replace("explore", "explore_json")
self.get_json_resp(slc_url, {"form_data": json.dumps(slc.form_data)})

View File

@ -25,6 +25,8 @@ from sqlalchemy import func
from superset import db, security_manager
from superset.connectors.sqla.models import SqlaTable
from superset.models import core as models
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from .base_tests import SupersetTestCase
@ -59,23 +61,23 @@ class DashboardTests(SupersetTestCase):
def test_dashboard(self):
self.login(username="admin")
urls = {}
for dash in db.session.query(models.Dashboard).all():
for dash in db.session.query(Dashboard).all():
urls[dash.dashboard_title] = dash.url
for title, url in urls.items():
assert escape(title) in self.client.get(url).data.decode("utf-8")
def test_new_dashboard(self):
self.login(username="admin")
dash_count_before = db.session.query(func.count(models.Dashboard.id)).first()[0]
dash_count_before = db.session.query(func.count(Dashboard.id)).first()[0]
url = "/dashboard/new/"
resp = self.get_resp(url)
self.assertIn("[ untitled dashboard ]", resp)
dash_count_after = db.session.query(func.count(models.Dashboard.id)).first()[0]
dash_count_after = db.session.query(func.count(Dashboard.id)).first()[0]
self.assertEqual(dash_count_before + 1, dash_count_after)
def test_dashboard_modes(self):
self.login(username="admin")
dash = db.session.query(models.Dashboard).filter_by(slug="births").first()
dash = db.session.query(Dashboard).filter_by(slug="births").first()
url = dash.url
if dash.url.find("?") == -1:
url += "?"
@ -88,7 +90,7 @@ class DashboardTests(SupersetTestCase):
def test_save_dash(self, username="admin"):
self.login(username=username)
dash = db.session.query(models.Dashboard).filter_by(slug="births").first()
dash = db.session.query(Dashboard).filter_by(slug="births").first()
positions = self.get_mock_positions(dash)
data = {
"css": "",
@ -102,7 +104,7 @@ class DashboardTests(SupersetTestCase):
def test_save_dash_with_filter(self, username="admin"):
self.login(username=username)
dash = db.session.query(models.Dashboard).filter_by(slug="world_health").first()
dash = db.session.query(Dashboard).filter_by(slug="world_health").first()
positions = self.get_mock_positions(dash)
filters = {str(dash.slices[0].id): {"region": ["North America"]}}
@ -119,9 +121,7 @@ class DashboardTests(SupersetTestCase):
resp = self.get_resp(url, data=dict(data=json.dumps(data)))
self.assertIn("SUCCESS", resp)
updatedDash = (
db.session.query(models.Dashboard).filter_by(slug="world_health").first()
)
updatedDash = db.session.query(Dashboard).filter_by(slug="world_health").first()
new_url = updatedDash.url
self.assertIn("region", new_url)
@ -130,7 +130,7 @@ class DashboardTests(SupersetTestCase):
def test_save_dash_with_invalid_filters(self, username="admin"):
self.login(username=username)
dash = db.session.query(models.Dashboard).filter_by(slug="world_health").first()
dash = db.session.query(Dashboard).filter_by(slug="world_health").first()
# add an invalid filter slice
positions = self.get_mock_positions(dash)
@ -148,15 +148,13 @@ class DashboardTests(SupersetTestCase):
resp = self.get_resp(url, data=dict(data=json.dumps(data)))
self.assertIn("SUCCESS", resp)
updatedDash = (
db.session.query(models.Dashboard).filter_by(slug="world_health").first()
)
updatedDash = db.session.query(Dashboard).filter_by(slug="world_health").first()
new_url = updatedDash.url
self.assertNotIn("region", new_url)
def test_save_dash_with_dashboard_title(self, username="admin"):
self.login(username=username)
dash = db.session.query(models.Dashboard).filter_by(slug="births").first()
dash = db.session.query(Dashboard).filter_by(slug="births").first()
origin_title = dash.dashboard_title
positions = self.get_mock_positions(dash)
data = {
@ -167,9 +165,7 @@ class DashboardTests(SupersetTestCase):
}
url = "/superset/save_dash/{}/".format(dash.id)
self.get_resp(url, data=dict(data=json.dumps(data)))
updatedDash = (
db.session.query(models.Dashboard).filter_by(slug="births").first()
)
updatedDash = db.session.query(Dashboard).filter_by(slug="births").first()
self.assertEqual(updatedDash.dashboard_title, "new title")
# bring back dashboard original title
data["dashboard_title"] = origin_title
@ -177,7 +173,7 @@ class DashboardTests(SupersetTestCase):
def test_save_dash_with_colors(self, username="admin"):
self.login(username=username)
dash = db.session.query(models.Dashboard).filter_by(slug="births").first()
dash = db.session.query(Dashboard).filter_by(slug="births").first()
positions = self.get_mock_positions(dash)
new_label_colors = {"data value": "random color"}
data = {
@ -191,9 +187,7 @@ class DashboardTests(SupersetTestCase):
}
url = "/superset/save_dash/{}/".format(dash.id)
self.get_resp(url, data=dict(data=json.dumps(data)))
updatedDash = (
db.session.query(models.Dashboard).filter_by(slug="births").first()
)
updatedDash = db.session.query(Dashboard).filter_by(slug="births").first()
self.assertIn("color_namespace", updatedDash.json_metadata)
self.assertIn("color_scheme", updatedDash.json_metadata)
self.assertIn("label_colors", updatedDash.json_metadata)
@ -205,7 +199,7 @@ class DashboardTests(SupersetTestCase):
def test_copy_dash(self, username="admin"):
self.login(username=username)
dash = db.session.query(models.Dashboard).filter_by(slug="births").first()
dash = db.session.query(Dashboard).filter_by(slug="births").first()
positions = self.get_mock_positions(dash)
new_label_colors = {"data value": "random color"}
data = {
@ -223,7 +217,7 @@ class DashboardTests(SupersetTestCase):
dash_id = dash.id
url = "/superset/save_dash/{}/".format(dash_id)
self.client.post(url, data=dict(data=json.dumps(data)))
dash = db.session.query(models.Dashboard).filter_by(id=dash_id).first()
dash = db.session.query(Dashboard).filter_by(id=dash_id).first()
orig_json_data = dash.data
# Verify that copy matches original
@ -241,16 +235,12 @@ class DashboardTests(SupersetTestCase):
def test_add_slices(self, username="admin"):
self.login(username=username)
dash = db.session.query(models.Dashboard).filter_by(slug="births").first()
dash = db.session.query(Dashboard).filter_by(slug="births").first()
new_slice = (
db.session.query(models.Slice)
.filter_by(slice_name="Energy Force Layout")
.first()
db.session.query(Slice).filter_by(slice_name="Energy Force Layout").first()
)
existing_slice = (
db.session.query(models.Slice)
.filter_by(slice_name="Girl Name Cloud")
.first()
db.session.query(Slice).filter_by(slice_name="Girl Name Cloud").first()
)
data = {
"slice_ids": [new_slice.data["slice_id"], existing_slice.data["slice_id"]]
@ -259,23 +249,21 @@ class DashboardTests(SupersetTestCase):
resp = self.client.post(url, data=dict(data=json.dumps(data)))
assert "SLICES ADDED" in resp.data.decode("utf-8")
dash = db.session.query(models.Dashboard).filter_by(slug="births").first()
dash = db.session.query(Dashboard).filter_by(slug="births").first()
new_slice = (
db.session.query(models.Slice)
.filter_by(slice_name="Energy Force Layout")
.first()
db.session.query(Slice).filter_by(slice_name="Energy Force Layout").first()
)
assert new_slice in dash.slices
assert len(set(dash.slices)) == len(dash.slices)
# cleaning up
dash = db.session.query(models.Dashboard).filter_by(slug="births").first()
dash = db.session.query(Dashboard).filter_by(slug="births").first()
dash.slices = [o for o in dash.slices if o.slice_name != "Energy Force Layout"]
db.session.commit()
def test_remove_slices(self, username="admin"):
self.login(username=username)
dash = db.session.query(models.Dashboard).filter_by(slug="births").first()
dash = db.session.query(Dashboard).filter_by(slug="births").first()
origin_slices_length = len(dash.slices)
positions = self.get_mock_positions(dash)
@ -297,7 +285,7 @@ class DashboardTests(SupersetTestCase):
dash_id = dash.id
url = "/superset/save_dash/{}/".format(dash_id)
self.client.post(url, data=dict(data=json.dumps(data)))
dash = db.session.query(models.Dashboard).filter_by(id=dash_id).first()
dash = db.session.query(Dashboard).filter_by(id=dash_id).first()
# verify slices data
data = dash.data
@ -307,7 +295,7 @@ class DashboardTests(SupersetTestCase):
table = db.session.query(SqlaTable).filter_by(table_name="birth_names").one()
# Make the births dash published so it can be seen
births_dash = db.session.query(models.Dashboard).filter_by(slug="births").one()
births_dash = db.session.query(Dashboard).filter_by(slug="births").one()
births_dash.published = True
db.session.merge(births_dash)
@ -345,7 +333,7 @@ class DashboardTests(SupersetTestCase):
table = db.session.query(SqlaTable).filter_by(table_name="birth_names").one()
self.grant_public_access_to_table(table)
dash = db.session.query(models.Dashboard).filter_by(slug="births").first()
dash = db.session.query(Dashboard).filter_by(slug="births").first()
dash.owners = [security_manager.find_user("admin")]
dash.created_by = security_manager.find_user("admin")
db.session.merge(dash)
@ -354,7 +342,7 @@ class DashboardTests(SupersetTestCase):
assert "Births" in self.get_resp("/superset/dashboard/births/")
def test_only_owners_can_save(self):
dash = db.session.query(models.Dashboard).filter_by(slug="births").first()
dash = db.session.query(Dashboard).filter_by(slug="births").first()
dash.owners = []
db.session.merge(dash)
db.session.commit()
@ -365,18 +353,16 @@ class DashboardTests(SupersetTestCase):
alpha = security_manager.find_user("alpha")
dash = db.session.query(models.Dashboard).filter_by(slug="births").first()
dash = db.session.query(Dashboard).filter_by(slug="births").first()
dash.owners = [alpha]
db.session.merge(dash)
db.session.commit()
self.test_save_dash("alpha")
def test_owners_can_view_empty_dashboard(self):
dash = (
db.session.query(models.Dashboard).filter_by(slug="empty_dashboard").first()
)
dash = db.session.query(Dashboard).filter_by(slug="empty_dashboard").first()
if not dash:
dash = models.Dashboard()
dash = Dashboard()
dash.dashboard_title = "Empty Dashboard"
dash.slug = "empty_dashboard"
else:
@ -394,9 +380,7 @@ class DashboardTests(SupersetTestCase):
def test_users_can_view_published_dashboard(self):
table = db.session.query(SqlaTable).filter_by(table_name="energy_usage").one()
# get a slice from the allowed table
slice = (
db.session.query(models.Slice).filter_by(slice_name="Energy Sankey").one()
)
slice = db.session.query(Slice).filter_by(slice_name="Energy Sankey").one()
self.grant_public_access_to_table(table)
@ -404,13 +388,13 @@ class DashboardTests(SupersetTestCase):
published_dash_slug = f"published_dash_{random()}"
# Create a published and hidden dashboard and add them to the database
published_dash = models.Dashboard()
published_dash = Dashboard()
published_dash.dashboard_title = "Published Dashboard"
published_dash.slug = published_dash_slug
published_dash.slices = [slice]
published_dash.published = True
hidden_dash = models.Dashboard()
hidden_dash = Dashboard()
hidden_dash.dashboard_title = "Hidden Dashboard"
hidden_dash.slug = hidden_dash_slug
hidden_dash.slices = [slice]
@ -430,13 +414,13 @@ class DashboardTests(SupersetTestCase):
not_my_dash_slug = f"not_my_dash_{random()}"
# Create one dashboard I own and another that I don't
dash = models.Dashboard()
dash = Dashboard()
dash.dashboard_title = "My Dashboard"
dash.slug = my_dash_slug
dash.owners = [user]
dash.slices = []
hidden_dash = models.Dashboard()
hidden_dash = Dashboard()
hidden_dash.dashboard_title = "Not My Dashboard"
hidden_dash.slug = not_my_dash_slug
hidden_dash.slices = []
@ -457,11 +441,11 @@ class DashboardTests(SupersetTestCase):
fav_dash_slug = f"my_favorite_dash_{random()}"
regular_dash_slug = f"regular_dash_{random()}"
favorite_dash = models.Dashboard()
favorite_dash = Dashboard()
favorite_dash.dashboard_title = "My Favorite Dashboard"
favorite_dash.slug = fav_dash_slug
regular_dash = models.Dashboard()
regular_dash = Dashboard()
regular_dash.dashboard_title = "A Plain Ol Dashboard"
regular_dash.slug = regular_dash_slug
@ -469,7 +453,7 @@ class DashboardTests(SupersetTestCase):
db.session.merge(regular_dash)
db.session.commit()
dash = db.session.query(models.Dashboard).filter_by(slug=fav_dash_slug).first()
dash = db.session.query(Dashboard).filter_by(slug=fav_dash_slug).first()
favorites = models.FavStar()
favorites.obj_id = dash.id
@ -490,7 +474,7 @@ class DashboardTests(SupersetTestCase):
slug = f"admin_owned_unpublished_dash_{random()}"
# Create a dashboard owned by admin and unpublished
dash = models.Dashboard()
dash = Dashboard()
dash.dashboard_title = "My Dashboard"
dash.slug = slug
dash.owners = [admin_user]

View File

@ -14,17 +14,20 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
"""Unit tests for Superset"""
import json
import unittest
from datetime import datetime
from unittest.mock import Mock, patch
from superset import db, security_manager
from tests.test_app import app
from superset import db, security_manager
from .base_tests import SupersetTestCase
try:
from superset.connectors.druid.models import (
DruidCluster,

View File

@ -22,12 +22,13 @@ import unittest
from flask import g
from sqlalchemy.orm.session import make_transient
from superset.utils.dashboard_import_export import decode_dashboards
from tests.test_app import app
from superset.utils.dashboard_import_export import decode_dashboards
from superset import db, security_manager
from superset.connectors.druid.models import DruidColumn, DruidDatasource, DruidMetric
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.models import core as models
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from .base_tests import SupersetTestCase
@ -40,10 +41,10 @@ class ImportExportTests(SupersetTestCase):
with app.app_context():
# Imported data clean up
session = db.session
for slc in session.query(models.Slice):
for slc in session.query(Slice):
if "remote_id" in slc.params_dict:
session.delete(slc)
for dash in session.query(models.Dashboard):
for dash in session.query(Dashboard):
if "remote_id" in dash.params_dict:
session.delete(dash)
for table in session.query(SqlaTable):
@ -86,7 +87,7 @@ class ImportExportTests(SupersetTestCase):
if table:
ds_id = table.id
return models.Slice(
return Slice(
slice_name=name,
datasource_type="table",
viz_type="bubble",
@ -97,7 +98,7 @@ class ImportExportTests(SupersetTestCase):
def create_dashboard(self, title, id=0, slcs=[]):
json_metadata = {"remote_id": id}
return models.Dashboard(
return Dashboard(
id=id,
dashboard_title=title,
slices=slcs,
@ -132,13 +133,13 @@ class ImportExportTests(SupersetTestCase):
return datasource
def get_slice(self, slc_id):
return db.session.query(models.Slice).filter_by(id=slc_id).first()
return db.session.query(Slice).filter_by(id=slc_id).first()
def get_slice_by_name(self, name):
return db.session.query(models.Slice).filter_by(slice_name=name).first()
return db.session.query(Slice).filter_by(slice_name=name).first()
def get_dash(self, dash_id):
return db.session.query(models.Dashboard).filter_by(id=dash_id).first()
return db.session.query(Dashboard).filter_by(id=dash_id).first()
def get_datasource(self, datasource_id):
return db.session.query(DruidDatasource).filter_by(id=datasource_id).first()
@ -292,7 +293,7 @@ class ImportExportTests(SupersetTestCase):
def test_import_1_slice(self):
expected_slice = self.create_slice("Import Me", id=10001)
slc_id = models.Slice.import_obj(expected_slice, None, import_time=1989)
slc_id = Slice.import_obj(expected_slice, None, import_time=1989)
slc = self.get_slice(slc_id)
self.assertEqual(slc.datasource.perm, slc.perm)
self.assert_slice_equals(expected_slice, slc)
@ -304,9 +305,9 @@ class ImportExportTests(SupersetTestCase):
table_id = self.get_table_by_name("wb_health_population").id
# table_id != 666, import func will have to find the table
slc_1 = self.create_slice("Import Me 1", ds_id=666, id=10002)
slc_id_1 = models.Slice.import_obj(slc_1, None)
slc_id_1 = Slice.import_obj(slc_1, None)
slc_2 = self.create_slice("Import Me 2", ds_id=666, id=10003)
slc_id_2 = models.Slice.import_obj(slc_2, None)
slc_id_2 = Slice.import_obj(slc_2, None)
imported_slc_1 = self.get_slice(slc_id_1)
imported_slc_2 = self.get_slice(slc_id_2)
@ -320,25 +321,25 @@ class ImportExportTests(SupersetTestCase):
def test_import_slices_for_non_existent_table(self):
with self.assertRaises(AttributeError):
models.Slice.import_obj(
Slice.import_obj(
self.create_slice("Import Me 3", id=10004, table_name="non_existent"),
None,
)
def test_import_slices_override(self):
slc = self.create_slice("Import Me New", id=10005)
slc_1_id = models.Slice.import_obj(slc, None, import_time=1990)
slc_1_id = Slice.import_obj(slc, None, import_time=1990)
slc.slice_name = "Import Me New"
imported_slc_1 = self.get_slice(slc_1_id)
slc_2 = self.create_slice("Import Me New", id=10005)
slc_2_id = models.Slice.import_obj(slc_2, imported_slc_1, import_time=1990)
slc_2_id = Slice.import_obj(slc_2, imported_slc_1, import_time=1990)
self.assertEqual(slc_1_id, slc_2_id)
imported_slc_2 = self.get_slice(slc_2_id)
self.assert_slice_equals(slc, imported_slc_2)
def test_import_empty_dashboard(self):
empty_dash = self.create_dashboard("empty_dashboard", id=10001)
imported_dash_id = models.Dashboard.import_obj(empty_dash, import_time=1989)
imported_dash_id = Dashboard.import_obj(empty_dash, import_time=1989)
imported_dash = self.get_dash(imported_dash_id)
self.assert_dash_equals(empty_dash, imported_dash, check_position=False)
@ -363,9 +364,7 @@ class ImportExportTests(SupersetTestCase):
""".format(
slc.id
)
imported_dash_id = models.Dashboard.import_obj(
dash_with_1_slice, import_time=1990
)
imported_dash_id = Dashboard.import_obj(dash_with_1_slice, import_time=1990)
imported_dash = self.get_dash(imported_dash_id)
expected_dash = self.create_dashboard("dash_with_1_slice", slcs=[slc], id=10002)
@ -400,9 +399,7 @@ class ImportExportTests(SupersetTestCase):
}
)
imported_dash_id = models.Dashboard.import_obj(
dash_with_2_slices, import_time=1991
)
imported_dash_id = Dashboard.import_obj(dash_with_2_slices, import_time=1991)
imported_dash = self.get_dash(imported_dash_id)
expected_dash = self.create_dashboard(
@ -431,9 +428,7 @@ class ImportExportTests(SupersetTestCase):
dash_to_import = self.create_dashboard(
"override_dashboard", slcs=[e_slc, b_slc], id=10004
)
imported_dash_id_1 = models.Dashboard.import_obj(
dash_to_import, import_time=1992
)
imported_dash_id_1 = Dashboard.import_obj(dash_to_import, import_time=1992)
# create new instances of the slices
e_slc = self.create_slice("e_slc", id=10009, table_name="energy_usage")
@ -442,7 +437,7 @@ class ImportExportTests(SupersetTestCase):
dash_to_import_override = self.create_dashboard(
"override_dashboard_new", slcs=[e_slc, b_slc, c_slc], id=10004
)
imported_dash_id_2 = models.Dashboard.import_obj(
imported_dash_id_2 = Dashboard.import_obj(
dash_to_import_override, import_time=1992
)
@ -472,7 +467,7 @@ class ImportExportTests(SupersetTestCase):
dash_with_1_slice.changed_by = admin_user
dash_with_1_slice.owners = [admin_user]
imported_dash_id = models.Dashboard.import_obj(dash_with_1_slice)
imported_dash_id = Dashboard.import_obj(dash_with_1_slice)
imported_dash = self.get_dash(imported_dash_id)
self.assertEqual(imported_dash.created_by, gamma_user)
self.assertEqual(imported_dash.changed_by, gamma_user)
@ -492,7 +487,7 @@ class ImportExportTests(SupersetTestCase):
dash_with_1_slice = self._create_dashboard_for_import(id_=10300)
imported_dash_id = models.Dashboard.import_obj(dash_with_1_slice)
imported_dash_id = Dashboard.import_obj(dash_with_1_slice)
imported_dash = self.get_dash(imported_dash_id)
self.assertEqual(imported_dash.created_by, gamma_user)
self.assertEqual(imported_dash.changed_by, gamma_user)
@ -508,7 +503,7 @@ class ImportExportTests(SupersetTestCase):
dash_with_1_slice = self._create_dashboard_for_import(id_=10300)
imported_dash_id = models.Dashboard.import_obj(dash_with_1_slice)
imported_dash_id = Dashboard.import_obj(dash_with_1_slice)
imported_dash = self.get_dash(imported_dash_id)
self.assertEqual(imported_dash.created_by, gamma_user)
self.assertEqual(imported_dash.changed_by, gamma_user)

View File

@ -23,7 +23,7 @@ from selenium.common.exceptions import WebDriverException
from tests.test_app import app
from superset import db
from superset.models.core import Dashboard, Slice
from superset.models.dashboard import Dashboard
from superset.models.schedules import (
DashboardEmailSchedule,
EmailDeliveryType,
@ -36,6 +36,7 @@ from superset.tasks.schedules import (
deliver_slice,
next_schedules,
)
from superset.models.slice import Slice
from tests.base_tests import SupersetTestCase
from .utils import read_fixture

View File

@ -24,7 +24,8 @@ from superset import app, appbuilder, db, security_manager, viz
from superset.connectors.druid.models import DruidCluster, DruidDatasource
from superset.connectors.sqla.models import SqlaTable
from superset.exceptions import SupersetSecurityException
from superset.models.core import Database, Slice
from superset.models.core import Database
from superset.models.slice import Slice
from superset.utils.core import get_example_database
from .base_tests import SupersetTestCase