guard against duplicate columns in datasource save (#8712)
* catch IntegrityError in datasource save * catch duplicate columns and wrap in exception handling decorators * use 409 * isort
This commit is contained in:
parent
98a82a0720
commit
a94464b9c9
|
|
@ -16,16 +16,18 @@
|
|||
# under the License.
|
||||
# pylint: disable=C,R,W
|
||||
import json
|
||||
from collections import Counter
|
||||
|
||||
from flask import request
|
||||
from flask_appbuilder import expose
|
||||
from flask_appbuilder.security.decorators import has_access_api
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from superset import appbuilder, db
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.models.core import Database
|
||||
|
||||
from .base import BaseSupersetView, json_error_response
|
||||
from .base import api, BaseSupersetView, handle_api_exception, json_error_response
|
||||
|
||||
|
||||
class Datasource(BaseSupersetView):
|
||||
|
|
@ -33,6 +35,8 @@ class Datasource(BaseSupersetView):
|
|||
|
||||
@expose("/save/", methods=["POST"])
|
||||
@has_access_api
|
||||
@api
|
||||
@handle_api_exception
|
||||
def save(self):
|
||||
datasource = json.loads(request.form.get("data"))
|
||||
datasource_id = datasource.get("id")
|
||||
|
|
@ -47,13 +51,29 @@ class Datasource(BaseSupersetView):
|
|||
.filter(orm_datasource.owner_class.id.in_(datasource["owners"]))
|
||||
.all()
|
||||
)
|
||||
|
||||
duplicates = [
|
||||
name
|
||||
for name, count in Counter(
|
||||
[col["column_name"] for col in datasource["columns"]]
|
||||
).items()
|
||||
if count > 1
|
||||
]
|
||||
if duplicates:
|
||||
return json_error_response(
|
||||
f"Duplicate column name(s): {','.join(duplicates)}", status="409"
|
||||
)
|
||||
|
||||
orm_datasource.update_from_object(datasource)
|
||||
data = orm_datasource.data
|
||||
db.session.commit()
|
||||
|
||||
return self.json_response(data)
|
||||
|
||||
@expose("/get/<datasource_type>/<datasource_id>/")
|
||||
@has_access_api
|
||||
@api
|
||||
@handle_api_exception
|
||||
def get(self, datasource_type, datasource_id):
|
||||
orm_datasource = ConnectorRegistry.get_datasource(
|
||||
datasource_type, datasource_id, db.session
|
||||
|
|
@ -68,6 +88,8 @@ class Datasource(BaseSupersetView):
|
|||
|
||||
@expose("/external_metadata/<datasource_type>/<datasource_id>/")
|
||||
@has_access_api
|
||||
@api
|
||||
@handle_api_exception
|
||||
def external_metadata(self, datasource_type=None, datasource_id=None):
|
||||
"""Gets column info from the source system"""
|
||||
if datasource_type == "druid":
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@
|
|||
# under the License.
|
||||
"""Unit tests for Superset"""
|
||||
import json
|
||||
from copy import deepcopy
|
||||
|
||||
from .base_tests import SupersetTestCase
|
||||
from .fixtures.datasource import datasource_post
|
||||
|
|
@ -63,6 +64,33 @@ class DatasourceTests(SupersetTestCase):
|
|||
else:
|
||||
self.assertEqual(resp[k], datasource_post[k])
|
||||
|
||||
def test_save_duplicate_key(self):
|
||||
self.login(username="admin")
|
||||
tbl_id = self.get_table_by_name("birth_names").id
|
||||
datasource_post_copy = deepcopy(datasource_post)
|
||||
datasource_post_copy["id"] = tbl_id
|
||||
datasource_post_copy["columns"].extend(
|
||||
[
|
||||
{
|
||||
"column_name": "<new column>",
|
||||
"filterable": True,
|
||||
"groupby": True,
|
||||
"expression": "<enter SQL expression here>",
|
||||
"id": "somerandomid",
|
||||
},
|
||||
{
|
||||
"column_name": "<new column>",
|
||||
"filterable": True,
|
||||
"groupby": True,
|
||||
"expression": "<enter SQL expression here>",
|
||||
"id": "somerandomid2",
|
||||
},
|
||||
]
|
||||
)
|
||||
data = dict(data=json.dumps(datasource_post_copy))
|
||||
resp = self.get_json_resp("/datasource/save/", data, raise_on_error=False)
|
||||
self.assertIn("Duplicate column name(s): <new column>", resp["error"])
|
||||
|
||||
def test_get_datasource(self):
|
||||
self.login(username="admin")
|
||||
tbl = self.get_table_by_name("birth_names")
|
||||
|
|
|
|||
Loading…
Reference in New Issue