feat(tag): fast follow for Tags flatten api + update client with generator + some bug fixes (#25309)

This commit is contained in:
Hugh A. Miles II 2023-09-18 14:56:08 -04:00 committed by GitHub
parent 6e799e37f4
commit 090ae64dfa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 124 additions and 37 deletions

View File

@ -45,13 +45,19 @@ const BulkTagModal: React.FC<BulkTagModalProps> = ({
addDangerToast,
}) => {
useEffect(() => {}, []);
const [tags, setTags] = useState<TaggableResourceOption[]>([]);
const onSave = async () => {
await SupersetClient.post({
endpoint: `/api/v1/tag/bulk_create`,
jsonPayload: {
tags: tags.map(tag => tag.value),
objects_to_tag: selected.map(item => [resourceName, +item.original.id]),
tags: tags.map(tag => ({
name: tag.value,
objects_to_tag: selected.map(item => [
resourceName,
+item.original.id,
]),
})),
},
})
.then(({ json = {} }) => {
@ -66,8 +72,6 @@ const BulkTagModal: React.FC<BulkTagModalProps> = ({
setTags([]);
};
const [tags, setTags] = useState<TaggableResourceOption[]>([]);
return (
<Modal
title={t('Bulk tag')}

View File

@ -412,4 +412,3 @@ class TagDAO(BaseDAO[Tag]):
)
db.session.add_all(tagged_objects)
db.session.commit()

View File

@ -260,7 +260,10 @@ class TagRestApi(BaseSupersetModelRestApi):
try:
for tag in item.get("tags"):
tagged_item: dict[str, Any] = self.add_model_schema.load(
{"name": tag, "objects_to_tag": item.get("objects_to_tag")}
{
"name": tag.get("name"),
"objects_to_tag": tag.get("objects_to_tag"),
}
)
CreateCustomTagWithRelationshipsCommand(
tagged_item, bulk_create=True

View File

@ -17,12 +17,13 @@
import logging
from typing import Any
from superset import db
from superset import db, security_manager
from superset.commands.base import BaseCommand, CreateMixin
from superset.daos.exceptions import DAOCreateFailedError
from superset.daos.tag import TagDAO
from superset.exceptions import SupersetSecurityException
from superset.tags.commands.exceptions import TagCreateFailedError, TagInvalidError
from superset.tags.commands.utils import to_object_type
from superset.tags.commands.utils import to_object_model, to_object_type
from superset.tags.models import ObjectTypes, TagTypes
logger = logging.getLogger(__name__)
@ -73,6 +74,7 @@ class CreateCustomTagWithRelationshipsCommand(CreateMixin, BaseCommand):
def run(self) -> None:
self.validate()
try:
tag = TagDAO.get_by_name(self._tag.strip(), TagTypes.custom)
if self._objects_to_tag:
@ -84,7 +86,8 @@ class CreateCustomTagWithRelationshipsCommand(CreateMixin, BaseCommand):
if self._description:
tag.description = self._description
db.session.commit()
db.session.commit()
except DAOCreateFailedError as ex:
logger.exception(ex.exception)
@ -98,12 +101,25 @@ class CreateCustomTagWithRelationshipsCommand(CreateMixin, BaseCommand):
exceptions.append(TagInvalidError())
# Validate object type
skipped_tagged_objects: list[tuple[str, int]] = []
for obj_type, obj_id in self._objects_to_tag:
skipped_tagged_objects = []
object_type = to_object_type(obj_type)
if not object_type:
exceptions.append(
TagInvalidError(f"invalid object type {object_type}")
)
try:
model = to_object_model(object_type, obj_id) # type: ignore
security_manager.raise_for_ownership(model)
except SupersetSecurityException:
# skip the object if the user doesn't have access
skipped_tagged_objects.append((obj_type, obj_id))
self._objects_to_tag = set(self._objects_to_tag) - set(
skipped_tagged_objects
)
if exceptions:
raise TagInvalidError(exceptions=exceptions)

View File

@ -17,6 +17,12 @@
from typing import Optional, Union
from superset.daos.chart import ChartDAO
from superset.daos.dashboard import DashboardDAO
from superset.daos.query import SavedQueryDAO
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.models.sql_lab import SavedQuery
from superset.tags.models import ObjectTypes
@ -27,3 +33,15 @@ def to_object_type(object_type: Union[ObjectTypes, int, str]) -> Optional[Object
if object_type in [type_.value, type_.name]:
return type_
return None
def to_object_model(
object_type: ObjectTypes, object_id: int
) -> Optional[Union[Dashboard, SavedQuery, Slice]]:
if ObjectTypes.dashboard == object_type:
return DashboardDAO.find_by_id(object_id)
if ObjectTypes.query == object_type:
return SavedQueryDAO.find_by_id(object_id)
if ObjectTypes.chart == object_type:
return ChartDAO.find_by_id(object_id)
return None

View File

@ -54,27 +54,21 @@ class TagGetResponseSchema(Schema):
type = fields.String()
class TagPostSchema(Schema):
class TagObjectSchema(Schema):
name = fields.String()
description = fields.String(required=False, allow_none=True)
# resource id's to tag with tag
objects_to_tag = fields.List(
fields.Tuple((fields.String(), fields.Int())), required=False
)
class TagPostBulkSchema(Schema):
tags = fields.List(fields.String())
# resource id's to tag with tag
objects_to_tag = fields.List(
fields.Tuple((fields.String(), fields.Int())), required=False
)
tags = fields.List(fields.Nested(TagObjectSchema))
class TagPutSchema(Schema):
name = fields.String()
description = fields.String(required=False, allow_none=True)
# resource id's to tag with tag
objects_to_tag = fields.List(
fields.Tuple((fields.String(), fields.Int())), required=False
)
class TagPostSchema(TagObjectSchema):
pass
class TagPutSchema(TagObjectSchema):
pass

View File

@ -530,8 +530,23 @@ class TestTagApi(SupersetTestCase):
rv = self.client.post(
uri,
json={
"tags": ["tag1", "tag2", "tag3"],
"objects_to_tag": [["dashboard", dashboard.id], ["chart", chart.id]],
"tags": [
{
"name": "tag1",
"objects_to_tag": [
["dashboard", dashboard.id],
["chart", chart.id],
],
},
{
"name": "tag2",
"objects_to_tag": [["dashboard", dashboard.id]],
},
{
"name": "tag3",
"objects_to_tag": [["chart", chart.id]],
},
]
},
)
@ -547,11 +562,10 @@ class TestTagApi(SupersetTestCase):
TaggedObject.object_id == dashboard.id,
TaggedObject.object_type == ObjectTypes.dashboard,
)
assert tagged_objects.count() == 3
assert tagged_objects.count() == 2
tagged_objects = db.session.query(TaggedObject).filter(
# TaggedObject.tag_id.in_([tag.id for tag in tags]),
TaggedObject.object_id == chart.id,
TaggedObject.object_type == ObjectTypes.chart,
)
assert tagged_objects.count() == 3
assert tagged_objects.count() == 2

View File

@ -169,6 +169,3 @@ def test_create_tag_relationship(mocker):
# Verify that the correct number of TaggedObjects are added to the session
assert mock_session.add_all.call_count == 1
assert len(mock_session.add_all.call_args[0][0]) == len(objects_to_tag)
# Verify that commit is called
mock_session.commit.assert_called_once()

View File

@ -1,4 +1,5 @@
import pytest
from pytest_mock import MockFixture
from sqlalchemy.orm.session import Session
from superset.utils.core import DatasourceType
@ -47,7 +48,7 @@ def session_with_data(session: Session):
yield session
def test_create_command_success(session_with_data: Session):
def test_create_command_success(session_with_data: Session, mocker: MockFixture):
from superset.connectors.sqla.models import SqlaTable
from superset.daos.tag import TagDAO
from superset.models.dashboard import Dashboard
@ -61,6 +62,12 @@ def test_create_command_success(session_with_data: Session):
chart = session_with_data.query(Slice).first()
dashboard = session_with_data.query(Dashboard).first()
mocker.patch(
"superset.security.SupersetSecurityManager.is_admin", return_value=True
)
mocker.patch("superset.daos.chart.ChartDAO.find_by_id", return_value=chart)
mocker.patch("superset.daos.query.SavedQueryDAO.find_by_id", return_value=query)
objects_to_tag = [
(ObjectTypes.query, query.id),
(ObjectTypes.chart, chart.id),
@ -84,7 +91,9 @@ def test_create_command_success(session_with_data: Session):
)
def test_create_command_failed_validate(session_with_data: Session):
def test_create_command_failed_validate(
session_with_data: Session, mocker: MockFixture
):
from superset.connectors.sqla.models import SqlaTable
from superset.daos.tag import TagDAO
from superset.models.dashboard import Dashboard
@ -98,6 +107,12 @@ def test_create_command_failed_validate(session_with_data: Session):
chart = session_with_data.query(Slice).first()
dashboard = session_with_data.query(Dashboard).first()
mocker.patch(
"superset.security.SupersetSecurityManager.is_admin", return_value=True
)
mocker.patch("superset.daos.chart.ChartDAO.find_by_id", return_value=query)
mocker.patch("superset.daos.query.SavedQueryDAO.find_by_id", return_value=chart)
objects_to_tag = [
(ObjectTypes.query, query.id),
(ObjectTypes.chart, chart.id),

View File

@ -1,4 +1,5 @@
import pytest
from pytest_mock import MockFixture
from sqlalchemy.orm.session import Session
from superset.utils.core import DatasourceType
@ -56,13 +57,19 @@ def session_with_data(session: Session):
yield session
def test_update_command_success(session_with_data: Session):
def test_update_command_success(session_with_data: Session, mocker: MockFixture):
from superset.daos.tag import TagDAO
from superset.models.dashboard import Dashboard
from superset.tags.commands.update import UpdateTagCommand
from superset.tags.models import ObjectTypes, TaggedObject
dashboard = session_with_data.query(Dashboard).first()
mocker.patch(
"superset.security.SupersetSecurityManager.is_admin", return_value=True
)
mocker.patch(
"superset.daos.dashboard.DashboardDAO.find_by_id", return_value=dashboard
)
objects_to_tag = [
(ObjectTypes.dashboard, dashboard.id),
@ -84,7 +91,9 @@ def test_update_command_success(session_with_data: Session):
assert len(session_with_data.query(TaggedObject).all()) == len(objects_to_tag)
def test_update_command_success_duplicates(session_with_data: Session):
def test_update_command_success_duplicates(
session_with_data: Session, mocker: MockFixture
):
from superset.daos.tag import TagDAO
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
@ -95,6 +104,14 @@ def test_update_command_success_duplicates(session_with_data: Session):
dashboard = session_with_data.query(Dashboard).first()
chart = session_with_data.query(Slice).first()
mocker.patch(
"superset.security.SupersetSecurityManager.is_admin", return_value=True
)
mocker.patch("superset.daos.chart.ChartDAO.find_by_id", return_value=chart)
mocker.patch(
"superset.daos.dashboard.DashboardDAO.find_by_id", return_value=dashboard
)
objects_to_tag = [
(ObjectTypes.dashboard, dashboard.id),
]
@ -124,14 +141,16 @@ def test_update_command_success_duplicates(session_with_data: Session):
assert changed_model.objects[0].object_id == chart.id
def test_update_command_failed_validation(session_with_data: Session):
def test_update_command_failed_validation(
session_with_data: Session, mocker: MockFixture
):
from superset.daos.tag import TagDAO
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.tags.commands.create import CreateCustomTagWithRelationshipsCommand
from superset.tags.commands.exceptions import TagInvalidError
from superset.tags.commands.update import UpdateTagCommand
from superset.tags.models import ObjectTypes, TaggedObject
from superset.tags.models import ObjectTypes
dashboard = session_with_data.query(Dashboard).first()
chart = session_with_data.query(Slice).first()
@ -139,6 +158,14 @@ def test_update_command_failed_validation(session_with_data: Session):
(ObjectTypes.chart, chart.id),
]
mocker.patch(
"superset.security.SupersetSecurityManager.is_admin", return_value=True
)
mocker.patch("superset.daos.chart.ChartDAO.find_by_id", return_value=chart)
mocker.patch(
"superset.daos.dashboard.DashboardDAO.find_by_id", return_value=dashboard
)
CreateCustomTagWithRelationshipsCommand(
data={"name": "test_tag", "objects_to_tag": objects_to_tag}
).run()