diff --git a/superset/migrations/versions/2023-03-27_12-30_7e67aecbf3f1_chart_ds_constraint.py b/superset/migrations/versions/2023-03-27_12-30_7e67aecbf3f1_chart_ds_constraint.py new file mode 100644 index 000000000..f7d6ca565 --- /dev/null +++ b/superset/migrations/versions/2023-03-27_12-30_7e67aecbf3f1_chart_ds_constraint.py @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""chart-ds-constraint + +Revision ID: 7e67aecbf3f1 +Revises: b5ea9d343307 +Create Date: 2023-03-27 12:30:01.164594 + +""" + +# revision identifiers, used by Alembic. +revision = "7e67aecbf3f1" +down_revision = "07f9a902af1b" + +import json + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.ext.declarative import declarative_base + +from superset import db + +Base = declarative_base() + + +class Slice(Base): # type: ignore + __tablename__ = "slices" + + id = sa.Column(sa.Integer, primary_key=True) + params = sa.Column(sa.String(250)) + datasource_type = sa.Column(sa.String(200)) + + +def upgrade_slc(slc: Slice) -> None: + # clean up all charts with datasource_type not != table + slc.datasource_type = "table" + try: + params_dict = json.loads(slc.params) + ds_id, ds_type = params_dict["datasource"].split("__") + params_dict["datasource"] = f"{ds_id}__table" + slc.params = json.dumps(params_dict) + except Exception: + # skip any malformatted params + pass + + +def upgrade(): + bind = op.get_bind() + session = db.Session(bind=bind) + + with op.batch_alter_table("slices") as batch_op: + for slc in session.query(Slice).filter(Slice.datasource_type == "query").all(): + upgrade_slc(slc) + session.add(slc) + + batch_op.create_check_constraint( + "ck_chart_datasource", "datasource_type in ('table')" + ) + + session.commit() + session.close() + + +def downgrade(): + op.drop_constraint("ck_chart_datasource", "slices", type_="check") diff --git a/tests/integration_tests/migrations/7e67aecbf3f1_chart_ds_constraint__tests.py b/tests/integration_tests/migrations/7e67aecbf3f1_chart_ds_constraint__tests.py new file mode 100644 index 000000000..a30741c0a --- /dev/null +++ b/tests/integration_tests/migrations/7e67aecbf3f1_chart_ds_constraint__tests.py @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +from importlib import import_module + +chart_ds_constraint = import_module( + "superset.migrations.versions." "2023-03-27_12-30_7e67aecbf3f1_chart_ds_constraint", +) + +Slice = chart_ds_constraint.Slice +upgrade_slice = chart_ds_constraint.upgrade_slc + +sample_params = { + "adhoc_filters": [], + "all_columns": ["country_name", "country_code", "region", "year", "SP_UWT_TFRT"], + "applied_time_extras": {}, + "datasource": "35__query", + "groupby": [], + "row_limit": 1000, + "time_range": "No filter", + "viz_type": "table", + "granularity_sqla": "year", + "percent_metrics": [], + "dashboards": [], +} + + +def test_upgrade(): + slc = Slice(datasource_type="query", params=json.dumps(sample_params)) + + upgrade_slice(slc) + + params = json.loads(slc.params) + assert slc.datasource_type == "table" + assert params.get("datasource") == "35__table" + + +def test_upgrade_bad_json(): + slc = Slice(datasource_type="query", params=json.dumps(sample_params)) + + assert None == upgrade_slice(slc)