diff --git a/superset/daos/tag.py b/superset/daos/tag.py index b6872a537..2acd221a3 100644 --- a/superset/daos/tag.py +++ b/superset/daos/tag.py @@ -390,7 +390,12 @@ class TagDAO(BaseDAO[Tag]): updated_tagged_objects = { (to_object_type(obj[0]), obj[1]) for obj in objects_to_tag } - tagged_objects_to_delete = current_tagged_objects - updated_tagged_objects + + tagged_objects_to_delete = ( + current_tagged_objects + if not objects_to_tag + else current_tagged_objects - updated_tagged_objects + ) for object_type, object_id in updated_tagged_objects: # create rows for new objects, and skip tags that already exist diff --git a/superset/tags/commands/create.py b/superset/tags/commands/create.py index e8311ad52..883c498bc 100644 --- a/superset/tags/commands/create.py +++ b/superset/tags/commands/create.py @@ -67,25 +67,22 @@ class CreateCustomTagCommand(CreateMixin, BaseCommand): class CreateCustomTagWithRelationshipsCommand(CreateMixin, BaseCommand): def __init__(self, data: dict[str, Any], bulk_create: bool = False): - self._tag = data["name"] - self._objects_to_tag = data.get("objects_to_tag") - self._description = data.get("description") + self._properties = data.copy() self._bulk_create = bulk_create def run(self) -> None: self.validate() try: - tag = TagDAO.get_by_name(self._tag.strip(), TagTypes.custom) - if self._objects_to_tag: - TagDAO.create_tag_relationship( - objects_to_tag=self._objects_to_tag, - tag=tag, - bulk_create=self._bulk_create, - ) + tag_name = self._properties["name"] + tag = TagDAO.get_by_name(tag_name.strip(), TagTypes.custom) + TagDAO.create_tag_relationship( + objects_to_tag=self._properties.get("objects_to_tag", []), + tag=tag, + bulk_create=self._bulk_create, + ) - if self._description: - tag.description = self._description + tag.description = self._properties.get("description", "") db.session.commit() @@ -95,31 +92,21 @@ class CreateCustomTagWithRelationshipsCommand(CreateMixin, BaseCommand): def validate(self) -> None: exceptions = [] - # Validate object_id - if self._objects_to_tag: - if any(obj_id == 0 for obj_type, obj_id in self._objects_to_tag): - exceptions.append(TagInvalidError()) + objects_to_tag = set(self._properties.get("objects_to_tag", [])) + skipped_tagged_objects: set[tuple[str, int]] = set() + for obj_type, obj_id in objects_to_tag: + object_type = to_object_type(obj_type) - # Validate object type - skipped_tagged_objects: list[tuple[str, int]] = [] - for obj_type, obj_id in self._objects_to_tag: - skipped_tagged_objects = [] - object_type = to_object_type(obj_type) + if not object_type: + exceptions.append(TagInvalidError(f"invalid object type {object_type}")) + try: + model = to_object_model(object_type, obj_id) # type: ignore + security_manager.raise_for_ownership(model) + except SupersetSecurityException: + # skip the object if the user doesn't have access + skipped_tagged_objects.add((obj_type, obj_id)) - if not object_type: - exceptions.append( - TagInvalidError(f"invalid object type {object_type}") - ) - try: - model = to_object_model(object_type, obj_id) # type: ignore - security_manager.raise_for_ownership(model) - except SupersetSecurityException: - # skip the object if the user doesn't have access - skipped_tagged_objects.append((obj_type, obj_id)) - - self._objects_to_tag = set(self._objects_to_tag) - set( - skipped_tagged_objects - ) + self._properties["objects_to_tag"] = objects_to_tag - skipped_tagged_objects if exceptions: raise TagInvalidError(exceptions=exceptions) diff --git a/superset/tags/commands/update.py b/superset/tags/commands/update.py index a13e4e8e7..cc1c9a2be 100644 --- a/superset/tags/commands/update.py +++ b/superset/tags/commands/update.py @@ -38,12 +38,10 @@ class UpdateTagCommand(UpdateMixin, BaseCommand): def run(self) -> Model: self.validate() if self._model: - if self._properties.get("objects_to_tag"): - # todo(hugh): can this manage duplication - TagDAO.create_tag_relationship( - objects_to_tag=self._properties["objects_to_tag"], - tag=self._model, - ) + TagDAO.create_tag_relationship( + objects_to_tag=self._properties.get("objects_to_tag", []), + tag=self._model, + ) if description := self._properties.get("description"): self._model.description = description if tag_name := self._properties.get("name"): @@ -63,11 +61,8 @@ class UpdateTagCommand(UpdateMixin, BaseCommand): # Validate object_id if objects_to_tag := self._properties.get("objects_to_tag"): - if any(obj_id == 0 for obj_type, obj_id in objects_to_tag): - exceptions.append(TagInvalidError(" invalid object_id")) - # Validate object type - for obj_type, obj_id in objects_to_tag: + for obj_type, _ in objects_to_tag: object_type = to_object_type(obj_type) if not object_type: exceptions.append( diff --git a/superset/tags/schemas.py b/superset/tags/schemas.py index 75fdc2410..a391fd2b8 100644 --- a/superset/tags/schemas.py +++ b/superset/tags/schemas.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. from marshmallow import fields, Schema +from marshmallow.validate import Range from superset.dashboards.schemas import UserSchema @@ -60,7 +61,8 @@ class TagObjectSchema(Schema): name = fields.String() description = fields.String(required=False, allow_none=True) objects_to_tag = fields.List( - fields.Tuple((fields.String(), fields.Int())), required=False + fields.Tuple((fields.String(), fields.Int(validate=Range(min=1)))), + required=False, ) diff --git a/tests/unit_tests/tags/commands/create_test.py b/tests/unit_tests/tags/commands/create_test.py index 639372a70..d4143bd4a 100644 --- a/tests/unit_tests/tags/commands/create_test.py +++ b/tests/unit_tests/tags/commands/create_test.py @@ -91,18 +91,16 @@ def test_create_command_success(session_with_data: Session, mocker: MockFixture) ) -def test_create_command_failed_validate( - session_with_data: Session, mocker: MockFixture -): +def test_create_command_success_clear(session_with_data: Session, mocker: MockFixture): from superset.connectors.sqla.models import SqlaTable from superset.daos.tag import TagDAO from superset.models.dashboard import Dashboard from superset.models.slice import Slice from superset.models.sql_lab import Query, SavedQuery from superset.tags.commands.create import CreateCustomTagWithRelationshipsCommand - from superset.tags.commands.exceptions import TagInvalidError from superset.tags.models import ObjectTypes, TaggedObject + # Define a list of objects to tag query = session_with_data.query(SavedQuery).first() chart = session_with_data.query(Slice).first() dashboard = session_with_data.query(Dashboard).first() @@ -110,16 +108,22 @@ def test_create_command_failed_validate( mocker.patch( "superset.security.SupersetSecurityManager.is_admin", return_value=True ) - mocker.patch("superset.daos.chart.ChartDAO.find_by_id", return_value=query) - mocker.patch("superset.daos.query.SavedQueryDAO.find_by_id", return_value=chart) + mocker.patch("superset.daos.chart.ChartDAO.find_by_id", return_value=chart) + mocker.patch("superset.daos.query.SavedQueryDAO.find_by_id", return_value=query) objects_to_tag = [ (ObjectTypes.query, query.id), (ObjectTypes.chart, chart.id), - (ObjectTypes.dashboard, 0), + (ObjectTypes.dashboard, dashboard.id), ] - with pytest.raises(TagInvalidError): - CreateCustomTagWithRelationshipsCommand( - data={"name": "test_tag", "objects_to_tag": objects_to_tag} - ).run() + CreateCustomTagWithRelationshipsCommand( + data={"name": "test_tag", "objects_to_tag": objects_to_tag} + ).run() + assert len(session_with_data.query(TaggedObject).all()) == len(objects_to_tag) + + CreateCustomTagWithRelationshipsCommand( + data={"name": "test_tag", "objects_to_tag": []} + ).run() + + assert len(session_with_data.query(TaggedObject).all()) == 0