From 66f1e1f714ee611cecf60b0e0aa2fbde149a3a4f Mon Sep 17 00:00:00 2001
From: Levis Mbote <111055098+LevisNgigi@users.noreply.github.com>
Date: Mon, 13 Jan 2025 19:30:29 +0300
Subject: [PATCH] refactor(bulk_select): Fix bulk select tagging issues for
users (#31631)
---
.../src/features/tags/BulkTagModal.test.tsx | 114 ++++++++++++++++++
.../src/features/tags/BulkTagModal.tsx | 2 +-
superset/commands/tag/create.py | 33 ++---
3 files changed, 134 insertions(+), 15 deletions(-)
create mode 100644 superset-frontend/src/features/tags/BulkTagModal.test.tsx
diff --git a/superset-frontend/src/features/tags/BulkTagModal.test.tsx b/superset-frontend/src/features/tags/BulkTagModal.test.tsx
new file mode 100644
index 000000000..d590dd551
--- /dev/null
+++ b/superset-frontend/src/features/tags/BulkTagModal.test.tsx
@@ -0,0 +1,114 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License. You may obtain
+ * a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software distributed
+ * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS
+ * OF ANY KIND, either express or implied. See the License for the specific language
+ * governing permissions and limitations under the License.
+ */
+
+import {
+ render,
+ screen,
+ fireEvent,
+ waitFor,
+} from 'spec/helpers/testing-library';
+import fetchMock from 'fetch-mock';
+import BulkTagModal from './BulkTagModal';
+
+const mockedProps = {
+ onHide: jest.fn(),
+ refreshData: jest.fn(),
+ addSuccessToast: jest.fn(),
+ addDangerToast: jest.fn(),
+ show: true,
+ selected: [
+ { original: { id: 1, name: 'Dashboard 1' } },
+ { original: { id: 2, name: 'Dashboard 2' } },
+ ],
+ resourceName: 'dashboard',
+};
+
+describe('BulkTagModal', () => {
+ afterEach(() => {
+ fetchMock.reset();
+ jest.clearAllMocks();
+ });
+
+ test('should render', () => {
+ const { container } = render();
+ expect(container).toBeInTheDocument();
+ });
+
+ test('renders the correct title and message', () => {
+ render();
+ expect(
+ screen.getByText(/you are adding tags to 2 dashboards/i),
+ ).toBeInTheDocument();
+ expect(screen.getByText('Bulk tag')).toBeInTheDocument();
+ });
+
+ test('renders tags input field', async () => {
+ render();
+ const tagsInput = await screen.findByRole('combobox', { name: /tags/i });
+ expect(tagsInput).toBeInTheDocument();
+ });
+
+ test('calls onHide when the Cancel button is clicked', () => {
+ render();
+ const cancelButton = screen.getByText('Cancel');
+ fireEvent.click(cancelButton);
+ expect(mockedProps.onHide).toHaveBeenCalled();
+ });
+
+ test('submits the selected tags and shows success toast', async () => {
+ fetchMock.post('glob:*/api/v1/tag/bulk_create', {
+ result: {
+ objects_tagged: [1, 2],
+ objects_skipped: [],
+ },
+ });
+
+ render();
+
+ const tagsInput = await screen.findByRole('combobox', { name: /tags/i });
+ fireEvent.change(tagsInput, { target: { value: 'Test Tag' } });
+ fireEvent.keyDown(tagsInput, { key: 'Enter', code: 'Enter' });
+
+ fireEvent.click(screen.getByText('Save'));
+
+ await waitFor(() => {
+ expect(mockedProps.addSuccessToast).toHaveBeenCalledWith(
+ 'Tagged 2 dashboards',
+ );
+ });
+
+ expect(mockedProps.refreshData).toHaveBeenCalled();
+ expect(mockedProps.onHide).toHaveBeenCalled();
+ });
+
+ test('handles API errors gracefully', async () => {
+ fetchMock.post('glob:*/api/v1/tag/bulk_create', 500);
+
+ render();
+
+ const tagsInput = await screen.findByRole('combobox', { name: /tags/i });
+ fireEvent.change(tagsInput, { target: { value: 'Test Tag' } });
+ fireEvent.keyDown(tagsInput, { key: 'Enter', code: 'Enter' });
+
+ fireEvent.click(screen.getByText('Save'));
+
+ await waitFor(() => {
+ expect(mockedProps.addDangerToast).toHaveBeenCalledWith(
+ 'Failed to tag items',
+ );
+ });
+ });
+});
diff --git a/superset-frontend/src/features/tags/BulkTagModal.tsx b/superset-frontend/src/features/tags/BulkTagModal.tsx
index 32ccd4b0d..319546128 100644
--- a/superset-frontend/src/features/tags/BulkTagModal.tsx
+++ b/superset-frontend/src/features/tags/BulkTagModal.tsx
@@ -59,7 +59,7 @@ const BulkTagModal: FC = ({
endpoint: `/api/v1/tag/bulk_create`,
jsonPayload: {
tags: tags.map(tag => ({
- name: tag.value,
+ name: tag.label,
objects_to_tag: selected.map(item => [
resourceName,
+item.original.id,
diff --git a/superset/commands/tag/create.py b/superset/commands/tag/create.py
index 775250dc8..b3788eb87 100644
--- a/superset/commands/tag/create.py
+++ b/superset/commands/tag/create.py
@@ -88,27 +88,32 @@ class CreateCustomTagWithRelationshipsCommand(CreateMixin, BaseCommand):
def validate(self) -> None:
exceptions = []
objects_to_tag = set(self._properties.get("objects_to_tag", []))
+
for obj_type, obj_id in objects_to_tag:
object_type = to_object_type(obj_type)
# Validate object type
- for obj_type, obj_id in objects_to_tag:
- object_type = to_object_type(obj_type)
+ if not object_type:
+ exceptions.append(TagInvalidError(f"invalid object type {object_type}"))
+ continue
- if not object_type:
- exceptions.append(
- TagInvalidError(f"invalid object type {object_type}")
- )
- try:
- if model := to_object_model(object_type, obj_id): # type: ignore
+ try:
+ if model := to_object_model(object_type, obj_id):
+ try:
security_manager.raise_for_ownership(model)
- except SupersetSecurityException:
- # skip the object if the user doesn't have access
- self._skipped_tagged_objects.add((obj_type, obj_id))
+ except SupersetSecurityException:
+ if (
+ not model.created_by
+ or model.created_by != security_manager.current_user
+ ):
+ # skip the object if the user doesn't have access
+ self._skipped_tagged_objects.add((obj_type, obj_id))
+ except Exception as e:
+ exceptions.append(TagInvalidError(str(e)))
- self._properties["objects_to_tag"] = (
- set(objects_to_tag) - self._skipped_tagged_objects
- )
+ self._properties["objects_to_tag"] = (
+ set(objects_to_tag) - self._skipped_tagged_objects
+ )
if exceptions:
raise TagInvalidError(exceptions=exceptions)