From bece2ea3e4b9979f3f45f63aa490f499095c7078 Mon Sep 17 00:00:00 2001 From: John Bodley <4567245+john-bodley@users.noreply.github.com> Date: Thu, 16 Nov 2023 19:04:04 -0800 Subject: [PATCH] chore: Remove unnecessary autoflush from tagging and key/value workflows (#26009) --- superset/key_value/commands/delete.py | 8 +--- superset/key_value/commands/get.py | 7 +--- superset/key_value/commands/update.py | 5 +-- superset/key_value/commands/upsert.py | 5 +-- superset/tags/models.py | 38 +++++-------------- .../key_value/commands/create_test.py | 12 ++---- .../key_value/commands/update_test.py | 6 +-- .../key_value/commands/upsert_test.py | 8 +--- 8 files changed, 21 insertions(+), 68 deletions(-) diff --git a/superset/key_value/commands/delete.py b/superset/key_value/commands/delete.py index b3cf84be0..8b9095c09 100644 --- a/superset/key_value/commands/delete.py +++ b/superset/key_value/commands/delete.py @@ -57,13 +57,7 @@ class DeleteKeyValueCommand(BaseCommand): def delete(self) -> bool: filter_ = get_filter(self.resource, self.key) - entry = ( - db.session.query(KeyValueEntry) - .filter_by(**filter_) - .autoflush(False) - .first() - ) - if entry: + if entry := db.session.query(KeyValueEntry).filter_by(**filter_).first(): db.session.delete(entry) db.session.commit() return True diff --git a/superset/key_value/commands/get.py b/superset/key_value/commands/get.py index 9d659f3bc..8a7a250f1 100644 --- a/superset/key_value/commands/get.py +++ b/superset/key_value/commands/get.py @@ -66,12 +66,7 @@ class GetKeyValueCommand(BaseCommand): def get(self) -> Optional[Any]: filter_ = get_filter(self.resource, self.key) - entry = ( - db.session.query(KeyValueEntry) - .filter_by(**filter_) - .autoflush(False) - .first() - ) + entry = db.session.query(KeyValueEntry).filter_by(**filter_).first() if entry and (entry.expires_on is None or entry.expires_on > datetime.now()): return self.codec.decode(entry.value) return None diff --git a/superset/key_value/commands/update.py b/superset/key_value/commands/update.py index becd6d9ca..4bcd49624 100644 --- a/superset/key_value/commands/update.py +++ b/superset/key_value/commands/update.py @@ -77,10 +77,7 @@ class UpdateKeyValueCommand(BaseCommand): def update(self) -> Optional[Key]: filter_ = get_filter(self.resource, self.key) entry: KeyValueEntry = ( - db.session.query(KeyValueEntry) - .filter_by(**filter_) - .autoflush(False) - .first() + db.session.query(KeyValueEntry).filter_by(**filter_).first() ) if entry: entry.value = self.codec.encode(self.value) diff --git a/superset/key_value/commands/upsert.py b/superset/key_value/commands/upsert.py index c5668f116..9a4092c00 100644 --- a/superset/key_value/commands/upsert.py +++ b/superset/key_value/commands/upsert.py @@ -81,10 +81,7 @@ class UpsertKeyValueCommand(BaseCommand): def upsert(self) -> Key: filter_ = get_filter(self.resource, self.key) entry: KeyValueEntry = ( - db.session.query(KeyValueEntry) - .filter_by(**filter_) - .autoflush(False) - .first() + db.session.query(KeyValueEntry).filter_by(**filter_).first() ) if entry: entry.value = self.codec.encode(self.value) diff --git a/superset/tags/models.py b/superset/tags/models.py index a469c7a33..7a77677a3 100644 --- a/superset/tags/models.py +++ b/superset/tags/models.py @@ -20,9 +20,9 @@ import enum from typing import TYPE_CHECKING from flask_appbuilder import Model -from sqlalchemy import Column, Enum, ForeignKey, Integer, String, Table, Text +from sqlalchemy import Column, Enum, ForeignKey, Integer, orm, String, Table, Text from sqlalchemy.engine.base import Connection -from sqlalchemy.orm import relationship, Session, sessionmaker +from sqlalchemy.orm import relationship, sessionmaker from sqlalchemy.orm.mapper import Mapper from superset import security_manager @@ -35,7 +35,7 @@ if TYPE_CHECKING: from superset.models.slice import Slice from superset.models.sql_lab import Query -Session = sessionmaker(autoflush=False) +Session = sessionmaker() user_favorite_tag_table = Table( "user_favorite_tag", @@ -111,7 +111,7 @@ class TaggedObject(Model, AuditMixinNullable): tag = relationship("Tag", back_populates="objects", overlaps="tags") -def get_tag(name: str, session: Session, type_: TagType) -> Tag: +def get_tag(name: str, session: orm.Session, type_: TagType) -> Tag: tag_name = name.strip() tag = session.query(Tag).filter_by(name=tag_name, type=type_).one_or_none() if tag is None: @@ -148,7 +148,7 @@ class ObjectUpdater: @classmethod def _add_owners( cls, - session: Session, + session: orm.Session, target: Dashboard | FavStar | Slice | Query | SqlaTable, ) -> None: for owner_id in cls.get_owners_ids(target): @@ -166,9 +166,7 @@ class ObjectUpdater: connection: Connection, target: Dashboard | FavStar | Slice | Query | SqlaTable, ) -> None: - session = Session(bind=connection) - - try: + with Session(bind=connection) as session: # add `owner:` tags cls._add_owners(session, target) @@ -179,8 +177,6 @@ class ObjectUpdater: ) session.add(tagged_object) session.commit() - finally: - session.close() @classmethod def after_update( @@ -189,9 +185,7 @@ class ObjectUpdater: connection: Connection, target: Dashboard | FavStar | Slice | Query | SqlaTable, ) -> None: - session = Session(bind=connection) - - try: + with Session(bind=connection) as session: # delete current `owner:` tags query = ( session.query(TaggedObject.id) @@ -210,8 +204,6 @@ class ObjectUpdater: # add `owner:` tags cls._add_owners(session, target) session.commit() - finally: - session.close() @classmethod def after_delete( @@ -220,9 +212,7 @@ class ObjectUpdater: connection: Connection, target: Dashboard | FavStar | Slice | Query | SqlaTable, ) -> None: - session = Session(bind=connection) - - try: + with Session(bind=connection) as session: # delete row from `tagged_objects` session.query(TaggedObject).filter( TaggedObject.object_type == cls.object_type, @@ -230,8 +220,6 @@ class ObjectUpdater: ).delete() session.commit() - finally: - session.close() class ChartUpdater(ObjectUpdater): @@ -271,8 +259,7 @@ class FavStarUpdater: def after_insert( cls, _mapper: Mapper, connection: Connection, target: FavStar ) -> None: - session = Session(bind=connection) - try: + with Session(bind=connection) as session: name = f"favorited_by:{target.user_id}" tag = get_tag(name, session, TagType.favorited_by) tagged_object = TaggedObject( @@ -282,15 +269,12 @@ class FavStarUpdater: ) session.add(tagged_object) session.commit() - finally: - session.close() @classmethod def after_delete( cls, _mapper: Mapper, connection: Connection, target: FavStar ) -> None: - session = Session(bind=connection) - try: + with Session(bind=connection) as session: name = f"favorited_by:{target.user_id}" query = ( session.query(TaggedObject.id) @@ -307,5 +291,3 @@ class FavStarUpdater: ) session.commit() - finally: - session.close() diff --git a/tests/integration_tests/key_value/commands/create_test.py b/tests/integration_tests/key_value/commands/create_test.py index a2ee3d13a..c7ba076b5 100644 --- a/tests/integration_tests/key_value/commands/create_test.py +++ b/tests/integration_tests/key_value/commands/create_test.py @@ -46,9 +46,7 @@ def test_create_id_entry(app_context: AppContext, admin: User) -> None: value=JSON_VALUE, codec=JSON_CODEC, ).run() - entry = ( - db.session.query(KeyValueEntry).filter_by(id=key.id).autoflush(False).one() - ) + entry = db.session.query(KeyValueEntry).filter_by(id=key.id).one() assert json.loads(entry.value) == JSON_VALUE assert entry.created_by_fk == admin.id db.session.delete(entry) @@ -63,9 +61,7 @@ def test_create_uuid_entry(app_context: AppContext, admin: User) -> None: key = CreateKeyValueCommand( resource=RESOURCE, value=JSON_VALUE, codec=JSON_CODEC ).run() - entry = ( - db.session.query(KeyValueEntry).filter_by(uuid=key.uuid).autoflush(False).one() - ) + entry = db.session.query(KeyValueEntry).filter_by(uuid=key.uuid).one() assert json.loads(entry.value) == JSON_VALUE assert entry.created_by_fk == admin.id db.session.delete(entry) @@ -93,9 +89,7 @@ def test_create_pickle_entry(app_context: AppContext, admin: User) -> None: value=PICKLE_VALUE, codec=PICKLE_CODEC, ).run() - entry = ( - db.session.query(KeyValueEntry).filter_by(id=key.id).autoflush(False).one() - ) + entry = db.session.query(KeyValueEntry).filter_by(id=key.id).one() assert type(pickle.loads(entry.value)) == type(PICKLE_VALUE) assert entry.created_by_fk == admin.id db.session.delete(entry) diff --git a/tests/integration_tests/key_value/commands/update_test.py b/tests/integration_tests/key_value/commands/update_test.py index 2c0fc3e31..816a6f857 100644 --- a/tests/integration_tests/key_value/commands/update_test.py +++ b/tests/integration_tests/key_value/commands/update_test.py @@ -57,7 +57,7 @@ def test_update_id_entry( ).run() assert key is not None assert key.id == ID_KEY - entry = db.session.query(KeyValueEntry).filter_by(id=ID_KEY).autoflush(False).one() + entry = db.session.query(KeyValueEntry).filter_by(id=ID_KEY).one() assert json.loads(entry.value) == NEW_VALUE assert entry.changed_by_fk == admin.id @@ -79,9 +79,7 @@ def test_update_uuid_entry( ).run() assert key is not None assert key.uuid == UUID_KEY - entry = ( - db.session.query(KeyValueEntry).filter_by(uuid=UUID_KEY).autoflush(False).one() - ) + entry = db.session.query(KeyValueEntry).filter_by(uuid=UUID_KEY).one() assert json.loads(entry.value) == NEW_VALUE assert entry.changed_by_fk == admin.id diff --git a/tests/integration_tests/key_value/commands/upsert_test.py b/tests/integration_tests/key_value/commands/upsert_test.py index c26b66d02..9b094ef65 100644 --- a/tests/integration_tests/key_value/commands/upsert_test.py +++ b/tests/integration_tests/key_value/commands/upsert_test.py @@ -57,9 +57,7 @@ def test_upsert_id_entry( ).run() assert key is not None assert key.id == ID_KEY - entry = ( - db.session.query(KeyValueEntry).filter_by(id=int(ID_KEY)).autoflush(False).one() - ) + entry = db.session.query(KeyValueEntry).filter_by(id=int(ID_KEY)).one() assert json.loads(entry.value) == NEW_VALUE assert entry.changed_by_fk == admin.id @@ -81,9 +79,7 @@ def test_upsert_uuid_entry( ).run() assert key is not None assert key.uuid == UUID_KEY - entry = ( - db.session.query(KeyValueEntry).filter_by(uuid=UUID_KEY).autoflush(False).one() - ) + entry = db.session.query(KeyValueEntry).filter_by(uuid=UUID_KEY).one() assert json.loads(entry.value) == NEW_VALUE assert entry.changed_by_fk == admin.id