From fd2d1c58c566d9312d6cfc5641a06ac2b03e753a Mon Sep 17 00:00:00 2001 From: Erik Ritter Date: Thu, 6 Aug 2020 15:33:48 -0700 Subject: [PATCH] Revert "chore: Cleanup database sessions (#10427)" (#10537) This reverts commit 7645fc85c3d6676a13ae76ca5133f83d8fb54dbe. --- superset/cli.py | 12 +- superset/commands/utils.py | 6 +- superset/common/query_context.py | 4 +- superset/connectors/connector_registry.py | 35 +++-- superset/connectors/druid/models.py | 52 ++++--- superset/connectors/druid/views.py | 5 +- superset/connectors/sqla/models.py | 29 +++- superset/dashboards/dao.py | 8 +- superset/models/dashboard.py | 36 ++--- superset/models/helpers.py | 13 +- superset/models/slice.py | 8 +- superset/models/tags.py | 81 ++++++---- superset/security/manager.py | 6 +- superset/sql_lab.py | 3 +- superset/tasks/cache.py | 23 ++- superset/tasks/schedules.py | 13 +- superset/utils/dashboard_import_export.py | 11 +- superset/utils/dict_import_export.py | 19 ++- superset/utils/import_datasource.py | 22 +-- superset/views/base.py | 3 +- superset/views/chart/views.py | 4 +- superset/views/core.py | 129 +++++++++------- superset/views/datasource.py | 6 +- superset/views/utils.py | 7 +- tests/access_tests.py | 92 ++++++----- tests/alerts_tests.py | 176 +++++++++++----------- tests/base_tests.py | 34 +++-- tests/celery_tests.py | 10 +- tests/charts/api_tests.py | 4 +- tests/core_tests.py | 24 +-- tests/database_api_tests.py | 5 +- tests/datasets/api_tests.py | 7 +- tests/dict_import_export_tests.py | 60 +++++--- tests/druid_tests.py | 10 +- tests/import_export_tests.py | 23 +-- tests/query_context_tests.py | 1 + tests/security_tests.py | 132 ++++++++-------- tests/sqllab_tests.py | 6 +- tests/strategy_tests.py | 6 +- 39 files changed, 637 insertions(+), 488 deletions(-) diff --git a/superset/cli.py b/superset/cli.py index ef9ee847f..ef176822d 100755 --- a/superset/cli.py +++ b/superset/cli.py @@ -197,9 +197,10 @@ def set_database_uri(database_name: str, uri: str) -> None: ) def refresh_druid(datasource: str, merge: bool) -> None: """Refresh druid datasources""" + session = db.session() from superset.connectors.druid.models import DruidCluster - for cluster in db.session.query(DruidCluster).all(): + for cluster in session.query(DruidCluster).all(): try: cluster.refresh_datasources(datasource_name=datasource, merge_flag=merge) except Exception as ex: # pylint: disable=broad-except @@ -207,7 +208,7 @@ def refresh_druid(datasource: str, merge: bool) -> None: logger.exception(ex) cluster.metadata_last_refreshed = datetime.now() print("Refreshed metadata from cluster " "[" + cluster.cluster_name + "]") - db.session.commit() + session.commit() @superset.command() @@ -249,7 +250,7 @@ def import_dashboards(path: str, recursive: bool, username: str) -> None: logger.info("Importing dashboard from file %s", file_) try: with file_.open() as data_stream: - dashboard_import_export.import_dashboards(data_stream) + dashboard_import_export.import_dashboards(db.session, data_stream) except Exception as ex: # pylint: disable=broad-except logger.error("Error when importing dashboard from file %s", file_) logger.error(ex) @@ -267,7 +268,7 @@ def export_dashboards(dashboard_file: str, print_stdout: bool) -> None: """Export dashboards to JSON""" from superset.utils import dashboard_import_export - data = dashboard_import_export.export_dashboards() + data = dashboard_import_export.export_dashboards(db.session) if print_stdout or not dashboard_file: print(data) if dashboard_file: @@ -320,7 +321,7 @@ def import_datasources(path: str, sync: str, recursive: bool) -> None: try: with file_.open() as data_stream: dict_import_export.import_from_dict( - yaml.safe_load(data_stream), sync=sync_array + db.session, yaml.safe_load(data_stream), sync=sync_array ) except Exception as ex: # pylint: disable=broad-except logger.error("Error when importing datasources from file %s", file_) @@ -359,6 +360,7 @@ def export_datasources( from superset.utils import dict_import_export data = dict_import_export.export_to_dict( + session=db.session, recursive=True, back_references=back_references, include_defaults=include_defaults, diff --git a/superset/commands/utils.py b/superset/commands/utils.py index 66fd5433f..c0bd8b707 100644 --- a/superset/commands/utils.py +++ b/superset/commands/utils.py @@ -25,7 +25,7 @@ from superset.commands.exceptions import ( ) from superset.connectors.base.models import BaseDatasource from superset.connectors.connector_registry import ConnectorRegistry -from superset.extensions import security_manager +from superset.extensions import db, security_manager def populate_owners(user: User, owners_ids: Optional[List[int]] = None) -> List[User]: @@ -50,6 +50,8 @@ def populate_owners(user: User, owners_ids: Optional[List[int]] = None) -> List[ def get_datasource_by_id(datasource_id: int, datasource_type: str) -> BaseDatasource: try: - return ConnectorRegistry.get_datasource(datasource_type, datasource_id) + return ConnectorRegistry.get_datasource( + datasource_type, datasource_id, db.session + ) except (NoResultFound, KeyError): raise DatasourceNotFoundValidationError() diff --git a/superset/common/query_context.py b/superset/common/query_context.py index 401f2625e..e602fbfac 100644 --- a/superset/common/query_context.py +++ b/superset/common/query_context.py @@ -23,7 +23,7 @@ from typing import Any, ClassVar, Dict, List, Optional, Union import numpy as np import pandas as pd -from superset import app, cache, security_manager +from superset import app, cache, db, security_manager from superset.common.query_object import QueryObject from superset.connectors.base.models import BaseDatasource from superset.connectors.connector_registry import ConnectorRegistry @@ -64,7 +64,7 @@ class QueryContext: result_format: Optional[utils.ChartDataResultFormat] = None, ) -> None: self.datasource = ConnectorRegistry.get_datasource( - str(datasource["type"]), int(datasource["id"]) + str(datasource["type"]), int(datasource["id"]), db.session ) self.queries = [QueryObject(**query_obj) for query_obj in queries] self.force = force diff --git a/superset/connectors/connector_registry.py b/superset/connectors/connector_registry.py index 7c47bc7b3..fff2f8e8f 100644 --- a/superset/connectors/connector_registry.py +++ b/superset/connectors/connector_registry.py @@ -17,9 +17,7 @@ from typing import Dict, List, Optional, Set, Type, TYPE_CHECKING from sqlalchemy import or_ -from sqlalchemy.orm import subqueryload - -from superset.extensions import db +from sqlalchemy.orm import Session, subqueryload if TYPE_CHECKING: # pylint: disable=unused-import @@ -45,20 +43,20 @@ class ConnectorRegistry: @classmethod def get_datasource( - cls, datasource_type: str, datasource_id: int + cls, datasource_type: str, datasource_id: int, session: Session ) -> "BaseDatasource": return ( - db.session.query(cls.sources[datasource_type]) + session.query(cls.sources[datasource_type]) .filter_by(id=datasource_id) .one() ) @classmethod - def get_all_datasources(cls) -> List["BaseDatasource"]: + def get_all_datasources(cls, session: Session) -> List["BaseDatasource"]: datasources: List["BaseDatasource"] = [] for source_type in ConnectorRegistry.sources: source_class = ConnectorRegistry.sources[source_type] - qry = db.session.query(source_class) + qry = session.query(source_class) qry = source_class.default_query(qry) datasources.extend(qry.all()) return datasources @@ -66,6 +64,7 @@ class ConnectorRegistry: @classmethod def get_datasource_by_name( # pylint: disable=too-many-arguments cls, + session: Session, datasource_type: str, datasource_name: str, schema: str, @@ -73,17 +72,21 @@ class ConnectorRegistry: ) -> Optional["BaseDatasource"]: datasource_class = ConnectorRegistry.sources[datasource_type] return datasource_class.get_datasource_by_name( - datasource_name, schema, database_name + session, datasource_name, schema, database_name ) @classmethod def query_datasources_by_permissions( # pylint: disable=invalid-name - cls, database: "Database", permissions: Set[str], schema_perms: Set[str], + cls, + session: Session, + database: "Database", + permissions: Set[str], + schema_perms: Set[str], ) -> List["BaseDatasource"]: # TODO(bogdan): add unit test datasource_class = ConnectorRegistry.sources[database.type] return ( - db.session.query(datasource_class) + session.query(datasource_class) .filter_by(database_id=database.id) .filter( or_( @@ -96,12 +99,12 @@ class ConnectorRegistry: @classmethod def get_eager_datasource( - cls, datasource_type: str, datasource_id: int + cls, session: Session, datasource_type: str, datasource_id: int ) -> "BaseDatasource": """Returns datasource with columns and metrics.""" datasource_class = ConnectorRegistry.sources[datasource_type] return ( - db.session.query(datasource_class) + session.query(datasource_class) .options( subqueryload(datasource_class.columns), subqueryload(datasource_class.metrics), @@ -112,9 +115,13 @@ class ConnectorRegistry: @classmethod def query_datasources_by_name( - cls, database: "Database", datasource_name: str, schema: Optional[str] = None, + cls, + session: Session, + database: "Database", + datasource_name: str, + schema: Optional[str] = None, ) -> List["BaseDatasource"]: datasource_class = ConnectorRegistry.sources[database.type] return datasource_class.query_datasources_by_name( - database, datasource_name, schema=schema + session, database, datasource_name, schema=schema ) diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index 0068f111c..162163f2b 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -45,7 +45,7 @@ from sqlalchemy import ( UniqueConstraint, ) from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm import backref, relationship +from sqlalchemy.orm import backref, relationship, Session from sqlalchemy.sql import expression from sqlalchemy_utils import EncryptedType @@ -223,8 +223,9 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin): Fetches metadata for the specified datasources and merges to the Superset database """ + session = db.session ds_list = ( - db.session.query(DruidDatasource) + session.query(DruidDatasource) .filter(DruidDatasource.cluster_id == self.id) .filter(DruidDatasource.datasource_name.in_(datasource_names)) ) @@ -233,8 +234,8 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin): datasource = ds_map.get(ds_name, None) if not datasource: datasource = DruidDatasource(datasource_name=ds_name) - with db.session.no_autoflush: - db.session.add(datasource) + with session.no_autoflush: + session.add(datasource) flasher(_("Adding new datasource [{}]").format(ds_name), "success") ds_map[ds_name] = datasource elif refresh_all: @@ -244,7 +245,7 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin): continue datasource.cluster = self datasource.merge_flag = merge_flag - db.session.flush() + session.flush() # Prepare multithreaded executation pool = ThreadPool() @@ -258,7 +259,7 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin): cols = metadata[i] if cols: col_objs_list = ( - db.session.query(DruidColumn) + session.query(DruidColumn) .filter(DruidColumn.datasource_id == datasource.id) .filter(DruidColumn.column_name.in_(cols.keys())) ) @@ -271,15 +272,15 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin): col_obj = DruidColumn( datasource_id=datasource.id, column_name=col ) - with db.session.no_autoflush: - db.session.add(col_obj) + with session.no_autoflush: + session.add(col_obj) col_obj.type = cols[col]["type"] col_obj.datasource = datasource if col_obj.type == "STRING": col_obj.groupby = True col_obj.filterable = True datasource.refresh_metrics() - db.session.commit() + session.commit() @hybrid_property def perm(self) -> str: @@ -389,7 +390,7 @@ class DruidColumn(Model, BaseColumn): .first() ) - return import_datasource.import_simple_obj(i_column, lookup_obj) + return import_datasource.import_simple_obj(db.session, i_column, lookup_obj) class DruidMetric(Model, BaseMetric): @@ -458,7 +459,7 @@ class DruidMetric(Model, BaseMetric): .first() ) - return import_datasource.import_simple_obj(i_metric, lookup_obj) + return import_datasource.import_simple_obj(db.session, i_metric, lookup_obj) druiddatasource_user = Table( @@ -634,7 +635,7 @@ class DruidDatasource(Model, BaseDatasource): return db.session.query(DruidCluster).filter_by(id=d.cluster_id).first() return import_datasource.import_datasource( - i_datasource, lookup_cluster, lookup_datasource, import_time + db.session, i_datasource, lookup_cluster, lookup_datasource, import_time ) def latest_metadata(self) -> Optional[Dict[str, Any]]: @@ -704,10 +705,9 @@ class DruidDatasource(Model, BaseDatasource): refresh: bool = True, ) -> None: """Merges the ds config from druid_config into one stored in the db.""" + session = db.session datasource = ( - db.session.query(cls) - .filter_by(datasource_name=druid_config["name"]) - .first() + session.query(cls).filter_by(datasource_name=druid_config["name"]).first() ) # Create a new datasource. if not datasource: @@ -718,13 +718,13 @@ class DruidDatasource(Model, BaseDatasource): changed_by_fk=user.id, created_by_fk=user.id, ) - db.session.add(datasource) + session.add(datasource) elif not refresh: return dimensions = druid_config["dimensions"] col_objs = ( - db.session.query(DruidColumn) + session.query(DruidColumn) .filter(DruidColumn.datasource_id == datasource.id) .filter(DruidColumn.column_name.in_(dimensions)) ) @@ -741,10 +741,10 @@ class DruidDatasource(Model, BaseDatasource): type="STRING", datasource=datasource, ) - db.session.add(col_obj) + session.add(col_obj) # Import Druid metrics metric_objs = ( - db.session.query(DruidMetric) + session.query(DruidMetric) .filter(DruidMetric.datasource_id == datasource.id) .filter( DruidMetric.metric_name.in_( @@ -777,8 +777,8 @@ class DruidDatasource(Model, BaseDatasource): % druid_config["name"] ), ) - db.session.add(metric_obj) - db.session.commit() + session.add(metric_obj) + session.commit() @staticmethod def time_offset(granularity: Granularity) -> int: @@ -788,10 +788,10 @@ class DruidDatasource(Model, BaseDatasource): @classmethod def get_datasource_by_name( - cls, datasource_name: str, schema: str, database_name: str + cls, session: Session, datasource_name: str, schema: str, database_name: str ) -> Optional["DruidDatasource"]: query = ( - db.session.query(cls) + session.query(cls) .join(DruidCluster) .filter(cls.datasource_name == datasource_name) .filter(DruidCluster.cluster_name == database_name) @@ -1724,7 +1724,11 @@ class DruidDatasource(Model, BaseDatasource): @classmethod def query_datasources_by_name( - cls, database: Database, datasource_name: str, schema: Optional[str] = None, + cls, + session: Session, + database: Database, + datasource_name: str, + schema: Optional[str] = None, ) -> List["DruidDatasource"]: return [] diff --git a/superset/connectors/druid/views.py b/superset/connectors/druid/views.py index 4a22bc2bc..4c2fbf991 100644 --- a/superset/connectors/druid/views.py +++ b/superset/connectors/druid/views.py @@ -365,10 +365,11 @@ class Druid(BaseSupersetView): self, refresh_all: bool = True ) -> FlaskResponse: """endpoint that refreshes druid datasources metadata""" + session = db.session() DruidCluster = ConnectorRegistry.sources[ # pylint: disable=invalid-name "druid" ].cluster_class - for cluster in db.session.query(DruidCluster).all(): + for cluster in session.query(DruidCluster).all(): cluster_name = cluster.cluster_name valid_cluster = True try: @@ -390,7 +391,7 @@ class Druid(BaseSupersetView): ), "info", ) - db.session.commit() + session.commit() return redirect("/druiddatasourcemodelview/list/") @has_access diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 82f30178d..530a2e10a 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -41,7 +41,7 @@ from sqlalchemy import ( Text, ) from sqlalchemy.exc import CompileError -from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty +from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, Session from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.schema import UniqueConstraint from sqlalchemy.sql import column, ColumnElement, literal_column, table, text @@ -255,7 +255,7 @@ class TableColumn(Model, BaseColumn): .first() ) - return import_datasource.import_simple_obj(i_column, lookup_obj) + return import_datasource.import_simple_obj(db.session, i_column, lookup_obj) def dttm_sql_literal( self, @@ -375,7 +375,7 @@ class SqlMetric(Model, BaseMetric): .first() ) - return import_datasource.import_simple_obj(i_metric, lookup_obj) + return import_datasource.import_simple_obj(db.session, i_metric, lookup_obj) sqlatable_user = Table( @@ -503,11 +503,15 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at @classmethod def get_datasource_by_name( - cls, datasource_name: str, schema: Optional[str], database_name: str, + cls, + session: Session, + datasource_name: str, + schema: Optional[str], + database_name: str, ) -> Optional["SqlaTable"]: schema = schema or None query = ( - db.session.query(cls) + session.query(cls) .join(Database) .filter(cls.table_name == datasource_name) .filter(Database.database_name == database_name) @@ -1292,15 +1296,24 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at ) return import_datasource.import_datasource( - i_datasource, lookup_database, lookup_sqlatable, import_time, database_id, + db.session, + i_datasource, + lookup_database, + lookup_sqlatable, + import_time, + database_id, ) @classmethod def query_datasources_by_name( - cls, database: Database, datasource_name: str, schema: Optional[str] = None, + cls, + session: Session, + database: Database, + datasource_name: str, + schema: Optional[str] = None, ) -> List["SqlaTable"]: query = ( - db.session.query(cls) + session.query(cls) .filter_by(database_id=database.id) .filter_by(table_name=datasource_name) ) diff --git a/superset/dashboards/dao.py b/superset/dashboards/dao.py index 6345bb70a..774e1c8d4 100644 --- a/superset/dashboards/dao.py +++ b/superset/dashboards/dao.py @@ -99,7 +99,9 @@ class DashboardDAO(BaseDAO): except KeyError: pass - current_slices = db.session.query(Slice).filter(Slice.id.in_(slice_ids)).all() + session = db.session() + current_slices = session.query(Slice).filter(Slice.id.in_(slice_ids)).all() + dashboard.slices = current_slices # update slice names. this assumes user has permissions to update the slice @@ -109,8 +111,8 @@ class DashboardDAO(BaseDAO): new_name = slice_id_to_name[slc.id] if slc.slice_name != new_name: slc.slice_name = new_name - db.session.merge(slc) - db.session.flush() + session.merge(slc) + session.flush() except KeyError: pass diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py index 04e9251c7..18445436e 100644 --- a/superset/models/dashboard.py +++ b/superset/models/dashboard.py @@ -37,7 +37,7 @@ from sqlalchemy import ( UniqueConstraint, ) from sqlalchemy.engine.base import Connection -from sqlalchemy.orm import relationship, subqueryload +from sqlalchemy.orm import relationship, sessionmaker, subqueryload from sqlalchemy.orm.mapper import Mapper from superset import app, ConnectorRegistry, db, is_feature_enabled, security_manager @@ -62,17 +62,18 @@ config = app.config logger = logging.getLogger(__name__) -def copy_dashboard( # pylint: disable=unused-argument - mapper: Mapper, connection: Connection, target: "Dashboard" -) -> None: +def copy_dashboard(mapper: Mapper, connection: Connection, target: "Dashboard") -> None: + # pylint: disable=unused-argument dashboard_id = config["DASHBOARD_TEMPLATE_ID"] if dashboard_id is None: return - new_user = db.session.query(User).filter_by(id=target.id).first() + session_class = sessionmaker(autoflush=False) + session = session_class(bind=connection) + new_user = session.query(User).filter_by(id=target.id).first() # copy template dashboard to user - template = db.session.query(Dashboard).filter_by(id=int(dashboard_id)).first() + template = session.query(Dashboard).filter_by(id=int(dashboard_id)).first() dashboard = Dashboard( dashboard_title=template.dashboard_title, position_json=template.position_json, @@ -82,15 +83,15 @@ def copy_dashboard( # pylint: disable=unused-argument slices=template.slices, owners=[new_user], ) - db.session.add(dashboard) - db.session.commit() + session.add(dashboard) + session.commit() # set dashboard as the welcome dashboard extra_attributes = UserAttribute( user_id=target.id, welcome_dashboard_id=dashboard.id ) - db.session.add(extra_attributes) - db.session.commit() + session.add(extra_attributes) + session.commit() sqla.event.listen(User, "after_insert", copy_dashboard) @@ -306,6 +307,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes logger.info( "Started import of the dashboard: %s", dashboard_to_import.to_json() ) + session = db.session logger.info("Dashboard has %d slices", len(dashboard_to_import.slices)) # copy slices object as Slice.import_slice will mutate the slice # and will remove the existing dashboard - slice association @@ -322,7 +324,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes i_params_dict = dashboard_to_import.params_dict remote_id_slice_map = { slc.params_dict["remote_id"]: slc - for slc in db.session.query(Slice).all() + for slc in session.query(Slice).all() if "remote_id" in slc.params_dict } for slc in slices: @@ -373,7 +375,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes # override the dashboard existing_dashboard = None - for dash in db.session.query(Dashboard).all(): + for dash in session.query(Dashboard).all(): if ( "remote_id" in dash.params_dict and dash.params_dict["remote_id"] == dashboard_to_import.id @@ -400,7 +402,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes ) new_slices = ( - db.session.query(Slice) + session.query(Slice) .filter(Slice.id.in_(old_to_new_slc_id_dict.values())) .all() ) @@ -408,12 +410,12 @@ class Dashboard( # pylint: disable=too-many-instance-attributes if existing_dashboard: existing_dashboard.override(dashboard_to_import) existing_dashboard.slices = new_slices - db.session.flush() + session.flush() return existing_dashboard.id dashboard_to_import.slices = new_slices - db.session.add(dashboard_to_import) - db.session.flush() + session.add(dashboard_to_import) + session.flush() return dashboard_to_import.id # type: ignore @classmethod @@ -455,7 +457,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes eager_datasources = [] for datasource_id, datasource_type in datasource_ids: eager_datasource = ConnectorRegistry.get_eager_datasource( - datasource_type, datasource_id + db.session, datasource_type, datasource_id ) copied_datasource = eager_datasource.copy() copied_datasource.alter_params( diff --git a/superset/models/helpers.py b/superset/models/helpers.py index c67c6b61e..d903d271a 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -34,9 +34,9 @@ from flask_appbuilder.models.mixins import AuditMixin from flask_appbuilder.security.sqla.models import User from sqlalchemy import and_, or_, UniqueConstraint from sqlalchemy.ext.declarative import declared_attr +from sqlalchemy.orm import Session from sqlalchemy.orm.exc import MultipleResultsFound -from superset.extensions import db from superset.utils.core import QueryStatus logger = logging.getLogger(__name__) @@ -127,6 +127,7 @@ class ImportMixin: @classmethod def import_from_dict( # pylint: disable=too-many-arguments,too-many-branches,too-many-locals cls, + session: Session, dict_rep: Dict[Any, Any], parent: Optional[Any] = None, recursive: bool = True, @@ -177,7 +178,7 @@ class ImportMixin: # Check if object already exists in DB, break if more than one is found try: - obj_query = db.session.query(cls).filter(and_(*filters)) + obj_query = session.query(cls).filter(and_(*filters)) obj = obj_query.one_or_none() except MultipleResultsFound as ex: logger.error( @@ -195,7 +196,7 @@ class ImportMixin: logger.info("Importing new %s %s", obj.__tablename__, str(obj)) if cls.export_parent and parent: setattr(obj, cls.export_parent, parent) - db.session.add(obj) + session.add(obj) else: is_new_obj = False logger.info("Updating %s %s", obj.__tablename__, str(obj)) @@ -213,7 +214,7 @@ class ImportMixin: for c_obj in new_children.get(child, []): added.append( child_class.import_from_dict( - dict_rep=c_obj, parent=obj, sync=sync + session=session, dict_rep=c_obj, parent=obj, sync=sync ) ) # If children should get synced, delete the ones that did not @@ -227,11 +228,11 @@ class ImportMixin: for k in back_refs.keys() ] to_delete = set( - db.session.query(child_class).filter(and_(*delete_filters)) + session.query(child_class).filter(and_(*delete_filters)) ).difference(set(added)) for o in to_delete: logger.info("Deleting %s %s", child, str(obj)) - db.session.delete(o) + session.delete(o) return obj diff --git a/superset/models/slice.py b/superset/models/slice.py index 0a2e7d5f7..b7f9e0537 100644 --- a/superset/models/slice.py +++ b/superset/models/slice.py @@ -300,6 +300,7 @@ class Slice( :returns: The resulting id for the imported slice :rtype: int """ + session = db.session make_transient(slc_to_import) slc_to_import.dashboards = [] slc_to_import.alter_params(remote_id=slc_to_import.id, import_time=import_time) @@ -308,6 +309,7 @@ class Slice( slc_to_import.reset_ownership() params = slc_to_import.params_dict datasource = ConnectorRegistry.get_datasource_by_name( + session, slc_to_import.datasource_type, params["datasource_name"], params["schema"], @@ -316,11 +318,11 @@ class Slice( slc_to_import.datasource_id = datasource.id # type: ignore if slc_to_override: slc_to_override.override(slc_to_import) - db.session.flush() + session.flush() return slc_to_override.id - db.session.add(slc_to_import) + session.add(slc_to_import) logger.info("Final slice: %s", str(slc_to_import.to_json())) - db.session.flush() + session.flush() return slc_to_import.id @property diff --git a/superset/models/tags.py b/superset/models/tags.py index 1302ff581..c09bb1685 100644 --- a/superset/models/tags.py +++ b/superset/models/tags.py @@ -22,11 +22,10 @@ from typing import List, Optional, TYPE_CHECKING, Union from flask_appbuilder import Model from sqlalchemy import Column, Enum, ForeignKey, Integer, String from sqlalchemy.engine.base import Connection -from sqlalchemy.orm import relationship +from sqlalchemy.orm import relationship, Session, sessionmaker from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.orm.mapper import Mapper -from superset.extensions import db from superset.models.helpers import AuditMixinNullable if TYPE_CHECKING: @@ -35,6 +34,8 @@ if TYPE_CHECKING: from superset.models.slice import Slice # pylint: disable=unused-import from superset.models.sql_lab import Query # pylint: disable=unused-import +Session = sessionmaker(autoflush=False) + class TagTypes(enum.Enum): @@ -87,13 +88,13 @@ class TaggedObject(Model, AuditMixinNullable): tag = relationship("Tag", backref="objects") -def get_tag(name: str, type_: TagTypes) -> Tag: +def get_tag(name: str, session: Session, type_: TagTypes) -> Tag: try: - tag = db.session.query(Tag).filter_by(name=name, type=type_).one() + tag = session.query(Tag).filter_by(name=name, type=type_).one() except NoResultFound: tag = Tag(name=name, type=type_) - db.session.add(tag) - db.session.commit() + session.add(tag) + session.commit() return tag @@ -121,43 +122,52 @@ class ObjectUpdater: raise NotImplementedError("Subclass should implement `get_owners_ids`") @classmethod - def _add_owners(cls, target: Union["Dashboard", "FavStar", "Slice"]) -> None: + def _add_owners( + cls, session: Session, target: Union["Dashboard", "FavStar", "Slice"] + ) -> None: for owner_id in cls.get_owners_ids(target): name = "owner:{0}".format(owner_id) - tag = get_tag(name, TagTypes.owner) + tag = get_tag(name, session, TagTypes.owner) tagged_object = TaggedObject( tag_id=tag.id, object_id=target.id, object_type=cls.object_type ) - db.session.add(tagged_object) + session.add(tagged_object) @classmethod - def after_insert( # pylint: disable=unused-argument + def after_insert( cls, mapper: Mapper, connection: Connection, target: Union["Dashboard", "FavStar", "Slice"], ) -> None: + # pylint: disable=unused-argument + session = Session(bind=connection) + # add `owner:` tags - cls._add_owners(target) + cls._add_owners(session, target) # add `type:` tags - tag = get_tag("type:{0}".format(cls.object_type), TagTypes.type) + tag = get_tag("type:{0}".format(cls.object_type), session, TagTypes.type) tagged_object = TaggedObject( tag_id=tag.id, object_id=target.id, object_type=cls.object_type ) - db.session.add(tagged_object) - db.session.commit() + session.add(tagged_object) + + session.commit() @classmethod - def after_update( # pylint: disable=unused-argument + def after_update( cls, mapper: Mapper, connection: Connection, target: Union["Dashboard", "FavStar", "Slice"], ) -> None: + # pylint: disable=unused-argument + session = Session(bind=connection) + # delete current `owner:` tags query = ( - db.session.query(TaggedObject.id) + session.query(TaggedObject.id) .join(Tag) .filter( TaggedObject.object_type == cls.object_type, @@ -166,28 +176,32 @@ class ObjectUpdater: ) ) ids = [row[0] for row in query] - db.session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete( + session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete( synchronize_session=False ) # add `owner:` tags - cls._add_owners(target) - db.session.commit() + cls._add_owners(session, target) + + session.commit() @classmethod - def after_delete( # pylint: disable=unused-argument + def after_delete( cls, mapper: Mapper, connection: Connection, target: Union["Dashboard", "FavStar", "Slice"], ) -> None: + # pylint: disable=unused-argument + session = Session(bind=connection) + # delete row from `tagged_objects` - db.session.query(TaggedObject).filter( + session.query(TaggedObject).filter( TaggedObject.object_type == cls.object_type, TaggedObject.object_id == target.id, ).delete() - db.session.commit() + session.commit() class ChartUpdater(ObjectUpdater): @@ -219,26 +233,31 @@ class QueryUpdater(ObjectUpdater): class FavStarUpdater: @classmethod - def after_insert( # pylint: disable=unused-argument + def after_insert( cls, mapper: Mapper, connection: Connection, target: "FavStar" ) -> None: + # pylint: disable=unused-argument + session = Session(bind=connection) name = "favorited_by:{0}".format(target.user_id) - tag = get_tag(name, TagTypes.favorited_by) + tag = get_tag(name, session, TagTypes.favorited_by) tagged_object = TaggedObject( tag_id=tag.id, object_id=target.obj_id, object_type=get_object_type(target.class_name), ) - db.session.add(tagged_object) - db.session.commit() + session.add(tagged_object) + + session.commit() @classmethod - def after_delete( # pylint: disable=unused-argument - cls, mapper: Mapper, connection: Connection, target: "FavStar", + def after_delete( + cls, mapper: Mapper, connection: Connection, target: "FavStar" ) -> None: + # pylint: disable=unused-argument + session = Session(bind=connection) name = "favorited_by:{0}".format(target.user_id) query = ( - db.session.query(TaggedObject.id) + session.query(TaggedObject.id) .join(Tag) .filter( TaggedObject.object_id == target.obj_id, @@ -247,8 +266,8 @@ class FavStarUpdater: ) ) ids = [row[0] for row in query] - db.session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete( + session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete( synchronize_session=False ) - db.session.commit() + session.commit() diff --git a/superset/security/manager.py b/superset/security/manager.py index f731c4284..da92d1684 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -507,7 +507,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods user_perms = self.user_view_menu_names("datasource_access") schema_perms = self.user_view_menu_names("schema_access") user_datasources = ConnectorRegistry.query_datasources_by_permissions( - database, user_perms, schema_perms + self.get_session, database, user_perms, schema_perms ) if schema: names = {d.table_name for d in user_datasources if d.schema == schema} @@ -568,7 +568,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods self.add_permission_view_menu(view_menu, perm) logger.info("Creating missing datasource permissions.") - datasources = ConnectorRegistry.get_all_datasources() + datasources = ConnectorRegistry.get_all_datasources(self.get_session) for datasource in datasources: merge_pv("datasource_access", datasource.get_perm()) merge_pv("schema_access", datasource.get_schema_perm()) @@ -901,7 +901,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods if not (schema_perm and self.can_access("schema_access", schema_perm)): datasources = SqlaTable.query_datasources_by_name( - database, table_.table, schema=table_.schema + self.get_session, database, table_.table, schema=table_.schema ) # Access to any datasource is suffice. diff --git a/superset/sql_lab.py b/superset/sql_lab.py index d941473eb..8c3f24fc5 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -132,7 +132,8 @@ def session_scope(nullpool: bool) -> Iterator[Session]: ) if nullpool: engine = sqlalchemy.create_engine(database_uri, poolclass=NullPool) - session_class = sessionmaker(bind=engine) + session_class = sessionmaker() + session_class.configure(bind=engine) session = session_class() else: session = db.session() diff --git a/superset/tasks/cache.py b/superset/tasks/cache.py index f4c9e3294..54b0dc1b1 100644 --- a/superset/tasks/cache.py +++ b/superset/tasks/cache.py @@ -134,7 +134,8 @@ class DummyStrategy(Strategy): name = "dummy" def get_urls(self) -> List[str]: - charts = db.session.query(Slice).all() + session = db.create_scoped_session() + charts = session.query(Slice).all() return [get_url(chart) for chart in charts] @@ -166,9 +167,10 @@ class TopNDashboardsStrategy(Strategy): def get_urls(self) -> List[str]: urls = [] + session = db.create_scoped_session() records = ( - db.session.query(Log.dashboard_id, func.count(Log.dashboard_id)) + session.query(Log.dashboard_id, func.count(Log.dashboard_id)) .filter(and_(Log.dashboard_id.isnot(None), Log.dttm >= self.since)) .group_by(Log.dashboard_id) .order_by(func.count(Log.dashboard_id).desc()) @@ -176,9 +178,7 @@ class TopNDashboardsStrategy(Strategy): .all() ) dash_ids = [record.dashboard_id for record in records] - dashboards = ( - db.session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all() - ) + dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all() for dashboard in dashboards: for chart in dashboard.slices: form_data_with_filters = get_form_data(chart.id, dashboard) @@ -211,13 +211,14 @@ class DashboardTagsStrategy(Strategy): def get_urls(self) -> List[str]: urls = [] + session = db.create_scoped_session() - tags = db.session.query(Tag).filter(Tag.name.in_(self.tags)).all() + tags = session.query(Tag).filter(Tag.name.in_(self.tags)).all() tag_ids = [tag.id for tag in tags] # add dashboards that are tagged tagged_objects = ( - db.session.query(TaggedObject) + session.query(TaggedObject) .filter( and_( TaggedObject.object_type == "dashboard", @@ -227,16 +228,14 @@ class DashboardTagsStrategy(Strategy): .all() ) dash_ids = [tagged_object.object_id for tagged_object in tagged_objects] - tagged_dashboards = db.session.query(Dashboard).filter( - Dashboard.id.in_(dash_ids) - ) + tagged_dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)) for dashboard in tagged_dashboards: for chart in dashboard.slices: urls.append(get_url(chart)) # add charts that are tagged tagged_objects = ( - db.session.query(TaggedObject) + session.query(TaggedObject) .filter( and_( TaggedObject.object_type == "chart", @@ -246,7 +245,7 @@ class DashboardTagsStrategy(Strategy): .all() ) chart_ids = [tagged_object.object_id for tagged_object in tagged_objects] - tagged_charts = db.session.query(Slice).filter(Slice.id.in_(chart_ids)) + tagged_charts = session.query(Slice).filter(Slice.id.in_(chart_ids)) for chart in tagged_charts: urls.append(get_url(chart)) diff --git a/superset/tasks/schedules.py b/superset/tasks/schedules.py index 0a74ddf88..4fd55aaf6 100644 --- a/superset/tasks/schedules.py +++ b/superset/tasks/schedules.py @@ -47,6 +47,7 @@ from flask_login import login_user from retry.api import retry_call from selenium.common.exceptions import WebDriverException from selenium.webdriver import chrome, firefox +from sqlalchemy.orm import Session from werkzeug.http import parse_cookie from superset import app, db, security_manager, thumbnail_cache @@ -541,7 +542,8 @@ def schedule_alert_query( # pylint: disable=unused-argument is_test_alert: Optional[bool] = False, ) -> None: model_cls = get_scheduler_model(report_type) - schedule = db.session.query(model_cls).get(schedule_id) + dbsession = db.create_scoped_session() + schedule = dbsession.query(model_cls).get(schedule_id) # The user may have disabled the schedule. If so, ignore this if not schedule or not schedule.active: @@ -553,7 +555,7 @@ def schedule_alert_query( # pylint: disable=unused-argument deliver_alert(schedule.id, recipients) return - if run_alert_query(schedule.id): + if run_alert_query(schedule.id, dbsession): # deliver_dashboard OR deliver_slice return else: @@ -616,7 +618,7 @@ def deliver_alert(alert_id: int, recipients: Optional[str] = None) -> None: _deliver_email(recipients, deliver_as_group, subject, body, data, images) -def run_alert_query(alert_id: int) -> Optional[bool]: +def run_alert_query(alert_id: int, dbsession: Session) -> Optional[bool]: """ Execute alert.sql and return value if any rows are returned """ @@ -670,7 +672,7 @@ def run_alert_query(alert_id: int) -> Optional[bool]: state=state, ) ) - db.session.commit() + dbsession.commit() return None @@ -710,7 +712,8 @@ def schedule_window( if not model_cls: return None - schedules = db.session.query(model_cls).filter(model_cls.active.is_(True)) + dbsession = db.create_scoped_session() + schedules = dbsession.query(model_cls).filter(model_cls.active.is_(True)) for schedule in schedules: logging.info("Processing schedule %s", schedule) diff --git a/superset/utils/dashboard_import_export.py b/superset/utils/dashboard_import_export.py index f8f673deb..6ae500b40 100644 --- a/superset/utils/dashboard_import_export.py +++ b/superset/utils/dashboard_import_export.py @@ -22,10 +22,10 @@ from io import BytesIO from typing import Any, Dict, Optional from flask_babel import lazy_gettext as _ +from sqlalchemy.orm import Session from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.exceptions import DashboardImportException -from superset.extensions import db from superset.models.dashboard import Dashboard from superset.models.slice import Slice @@ -71,6 +71,7 @@ def decode_dashboards( # pylint: disable=too-many-return-statements def import_dashboards( + session: Session, data_stream: BytesIO, database_id: Optional[int] = None, import_time: Optional[int] = None, @@ -83,16 +84,16 @@ def import_dashboards( raise DashboardImportException(_("No data in file")) for table in data["datasources"]: type(table).import_obj(table, database_id, import_time=import_time) - db.session.commit() + session.commit() for dashboard in data["dashboards"]: Dashboard.import_obj(dashboard, import_time=import_time) - db.session.commit() + session.commit() -def export_dashboards() -> str: +def export_dashboards(session: Session) -> str: """Returns all dashboards metadata as a json dump""" logger.info("Starting export") - dashboards = db.session.query(Dashboard) + dashboards = session.query(Dashboard) dashboard_ids = [] for dashboard in dashboards: dashboard_ids.append(dashboard.id) diff --git a/superset/utils/dict_import_export.py b/superset/utils/dict_import_export.py index 8edae22e7..4d9e0496b 100644 --- a/superset/utils/dict_import_export.py +++ b/superset/utils/dict_import_export.py @@ -17,8 +17,9 @@ import logging from typing import Any, Dict, List, Optional +from sqlalchemy.orm import Session + from superset.connectors.druid.models import DruidCluster -from superset.extensions import db from superset.models.core import Database DATABASES_KEY = "databases" @@ -43,11 +44,11 @@ def export_schema_to_dict(back_references: bool) -> Dict[str, Any]: def export_to_dict( - recursive: bool, back_references: bool, include_defaults: bool + session: Session, recursive: bool, back_references: bool, include_defaults: bool ) -> Dict[str, Any]: """Exports databases and druid clusters to a dictionary""" logger.info("Starting export") - dbs = db.session.query(Database) + dbs = session.query(Database) databases = [ database.export_to_dict( recursive=recursive, @@ -57,7 +58,7 @@ def export_to_dict( for database in dbs ] logger.info("Exported %d %s", len(databases), DATABASES_KEY) - cls = db.session.query(DruidCluster) + cls = session.query(DruidCluster) clusters = [ cluster.export_to_dict( recursive=recursive, @@ -75,20 +76,22 @@ def export_to_dict( return data -def import_from_dict(data: Dict[str, Any], sync: Optional[List[str]] = None) -> None: +def import_from_dict( + session: Session, data: Dict[str, Any], sync: Optional[List[str]] = None +) -> None: """Imports databases and druid clusters from dictionary""" if not sync: sync = [] if isinstance(data, dict): logger.info("Importing %d %s", len(data.get(DATABASES_KEY, [])), DATABASES_KEY) for database in data.get(DATABASES_KEY, []): - Database.import_from_dict(database, sync=sync) + Database.import_from_dict(session, database, sync=sync) logger.info( "Importing %d %s", len(data.get(DRUID_CLUSTERS_KEY, [])), DRUID_CLUSTERS_KEY ) for datasource in data.get(DRUID_CLUSTERS_KEY, []): - DruidCluster.import_from_dict(datasource, sync=sync) - db.session.commit() + DruidCluster.import_from_dict(session, datasource, sync=sync) + session.commit() else: logger.info("Supplied object is not a dictionary.") diff --git a/superset/utils/import_datasource.py b/superset/utils/import_datasource.py index a59a3d658..25da876b2 100644 --- a/superset/utils/import_datasource.py +++ b/superset/utils/import_datasource.py @@ -18,14 +18,14 @@ import logging from typing import Callable, Optional from flask_appbuilder import Model +from sqlalchemy.orm import Session from sqlalchemy.orm.session import make_transient -from superset.extensions import db - logger = logging.getLogger(__name__) def import_datasource( # pylint: disable=too-many-arguments + session: Session, i_datasource: Model, lookup_database: Callable[[Model], Model], lookup_datasource: Callable[[Model], Model], @@ -52,11 +52,11 @@ def import_datasource( # pylint: disable=too-many-arguments if datasource: datasource.override(i_datasource) - db.session.flush() + session.flush() else: datasource = i_datasource.copy() - db.session.add(datasource) - db.session.flush() + session.add(datasource) + session.flush() for metric in i_datasource.metrics: new_m = metric.copy() @@ -81,11 +81,13 @@ def import_datasource( # pylint: disable=too-many-arguments imported_c = i_datasource.column_class.import_obj(new_c) if imported_c.column_name not in [c.column_name for c in datasource.columns]: datasource.columns.append(imported_c) - db.session.flush() + session.flush() return datasource.id -def import_simple_obj(i_obj: Model, lookup_obj: Callable[[Model], Model]) -> Model: +def import_simple_obj( + session: Session, i_obj: Model, lookup_obj: Callable[[Model], Model] +) -> Model: make_transient(i_obj) i_obj.id = None i_obj.table = None @@ -95,9 +97,9 @@ def import_simple_obj(i_obj: Model, lookup_obj: Callable[[Model], Model]) -> Mod i_obj.table = None if existing_column: existing_column.override(i_obj) - db.session.flush() + session.flush() return existing_column - db.session.add(i_obj) - db.session.flush() + session.add(i_obj) + session.flush() return i_obj diff --git a/superset/views/base.py b/superset/views/base.py index 58c4943d0..7aeae79bd 100644 --- a/superset/views/base.py +++ b/superset/views/base.py @@ -487,7 +487,8 @@ def check_ownership(obj: Any, raise_if_false: bool = True) -> bool: roles = [r.name for r in get_user_roles()] if "Admin" in roles: return True - orig_obj = db.session.query(obj.__class__).filter_by(id=obj.id).first() + scoped_session = db.create_scoped_session() + orig_obj = scoped_session.query(obj.__class__).filter_by(id=obj.id).first() # Making a list of owners that works across ORM models owners: List[User] = [] diff --git a/superset/views/chart/views.py b/superset/views/chart/views.py index db100a74c..0523e33aa 100644 --- a/superset/views/chart/views.py +++ b/superset/views/chart/views.py @@ -20,7 +20,7 @@ from flask_appbuilder import expose, has_access from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_babel import lazy_gettext as _ -from superset import app +from superset import app, db from superset.connectors.connector_registry import ConnectorRegistry from superset.constants import RouteMethod from superset.models.slice import Slice @@ -56,7 +56,7 @@ class SliceModelView( def add(self) -> FlaskResponse: datasources = [ {"value": str(d.id) + "__" + d.type, "label": repr(d)} - for d in ConnectorRegistry.get_all_datasources() + for d in ConnectorRegistry.get_all_datasources(db.session) ] return self.render_template( "superset/add_slice.html", diff --git a/superset/views/core.py b/superset/views/core.py index bcbce029b..f3a226459 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -40,6 +40,7 @@ from sqlalchemy.exc import ( OperationalError, SQLAlchemyError, ) +from sqlalchemy.orm.session import Session from werkzeug.urls import Href import superset.models.core as models @@ -163,7 +164,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods sorted( [ datasource.short_data - for datasource in ConnectorRegistry.get_all_datasources() + for datasource in ConnectorRegistry.get_all_datasources(db.session) if datasource.short_data.get("name") ], key=lambda datasource: datasource["name"], @@ -202,7 +203,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods ) db_ds_names.add(fullname) - existing_datasources = ConnectorRegistry.get_all_datasources() + existing_datasources = ConnectorRegistry.get_all_datasources(db.session) datasources = [d for d in existing_datasources if d.full_name in db_ds_names] role = security_manager.find_role(role_name) # remove all permissions @@ -269,15 +270,15 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods @has_access @expose("/approve") def approve(self) -> FlaskResponse: # pylint: disable=too-many-locals,no-self-use - def clean_fulfilled_requests() -> None: - for dar in db.session.query(DAR).all(): + def clean_fulfilled_requests(session: Session) -> None: + for dar in session.query(DAR).all(): datasource = ConnectorRegistry.get_datasource( - dar.datasource_type, dar.datasource_id + dar.datasource_type, dar.datasource_id, session ) if not datasource or security_manager.can_access_datasource(datasource): # datasource does not exist anymore - db.session.delete(dar) - db.session.commit() + session.delete(dar) + session.commit() datasource_type = request.args["datasource_type"] datasource_id = request.args["datasource_id"] @@ -285,7 +286,10 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods role_to_grant = request.args.get("role_to_grant") role_to_extend = request.args.get("role_to_extend") - datasource = ConnectorRegistry.get_datasource(datasource_type, datasource_id) + session = db.session + datasource = ConnectorRegistry.get_datasource( + datasource_type, datasource_id, session + ) if not datasource: flash(DATASOURCE_MISSING_ERR, "alert") @@ -297,7 +301,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods return json_error_response(USER_MISSING_ERR) requests = ( - db.session.query(DAR) + session.query(DAR) .filter( DAR.datasource_id == datasource_id, DAR.datasource_type == datasource_type, @@ -357,13 +361,13 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods app.config, ) flash(msg, "info") - clean_fulfilled_requests() + clean_fulfilled_requests(session) else: flash(__("You have no permission to approve this request"), "danger") return redirect("/accessrequestsmodelview/list/") for request_ in requests: - db.session.delete(request_) - db.session.commit() + session.delete(request_) + session.commit() return redirect("/accessrequestsmodelview/list/") @has_access @@ -544,7 +548,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods database_id = request.form.get("db_id") try: dashboard_import_export.import_dashboards( - import_file.stream, database_id + db.session, import_file.stream, database_id ) success = True except DatabaseNotFound as ex: @@ -626,7 +630,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods return redirect(error_redirect) datasource = ConnectorRegistry.get_datasource( - cast(str, datasource_type), datasource_id + cast(str, datasource_type), datasource_id, db.session ) if not datasource: flash(DATASOURCE_MISSING_ERR, "danger") @@ -745,7 +749,9 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods :raises SupersetSecurityException: If the user cannot access the resource """ # TODO: Cache endpoint by user, datasource and column - datasource = ConnectorRegistry.get_datasource(datasource_type, datasource_id) + datasource = ConnectorRegistry.get_datasource( + datasource_type, datasource_id, db.session + ) if not datasource: return json_error_response(DATASOURCE_MISSING_ERR) @@ -1009,9 +1015,10 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods self, dashboard_id: int ) -> FlaskResponse: """Copy dashboard""" + session = db.session() data = json.loads(request.form["data"]) dash = models.Dashboard() - original_dash = db.session.query(Dashboard).get(dashboard_id) + original_dash = session.query(Dashboard).get(dashboard_id) dash.owners = [g.user] if g.user else [] dash.dashboard_title = data["dashboard_title"] @@ -1022,8 +1029,8 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods for slc in original_dash.slices: new_slice = slc.clone() new_slice.owners = [g.user] if g.user else [] - db.session.add(new_slice) - db.session.flush() + session.add(new_slice) + session.flush() new_slice.dashboards.append(dash) old_to_new_slice_ids[slc.id] = new_slice.id @@ -1039,9 +1046,10 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods dash.params = original_dash.params DashboardDAO.set_dash_metadata(dash, data, old_to_new_slice_ids) - db.session.add(dash) - db.session.commit() + session.add(dash) + session.commit() dash_json = json.dumps(dash.data) + session.close() return json_success(dash_json) @api @@ -1051,12 +1059,14 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods self, dashboard_id: int ) -> FlaskResponse: """Save a dashboard's metadata""" - dash = db.session.query(Dashboard).get(dashboard_id) + session = db.session() + dash = session.query(Dashboard).get(dashboard_id) check_ownership(dash, raise_if_false=True) data = json.loads(request.form["data"]) DashboardDAO.set_dash_metadata(dash, data) - db.session.merge(dash) - db.session.commit() + session.merge(dash) + session.commit() + session.close() return json_success(json.dumps({"status": "SUCCESS"})) @api @@ -1067,12 +1077,14 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods ) -> FlaskResponse: """Add and save slices to a dashboard""" data = json.loads(request.form["data"]) - dash = db.session.query(Dashboard).get(dashboard_id) + session = db.session() + dash = session.query(Dashboard).get(dashboard_id) check_ownership(dash, raise_if_false=True) - new_slices = db.session.query(Slice).filter(Slice.id.in_(data["slice_ids"])) + new_slices = session.query(Slice).filter(Slice.id.in_(data["slice_ids"])) dash.slices += new_slices - db.session.merge(dash) - db.session.commit() + session.merge(dash) + session.commit() + session.close() return "SLICES ADDED" @api @@ -1419,6 +1431,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods Note for slices a force refresh occurs. """ + session = db.session() slice_id = request.args.get("slice_id") dashboard_id = request.args.get("dashboard_id") table_name = request.args.get("table_name") @@ -1433,14 +1446,14 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods status=400, ) if slice_id: - slices = db.session.query(Slice).filter_by(id=slice_id).all() + slices = session.query(Slice).filter_by(id=slice_id).all() if not slices: return json_error_response( __("Chart %(id)s not found", id=slice_id), status=404 ) elif table_name and db_name: table = ( - db.session.query(SqlaTable) + session.query(SqlaTable) .join(models.Database) .filter( models.Database.database_name == db_name @@ -1457,7 +1470,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods status=404, ) slices = ( - db.session.query(Slice) + session.query(Slice) .filter_by(datasource_id=table.id, datasource_type=table.type) .all() ) @@ -1500,16 +1513,17 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods self, class_name: str, obj_id: int, action: str ) -> FlaskResponse: """Toggle favorite stars on Slices and Dashboard""" + session = db.session() FavStar = models.FavStar count = 0 favs = ( - db.session.query(FavStar) + session.query(FavStar) .filter_by(class_name=class_name, obj_id=obj_id, user_id=g.user.get_id()) .all() ) if action == "select": if not favs: - db.session.add( + session.add( FavStar( class_name=class_name, obj_id=obj_id, @@ -1520,10 +1534,10 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods count = 1 elif action == "unselect": for fav in favs: - db.session.delete(fav) + session.delete(fav) else: count = len(favs) - db.session.commit() + session.commit() return json_success(json.dumps({"count": count})) @api @@ -1536,13 +1550,12 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods logger.warning( "This API endpoint is deprecated and will be removed in version 1.0.0" ) + session = db.session() Role = ab_models.Role dash = ( - db.session.query(Dashboard) - .filter(Dashboard.id == dashboard_id) - .one_or_none() + session.query(Dashboard).filter(Dashboard.id == dashboard_id).one_or_none() ) - admin_role = db.session.query(Role).filter(Role.name == "Admin").one_or_none() + admin_role = session.query(Role).filter(Role.name == "Admin").one_or_none() if request.method == "GET": if dash: @@ -1561,7 +1574,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods ) dash.published = str(request.form["published"]).lower() == "true" - db.session.commit() + session.commit() return json_success(json.dumps({"published": dash.published})) @has_access @@ -1570,7 +1583,8 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods self, dashboard_id_or_slug: str ) -> FlaskResponse: """Server side rendering for a dashboard""" - qry = db.session.query(Dashboard) + session = db.session() + qry = session.query(Dashboard) if dashboard_id_or_slug.isdigit(): qry = qry.filter_by(id=int(dashboard_id_or_slug)) else: @@ -2028,7 +2042,8 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods "SQL validation does not support template parameters", status=400 ) - mydb = db.session.query(models.Database).filter_by(id=database_id).one_or_none() + session = db.session() + mydb = session.query(models.Database).filter_by(id=database_id).one_or_none() if not mydb: return json_error_response( "Database with id {} is missing.".format(database_id), status=400 @@ -2077,6 +2092,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods @staticmethod def _sql_json_async( # pylint: disable=too-many-arguments + session: Session, rendered_query: str, query: Query, expand_data: bool, @@ -2085,6 +2101,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods """ Send SQL JSON query to celery workers. + :param session: SQLAlchemy session object :param rendered_query: the rendered query to perform by workers :param query: The query (SQLAlchemy) object :return: A Flask Response @@ -2115,7 +2132,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods ) query.status = QueryStatus.FAILED query.error_message = msg - db.session.commit() + session.commit() return json_error_response("{}".format(msg)) resp = json_success( json.dumps( @@ -2125,11 +2142,12 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods ), status=202, ) - db.session.commit() + session.commit() return resp @staticmethod def _sql_json_sync( + _session: Session, rendered_query: str, query: Query, expand_data: bool, @@ -2223,7 +2241,8 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods tab_name: str = cast(str, query_params.get("tab")) status: str = QueryStatus.PENDING if async_flag else QueryStatus.RUNNING - mydb = db.session.query(models.Database).get(database_id) + session = db.session() + mydb = session.query(models.Database).get(database_id) if not mydb: return json_error_response("Database with id %i is missing.", database_id) @@ -2254,13 +2273,13 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods client_id=client_id, ) try: - db.session.add(query) - db.session.flush() + session.add(query) + session.flush() query_id = query.id - db.session.commit() # shouldn't be necessary + session.commit() # shouldn't be necessary except SQLAlchemyError as ex: logger.error("Errors saving query details %s", str(ex)) - db.session.rollback() + session.rollback() raise Exception(_("Query record was not created as expected.")) if not query_id: raise Exception(_("Query record was not created as expected.")) @@ -2271,7 +2290,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods query.raise_for_access() except SupersetSecurityException as ex: query.status = QueryStatus.FAILED - db.session.commit() + session.commit() return json_errors_response([ex.error], status=403) try: @@ -2304,9 +2323,13 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods # Async request. if async_flag: - return self._sql_json_async(rendered_query, query, expand_data, log_params) + return self._sql_json_async( + session, rendered_query, query, expand_data, log_params + ) # Sync request. - return self._sql_json_sync(rendered_query, query, expand_data, log_params) + return self._sql_json_sync( + session, rendered_query, query, expand_data, log_params + ) @has_access @expose("/csv/") @@ -2375,7 +2398,9 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods """ datasource_id, datasource_type = request.args["datasourceKey"].split("__") - datasource = ConnectorRegistry.get_datasource(datasource_type, datasource_id) + datasource = ConnectorRegistry.get_datasource( + datasource_type, datasource_id, db.session + ) # Check if datasource exists if not datasource: return json_error_response(DATASOURCE_MISSING_ERR) diff --git a/superset/views/datasource.py b/superset/views/datasource.py index c2affcb50..2ce11027c 100644 --- a/superset/views/datasource.py +++ b/superset/views/datasource.py @@ -47,7 +47,7 @@ class Datasource(BaseSupersetView): datasource_type = datasource_dict.get("type") database_id = datasource_dict["database"].get("id") orm_datasource = ConnectorRegistry.get_datasource( - datasource_type, datasource_id + datasource_type, datasource_id, db.session ) orm_datasource.database_id = database_id @@ -82,7 +82,7 @@ class Datasource(BaseSupersetView): def get(self, datasource_type: str, datasource_id: int) -> FlaskResponse: try: orm_datasource = ConnectorRegistry.get_datasource( - datasource_type, datasource_id + datasource_type, datasource_id, db.session ) if not orm_datasource.data: return json_error_response( @@ -102,7 +102,7 @@ class Datasource(BaseSupersetView): """Gets column info from the source system""" if datasource_type == "druid": datasource = ConnectorRegistry.get_datasource( - datasource_type, datasource_id + datasource_type, datasource_id, db.session ) elif datasource_type == "table": database = ( diff --git a/superset/views/utils.py b/superset/views/utils.py index dd2aa064a..2a8b2ccc0 100644 --- a/superset/views/utils.py +++ b/superset/views/utils.py @@ -105,7 +105,9 @@ def get_viz( form_data: FormData, datasource_type: str, datasource_id: int, force: bool = False ) -> BaseViz: viz_type = form_data.get("viz_type", "table") - datasource = ConnectorRegistry.get_datasource(datasource_type, datasource_id) + datasource = ConnectorRegistry.get_datasource( + datasource_type, datasource_id, db.session + ) viz_obj = viz.viz_types[viz_type](datasource, form_data=form_data, force=force) return viz_obj @@ -291,7 +293,8 @@ CONTAINER_TYPES = ["COLUMN", "GRID", "TABS", "TAB", "ROW"] def get_dashboard_extra_filters( slice_id: int, dashboard_id: int ) -> List[Dict[str, Any]]: - dashboard = db.session.query(Dashboard).filter_by(id=dashboard_id).one_or_none() + session = db.session() + dashboard = session.query(Dashboard).filter_by(id=dashboard_id).one_or_none() # is chart in this dashboard? if ( diff --git a/tests/access_tests.py b/tests/access_tests.py index 10a4867d7..d452d13aa 100644 --- a/tests/access_tests.py +++ b/tests/access_tests.py @@ -71,17 +71,13 @@ DB_ACCESS_ROLE = "db_access_role" SCHEMA_ACCESS_ROLE = "schema_access_role" -def create_access_request(ds_type, ds_name, role_name, user_name): +def create_access_request(session, ds_type, ds_name, role_name, user_name): ds_class = ConnectorRegistry.sources[ds_type] # TODO: generalize datasource names if ds_type == "table": - ds = db.session.query(ds_class).filter(ds_class.table_name == ds_name).first() + ds = session.query(ds_class).filter(ds_class.table_name == ds_name).first() else: - ds = ( - db.session.query(ds_class) - .filter(ds_class.datasource_name == ds_name) - .first() - ) + ds = session.query(ds_class).filter(ds_class.datasource_name == ds_name).first() ds_perm_view = security_manager.find_permission_view_menu( "datasource_access", ds.perm ) @@ -93,8 +89,8 @@ def create_access_request(ds_type, ds_name, role_name, user_name): datasource_type=ds_type, created_by_fk=security_manager.find_user(username=user_name).id, ) - db.session.add(access_request) - db.session.commit() + session.add(access_request) + session.commit() return access_request @@ -130,6 +126,7 @@ class TestRequestAccess(SupersetTestCase): override_me = security_manager.find_role("override_me") override_me.permissions = [] db.session.commit() + db.session.close() def test_override_role_permissions_is_admin_only(self): self.logout() @@ -214,6 +211,7 @@ class TestRequestAccess(SupersetTestCase): ) def test_clean_requests_after_role_extend(self): + session = db.session # Case 1. Gamma and gamma2 requested test_role1 on energy_usage access # Gamma already has role test_role1 @@ -223,10 +221,12 @@ class TestRequestAccess(SupersetTestCase): # gamma2 and gamma request table_role on energy usage if app.config["ENABLE_ACCESS_REQUEST"]: access_request1 = create_access_request( - "table", "random_time_series", TEST_ROLE_1, "gamma2" + session, "table", "random_time_series", TEST_ROLE_1, "gamma2" ) ds_1_id = access_request1.datasource_id - create_access_request("table", "random_time_series", TEST_ROLE_1, "gamma") + create_access_request( + session, "table", "random_time_series", TEST_ROLE_1, "gamma" + ) access_requests = self.get_access_requests("gamma", "table", ds_1_id) self.assertTrue(access_requests) # gamma gets test_role1 @@ -244,20 +244,22 @@ class TestRequestAccess(SupersetTestCase): gamma_user.roles.remove(security_manager.find_role("test_role1")) def test_clean_requests_after_alpha_grant(self): + session = db.session + # Case 2. Two access requests from gamma and gamma2 # Gamma becomes alpha, gamma2 gets granted # Check if request by gamma has been deleted access_request1 = create_access_request( - "table", "birth_names", TEST_ROLE_1, "gamma" + session, "table", "birth_names", TEST_ROLE_1, "gamma" ) - create_access_request("table", "birth_names", TEST_ROLE_2, "gamma2") + create_access_request(session, "table", "birth_names", TEST_ROLE_2, "gamma2") ds_1_id = access_request1.datasource_id # gamma becomes alpha alpha_role = security_manager.find_role("Alpha") gamma_user = security_manager.find_user(username="gamma") gamma_user.roles.append(alpha_role) - db.session.commit() + session.commit() access_requests = self.get_access_requests("gamma", "table", ds_1_id) self.assertTrue(access_requests) self.client.get( @@ -268,21 +270,23 @@ class TestRequestAccess(SupersetTestCase): gamma_user = security_manager.find_user(username="gamma") gamma_user.roles.remove(security_manager.find_role("Alpha")) - db.session.commit() + session.commit() def test_clean_requests_after_db_grant(self): + session = db.session + # Case 3. Two access requests from gamma and gamma2 # Gamma gets database access, gamma2 access request granted # Check if request by gamma has been deleted gamma_user = security_manager.find_user(username="gamma") access_request1 = create_access_request( - "table", "energy_usage", TEST_ROLE_1, "gamma" + session, "table", "energy_usage", TEST_ROLE_1, "gamma" ) - create_access_request("table", "energy_usage", TEST_ROLE_2, "gamma2") + create_access_request(session, "table", "energy_usage", TEST_ROLE_2, "gamma2") ds_1_id = access_request1.datasource_id # gamma gets granted database access - database = db.session.query(models.Database).first() + database = session.query(models.Database).first() security_manager.add_permission_view_menu("database_access", database.perm) ds_perm_view = security_manager.find_permission_view_menu( @@ -292,7 +296,7 @@ class TestRequestAccess(SupersetTestCase): security_manager.find_role(DB_ACCESS_ROLE), ds_perm_view ) gamma_user.roles.append(security_manager.find_role(DB_ACCESS_ROLE)) - db.session.commit() + session.commit() access_requests = self.get_access_requests("gamma", "table", ds_1_id) self.assertTrue(access_requests) # gamma2 request gets fulfilled @@ -304,21 +308,25 @@ class TestRequestAccess(SupersetTestCase): self.assertFalse(access_requests) gamma_user = security_manager.find_user(username="gamma") gamma_user.roles.remove(security_manager.find_role(DB_ACCESS_ROLE)) - db.session.commit() + session.commit() def test_clean_requests_after_schema_grant(self): + session = db.session + # Case 4. Two access requests from gamma and gamma2 # Gamma gets schema access, gamma2 access request granted # Check if request by gamma has been deleted gamma_user = security_manager.find_user(username="gamma") access_request1 = create_access_request( - "table", "wb_health_population", TEST_ROLE_1, "gamma" + session, "table", "wb_health_population", TEST_ROLE_1, "gamma" + ) + create_access_request( + session, "table", "wb_health_population", TEST_ROLE_2, "gamma2" ) - create_access_request("table", "wb_health_population", TEST_ROLE_2, "gamma2") ds_1_id = access_request1.datasource_id ds = ( - db.session.query(SqlaTable) + session.query(SqlaTable) .filter_by(table_name="wb_health_population") .first() ) @@ -332,7 +340,7 @@ class TestRequestAccess(SupersetTestCase): security_manager.find_role(SCHEMA_ACCESS_ROLE), schema_perm_view ) gamma_user.roles.append(security_manager.find_role(SCHEMA_ACCESS_ROLE)) - db.session.commit() + session.commit() # gamma2 request gets fulfilled self.client.get( EXTEND_ROLE_REQUEST.format("table", ds_1_id, "gamma2", TEST_ROLE_2) @@ -343,24 +351,25 @@ class TestRequestAccess(SupersetTestCase): gamma_user.roles.remove(security_manager.find_role(SCHEMA_ACCESS_ROLE)) ds = ( - db.session.query(SqlaTable) + session.query(SqlaTable) .filter_by(table_name="wb_health_population") .first() ) ds.schema = None - db.session.commit() + session.commit() @mock.patch("superset.utils.core.send_mime_email") def test_approve(self, mock_send_mime): if app.config["ENABLE_ACCESS_REQUEST"]: + session = db.session TEST_ROLE_NAME = "table_role" security_manager.add_role(TEST_ROLE_NAME) # Case 1. Grant new role to the user. access_request1 = create_access_request( - "table", "unicode_test", TEST_ROLE_NAME, "gamma" + session, "table", "unicode_test", TEST_ROLE_NAME, "gamma" ) ds_1_id = access_request1.datasource_id self.get_resp( @@ -395,7 +404,7 @@ class TestRequestAccess(SupersetTestCase): # Case 2. Extend the role to have access to the table access_request2 = create_access_request( - "table", "energy_usage", TEST_ROLE_NAME, "gamma" + session, "table", "energy_usage", TEST_ROLE_NAME, "gamma" ) ds_2_id = access_request2.datasource_id energy_usage_perm = access_request2.datasource.perm @@ -439,7 +448,7 @@ class TestRequestAccess(SupersetTestCase): security_manager.add_role("druid_role") access_request3 = create_access_request( - "druid", "druid_ds_1", "druid_role", "gamma" + session, "druid", "druid_ds_1", "druid_role", "gamma" ) self.get_resp( GRANT_ROLE_REQUEST.format( @@ -454,7 +463,7 @@ class TestRequestAccess(SupersetTestCase): # Case 4. Extend the role to have access to the druid datasource access_request4 = create_access_request( - "druid", "druid_ds_2", "druid_role", "gamma" + session, "druid", "druid_ds_2", "druid_role", "gamma" ) druid_ds_2_perm = access_request4.datasource.perm @@ -474,18 +483,19 @@ class TestRequestAccess(SupersetTestCase): gamma_user = security_manager.find_user(username="gamma") gamma_user.roles.remove(security_manager.find_role("druid_role")) gamma_user.roles.remove(security_manager.find_role(TEST_ROLE_NAME)) - db.session.delete(security_manager.find_role("druid_role")) - db.session.delete(security_manager.find_role(TEST_ROLE_NAME)) - db.session.commit() + session.delete(security_manager.find_role("druid_role")) + session.delete(security_manager.find_role(TEST_ROLE_NAME)) + session.commit() def test_request_access(self): if app.config["ENABLE_ACCESS_REQUEST"]: + session = db.session self.logout() self.login(username="gamma") gamma_user = security_manager.find_user(username="gamma") security_manager.add_role("dummy_role") gamma_user.roles.append(security_manager.find_role("dummy_role")) - db.session.commit() + session.commit() ACCESS_REQUEST = ( "/superset/request_access?" @@ -501,7 +511,7 @@ class TestRequestAccess(SupersetTestCase): # Request table access, there are no roles have this table. table1 = ( - db.session.query(SqlaTable) + session.query(SqlaTable) .filter_by(table_name="random_time_series") .first() ) @@ -516,7 +526,7 @@ class TestRequestAccess(SupersetTestCase): # Request access, roles exist that contains the table. # add table to the existing roles table3 = ( - db.session.query(SqlaTable).filter_by(table_name="energy_usage").first() + session.query(SqlaTable).filter_by(table_name="energy_usage").first() ) table_3_id = table3.id table3_perm = table3.perm @@ -535,7 +545,7 @@ class TestRequestAccess(SupersetTestCase): "datasource_access", table3_perm ), ) - db.session.commit() + session.commit() self.get_resp(ACCESS_REQUEST.format("table", table_3_id, "go")) access_request3 = self.get_access_requests("gamma", "table", table_3_id) @@ -549,7 +559,7 @@ class TestRequestAccess(SupersetTestCase): # Request druid access, there are no roles have this table. druid_ds_4 = ( - db.session.query(DruidDatasource) + session.query(DruidDatasource) .filter_by(datasource_name="druid_ds_1") .first() ) @@ -564,7 +574,7 @@ class TestRequestAccess(SupersetTestCase): # Case 5. Roles exist that contains the druid datasource. # add druid ds to the existing roles druid_ds_5 = ( - db.session.query(DruidDatasource) + session.query(DruidDatasource) .filter_by(datasource_name="druid_ds_2") .first() ) @@ -585,7 +595,7 @@ class TestRequestAccess(SupersetTestCase): "datasource_access", druid_ds_5_perm ), ) - db.session.commit() + session.commit() self.get_resp(ACCESS_REQUEST.format("druid", druid_ds_5_id, "go")) access_request5 = self.get_access_requests("gamma", "druid", druid_ds_5_id) @@ -600,7 +610,7 @@ class TestRequestAccess(SupersetTestCase): # cleanup gamma_user = security_manager.find_user(username="gamma") gamma_user.roles.remove(security_manager.find_role("dummy_role")) - db.session.commit() + session.commit() if __name__ == "__main__": diff --git a/tests/alerts_tests.py b/tests/alerts_tests.py index 07205810e..c78847cfa 100644 --- a/tests/alerts_tests.py +++ b/tests/alerts_tests.py @@ -32,118 +32,112 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -def setup_module(): +@pytest.yield_fixture(scope="module") +def setup_database(): with app.app_context(): slice_id = db.session.query(Slice).all()[0].id database_id = utils.get_example_database().id - alerts = [ - Alert( - id=1, - label="alert_1", - active=True, - crontab="*/1 * * * *", - sql="SELECT 0", - alert_type="email", - slice_id=slice_id, - database_id=database_id, - ), - Alert( - id=2, - label="alert_2", - active=True, - crontab="*/1 * * * *", - sql="SELECT 55", - alert_type="email", - slice_id=slice_id, - database_id=database_id, - ), - Alert( - id=3, - label="alert_3", - active=False, - crontab="*/1 * * * *", - sql="UPDATE 55", - alert_type="email", - slice_id=slice_id, - database_id=database_id, - ), - Alert(id=4, active=False, label="alert_4", database_id=-1), - Alert(id=5, active=False, label="alert_5", database_id=database_id), - ] + alert1 = Alert( + id=1, + label="alert_1", + active=True, + crontab="*/1 * * * *", + sql="SELECT 0", + alert_type="email", + slice_id=slice_id, + database_id=database_id, + ) + alert2 = Alert( + id=2, + label="alert_2", + active=True, + crontab="*/1 * * * *", + sql="SELECT 55", + alert_type="email", + slice_id=slice_id, + database_id=database_id, + ) + alert3 = Alert( + id=3, + label="alert_3", + active=False, + crontab="*/1 * * * *", + sql="UPDATE 55", + alert_type="email", + slice_id=slice_id, + database_id=database_id, + ) + alert4 = Alert(id=4, active=False, label="alert_4", database_id=-1) + alert5 = Alert(id=5, active=False, label="alert_5", database_id=database_id) - db.session.bulk_save_objects(alerts) + for num in range(1, 6): + eval(f"db.session.add(alert{num})") db.session.commit() + yield db.session - -def teardown_module(): - with app.app_context(): db.session.query(AlertLog).delete() db.session.query(Alert).delete() @patch("superset.tasks.schedules.deliver_alert") @patch("superset.tasks.schedules.logging.Logger.error") -def test_run_alert_query(mock_error, mock_deliver_alert): - with app.app_context(): - run_alert_query(db.session.query(Alert).filter_by(id=1).one().id) - alert1 = db.session.query(Alert).filter_by(id=1).one() - assert mock_deliver_alert.call_count == 0 - assert len(alert1.logs) == 1 - assert alert1.logs[0].alert_id == 1 - assert alert1.logs[0].state == "pass" +def test_run_alert_query(mock_error, mock_deliver, setup_database): + database = setup_database + run_alert_query(database.query(Alert).filter_by(id=1).one().id, database) + alert1 = database.query(Alert).filter_by(id=1).one() + assert mock_deliver.call_count == 0 + assert len(alert1.logs) == 1 + assert alert1.logs[0].alert_id == 1 + assert alert1.logs[0].state == "pass" - run_alert_query(db.session.query(Alert).filter_by(id=2).one().id) - alert2 = db.session.query(Alert).filter_by(id=2).one() - assert mock_deliver_alert.call_count == 1 - assert len(alert2.logs) == 1 - assert alert2.logs[0].alert_id == 2 - assert alert2.logs[0].state == "trigger" + run_alert_query(database.query(Alert).filter_by(id=2).one().id, database) + alert2 = database.query(Alert).filter_by(id=2).one() + assert mock_deliver.call_count == 1 + assert len(alert2.logs) == 1 + assert alert2.logs[0].alert_id == 2 + assert alert2.logs[0].state == "trigger" - run_alert_query(db.session.query(Alert).filter_by(id=3).one().id) - alert3 = db.session.query(Alert).filter_by(id=3).one() - assert mock_deliver_alert.call_count == 1 - assert mock_error.call_count == 2 - assert len(alert3.logs) == 1 - assert alert3.logs[0].alert_id == 3 - assert alert3.logs[0].state == "error" + run_alert_query(database.query(Alert).filter_by(id=3).one().id, database) + alert3 = database.query(Alert).filter_by(id=3).one() + assert mock_deliver.call_count == 1 + assert mock_error.call_count == 2 + assert len(alert3.logs) == 1 + assert alert3.logs[0].alert_id == 3 + assert alert3.logs[0].state == "error" - run_alert_query(db.session.query(Alert).filter_by(id=4).one().id) - assert mock_deliver_alert.call_count == 1 - assert mock_error.call_count == 3 + run_alert_query(database.query(Alert).filter_by(id=4).one().id, database) + assert mock_deliver.call_count == 1 + assert mock_error.call_count == 3 - run_alert_query(db.session.query(Alert).filter_by(id=5).one().id) - assert mock_deliver_alert.call_count == 1 - assert mock_error.call_count == 4 + run_alert_query(database.query(Alert).filter_by(id=5).one().id, database) + assert mock_deliver.call_count == 1 + assert mock_error.call_count == 4 @patch("superset.tasks.schedules.deliver_alert") @patch("superset.tasks.schedules.run_alert_query") -def test_schedule_alert_query(mock_run_alert, mock_deliver_alert): - with app.app_context(): - active_alert = db.session.query(Alert).filter_by(id=1).one() - inactive_alert = db.session.query(Alert).filter_by(id=3).one() +def test_schedule_alert_query(mock_run_alert, mock_deliver_alert, setup_database): + database = setup_database + active_alert = database.query(Alert).filter_by(id=1).one() + inactive_alert = database.query(Alert).filter_by(id=3).one() - # Test that inactive alerts are no processed - schedule_alert_query( - report_type=ScheduleType.alert, schedule_id=inactive_alert.id - ) - assert mock_run_alert.call_count == 0 - assert mock_deliver_alert.call_count == 0 + # Test that inactive alerts are no processed + schedule_alert_query(report_type=ScheduleType.alert, schedule_id=inactive_alert.id) + assert mock_run_alert.call_count == 0 + assert mock_deliver_alert.call_count == 0 - # Test that active alerts with no recipients passed in are processed regularly - schedule_alert_query( - report_type=ScheduleType.alert, schedule_id=active_alert.id - ) - assert mock_run_alert.call_count == 1 - assert mock_deliver_alert.call_count == 0 + # Test that active alerts with no recipients passed in are processed regularly + schedule_alert_query(report_type=ScheduleType.alert, schedule_id=active_alert.id) + assert mock_run_alert.call_count == 1 + assert mock_deliver_alert.call_count == 0 - # Test that active alerts sent as a test are delivered immediately - schedule_alert_query( - report_type=ScheduleType.alert, - schedule_id=active_alert.id, - recipients="testing@email.com", - is_test_alert=True, - ) - assert mock_run_alert.call_count == 1 - assert mock_deliver_alert.call_count == 1 + # Test that active alerts sent as a test are delivered immediately + schedule_alert_query( + report_type=ScheduleType.alert, + schedule_id=active_alert.id, + recipients="testing@email.com", + is_test_alert=True, + ) + assert mock_run_alert.call_count == 1 + assert mock_deliver_alert.call_count == 1 diff --git a/tests/base_tests.py b/tests/base_tests.py index c74378f5a..e0a20a4a7 100644 --- a/tests/base_tests.py +++ b/tests/base_tests.py @@ -25,6 +25,7 @@ import pandas as pd from flask import Response from flask_appbuilder.security.sqla import models as ab_models from flask_testing import TestCase +from sqlalchemy.orm import Session from tests.test_app import app from superset.sql_parse import CtasMethod @@ -103,25 +104,24 @@ class SupersetTestCase(TestCase): # create druid cluster and druid datasources with app.app_context(): + session = db.session cluster = ( - db.session.query(DruidCluster) - .filter_by(cluster_name="druid_test") - .first() + session.query(DruidCluster).filter_by(cluster_name="druid_test").first() ) if not cluster: cluster = DruidCluster(cluster_name="druid_test") - db.session.add(cluster) - db.session.commit() + session.add(cluster) + session.commit() druid_datasource1 = DruidDatasource( datasource_name="druid_ds_1", cluster=cluster ) - db.session.add(druid_datasource1) + session.add(druid_datasource1) druid_datasource2 = DruidDatasource( datasource_name="druid_ds_2", cluster=cluster ) - db.session.add(druid_datasource2) - db.session.commit() + session.add(druid_datasource2) + session.commit() @staticmethod def get_table_by_id(table_id: int) -> SqlaTable: @@ -135,23 +135,25 @@ class SupersetTestCase(TestCase): except ImportError: return False - def get_or_create(self, cls, criteria, **kwargs): - obj = db.session.query(cls).filter_by(**criteria).first() + def get_or_create(self, cls, criteria, session, **kwargs): + obj = session.query(cls).filter_by(**criteria).first() if not obj: obj = cls(**criteria) obj.__dict__.update(**kwargs) - db.session.add(obj) - db.session.commit() + session.add(obj) + session.commit() return obj def login(self, username="admin", password="general"): resp = self.get_resp("/login/", data=dict(username=username, password=password)) self.assertNotIn("User confirmation needed", resp) - def get_slice(self, slice_name: str, expunge_from_session: bool = True) -> Slice: - slc = db.session.query(Slice).filter_by(slice_name=slice_name).one() + def get_slice( + self, slice_name: str, session: Session, expunge_from_session: bool = True + ) -> Slice: + slc = session.query(Slice).filter_by(slice_name=slice_name).one() if expunge_from_session: - db.session.expunge_all() + session.expunge_all() return slc @staticmethod @@ -300,6 +302,7 @@ class SupersetTestCase(TestCase): return self.get_or_create( cls=models.Database, criteria={"database_name": database_name}, + session=db.session, sqlalchemy_uri="sqlite:///:memory:", id=db_id, extra=extra, @@ -321,6 +324,7 @@ class SupersetTestCase(TestCase): return self.get_or_create( cls=models.Database, criteria={"database_name": database_name}, + session=db.session, sqlalchemy_uri="presto://user@host:8080/hive", id=db_id, ) diff --git a/tests/celery_tests.py b/tests/celery_tests.py index 68a7213ba..53190cb7d 100644 --- a/tests/celery_tests.py +++ b/tests/celery_tests.py @@ -99,13 +99,15 @@ class TestAppContext(SupersetTestCase): class TestCelery(SupersetTestCase): def get_query_by_name(self, sql): - query = db.session.query(Query).filter_by(sql=sql).first() - db.session.close() + session = db.session + query = session.query(Query).filter_by(sql=sql).first() + session.close() return query def get_query_by_id(self, id): - query = db.session.query(Query).filter_by(id=id).first() - db.session.close() + session = db.session + query = session.query(Query).filter_by(id=id).first() + session.close() return query @classmethod diff --git a/tests/charts/api_tests.py b/tests/charts/api_tests.py index 0820518b5..5048a0a6d 100644 --- a/tests/charts/api_tests.py +++ b/tests/charts/api_tests.py @@ -58,7 +58,9 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin): for owner in owners: user = db.session.query(security_manager.user_model).get(owner) obj_owners.append(user) - datasource = ConnectorRegistry.get_datasource(datasource_type, datasource_id) + datasource = ConnectorRegistry.get_datasource( + datasource_type, datasource_id, db.session + ) slice = Slice( slice_name=slice_name, datasource_id=datasource.id, diff --git a/tests/core_tests.py b/tests/core_tests.py index 2793795fb..ade009502 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -100,7 +100,7 @@ class TestCore(SupersetTestCase): def test_slice_endpoint(self): self.login(username="admin") - slc = self.get_slice("Girls") + slc = self.get_slice("Girls", db.session) resp = self.get_resp("/superset/slice/{}/".format(slc.id)) assert "Time Column" in resp assert "List Roles" in resp @@ -114,7 +114,7 @@ class TestCore(SupersetTestCase): def test_viz_cache_key(self): self.login(username="admin") - slc = self.get_slice("Girls") + slc = self.get_slice("Girls", db.session) viz = slc.viz qobj = viz.query_obj() @@ -233,7 +233,7 @@ class TestCore(SupersetTestCase): def test_save_slice(self): self.login(username="admin") slice_name = f"Energy Sankey" - slice_id = self.get_slice(slice_name).id + slice_id = self.get_slice(slice_name, db.session).id copy_name_prefix = "Test Sankey" copy_name = f"{copy_name_prefix}[save]{random.random()}" tbl_id = self.table_ids.get("energy_usage") @@ -299,7 +299,7 @@ class TestCore(SupersetTestCase): def test_filter_endpoint(self): self.login(username="admin") slice_name = "Energy Sankey" - slice_id = self.get_slice(slice_name).id + slice_id = self.get_slice(slice_name, db.session).id db.session.commit() tbl_id = self.table_ids.get("energy_usage") table = db.session.query(SqlaTable).filter(SqlaTable.id == tbl_id) @@ -319,7 +319,9 @@ class TestCore(SupersetTestCase): def test_slice_data(self): # slice data should have some required attributes self.login(username="admin") - slc = self.get_slice(slice_name="Girls", expunge_from_session=False) + slc = self.get_slice( + slice_name="Girls", session=db.session, expunge_from_session=False + ) slc_data_attributes = slc.data.keys() assert "changed_on" in slc_data_attributes assert "modified" in slc_data_attributes @@ -370,7 +372,9 @@ class TestCore(SupersetTestCase): self.assertEqual(data, []) # make user owner of slice and verify that endpoint returns said slice - slc = self.get_slice(slice_name=slice_name, expunge_from_session=False) + slc = self.get_slice( + slice_name=slice_name, session=db.session, expunge_from_session=False + ) slc.owners = [user] db.session.merge(slc) db.session.commit() @@ -381,7 +385,9 @@ class TestCore(SupersetTestCase): self.assertEqual(data[0]["title"], slice_name) # remove ownership and ensure user no longer gets slice - slc = self.get_slice(slice_name=slice_name, expunge_from_session=False) + slc = self.get_slice( + slice_name=slice_name, session=db.session, expunge_from_session=False + ) slc.owners = [] db.session.merge(slc) db.session.commit() @@ -559,7 +565,7 @@ class TestCore(SupersetTestCase): db.session.commit() def test_warm_up_cache(self): - slc = self.get_slice("Girls") + slc = self.get_slice("Girls", db.session) data = self.get_json_resp("/superset/warm_up_cache?slice_id={}".format(slc.id)) self.assertEqual( data, [{"slice_id": slc.id, "viz_error": None, "viz_status": "success"}] @@ -784,7 +790,7 @@ class TestCore(SupersetTestCase): def test_user_profile(self, username="admin"): self.login(username=username) - slc = self.get_slice("Girls") + slc = self.get_slice("Girls", db.session) # Setting some faves url = f"/superset/favstar/Slice/{slc.id}/select/" diff --git a/tests/database_api_tests.py b/tests/database_api_tests.py index 7126c25ed..49ede7b23 100644 --- a/tests/database_api_tests.py +++ b/tests/database_api_tests.py @@ -178,11 +178,12 @@ class TestDatabaseApi(SupersetTestCase): """ Database API: Test get select star with datasource access """ + session = db.session table = SqlaTable( schema="main", table_name="ab_permission", database=get_main_database() ) - db.session.add(table) - db.session.commit() + session.add(table) + session.commit() tmp_table_perm = security_manager.find_permission_view_menu( "datasource_access", table.get_perm() diff --git a/tests/datasets/api_tests.py b/tests/datasets/api_tests.py index 7e078fd73..08c8d5ae2 100644 --- a/tests/datasets/api_tests.py +++ b/tests/datasets/api_tests.py @@ -156,7 +156,7 @@ class TestDatasetApi(SupersetTestCase): "template_params": None, } for key, value in expected_result.items(): - self.assertEqual(response["result"][key], value) + self.assertEqual(response["result"][key], expected_result[key]) self.assertEqual(len(response["result"]["columns"]), 8) self.assertEqual(len(response["result"]["metrics"]), 2) @@ -721,7 +721,10 @@ class TestDatasetApi(SupersetTestCase): ) cli_export = export_to_dict( - recursive=True, back_references=False, include_defaults=False, + session=db.session, + recursive=True, + back_references=False, + include_defaults=False, ) cli_export_tables = cli_export["databases"][0]["tables"] expected_response = [] diff --git a/tests/dict_import_export_tests.py b/tests/dict_import_export_tests.py index 725ffdfec..dc0b8d8d8 100644 --- a/tests/dict_import_export_tests.py +++ b/tests/dict_import_export_tests.py @@ -47,13 +47,14 @@ class TestDictImportExport(SupersetTestCase): def delete_imports(cls): with app.app_context(): # Imported data clean up - for table in db.session.query(SqlaTable): + session = db.session + for table in session.query(SqlaTable): if DBREF in table.params_dict: - db.session.delete(table) - for datasource in db.session.query(DruidDatasource): + session.delete(table) + for datasource in session.query(DruidDatasource): if DBREF in datasource.params_dict: - db.session.delete(datasource) - db.session.commit() + session.delete(datasource) + session.commit() @classmethod def setUpClass(cls): @@ -89,7 +90,9 @@ class TestDictImportExport(SupersetTestCase): def create_druid_datasource(self, name, id=0, cols_names=[], metric_names=[]): cluster_name = "druid_test" - cluster = self.get_or_create(DruidCluster, {"cluster_name": cluster_name}) + cluster = self.get_or_create( + DruidCluster, {"cluster_name": cluster_name}, db.session + ) name = "{0}{1}".format(NAME_PREFIX, name) params = {DBREF: id, "database_name": cluster_name} @@ -156,7 +159,7 @@ class TestDictImportExport(SupersetTestCase): def test_import_table_no_metadata(self): table, dict_table = self.create_table("pure_table", id=ID_PREFIX + 1) - new_table = SqlaTable.import_from_dict(dict_table) + new_table = SqlaTable.import_from_dict(db.session, dict_table) db.session.commit() imported_id = new_table.id imported = self.get_table_by_id(imported_id) @@ -170,7 +173,7 @@ class TestDictImportExport(SupersetTestCase): cols_names=["col1"], metric_names=["metric1"], ) - imported_table = SqlaTable.import_from_dict(dict_table) + imported_table = SqlaTable.import_from_dict(db.session, dict_table) db.session.commit() imported = self.get_table_by_id(imported_table.id) self.assert_table_equals(table, imported) @@ -186,7 +189,7 @@ class TestDictImportExport(SupersetTestCase): cols_names=["c1", "c2"], metric_names=["m1", "m2"], ) - imported_table = SqlaTable.import_from_dict(dict_table) + imported_table = SqlaTable.import_from_dict(db.session, dict_table) db.session.commit() imported = self.get_table_by_id(imported_table.id) self.assert_table_equals(table, imported) @@ -196,7 +199,7 @@ class TestDictImportExport(SupersetTestCase): table, dict_table = self.create_table( "table_override", id=ID_PREFIX + 3, cols_names=["col1"], metric_names=["m1"] ) - imported_table = SqlaTable.import_from_dict(dict_table) + imported_table = SqlaTable.import_from_dict(db.session, dict_table) db.session.commit() table_over, dict_table_over = self.create_table( "table_override", @@ -204,7 +207,7 @@ class TestDictImportExport(SupersetTestCase): cols_names=["new_col1", "col2", "col3"], metric_names=["new_metric1"], ) - imported_over_table = SqlaTable.import_from_dict(dict_table_over) + imported_over_table = SqlaTable.import_from_dict(db.session, dict_table_over) db.session.commit() imported_over = self.get_table_by_id(imported_over_table.id) @@ -224,7 +227,7 @@ class TestDictImportExport(SupersetTestCase): table, dict_table = self.create_table( "table_override", id=ID_PREFIX + 3, cols_names=["col1"], metric_names=["m1"] ) - imported_table = SqlaTable.import_from_dict(dict_table) + imported_table = SqlaTable.import_from_dict(db.session, dict_table) db.session.commit() table_over, dict_table_over = self.create_table( "table_override", @@ -233,7 +236,7 @@ class TestDictImportExport(SupersetTestCase): metric_names=["new_metric1"], ) imported_over_table = SqlaTable.import_from_dict( - dict_rep=dict_table_over, sync=["metrics", "columns"] + session=db.session, dict_rep=dict_table_over, sync=["metrics", "columns"] ) db.session.commit() @@ -257,7 +260,7 @@ class TestDictImportExport(SupersetTestCase): cols_names=["new_col1", "col2", "col3"], metric_names=["new_metric1"], ) - imported_table = SqlaTable.import_from_dict(dict_table) + imported_table = SqlaTable.import_from_dict(db.session, dict_table) db.session.commit() copy_table, dict_copy_table = self.create_table( "copy_cat", @@ -265,7 +268,7 @@ class TestDictImportExport(SupersetTestCase): cols_names=["new_col1", "col2", "col3"], metric_names=["new_metric1"], ) - imported_copy_table = SqlaTable.import_from_dict(dict_copy_table) + imported_copy_table = SqlaTable.import_from_dict(db.session, dict_copy_table) db.session.commit() self.assertEqual(imported_table.id, imported_copy_table.id) self.assert_table_equals(copy_table, self.get_table_by_id(imported_table.id)) @@ -278,7 +281,10 @@ class TestDictImportExport(SupersetTestCase): self.delete_fake_db() cli_export = export_to_dict( - recursive=True, back_references=False, include_defaults=False, + session=db.session, + recursive=True, + back_references=False, + include_defaults=False, ) self.get_resp("/login/", data=dict(username="admin", password="general")) resp = self.get_resp( @@ -297,7 +303,7 @@ class TestDictImportExport(SupersetTestCase): datasource, dict_datasource = self.create_druid_datasource( "pure_druid", id=ID_PREFIX + 1 ) - imported_cluster = DruidDatasource.import_from_dict(dict_datasource) + imported_cluster = DruidDatasource.import_from_dict(db.session, dict_datasource) db.session.commit() imported = self.get_datasource(imported_cluster.id) self.assert_datasource_equals(datasource, imported) @@ -309,7 +315,7 @@ class TestDictImportExport(SupersetTestCase): cols_names=["col1"], metric_names=["metric1"], ) - imported_cluster = DruidDatasource.import_from_dict(dict_datasource) + imported_cluster = DruidDatasource.import_from_dict(db.session, dict_datasource) db.session.commit() imported = self.get_datasource(imported_cluster.id) self.assert_datasource_equals(datasource, imported) @@ -325,7 +331,7 @@ class TestDictImportExport(SupersetTestCase): cols_names=["c1", "c2"], metric_names=["m1", "m2"], ) - imported_cluster = DruidDatasource.import_from_dict(dict_datasource) + imported_cluster = DruidDatasource.import_from_dict(db.session, dict_datasource) db.session.commit() imported = self.get_datasource(imported_cluster.id) self.assert_datasource_equals(datasource, imported) @@ -334,7 +340,7 @@ class TestDictImportExport(SupersetTestCase): datasource, dict_datasource = self.create_druid_datasource( "druid_override", id=ID_PREFIX + 3, cols_names=["col1"], metric_names=["m1"] ) - imported_cluster = DruidDatasource.import_from_dict(dict_datasource) + imported_cluster = DruidDatasource.import_from_dict(db.session, dict_datasource) db.session.commit() table_over, table_over_dict = self.create_druid_datasource( "druid_override", @@ -342,7 +348,9 @@ class TestDictImportExport(SupersetTestCase): cols_names=["new_col1", "col2", "col3"], metric_names=["new_metric1"], ) - imported_over_cluster = DruidDatasource.import_from_dict(table_over_dict) + imported_over_cluster = DruidDatasource.import_from_dict( + db.session, table_over_dict + ) db.session.commit() imported_over = self.get_datasource(imported_over_cluster.id) self.assertEqual(imported_cluster.id, imported_over.id) @@ -358,7 +366,7 @@ class TestDictImportExport(SupersetTestCase): datasource, dict_datasource = self.create_druid_datasource( "druid_override", id=ID_PREFIX + 3, cols_names=["col1"], metric_names=["m1"] ) - imported_cluster = DruidDatasource.import_from_dict(dict_datasource) + imported_cluster = DruidDatasource.import_from_dict(db.session, dict_datasource) db.session.commit() table_over, table_over_dict = self.create_druid_datasource( "druid_override", @@ -367,7 +375,7 @@ class TestDictImportExport(SupersetTestCase): metric_names=["new_metric1"], ) imported_over_cluster = DruidDatasource.import_from_dict( - dict_rep=table_over_dict, sync=["metrics", "columns"] + session=db.session, dict_rep=table_over_dict, sync=["metrics", "columns"] ) # syncing metrics and columns db.session.commit() imported_over = self.get_datasource(imported_over_cluster.id) @@ -387,7 +395,9 @@ class TestDictImportExport(SupersetTestCase): cols_names=["new_col1", "col2", "col3"], metric_names=["new_metric1"], ) - imported = DruidDatasource.import_from_dict(dict_rep=dict_datasource) + imported = DruidDatasource.import_from_dict( + session=db.session, dict_rep=dict_datasource + ) db.session.commit() copy_datasource, dict_cp_datasource = self.create_druid_datasource( "copy_cat", @@ -395,7 +405,7 @@ class TestDictImportExport(SupersetTestCase): cols_names=["new_col1", "col2", "col3"], metric_names=["new_metric1"], ) - imported_copy = DruidDatasource.import_from_dict(dict_cp_datasource) + imported_copy = DruidDatasource.import_from_dict(db.session, dict_cp_datasource) db.session.commit() self.assertEqual(imported.id, imported_copy.id) diff --git a/tests/druid_tests.py b/tests/druid_tests.py index c75767152..648eb32cc 100644 --- a/tests/druid_tests.py +++ b/tests/druid_tests.py @@ -212,7 +212,9 @@ class TestDruid(SupersetTestCase): def test_druid_sync_from_config(self): CLUSTER_NAME = "new_druid" self.login() - cluster = self.get_or_create(DruidCluster, {"cluster_name": CLUSTER_NAME}) + cluster = self.get_or_create( + DruidCluster, {"cluster_name": CLUSTER_NAME}, db.session + ) db.session.merge(cluster) db.session.commit() @@ -300,12 +302,15 @@ class TestDruid(SupersetTestCase): @unittest.skipUnless(app.config["DRUID_IS_ACTIVE"], "DRUID_IS_ACTIVE is false") def test_filter_druid_datasource(self): CLUSTER_NAME = "new_druid" - cluster = self.get_or_create(DruidCluster, {"cluster_name": CLUSTER_NAME}) + cluster = self.get_or_create( + DruidCluster, {"cluster_name": CLUSTER_NAME}, db.session + ) db.session.merge(cluster) gamma_ds = self.get_or_create( DruidDatasource, {"datasource_name": "datasource_for_gamma", "cluster": cluster}, + db.session, ) gamma_ds.cluster = cluster db.session.merge(gamma_ds) @@ -313,6 +318,7 @@ class TestDruid(SupersetTestCase): no_gamma_ds = self.get_or_create( DruidDatasource, {"datasource_name": "datasource_not_for_gamma", "cluster": cluster}, + db.session, ) no_gamma_ds.cluster = cluster db.session.merge(no_gamma_ds) diff --git a/tests/import_export_tests.py b/tests/import_export_tests.py index fc6ee51c5..e772d16b1 100644 --- a/tests/import_export_tests.py +++ b/tests/import_export_tests.py @@ -46,19 +46,20 @@ class TestImportExport(SupersetTestCase): def delete_imports(cls): with app.app_context(): # Imported data clean up - for slc in db.session.query(Slice): + session = db.session + for slc in session.query(Slice): if "remote_id" in slc.params_dict: - db.session.delete(slc) - for dash in db.session.query(Dashboard): + session.delete(slc) + for dash in session.query(Dashboard): if "remote_id" in dash.params_dict: - db.session.delete(dash) - for table in db.session.query(SqlaTable): + session.delete(dash) + for table in session.query(SqlaTable): if "remote_id" in table.params_dict: - db.session.delete(table) - for datasource in db.session.query(DruidDatasource): + session.delete(table) + for datasource in session.query(DruidDatasource): if "remote_id" in datasource.params_dict: - db.session.delete(datasource) - db.session.commit() + session.delete(datasource) + session.commit() @classmethod def setUpClass(cls): @@ -125,7 +126,9 @@ class TestImportExport(SupersetTestCase): def create_druid_datasource(self, name, id=0, cols_names=[], metric_names=[]): cluster_name = "druid_test" - cluster = self.get_or_create(DruidCluster, {"cluster_name": cluster_name}) + cluster = self.get_or_create( + DruidCluster, {"cluster_name": cluster_name}, db.session + ) params = {"remote_id": id, "database_name": cluster_name} datasource = DruidDatasource( diff --git a/tests/query_context_tests.py b/tests/query_context_tests.py index e016b13a9..f816bcd59 100644 --- a/tests/query_context_tests.py +++ b/tests/query_context_tests.py @@ -83,6 +83,7 @@ class TestQueryContext(SupersetTestCase): datasource = ConnectorRegistry.get_datasource( datasource_type=payload["datasource"]["type"], datasource_id=payload["datasource"]["id"], + session=db.session, ) description_original = datasource.description datasource.description = "temporary description" diff --git a/tests/security_tests.py b/tests/security_tests.py index a161adabe..60d20fde0 100644 --- a/tests/security_tests.py +++ b/tests/security_tests.py @@ -69,8 +69,9 @@ class TestRolePermission(SupersetTestCase): """Testing export role permissions.""" def setUp(self): + session = db.session security_manager.add_role(SCHEMA_ACCESS_ROLE) - db.session.commit() + session.commit() ds = ( db.session.query(SqlaTable) @@ -81,7 +82,7 @@ class TestRolePermission(SupersetTestCase): ds.schema_perm = ds.get_schema_perm() ds_slices = ( - db.session.query(Slice) + session.query(Slice) .filter_by(datasource_type="table") .filter_by(datasource_id=ds.id) .all() @@ -91,11 +92,12 @@ class TestRolePermission(SupersetTestCase): create_schema_perm("[examples].[temp_schema]") gamma_user = security_manager.find_user(username="gamma") gamma_user.roles.append(security_manager.find_role(SCHEMA_ACCESS_ROLE)) - db.session.commit() + session.commit() def tearDown(self): + session = db.session ds = ( - db.session.query(SqlaTable) + session.query(SqlaTable) .filter_by(table_name="wb_health_population") .first() ) @@ -103,7 +105,7 @@ class TestRolePermission(SupersetTestCase): ds.schema = None ds.schema_perm = None ds_slices = ( - db.session.query(Slice) + session.query(Slice) .filter_by(datasource_type="table") .filter_by(datasource_id=ds.id) .all() @@ -112,20 +114,21 @@ class TestRolePermission(SupersetTestCase): s.schema_perm = None delete_schema_perm(schema_perm) - db.session.delete(security_manager.find_role(SCHEMA_ACCESS_ROLE)) - db.session.commit() + session.delete(security_manager.find_role(SCHEMA_ACCESS_ROLE)) + session.commit() def test_set_perm_sqla_table(self): + session = db.session table = SqlaTable( schema="tmp_schema", table_name="tmp_perm_table", database=get_example_database(), ) - db.session.add(table) - db.session.commit() + session.add(table) + session.commit() stored_table = ( - db.session.query(SqlaTable).filter_by(table_name="tmp_perm_table").one() + session.query(SqlaTable).filter_by(table_name="tmp_perm_table").one() ) self.assertEqual( stored_table.perm, f"[examples].[tmp_perm_table](id:{stored_table.id})" @@ -144,9 +147,9 @@ class TestRolePermission(SupersetTestCase): # table name change stored_table.table_name = "tmp_perm_table_v2" - db.session.commit() + session.commit() stored_table = ( - db.session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one() + session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one() ) self.assertEqual( stored_table.perm, f"[examples].[tmp_perm_table_v2](id:{stored_table.id})" @@ -166,9 +169,9 @@ class TestRolePermission(SupersetTestCase): # schema name change stored_table.schema = "tmp_schema_v2" - db.session.commit() + session.commit() stored_table = ( - db.session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one() + session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one() ) self.assertEqual( stored_table.perm, f"[examples].[tmp_perm_table_v2](id:{stored_table.id})" @@ -188,13 +191,13 @@ class TestRolePermission(SupersetTestCase): # database change new_db = Database(sqlalchemy_uri="some_uri", database_name="tmp_db") - db.session.add(new_db) + session.add(new_db) stored_table.database = ( - db.session.query(Database).filter_by(database_name="tmp_db").one() + session.query(Database).filter_by(database_name="tmp_db").one() ) - db.session.commit() + session.commit() stored_table = ( - db.session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one() + session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one() ) self.assertEqual( stored_table.perm, f"[tmp_db].[tmp_perm_table_v2](id:{stored_table.id})" @@ -214,9 +217,9 @@ class TestRolePermission(SupersetTestCase): # no schema stored_table.schema = None - db.session.commit() + session.commit() stored_table = ( - db.session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one() + session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one() ) self.assertEqual( stored_table.perm, f"[tmp_db].[tmp_perm_table_v2](id:{stored_table.id})" @@ -228,25 +231,26 @@ class TestRolePermission(SupersetTestCase): ) self.assertIsNone(stored_table.schema_perm) - db.session.delete(new_db) - db.session.delete(stored_table) - db.session.commit() + session.delete(new_db) + session.delete(stored_table) + session.commit() def test_set_perm_druid_datasource(self): + session = db.session druid_cluster = ( - db.session.query(DruidCluster).filter_by(cluster_name="druid_test").one() + session.query(DruidCluster).filter_by(cluster_name="druid_test").one() ) datasource = DruidDatasource( datasource_name="tmp_datasource", cluster=druid_cluster, cluster_id=druid_cluster.id, ) - db.session.add(datasource) - db.session.commit() + session.add(datasource) + session.commit() # store without a schema stored_datasource = ( - db.session.query(DruidDatasource) + session.query(DruidDatasource) .filter_by(datasource_name="tmp_datasource") .one() ) @@ -263,7 +267,7 @@ class TestRolePermission(SupersetTestCase): # store with a schema stored_datasource.datasource_name = "tmp_schema.tmp_datasource" - db.session.commit() + session.commit() self.assertEqual( stored_datasource.perm, f"[druid_test].[tmp_schema.tmp_datasource](id:{stored_datasource.id})", @@ -280,15 +284,16 @@ class TestRolePermission(SupersetTestCase): ) ) - db.session.delete(stored_datasource) - db.session.commit() + session.delete(stored_datasource) + session.commit() def test_set_perm_druid_cluster(self): + session = db.session cluster = DruidCluster(cluster_name="tmp_druid_cluster") - db.session.add(cluster) + session.add(cluster) stored_cluster = ( - db.session.query(DruidCluster) + session.query(DruidCluster) .filter_by(cluster_name="tmp_druid_cluster") .one() ) @@ -302,7 +307,7 @@ class TestRolePermission(SupersetTestCase): ) stored_cluster.cluster_name = "tmp_druid_cluster2" - db.session.commit() + session.commit() self.assertEqual( stored_cluster.perm, f"[tmp_druid_cluster2].(id:{stored_cluster.id})" ) @@ -312,17 +317,18 @@ class TestRolePermission(SupersetTestCase): ) ) - db.session.delete(stored_cluster) - db.session.commit() + session.delete(stored_cluster) + session.commit() def test_set_perm_database(self): + session = db.session database = Database( database_name="tmp_database", sqlalchemy_uri="sqlite://test" ) - db.session.add(database) + session.add(database) stored_db = ( - db.session.query(Database).filter_by(database_name="tmp_database").one() + session.query(Database).filter_by(database_name="tmp_database").one() ) self.assertEqual(stored_db.perm, f"[tmp_database].(id:{stored_db.id})") self.assertIsNotNone( @@ -332,9 +338,9 @@ class TestRolePermission(SupersetTestCase): ) stored_db.database_name = "tmp_database2" - db.session.commit() + session.commit() stored_db = ( - db.session.query(Database).filter_by(database_name="tmp_database2").one() + session.query(Database).filter_by(database_name="tmp_database2").one() ) self.assertEqual(stored_db.perm, f"[tmp_database2].(id:{stored_db.id})") self.assertIsNotNone( @@ -343,8 +349,8 @@ class TestRolePermission(SupersetTestCase): ) ) - db.session.delete(stored_db) - db.session.commit() + session.delete(stored_db) + session.commit() def test_hybrid_perm_druid_cluster(self): cluster = DruidCluster(cluster_name="tmp_druid_cluster3") @@ -394,13 +400,14 @@ class TestRolePermission(SupersetTestCase): db.session.commit() def test_set_perm_slice(self): + session = db.session database = Database( database_name="tmp_database", sqlalchemy_uri="sqlite://test" ) table = SqlaTable(table_name="tmp_perm_table", database=database) - db.session.add(database) - db.session.add(table) - db.session.commit() + session.add(database) + session.add(table) + session.commit() # no schema permission slice = Slice( @@ -409,10 +416,10 @@ class TestRolePermission(SupersetTestCase): datasource_name="tmp_perm_table", slice_name="slice_name", ) - db.session.add(slice) - db.session.commit() + session.add(slice) + session.commit() - slice = db.session.query(Slice).filter_by(slice_name="slice_name").one() + slice = session.query(Slice).filter_by(slice_name="slice_name").one() self.assertEqual(slice.perm, table.perm) self.assertEqual(slice.perm, f"[tmp_database].[tmp_perm_table](id:{table.id})") self.assertEqual(slice.schema_perm, table.schema_perm) @@ -420,7 +427,7 @@ class TestRolePermission(SupersetTestCase): table.schema = "tmp_perm_schema" table.table_name = "tmp_perm_table_v2" - db.session.commit() + session.commit() # TODO(bogdan): modify slice permissions on the table update. self.assertNotEquals(slice.perm, table.perm) self.assertEqual(slice.perm, f"[tmp_database].[tmp_perm_table](id:{table.id})") @@ -433,7 +440,7 @@ class TestRolePermission(SupersetTestCase): # updating slice refreshes the permissions slice.slice_name = "slice_name_v2" - db.session.commit() + session.commit() self.assertEqual(slice.perm, table.perm) self.assertEqual( slice.perm, f"[tmp_database].[tmp_perm_table_v2](id:{table.id})" @@ -441,10 +448,11 @@ class TestRolePermission(SupersetTestCase): self.assertEqual(slice.schema_perm, table.schema_perm) self.assertEqual(slice.schema_perm, "[tmp_database].[tmp_perm_schema]") - db.session.delete(slice) - db.session.delete(table) - db.session.delete(database) - db.session.commit() + session.delete(slice) + session.delete(table) + session.delete(database) + + session.commit() # TODO test slice permission @@ -524,11 +532,11 @@ class TestRolePermission(SupersetTestCase): self.assertNotIn("Girl Name Cloud", data) # birth_names slice, no access def test_sqllab_gamma_user_schema_access_to_sqllab(self): - example_db = ( - db.session.query(Database).filter_by(database_name="examples").one() - ) + session = db.session + + example_db = session.query(Database).filter_by(database_name="examples").one() example_db.expose_in_sqllab = True - db.session.commit() + session.commit() arguments = { "keys": ["none"], @@ -951,10 +959,12 @@ class TestRowLevelSecurity(SupersetTestCase): rls_entry = None def setUp(self): + session = db.session + # Create the RowLevelSecurityFilter self.rls_entry = RowLevelSecurityFilter() self.rls_entry.tables.extend( - db.session.query(SqlaTable) + session.query(SqlaTable) .filter(SqlaTable.table_name.in_(["energy_usage", "unicode_test"])) .all() ) @@ -964,11 +974,13 @@ class TestRowLevelSecurity(SupersetTestCase): ) # db.session.query(Role).filter_by(name="Gamma").first()) self.rls_entry.roles.append(security_manager.find_role("Alpha")) db.session.add(self.rls_entry) + db.session.commit() def tearDown(self): - db.session.delete(self.rls_entry) - db.session.commit() + session = db.session + session.delete(self.rls_entry) + session.commit() # Do another test to make sure it doesn't alter another query def test_rls_filter_alters_query(self): diff --git a/tests/sqllab_tests.py b/tests/sqllab_tests.py index 9315d3968..bff8d9dbc 100644 --- a/tests/sqllab_tests.py +++ b/tests/sqllab_tests.py @@ -63,6 +63,7 @@ class TestSqlLab(SupersetTestCase): self.logout() db.session.query(Query).delete() db.session.commit() + db.session.close() def test_sql_json(self): self.login("admin") @@ -459,6 +460,7 @@ class TestSqlLab(SupersetTestCase): Test query api with can_access_all_queries perm added to gamma and make sure all queries show up. """ + session = db.session # Add all_query_access perm to Gamma user all_queries_view = security_manager.find_permission_view_menu( @@ -468,7 +470,7 @@ class TestSqlLab(SupersetTestCase): security_manager.add_permission_role( security_manager.find_role("gamma_sqllab"), all_queries_view ) - db.session.commit() + session.commit() # Test search_queries for Admin user self.run_some_queries() @@ -485,7 +487,7 @@ class TestSqlLab(SupersetTestCase): security_manager.find_role("gamma_sqllab"), all_queries_view ) - db.session.commit() + session.commit() def test_query_admin_can_access_all_queries(self) -> None: """ diff --git a/tests/strategy_tests.py b/tests/strategy_tests.py index 49e234900..c4f0019c6 100644 --- a/tests/strategy_tests.py +++ b/tests/strategy_tests.py @@ -194,7 +194,7 @@ class TestCacheWarmUp(SupersetTestCase): db.session.commit() def test_dashboard_tags(self): - tag1 = get_tag("tag1", TagTypes.custom) + tag1 = get_tag("tag1", db.session, TagTypes.custom) # delete first to make test idempotent self.reset_tag(tag1) @@ -204,7 +204,7 @@ class TestCacheWarmUp(SupersetTestCase): self.assertEqual(result, expected) # tag dashboard 'births' with `tag1` - tag1 = get_tag("tag1", TagTypes.custom) + tag1 = get_tag("tag1", db.session, TagTypes.custom) dash = self.get_dash_by_slug("births") tag1_urls = sorted([f"{URL_PREFIX}{slc.url}" for slc in dash.slices]) tagged_object = TaggedObject( @@ -216,7 +216,7 @@ class TestCacheWarmUp(SupersetTestCase): self.assertEqual(sorted(strategy.get_urls()), tag1_urls) strategy = DashboardTagsStrategy(["tag2"]) - tag2 = get_tag("tag2", TagTypes.custom) + tag2 = get_tag("tag2", db.session, TagTypes.custom) self.reset_tag(tag2) result = sorted(strategy.get_urls())