diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index 8889ea626..21344681f 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -45,7 +45,7 @@ from sqlalchemy import ( Text, UniqueConstraint, ) -from sqlalchemy.orm import backref, relationship, RelationshipProperty, Session +from sqlalchemy.orm import backref, relationship, Session from sqlalchemy_utils import EncryptedType from superset import conf, db, security_manager @@ -222,7 +222,7 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin): session = db.session ds_list = ( session.query(DruidDatasource) - .filter(DruidDatasource.cluster_name == self.cluster_name) + .filter(DruidDatasource.cluster_id == self.id) .filter(DruidDatasource.datasource_name.in_(datasource_names)) ) ds_map = {ds.name: ds for ds in ds_list} @@ -468,7 +468,7 @@ class DruidDatasource(Model, BaseDatasource): """ORM object referencing Druid datasources (tables)""" __tablename__ = "datasources" - __table_args__ = (UniqueConstraint("datasource_name", "cluster_name"),) + __table_args__ = (UniqueConstraint("datasource_name", "cluster_id"),) type = "druid" query_language = "json" @@ -484,11 +484,9 @@ class DruidDatasource(Model, BaseDatasource): is_hidden = Column(Boolean, default=False) filter_select_enabled = Column(Boolean, default=True) # override default fetch_values_from = Column(String(100)) - cluster_name = Column( - String(250), ForeignKey("clusters.cluster_name"), nullable=False - ) + cluster_id = Column(Integer, ForeignKey("clusters.id"), nullable=False) cluster = relationship( - "DruidCluster", backref="datasources", foreign_keys=[cluster_name] + "DruidCluster", backref="datasources", foreign_keys=[cluster_id] ) owners = relationship( owner_class, secondary=druiddatasource_user, backref="druiddatasources" @@ -499,7 +497,7 @@ class DruidDatasource(Model, BaseDatasource): "is_hidden", "description", "default_endpoint", - "cluster_name", + "cluster_id", "offset", "cache_timeout", "params", @@ -511,7 +509,15 @@ class DruidDatasource(Model, BaseDatasource): export_children = ["columns", "metrics"] @property - def database(self) -> RelationshipProperty: + def cluster_name(self) -> str: + cluster = ( + self.cluster + or db.session.query(DruidCluster).filter_by(id=self.cluster_id).one() + ) + return cluster.cluster_name + + @property + def database(self) -> DruidCluster: return self.cluster @property @@ -608,17 +614,13 @@ class DruidDatasource(Model, BaseDatasource): db.session.query(DruidDatasource) .filter( DruidDatasource.datasource_name == d.datasource_name, - DruidCluster.cluster_name == d.cluster_name, + DruidDatasource.cluster_id == d.cluster_id, ) .first() ) def lookup_cluster(d: DruidDatasource) -> Optional[DruidCluster]: - return ( - db.session.query(DruidCluster) - .filter_by(cluster_name=d.cluster_name) - .one() - ) + return db.session.query(DruidCluster).filter_by(id=d.cluster_id).first() return import_datasource.import_datasource( db.session, i_datasource, lookup_cluster, lookup_datasource, import_time @@ -1615,12 +1617,7 @@ class DruidDatasource(Model, BaseDatasource): def query_datasources_by_name( cls, session: Session, database: Database, datasource_name: str, schema=None ) -> List["DruidDatasource"]: - return ( - session.query(cls) - .filter_by(cluster_name=database.id) - .filter_by(datasource_name=datasource_name) - .all() - ) + return [] def external_metadata(self) -> List[Dict]: self.merge_flag = True diff --git a/superset/connectors/druid/views.py b/superset/connectors/druid/views.py index 07f9c2ab5..d7d2ba084 100644 --- a/superset/connectors/druid/views.py +++ b/superset/connectors/druid/views.py @@ -341,7 +341,7 @@ class DruidDatasourceModelView(DatasourceModelView, DeleteMixin, YamlExportMixin with db.session.no_autoflush: query = db.session.query(models.DruidDatasource).filter( models.DruidDatasource.datasource_name == datasource.datasource_name, - models.DruidDatasource.cluster_name == datasource.cluster.id, + models.DruidDatasource.cluster_id == datasource.cluster_id, ) if db.session.query(query.exists()).scalar(): raise Exception(get_datasource_exist_error_msg(datasource.full_name)) diff --git a/superset/migrations/versions/e96dbf2cfef0_datasource_cluster_fk.py b/superset/migrations/versions/e96dbf2cfef0_datasource_cluster_fk.py new file mode 100644 index 000000000..e94505ba3 --- /dev/null +++ b/superset/migrations/versions/e96dbf2cfef0_datasource_cluster_fk.py @@ -0,0 +1,124 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""datasource_cluster_fk + +Revision ID: e96dbf2cfef0 +Revises: 817e1c9b09d0 +Create Date: 2020-01-08 01:17:40.127610 + +""" +import sqlalchemy as sa +from alembic import op + +from superset import db +from superset.utils.core import ( + generic_find_fk_constraint_name, + generic_find_uq_constraint_name, +) + +# revision identifiers, used by Alembic. +revision = "e96dbf2cfef0" +down_revision = "817e1c9b09d0" + + +def upgrade(): + bind = op.get_bind() + insp = sa.engine.reflection.Inspector.from_engine(bind) + + # Add cluster_id column + with op.batch_alter_table("datasources") as batch_op: + batch_op.add_column(sa.Column("cluster_id", sa.Integer())) + + # Update cluster_id values + metadata = sa.MetaData(bind=bind) + datasources = sa.Table("datasources", metadata, autoload=True) + clusters = sa.Table("clusters", metadata, autoload=True) + + statement = datasources.update().values( + cluster_id=sa.select([clusters.c.id]) + .where(datasources.c.cluster_name == clusters.c.cluster_name) + .as_scalar() + ) + bind.execute(statement) + + with op.batch_alter_table("datasources") as batch_op: + # Drop cluster_name column + fk_constraint_name = generic_find_fk_constraint_name( + "datasources", {"cluster_name"}, "clusters", insp + ) + uq_constraint_name = generic_find_uq_constraint_name( + "datasources", {"cluster_name", "datasource_name"}, insp + ) + batch_op.drop_constraint(fk_constraint_name, type_="foreignkey") + batch_op.drop_constraint(uq_constraint_name, type_="unique") + batch_op.drop_column("cluster_name") + + # Add constraints to cluster_id column + batch_op.alter_column("cluster_id", existing_type=sa.Integer, nullable=False) + batch_op.create_unique_constraint( + "uq_datasources_cluster_id", ["cluster_id", "datasource_name"] + ) + batch_op.create_foreign_key( + "fk_datasources_cluster_id_clusters", "clusters", ["cluster_id"], ["id"] + ) + + +def downgrade(): + bind = op.get_bind() + insp = sa.engine.reflection.Inspector.from_engine(bind) + + # Add cluster_name column + with op.batch_alter_table("datasources") as batch_op: + batch_op.add_column(sa.Column("cluster_name", sa.String(250))) + + # Update cluster_name values + metadata = sa.MetaData(bind=bind) + datasources = sa.Table("datasources", metadata, autoload=True) + clusters = sa.Table("clusters", metadata, autoload=True) + + statement = datasources.update().values( + cluster_name=sa.select([clusters.c.cluster_name]) + .where(datasources.c.cluster_id == clusters.c.id) + .as_scalar() + ) + bind.execute(statement) + + with op.batch_alter_table("datasources") as batch_op: + # Drop cluster_id column + fk_constraint_name = generic_find_fk_constraint_name( + "datasources", {"id"}, "clusters", insp + ) + uq_constraint_name = generic_find_uq_constraint_name( + "datasources", {"cluster_id", "datasource_name"}, insp + ) + batch_op.drop_constraint(fk_constraint_name, type_="foreignkey") + batch_op.drop_constraint(uq_constraint_name, type_="unique") + batch_op.drop_column("cluster_id") + + # Add constraints to cluster_name column + batch_op.alter_column( + "cluster_name", existing_type=sa.String(250), nullable=False + ) + batch_op.create_unique_constraint( + "uq_datasources_cluster_name", ["cluster_name", "datasource_name"] + ) + batch_op.create_foreign_key( + "fk_datasources_cluster_name_clusters", + "clusters", + ["cluster_name"], + ["cluster_name"], + ) diff --git a/superset/utils/core.py b/superset/utils/core.py index 4ba09fa3d..ce55814b2 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -35,7 +35,7 @@ from email.mime.text import MIMEText from email.utils import formatdate from enum import Enum from time import struct_time -from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Tuple, Union +from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Set, Tuple, Union from urllib.parse import unquote_plus import bleach @@ -46,7 +46,8 @@ import parsedatetime import sqlalchemy as sa from dateutil.parser import parse from dateutil.relativedelta import relativedelta -from flask import current_app, flash, g, Markup, render_template +from flask import current_app, flash, Flask, g, Markup, render_template +from flask_appbuilder import SQLA from flask_appbuilder.security.sqla.models import User from flask_babel import gettext as __, lazy_gettext as _ from sqlalchemy import event, exc, select, Text @@ -487,7 +488,9 @@ def readfile(file_path: str) -> Optional[str]: return content -def generic_find_constraint_name(table, columns, referenced, db): +def generic_find_constraint_name( + table: str, columns: Set[str], referenced: str, db: SQLA +): """Utility to find a constraint name in alembic migrations""" t = sa.Table(table, db.metadata, autoload=True, autoload_with=db.engine) @@ -496,7 +499,9 @@ def generic_find_constraint_name(table, columns, referenced, db): return fk.name -def generic_find_fk_constraint_name(table, columns, referenced, insp): +def generic_find_fk_constraint_name( + table: str, columns: Set[str], referenced: str, insp +): """Utility to find a foreign-key constraint name in alembic migrations""" for fk in insp.get_foreign_keys(table): if ( diff --git a/tests/base_tests.py b/tests/base_tests.py index 0549c97b2..280e98aee 100644 --- a/tests/base_tests.py +++ b/tests/base_tests.py @@ -64,6 +64,7 @@ class SupersetTestCase(TestCase): @classmethod def create_druid_test_objects(cls): # create druid cluster and druid datasources + with app.app_context(): session = db.session cluster = ( @@ -75,11 +76,11 @@ class SupersetTestCase(TestCase): session.commit() druid_datasource1 = DruidDatasource( - datasource_name="druid_ds_1", cluster_name="druid_test" + datasource_name="druid_ds_1", cluster=cluster ) session.add(druid_datasource1) druid_datasource2 = DruidDatasource( - datasource_name="druid_ds_2", cluster_name="druid_test" + datasource_name="druid_ds_2", cluster=cluster ) session.add(druid_datasource2) session.commit() diff --git a/tests/dict_import_export_tests.py b/tests/dict_import_export_tests.py index d30443e3c..404709cac 100644 --- a/tests/dict_import_export_tests.py +++ b/tests/dict_import_export_tests.py @@ -23,7 +23,12 @@ import yaml from tests.test_app import app from superset import db -from superset.connectors.druid.models import DruidColumn, DruidDatasource, DruidMetric +from superset.connectors.druid.models import ( + DruidColumn, + DruidDatasource, + DruidMetric, + DruidCluster, +) from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.utils.core import get_example_database from superset.utils.dict_import_export import export_to_dict @@ -87,11 +92,15 @@ class DictImportExportTests(SupersetTestCase): return table, dict_rep def create_druid_datasource(self, name, id=0, cols_names=[], metric_names=[]): - name = "{0}{1}".format(NAME_PREFIX, name) cluster_name = "druid_test" + cluster = self.get_or_create( + DruidCluster, {"cluster_name": cluster_name}, db.session + ) + + name = "{0}{1}".format(NAME_PREFIX, name) params = {DBREF: id, "database_name": cluster_name} dict_rep = { - "cluster_name": cluster_name, + "cluster_id": cluster.id, "datasource_name": name, "id": id, "params": json.dumps(params), @@ -102,7 +111,7 @@ class DictImportExportTests(SupersetTestCase): datasource = DruidDatasource( id=id, datasource_name=name, - cluster_name=cluster_name, + cluster_id=cluster.id, params=json.dumps(params), ) for col_name in cols_names: diff --git a/tests/druid_tests.py b/tests/druid_tests.py index 059ac4c7e..4a8fe5387 100644 --- a/tests/druid_tests.py +++ b/tests/druid_tests.py @@ -131,9 +131,7 @@ class DruidTests(SupersetTestCase): ) if cluster: for datasource in ( - db.session.query(DruidDatasource) - .filter_by(cluster_name=cluster.cluster_name) - .all() + db.session.query(DruidDatasource).filter_by(cluster_id=cluster.id).all() ): db.session.delete(datasource) @@ -358,9 +356,7 @@ class DruidTests(SupersetTestCase): ) if cluster: for datasource in ( - db.session.query(DruidDatasource) - .filter_by(cluster_name=cluster.cluster_name) - .all() + db.session.query(DruidDatasource).filter_by(cluster_id=cluster.id).all() ): db.session.delete(datasource) diff --git a/tests/import_export_tests.py b/tests/import_export_tests.py index 807da18ef..b932bf031 100644 --- a/tests/import_export_tests.py +++ b/tests/import_export_tests.py @@ -25,7 +25,12 @@ from sqlalchemy.orm.session import make_transient from tests.test_app import app from superset.utils.dashboard_import_export import decode_dashboards from superset import db, security_manager -from superset.connectors.druid.models import DruidColumn, DruidDatasource, DruidMetric +from superset.connectors.druid.models import ( + DruidColumn, + DruidDatasource, + DruidMetric, + DruidCluster, +) from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.models.dashboard import Dashboard from superset.models.slice import Slice @@ -119,11 +124,16 @@ class ImportExportTests(SupersetTestCase): return table def create_druid_datasource(self, name, id=0, cols_names=[], metric_names=[]): - params = {"remote_id": id, "database_name": "druid_test"} + cluster_name = "druid_test" + cluster = self.get_or_create( + DruidCluster, {"cluster_name": cluster_name}, db.session + ) + + params = {"remote_id": id, "database_name": cluster_name} datasource = DruidDatasource( id=id, datasource_name=name, - cluster_name="druid_test", + cluster_id=cluster.id, params=json.dumps(params), ) for col_name in cols_names: diff --git a/tests/security_tests.py b/tests/security_tests.py index 3b792e4e8..67877bc5a 100644 --- a/tests/security_tests.py +++ b/tests/security_tests.py @@ -238,7 +238,7 @@ class RolePermissionTests(SupersetTestCase): datasource = DruidDatasource( datasource_name="tmp_datasource", cluster=druid_cluster, - cluster_name="druid_test", + cluster_id=druid_cluster.id, ) session.add(datasource) session.commit()