diff --git a/superset/commands/database/test_connection.py b/superset/commands/database/test_connection.py index 431918c6b..6bf69bbb8 100644 --- a/superset/commands/database/test_connection.py +++ b/superset/commands/database/test_connection.py @@ -138,9 +138,7 @@ class TestConnectionDatabaseCommand(BaseCommand): with closing(engine.raw_connection()) as conn: return engine.dialect.do_ping(conn) - with database.get_sqla_engine_with_context( - override_ssh_tunnel=ssh_tunnel - ) as engine: + with database.get_sqla_engine(override_ssh_tunnel=ssh_tunnel) as engine: try: alive = func_timeout( app.config["TEST_DATABASE_CONNECTION_TIMEOUT"].total_seconds(), diff --git a/superset/commands/database/validate.py b/superset/commands/database/validate.py index 83bbc4e90..e550f51d7 100644 --- a/superset/commands/database/validate.py +++ b/superset/commands/database/validate.py @@ -102,7 +102,7 @@ class ValidateDatabaseParametersCommand(BaseCommand): database.db_engine_spec.mutate_db_for_connection_test(database) alive = False - with database.get_sqla_engine_with_context() as engine: + with database.get_sqla_engine() as engine: try: with closing(engine.raw_connection()) as conn: alive = engine.dialect.do_ping(conn) diff --git a/superset/commands/dataset/importers/v1/utils.py b/superset/commands/dataset/importers/v1/utils.py index 04fc81e24..50bb916b0 100644 --- a/superset/commands/dataset/importers/v1/utils.py +++ b/superset/commands/dataset/importers/v1/utils.py @@ -217,7 +217,7 @@ def load_data(data_uri: str, dataset: SqlaTable, database: Database) -> None: ) else: logger.warning("Loading data outside the import transaction") - with database.get_sqla_engine_with_context() as engine: + with database.get_sqla_engine() as engine: df.to_sql( dataset.table_name, con=engine, diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index bcb4035c9..ec1cc741d 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -771,7 +771,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods ... connection.execute(sql) """ - return database.get_sqla_engine_with_context(schema=schema, source=source) + return database.get_sqla_engine(schema=schema, source=source) @classmethod def get_timestamp_expr( diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py index a8d834276..63860e8aa 100644 --- a/superset/db_engine_specs/bigquery.py +++ b/superset/db_engine_specs/bigquery.py @@ -456,7 +456,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met In BigQuery, a catalog is called a "project". """ engine: Engine - with database.get_sqla_engine_with_context() as engine: + with database.get_sqla_engine() as engine: client = cls._get_client(engine) projects = client.list_projects() diff --git a/superset/examples/bart_lines.py b/superset/examples/bart_lines.py index ad96aecac..9ce27d495 100644 --- a/superset/examples/bart_lines.py +++ b/superset/examples/bart_lines.py @@ -29,7 +29,7 @@ from .helpers import get_example_url, get_table_connector_registry def load_bart_lines(only_metadata: bool = False, force: bool = False) -> None: tbl_name = "bart_lines" database = get_example_database() - with database.get_sqla_engine_with_context() as engine: + with database.get_sqla_engine() as engine: schema = inspect(engine).default_schema_name table_exists = database.has_table_by_name(tbl_name) diff --git a/superset/examples/birth_names.py b/superset/examples/birth_names.py index c9e38f168..2e711bef2 100644 --- a/superset/examples/birth_names.py +++ b/superset/examples/birth_names.py @@ -63,7 +63,7 @@ def load_data(tbl_name: str, database: Database, sample: bool = False) -> None: pdf.ds = pd.to_datetime(pdf.ds, unit="ms") pdf = pdf.head(100) if sample else pdf - with database.get_sqla_engine_with_context() as engine: + with database.get_sqla_engine() as engine: schema = inspect(engine).default_schema_name pdf.to_sql( @@ -91,7 +91,7 @@ def load_birth_names( ) -> None: """Loading birth name dataset from a zip file in the repo""" database = get_example_database() - with database.get_sqla_engine_with_context() as engine: + with database.get_sqla_engine() as engine: schema = inspect(engine).default_schema_name tbl_name = "birth_names" diff --git a/superset/examples/country_map.py b/superset/examples/country_map.py index 3caf63758..59c257bc8 100644 --- a/superset/examples/country_map.py +++ b/superset/examples/country_map.py @@ -40,7 +40,7 @@ def load_country_map_data(only_metadata: bool = False, force: bool = False) -> N tbl_name = "birth_france_by_region" database = database_utils.get_example_database() - with database.get_sqla_engine_with_context() as engine: + with database.get_sqla_engine() as engine: schema = inspect(engine).default_schema_name table_exists = database.has_table_by_name(tbl_name) diff --git a/superset/examples/energy.py b/superset/examples/energy.py index 998ee97a3..1f11c0f3f 100644 --- a/superset/examples/energy.py +++ b/superset/examples/energy.py @@ -42,7 +42,7 @@ def load_energy( tbl_name = "energy_usage" database = database_utils.get_example_database() - with database.get_sqla_engine_with_context() as engine: + with database.get_sqla_engine() as engine: schema = inspect(engine).default_schema_name table_exists = database.has_table_by_name(tbl_name) diff --git a/superset/examples/flights.py b/superset/examples/flights.py index c7890cfa1..a42df2023 100644 --- a/superset/examples/flights.py +++ b/superset/examples/flights.py @@ -27,7 +27,7 @@ def load_flights(only_metadata: bool = False, force: bool = False) -> None: """Loading random time series data from a zip file in the repo""" tbl_name = "flights" database = database_utils.get_example_database() - with database.get_sqla_engine_with_context() as engine: + with database.get_sqla_engine() as engine: schema = inspect(engine).default_schema_name table_exists = database.has_table_by_name(tbl_name) diff --git a/superset/examples/long_lat.py b/superset/examples/long_lat.py index 6f7cc6402..95cccadc2 100644 --- a/superset/examples/long_lat.py +++ b/superset/examples/long_lat.py @@ -39,7 +39,7 @@ def load_long_lat_data(only_metadata: bool = False, force: bool = False) -> None """Loading lat/long data from a csv file in the repo""" tbl_name = "long_lat" database = database_utils.get_example_database() - with database.get_sqla_engine_with_context() as engine: + with database.get_sqla_engine() as engine: schema = inspect(engine).default_schema_name table_exists = database.has_table_by_name(tbl_name) diff --git a/superset/examples/multiformat_time_series.py b/superset/examples/multiformat_time_series.py index 4c1e79631..91799b2c2 100644 --- a/superset/examples/multiformat_time_series.py +++ b/superset/examples/multiformat_time_series.py @@ -39,7 +39,7 @@ def load_multiformat_time_series( # pylint: disable=too-many-locals """Loading time series data from a zip file in the repo""" tbl_name = "multiformat_time_series" database = get_example_database() - with database.get_sqla_engine_with_context() as engine: + with database.get_sqla_engine() as engine: schema = inspect(engine).default_schema_name table_exists = database.has_table_by_name(tbl_name) diff --git a/superset/examples/paris.py b/superset/examples/paris.py index fa5c77b84..cea784be7 100644 --- a/superset/examples/paris.py +++ b/superset/examples/paris.py @@ -28,7 +28,7 @@ from .helpers import get_example_url, get_table_connector_registry def load_paris_iris_geojson(only_metadata: bool = False, force: bool = False) -> None: tbl_name = "paris_iris_mapping" database = database_utils.get_example_database() - with database.get_sqla_engine_with_context() as engine: + with database.get_sqla_engine() as engine: schema = inspect(engine).default_schema_name table_exists = database.has_table_by_name(tbl_name) diff --git a/superset/examples/random_time_series.py b/superset/examples/random_time_series.py index 4a2d10aee..9b5306781 100644 --- a/superset/examples/random_time_series.py +++ b/superset/examples/random_time_series.py @@ -37,7 +37,7 @@ def load_random_time_series_data( """Loading random time series data from a zip file in the repo""" tbl_name = "random_time_series" database = database_utils.get_example_database() - with database.get_sqla_engine_with_context() as engine: + with database.get_sqla_engine() as engine: schema = inspect(engine).default_schema_name table_exists = database.has_table_by_name(tbl_name) diff --git a/superset/examples/sf_population_polygons.py b/superset/examples/sf_population_polygons.py index ba5905f58..d97ffd3ae 100644 --- a/superset/examples/sf_population_polygons.py +++ b/superset/examples/sf_population_polygons.py @@ -30,7 +30,7 @@ def load_sf_population_polygons( ) -> None: tbl_name = "sf_population_polygons" database = database_utils.get_example_database() - with database.get_sqla_engine_with_context() as engine: + with database.get_sqla_engine() as engine: schema = inspect(engine).default_schema_name table_exists = database.has_table_by_name(tbl_name) diff --git a/superset/examples/supported_charts_dashboard.py b/superset/examples/supported_charts_dashboard.py index 6ca33a87c..371f03d18 100644 --- a/superset/examples/supported_charts_dashboard.py +++ b/superset/examples/supported_charts_dashboard.py @@ -439,7 +439,7 @@ def load_supported_charts_dashboard() -> None: """Loading a dashboard featuring supported charts""" database = get_example_database() - with database.get_sqla_engine_with_context() as engine: + with database.get_sqla_engine() as engine: schema = inspect(engine).default_schema_name tbl_name = "birth_names" diff --git a/superset/examples/world_bank.py b/superset/examples/world_bank.py index 5e895fd78..74ea2c43a 100644 --- a/superset/examples/world_bank.py +++ b/superset/examples/world_bank.py @@ -48,7 +48,7 @@ def load_world_bank_health_n_pop( # pylint: disable=too-many-locals, too-many-s """Loads the world bank health dataset, slices and a dashboard""" tbl_name = "wb_health_population" database = superset.utils.database.get_example_database() - with database.get_sqla_engine_with_context() as engine: + with database.get_sqla_engine() as engine: schema = inspect(engine).default_schema_name table_exists = database.has_table_by_name(tbl_name) diff --git a/superset/extensions/metadb.py b/superset/extensions/metadb.py index bdfe1ae1e..ea6ce118c 100644 --- a/superset/extensions/metadb.py +++ b/superset/extensions/metadb.py @@ -315,7 +315,7 @@ class SupersetShillelaghAdapter(Adapter): # store this callable for later whenever we need an engine self.engine_context = partial( - database.get_sqla_engine_with_context, + database.get_sqla_engine, self.schema, ) diff --git a/superset/models/core.py b/superset/models/core.py index 92f6946f1..bfd4c3959 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -382,7 +382,7 @@ class Database( ) @contextmanager - def get_sqla_engine_with_context( + def get_sqla_engine( self, schema: str | None = None, nullpool: bool = True, @@ -424,6 +424,11 @@ class Database( sqlalchemy_uri=sqlalchemy_uri, ) + # The `get_sqla_engine_with_context` was renamed to `get_sqla_engine`, but we kept a + # reference to the old method to prevent breaking third-party applications. + # TODO (betodealmeida): Remove in 5.0 + get_sqla_engine_with_context = get_sqla_engine + def _get_sqla_engine( self, schema: str | None = None, @@ -531,7 +536,7 @@ class Database( nullpool: bool = True, source: utils.QuerySource | None = None, ) -> Connection: - with self.get_sqla_engine_with_context( + with self.get_sqla_engine( schema=schema, nullpool=nullpool, source=source ) as engine: with closing(engine.raw_connection()) as conn: @@ -574,7 +579,7 @@ class Database( mutator: Callable[[pd.DataFrame], None] | None = None, ) -> pd.DataFrame: sqls = self.db_engine_spec.parse_sql(sql) - with self.get_sqla_engine_with_context(schema) as engine: + with self.get_sqla_engine(schema) as engine: engine_url = engine.url mutate_after_split = config["MUTATE_AFTER_SPLIT"] sql_query_mutator = config["SQL_QUERY_MUTATOR"] @@ -636,7 +641,7 @@ class Database( return df def compile_sqla_query(self, qry: Select, schema: str | None = None) -> str: - with self.get_sqla_engine_with_context(schema) as engine: + with self.get_sqla_engine(schema) as engine: sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True})) # pylint: disable=protected-access @@ -656,7 +661,7 @@ class Database( cols: list[ResultSetColumnType] | None = None, ) -> str: """Generates a ``select *`` statement in the proper dialect""" - with self.get_sqla_engine_with_context(schema) as engine: + with self.get_sqla_engine(schema) as engine: return self.db_engine_spec.select_star( self, table_name, @@ -753,9 +758,7 @@ class Database( def get_inspector_with_context( self, ssh_tunnel: SSHTunnel | None = None ) -> Inspector: - with self.get_sqla_engine_with_context( - override_ssh_tunnel=ssh_tunnel - ) as engine: + with self.get_sqla_engine(override_ssh_tunnel=ssh_tunnel) as engine: yield sqla.inspect(engine) @cache_util.memoized_func( @@ -835,7 +838,7 @@ class Database( def get_table(self, table_name: str, schema: str | None = None) -> Table: extra = self.get_extra() meta = MetaData(**extra.get("metadata_params", {})) - with self.get_sqla_engine_with_context() as engine: + with self.get_sqla_engine() as engine: return Table( table_name, meta, @@ -939,11 +942,11 @@ class Database( return self.perm # type: ignore def has_table(self, table: Table) -> bool: - with self.get_sqla_engine_with_context() as engine: + with self.get_sqla_engine() as engine: return engine.has_table(table.table_name, table.schema or None) def has_table_by_name(self, table_name: str, schema: str | None = None) -> bool: - with self.get_sqla_engine_with_context() as engine: + with self.get_sqla_engine() as engine: return engine.has_table(table_name, schema) @classmethod @@ -962,7 +965,7 @@ class Database( return view_name in view_names def has_view(self, view_name: str, schema: str | None = None) -> bool: - with self.get_sqla_engine_with_context(schema) as engine: + with self.get_sqla_engine(schema) as engine: return engine.run_callable( self._has_view, engine.dialect, view_name, schema ) diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py index 0a0d789c7..aa961a2ff 100644 --- a/superset/models/dashboard.py +++ b/superset/models/dashboard.py @@ -217,7 +217,7 @@ class Dashboard(AuditMixinNullable, ImportExportMixin, Model): @property def sqla_metadata(self) -> None: # pylint: disable=no-member - with self.get_sqla_engine_with_context() as engine: + with self.get_sqla_engine() as engine: meta = MetaData(bind=engine) meta.reflect() diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 4b2287390..ad90e664b 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -1390,7 +1390,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods if self.fetch_values_predicate: qry = qry.where(self.get_fetch_values_predicate(template_processor=tp)) - with self.database.get_sqla_engine_with_context() as engine: + with self.database.get_sqla_engine() as engine: sql = qry.compile(engine, compile_kwargs={"literal_binds": True}) sql = self._apply_cte(sql, cte) sql = self.mutate_query_from_config(sql) @@ -1992,7 +1992,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods and db_engine_spec.allows_hidden_cc_in_orderby and col.name in [select_col.name for select_col in select_exprs] ): - with self.database.get_sqla_engine_with_context() as engine: + with self.database.get_sqla_engine() as engine: quote = engine.dialect.identifier_preparer.quote col = literal_column(quote(col.name)) direction = sa.asc if ascending else sa.desc diff --git a/superset/sql_lab.py b/superset/sql_lab.py index e34f7e2fd..e87ae9c5b 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -644,7 +644,7 @@ def cancel_query(query: Query) -> bool: if cancel_query_id is None: return False - with query.database.get_sqla_engine_with_context( + with query.database.get_sqla_engine( query.schema, source=QuerySource.SQL_LAB ) as engine: with closing(engine.raw_connection()) as conn: diff --git a/superset/sql_validators/presto_db.py b/superset/sql_validators/presto_db.py index 4852f70ee..8e7d8c720 100644 --- a/superset/sql_validators/presto_db.py +++ b/superset/sql_validators/presto_db.py @@ -160,9 +160,7 @@ class PrestoDBSQLValidator(BaseSQLValidator): logger.info("Validating %i statement(s)", len(statements)) # todo(hughhh): update this to use new database.get_raw_connection() # this function keeps stalling CI - with database.get_sqla_engine_with_context( - schema, source=QuerySource.SQL_LAB - ) as engine: + with database.get_sqla_engine(schema, source=QuerySource.SQL_LAB) as engine: # Sharing a single connection and cursor across the # execution of all statements (if many) annotations: list[SQLValidationAnnotation] = [] diff --git a/superset/utils/core.py b/superset/utils/core.py index de1034ddb..988baed0a 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1170,7 +1170,7 @@ def get_example_default_schema() -> str | None: Return the default schema of the examples database, if any. """ database = get_example_database() - with database.get_sqla_engine_with_context() as engine: + with database.get_sqla_engine() as engine: return inspect(engine).default_schema_name diff --git a/superset/utils/mock_data.py b/superset/utils/mock_data.py index 67bd9ad73..fc082ecb4 100644 --- a/superset/utils/mock_data.py +++ b/superset/utils/mock_data.py @@ -184,7 +184,7 @@ def add_data( database = get_example_database() table_exists = database.has_table_by_name(table_name) - with database.get_sqla_engine_with_context() as engine: + with database.get_sqla_engine() as engine: if columns is None: if not table_exists: raise Exception( # pylint: disable=broad-exception-raised diff --git a/tests/conftest.py b/tests/conftest.py index 9d13e5817..c659a8524 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -72,7 +72,7 @@ def example_db_provider() -> Callable[[], Database]: @fixture(scope="session") def example_db_engine(example_db_provider: Callable[[], Database]) -> Engine: with app.app_context(): - with example_db_provider().get_sqla_engine_with_context() as engine: + with example_db_provider().get_sqla_engine() as engine: return engine diff --git a/tests/integration_tests/celery_tests.py b/tests/integration_tests/celery_tests.py index 5774d8920..384e6674a 100644 --- a/tests/integration_tests/celery_tests.py +++ b/tests/integration_tests/celery_tests.py @@ -113,7 +113,7 @@ def drop_table_if_exists(table_name: str, table_type: CtasMethod) -> None: """Drop table if it exists, works on any DB""" sql = f"DROP {table_type} IF EXISTS {table_name}" database = get_example_database() - with database.get_sqla_engine_with_context() as engine: + with database.get_sqla_engine() as engine: engine.execute(sql) diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index b90416587..cc11c4df4 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -212,7 +212,7 @@ def setup_presto_if_needed(): if backend in {"presto", "hive"}: database = get_example_database() - with database.get_sqla_engine_with_context() as engine: + with database.get_sqla_engine() as engine: drop_from_schema(engine, CTAS_SCHEMA_NAME) engine.execute(f"DROP SCHEMA IF EXISTS {CTAS_SCHEMA_NAME}") engine.execute(f"CREATE SCHEMA {CTAS_SCHEMA_NAME}") @@ -343,7 +343,7 @@ def physical_dataset(): example_database = get_example_database() - with example_database.get_sqla_engine_with_context() as engine: + with example_database.get_sqla_engine() as engine: quoter = get_identifier_quoter(engine.name) # sqlite can only execute one statement at a time engine.execute( diff --git a/tests/integration_tests/csv_upload_tests.py b/tests/integration_tests/csv_upload_tests.py index 741f4c1bc..85be02cff 100644 --- a/tests/integration_tests/csv_upload_tests.py +++ b/tests/integration_tests/csv_upload_tests.py @@ -71,7 +71,7 @@ def _setup_csv_upload(): yield upload_db = get_upload_db() - with upload_db.get_sqla_engine_with_context() as engine: + with upload_db.get_sqla_engine() as engine: engine.execute(f"DROP TABLE IF EXISTS {EXCEL_UPLOAD_TABLE}") engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE}") engine.execute(f"DROP TABLE IF EXISTS {PARQUET_UPLOAD_TABLE}") @@ -268,7 +268,7 @@ def test_import_csv_enforced_schema(mock_event_logger): table=CSV_UPLOAD_TABLE_W_SCHEMA, ) - with get_upload_db().get_sqla_engine_with_context() as engine: + with get_upload_db().get_sqla_engine() as engine: data = engine.execute( f"SELECT * from {ADMIN_SCHEMA_NAME}.{CSV_UPLOAD_TABLE_W_SCHEMA} ORDER BY b" ).fetchall() @@ -294,7 +294,7 @@ def test_import_csv_enforced_schema(mock_event_logger): assert success_msg in resp # Clean up - with get_upload_db().get_sqla_engine_with_context() as engine: + with get_upload_db().get_sqla_engine() as engine: engine.execute(f"DROP TABLE {full_table_name}") @@ -380,7 +380,7 @@ def test_import_csv(mock_event_logger): extra={"null_values": '["", "john"]', "if_exists": "replace"}, ) # make sure that john and empty string are replaced with None - with test_db.get_sqla_engine_with_context() as engine: + with test_db.get_sqla_engine() as engine: data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE} ORDER BY c").fetchall() assert data == [(None, 1, "x"), ("paul", 2, None)] # default null values @@ -390,7 +390,7 @@ def test_import_csv(mock_event_logger): assert data == [("john", 1, "x"), ("paul", 2, None)] # cleanup - with get_upload_db().get_sqla_engine_with_context() as engine: + with get_upload_db().get_sqla_engine() as engine: engine.execute(f"DROP TABLE {full_table_name}") # with dtype @@ -403,12 +403,12 @@ def test_import_csv(mock_event_logger): # you can change the type to something compatible, like an object to string # or an int to a float # file upload should work as normal - with test_db.get_sqla_engine_with_context() as engine: + with test_db.get_sqla_engine() as engine: data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE} ORDER BY b").fetchall() assert data == [("john", 1), ("paul", 2)] # cleanup - with get_upload_db().get_sqla_engine_with_context() as engine: + with get_upload_db().get_sqla_engine() as engine: engine.execute(f"DROP TABLE {full_table_name}") # with dtype - wrong type @@ -475,7 +475,7 @@ def test_import_excel(mock_event_logger): table=EXCEL_UPLOAD_TABLE, ) - with test_db.get_sqla_engine_with_context() as engine: + with test_db.get_sqla_engine() as engine: data = engine.execute( f"SELECT * from {EXCEL_UPLOAD_TABLE} ORDER BY b" ).fetchall() @@ -541,7 +541,7 @@ def test_import_parquet(mock_event_logger): ) assert success_msg_f1 in resp - with test_db.get_sqla_engine_with_context() as engine: + with test_db.get_sqla_engine() as engine: data = engine.execute( f"SELECT * from {PARQUET_UPLOAD_TABLE} ORDER BY b" ).fetchall() @@ -554,7 +554,7 @@ def test_import_parquet(mock_event_logger): success_msg_f2 = f"Columnar file {escaped_parquet(ZIP_FILENAME)} uploaded to table {escaped_double_quotes(full_table_name)}" assert success_msg_f2 in resp - with test_db.get_sqla_engine_with_context() as engine: + with test_db.get_sqla_engine() as engine: data = engine.execute( f"SELECT * from {PARQUET_UPLOAD_TABLE} ORDER BY b" ).fetchall() diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index b94424225..e25a74e40 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -895,7 +895,7 @@ class TestDatabaseApi(SupersetTestCase): if database.backend == "mysql": query = query.replace('"', "`") - with database.get_sqla_engine_with_context() as engine: + with database.get_sqla_engine() as engine: engine.execute(query) self.login(ADMIN_USERNAME) diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 939c03a4e..3597bcdb0 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -718,7 +718,7 @@ class TestDatasetApi(SupersetTestCase): return example_db = get_example_database() - with example_db.get_sqla_engine_with_context() as engine: + with example_db.get_sqla_engine() as engine: engine.execute( f"CREATE TABLE {CTAS_SCHEMA_NAME}.birth_names AS SELECT 2 as two" ) @@ -739,7 +739,7 @@ class TestDatasetApi(SupersetTestCase): uri = f'api/v1/dataset/{data.get("id")}' rv = self.client.delete(uri) assert rv.status_code == 200 - with example_db.get_sqla_engine_with_context() as engine: + with example_db.get_sqla_engine() as engine: engine.execute(f"DROP TABLE {CTAS_SCHEMA_NAME}.birth_names") def test_create_dataset_validate_database(self): @@ -800,7 +800,7 @@ class TestDatasetApi(SupersetTestCase): mock_get_table.return_value = None example_db = get_example_database() - with example_db.get_sqla_engine_with_context() as engine: + with example_db.get_sqla_engine() as engine: engine = engine dialect = engine.dialect @@ -2389,7 +2389,7 @@ class TestDatasetApi(SupersetTestCase): self.login(ADMIN_USERNAME) examples_db = get_example_database() - with examples_db.get_sqla_engine_with_context() as engine: + with examples_db.get_sqla_engine() as engine: engine.execute("DROP TABLE IF EXISTS test_create_sqla_table_api") engine.execute("CREATE TABLE test_create_sqla_table_api AS SELECT 2 as col") @@ -2415,7 +2415,7 @@ class TestDatasetApi(SupersetTestCase): db.session.delete(table) db.session.commit() - with examples_db.get_sqla_engine_with_context() as engine: + with examples_db.get_sqla_engine() as engine: engine.execute("DROP TABLE test_create_sqla_table_api") @pytest.mark.usefixtures( diff --git a/tests/integration_tests/datasets/commands_tests.py b/tests/integration_tests/datasets/commands_tests.py index cdf3cb6d9..806346693 100644 --- a/tests/integration_tests/datasets/commands_tests.py +++ b/tests/integration_tests/datasets/commands_tests.py @@ -563,7 +563,7 @@ class TestCreateDatasetCommand(SupersetTestCase): def test_create_dataset_command(self): examples_db = get_example_database() - with examples_db.get_sqla_engine_with_context() as engine: + with examples_db.get_sqla_engine() as engine: engine.execute("DROP TABLE IF EXISTS test_create_dataset_command") engine.execute( "CREATE TABLE test_create_dataset_command AS SELECT 2 as col" @@ -585,7 +585,7 @@ class TestCreateDatasetCommand(SupersetTestCase): self.assertEqual([owner.username for owner in table.owners], ["admin"]) db.session.delete(table) - with examples_db.get_sqla_engine_with_context() as engine: + with examples_db.get_sqla_engine() as engine: engine.execute("DROP TABLE test_create_dataset_command") db.session.commit() diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py index 4b02bb59a..34da3df35 100644 --- a/tests/integration_tests/datasource_tests.py +++ b/tests/integration_tests/datasource_tests.py @@ -47,7 +47,7 @@ def create_test_table_context(database: Database): schema = get_example_default_schema() full_table_name = f"{schema}.test_table" if schema else "test_table" - with database.get_sqla_engine_with_context() as engine: + with database.get_sqla_engine() as engine: engine.execute( f"CREATE TABLE IF NOT EXISTS {full_table_name} AS SELECT 1 as first, 2 as second" ) @@ -56,7 +56,7 @@ def create_test_table_context(database: Database): yield db.session - with database.get_sqla_engine_with_context() as engine: + with database.get_sqla_engine() as engine: engine.execute(f"DROP TABLE {full_table_name}") diff --git a/tests/integration_tests/db_engine_specs/hive_tests.py b/tests/integration_tests/db_engine_specs/hive_tests.py index 374d99c02..d4b2e14d5 100644 --- a/tests/integration_tests/db_engine_specs/hive_tests.py +++ b/tests/integration_tests/db_engine_specs/hive_tests.py @@ -193,7 +193,7 @@ def test_df_to_sql_if_exists_replace(mock_upload_to_s3, mock_g): mock_database = mock.MagicMock() mock_database.get_df.return_value.empty = False mock_execute = mock.MagicMock(return_value=True) - mock_database.get_sqla_engine_with_context.return_value.__enter__.return_value.execute = ( + mock_database.get_sqla_engine.return_value.__enter__.return_value.execute = ( mock_execute ) table_name = "foobar" @@ -220,7 +220,7 @@ def test_df_to_sql_if_exists_replace_with_schema(mock_upload_to_s3, mock_g): mock_database = mock.MagicMock() mock_database.get_df.return_value.empty = False mock_execute = mock.MagicMock(return_value=True) - mock_database.get_sqla_engine_with_context.return_value.__enter__.return_value.execute = ( + mock_database.get_sqla_engine.return_value.__enter__.return_value.execute = ( mock_execute ) table_name = "foobar" diff --git a/tests/integration_tests/fixtures/energy_dashboard.py b/tests/integration_tests/fixtures/energy_dashboard.py index 9687fb4af..5d938e054 100644 --- a/tests/integration_tests/fixtures/energy_dashboard.py +++ b/tests/integration_tests/fixtures/energy_dashboard.py @@ -38,7 +38,7 @@ ENERGY_USAGE_TBL_NAME = "energy_usage" def load_energy_table_data(): with app.app_context(): database = get_example_database() - with database.get_sqla_engine_with_context() as engine: + with database.get_sqla_engine() as engine: df = _get_dataframe() df.to_sql( ENERGY_USAGE_TBL_NAME, @@ -52,7 +52,7 @@ def load_energy_table_data(): ) yield with app.app_context(): - with get_example_database().get_sqla_engine_with_context() as engine: + with get_example_database().get_sqla_engine() as engine: engine.execute("DROP TABLE IF EXISTS energy_usage") diff --git a/tests/integration_tests/fixtures/unicode_dashboard.py b/tests/integration_tests/fixtures/unicode_dashboard.py index 78178bcde..e68e8f079 100644 --- a/tests/integration_tests/fixtures/unicode_dashboard.py +++ b/tests/integration_tests/fixtures/unicode_dashboard.py @@ -37,7 +37,7 @@ UNICODE_TBL_NAME = "unicode_test" @pytest.fixture(scope="session") def load_unicode_data(): with app.app_context(): - with get_example_database().get_sqla_engine_with_context() as engine: + with get_example_database().get_sqla_engine() as engine: _get_dataframe().to_sql( UNICODE_TBL_NAME, engine, @@ -51,7 +51,7 @@ def load_unicode_data(): yield with app.app_context(): - with get_example_database().get_sqla_engine_with_context() as engine: + with get_example_database().get_sqla_engine() as engine: engine.execute("DROP TABLE IF EXISTS unicode_test") diff --git a/tests/integration_tests/fixtures/world_bank_dashboard.py b/tests/integration_tests/fixtures/world_bank_dashboard.py index a53cd76aa..6c3d29eb4 100644 --- a/tests/integration_tests/fixtures/world_bank_dashboard.py +++ b/tests/integration_tests/fixtures/world_bank_dashboard.py @@ -50,7 +50,7 @@ def load_world_bank_data(): "country_name": String(255), "region": String(255), } - with database.get_sqla_engine_with_context() as engine: + with database.get_sqla_engine() as engine: _get_dataframe(database).to_sql( WB_HEALTH_POPULATION, engine, @@ -64,7 +64,7 @@ def load_world_bank_data(): yield with app.app_context(): - with get_example_database().get_sqla_engine_with_context() as engine: + with get_example_database().get_sqla_engine() as engine: engine.execute("DROP TABLE IF EXISTS wb_health_population") diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py index 2a4c33a28..b9cbd9332 100644 --- a/tests/integration_tests/model_tests.py +++ b/tests/integration_tests/model_tests.py @@ -56,22 +56,22 @@ class TestDatabaseModel(SupersetTestCase): sqlalchemy_uri = "presto://presto.airbnb.io:8080/hive/default" model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri) - with model.get_sqla_engine_with_context() as engine: + with model.get_sqla_engine() as engine: db = make_url(engine.url).database self.assertEqual("hive/default", db) - with model.get_sqla_engine_with_context(schema="core_db") as engine: + with model.get_sqla_engine(schema="core_db") as engine: db = make_url(engine.url).database self.assertEqual("hive/core_db", db) sqlalchemy_uri = "presto://presto.airbnb.io:8080/hive" model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri) - with model.get_sqla_engine_with_context() as engine: + with model.get_sqla_engine() as engine: db = make_url(engine.url).database self.assertEqual("hive", db) - with model.get_sqla_engine_with_context(schema="core_db") as engine: + with model.get_sqla_engine(schema="core_db") as engine: db = make_url(engine.url).database self.assertEqual("hive/core_db", db) @@ -79,11 +79,11 @@ class TestDatabaseModel(SupersetTestCase): sqlalchemy_uri = "postgresql+psycopg2://postgres.airbnb.io:5439/prod" model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri) - with model.get_sqla_engine_with_context() as engine: + with model.get_sqla_engine() as engine: db = make_url(engine.url).database self.assertEqual("prod", db) - with model.get_sqla_engine_with_context(schema="foo") as engine: + with model.get_sqla_engine(schema="foo") as engine: db = make_url(engine.url).database self.assertEqual("prod", db) @@ -97,11 +97,11 @@ class TestDatabaseModel(SupersetTestCase): sqlalchemy_uri = "hive://hive@hive.airbnb.io:10000/default?auth=NOSASL" model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri) - with model.get_sqla_engine_with_context() as engine: + with model.get_sqla_engine() as engine: db = make_url(engine.url).database self.assertEqual("default", db) - with model.get_sqla_engine_with_context(schema="core_db") as engine: + with model.get_sqla_engine(schema="core_db") as engine: db = make_url(engine.url).database self.assertEqual("core_db", db) @@ -112,11 +112,11 @@ class TestDatabaseModel(SupersetTestCase): sqlalchemy_uri = "mysql://root@localhost/superset" model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri) - with model.get_sqla_engine_with_context() as engine: + with model.get_sqla_engine() as engine: db = make_url(engine.url).database self.assertEqual("superset", db) - with model.get_sqla_engine_with_context(schema="staging") as engine: + with model.get_sqla_engine(schema="staging") as engine: db = make_url(engine.url).database self.assertEqual("staging", db) @@ -130,12 +130,12 @@ class TestDatabaseModel(SupersetTestCase): with override_user(example_user): model.impersonate_user = True - with model.get_sqla_engine_with_context() as engine: + with model.get_sqla_engine() as engine: username = make_url(engine.url).username self.assertEqual(example_user.username, username) model.impersonate_user = False - with model.get_sqla_engine_with_context() as engine: + with model.get_sqla_engine() as engine: username = make_url(engine.url).username self.assertNotEqual(example_user.username, username) @@ -295,7 +295,7 @@ class TestDatabaseModel(SupersetTestCase): db = get_example_database() table_name = "energy_usage" sql = db.select_star(table_name, show_cols=False, latest_partition=False) - with db.get_sqla_engine_with_context() as engine: + with db.get_sqla_engine() as engine: quote = engine.dialect.identifier_preparer.quote_identifier source = quote(table_name) if db.backend in {"presto", "hive"} else table_name diff --git a/tests/integration_tests/reports/commands_tests.py b/tests/integration_tests/reports/commands_tests.py index 0c353d1fa..9e92841a6 100644 --- a/tests/integration_tests/reports/commands_tests.py +++ b/tests/integration_tests/reports/commands_tests.py @@ -150,13 +150,13 @@ def assert_log(state: str, error_message: Optional[str] = None): @contextmanager def create_test_table_context(database: Database): - with database.get_sqla_engine_with_context() as engine: + with database.get_sqla_engine() as engine: engine.execute("CREATE TABLE test_table AS SELECT 1 as first, 2 as second") engine.execute("INSERT INTO test_table (first, second) VALUES (1, 2)") engine.execute("INSERT INTO test_table (first, second) VALUES (3, 4)") yield db.session - with database.get_sqla_engine_with_context() as engine: + with database.get_sqla_engine() as engine: engine.execute("DROP TABLE test_table") diff --git a/tests/integration_tests/sql_validator_tests.py b/tests/integration_tests/sql_validator_tests.py index ae8b160ae..850cc9ada 100644 --- a/tests/integration_tests/sql_validator_tests.py +++ b/tests/integration_tests/sql_validator_tests.py @@ -38,7 +38,7 @@ class TestPrestoValidator(SupersetTestCase): self.validator = PrestoDBSQLValidator self.database = MagicMock() self.database_engine = ( - self.database.get_sqla_engine_with_context.return_value.__enter__.return_value + self.database.get_sqla_engine.return_value.__enter__.return_value ) self.database_conn = self.database_engine.raw_connection.return_value self.database_cursor = self.database_conn.cursor.return_value diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py index 6cae6f6a1..0359317e3 100644 --- a/tests/integration_tests/sqla_models_tests.py +++ b/tests/integration_tests/sqla_models_tests.py @@ -313,7 +313,7 @@ class TestDatabaseModel(SupersetTestCase): query = table.database.compile_sqla_query(sqla_query.sqla_query) database = table.database - with database.get_sqla_engine_with_context() as engine: + with database.get_sqla_engine() as engine: quote = engine.dialect.identifier_preparer.quote_identifier for metric_label in {"metric using jinja macro", "same but different"}: diff --git a/tests/integration_tests/sqllab_tests.py b/tests/integration_tests/sqllab_tests.py index 8f4c42ee2..ccc76a039 100644 --- a/tests/integration_tests/sqllab_tests.py +++ b/tests/integration_tests/sqllab_tests.py @@ -212,7 +212,7 @@ class TestSqlLab(SupersetTestCase): # assertions db.session.commit() examples_db = get_example_database() - with examples_db.get_sqla_engine_with_context() as engine: + with examples_db.get_sqla_engine() as engine: data = engine.execute( f"SELECT * FROM admin_database.{tmp_table_name}" ).fetchall() @@ -296,7 +296,7 @@ class TestSqlLab(SupersetTestCase): "SchemaUser", ["SchemaPermission", "Gamma", "sql_lab"] ) - with examples_db.get_sqla_engine_with_context() as engine: + with examples_db.get_sqla_engine() as engine: engine.execute( f"CREATE TABLE IF NOT EXISTS {CTAS_SCHEMA_NAME}.test_table AS SELECT 1 as c1, 2 as c2" ) @@ -325,7 +325,7 @@ class TestSqlLab(SupersetTestCase): self.assertEqual(1, len(data["data"])) db.session.query(Query).delete() - with get_example_database().get_sqla_engine_with_context() as engine: + with get_example_database().get_sqla_engine() as engine: engine.execute(f"DROP TABLE IF EXISTS {CTAS_SCHEMA_NAME}.test_table") db.session.commit() diff --git a/tests/unit_tests/extensions/test_sqlalchemy.py b/tests/unit_tests/extensions/test_sqlalchemy.py index df36dc44e..24c849f55 100644 --- a/tests/unit_tests/extensions/test_sqlalchemy.py +++ b/tests/unit_tests/extensions/test_sqlalchemy.py @@ -59,7 +59,7 @@ def database1(session: Session) -> Iterator["Database"]: @pytest.fixture def table1(session: Session, database1: "Database") -> Iterator[None]: - with database1.get_sqla_engine_with_context() as engine: + with database1.get_sqla_engine() as engine: conn = engine.connect() conn.execute("CREATE TABLE table1 (a INTEGER NOT NULL PRIMARY KEY, b INTEGER)") conn.execute("INSERT INTO table1 (a, b) VALUES (1, 10), (2, 20)") @@ -92,7 +92,7 @@ def database2(session: Session) -> Iterator["Database"]: @pytest.fixture def table2(session: Session, database2: "Database") -> Iterator[None]: - with database2.get_sqla_engine_with_context() as engine: + with database2.get_sqla_engine() as engine: conn = engine.connect() conn.execute("CREATE TABLE table2 (a INTEGER NOT NULL PRIMARY KEY, b TEXT)") conn.execute("INSERT INTO table2 (a, b) VALUES (1, 'ten'), (2, 'twenty')") diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index 5d6c1fcbc..beefd3ea3 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -220,7 +220,7 @@ def test_get_prequeries(mocker: MockFixture) -> None: """ mocker.patch.object( Database, - "get_sqla_engine_with_context", + "get_sqla_engine", ) db_engine_spec = mocker.patch.object(Database, "db_engine_spec") db_engine_spec.get_prequeries.return_value = ["set a=1", "set b=2"] diff --git a/tests/unit_tests/models/helpers_test.py b/tests/unit_tests/models/helpers_test.py index 6d9597c0d..e3c59cbcb 100644 --- a/tests/unit_tests/models/helpers_test.py +++ b/tests/unit_tests/models/helpers_test.py @@ -54,13 +54,13 @@ def test_values_for_column(mocker: MockerFixture, session: Session) -> None: # since we're using an in-memory SQLite database, make sure we always # return the same engine where the table was created @contextmanager - def mock_get_sqla_engine_with_context(): + def mock_get_sqla_engine(): yield engine mocker.patch.object( database, - "get_sqla_engine_with_context", - new=mock_get_sqla_engine_with_context, + "get_sqla_engine", + new=mock_get_sqla_engine, ) table = SqlaTable(