From bc435e08d01b87efcf8774f29a7078cee8988e39 Mon Sep 17 00:00:00 2001 From: "Hugh A. Miles II" Date: Fri, 29 Jul 2022 21:51:35 -0400 Subject: [PATCH] fix: overwrite update override columns on PUT /dataset (#20862) * update override columns * save * fix overwrite with session.flush * write test * write test * layup * address concerns * address concerns --- superset/datasets/commands/update.py | 3 +- superset/datasets/dao.py | 48 +++++++++++++----- tests/integration_tests/datasets/api_tests.py | 50 +++++++++++++++++++ 3 files changed, 87 insertions(+), 14 deletions(-) diff --git a/superset/datasets/commands/update.py b/superset/datasets/commands/update.py index e3c908ceb..483a98e76 100644 --- a/superset/datasets/commands/update.py +++ b/superset/datasets/commands/update.py @@ -50,12 +50,13 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand): self, model_id: int, data: Dict[str, Any], - override_columns: bool = False, + override_columns: Optional[bool] = False, ): self._model_id = model_id self._properties = data.copy() self._model: Optional[SqlaTable] = None self.override_columns = override_columns + self._properties["override_columns"] = override_columns def run(self) -> Model: self.validate() diff --git a/superset/datasets/dao.py b/superset/datasets/dao.py index a538a70c1..d260df361 100644 --- a/superset/datasets/dao.py +++ b/superset/datasets/dao.py @@ -147,14 +147,22 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods @classmethod def update( - cls, model: SqlaTable, properties: Dict[str, Any], commit: bool = True + cls, + model: SqlaTable, + properties: Dict[str, Any], + commit: bool = True, ) -> Optional[SqlaTable]: """ Updates a Dataset model on the metadata DB """ if "columns" in properties: - cls.update_columns(model, properties.pop("columns"), commit=commit) + cls.update_columns( + model, + properties.pop("columns"), + commit=commit, + override_columns=bool(properties.get("override_columns")), + ) if "metrics" in properties: cls.update_metrics(model, properties.pop("metrics"), commit=commit) @@ -167,6 +175,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods model: SqlaTable, property_columns: List[Dict[str, Any]], commit: bool = True, + override_columns: bool = False, ) -> None: """ Creates/updates and/or deletes a list of columns, based on a @@ -180,24 +189,37 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods column_by_id = {column.id: column for column in model.columns} seen = set() + original_cols = {obj.id for obj in model.columns} - for properties in property_columns: - if "id" in properties: - seen.add(properties["id"]) + if override_columns: + for id_ in original_cols: + DatasetDAO.delete_column(column_by_id[id_], commit=False) - DatasetDAO.update_column( - column_by_id[properties["id"]], - properties, - commit=False, - ) - else: + db.session.flush() + + for properties in property_columns: DatasetDAO.create_column( {**properties, "table_id": model.id}, commit=False, ) + else: + for properties in property_columns: + if "id" in properties: + seen.add(properties["id"]) - for id_ in {obj.id for obj in model.columns} - seen: - DatasetDAO.delete_column(column_by_id[id_], commit=False) + DatasetDAO.update_column( + column_by_id[properties["id"]], + properties, + commit=False, + ) + else: + DatasetDAO.create_column( + {**properties, "table_id": model.id}, + commit=False, + ) + + for id_ in {obj.id for obj in model.columns} - seen: + DatasetDAO.delete_column(column_by_id[id_], commit=False) if commit: db.session.commit() diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 46739f963..a993f0c0b 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -777,6 +777,56 @@ class TestDatasetApi(SupersetTestCase): db.session.delete(dataset) db.session.commit() + def test_update_dataset_item_w_override_columns_same_columns(self): + """ + Dataset API: Test update dataset with override columns + """ + if backend() == "sqlite": + return + + # Add default dataset + main_db = get_main_database() + dataset = self.insert_default_dataset() + prev_col_len = len(dataset.columns) + + cols = [ + { + "column_name": c.column_name, + "description": c.description, + "expression": c.expression, + "type": c.type, + "advanced_data_type": c.advanced_data_type, + "verbose_name": c.verbose_name, + } + for c in dataset.columns + ] + + cols.append( + { + "column_name": "new_col", + "description": "description", + "expression": "expression", + "type": "INTEGER", + "advanced_data_type": "ADVANCED_DATA_TYPE", + "verbose_name": "New Col", + } + ) + + self.login(username="admin") + dataset_data = { + "columns": cols, + } + uri = f"api/v1/dataset/{dataset.id}?override_columns=true" + rv = self.put_assert_metric(uri, dataset_data, "put") + + assert rv.status_code == 200 + + columns = db.session.query(TableColumn).filter_by(table_id=dataset.id).all() + assert len(columns) != prev_col_len + assert len(columns) == 3 + db.session.delete(dataset) + db.session.commit() + def test_update_dataset_create_column_and_metric(self): """ Dataset API: Test update dataset create column