From 2e564897f860192c3e3ecbe41cfbac6b3e557b35 Mon Sep 17 00:00:00 2001 From: cccs-Dustin <96579982+cccs-Dustin@users.noreply.github.com> Date: Fri, 23 Sep 2022 04:01:17 -0400 Subject: [PATCH] feat: Add dataset tagging to the back-end (#20892) Co-authored-by: Ville Brofeldt <33317356+villebro@users.noreply.github.com> --- superset/common/tags.py | 353 ++++++++++++------ superset/initialization/__init__.py | 4 + ...26_11-10_c82ee8a39623_add_implicit_tags.py | 2 +- superset/models/core.py | 9 +- superset/models/dashboard.py | 7 - superset/models/slice.py | 8 - superset/models/sql_lab.py | 7 - superset/tags/core.py | 88 +++++ superset/{models/tags.py => tags/models.py} | 46 ++- superset/tasks/cache.py | 2 +- superset/utils/url_map_converters.py | 2 +- superset/views/tags.py | 30 +- tests/integration_tests/fixtures/tags.py | 33 ++ tests/integration_tests/strategy_tests.py | 2 +- tests/integration_tests/tagging_tests.py | 276 ++++++++++++++ 15 files changed, 712 insertions(+), 157 deletions(-) create mode 100644 superset/tags/core.py rename superset/{models/tags.py => tags/models.py} (84%) create mode 100644 tests/integration_tests/fixtures/tags.py diff --git a/superset/common/tags.py b/superset/common/tags.py index 74c882cf9..d85a33b84 100644 --- a/superset/common/tags.py +++ b/superset/common/tags.py @@ -14,15 +14,145 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from sqlalchemy import Metadata +from typing import Any, List + +from sqlalchemy import MetaData from sqlalchemy.engine import Engine from sqlalchemy.exc import IntegrityError -from sqlalchemy.sql import and_, func, functions, join, literal, select +from sqlalchemy.sql import and_, func, join, literal, select -from superset.models.tags import ObjectTypes, TagTypes +from superset.tags.models import ObjectTypes, TagTypes -def add_types(engine: Engine, metadata: Metadata) -> None: +def add_types_to_charts( + engine: Engine, metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] +) -> None: + slices = metadata.tables["slices"] + + charts = ( + select( + [ + tag.c.id.label("tag_id"), + slices.c.id.label("object_id"), + literal(ObjectTypes.chart.name).label("object_type"), + ] + ) + .select_from( + join( + join(slices, tag, tag.c.name == "type:chart"), + tagged_object, + and_( + tagged_object.c.tag_id == tag.c.id, + tagged_object.c.object_id == slices.c.id, + tagged_object.c.object_type == "chart", + ), + isouter=True, + full=False, + ) + ) + .where(tagged_object.c.tag_id.is_(None)) + ) + query = tagged_object.insert().from_select(columns, charts) + engine.execute(query) + + +def add_types_to_dashboards( + engine: Engine, metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] +) -> None: + dashboard_table = metadata.tables["dashboards"] + + dashboards = ( + select( + [ + tag.c.id.label("tag_id"), + dashboard_table.c.id.label("object_id"), + literal(ObjectTypes.dashboard.name).label("object_type"), + ] + ) + .select_from( + join( + join(dashboard_table, tag, tag.c.name == "type:dashboard"), + tagged_object, + and_( + tagged_object.c.tag_id == tag.c.id, + tagged_object.c.object_id == dashboard_table.c.id, + tagged_object.c.object_type == "dashboard", + ), + isouter=True, + full=False, + ) + ) + .where(tagged_object.c.tag_id.is_(None)) + ) + query = tagged_object.insert().from_select(columns, dashboards) + engine.execute(query) + + +def add_types_to_saved_queries( + engine: Engine, metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] +) -> None: + saved_query = metadata.tables["saved_query"] + + saved_queries = ( + select( + [ + tag.c.id.label("tag_id"), + saved_query.c.id.label("object_id"), + literal(ObjectTypes.query.name).label("object_type"), + ] + ) + .select_from( + join( + join(saved_query, tag, tag.c.name == "type:query"), + tagged_object, + and_( + tagged_object.c.tag_id == tag.c.id, + tagged_object.c.object_id == saved_query.c.id, + tagged_object.c.object_type == "query", + ), + isouter=True, + full=False, + ) + ) + .where(tagged_object.c.tag_id.is_(None)) + ) + query = tagged_object.insert().from_select(columns, saved_queries) + engine.execute(query) + + +def add_types_to_datasets( + engine: Engine, metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] +) -> None: + tables = metadata.tables["tables"] + + datasets = ( + select( + [ + tag.c.id.label("tag_id"), + tables.c.id.label("object_id"), + literal(ObjectTypes.dataset.name).label("object_type"), + ] + ) + .select_from( + join( + join(tables, tag, tag.c.name == "type:dataset"), + tagged_object, + and_( + tagged_object.c.tag_id == tag.c.id, + tagged_object.c.object_id == tables.c.id, + tagged_object.c.object_type == "dataset", + ), + isouter=True, + full=False, + ) + ) + .where(tagged_object.c.tag_id.is_(None)) + ) + query = tagged_object.insert().from_select(columns, datasets) + engine.execute(query) + + +def add_types(engine: Engine, metadata: MetaData) -> None: """ Tag every object according to its type: @@ -68,13 +198,24 @@ def add_types(engine: Engine, metadata: Metadata) -> None: AND tagged_object.object_type = 'query' WHERE tagged_object.tag_id IS NULL; + INSERT INTO tagged_object (tag_id, object_id, object_type) + SELECT + tag.id AS tag_id, + tables.id AS object_id, + 'dataset' AS object_type + FROM tables + JOIN tag + ON tag.name = 'type:dataset' + LEFT OUTER JOIN tagged_object + ON tagged_object.tag_id = tag.id + AND tagged_object.object_id = tables.id + AND tagged_object.object_type = 'dataset' + WHERE tagged_object.tag_id IS NULL; + """ tag = metadata.tables["tag"] tagged_object = metadata.tables["tagged_object"] - slices = metadata.tables["slices"] - dashboards = metadata.tables["dashboards"] - saved_query = metadata.tables["saved_query"] columns = ["tag_id", "object_id", "object_type"] # add a tag for each object type @@ -85,6 +226,17 @@ def add_types(engine: Engine, metadata: Metadata) -> None: except IntegrityError: pass # already exists + add_types_to_charts(engine, metadata, tag, tagged_object, columns) + add_types_to_dashboards(engine, metadata, tag, tagged_object, columns) + add_types_to_saved_queries(engine, metadata, tag, tagged_object, columns) + add_types_to_datasets(engine, metadata, tag, tagged_object, columns) + + +def add_owners_to_charts( + engine: Engine, metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] +) -> None: + slices = metadata.tables["slices"] + charts = ( select( [ @@ -95,7 +247,11 @@ def add_types(engine: Engine, metadata: Metadata) -> None: ) .select_from( join( - join(slices, tag, tag.c.name == "type:chart"), + join( + slices, + tag, + tag.c.name == "owner:" + slices.c.created_by_fk, + ), tagged_object, and_( tagged_object.c.tag_id == tag.c.id, @@ -111,21 +267,31 @@ def add_types(engine: Engine, metadata: Metadata) -> None: query = tagged_object.insert().from_select(columns, charts) engine.execute(query) + +def add_owners_to_dashboards( + engine: Engine, metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] +) -> None: + dashboard_table = metadata.tables["dashboards"] + dashboards = ( select( [ tag.c.id.label("tag_id"), - dashboards.c.id.label("object_id"), + dashboard_table.c.id.label("object_id"), literal(ObjectTypes.dashboard.name).label("object_type"), ] ) .select_from( join( - join(dashboards, tag, tag.c.name == "type:dashboard"), + join( + dashboard_table, + tag, + tag.c.name == "owner:" + dashboard_table.c.created_by_fk, + ), tagged_object, and_( tagged_object.c.tag_id == tag.c.id, - tagged_object.c.object_id == dashboards.c.id, + tagged_object.c.object_id == dashboard_table.c.id, tagged_object.c.object_type == "dashboard", ), isouter=True, @@ -137,6 +303,12 @@ def add_types(engine: Engine, metadata: Metadata) -> None: query = tagged_object.insert().from_select(columns, dashboards) engine.execute(query) + +def add_owners_to_saved_queries( + engine: Engine, metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] +) -> None: + saved_query = metadata.tables["saved_query"] + saved_queries = ( select( [ @@ -147,7 +319,11 @@ def add_types(engine: Engine, metadata: Metadata) -> None: ) .select_from( join( - join(saved_query, tag, tag.c.name == "type:query"), + join( + saved_query, + tag, + tag.c.name == "owner:" + saved_query.c.created_by_fk, + ), tagged_object, and_( tagged_object.c.tag_id == tag.c.id, @@ -164,7 +340,43 @@ def add_types(engine: Engine, metadata: Metadata) -> None: engine.execute(query) -def add_owners(engine: Engine, metadata: Metadata) -> None: +def add_owners_to_datasets( + engine: Engine, metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] +) -> None: + tables = metadata.tables["tables"] + + datasets = ( + select( + [ + tag.c.id.label("tag_id"), + tables.c.id.label("object_id"), + literal(ObjectTypes.dataset.name).label("object_type"), + ] + ) + .select_from( + join( + join( + tables, + tag, + tag.c.name == "owner:" + tables.c.created_by_fk, + ), + tagged_object, + and_( + tagged_object.c.tag_id == tag.c.id, + tagged_object.c.object_id == tables.c.id, + tagged_object.c.object_type == "dataset", + ), + isouter=True, + full=False, + ) + ) + .where(tagged_object.c.tag_id.is_(None)) + ) + query = tagged_object.insert().from_select(columns, datasets) + engine.execute(query) + + +def add_owners(engine: Engine, metadata: MetaData) -> None: """ Tag every object according to its owner: @@ -208,14 +420,24 @@ def add_owners(engine: Engine, metadata: Metadata) -> None: AND tagged_object.object_type = 'query' WHERE tagged_object.tag_id IS NULL; + SELECT + tag.id AS tag_id, + tables.id AS object_id, + 'dataset' AS object_type + FROM tables + JOIN tag + ON tag.name = CONCAT('owner:', tables.created_by_fk) + LEFT OUTER JOIN tagged_object + ON tagged_object.tag_id = tag.id + AND tagged_object.object_id = tables.id + AND tagged_object.object_type = 'dataset' + WHERE tagged_object.tag_id IS NULL; + """ tag = metadata.tables["tag"] tagged_object = metadata.tables["tagged_object"] users = metadata.tables["ab_user"] - slices = metadata.tables["slices"] - dashboards = metadata.tables["dashboards"] - saved_query = metadata.tables["saved_query"] columns = ["tag_id", "object_id", "object_type"] # create a custom tag for each user @@ -227,100 +449,13 @@ def add_owners(engine: Engine, metadata: Metadata) -> None: except IntegrityError: pass # already exists - charts = ( - select( - [ - tag.c.id.label("tag_id"), - slices.c.id.label("object_id"), - literal(ObjectTypes.chart.name).label("object_type"), - ] - ) - .select_from( - join( - join( - slices, - tag, - tag.c.name == functions.concat("owner:", slices.c.created_by_fk), - ), - tagged_object, - and_( - tagged_object.c.tag_id == tag.c.id, - tagged_object.c.object_id == slices.c.id, - tagged_object.c.object_type == "chart", - ), - isouter=True, - full=False, - ) - ) - .where(tagged_object.c.tag_id.is_(None)) - ) - query = tagged_object.insert().from_select(columns, charts) - engine.execute(query) - - dashboards = ( - select( - [ - tag.c.id.label("tag_id"), - dashboards.c.id.label("object_id"), - literal(ObjectTypes.dashboard.name).label("object_type"), - ] - ) - .select_from( - join( - join( - dashboards, - tag, - tag.c.name - == functions.concat("owner:", dashboards.c.created_by_fk), - ), - tagged_object, - and_( - tagged_object.c.tag_id == tag.c.id, - tagged_object.c.object_id == dashboards.c.id, - tagged_object.c.object_type == "dashboard", - ), - isouter=True, - full=False, - ) - ) - .where(tagged_object.c.tag_id.is_(None)) - ) - query = tagged_object.insert().from_select(columns, dashboards) - engine.execute(query) - - saved_queries = ( - select( - [ - tag.c.id.label("tag_id"), - saved_query.c.id.label("object_id"), - literal(ObjectTypes.query.name).label("object_type"), - ] - ) - .select_from( - join( - join( - saved_query, - tag, - tag.c.name - == functions.concat("owner:", saved_query.c.created_by_fk), - ), - tagged_object, - and_( - tagged_object.c.tag_id == tag.c.id, - tagged_object.c.object_id == saved_query.c.id, - tagged_object.c.object_type == "query", - ), - isouter=True, - full=False, - ) - ) - .where(tagged_object.c.tag_id.is_(None)) - ) - query = tagged_object.insert().from_select(columns, saved_queries) - engine.execute(query) + add_owners_to_charts(engine, metadata, tag, tagged_object, columns) + add_owners_to_dashboards(engine, metadata, tag, tagged_object, columns) + add_owners_to_saved_queries(engine, metadata, tag, tagged_object, columns) + add_owners_to_datasets(engine, metadata, tag, tagged_object, columns) -def add_favorites(engine: Engine, metadata: Metadata) -> None: +def add_favorites(engine: Engine, metadata: MetaData) -> None: """ Tag every object that was favorited: @@ -368,7 +503,7 @@ def add_favorites(engine: Engine, metadata: Metadata) -> None: join( favstar, tag, - tag.c.name == functions.concat("favorited_by:", favstar.c.user_id), + tag.c.name == "favorited_by:" + favstar.c.user_id, ), tagged_object, and_( diff --git a/superset/initialization/__init__.py b/superset/initialization/__init__.py index 12d3692ac..598cf94e0 100644 --- a/superset/initialization/__init__.py +++ b/superset/initialization/__init__.py @@ -49,6 +49,7 @@ from superset.extensions import ( ) from superset.security import SupersetSecurityManager from superset.superset_typing import FlaskResponse +from superset.tags.core import register_sqla_event_listeners from superset.utils.core import pessimistic_connection_handling from superset.utils.log import DBEventLogger, get_event_logger_from_cfg_value @@ -426,6 +427,9 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods if flask_app_mutator: flask_app_mutator(self.superset_app) + if feature_flag_manager.is_feature_enabled("TAGGING_SYSTEM"): + register_sqla_event_listeners() + self.init_views() def check_secret_key(self) -> None: diff --git a/superset/migrations/versions/2018-07-26_11-10_c82ee8a39623_add_implicit_tags.py b/superset/migrations/versions/2018-07-26_11-10_c82ee8a39623_add_implicit_tags.py index 8a1a5f989..0179ba7d0 100644 --- a/superset/migrations/versions/2018-07-26_11-10_c82ee8a39623_add_implicit_tags.py +++ b/superset/migrations/versions/2018-07-26_11-10_c82ee8a39623_add_implicit_tags.py @@ -33,7 +33,7 @@ from flask_appbuilder.models.mixins import AuditMixin from sqlalchemy import Column, DateTime, Enum, ForeignKey, Integer, String from sqlalchemy.ext.declarative import declarative_base, declared_attr -from superset.models.tags import ObjectTypes, TagTypes +from superset.tags.models import ObjectTypes, TagTypes from superset.utils.core import get_user_id Base = declarative_base() diff --git a/superset/models/core.py b/superset/models/core.py index 512adfebc..a8ab4df6b 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -53,13 +53,12 @@ from sqlalchemy.pool import NullPool from sqlalchemy.schema import UniqueConstraint from sqlalchemy.sql import expression, Select -from superset import app, db_engine_specs, is_feature_enabled +from superset import app, db_engine_specs from superset.constants import PASSWORD_MASK from superset.databases.utils import make_url_safe from superset.db_engine_specs.base import MetricType, TimeGrain from superset.extensions import cache_manager, encrypted_field_factory, security_manager from superset.models.helpers import AuditMixinNullable, ImportExportMixin -from superset.models.tags import FavStarUpdater from superset.result_set import SupersetResultSet from superset.utils import cache as cache_util, core as utils from superset.utils.core import get_username @@ -809,9 +808,3 @@ class FavStar(Model): # pylint: disable=too-few-public-methods class_name = Column(String(50)) obj_id = Column(Integer) dttm = Column(DateTime, default=datetime.utcnow) - - -# events for updating tags -if is_feature_enabled("TAGGING_SYSTEM"): - sqla.event.listen(FavStar, "after_insert", FavStarUpdater.after_insert) - sqla.event.listen(FavStar, "after_delete", FavStarUpdater.after_delete) diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py index b8dbf37e7..57567e616 100644 --- a/superset/models/dashboard.py +++ b/superset/models/dashboard.py @@ -53,7 +53,6 @@ from superset.extensions import cache_manager from superset.models.filter_set import FilterSet from superset.models.helpers import AuditMixinNullable, ImportExportMixin from superset.models.slice import Slice -from superset.models.tags import DashboardUpdater from superset.models.user_attributes import UserAttribute from superset.tasks.thumbnails import cache_dashboard_thumbnail from superset.utils import core as utils @@ -454,12 +453,6 @@ def id_or_slug_filter(id_or_slug: Union[int, str]) -> BinaryExpression: OnDashboardChange = Callable[[Mapper, Connection, Dashboard], Any] -# events for updating tags -if is_feature_enabled("TAGGING_SYSTEM"): - sqla.event.listen(Dashboard, "after_insert", DashboardUpdater.after_insert) - sqla.event.listen(Dashboard, "after_update", DashboardUpdater.after_update) - sqla.event.listen(Dashboard, "after_delete", DashboardUpdater.after_delete) - if is_feature_enabled("THUMBNAILS_SQLA_LISTENERS"): update_thumbnail: OnDashboardChange = lambda _, __, dash: dash.update_thumbnail() sqla.event.listen(Dashboard, "after_insert", update_thumbnail) diff --git a/superset/models/slice.py b/superset/models/slice.py index de0f3df59..d644e7b74 100644 --- a/superset/models/slice.py +++ b/superset/models/slice.py @@ -42,7 +42,6 @@ from sqlalchemy.orm.mapper import Mapper from superset import db, is_feature_enabled, security_manager from superset.legacy import update_time_range from superset.models.helpers import AuditMixinNullable, ImportExportMixin -from superset.models.tags import ChartUpdater from superset.tasks.thumbnails import cache_chart_thumbnail from superset.utils import core as utils from superset.utils.hashing import md5_sha_from_str @@ -367,13 +366,6 @@ def event_after_chart_changed( sqla.event.listen(Slice, "before_insert", set_related_perm) sqla.event.listen(Slice, "before_update", set_related_perm) -# events for updating tags -if is_feature_enabled("TAGGING_SYSTEM"): - sqla.event.listen(Slice, "after_insert", ChartUpdater.after_insert) - sqla.event.listen(Slice, "after_update", ChartUpdater.after_update) - sqla.event.listen(Slice, "after_delete", ChartUpdater.after_delete) - -# events for updating tags if is_feature_enabled("THUMBNAILS_SQLA_LISTENERS"): sqla.event.listen(Slice, "after_insert", event_after_chart_changed) sqla.event.listen(Slice, "after_update", event_after_chart_changed) diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index d12af4908..408bc708d 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -49,7 +49,6 @@ from superset.models.helpers import ( ExtraJSONMixin, ImportExportMixin, ) -from superset.models.tags import QueryUpdater from superset.sql_parse import CtasMethod, ParsedQuery, Table from superset.sqllab.limiting_factor import LimitingFactor from superset.superset_typing import ResultSetColumnType @@ -509,9 +508,3 @@ class TableSchema(Model, AuditMixinNullable, ExtraJSONMixin): "description": description, "expanded": self.expanded, } - - -# events for updating tags -sqla.event.listen(SavedQuery, "after_insert", QueryUpdater.after_insert) -sqla.event.listen(SavedQuery, "after_update", QueryUpdater.after_update) -sqla.event.listen(SavedQuery, "after_delete", QueryUpdater.after_delete) diff --git a/superset/tags/core.py b/superset/tags/core.py new file mode 100644 index 000000000..f1f832c7a --- /dev/null +++ b/superset/tags/core.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +def register_sqla_event_listeners() -> None: + import sqlalchemy as sqla + + from superset.connectors.sqla.models import SqlaTable + from superset.models.core import FavStar + from superset.models.dashboard import Dashboard + from superset.models.slice import Slice + from superset.models.sql_lab import SavedQuery + from superset.tags.models import ( + ChartUpdater, + DashboardUpdater, + DatasetUpdater, + FavStarUpdater, + QueryUpdater, + ) + + sqla.event.listen(SqlaTable, "after_insert", DatasetUpdater.after_insert) + sqla.event.listen(SqlaTable, "after_update", DatasetUpdater.after_update) + sqla.event.listen(SqlaTable, "after_delete", DatasetUpdater.after_delete) + + sqla.event.listen(Slice, "after_insert", ChartUpdater.after_insert) + sqla.event.listen(Slice, "after_update", ChartUpdater.after_update) + sqla.event.listen(Slice, "after_delete", ChartUpdater.after_delete) + + sqla.event.listen(Dashboard, "after_insert", DashboardUpdater.after_insert) + sqla.event.listen(Dashboard, "after_update", DashboardUpdater.after_update) + sqla.event.listen(Dashboard, "after_delete", DashboardUpdater.after_delete) + + sqla.event.listen(FavStar, "after_insert", FavStarUpdater.after_insert) + sqla.event.listen(FavStar, "after_delete", FavStarUpdater.after_delete) + + sqla.event.listen(SavedQuery, "after_insert", QueryUpdater.after_insert) + sqla.event.listen(SavedQuery, "after_update", QueryUpdater.after_update) + sqla.event.listen(SavedQuery, "after_delete", QueryUpdater.after_delete) + + +def clear_sqla_event_listeners() -> None: + import sqlalchemy as sqla + + from superset.connectors.sqla.models import SqlaTable + from superset.models.core import FavStar + from superset.models.dashboard import Dashboard + from superset.models.slice import Slice + from superset.models.sql_lab import SavedQuery + from superset.tags.models import ( + ChartUpdater, + DashboardUpdater, + DatasetUpdater, + FavStarUpdater, + QueryUpdater, + ) + + sqla.event.remove(SqlaTable, "after_insert", DatasetUpdater.after_insert) + sqla.event.remove(SqlaTable, "after_update", DatasetUpdater.after_update) + sqla.event.remove(SqlaTable, "after_delete", DatasetUpdater.after_delete) + + sqla.event.remove(Slice, "after_insert", ChartUpdater.after_insert) + sqla.event.remove(Slice, "after_update", ChartUpdater.after_update) + sqla.event.remove(Slice, "after_delete", ChartUpdater.after_delete) + + sqla.event.remove(Dashboard, "after_insert", DashboardUpdater.after_insert) + sqla.event.remove(Dashboard, "after_update", DashboardUpdater.after_update) + sqla.event.remove(Dashboard, "after_delete", DashboardUpdater.after_delete) + + sqla.event.remove(FavStar, "after_insert", FavStarUpdater.after_insert) + sqla.event.remove(FavStar, "after_delete", FavStarUpdater.after_delete) + + sqla.event.remove(SavedQuery, "after_insert", QueryUpdater.after_insert) + sqla.event.remove(SavedQuery, "after_update", QueryUpdater.after_update) + sqla.event.remove(SavedQuery, "after_delete", QueryUpdater.after_delete) diff --git a/superset/models/tags.py b/superset/tags/models.py similarity index 84% rename from superset/models/tags.py rename to superset/tags/models.py index 528206e67..89505146e 100644 --- a/superset/models/tags.py +++ b/superset/tags/models.py @@ -14,7 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals +from __future__ import ( + absolute_import, + annotations, + division, + print_function, + unicode_literals, +) import enum from typing import List, Optional, TYPE_CHECKING, Union @@ -28,6 +34,7 @@ from sqlalchemy.orm.mapper import Mapper from superset.models.helpers import AuditMixinNullable if TYPE_CHECKING: + from superset.connectors.sqla.models import SqlaTable from superset.models.core import FavStar from superset.models.dashboard import Dashboard from superset.models.slice import Slice @@ -41,7 +48,7 @@ class TagTypes(enum.Enum): """ Types for tags. - Objects (queries, charts and dashboards) will have with implicit tags based + Objects (queries, charts, dashboards, and datasets) will have with implicit tags based on metadata: types, owners and who favorited them. This way, user "alice" can find all their objects by querying for the tag `owner:alice`. """ @@ -64,11 +71,12 @@ class ObjectTypes(enum.Enum): query = 1 chart = 2 dashboard = 3 + dataset = 4 class Tag(Model, AuditMixinNullable): - """A tag attached to an object (query, chart or dashboard).""" + """A tag attached to an object (query, chart, dashboard, or dataset).""" __tablename__ = "tag" id = Column(Integer, primary_key=True) @@ -103,6 +111,7 @@ def get_object_type(class_name: str) -> ObjectTypes: "slice": ObjectTypes.chart, "dashboard": ObjectTypes.dashboard, "query": ObjectTypes.query, + "dataset": ObjectTypes.dataset, } try: return mapping[class_name.lower()] @@ -116,13 +125,15 @@ class ObjectUpdater: @classmethod def get_owners_ids( - cls, target: Union["Dashboard", "FavStar", "Slice"] + cls, target: Union[Dashboard, FavStar, Slice, Query, SqlaTable] ) -> List[int]: raise NotImplementedError("Subclass should implement `get_owners_ids`") @classmethod def _add_owners( - cls, session: Session, target: Union["Dashboard", "FavStar", "Slice"] + cls, + session: Session, + target: Union[Dashboard, FavStar, Slice, Query, SqlaTable], ) -> None: for owner_id in cls.get_owners_ids(target): name = "owner:{0}".format(owner_id) @@ -137,7 +148,7 @@ class ObjectUpdater: cls, _mapper: Mapper, connection: Connection, - target: Union["Dashboard", "FavStar", "Slice"], + target: Union[Dashboard, FavStar, Slice, Query, SqlaTable], ) -> None: session = Session(bind=connection) @@ -158,7 +169,7 @@ class ObjectUpdater: cls, _mapper: Mapper, connection: Connection, - target: Union["Dashboard", "FavStar", "Slice"], + target: Union[Dashboard, FavStar, Slice, Query, SqlaTable], ) -> None: session = Session(bind=connection) @@ -187,7 +198,7 @@ class ObjectUpdater: cls, _mapper: Mapper, connection: Connection, - target: Union["Dashboard", "FavStar", "Slice"], + target: Union[Dashboard, FavStar, Slice, Query, SqlaTable], ) -> None: session = Session(bind=connection) @@ -205,7 +216,7 @@ class ChartUpdater(ObjectUpdater): object_type = "chart" @classmethod - def get_owners_ids(cls, target: "Slice") -> List[int]: + def get_owners_ids(cls, target: Slice) -> List[int]: return [owner.id for owner in target.owners] @@ -214,7 +225,7 @@ class DashboardUpdater(ObjectUpdater): object_type = "dashboard" @classmethod - def get_owners_ids(cls, target: "Dashboard") -> List[int]: + def get_owners_ids(cls, target: Dashboard) -> List[int]: return [owner.id for owner in target.owners] @@ -223,14 +234,23 @@ class QueryUpdater(ObjectUpdater): object_type = "query" @classmethod - def get_owners_ids(cls, target: "Query") -> List[int]: + def get_owners_ids(cls, target: Query) -> List[int]: return [target.user_id] +class DatasetUpdater(ObjectUpdater): + + object_type = "dataset" + + @classmethod + def get_owners_ids(cls, target: SqlaTable) -> List[int]: + return [owner.id for owner in target.owners] + + class FavStarUpdater: @classmethod def after_insert( - cls, _mapper: Mapper, connection: Connection, target: "FavStar" + cls, _mapper: Mapper, connection: Connection, target: FavStar ) -> None: session = Session(bind=connection) name = "favorited_by:{0}".format(target.user_id) @@ -246,7 +266,7 @@ class FavStarUpdater: @classmethod def after_delete( - cls, _mapper: Mapper, connection: Connection, target: "FavStar" + cls, _mapper: Mapper, connection: Connection, target: FavStar ) -> None: session = Session(bind=connection) name = "favorited_by:{0}".format(target.user_id) diff --git a/superset/tasks/cache.py b/superset/tasks/cache.py index 137ec068e..0bda1e708 100644 --- a/superset/tasks/cache.py +++ b/superset/tasks/cache.py @@ -28,7 +28,7 @@ from superset.extensions import celery_app from superset.models.core import Log from superset.models.dashboard import Dashboard from superset.models.slice import Slice -from superset.models.tags import Tag, TaggedObject +from superset.tags.models import Tag, TaggedObject from superset.utils.date_parser import parse_human_datetime from superset.utils.machine_auth import MachineAuthProvider diff --git a/superset/utils/url_map_converters.py b/superset/utils/url_map_converters.py index c6a14f3fd..c5eaf3b35 100644 --- a/superset/utils/url_map_converters.py +++ b/superset/utils/url_map_converters.py @@ -18,7 +18,7 @@ from typing import Any, List from werkzeug.routing import BaseConverter, Map -from superset.models.tags import ObjectTypes +from superset.tags.models import ObjectTypes class RegexConverter(BaseConverter): diff --git a/superset/views/tags.py b/superset/views/tags.py index 8ab2798f5..985d26179 100644 --- a/superset/views/tags.py +++ b/superset/views/tags.py @@ -28,12 +28,13 @@ from sqlalchemy import and_, func from werkzeug.exceptions import NotFound from superset import db, is_feature_enabled, utils +from superset.connectors.sqla.models import SqlaTable from superset.jinja_context import ExtraCache from superset.models.dashboard import Dashboard from superset.models.slice import Slice from superset.models.sql_lab import SavedQuery -from superset.models.tags import ObjectTypes, Tag, TaggedObject, TagTypes from superset.superset_typing import FlaskResponse +from superset.tags.models import ObjectTypes, Tag, TaggedObject, TagTypes from .base import BaseSupersetView, json_success @@ -238,4 +239,31 @@ class TagView(BaseSupersetView): for obj in saved_queries ) + # datasets + if not types or "dataset" in types: + datasets = ( + db.session.query(SqlaTable) + .join( + TaggedObject, + and_( + TaggedObject.object_id == SqlaTable.id, + TaggedObject.object_type == ObjectTypes.dataset, + ), + ) + .join(Tag, TaggedObject.tag_id == Tag.id) + .filter(Tag.name.in_(tags)) + ) + results.extend( + { + "id": obj.id, + "type": ObjectTypes.dataset.name, + "name": obj.table_name, + "url": obj.sql_url(), + "changed_on": obj.changed_on, + "created_by": obj.created_by_fk, + "creator": obj.creator(), + } + for obj in datasets + ) + return json_success(json.dumps(results, default=utils.core.json_int_dttm_ser)) diff --git a/tests/integration_tests/fixtures/tags.py b/tests/integration_tests/fixtures/tags.py new file mode 100644 index 000000000..57fd4ec71 --- /dev/null +++ b/tests/integration_tests/fixtures/tags.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +from superset.tags.core import clear_sqla_event_listeners, register_sqla_event_listeners +from tests.integration_tests.test_app import app + + +@pytest.fixture +def with_tagging_system_feature(): + with app.app_context(): + is_enabled = app.config["DEFAULT_FEATURE_FLAGS"]["TAGGING_SYSTEM"] + if not is_enabled: + app.config["DEFAULT_FEATURE_FLAGS"]["TAGGING_SYSTEM"] = True + register_sqla_event_listeners() + yield + app.config["DEFAULT_FEATURE_FLAGS"]["TAGGING_SYSTEM"] = False + clear_sqla_event_listeners() diff --git a/tests/integration_tests/strategy_tests.py b/tests/integration_tests/strategy_tests.py index f31489bb0..e54ae865e 100644 --- a/tests/integration_tests/strategy_tests.py +++ b/tests/integration_tests/strategy_tests.py @@ -35,7 +35,7 @@ from superset.utils.database import get_example_database from superset import db from superset.models.core import Log -from superset.models.tags import get_tag, ObjectTypes, TaggedObject, TagTypes +from superset.tags.models import get_tag, ObjectTypes, TaggedObject, TagTypes from superset.tasks.cache import ( DashboardTagsStrategy, TopNDashboardsStrategy, diff --git a/tests/integration_tests/tagging_tests.py b/tests/integration_tests/tagging_tests.py index 9ae8764d4..4ee10041d 100644 --- a/tests/integration_tests/tagging_tests.py +++ b/tests/integration_tests/tagging_tests.py @@ -15,11 +15,33 @@ # specific language governing permissions and limitations # under the License. +from unittest import mock + +import pytest + +from superset.connectors.sqla.models import SqlaTable +from superset.extensions import db +from superset.models.core import FavStar +from superset.models.dashboard import Dashboard +from superset.models.slice import Slice +from superset.models.sql_lab import SavedQuery +from superset.tags.models import TaggedObject +from superset.utils.core import DatasourceType +from superset.utils.database import get_main_database from tests.integration_tests.base_tests import SupersetTestCase from tests.integration_tests.conftest import with_feature_flags +from tests.integration_tests.fixtures.tags import with_tagging_system_feature class TestTagging(SupersetTestCase): + def query_tagged_object_table(self): + query = db.session.query(TaggedObject).all() + return query + + def clear_tagged_object_table(self): + db.session.query(TaggedObject).delete() + db.session.commit() + @with_feature_flags(TAGGING_SYSTEM=False) def test_tag_view_disabled(self): self.login("admin") @@ -31,3 +53,257 @@ class TestTagging(SupersetTestCase): self.login("admin") response = self.client.get("/tagview/tags/suggestions/") self.assertNotEqual(404, response.status_code) + + @pytest.mark.usefixtures("with_tagging_system_feature") + def test_dataset_tagging(self): + """ + Test to make sure that when a new dataset is created, + a corresponding tag in the tagged_objects table + is created + """ + + # Remove all existing rows in the tagged_object table + self.clear_tagged_object_table() + + # Test to make sure nothing is in the tagged_object table + self.assertEqual([], self.query_tagged_object_table()) + + # Create a dataset and add it to the db + test_dataset = SqlaTable( + table_name="foo", + schema=None, + owners=[], + database=get_main_database(), + sql=None, + extra='{"certification": 1}', + ) + db.session.add(test_dataset) + db.session.commit() + + # Test to make sure that a dataset tag was added to the tagged_object table + tags = self.query_tagged_object_table() + self.assertEqual(1, len(tags)) + self.assertEqual("ObjectTypes.dataset", str(tags[0].object_type)) + self.assertEqual(test_dataset.id, tags[0].object_id) + + # Cleanup the db + db.session.delete(test_dataset) + db.session.commit() + + # Test to make sure the tag is deleted when the associated object is deleted + self.assertEqual([], self.query_tagged_object_table()) + + @pytest.mark.usefixtures("with_tagging_system_feature") + def test_chart_tagging(self): + """ + Test to make sure that when a new chart is created, + a corresponding tag in the tagged_objects table + is created + """ + + # Remove all existing rows in the tagged_object table + self.clear_tagged_object_table() + + # Test to make sure nothing is in the tagged_object table + self.assertEqual([], self.query_tagged_object_table()) + + # Create a chart and add it to the db + test_chart = Slice( + slice_name="test_chart", + datasource_type=DatasourceType.TABLE, + viz_type="bubble", + datasource_id=1, + id=1, + ) + db.session.add(test_chart) + db.session.commit() + + # Test to make sure that a chart tag was added to the tagged_object table + tags = self.query_tagged_object_table() + self.assertEqual(1, len(tags)) + self.assertEqual("ObjectTypes.chart", str(tags[0].object_type)) + self.assertEqual(test_chart.id, tags[0].object_id) + + # Cleanup the db + db.session.delete(test_chart) + db.session.commit() + + # Test to make sure the tag is deleted when the associated object is deleted + self.assertEqual([], self.query_tagged_object_table()) + + @pytest.mark.usefixtures("with_tagging_system_feature") + def test_dashboard_tagging(self): + """ + Test to make sure that when a new dashboard is created, + a corresponding tag in the tagged_objects table + is created + """ + + # Remove all existing rows in the tagged_object table + self.clear_tagged_object_table() + + # Test to make sure nothing is in the tagged_object table + self.assertEqual([], self.query_tagged_object_table()) + + # Create a dashboard and add it to the db + test_dashboard = Dashboard() + test_dashboard.dashboard_title = "test_dashboard" + test_dashboard.slug = "test_slug" + test_dashboard.slices = [] + test_dashboard.published = True + + db.session.add(test_dashboard) + db.session.commit() + + # Test to make sure that a dashboard tag was added to the tagged_object table + tags = self.query_tagged_object_table() + self.assertEqual(1, len(tags)) + self.assertEqual("ObjectTypes.dashboard", str(tags[0].object_type)) + self.assertEqual(test_dashboard.id, tags[0].object_id) + + # Cleanup the db + db.session.delete(test_dashboard) + db.session.commit() + + # Test to make sure the tag is deleted when the associated object is deleted + self.assertEqual([], self.query_tagged_object_table()) + + @pytest.mark.usefixtures("with_tagging_system_feature") + def test_saved_query_tagging(self): + """ + Test to make sure that when a new saved query is + created, a corresponding tag in the tagged_objects + table is created + """ + + # Remove all existing rows in the tagged_object table + self.clear_tagged_object_table() + + # Test to make sure nothing is in the tagged_object table + self.assertEqual([], self.query_tagged_object_table()) + + # Create a saved query and add it to the db + test_saved_query = SavedQuery(id=1, label="test saved query") + db.session.add(test_saved_query) + db.session.commit() + + # Test to make sure that a saved query tag was added to the tagged_object table + tags = self.query_tagged_object_table() + + self.assertEqual(2, len(tags)) + + self.assertEqual("ObjectTypes.query", str(tags[0].object_type)) + self.assertEqual("owner:None", str(tags[0].tag.name)) + self.assertEqual("TagTypes.owner", str(tags[0].tag.type)) + self.assertEqual(test_saved_query.id, tags[0].object_id) + + self.assertEqual("ObjectTypes.query", str(tags[1].object_type)) + self.assertEqual("type:query", str(tags[1].tag.name)) + self.assertEqual("TagTypes.type", str(tags[1].tag.type)) + self.assertEqual(test_saved_query.id, tags[1].object_id) + + # Cleanup the db + db.session.delete(test_saved_query) + db.session.commit() + + # Test to make sure the tag is deleted when the associated object is deleted + self.assertEqual([], self.query_tagged_object_table()) + + @pytest.mark.usefixtures("with_tagging_system_feature") + def test_favorite_tagging(self): + """ + Test to make sure that when a new favorite object is + created, a corresponding tag in the tagged_objects + table is created + """ + + # Remove all existing rows in the tagged_object table + self.clear_tagged_object_table() + + # Test to make sure nothing is in the tagged_object table + self.assertEqual([], self.query_tagged_object_table()) + + # Create a favorited object and add it to the db + test_saved_query = FavStar(user_id=1, class_name="slice", obj_id=1) + db.session.add(test_saved_query) + db.session.commit() + + # Test to make sure that a favorited object tag was added to the tagged_object table + tags = self.query_tagged_object_table() + self.assertEqual(1, len(tags)) + self.assertEqual("ObjectTypes.chart", str(tags[0].object_type)) + self.assertEqual(test_saved_query.obj_id, tags[0].object_id) + + # Cleanup the db + db.session.delete(test_saved_query) + db.session.commit() + + # Test to make sure the tag is deleted when the associated object is deleted + self.assertEqual([], self.query_tagged_object_table()) + + @with_feature_flags(TAGGING_SYSTEM=False) + def test_tagging_system(self): + """ + Test to make sure that when the TAGGING_SYSTEM + feature flag is false, that no tags are created + """ + + # Remove all existing rows in the tagged_object table + self.clear_tagged_object_table() + + # Test to make sure nothing is in the tagged_object table + self.assertEqual([], self.query_tagged_object_table()) + + # Create a dataset and add it to the db + test_dataset = SqlaTable( + table_name="foo", + schema=None, + owners=[], + database=get_main_database(), + sql=None, + extra='{"certification": 1}', + ) + + # Create a chart and add it to the db + test_chart = Slice( + slice_name="test_chart", + datasource_type=DatasourceType.TABLE, + viz_type="bubble", + datasource_id=1, + id=1, + ) + + # Create a dashboard and add it to the db + test_dashboard = Dashboard() + test_dashboard.dashboard_title = "test_dashboard" + test_dashboard.slug = "test_slug" + test_dashboard.slices = [] + test_dashboard.published = True + + # Create a saved query and add it to the db + test_saved_query = SavedQuery(id=1, label="test saved query") + + # Create a favorited object and add it to the db + test_favorited_object = FavStar(user_id=1, class_name="slice", obj_id=1) + + db.session.add(test_dataset) + db.session.add(test_chart) + db.session.add(test_dashboard) + db.session.add(test_saved_query) + db.session.add(test_favorited_object) + db.session.commit() + + # Test to make sure that no tags were added to the tagged_object table + tags = self.query_tagged_object_table() + self.assertEqual(0, len(tags)) + + # Cleanup the db + db.session.delete(test_dataset) + db.session.delete(test_chart) + db.session.delete(test_dashboard) + db.session.delete(test_saved_query) + db.session.delete(test_favorited_object) + db.session.commit() + + # Test to make sure all the tags are deleted when the associated objects are deleted + self.assertEqual([], self.query_tagged_object_table())