feat: Add dataset tagging to the back-end (#20892)

Co-authored-by: Ville Brofeldt <33317356+villebro@users.noreply.github.com>
This commit is contained in:
cccs-Dustin 2022-09-23 04:01:17 -04:00 committed by GitHub
parent dc539087c7
commit 2e564897f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 712 additions and 157 deletions

View File

@ -14,15 +14,145 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from sqlalchemy import Metadata
from typing import Any, List
from sqlalchemy import MetaData
from sqlalchemy.engine import Engine
from sqlalchemy.exc import IntegrityError
from sqlalchemy.sql import and_, func, functions, join, literal, select
from sqlalchemy.sql import and_, func, join, literal, select
from superset.models.tags import ObjectTypes, TagTypes
from superset.tags.models import ObjectTypes, TagTypes
def add_types(engine: Engine, metadata: Metadata) -> None:
def add_types_to_charts(
engine: Engine, metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
) -> None:
slices = metadata.tables["slices"]
charts = (
select(
[
tag.c.id.label("tag_id"),
slices.c.id.label("object_id"),
literal(ObjectTypes.chart.name).label("object_type"),
]
)
.select_from(
join(
join(slices, tag, tag.c.name == "type:chart"),
tagged_object,
and_(
tagged_object.c.tag_id == tag.c.id,
tagged_object.c.object_id == slices.c.id,
tagged_object.c.object_type == "chart",
),
isouter=True,
full=False,
)
)
.where(tagged_object.c.tag_id.is_(None))
)
query = tagged_object.insert().from_select(columns, charts)
engine.execute(query)
def add_types_to_dashboards(
engine: Engine, metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
) -> None:
dashboard_table = metadata.tables["dashboards"]
dashboards = (
select(
[
tag.c.id.label("tag_id"),
dashboard_table.c.id.label("object_id"),
literal(ObjectTypes.dashboard.name).label("object_type"),
]
)
.select_from(
join(
join(dashboard_table, tag, tag.c.name == "type:dashboard"),
tagged_object,
and_(
tagged_object.c.tag_id == tag.c.id,
tagged_object.c.object_id == dashboard_table.c.id,
tagged_object.c.object_type == "dashboard",
),
isouter=True,
full=False,
)
)
.where(tagged_object.c.tag_id.is_(None))
)
query = tagged_object.insert().from_select(columns, dashboards)
engine.execute(query)
def add_types_to_saved_queries(
engine: Engine, metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
) -> None:
saved_query = metadata.tables["saved_query"]
saved_queries = (
select(
[
tag.c.id.label("tag_id"),
saved_query.c.id.label("object_id"),
literal(ObjectTypes.query.name).label("object_type"),
]
)
.select_from(
join(
join(saved_query, tag, tag.c.name == "type:query"),
tagged_object,
and_(
tagged_object.c.tag_id == tag.c.id,
tagged_object.c.object_id == saved_query.c.id,
tagged_object.c.object_type == "query",
),
isouter=True,
full=False,
)
)
.where(tagged_object.c.tag_id.is_(None))
)
query = tagged_object.insert().from_select(columns, saved_queries)
engine.execute(query)
def add_types_to_datasets(
engine: Engine, metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
) -> None:
tables = metadata.tables["tables"]
datasets = (
select(
[
tag.c.id.label("tag_id"),
tables.c.id.label("object_id"),
literal(ObjectTypes.dataset.name).label("object_type"),
]
)
.select_from(
join(
join(tables, tag, tag.c.name == "type:dataset"),
tagged_object,
and_(
tagged_object.c.tag_id == tag.c.id,
tagged_object.c.object_id == tables.c.id,
tagged_object.c.object_type == "dataset",
),
isouter=True,
full=False,
)
)
.where(tagged_object.c.tag_id.is_(None))
)
query = tagged_object.insert().from_select(columns, datasets)
engine.execute(query)
def add_types(engine: Engine, metadata: MetaData) -> None:
"""
Tag every object according to its type:
@ -68,13 +198,24 @@ def add_types(engine: Engine, metadata: Metadata) -> None:
AND tagged_object.object_type = 'query'
WHERE tagged_object.tag_id IS NULL;
INSERT INTO tagged_object (tag_id, object_id, object_type)
SELECT
tag.id AS tag_id,
tables.id AS object_id,
'dataset' AS object_type
FROM tables
JOIN tag
ON tag.name = 'type:dataset'
LEFT OUTER JOIN tagged_object
ON tagged_object.tag_id = tag.id
AND tagged_object.object_id = tables.id
AND tagged_object.object_type = 'dataset'
WHERE tagged_object.tag_id IS NULL;
"""
tag = metadata.tables["tag"]
tagged_object = metadata.tables["tagged_object"]
slices = metadata.tables["slices"]
dashboards = metadata.tables["dashboards"]
saved_query = metadata.tables["saved_query"]
columns = ["tag_id", "object_id", "object_type"]
# add a tag for each object type
@ -85,6 +226,17 @@ def add_types(engine: Engine, metadata: Metadata) -> None:
except IntegrityError:
pass # already exists
add_types_to_charts(engine, metadata, tag, tagged_object, columns)
add_types_to_dashboards(engine, metadata, tag, tagged_object, columns)
add_types_to_saved_queries(engine, metadata, tag, tagged_object, columns)
add_types_to_datasets(engine, metadata, tag, tagged_object, columns)
def add_owners_to_charts(
engine: Engine, metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
) -> None:
slices = metadata.tables["slices"]
charts = (
select(
[
@ -95,7 +247,11 @@ def add_types(engine: Engine, metadata: Metadata) -> None:
)
.select_from(
join(
join(slices, tag, tag.c.name == "type:chart"),
join(
slices,
tag,
tag.c.name == "owner:" + slices.c.created_by_fk,
),
tagged_object,
and_(
tagged_object.c.tag_id == tag.c.id,
@ -111,21 +267,31 @@ def add_types(engine: Engine, metadata: Metadata) -> None:
query = tagged_object.insert().from_select(columns, charts)
engine.execute(query)
def add_owners_to_dashboards(
engine: Engine, metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
) -> None:
dashboard_table = metadata.tables["dashboards"]
dashboards = (
select(
[
tag.c.id.label("tag_id"),
dashboards.c.id.label("object_id"),
dashboard_table.c.id.label("object_id"),
literal(ObjectTypes.dashboard.name).label("object_type"),
]
)
.select_from(
join(
join(dashboards, tag, tag.c.name == "type:dashboard"),
join(
dashboard_table,
tag,
tag.c.name == "owner:" + dashboard_table.c.created_by_fk,
),
tagged_object,
and_(
tagged_object.c.tag_id == tag.c.id,
tagged_object.c.object_id == dashboards.c.id,
tagged_object.c.object_id == dashboard_table.c.id,
tagged_object.c.object_type == "dashboard",
),
isouter=True,
@ -137,6 +303,12 @@ def add_types(engine: Engine, metadata: Metadata) -> None:
query = tagged_object.insert().from_select(columns, dashboards)
engine.execute(query)
def add_owners_to_saved_queries(
engine: Engine, metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
) -> None:
saved_query = metadata.tables["saved_query"]
saved_queries = (
select(
[
@ -147,7 +319,11 @@ def add_types(engine: Engine, metadata: Metadata) -> None:
)
.select_from(
join(
join(saved_query, tag, tag.c.name == "type:query"),
join(
saved_query,
tag,
tag.c.name == "owner:" + saved_query.c.created_by_fk,
),
tagged_object,
and_(
tagged_object.c.tag_id == tag.c.id,
@ -164,7 +340,43 @@ def add_types(engine: Engine, metadata: Metadata) -> None:
engine.execute(query)
def add_owners(engine: Engine, metadata: Metadata) -> None:
def add_owners_to_datasets(
engine: Engine, metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
) -> None:
tables = metadata.tables["tables"]
datasets = (
select(
[
tag.c.id.label("tag_id"),
tables.c.id.label("object_id"),
literal(ObjectTypes.dataset.name).label("object_type"),
]
)
.select_from(
join(
join(
tables,
tag,
tag.c.name == "owner:" + tables.c.created_by_fk,
),
tagged_object,
and_(
tagged_object.c.tag_id == tag.c.id,
tagged_object.c.object_id == tables.c.id,
tagged_object.c.object_type == "dataset",
),
isouter=True,
full=False,
)
)
.where(tagged_object.c.tag_id.is_(None))
)
query = tagged_object.insert().from_select(columns, datasets)
engine.execute(query)
def add_owners(engine: Engine, metadata: MetaData) -> None:
"""
Tag every object according to its owner:
@ -208,14 +420,24 @@ def add_owners(engine: Engine, metadata: Metadata) -> None:
AND tagged_object.object_type = 'query'
WHERE tagged_object.tag_id IS NULL;
SELECT
tag.id AS tag_id,
tables.id AS object_id,
'dataset' AS object_type
FROM tables
JOIN tag
ON tag.name = CONCAT('owner:', tables.created_by_fk)
LEFT OUTER JOIN tagged_object
ON tagged_object.tag_id = tag.id
AND tagged_object.object_id = tables.id
AND tagged_object.object_type = 'dataset'
WHERE tagged_object.tag_id IS NULL;
"""
tag = metadata.tables["tag"]
tagged_object = metadata.tables["tagged_object"]
users = metadata.tables["ab_user"]
slices = metadata.tables["slices"]
dashboards = metadata.tables["dashboards"]
saved_query = metadata.tables["saved_query"]
columns = ["tag_id", "object_id", "object_type"]
# create a custom tag for each user
@ -227,100 +449,13 @@ def add_owners(engine: Engine, metadata: Metadata) -> None:
except IntegrityError:
pass # already exists
charts = (
select(
[
tag.c.id.label("tag_id"),
slices.c.id.label("object_id"),
literal(ObjectTypes.chart.name).label("object_type"),
]
)
.select_from(
join(
join(
slices,
tag,
tag.c.name == functions.concat("owner:", slices.c.created_by_fk),
),
tagged_object,
and_(
tagged_object.c.tag_id == tag.c.id,
tagged_object.c.object_id == slices.c.id,
tagged_object.c.object_type == "chart",
),
isouter=True,
full=False,
)
)
.where(tagged_object.c.tag_id.is_(None))
)
query = tagged_object.insert().from_select(columns, charts)
engine.execute(query)
dashboards = (
select(
[
tag.c.id.label("tag_id"),
dashboards.c.id.label("object_id"),
literal(ObjectTypes.dashboard.name).label("object_type"),
]
)
.select_from(
join(
join(
dashboards,
tag,
tag.c.name
== functions.concat("owner:", dashboards.c.created_by_fk),
),
tagged_object,
and_(
tagged_object.c.tag_id == tag.c.id,
tagged_object.c.object_id == dashboards.c.id,
tagged_object.c.object_type == "dashboard",
),
isouter=True,
full=False,
)
)
.where(tagged_object.c.tag_id.is_(None))
)
query = tagged_object.insert().from_select(columns, dashboards)
engine.execute(query)
saved_queries = (
select(
[
tag.c.id.label("tag_id"),
saved_query.c.id.label("object_id"),
literal(ObjectTypes.query.name).label("object_type"),
]
)
.select_from(
join(
join(
saved_query,
tag,
tag.c.name
== functions.concat("owner:", saved_query.c.created_by_fk),
),
tagged_object,
and_(
tagged_object.c.tag_id == tag.c.id,
tagged_object.c.object_id == saved_query.c.id,
tagged_object.c.object_type == "query",
),
isouter=True,
full=False,
)
)
.where(tagged_object.c.tag_id.is_(None))
)
query = tagged_object.insert().from_select(columns, saved_queries)
engine.execute(query)
add_owners_to_charts(engine, metadata, tag, tagged_object, columns)
add_owners_to_dashboards(engine, metadata, tag, tagged_object, columns)
add_owners_to_saved_queries(engine, metadata, tag, tagged_object, columns)
add_owners_to_datasets(engine, metadata, tag, tagged_object, columns)
def add_favorites(engine: Engine, metadata: Metadata) -> None:
def add_favorites(engine: Engine, metadata: MetaData) -> None:
"""
Tag every object that was favorited:
@ -368,7 +503,7 @@ def add_favorites(engine: Engine, metadata: Metadata) -> None:
join(
favstar,
tag,
tag.c.name == functions.concat("favorited_by:", favstar.c.user_id),
tag.c.name == "favorited_by:" + favstar.c.user_id,
),
tagged_object,
and_(

View File

@ -49,6 +49,7 @@ from superset.extensions import (
)
from superset.security import SupersetSecurityManager
from superset.superset_typing import FlaskResponse
from superset.tags.core import register_sqla_event_listeners
from superset.utils.core import pessimistic_connection_handling
from superset.utils.log import DBEventLogger, get_event_logger_from_cfg_value
@ -426,6 +427,9 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
if flask_app_mutator:
flask_app_mutator(self.superset_app)
if feature_flag_manager.is_feature_enabled("TAGGING_SYSTEM"):
register_sqla_event_listeners()
self.init_views()
def check_secret_key(self) -> None:

View File

@ -33,7 +33,7 @@ from flask_appbuilder.models.mixins import AuditMixin
from sqlalchemy import Column, DateTime, Enum, ForeignKey, Integer, String
from sqlalchemy.ext.declarative import declarative_base, declared_attr
from superset.models.tags import ObjectTypes, TagTypes
from superset.tags.models import ObjectTypes, TagTypes
from superset.utils.core import get_user_id
Base = declarative_base()

View File

@ -53,13 +53,12 @@ from sqlalchemy.pool import NullPool
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import expression, Select
from superset import app, db_engine_specs, is_feature_enabled
from superset import app, db_engine_specs
from superset.constants import PASSWORD_MASK
from superset.databases.utils import make_url_safe
from superset.db_engine_specs.base import MetricType, TimeGrain
from superset.extensions import cache_manager, encrypted_field_factory, security_manager
from superset.models.helpers import AuditMixinNullable, ImportExportMixin
from superset.models.tags import FavStarUpdater
from superset.result_set import SupersetResultSet
from superset.utils import cache as cache_util, core as utils
from superset.utils.core import get_username
@ -809,9 +808,3 @@ class FavStar(Model): # pylint: disable=too-few-public-methods
class_name = Column(String(50))
obj_id = Column(Integer)
dttm = Column(DateTime, default=datetime.utcnow)
# events for updating tags
if is_feature_enabled("TAGGING_SYSTEM"):
sqla.event.listen(FavStar, "after_insert", FavStarUpdater.after_insert)
sqla.event.listen(FavStar, "after_delete", FavStarUpdater.after_delete)

View File

@ -53,7 +53,6 @@ from superset.extensions import cache_manager
from superset.models.filter_set import FilterSet
from superset.models.helpers import AuditMixinNullable, ImportExportMixin
from superset.models.slice import Slice
from superset.models.tags import DashboardUpdater
from superset.models.user_attributes import UserAttribute
from superset.tasks.thumbnails import cache_dashboard_thumbnail
from superset.utils import core as utils
@ -454,12 +453,6 @@ def id_or_slug_filter(id_or_slug: Union[int, str]) -> BinaryExpression:
OnDashboardChange = Callable[[Mapper, Connection, Dashboard], Any]
# events for updating tags
if is_feature_enabled("TAGGING_SYSTEM"):
sqla.event.listen(Dashboard, "after_insert", DashboardUpdater.after_insert)
sqla.event.listen(Dashboard, "after_update", DashboardUpdater.after_update)
sqla.event.listen(Dashboard, "after_delete", DashboardUpdater.after_delete)
if is_feature_enabled("THUMBNAILS_SQLA_LISTENERS"):
update_thumbnail: OnDashboardChange = lambda _, __, dash: dash.update_thumbnail()
sqla.event.listen(Dashboard, "after_insert", update_thumbnail)

View File

@ -42,7 +42,6 @@ from sqlalchemy.orm.mapper import Mapper
from superset import db, is_feature_enabled, security_manager
from superset.legacy import update_time_range
from superset.models.helpers import AuditMixinNullable, ImportExportMixin
from superset.models.tags import ChartUpdater
from superset.tasks.thumbnails import cache_chart_thumbnail
from superset.utils import core as utils
from superset.utils.hashing import md5_sha_from_str
@ -367,13 +366,6 @@ def event_after_chart_changed(
sqla.event.listen(Slice, "before_insert", set_related_perm)
sqla.event.listen(Slice, "before_update", set_related_perm)
# events for updating tags
if is_feature_enabled("TAGGING_SYSTEM"):
sqla.event.listen(Slice, "after_insert", ChartUpdater.after_insert)
sqla.event.listen(Slice, "after_update", ChartUpdater.after_update)
sqla.event.listen(Slice, "after_delete", ChartUpdater.after_delete)
# events for updating tags
if is_feature_enabled("THUMBNAILS_SQLA_LISTENERS"):
sqla.event.listen(Slice, "after_insert", event_after_chart_changed)
sqla.event.listen(Slice, "after_update", event_after_chart_changed)

View File

@ -49,7 +49,6 @@ from superset.models.helpers import (
ExtraJSONMixin,
ImportExportMixin,
)
from superset.models.tags import QueryUpdater
from superset.sql_parse import CtasMethod, ParsedQuery, Table
from superset.sqllab.limiting_factor import LimitingFactor
from superset.superset_typing import ResultSetColumnType
@ -509,9 +508,3 @@ class TableSchema(Model, AuditMixinNullable, ExtraJSONMixin):
"description": description,
"expanded": self.expanded,
}
# events for updating tags
sqla.event.listen(SavedQuery, "after_insert", QueryUpdater.after_insert)
sqla.event.listen(SavedQuery, "after_update", QueryUpdater.after_update)
sqla.event.listen(SavedQuery, "after_delete", QueryUpdater.after_delete)

88
superset/tags/core.py Normal file
View File

@ -0,0 +1,88 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
def register_sqla_event_listeners() -> None:
import sqlalchemy as sqla
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import FavStar
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.models.sql_lab import SavedQuery
from superset.tags.models import (
ChartUpdater,
DashboardUpdater,
DatasetUpdater,
FavStarUpdater,
QueryUpdater,
)
sqla.event.listen(SqlaTable, "after_insert", DatasetUpdater.after_insert)
sqla.event.listen(SqlaTable, "after_update", DatasetUpdater.after_update)
sqla.event.listen(SqlaTable, "after_delete", DatasetUpdater.after_delete)
sqla.event.listen(Slice, "after_insert", ChartUpdater.after_insert)
sqla.event.listen(Slice, "after_update", ChartUpdater.after_update)
sqla.event.listen(Slice, "after_delete", ChartUpdater.after_delete)
sqla.event.listen(Dashboard, "after_insert", DashboardUpdater.after_insert)
sqla.event.listen(Dashboard, "after_update", DashboardUpdater.after_update)
sqla.event.listen(Dashboard, "after_delete", DashboardUpdater.after_delete)
sqla.event.listen(FavStar, "after_insert", FavStarUpdater.after_insert)
sqla.event.listen(FavStar, "after_delete", FavStarUpdater.after_delete)
sqla.event.listen(SavedQuery, "after_insert", QueryUpdater.after_insert)
sqla.event.listen(SavedQuery, "after_update", QueryUpdater.after_update)
sqla.event.listen(SavedQuery, "after_delete", QueryUpdater.after_delete)
def clear_sqla_event_listeners() -> None:
import sqlalchemy as sqla
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import FavStar
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.models.sql_lab import SavedQuery
from superset.tags.models import (
ChartUpdater,
DashboardUpdater,
DatasetUpdater,
FavStarUpdater,
QueryUpdater,
)
sqla.event.remove(SqlaTable, "after_insert", DatasetUpdater.after_insert)
sqla.event.remove(SqlaTable, "after_update", DatasetUpdater.after_update)
sqla.event.remove(SqlaTable, "after_delete", DatasetUpdater.after_delete)
sqla.event.remove(Slice, "after_insert", ChartUpdater.after_insert)
sqla.event.remove(Slice, "after_update", ChartUpdater.after_update)
sqla.event.remove(Slice, "after_delete", ChartUpdater.after_delete)
sqla.event.remove(Dashboard, "after_insert", DashboardUpdater.after_insert)
sqla.event.remove(Dashboard, "after_update", DashboardUpdater.after_update)
sqla.event.remove(Dashboard, "after_delete", DashboardUpdater.after_delete)
sqla.event.remove(FavStar, "after_insert", FavStarUpdater.after_insert)
sqla.event.remove(FavStar, "after_delete", FavStarUpdater.after_delete)
sqla.event.remove(SavedQuery, "after_insert", QueryUpdater.after_insert)
sqla.event.remove(SavedQuery, "after_update", QueryUpdater.after_update)
sqla.event.remove(SavedQuery, "after_delete", QueryUpdater.after_delete)

View File

@ -14,7 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, unicode_literals
from __future__ import (
absolute_import,
annotations,
division,
print_function,
unicode_literals,
)
import enum
from typing import List, Optional, TYPE_CHECKING, Union
@ -28,6 +34,7 @@ from sqlalchemy.orm.mapper import Mapper
from superset.models.helpers import AuditMixinNullable
if TYPE_CHECKING:
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import FavStar
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
@ -41,7 +48,7 @@ class TagTypes(enum.Enum):
"""
Types for tags.
Objects (queries, charts and dashboards) will have with implicit tags based
Objects (queries, charts, dashboards, and datasets) will have with implicit tags based
on metadata: types, owners and who favorited them. This way, user "alice"
can find all their objects by querying for the tag `owner:alice`.
"""
@ -64,11 +71,12 @@ class ObjectTypes(enum.Enum):
query = 1
chart = 2
dashboard = 3
dataset = 4
class Tag(Model, AuditMixinNullable):
"""A tag attached to an object (query, chart or dashboard)."""
"""A tag attached to an object (query, chart, dashboard, or dataset)."""
__tablename__ = "tag"
id = Column(Integer, primary_key=True)
@ -103,6 +111,7 @@ def get_object_type(class_name: str) -> ObjectTypes:
"slice": ObjectTypes.chart,
"dashboard": ObjectTypes.dashboard,
"query": ObjectTypes.query,
"dataset": ObjectTypes.dataset,
}
try:
return mapping[class_name.lower()]
@ -116,13 +125,15 @@ class ObjectUpdater:
@classmethod
def get_owners_ids(
cls, target: Union["Dashboard", "FavStar", "Slice"]
cls, target: Union[Dashboard, FavStar, Slice, Query, SqlaTable]
) -> List[int]:
raise NotImplementedError("Subclass should implement `get_owners_ids`")
@classmethod
def _add_owners(
cls, session: Session, target: Union["Dashboard", "FavStar", "Slice"]
cls,
session: Session,
target: Union[Dashboard, FavStar, Slice, Query, SqlaTable],
) -> None:
for owner_id in cls.get_owners_ids(target):
name = "owner:{0}".format(owner_id)
@ -137,7 +148,7 @@ class ObjectUpdater:
cls,
_mapper: Mapper,
connection: Connection,
target: Union["Dashboard", "FavStar", "Slice"],
target: Union[Dashboard, FavStar, Slice, Query, SqlaTable],
) -> None:
session = Session(bind=connection)
@ -158,7 +169,7 @@ class ObjectUpdater:
cls,
_mapper: Mapper,
connection: Connection,
target: Union["Dashboard", "FavStar", "Slice"],
target: Union[Dashboard, FavStar, Slice, Query, SqlaTable],
) -> None:
session = Session(bind=connection)
@ -187,7 +198,7 @@ class ObjectUpdater:
cls,
_mapper: Mapper,
connection: Connection,
target: Union["Dashboard", "FavStar", "Slice"],
target: Union[Dashboard, FavStar, Slice, Query, SqlaTable],
) -> None:
session = Session(bind=connection)
@ -205,7 +216,7 @@ class ChartUpdater(ObjectUpdater):
object_type = "chart"
@classmethod
def get_owners_ids(cls, target: "Slice") -> List[int]:
def get_owners_ids(cls, target: Slice) -> List[int]:
return [owner.id for owner in target.owners]
@ -214,7 +225,7 @@ class DashboardUpdater(ObjectUpdater):
object_type = "dashboard"
@classmethod
def get_owners_ids(cls, target: "Dashboard") -> List[int]:
def get_owners_ids(cls, target: Dashboard) -> List[int]:
return [owner.id for owner in target.owners]
@ -223,14 +234,23 @@ class QueryUpdater(ObjectUpdater):
object_type = "query"
@classmethod
def get_owners_ids(cls, target: "Query") -> List[int]:
def get_owners_ids(cls, target: Query) -> List[int]:
return [target.user_id]
class DatasetUpdater(ObjectUpdater):
object_type = "dataset"
@classmethod
def get_owners_ids(cls, target: SqlaTable) -> List[int]:
return [owner.id for owner in target.owners]
class FavStarUpdater:
@classmethod
def after_insert(
cls, _mapper: Mapper, connection: Connection, target: "FavStar"
cls, _mapper: Mapper, connection: Connection, target: FavStar
) -> None:
session = Session(bind=connection)
name = "favorited_by:{0}".format(target.user_id)
@ -246,7 +266,7 @@ class FavStarUpdater:
@classmethod
def after_delete(
cls, _mapper: Mapper, connection: Connection, target: "FavStar"
cls, _mapper: Mapper, connection: Connection, target: FavStar
) -> None:
session = Session(bind=connection)
name = "favorited_by:{0}".format(target.user_id)

View File

@ -28,7 +28,7 @@ from superset.extensions import celery_app
from superset.models.core import Log
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.models.tags import Tag, TaggedObject
from superset.tags.models import Tag, TaggedObject
from superset.utils.date_parser import parse_human_datetime
from superset.utils.machine_auth import MachineAuthProvider

View File

@ -18,7 +18,7 @@ from typing import Any, List
from werkzeug.routing import BaseConverter, Map
from superset.models.tags import ObjectTypes
from superset.tags.models import ObjectTypes
class RegexConverter(BaseConverter):

View File

@ -28,12 +28,13 @@ from sqlalchemy import and_, func
from werkzeug.exceptions import NotFound
from superset import db, is_feature_enabled, utils
from superset.connectors.sqla.models import SqlaTable
from superset.jinja_context import ExtraCache
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.models.sql_lab import SavedQuery
from superset.models.tags import ObjectTypes, Tag, TaggedObject, TagTypes
from superset.superset_typing import FlaskResponse
from superset.tags.models import ObjectTypes, Tag, TaggedObject, TagTypes
from .base import BaseSupersetView, json_success
@ -238,4 +239,31 @@ class TagView(BaseSupersetView):
for obj in saved_queries
)
# datasets
if not types or "dataset" in types:
datasets = (
db.session.query(SqlaTable)
.join(
TaggedObject,
and_(
TaggedObject.object_id == SqlaTable.id,
TaggedObject.object_type == ObjectTypes.dataset,
),
)
.join(Tag, TaggedObject.tag_id == Tag.id)
.filter(Tag.name.in_(tags))
)
results.extend(
{
"id": obj.id,
"type": ObjectTypes.dataset.name,
"name": obj.table_name,
"url": obj.sql_url(),
"changed_on": obj.changed_on,
"created_by": obj.created_by_fk,
"creator": obj.creator(),
}
for obj in datasets
)
return json_success(json.dumps(results, default=utils.core.json_int_dttm_ser))

View File

@ -0,0 +1,33 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from superset.tags.core import clear_sqla_event_listeners, register_sqla_event_listeners
from tests.integration_tests.test_app import app
@pytest.fixture
def with_tagging_system_feature():
with app.app_context():
is_enabled = app.config["DEFAULT_FEATURE_FLAGS"]["TAGGING_SYSTEM"]
if not is_enabled:
app.config["DEFAULT_FEATURE_FLAGS"]["TAGGING_SYSTEM"] = True
register_sqla_event_listeners()
yield
app.config["DEFAULT_FEATURE_FLAGS"]["TAGGING_SYSTEM"] = False
clear_sqla_event_listeners()

View File

@ -35,7 +35,7 @@ from superset.utils.database import get_example_database
from superset import db
from superset.models.core import Log
from superset.models.tags import get_tag, ObjectTypes, TaggedObject, TagTypes
from superset.tags.models import get_tag, ObjectTypes, TaggedObject, TagTypes
from superset.tasks.cache import (
DashboardTagsStrategy,
TopNDashboardsStrategy,

View File

@ -15,11 +15,33 @@
# specific language governing permissions and limitations
# under the License.
from unittest import mock
import pytest
from superset.connectors.sqla.models import SqlaTable
from superset.extensions import db
from superset.models.core import FavStar
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.models.sql_lab import SavedQuery
from superset.tags.models import TaggedObject
from superset.utils.core import DatasourceType
from superset.utils.database import get_main_database
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.conftest import with_feature_flags
from tests.integration_tests.fixtures.tags import with_tagging_system_feature
class TestTagging(SupersetTestCase):
def query_tagged_object_table(self):
query = db.session.query(TaggedObject).all()
return query
def clear_tagged_object_table(self):
db.session.query(TaggedObject).delete()
db.session.commit()
@with_feature_flags(TAGGING_SYSTEM=False)
def test_tag_view_disabled(self):
self.login("admin")
@ -31,3 +53,257 @@ class TestTagging(SupersetTestCase):
self.login("admin")
response = self.client.get("/tagview/tags/suggestions/")
self.assertNotEqual(404, response.status_code)
@pytest.mark.usefixtures("with_tagging_system_feature")
def test_dataset_tagging(self):
"""
Test to make sure that when a new dataset is created,
a corresponding tag in the tagged_objects table
is created
"""
# Remove all existing rows in the tagged_object table
self.clear_tagged_object_table()
# Test to make sure nothing is in the tagged_object table
self.assertEqual([], self.query_tagged_object_table())
# Create a dataset and add it to the db
test_dataset = SqlaTable(
table_name="foo",
schema=None,
owners=[],
database=get_main_database(),
sql=None,
extra='{"certification": 1}',
)
db.session.add(test_dataset)
db.session.commit()
# Test to make sure that a dataset tag was added to the tagged_object table
tags = self.query_tagged_object_table()
self.assertEqual(1, len(tags))
self.assertEqual("ObjectTypes.dataset", str(tags[0].object_type))
self.assertEqual(test_dataset.id, tags[0].object_id)
# Cleanup the db
db.session.delete(test_dataset)
db.session.commit()
# Test to make sure the tag is deleted when the associated object is deleted
self.assertEqual([], self.query_tagged_object_table())
@pytest.mark.usefixtures("with_tagging_system_feature")
def test_chart_tagging(self):
"""
Test to make sure that when a new chart is created,
a corresponding tag in the tagged_objects table
is created
"""
# Remove all existing rows in the tagged_object table
self.clear_tagged_object_table()
# Test to make sure nothing is in the tagged_object table
self.assertEqual([], self.query_tagged_object_table())
# Create a chart and add it to the db
test_chart = Slice(
slice_name="test_chart",
datasource_type=DatasourceType.TABLE,
viz_type="bubble",
datasource_id=1,
id=1,
)
db.session.add(test_chart)
db.session.commit()
# Test to make sure that a chart tag was added to the tagged_object table
tags = self.query_tagged_object_table()
self.assertEqual(1, len(tags))
self.assertEqual("ObjectTypes.chart", str(tags[0].object_type))
self.assertEqual(test_chart.id, tags[0].object_id)
# Cleanup the db
db.session.delete(test_chart)
db.session.commit()
# Test to make sure the tag is deleted when the associated object is deleted
self.assertEqual([], self.query_tagged_object_table())
@pytest.mark.usefixtures("with_tagging_system_feature")
def test_dashboard_tagging(self):
"""
Test to make sure that when a new dashboard is created,
a corresponding tag in the tagged_objects table
is created
"""
# Remove all existing rows in the tagged_object table
self.clear_tagged_object_table()
# Test to make sure nothing is in the tagged_object table
self.assertEqual([], self.query_tagged_object_table())
# Create a dashboard and add it to the db
test_dashboard = Dashboard()
test_dashboard.dashboard_title = "test_dashboard"
test_dashboard.slug = "test_slug"
test_dashboard.slices = []
test_dashboard.published = True
db.session.add(test_dashboard)
db.session.commit()
# Test to make sure that a dashboard tag was added to the tagged_object table
tags = self.query_tagged_object_table()
self.assertEqual(1, len(tags))
self.assertEqual("ObjectTypes.dashboard", str(tags[0].object_type))
self.assertEqual(test_dashboard.id, tags[0].object_id)
# Cleanup the db
db.session.delete(test_dashboard)
db.session.commit()
# Test to make sure the tag is deleted when the associated object is deleted
self.assertEqual([], self.query_tagged_object_table())
@pytest.mark.usefixtures("with_tagging_system_feature")
def test_saved_query_tagging(self):
"""
Test to make sure that when a new saved query is
created, a corresponding tag in the tagged_objects
table is created
"""
# Remove all existing rows in the tagged_object table
self.clear_tagged_object_table()
# Test to make sure nothing is in the tagged_object table
self.assertEqual([], self.query_tagged_object_table())
# Create a saved query and add it to the db
test_saved_query = SavedQuery(id=1, label="test saved query")
db.session.add(test_saved_query)
db.session.commit()
# Test to make sure that a saved query tag was added to the tagged_object table
tags = self.query_tagged_object_table()
self.assertEqual(2, len(tags))
self.assertEqual("ObjectTypes.query", str(tags[0].object_type))
self.assertEqual("owner:None", str(tags[0].tag.name))
self.assertEqual("TagTypes.owner", str(tags[0].tag.type))
self.assertEqual(test_saved_query.id, tags[0].object_id)
self.assertEqual("ObjectTypes.query", str(tags[1].object_type))
self.assertEqual("type:query", str(tags[1].tag.name))
self.assertEqual("TagTypes.type", str(tags[1].tag.type))
self.assertEqual(test_saved_query.id, tags[1].object_id)
# Cleanup the db
db.session.delete(test_saved_query)
db.session.commit()
# Test to make sure the tag is deleted when the associated object is deleted
self.assertEqual([], self.query_tagged_object_table())
@pytest.mark.usefixtures("with_tagging_system_feature")
def test_favorite_tagging(self):
"""
Test to make sure that when a new favorite object is
created, a corresponding tag in the tagged_objects
table is created
"""
# Remove all existing rows in the tagged_object table
self.clear_tagged_object_table()
# Test to make sure nothing is in the tagged_object table
self.assertEqual([], self.query_tagged_object_table())
# Create a favorited object and add it to the db
test_saved_query = FavStar(user_id=1, class_name="slice", obj_id=1)
db.session.add(test_saved_query)
db.session.commit()
# Test to make sure that a favorited object tag was added to the tagged_object table
tags = self.query_tagged_object_table()
self.assertEqual(1, len(tags))
self.assertEqual("ObjectTypes.chart", str(tags[0].object_type))
self.assertEqual(test_saved_query.obj_id, tags[0].object_id)
# Cleanup the db
db.session.delete(test_saved_query)
db.session.commit()
# Test to make sure the tag is deleted when the associated object is deleted
self.assertEqual([], self.query_tagged_object_table())
@with_feature_flags(TAGGING_SYSTEM=False)
def test_tagging_system(self):
"""
Test to make sure that when the TAGGING_SYSTEM
feature flag is false, that no tags are created
"""
# Remove all existing rows in the tagged_object table
self.clear_tagged_object_table()
# Test to make sure nothing is in the tagged_object table
self.assertEqual([], self.query_tagged_object_table())
# Create a dataset and add it to the db
test_dataset = SqlaTable(
table_name="foo",
schema=None,
owners=[],
database=get_main_database(),
sql=None,
extra='{"certification": 1}',
)
# Create a chart and add it to the db
test_chart = Slice(
slice_name="test_chart",
datasource_type=DatasourceType.TABLE,
viz_type="bubble",
datasource_id=1,
id=1,
)
# Create a dashboard and add it to the db
test_dashboard = Dashboard()
test_dashboard.dashboard_title = "test_dashboard"
test_dashboard.slug = "test_slug"
test_dashboard.slices = []
test_dashboard.published = True
# Create a saved query and add it to the db
test_saved_query = SavedQuery(id=1, label="test saved query")
# Create a favorited object and add it to the db
test_favorited_object = FavStar(user_id=1, class_name="slice", obj_id=1)
db.session.add(test_dataset)
db.session.add(test_chart)
db.session.add(test_dashboard)
db.session.add(test_saved_query)
db.session.add(test_favorited_object)
db.session.commit()
# Test to make sure that no tags were added to the tagged_object table
tags = self.query_tagged_object_table()
self.assertEqual(0, len(tags))
# Cleanup the db
db.session.delete(test_dataset)
db.session.delete(test_chart)
db.session.delete(test_dashboard)
db.session.delete(test_saved_query)
db.session.delete(test_favorited_object)
db.session.commit()
# Test to make sure all the tags are deleted when the associated objects are deleted
self.assertEqual([], self.query_tagged_object_table())