fix: overwrite update override columns on PUT /dataset (#20862)
* update override columns * save * fix overwrite with session.flush * write test * write test * layup * address concerns * address concerns
This commit is contained in:
parent
67e3dc7c7b
commit
bc435e08d0
|
|
@ -50,12 +50,13 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
|
|||
self,
|
||||
model_id: int,
|
||||
data: Dict[str, Any],
|
||||
override_columns: bool = False,
|
||||
override_columns: Optional[bool] = False,
|
||||
):
|
||||
self._model_id = model_id
|
||||
self._properties = data.copy()
|
||||
self._model: Optional[SqlaTable] = None
|
||||
self.override_columns = override_columns
|
||||
self._properties["override_columns"] = override_columns
|
||||
|
||||
def run(self) -> Model:
|
||||
self.validate()
|
||||
|
|
|
|||
|
|
@ -147,14 +147,22 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
|
|||
|
||||
@classmethod
|
||||
def update(
|
||||
cls, model: SqlaTable, properties: Dict[str, Any], commit: bool = True
|
||||
cls,
|
||||
model: SqlaTable,
|
||||
properties: Dict[str, Any],
|
||||
commit: bool = True,
|
||||
) -> Optional[SqlaTable]:
|
||||
"""
|
||||
Updates a Dataset model on the metadata DB
|
||||
"""
|
||||
|
||||
if "columns" in properties:
|
||||
cls.update_columns(model, properties.pop("columns"), commit=commit)
|
||||
cls.update_columns(
|
||||
model,
|
||||
properties.pop("columns"),
|
||||
commit=commit,
|
||||
override_columns=bool(properties.get("override_columns")),
|
||||
)
|
||||
|
||||
if "metrics" in properties:
|
||||
cls.update_metrics(model, properties.pop("metrics"), commit=commit)
|
||||
|
|
@ -167,6 +175,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
|
|||
model: SqlaTable,
|
||||
property_columns: List[Dict[str, Any]],
|
||||
commit: bool = True,
|
||||
override_columns: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Creates/updates and/or deletes a list of columns, based on a
|
||||
|
|
@ -180,24 +189,37 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
|
|||
|
||||
column_by_id = {column.id: column for column in model.columns}
|
||||
seen = set()
|
||||
original_cols = {obj.id for obj in model.columns}
|
||||
|
||||
for properties in property_columns:
|
||||
if "id" in properties:
|
||||
seen.add(properties["id"])
|
||||
if override_columns:
|
||||
for id_ in original_cols:
|
||||
DatasetDAO.delete_column(column_by_id[id_], commit=False)
|
||||
|
||||
DatasetDAO.update_column(
|
||||
column_by_id[properties["id"]],
|
||||
properties,
|
||||
commit=False,
|
||||
)
|
||||
else:
|
||||
db.session.flush()
|
||||
|
||||
for properties in property_columns:
|
||||
DatasetDAO.create_column(
|
||||
{**properties, "table_id": model.id},
|
||||
commit=False,
|
||||
)
|
||||
else:
|
||||
for properties in property_columns:
|
||||
if "id" in properties:
|
||||
seen.add(properties["id"])
|
||||
|
||||
for id_ in {obj.id for obj in model.columns} - seen:
|
||||
DatasetDAO.delete_column(column_by_id[id_], commit=False)
|
||||
DatasetDAO.update_column(
|
||||
column_by_id[properties["id"]],
|
||||
properties,
|
||||
commit=False,
|
||||
)
|
||||
else:
|
||||
DatasetDAO.create_column(
|
||||
{**properties, "table_id": model.id},
|
||||
commit=False,
|
||||
)
|
||||
|
||||
for id_ in {obj.id for obj in model.columns} - seen:
|
||||
DatasetDAO.delete_column(column_by_id[id_], commit=False)
|
||||
|
||||
if commit:
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -777,6 +777,56 @@ class TestDatasetApi(SupersetTestCase):
|
|||
db.session.delete(dataset)
|
||||
db.session.commit()
|
||||
|
||||
def test_update_dataset_item_w_override_columns_same_columns(self):
|
||||
"""
|
||||
Dataset API: Test update dataset with override columns
|
||||
"""
|
||||
if backend() == "sqlite":
|
||||
return
|
||||
|
||||
# Add default dataset
|
||||
main_db = get_main_database()
|
||||
dataset = self.insert_default_dataset()
|
||||
prev_col_len = len(dataset.columns)
|
||||
|
||||
cols = [
|
||||
{
|
||||
"column_name": c.column_name,
|
||||
"description": c.description,
|
||||
"expression": c.expression,
|
||||
"type": c.type,
|
||||
"advanced_data_type": c.advanced_data_type,
|
||||
"verbose_name": c.verbose_name,
|
||||
}
|
||||
for c in dataset.columns
|
||||
]
|
||||
|
||||
cols.append(
|
||||
{
|
||||
"column_name": "new_col",
|
||||
"description": "description",
|
||||
"expression": "expression",
|
||||
"type": "INTEGER",
|
||||
"advanced_data_type": "ADVANCED_DATA_TYPE",
|
||||
"verbose_name": "New Col",
|
||||
}
|
||||
)
|
||||
|
||||
self.login(username="admin")
|
||||
dataset_data = {
|
||||
"columns": cols,
|
||||
}
|
||||
uri = f"api/v1/dataset/{dataset.id}?override_columns=true"
|
||||
rv = self.put_assert_metric(uri, dataset_data, "put")
|
||||
|
||||
assert rv.status_code == 200
|
||||
|
||||
columns = db.session.query(TableColumn).filter_by(table_id=dataset.id).all()
|
||||
assert len(columns) != prev_col_len
|
||||
assert len(columns) == 3
|
||||
db.session.delete(dataset)
|
||||
db.session.commit()
|
||||
|
||||
def test_update_dataset_create_column_and_metric(self):
|
||||
"""
|
||||
Dataset API: Test update dataset create column
|
||||
|
|
|
|||
Loading…
Reference in New Issue