Break some static methods out of superset.views.core.Superset (#10175)

This commit is contained in:
Will Barrett 2020-06-26 14:34:45 -07:00 committed by GitHub
parent f6ed46dcc0
commit 4965d87505
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 192 additions and 172 deletions

View File

@ -60,3 +60,15 @@ class ChartDAO(BaseDAO):
@staticmethod
def fetch_all_datasources() -> List["BaseDatasource"]:
return ConnectorRegistry.get_all_datasources(db.session)
@staticmethod
def save(slc: Slice, commit: bool = True) -> None:
db.session.add(slc)
if commit:
db.session.commit()
@staticmethod
def overwrite(slc: Slice, commit: bool = True) -> None:
db.session.merge(slc)
if commit:
db.session.commit()

View File

@ -14,8 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import logging
from typing import List, Optional
from typing import Any, Dict, List, Optional
from sqlalchemy.exc import SQLAlchemyError
@ -23,6 +24,8 @@ from superset.dao.base import BaseDAO
from superset.dashboards.filters import DashboardFilter
from superset.extensions import db
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.utils.dashboard_filter_scopes_converter import copy_filter_scopes
logger = logging.getLogger(__name__)
@ -76,3 +79,85 @@ class DashboardDAO(BaseDAO):
if commit:
db.session.rollback()
raise ex
@staticmethod
def set_dash_metadata( # pylint: disable=too-many-locals,too-many-branches,too-many-statements
dashboard: Dashboard,
data: Dict[Any, Any],
old_to_new_slice_ids: Optional[Dict[int, int]] = None,
) -> None:
positions = data["positions"]
# find slices in the position data
slice_ids = []
slice_id_to_name = {}
for value in positions.values():
if isinstance(value, dict):
try:
slice_id = value["meta"]["chartId"]
slice_ids.append(slice_id)
slice_id_to_name[slice_id] = value["meta"]["sliceName"]
except KeyError:
pass
session = db.session()
current_slices = session.query(Slice).filter(Slice.id.in_(slice_ids)).all()
dashboard.slices = current_slices
# update slice names. this assumes user has permissions to update the slice
# we allow user set slice name be empty string
for slc in dashboard.slices:
try:
new_name = slice_id_to_name[slc.id]
if slc.slice_name != new_name:
slc.slice_name = new_name
session.merge(slc)
session.flush()
except KeyError:
pass
# remove leading and trailing white spaces in the dumped json
dashboard.position_json = json.dumps(
positions, indent=None, separators=(",", ":"), sort_keys=True
)
md = dashboard.params_dict
dashboard.css = data.get("css")
dashboard.dashboard_title = data["dashboard_title"]
if "timed_refresh_immune_slices" not in md:
md["timed_refresh_immune_slices"] = []
new_filter_scopes = {}
if "filter_scopes" in data:
# replace filter_id and immune ids from old slice id to new slice id:
# and remove slice ids that are not in dash anymore
slc_id_dict: Dict[int, int] = {}
if old_to_new_slice_ids:
slc_id_dict = {
old: new
for old, new in old_to_new_slice_ids.items()
if new in slice_ids
}
else:
slc_id_dict = {sid: sid for sid in slice_ids}
new_filter_scopes = copy_filter_scopes(
old_to_new_slc_id_dict=slc_id_dict,
old_filter_scopes=json.loads(data["filter_scopes"] or "{}"),
)
if new_filter_scopes:
md["filter_scopes"] = new_filter_scopes
else:
md.pop("filter_scopes", None)
md["expanded_slices"] = data.get("expanded_slices", {})
md["refresh_frequency"] = data.get("refresh_frequency", 0)
default_filters_data = json.loads(data.get("default_filters", "{}"))
applicable_filters = {
key: v for key, v in default_filters_data.items() if int(key) in slice_ids
}
md["default_filters"] = json.dumps(applicable_filters)
if data.get("color_namespace"):
md["color_namespace"] = data.get("color_namespace")
if data.get("color_scheme"):
md["color_scheme"] = data.get("color_scheme")
if data.get("label_colors"):
md["label_colors"] = data.get("label_colors")
dashboard.json_metadata = json.dumps(md)

View File

@ -58,6 +58,7 @@ from superset import (
sql_lab,
viz,
)
from superset.charts.dao import ChartDAO
from superset.connectors.connector_registry import ConnectorRegistry
from superset.connectors.sqla.models import (
AnnotationDatasource,
@ -65,6 +66,7 @@ from superset.connectors.sqla.models import (
SqlMetric,
TableColumn,
)
from superset.dashboards.dao import DashboardDAO
from superset.exceptions import (
CertificateException,
DatabaseNotFound,
@ -86,7 +88,6 @@ from superset.sql_parse import CtasMethod, ParsedQuery, Table
from superset.sql_validators import get_validator_by_name
from superset.typing import FlaskResponse
from superset.utils import core as utils, dashboard_import_export
from superset.utils.dashboard_filter_scopes_converter import copy_filter_scopes
from superset.utils.dates import now_as_float
from superset.utils.decorators import etag_cache
from superset.views.base import (
@ -756,7 +757,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
Those should not be saved when saving the chart"""
return [f for f in filters if not f.get("isExtra")]
def save_or_overwrite_slice( # pylint: disable=too-many-arguments
def save_or_overwrite_slice( # pylint: disable=too-many-arguments,too-many-locals,no-self-use
self,
slc: Optional[Slice],
slice_add_perm: bool,
@ -789,9 +790,13 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
slc.slice_name = slice_name
if action == "saveas" and slice_add_perm:
self.save_slice(slc)
ChartDAO.save(slc)
msg = _("Chart [{}] has been saved").format(slc.slice_name)
flash(msg, "info")
elif action == "overwrite" and slice_overwrite_perm:
self.overwrite_slice(slc)
ChartDAO.overwrite(slc)
msg = _("Chart [{}] has been overwritten").format(slc.slice_name)
flash(msg, "info")
# Adding slice to a dashboard if requested
dash: Optional[Dashboard] = None
@ -859,22 +864,6 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
return json_success(json.dumps(response))
@staticmethod
def save_slice(slc: Slice) -> None:
session = db.session()
msg = _("Chart [{}] has been saved").format(slc.slice_name)
session.add(slc)
session.commit()
flash(msg, "info")
@staticmethod
def overwrite_slice(slc: Slice) -> None:
session = db.session()
session.merge(slc)
session.commit()
msg = _("Chart [{}] has been overwritten").format(slc.slice_name)
flash(msg, "info")
@api
@has_access_api
@expose("/schemas/<int:db_id>/")
@ -1003,7 +992,9 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
@api
@has_access_api
@expose("/copy_dash/<int:dashboard_id>/", methods=["GET", "POST"])
def copy_dash(self, dashboard_id: int) -> FlaskResponse:
def copy_dash( # pylint: disable=no-self-use
self, dashboard_id: int
) -> FlaskResponse:
"""Copy dashboard"""
session = db.session()
data = json.loads(request.form["data"])
@ -1035,7 +1026,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
dash.params = original_dash.params
self._set_dash_metadata(dash, data, old_to_new_slice_ids)
DashboardDAO.set_dash_metadata(dash, data, old_to_new_slice_ids)
session.add(dash)
session.commit()
dash_json = json.dumps(dash.data)
@ -1045,100 +1036,20 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
@api
@has_access_api
@expose("/save_dash/<int:dashboard_id>/", methods=["GET", "POST"])
def save_dash(self, dashboard_id: int) -> FlaskResponse:
def save_dash( # pylint: disable=no-self-use
self, dashboard_id: int
) -> FlaskResponse:
"""Save a dashboard's metadata"""
session = db.session()
dash = session.query(Dashboard).get(dashboard_id)
check_ownership(dash, raise_if_false=True)
data = json.loads(request.form["data"])
self._set_dash_metadata(dash, data)
DashboardDAO.set_dash_metadata(dash, data)
session.merge(dash)
session.commit()
session.close()
return json_success(json.dumps({"status": "SUCCESS"}))
@staticmethod
def _set_dash_metadata( # pylint: disable=too-many-locals,too-many-branches,too-many-statements
dashboard: Dashboard,
data: Dict[Any, Any],
old_to_new_slice_ids: Optional[Dict[int, int]] = None,
) -> None:
positions = data["positions"]
# find slices in the position data
slice_ids = []
slice_id_to_name = {}
for value in positions.values():
if isinstance(value, dict):
try:
slice_id = value["meta"]["chartId"]
slice_ids.append(slice_id)
slice_id_to_name[slice_id] = value["meta"]["sliceName"]
except KeyError:
pass
session = db.session()
current_slices = session.query(Slice).filter(Slice.id.in_(slice_ids)).all()
dashboard.slices = current_slices
# update slice names. this assumes user has permissions to update the slice
# we allow user set slice name be empty string
for slc in dashboard.slices:
try:
new_name = slice_id_to_name[slc.id]
if slc.slice_name != new_name:
slc.slice_name = new_name
session.merge(slc)
session.flush()
except KeyError:
pass
# remove leading and trailing white spaces in the dumped json
dashboard.position_json = json.dumps(
positions, indent=None, separators=(",", ":"), sort_keys=True
)
md = dashboard.params_dict
dashboard.css = data.get("css")
dashboard.dashboard_title = data["dashboard_title"]
if "timed_refresh_immune_slices" not in md:
md["timed_refresh_immune_slices"] = []
new_filter_scopes = {}
if "filter_scopes" in data:
# replace filter_id and immune ids from old slice id to new slice id:
# and remove slice ids that are not in dash anymore
slc_id_dict: Dict[int, int] = {}
if old_to_new_slice_ids:
slc_id_dict = {
old: new
for old, new in old_to_new_slice_ids.items()
if new in slice_ids
}
else:
slc_id_dict = {sid: sid for sid in slice_ids}
new_filter_scopes = copy_filter_scopes(
old_to_new_slc_id_dict=slc_id_dict,
old_filter_scopes=json.loads(data["filter_scopes"] or "{}"),
)
if new_filter_scopes:
md["filter_scopes"] = new_filter_scopes
else:
md.pop("filter_scopes", None)
md["expanded_slices"] = data.get("expanded_slices", {})
md["refresh_frequency"] = data.get("refresh_frequency", 0)
default_filters_data = json.loads(data.get("default_filters", "{}"))
applicable_filters = {
key: v for key, v in default_filters_data.items() if int(key) in slice_ids
}
md["default_filters"] = json.dumps(applicable_filters)
if data.get("color_namespace"):
md["color_namespace"] = data.get("color_namespace")
if data.get("color_scheme"):
md["color_scheme"] = data.get("color_scheme")
if data.get("label_colors"):
md["label_colors"] = data.get("label_colors")
dashboard.json_metadata = json.dumps(md)
@api
@has_access_api
@expose("/add_slices/<int:dashboard_id>/", methods=["POST"])

View File

@ -16,7 +16,6 @@
# under the License.
# isort:skip_file
"""Unit tests for Superset"""
import copy
import json
import unittest
from random import random
@ -37,19 +36,6 @@ from .base_tests import SupersetTestCase
class DashboardTests(SupersetTestCase):
def __init__(self, *args, **kwargs):
super(DashboardTests, self).__init__(*args, **kwargs)
@classmethod
def setUpClass(cls):
pass
def setUp(self):
pass
def tearDown(self):
pass
def get_mock_positions(self, dash):
positions = {"DASHBOARD_VERSION_KEY": "v2"}
for i, slc in enumerate(dash.slices):
@ -238,57 +224,6 @@ class DashboardTests(SupersetTestCase):
if key not in ["modified", "changed_on", "changed_on_humanized"]:
self.assertEqual(slc[key], resp["slices"][index][key])
def test_set_dash_metadata(self, username="admin"):
self.login(username=username)
dash = db.session.query(Dashboard).filter_by(slug="world_health").first()
data = dash.data
positions = data["position_json"]
data.update({"positions": positions})
original_data = copy.deepcopy(data)
# add filter scopes
filter_slice = dash.slices[0]
immune_slices = dash.slices[2:]
filter_scopes = {
str(filter_slice.id): {
"region": {
"scope": ["ROOT_ID"],
"immune": [slc.id for slc in immune_slices],
}
}
}
data.update({"filter_scopes": json.dumps(filter_scopes)})
views.Superset._set_dash_metadata(dash, data)
updated_metadata = json.loads(dash.json_metadata)
self.assertEqual(updated_metadata["filter_scopes"], filter_scopes)
# remove a slice and change slice ids (as copy slices)
removed_slice = immune_slices.pop()
removed_component = [
key
for (key, value) in positions.items()
if isinstance(value, dict)
and value.get("type") == "CHART"
and value["meta"]["chartId"] == removed_slice.id
]
positions.pop(removed_component[0], None)
data.update({"positions": positions})
views.Superset._set_dash_metadata(dash, data)
updated_metadata = json.loads(dash.json_metadata)
expected_filter_scopes = {
str(filter_slice.id): {
"region": {
"scope": ["ROOT_ID"],
"immune": [slc.id for slc in immune_slices],
}
}
}
self.assertEqual(updated_metadata["filter_scopes"], expected_filter_scopes)
# reset dash to original data
views.Superset._set_dash_metadata(dash, original_data)
def test_add_slices(self, username="admin"):
self.login(username=username)
dash = db.session.query(Dashboard).filter_by(slug="births").first()

View File

@ -0,0 +1,77 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
import copy
import json
import tests.test_app # pylint: disable=unused-import
from superset import db
from superset.dashboards.dao import DashboardDAO
from superset.models.dashboard import Dashboard
from tests.base_tests import SupersetTestCase
class DashboardDAOTests(SupersetTestCase):
def test_set_dash_metadata(self):
dash = db.session.query(Dashboard).filter_by(slug="world_health").first()
data = dash.data
positions = data["position_json"]
data.update({"positions": positions})
original_data = copy.deepcopy(data)
# add filter scopes
filter_slice = dash.slices[0]
immune_slices = dash.slices[2:]
filter_scopes = {
str(filter_slice.id): {
"region": {
"scope": ["ROOT_ID"],
"immune": [slc.id for slc in immune_slices],
}
}
}
data.update({"filter_scopes": json.dumps(filter_scopes)})
DashboardDAO.set_dash_metadata(dash, data)
updated_metadata = json.loads(dash.json_metadata)
self.assertEqual(updated_metadata["filter_scopes"], filter_scopes)
# remove a slice and change slice ids (as copy slices)
removed_slice = immune_slices.pop()
removed_component = [
key
for (key, value) in positions.items()
if isinstance(value, dict)
and value.get("type") == "CHART"
and value["meta"]["chartId"] == removed_slice.id
]
positions.pop(removed_component[0], None)
data.update({"positions": positions})
DashboardDAO.set_dash_metadata(dash, data)
updated_metadata = json.loads(dash.json_metadata)
expected_filter_scopes = {
str(filter_slice.id): {
"region": {
"scope": ["ROOT_ID"],
"immune": [slc.id for slc in immune_slices],
}
}
}
self.assertEqual(updated_metadata["filter_scopes"], expected_filter_scopes)
# reset dash to original data
DashboardDAO.set_dash_metadata(dash, original_data)