refactor: rename get_sqla_engine_with_context (#28012)

This commit is contained in:
Beto Dealmeida 2024-04-12 13:31:05 -04:00 committed by GitHub
parent 06077d42a8
commit 99a1601aea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
45 changed files with 99 additions and 100 deletions

View File

@ -138,9 +138,7 @@ class TestConnectionDatabaseCommand(BaseCommand):
with closing(engine.raw_connection()) as conn:
return engine.dialect.do_ping(conn)
with database.get_sqla_engine_with_context(
override_ssh_tunnel=ssh_tunnel
) as engine:
with database.get_sqla_engine(override_ssh_tunnel=ssh_tunnel) as engine:
try:
alive = func_timeout(
app.config["TEST_DATABASE_CONNECTION_TIMEOUT"].total_seconds(),

View File

@ -102,7 +102,7 @@ class ValidateDatabaseParametersCommand(BaseCommand):
database.db_engine_spec.mutate_db_for_connection_test(database)
alive = False
with database.get_sqla_engine_with_context() as engine:
with database.get_sqla_engine() as engine:
try:
with closing(engine.raw_connection()) as conn:
alive = engine.dialect.do_ping(conn)

View File

@ -217,7 +217,7 @@ def load_data(data_uri: str, dataset: SqlaTable, database: Database) -> None:
)
else:
logger.warning("Loading data outside the import transaction")
with database.get_sqla_engine_with_context() as engine:
with database.get_sqla_engine() as engine:
df.to_sql(
dataset.table_name,
con=engine,

View File

@ -771,7 +771,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
... connection.execute(sql)
"""
return database.get_sqla_engine_with_context(schema=schema, source=source)
return database.get_sqla_engine(schema=schema, source=source)
@classmethod
def get_timestamp_expr(

View File

@ -456,7 +456,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
In BigQuery, a catalog is called a "project".
"""
engine: Engine
with database.get_sqla_engine_with_context() as engine:
with database.get_sqla_engine() as engine:
client = cls._get_client(engine)
projects = client.list_projects()

View File

@ -29,7 +29,7 @@ from .helpers import get_example_url, get_table_connector_registry
def load_bart_lines(only_metadata: bool = False, force: bool = False) -> None:
tbl_name = "bart_lines"
database = get_example_database()
with database.get_sqla_engine_with_context() as engine:
with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)

View File

@ -63,7 +63,7 @@ def load_data(tbl_name: str, database: Database, sample: bool = False) -> None:
pdf.ds = pd.to_datetime(pdf.ds, unit="ms")
pdf = pdf.head(100) if sample else pdf
with database.get_sqla_engine_with_context() as engine:
with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
pdf.to_sql(
@ -91,7 +91,7 @@ def load_birth_names(
) -> None:
"""Loading birth name dataset from a zip file in the repo"""
database = get_example_database()
with database.get_sqla_engine_with_context() as engine:
with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
tbl_name = "birth_names"

View File

@ -40,7 +40,7 @@ def load_country_map_data(only_metadata: bool = False, force: bool = False) -> N
tbl_name = "birth_france_by_region"
database = database_utils.get_example_database()
with database.get_sqla_engine_with_context() as engine:
with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)

View File

@ -42,7 +42,7 @@ def load_energy(
tbl_name = "energy_usage"
database = database_utils.get_example_database()
with database.get_sqla_engine_with_context() as engine:
with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)

View File

@ -27,7 +27,7 @@ def load_flights(only_metadata: bool = False, force: bool = False) -> None:
"""Loading random time series data from a zip file in the repo"""
tbl_name = "flights"
database = database_utils.get_example_database()
with database.get_sqla_engine_with_context() as engine:
with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)

View File

@ -39,7 +39,7 @@ def load_long_lat_data(only_metadata: bool = False, force: bool = False) -> None
"""Loading lat/long data from a csv file in the repo"""
tbl_name = "long_lat"
database = database_utils.get_example_database()
with database.get_sqla_engine_with_context() as engine:
with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)

View File

