fix: overwrite update override columns on PUT /dataset (#20862)

* update override columns

* save

* fix overwrite with session.flush

* write test

* write test

* layup

* address concerns

* address concerns
This commit is contained in:
Hugh A. Miles II 2022-07-29 21:51:35 -04:00 committed by GitHub
parent 67e3dc7c7b
commit bc435e08d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 87 additions and 14 deletions

View File

@ -50,12 +50,13 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
self,
model_id: int,
data: Dict[str, Any],
override_columns: bool = False,
override_columns: Optional[bool] = False,
):
self._model_id = model_id
self._properties = data.copy()
self._model: Optional[SqlaTable] = None
self.override_columns = override_columns
self._properties["override_columns"] = override_columns
def run(self) -> Model:
self.validate()

View File

@ -147,14 +147,22 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
@classmethod
def update(
cls, model: SqlaTable, properties: Dict[str, Any], commit: bool = True
cls,
model: SqlaTable,
properties: Dict[str, Any],
commit: bool = True,
) -> Optional[SqlaTable]:
"""
Updates a Dataset model on the metadata DB
"""
if "columns" in properties:
cls.update_columns(model, properties.pop("columns"), commit=commit)
cls.update_columns(
model,
properties.pop("columns"),
commit=commit,
override_columns=bool(properties.get("override_columns")),
)
if "metrics" in properties:
cls.update_metrics(model, properties.pop("metrics"), commit=commit)
@ -167,6 +175,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
model: SqlaTable,
property_columns: List[Dict[str, Any]],
commit: bool = True,
override_columns: bool = False,
) -> None:
"""
Creates/updates and/or deletes a list of columns, based on a
@ -180,24 +189,37 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
column_by_id = {column.id: column for column in model.columns}
seen = set()
original_cols = {obj.id for obj in model.columns}
for properties in property_columns:
if "id" in properties:
seen.add(properties["id"])
if override_columns:
for id_ in original_cols:
DatasetDAO.delete_column(column_by_id[id_], commit=False)
DatasetDAO.update_column(
column_by_id[properties["id"]],
properties,
commit=False,
)
else:
db.session.flush()
for properties in property_columns:
DatasetDAO.create_column(
{**properties, "table_id": model.id},
commit=False,
)
else:
for properties in property_columns:
if "id" in properties:
seen.add(properties["id"])
for id_ in {obj.id for obj in model.columns} - seen:
DatasetDAO.delete_column(column_by_id[id_], commit=False)
DatasetDAO.update_column(
column_by_id[properties["id"]],
properties,
commit=False,
)
else:
DatasetDAO.create_column(
{**properties, "table_id": model.id},
commit=False,
)
for id_ in {obj.id for obj in model.columns} - seen:
DatasetDAO.delete_column(column_by_id[id_], commit=False)
if commit:
db.session.commit()

View File

@ -777,6 +777,56 @@ class TestDatasetApi(SupersetTestCase):
db.session.delete(dataset)
db.session.commit()
def test_update_dataset_item_w_override_columns_same_columns(self):
"""
Dataset API: Test update dataset with override columns
"""
if backend() == "sqlite":
return
# Add default dataset
main_db = get_main_database()
dataset = self.insert_default_dataset()
prev_col_len = len(dataset.columns)
cols = [
{
"column_name": c.column_name,
"description": c.description,
"expression": c.expression,
"type": c.type,
"advanced_data_type": c.advanced_data_type,
"verbose_name": c.verbose_name,
}
for c in dataset.columns
]
cols.append(
{
"column_name": "new_col",
"description": "description",
"expression": "expression",
"type": "INTEGER",
"advanced_data_type": "ADVANCED_DATA_TYPE",
"verbose_name": "New Col",
}
)
self.login(username="admin")
dataset_data = {
"columns": cols,
}
uri = f"api/v1/dataset/{dataset.id}?override_columns=true"
rv = self.put_assert_metric(uri, dataset_data, "put")
assert rv.status_code == 200
columns = db.session.query(TableColumn).filter_by(table_id=dataset.id).all()
assert len(columns) != prev_col_len
assert len(columns) == 3
db.session.delete(dataset)
db.session.commit()
def test_update_dataset_create_column_and_metric(self):
"""
Dataset API: Test update dataset create column