diff --git a/superset/views/datasource.py b/superset/views/datasource.py index 900dbd090..12c6916ed 100644 --- a/superset/views/datasource.py +++ b/superset/views/datasource.py @@ -16,16 +16,18 @@ # under the License. # pylint: disable=C,R,W import json +from collections import Counter from flask import request from flask_appbuilder import expose from flask_appbuilder.security.decorators import has_access_api +from sqlalchemy.exc import IntegrityError from superset import appbuilder, db from superset.connectors.connector_registry import ConnectorRegistry from superset.models.core import Database -from .base import BaseSupersetView, json_error_response +from .base import api, BaseSupersetView, handle_api_exception, json_error_response class Datasource(BaseSupersetView): @@ -33,6 +35,8 @@ class Datasource(BaseSupersetView): @expose("/save/", methods=["POST"]) @has_access_api + @api + @handle_api_exception def save(self): datasource = json.loads(request.form.get("data")) datasource_id = datasource.get("id") @@ -47,13 +51,29 @@ class Datasource(BaseSupersetView): .filter(orm_datasource.owner_class.id.in_(datasource["owners"])) .all() ) + + duplicates = [ + name + for name, count in Counter( + [col["column_name"] for col in datasource["columns"]] + ).items() + if count > 1 + ] + if duplicates: + return json_error_response( + f"Duplicate column name(s): {','.join(duplicates)}", status="409" + ) + orm_datasource.update_from_object(datasource) data = orm_datasource.data db.session.commit() + return self.json_response(data) @expose("/get///") @has_access_api + @api + @handle_api_exception def get(self, datasource_type, datasource_id): orm_datasource = ConnectorRegistry.get_datasource( datasource_type, datasource_id, db.session @@ -68,6 +88,8 @@ class Datasource(BaseSupersetView): @expose("/external_metadata///") @has_access_api + @api + @handle_api_exception def external_metadata(self, datasource_type=None, datasource_id=None): """Gets column info from the source system""" if datasource_type == "druid": diff --git a/tests/datasource_tests.py b/tests/datasource_tests.py index 3bffdf560..fa3616c32 100644 --- a/tests/datasource_tests.py +++ b/tests/datasource_tests.py @@ -16,6 +16,7 @@ # under the License. """Unit tests for Superset""" import json +from copy import deepcopy from .base_tests import SupersetTestCase from .fixtures.datasource import datasource_post @@ -63,6 +64,33 @@ class DatasourceTests(SupersetTestCase): else: self.assertEqual(resp[k], datasource_post[k]) + def test_save_duplicate_key(self): + self.login(username="admin") + tbl_id = self.get_table_by_name("birth_names").id + datasource_post_copy = deepcopy(datasource_post) + datasource_post_copy["id"] = tbl_id + datasource_post_copy["columns"].extend( + [ + { + "column_name": "", + "filterable": True, + "groupby": True, + "expression": "", + "id": "somerandomid", + }, + { + "column_name": "", + "filterable": True, + "groupby": True, + "expression": "", + "id": "somerandomid2", + }, + ] + ) + data = dict(data=json.dumps(datasource_post_copy)) + resp = self.get_json_resp("/datasource/save/", data, raise_on_error=False) + self.assertIn("Duplicate column name(s): ", resp["error"]) + def test_get_datasource(self): self.login(username="admin") tbl = self.get_table_by_name("birth_names")