@ -39,7 +39,7 @@ def load_multiformat_time_series( # pylint: disable=too-many-locals
"""Loading time series data from a zip file in the repo"""
tbl_name = "multiformat_time_series"
database = get_example_database()
with database.get_sqla_engine_with_context() as engine:
with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)

View File

@ -28,7 +28,7 @@ from .helpers import get_example_url, get_table_connector_registry
def load_paris_iris_geojson(only_metadata: bool = False, force: bool = False) -> None:
tbl_name = "paris_iris_mapping"
database = database_utils.get_example_database()
with database.get_sqla_engine_with_context() as engine:
with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)

View File

@ -37,7 +37,7 @@ def load_random_time_series_data(
"""Loading random time series data from a zip file in the repo"""
tbl_name = "random_time_series"
database = database_utils.get_example_database()
with database.get_sqla_engine_with_context() as engine:
with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)

View File

@ -30,7 +30,7 @@ def load_sf_population_polygons(
) -> None:
tbl_name = "sf_population_polygons"
database = database_utils.get_example_database()
with database.get_sqla_engine_with_context() as engine:
with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)

View File

@ -439,7 +439,7 @@ def load_supported_charts_dashboard() -> None:
"""Loading a dashboard featuring supported charts"""
database = get_example_database()
with database.get_sqla_engine_with_context() as engine:
with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
tbl_name = "birth_names"

View File

@ -48,7 +48,7 @@ def load_world_bank_health_n_pop( # pylint: disable=too-many-locals, too-many-s
"""Loads the world bank health dataset, slices and a dashboard"""
tbl_name = "wb_health_population"
database = superset.utils.database.get_example_database()
with database.get_sqla_engine_with_context() as engine:
with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)

View File

@ -315,7 +315,7 @@ class SupersetShillelaghAdapter(Adapter):
# store this callable for later whenever we need an engine
self.engine_context = partial(
database.get_sqla_engine_with_context,
database.get_sqla_engine,
self.schema,
)

View File

@ -382,7 +382,7 @@ class Database(
)
@contextmanager
def get_sqla_engine_with_context(
def get_sqla_engine(
self,
schema: str | None = None,
nullpool: bool = True,
@ -424,6 +424,11 @@ class Database(
sqlalchemy_uri=sqlalchemy_uri,
)
# The `get_sqla_engine_with_context` was renamed to `get_sqla_engine`, but we kept a
# reference to the old method to prevent breaking third-party applications.
# TODO (betodealmeida): Remove in 5.0
get_sqla_engine_with_context = get_sqla_engine
def _get_sqla_engine(
self,
schema: str | None = None,
@ -531,7 +536,7 @@ class Database(
nullpool: bool = True,
source: utils.QuerySource | None = None,
) -> Connection:
with self.get_sqla_engine_with_context(
with self.get_sqla_engine(
schema=schema, nullpool=nullpool, source=source
) as engine:
with closing(engine.raw_connection()) as conn:
@ -574,7 +579,7 @@ class Database(
mutator: Callable[[pd.DataFrame], None] | None = None,
) -> pd.DataFrame:
sqls = self.db_engine_spec.parse_sql(sql)
with self.get_sqla_engine_with_context(schema) as engine:
with self.get_sqla_engine(schema) as engine:
engine_url = engine.url
mutate_after_split = config["MUTATE_AFTER_SPLIT"]
sql_query_mutator = config["SQL_QUERY_MUTATOR"]
@ -636,7 +641,7 @@ class Database(
return df
def compile_sqla_query(self, qry: Select, schema: str | None = None) -> str:
with self.get_sqla_engine_with_context(schema) as engine:
with self.get_sqla_engine(schema) as engine:
sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True}))
# pylint: disable=protected-access
@ -656,7 +661,7 @@ class Database(
cols: list[ResultSetColumnType] | None = None,
) -> str:
"""Generates a ``select *`` statement in the proper dialect"""
with self.get_sqla_engine_with_context(schema) as engine:
with self.get_sqla_engine(schema) as engine:
return self.db_engine_spec.select_star(
self,
table_name,
@ -753,9 +758,7 @@ class Database(
def get_inspector_with_context(
self, ssh_tunnel: SSHTunnel | None = None
) -> Inspector:
with self.get_sqla_engine_with_context(
override_ssh_tunnel=ssh_tunnel
) as engine:
with self.get_sqla_engine(override_ssh_tunnel=ssh_tunnel) as engine:
yield sqla.inspect(engine)
@cache_util.memoized_func(
@ -835,7 +838,7 @@ class Database(
def get_table(self, table_name: str, schema: str | None = None) -> Table:
extra = self.get_extra()
meta = MetaData(**extra.get("metadata_params", {}))
with self.get_sqla_engine_with_context() as engine:
with self.get_sqla_engine() as engine:
return Table(
table_name,
meta,
@ -939,11 +942,11 @@ class Database(
return self.perm # type: ignore
def has_table(self, table: Table) -> bool:
with self.get_sqla_engine_with_context() as engine:
with self.get_sqla_engine() as engine:
return engine.has_table(table.table_name, table.schema or None)
def has_table_by_name(self, table_name: str, schema: str | None = None) -> bool:
with self.get_sqla_engine_with_context() as engine:
with self.get_sqla_engine() as engine:
return engine.has_table(table_name, schema)
@classmethod
@ -962,7 +965,7 @@ class Database(
return view_name in view_names
def has_view(self, view_name: str, schema: str | None = None) -> bool:
with self.get_sqla_engine_with_context(schema) as engine:
with self.get_sqla_engine(schema) as engine:
return engine.run_callable(
self._has_view, engine.dialect, view_name, schema
)

View File

@ -217,7 +217,7 @@ class Dashboard(AuditMixinNullable, ImportExportMixin, Model):
@property
def sqla_metadata(self) -> None:
# pylint: disable=no-member
with self.get_sqla_engine_with_context() as engine:
with self.get_sqla_engine() as engine:
meta = MetaData(bind=engine)
meta.reflect()

View File

@ -1390,7 +1390,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
if self.fetch_values_predicate:
qry = qry.where(self.get_fetch_values_predicate(template_processor=tp))
with self.database.get_sqla_engine_with_context() as engine:
with self.database.get_sqla_engine() as engine:
sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
sql = self._apply_cte(sql, cte)
sql = self.mutate_query_from_config(sql)
@ -1992,7 +1992,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
and db_engine_spec.allows_hidden_cc_in_orderby
and col.name in [select_col.name for select_col in select_exprs]
):
with self.database.get_sqla_engine_with_context() as engine:
with self.database.get_sqla_engine() as engine:
quote = engine.dialect.identifier_preparer.quote
col = literal_column(quote(col.name))
direction = sa.asc if ascending else sa.desc

View File

@ -644,7 +644,7 @@ def cancel_query(query: Query) -> bool:
if cancel_query_id is None:
return False
with query.database.get_sqla_engine_with_context(
with query.database.get_sqla_engine(
query.schema, source=QuerySource.SQL_LAB
) as engine:
with closing(engine.raw_connection()) as conn:

View File

@ -160,9 +160,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
logger.info("Validating %i statement(s)", len(statements))
# todo(hughhh): update this to use new database.get_raw_connection()
# this function keeps stalling CI
with database.get_sqla_engine_with_context(
schema, source=QuerySource.SQL_LAB
) as engine:
with database.get_sqla_engine(schema, source=QuerySource.SQL_LAB) as engine:
# Sharing a single connection and cursor across the
# execution of all statements (if many)
annotations: list[SQLValidationAnnotation] = []

View File

@ -1170,7 +1170,7 @@ def get_example_default_schema() -> str | None:
Return the default schema of the examples database, if any.
"""
database = get_example_database()
with database.get_sqla_engine_with_context() as engine:
with database.get_sqla_engine() as engine:
return inspect(engine).default_schema_name

View File

@ -184,7 +184,7 @@ def add_data(
database = get_example_database()
table_exists = database.has_table_by_name(table_name)
with database.get_sqla_engine_with_context() as engine:
with database.get_sqla_engine() as engine:
if columns is None:
if not table_exists:
raise Exception( # pylint: disable=broad-exception-raised

View File

@ -72,7 +72,7 @@ def example_db_provider() -> Callable[[], Database]:
@fixture(scope="session")
def example_db_engine(example_db_provider: Callable[[], Database]) -> Engine:
with app.app_context():
with example_db_provider().get_sqla_engine_with_context() as engine:
with example_db_provider().get_sqla_engine() as engine:
return engine

View File

@ -113,7 +113,7 @@ def drop_table_if_exists(table_name: str, table_type: CtasMethod) -> None:
"""Drop table if it exists, works on any DB"""
sql = f"DROP {table_type} IF EXISTS {table_name}"
database = get_example_database()
with database.get_sqla_engine_with_context() as engine:
with database.get_sqla_engine() as engine:
engine.execute(sql)

View File

@ -212,7 +212,7 @@ def setup_presto_if_needed():
if backend in {"presto", "hive"}:
database = get_example_database()
with database.get_sqla_engine_with_context() as engine:
with database.get_sqla_engine() as engine:
drop_from_schema(engine, CTAS_SCHEMA_NAME)
engine.execute(f"DROP SCHEMA IF EXISTS {CTAS_SCHEMA_NAME}")
engine.execute(f"CREATE SCHEMA {CTAS_SCHEMA_NAME}")
@ -343,7 +343,7 @@ def physical_dataset():
example_database = get_example_database()
with example_database.get_sqla_engine_with_context() as engine:
with example_database.get_sqla_engine() as engine:
quoter = get_identifier_quoter(engine.name)
# sqlite can only execute one statement at a time
engine.execute(

View File

@ -71,7 +71,7 @@ def _setup_csv_upload():
yield
upload_db = get_upload_db()
with upload_db.get_sqla_engine_with_context() as engine:
with upload_db.get_sqla_engine() as engine:
engine.execute(f"DROP TABLE IF EXISTS {EXCEL_UPLOAD_TABLE}")
engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE}")
engine.execute(f"DROP TABLE IF EXISTS {PARQUET_UPLOAD_TABLE}")
@ -268,7 +268,7 @@ def test_import_csv_enforced_schema(mock_event_logger):
table=CSV_UPLOAD_TABLE_W_SCHEMA,
)
with get_upload_db().get_sqla_engine_with_context() as engine:
with get_upload_db().get_sqla_engine() as engine:
data = engine.execute(
f"SELECT * from {ADMIN_SCHEMA_NAME}.{CSV_UPLOAD_TABLE_W_SCHEMA} ORDER BY b"
).fetchall()
@ -294,7 +294,7 @@ def test_import_csv_enforced_schema(mock_event_logger):
assert success_msg in resp
# Clean up
with get_upload_db().get_sqla_engine_with_context() as engine:
with get_upload_db().get_sqla_engine() as engine:
engine.execute(f"DROP TABLE {full_table_name}")
@ -380,7 +380,7 @@ def test_import_csv(mock_event_logger):
extra={"null_values": '["", "john"]', "if_exists": "replace"},
)
# make sure that john and empty string are replaced with None
with test_db.get_sqla_engine_with_context() as engine:
with test_db.get_sqla_engine() as engine:
data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE} ORDER BY c").fetchall()
assert data == [(None, 1, "x"), ("paul", 2, None)]
# default null values
@ -390,7 +390,7 @@ def test_import_csv(mock_event_logger):
assert data == [("john", 1, "x"), ("paul", 2, None)]
# cleanup
with get_upload_db().get_sqla_engine_with_context() as engine:
with get_upload_db().get_sqla_engine() as engine:
engine.execute(f"DROP TABLE {full_table_name}")
# with dtype
@ -403,12 +403,12 @@ def test_import_csv(mock_event_logger):
# you can change the type to something compatible, like an object to string
# or an int to a float
# file upload should work as normal
with test_db.get_sqla_engine_with_context() as engine:
with test_db.get_sqla_engine() as engine:
data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE} ORDER BY b").fetchall()
assert data == [("john", 1), ("paul", 2)]
# cleanup
with get_upload_db().get_sqla_engine_with_context() as engine:
with get_upload_db().get_sqla_engine() as engine:
engine.execute(f"DROP TABLE {full_table_name}")
# with dtype - wrong type
@ -475,7 +475,7 @@ def test_import_excel(mock_event_logger):
table=EXCEL_UPLOAD_TABLE,
)
with test_db.get_sqla_engine_with_context() as engine:
with test_db.get_sqla_engine() as engine:
data = engine.execute(
f"SELECT * from {EXCEL_UPLOAD_TABLE} ORDER BY b"
).fetchall()
@ -541,7 +541,7 @@ def test_import_parquet(mock_event_logger):
)
assert success_msg_f1 in resp
with test_db.get_sqla_engine_with_context() as engine:
with test_db.get_sqla_engine() as engine:
data = engine.execute(
f"SELECT * from {PARQUET_UPLOAD_TABLE} ORDER BY b"
).fetchall()
@ -554,7 +554,7 @@ def test_import_parquet(mock_event_logger):
success_msg_f2 = f"Columnar file {escaped_parquet(ZIP_FILENAME)} uploaded to table {escaped_double_quotes(full_table_name)}"
assert success_msg_f2 in resp
with test_db.get_sqla_engine_with_context() as engine:
with test_db.get_sqla_engine() as engine:
data = engine.execute(
f"SELECT * from {PARQUET_UPLOAD_TABLE} ORDER BY b"
).fetchall()

View File

@ -895,7 +895,7 @@ class TestDatabaseApi(SupersetTestCase):
if database.backend == "mysql":
query = query.replace('"', "`")
with database.get_sqla_engine_with_context() as engine:
with database.get_sqla_engine() as engine:
engine.execute(query)
self.login(ADMIN_USERNAME)

View File

@ -718,7 +718,7 @@ class TestDatasetApi(SupersetTestCase):
return
example_db = get_example_database()
with example_db.get_sqla_engine_with_context() as engine:
with example_db.get_sqla_engine() as engine:
engine.execute(
f"CREATE TABLE {CTAS_SCHEMA_NAME}.birth_names AS SELECT 2 as two"
)
@ -739,7 +739,7 @@ class TestDatasetApi(SupersetTestCase):
uri = f'api/v1/dataset/{data.get("id")}'
rv = self.client.delete(uri)
assert rv.status_code == 200
with example_db.get_sqla_engine_with_context() as engine:
with example_db.get_sqla_engine() as engine:
engine.execute(f"DROP TABLE {CTAS_SCHEMA_NAME}.birth_names")
def test_create_dataset_validate_database(self):
@ -800,7 +800,7 @@ class TestDatasetApi(SupersetTestCase):
mock_get_table.return_value = None
example_db = get_example_database()
with example_db.get_sqla_engine_with_context() as engine:
with example_db.get_sqla_engine() as engine:
engine = engine
dialect = engine.dialect
@ -2389,7 +2389,7 @@ class TestDatasetApi(SupersetTestCase):
self.login(ADMIN_USERNAME)
examples_db = get_example_database()
with examples_db.get_sqla_engine_with_context() as engine:
with examples_db.get_sqla_engine() as engine:
engine.execute("DROP TABLE IF EXISTS test_create_sqla_table_api")
engine.execute("CREATE TABLE test_create_sqla_table_api AS SELECT 2 as col")
@ -2415,7 +2415,7 @@ class TestDatasetApi(SupersetTestCase):
db.session.delete(table)
db.session.commit()
with examples_db.get_sqla_engine_with_context() as engine:
with examples_db.get_sqla_engine() as engine:
engine.execute("DROP TABLE test_create_sqla_table_api")
@pytest.mark.usefixtures(

View File

@ -563,7 +563,7 @@ class TestCreateDatasetCommand(SupersetTestCase):
def test_create_dataset_command(self):
examples_db = get_example_database()
with examples_db.get_sqla_engine_with_context() as engine:
with examples_db.get_sqla_engine() as engine:
engine.execute("DROP TABLE IF EXISTS test_create_dataset_command")
engine.execute(
"CREATE TABLE test_create_dataset_command AS SELECT 2 as col"
@ -585,7 +585,7 @@ class TestCreateDatasetCommand(SupersetTestCase):
self.assertEqual([owner.username for owner in table.owners], ["admin"])
db.session.delete(table)
with examples_db.get_sqla_engine_with_context() as engine:
with examples_db.get_sqla_engine() as engine:
engine.execute("DROP TABLE test_create_dataset_command")
db.session.commit()

View File

@ -47,7 +47,7 @@ def create_test_table_context(database: Database):
schema = get_example_default_schema()
full_table_name = f"{schema}.test_table" if schema else "test_table"
with database.get_sqla_engine_with_context() as engine:
with database.get_sqla_engine() as engine:
engine.execute(
f"CREATE TABLE IF NOT EXISTS {full_table_name} AS SELECT 1 as first, 2 as second"
)
@ -56,7 +56,7 @@ def create_test_table_context(database: Database):
yield db.session
with database.get_sqla_engine_with_context() as engine:
with database.get_sqla_engine() as engine:
engine.execute(f"DROP TABLE {full_table_name}")

View File

@ -193,7 +193,7 @@ def test_df_to_sql_if_exists_replace(mock_upload_to_s3, mock_g):
mock_database = mock.MagicMock()
mock_database.get_df.return_value.empty = False
mock_execute = mock.MagicMock(return_value=True)
mock_database.get_sqla_engine_with_context.return_value.__enter__.return_value.execute = (
mock_database.get_sqla_engine.return_value.__enter__.return_value.execute = (
mock_execute
)
table_name = "foobar"
@ -220,7 +220,7 @@ def test_df_to_sql_if_exists_replace_with_schema(mock_upload_to_s3, mock_g):
mock_database = mock.MagicMock()
mock_database.get_df.return_value.empty = False
mock_execute = mock.MagicMock(return_value=True)
mock_database.get_sqla_engine_with_context.return_value.__enter__.return_value.execute = (
mock_database.get_sqla_engine.return_value.__enter__.return_value.execute = (
mock_execute
)
table_name = "foobar"

View File

@ -38,7 +38,7 @@ ENERGY_USAGE_TBL_NAME = "energy_usage"
def load_energy_table_data():
with app.app_context():
database = get_example_database()
with database.get_sqla_engine_with_context() as engine:
with database.get_sqla_engine() as engine:
df = _get_dataframe()
df.to_sql(
ENERGY_USAGE_TBL_NAME,
@ -52,7 +52,7 @@ def load_energy_table_data():
)
yield
with app.app_context():
with get_example_database().get_sqla_engine_with_context() as engine:
with get_example_database().get_sqla_engine() as engine:
engine.execute("DROP TABLE IF EXISTS energy_usage")

View File

@ -37,7 +37,7 @@ UNICODE_TBL_NAME = "unicode_test"
@pytest.fixture(scope="session")
def load_unicode_data():
with app.app_context():
with get_example_database().get_sqla_engine_with_context() as engine:
with get_example_database().get_sqla_engine() as engine:
_get_dataframe().to_sql(
UNICODE_TBL_NAME,
engine,
@ -51,7 +51,7 @@ def load_unicode_data():
yield
with app.app_context():
with get_example_database().get_sqla_engine_with_context() as engine:
with get_example_database().get_sqla_engine() as engine:
engine.execute("DROP TABLE IF EXISTS unicode_test")

View File

@ -50,7 +50,7 @@ def load_world_bank_data():
"country_name": String(255),
"region": String(255),
}
with database.get_sqla_engine_with_context() as engine:
with database.get_sqla_engine() as engine:
_get_dataframe(database).to_sql(
WB_HEALTH_POPULATION,
engine,
@ -64,7 +64,7 @@ def load_world_bank_data():
yield
with app.app_context():
with get_example_database().get_sqla_engine_with_context() as engine:
with get_example_database().get_sqla_engine() as engine:
engine.execute("DROP TABLE IF EXISTS wb_health_population")

View File

@ -56,22 +56,22 @@ class TestDatabaseModel(SupersetTestCase):
sqlalchemy_uri = "presto://presto.airbnb.io:8080/hive/default"
model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri)
with model.get_sqla_engine_with_context() as engine:
with model.get_sqla_engine() as engine:
db = make_url(engine.url).database
self.assertEqual("hive/default", db)
with model.get_sqla_engine_with_context(schema="core_db") as engine:
with model.get_sqla_engine(schema="core_db") as engine:
db = make_url(engine.url).database
self.assertEqual("hive/core_db", db)
sqlalchemy_uri = "presto://presto.airbnb.io:8080/hive"
model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri)
with model.get_sqla_engine_with_context() as engine:
with model.get_sqla_engine() as engine:
db = make_url(engine.url).database
self.assertEqual("hive", db)
with model.get_sqla_engine_with_context(schema="core_db") as engine:
with model.get_sqla_engine(schema="core_db") as engine:
db = make_url(engine.url).database
self.assertEqual("hive/core_db", db)
@ -79,11 +79,11 @@ class TestDatabaseModel(SupersetTestCase):
sqlalchemy_uri = "postgresql+psycopg2://postgres.airbnb.io:5439/prod"
model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri)
with model.get_sqla_engine_with_context() as engine:
with model.get_sqla_engine() as engine:
db = make_url(engine.url).database
self.assertEqual("prod", db)
with model.get_sqla_engine_with_context(schema="foo") as engine:
with model.get_sqla_engine(schema="foo") as engine:
db = make_url(engine.url).database
self.assertEqual("prod", db)
@ -97,11 +97,11 @@ class TestDatabaseModel(SupersetTestCase):
sqlalchemy_uri = "hive://hive@hive.airbnb.io:10000/default?auth=NOSASL"
model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri)
with model.get_sqla_engine_with_context() as engine:
with model.get_sqla_engine() as engine:
db = make_url(engine.url).database
self.assertEqual("default", db)
with model.get_sqla_engine_with_context(schema="core_db") as engine:
with model.get_sqla_engine(schema="core_db") as engine:
db = make_url(engine.url).database
self.assertEqual("core_db", db)
@ -112,11 +112,11 @@ class TestDatabaseModel(SupersetTestCase):
sqlalchemy_uri = "mysql://root@localhost/superset"
model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri)
with model.get_sqla_engine_with_context() as engine:
with model.get_sqla_engine() as engine:
db = make_url(engine.url).database
self.assertEqual("superset", db)
with model.get_sqla_engine_with_context(schema="staging") as engine:
with model.get_sqla_engine(schema="staging") as engine:
db = make_url(engine.url).database
self.assertEqual("staging", db)
@ -130,12 +130,12 @@ class TestDatabaseModel(SupersetTestCase):
with override_user(example_user):
model.impersonate_user = True
with model.get_sqla_engine_with_context() as engine:
with model.get_sqla_engine() as engine:
username = make_url(engine.url).username
self.assertEqual(example_user.username, username)
model.impersonate_user = False
with model.get_sqla_engine_with_context() as engine:
with model.get_sqla_engine() as engine:
username = make_url(engine.url).username
self.assertNotEqual(example_user.username, username)
@ -295,7 +295,7 @@ class TestDatabaseModel(SupersetTestCase):
db = get_example_database()
table_name = "energy_usage"
sql = db.select_star(table_name, show_cols=False, latest_partition=False)
with db.get_sqla_engine_with_context() as engine:
with db.get_sqla_engine() as engine:
quote = engine.dialect.identifier_preparer.quote_identifier
source = quote(table_name) if db.backend in {"presto", "hive"} else table_name

View File

@ -150,13 +150,13 @@ def assert_log(state: str, error_message: Optional[str] = None):
@contextmanager
def create_test_table_context(database: Database):
with database.get_sqla_engine_with_context() as engine:
with database.get_sqla_engine() as engine:
engine.execute("CREATE TABLE test_table AS SELECT 1 as first, 2 as second")
engine.execute("INSERT INTO test_table (first, second) VALUES (1, 2)")
engine.execute("INSERT INTO test_table (first, second) VALUES (3, 4)")
yield db.session
with database.get_sqla_engine_with_context() as engine:
with database.get_sqla_engine() as engine:
engine.execute("DROP TABLE test_table")

View File

@ -38,7 +38,7 @@ class TestPrestoValidator(SupersetTestCase):
self.validator = PrestoDBSQLValidator
self.database = MagicMock()
self.database_engine = (
self.database.get_sqla_engine_with_context.return_value.__enter__.return_value
self.database.get_sqla_engine.return_value.__enter__.return_value
)
self.database_conn = self.database_engine.raw_connection.return_value
self.database_cursor = self.database_conn.cursor.return_value

View File

@ -313,7 +313,7 @@ class TestDatabaseModel(SupersetTestCase):
query = table.database.compile_sqla_query(sqla_query.sqla_query)
database = table.database
with database.get_sqla_engine_with_context() as engine:
with database.get_sqla_engine() as engine:
quote = engine.dialect.identifier_preparer.quote_identifier
for metric_label in {"metric using jinja macro", "same but different"}:

View File

@ -212,7 +212,7 @@ class TestSqlLab(SupersetTestCase):
# assertions
db.session.commit()
examples_db = get_example_database()
with examples_db.get_sqla_engine_with_context() as engine:
with examples_db.get_sqla_engine() as engine:
data = engine.execute(
f"SELECT * FROM admin_database.{tmp_table_name}"
).fetchall()
@ -296,7 +296,7 @@ class TestSqlLab(SupersetTestCase):
"SchemaUser", ["SchemaPermission", "Gamma", "sql_lab"]
)
with examples_db.get_sqla_engine_with_context() as engine:
with examples_db.get_sqla_engine() as engine:
engine.execute(
f"CREATE TABLE IF NOT EXISTS {CTAS_SCHEMA_NAME}.test_table AS SELECT 1 as c1, 2 as c2"
)
@ -325,7 +325,7 @@ class TestSqlLab(SupersetTestCase):
self.assertEqual(1, len(data["data"]))
db.session.query(Query).delete()
with get_example_database().get_sqla_engine_with_context() as engine:
with get_example_database().get_sqla_engine() as engine:
engine.execute(f"DROP TABLE IF EXISTS {CTAS_SCHEMA_NAME}.test_table")
db.session.commit()

View File

@ -59,7 +59,7 @@ def database1(session: Session) -> Iterator["Database"]:
@pytest.fixture
def table1(session: Session, database1: "Database") -> Iterator[None]:
with database1.get_sqla_engine_with_context() as engine:
with database1.get_sqla_engine() as engine:
conn = engine.connect()
conn.execute("CREATE TABLE table1 (a INTEGER NOT NULL PRIMARY KEY, b INTEGER)")
conn.execute("INSERT INTO table1 (a, b) VALUES (1, 10), (2, 20)")
@ -92,7 +92,7 @@ def database2(session: Session) -> Iterator["Database"]:
@pytest.fixture
def table2(session: Session, database2: "Database") -> Iterator[None]:
with database2.get_sqla_engine_with_context() as engine:
with database2.get_sqla_engine() as engine:
conn = engine.connect()
conn.execute("CREATE TABLE table2 (a INTEGER NOT NULL PRIMARY KEY, b TEXT)")
conn.execute("INSERT INTO table2 (a, b) VALUES (1, 'ten'), (2, 'twenty')")

View File

@ -220,7 +220,7 @@ def test_get_prequeries(mocker: MockFixture) -> None:
"""
mocker.patch.object(
Database,
"get_sqla_engine_with_context",
"get_sqla_engine",
)
db_engine_spec = mocker.patch.object(Database, "db_engine_spec")
db_engine_spec.get_prequeries.return_value = ["set a=1", "set b=2"]

View File

@ -54,13 +54,13 @@ def test_values_for_column(mocker: MockerFixture, session: Session) -> None:
# since we're using an in-memory SQLite database, make sure we always
# return the same engine where the table was created
@contextmanager
def mock_get_sqla_engine_with_context():
def mock_get_sqla_engine():
yield engine
mocker.patch.object(
database,
"get_sqla_engine_with_context",
new=mock_get_sqla_engine_with_context,
"get_sqla_engine",
new=mock_get_sqla_engine,
)
table = SqlaTable(