refactor: rename get_sqla_engine_with_context (#28012)
This commit is contained in:
parent
06077d42a8
commit
99a1601aea
|
|
@ -138,9 +138,7 @@ class TestConnectionDatabaseCommand(BaseCommand):
|
|||
with closing(engine.raw_connection()) as conn:
|
||||
return engine.dialect.do_ping(conn)
|
||||
|
||||
with database.get_sqla_engine_with_context(
|
||||
override_ssh_tunnel=ssh_tunnel
|
||||
) as engine:
|
||||
with database.get_sqla_engine(override_ssh_tunnel=ssh_tunnel) as engine:
|
||||
try:
|
||||
alive = func_timeout(
|
||||
app.config["TEST_DATABASE_CONNECTION_TIMEOUT"].total_seconds(),
|
||||
|
|
|
|||
|
|
@ -102,7 +102,7 @@ class ValidateDatabaseParametersCommand(BaseCommand):
|
|||
database.db_engine_spec.mutate_db_for_connection_test(database)
|
||||
|
||||
alive = False
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
with database.get_sqla_engine() as engine:
|
||||
try:
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
alive = engine.dialect.do_ping(conn)
|
||||
|
|
|
|||
|
|
@ -217,7 +217,7 @@ def load_data(data_uri: str, dataset: SqlaTable, database: Database) -> None:
|
|||
)
|
||||
else:
|
||||
logger.warning("Loading data outside the import transaction")
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
with database.get_sqla_engine() as engine:
|
||||
df.to_sql(
|
||||
dataset.table_name,
|
||||
con=engine,
|
||||
|
|
|
|||
|
|
@ -771,7 +771,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
... connection.execute(sql)
|
||||
|
||||
"""
|
||||
return database.get_sqla_engine_with_context(schema=schema, source=source)
|
||||
return database.get_sqla_engine(schema=schema, source=source)
|
||||
|
||||
@classmethod
|
||||
def get_timestamp_expr(
|
||||
|
|
|
|||
|
|
@ -456,7 +456,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
|
|||
In BigQuery, a catalog is called a "project".
|
||||
"""
|
||||
engine: Engine
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
with database.get_sqla_engine() as engine:
|
||||
client = cls._get_client(engine)
|
||||
projects = client.list_projects()
|
||||
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ from .helpers import get_example_url, get_table_connector_registry
|
|||
def load_bart_lines(only_metadata: bool = False, force: bool = False) -> None:
|
||||
tbl_name = "bart_lines"
|
||||
database = get_example_database()
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
with database.get_sqla_engine() as engine:
|
||||
schema = inspect(engine).default_schema_name
|
||||
table_exists = database.has_table_by_name(tbl_name)
|
||||
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ def load_data(tbl_name: str, database: Database, sample: bool = False) -> None:
|
|||
pdf.ds = pd.to_datetime(pdf.ds, unit="ms")
|
||||
pdf = pdf.head(100) if sample else pdf
|
||||
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
with database.get_sqla_engine() as engine:
|
||||
schema = inspect(engine).default_schema_name
|
||||
|
||||
pdf.to_sql(
|
||||
|
|
@ -91,7 +91,7 @@ def load_birth_names(
|
|||
) -> None:
|
||||
"""Loading birth name dataset from a zip file in the repo"""
|
||||
database = get_example_database()
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
with database.get_sqla_engine() as engine:
|
||||
schema = inspect(engine).default_schema_name
|
||||
|
||||
tbl_name = "birth_names"
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ def load_country_map_data(only_metadata: bool = False, force: bool = False) -> N
|
|||
tbl_name = "birth_france_by_region"
|
||||
database = database_utils.get_example_database()
|
||||
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
with database.get_sqla_engine() as engine:
|
||||
schema = inspect(engine).default_schema_name
|
||||
table_exists = database.has_table_by_name(tbl_name)
|
||||
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ def load_energy(
|
|||
tbl_name = "energy_usage"
|
||||
database = database_utils.get_example_database()
|
||||
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
with database.get_sqla_engine() as engine:
|
||||
schema = inspect(engine).default_schema_name
|
||||
table_exists = database.has_table_by_name(tbl_name)
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ def load_flights(only_metadata: bool = False, force: bool = False) -> None:
|
|||
"""Loading random time series data from a zip file in the repo"""
|
||||
tbl_name = "flights"
|
||||
database = database_utils.get_example_database()
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
with database.get_sqla_engine() as engine:
|
||||
schema = inspect(engine).default_schema_name
|
||||
table_exists = database.has_table_by_name(tbl_name)
|
||||
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ def load_long_lat_data(only_metadata: bool = False, force: bool = False) -> None
|
|||
"""Loading lat/long data from a csv file in the repo"""
|
||||
tbl_name = "long_lat"
|
||||
database = database_utils.get_example_database()
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
with database.get_sqla_engine() as engine:
|
||||
schema = inspect(engine).default_schema_name
|
||||
table_exists = database.has_table_by_name(tbl_name)
|
||||
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ def load_multiformat_time_series( # pylint: disable=too-many-locals
|
|||
"""Loading time series data from a zip file in the repo"""
|
||||
tbl_name = "multiformat_time_series"
|
||||
database = get_example_database()
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
with database.get_sqla_engine() as engine:
|
||||
schema = inspect(engine).default_schema_name
|
||||
table_exists = database.has_table_by_name(tbl_name)
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ from .helpers import get_example_url, get_table_connector_registry
|
|||
def load_paris_iris_geojson(only_metadata: bool = False, force: bool = False) -> None:
|
||||
tbl_name = "paris_iris_mapping"
|
||||
database = database_utils.get_example_database()
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
with database.get_sqla_engine() as engine:
|
||||
schema = inspect(engine).default_schema_name
|
||||
table_exists = database.has_table_by_name(tbl_name)
|
||||
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ def load_random_time_series_data(
|
|||
"""Loading random time series data from a zip file in the repo"""
|
||||
tbl_name = "random_time_series"
|
||||
database = database_utils.get_example_database()
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
with database.get_sqla_engine() as engine:
|
||||
schema = inspect(engine).default_schema_name
|
||||
table_exists = database.has_table_by_name(tbl_name)
|
||||
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ def load_sf_population_polygons(
|
|||
) -> None:
|
||||
tbl_name = "sf_population_polygons"
|
||||
database = database_utils.get_example_database()
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
with database.get_sqla_engine() as engine:
|
||||
schema = inspect(engine).default_schema_name
|
||||
table_exists = database.has_table_by_name(tbl_name)
|
||||
|
||||
|
|
|
|||
|
|
@ -439,7 +439,7 @@ def load_supported_charts_dashboard() -> None:
|
|||
"""Loading a dashboard featuring supported charts"""
|
||||
|
||||
database = get_example_database()
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
with database.get_sqla_engine() as engine:
|
||||
schema = inspect(engine).default_schema_name
|
||||
|
||||
tbl_name = "birth_names"
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ def load_world_bank_health_n_pop( # pylint: disable=too-many-locals, too-many-s
|
|||
"""Loads the world bank health dataset, slices and a dashboard"""
|
||||
tbl_name = "wb_health_population"
|
||||
database = superset.utils.database.get_example_database()
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
with database.get_sqla_engine() as engine:
|
||||
schema = inspect(engine).default_schema_name
|
||||
table_exists = database.has_table_by_name(tbl_name)
|
||||
|
||||
|
|
|
|||
|
|
@ -315,7 +315,7 @@ class SupersetShillelaghAdapter(Adapter):
|
|||
|
||||
# store this callable for later whenever we need an engine
|
||||
self.engine_context = partial(
|
||||
database.get_sqla_engine_with_context,
|
||||
database.get_sqla_engine,
|
||||
self.schema,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -382,7 +382,7 @@ class Database(
|
|||
)
|
||||
|
||||
@contextmanager
|
||||
def get_sqla_engine_with_context(
|
||||
def get_sqla_engine(
|
||||
self,
|
||||
schema: str | None = None,
|
||||
nullpool: bool = True,
|
||||
|
|
@ -424,6 +424,11 @@ class Database(
|
|||
sqlalchemy_uri=sqlalchemy_uri,
|
||||
)
|
||||
|
||||
# The `get_sqla_engine_with_context` was renamed to `get_sqla_engine`, but we kept a
|
||||
# reference to the old method to prevent breaking third-party applications.
|
||||
# TODO (betodealmeida): Remove in 5.0
|
||||
get_sqla_engine_with_context = get_sqla_engine
|
||||
|
||||
def _get_sqla_engine(
|
||||
self,
|
||||
schema: str | None = None,
|
||||
|
|
@ -531,7 +536,7 @@ class Database(
|
|||
nullpool: bool = True,
|
||||
source: utils.QuerySource | None = None,
|
||||
) -> Connection:
|
||||
with self.get_sqla_engine_with_context(
|
||||
with self.get_sqla_engine(
|
||||
schema=schema, nullpool=nullpool, source=source
|
||||
) as engine:
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
|
|
@ -574,7 +579,7 @@ class Database(
|
|||
mutator: Callable[[pd.DataFrame], None] | None = None,
|
||||
) -> pd.DataFrame:
|
||||
sqls = self.db_engine_spec.parse_sql(sql)
|
||||
with self.get_sqla_engine_with_context(schema) as engine:
|
||||
with self.get_sqla_engine(schema) as engine:
|
||||
engine_url = engine.url
|
||||
mutate_after_split = config["MUTATE_AFTER_SPLIT"]
|
||||
sql_query_mutator = config["SQL_QUERY_MUTATOR"]
|
||||
|
|
@ -636,7 +641,7 @@ class Database(
|
|||
return df
|
||||
|
||||
def compile_sqla_query(self, qry: Select, schema: str | None = None) -> str:
|
||||
with self.get_sqla_engine_with_context(schema) as engine:
|
||||
with self.get_sqla_engine(schema) as engine:
|
||||
sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True}))
|
||||
|
||||
# pylint: disable=protected-access
|
||||
|
|
@ -656,7 +661,7 @@ class Database(
|
|||
cols: list[ResultSetColumnType] | None = None,
|
||||
) -> str:
|
||||
"""Generates a ``select *`` statement in the proper dialect"""
|
||||
with self.get_sqla_engine_with_context(schema) as engine:
|
||||
with self.get_sqla_engine(schema) as engine:
|
||||
return self.db_engine_spec.select_star(
|
||||
self,
|
||||
table_name,
|
||||
|
|
@ -753,9 +758,7 @@ class Database(
|
|||
def get_inspector_with_context(
|
||||
self, ssh_tunnel: SSHTunnel | None = None
|
||||
) -> Inspector:
|
||||
with self.get_sqla_engine_with_context(
|
||||
override_ssh_tunnel=ssh_tunnel
|
||||
) as engine:
|
||||
with self.get_sqla_engine(override_ssh_tunnel=ssh_tunnel) as engine:
|
||||
yield sqla.inspect(engine)
|
||||
|
||||
@cache_util.memoized_func(
|
||||
|
|
@ -835,7 +838,7 @@ class Database(
|
|||
def get_table(self, table_name: str, schema: str | None = None) -> Table:
|
||||
extra = self.get_extra()
|
||||
meta = MetaData(**extra.get("metadata_params", {}))
|
||||
with self.get_sqla_engine_with_context() as engine:
|
||||
with self.get_sqla_engine() as engine:
|
||||
return Table(
|
||||
table_name,
|
||||
meta,
|
||||
|
|
@ -939,11 +942,11 @@ class Database(
|
|||
return self.perm # type: ignore
|
||||
|
||||
def has_table(self, table: Table) -> bool:
|
||||
with self.get_sqla_engine_with_context() as engine:
|
||||
with self.get_sqla_engine() as engine:
|
||||
return engine.has_table(table.table_name, table.schema or None)
|
||||
|
||||
def has_table_by_name(self, table_name: str, schema: str | None = None) -> bool:
|
||||
with self.get_sqla_engine_with_context() as engine:
|
||||
with self.get_sqla_engine() as engine:
|
||||
return engine.has_table(table_name, schema)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -962,7 +965,7 @@ class Database(
|
|||
return view_name in view_names
|
||||
|
||||
def has_view(self, view_name: str, schema: str | None = None) -> bool:
|
||||
with self.get_sqla_engine_with_context(schema) as engine:
|
||||
with self.get_sqla_engine(schema) as engine:
|
||||
return engine.run_callable(
|
||||
self._has_view, engine.dialect, view_name, schema
|
||||
)
|
||||
|
|
|
|||
|
|
@ -217,7 +217,7 @@ class Dashboard(AuditMixinNullable, ImportExportMixin, Model):
|
|||
@property
|
||||
def sqla_metadata(self) -> None:
|
||||
# pylint: disable=no-member
|
||||
with self.get_sqla_engine_with_context() as engine:
|
||||
with self.get_sqla_engine() as engine:
|
||||
meta = MetaData(bind=engine)
|
||||
meta.reflect()
|
||||
|
||||
|
|
|
|||
|
|
@ -1390,7 +1390,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
|||
if self.fetch_values_predicate:
|
||||
qry = qry.where(self.get_fetch_values_predicate(template_processor=tp))
|
||||
|
||||
with self.database.get_sqla_engine_with_context() as engine:
|
||||
with self.database.get_sqla_engine() as engine:
|
||||
sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
|
||||
sql = self._apply_cte(sql, cte)
|
||||
sql = self.mutate_query_from_config(sql)
|
||||
|
|
@ -1992,7 +1992,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
|||
and db_engine_spec.allows_hidden_cc_in_orderby
|
||||
and col.name in [select_col.name for select_col in select_exprs]
|
||||
):
|
||||
with self.database.get_sqla_engine_with_context() as engine:
|
||||
with self.database.get_sqla_engine() as engine:
|
||||
quote = engine.dialect.identifier_preparer.quote
|
||||
col = literal_column(quote(col.name))
|
||||
direction = sa.asc if ascending else sa.desc
|
||||
|
|
|
|||
|
|
@ -644,7 +644,7 @@ def cancel_query(query: Query) -> bool:
|
|||
if cancel_query_id is None:
|
||||
return False
|
||||
|
||||
with query.database.get_sqla_engine_with_context(
|
||||
with query.database.get_sqla_engine(
|
||||
query.schema, source=QuerySource.SQL_LAB
|
||||
) as engine:
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
|
|
|
|||
|
|
@ -160,9 +160,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
|
|||
logger.info("Validating %i statement(s)", len(statements))
|
||||
# todo(hughhh): update this to use new database.get_raw_connection()
|
||||
# this function keeps stalling CI
|
||||
with database.get_sqla_engine_with_context(
|
||||
schema, source=QuerySource.SQL_LAB
|
||||
) as engine:
|
||||
with database.get_sqla_engine(schema, source=QuerySource.SQL_LAB) as engine:
|
||||
# Sharing a single connection and cursor across the
|
||||
# execution of all statements (if many)
|
||||
annotations: list[SQLValidationAnnotation] = []
|
||||
|
|
|
|||
|
|
@ -1170,7 +1170,7 @@ def get_example_default_schema() -> str | None:
|
|||
Return the default schema of the examples database, if any.
|
||||
"""
|
||||
database = get_example_database()
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
with database.get_sqla_engine() as engine:
|
||||
return inspect(engine).default_schema_name
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -184,7 +184,7 @@ def add_data(
|
|||
database = get_example_database()
|
||||
table_exists = database.has_table_by_name(table_name)
|
||||
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
with database.get_sqla_engine() as engine:
|
||||
if columns is None:
|
||||
if not table_exists:
|
||||
raise Exception( # pylint: disable=broad-exception-raised
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ def example_db_provider() -> Callable[[], Database]:
|
|||
@fixture(scope="session")
|
||||
def example_db_engine(example_db_provider: Callable[[], Database]) -> Engine:
|
||||
with app.app_context():
|
||||
with example_db_provider().get_sqla_engine_with_context() as engine:
|
||||
with example_db_provider().get_sqla_engine() as engine:
|
||||
return engine
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -113,7 +113,7 @@ def drop_table_if_exists(table_name: str, table_type: CtasMethod) -> None:
|
|||
"""Drop table if it exists, works on any DB"""
|
||||
sql = f"DROP {table_type} IF EXISTS {table_name}"
|
||||
database = get_example_database()
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
with database.get_sqla_engine() as engine:
|
||||
engine.execute(sql)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -212,7 +212,7 @@ def setup_presto_if_needed():
|
|||
|
||||
if backend in {"presto", "hive"}:
|
||||
database = get_example_database()
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
with database.get_sqla_engine() as engine:
|
||||
drop_from_schema(engine, CTAS_SCHEMA_NAME)
|
||||
engine.execute(f"DROP SCHEMA IF EXISTS {CTAS_SCHEMA_NAME}")
|
||||
engine.execute(f"CREATE SCHEMA {CTAS_SCHEMA_NAME}")
|
||||
|
|
@ -343,7 +343,7 @@ def physical_dataset():
|
|||
|
||||
example_database = get_example_database()
|
||||
|
||||
with example_database.get_sqla_engine_with_context() as engine:
|
||||
with example_database.get_sqla_engine() as engine:
|
||||
quoter = get_identifier_quoter(engine.name)
|
||||
# sqlite can only execute one statement at a time
|
||||
engine.execute(
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ def _setup_csv_upload():
|
|||
yield
|
||||
|
||||
upload_db = get_upload_db()
|
||||
with upload_db.get_sqla_engine_with_context() as engine:
|
||||
with upload_db.get_sqla_engine() as engine:
|
||||
engine.execute(f"DROP TABLE IF EXISTS {EXCEL_UPLOAD_TABLE}")
|
||||
engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE}")
|
||||
engine.execute(f"DROP TABLE IF EXISTS {PARQUET_UPLOAD_TABLE}")
|
||||
|
|
@ -268,7 +268,7 @@ def test_import_csv_enforced_schema(mock_event_logger):
|
|||
table=CSV_UPLOAD_TABLE_W_SCHEMA,
|
||||
)
|
||||
|
||||
with get_upload_db().get_sqla_engine_with_context() as engine:
|
||||
with get_upload_db().get_sqla_engine() as engine:
|
||||
data = engine.execute(
|
||||
f"SELECT * from {ADMIN_SCHEMA_NAME}.{CSV_UPLOAD_TABLE_W_SCHEMA} ORDER BY b"
|
||||
).fetchall()
|
||||
|
|
@ -294,7 +294,7 @@ def test_import_csv_enforced_schema(mock_event_logger):
|
|||
assert success_msg in resp
|
||||
|
||||
# Clean up
|
||||
with get_upload_db().get_sqla_engine_with_context() as engine:
|
||||
with get_upload_db().get_sqla_engine() as engine:
|
||||
engine.execute(f"DROP TABLE {full_table_name}")
|
||||
|
||||
|
||||
|
|
@ -380,7 +380,7 @@ def test_import_csv(mock_event_logger):
|
|||
extra={"null_values": '["", "john"]', "if_exists": "replace"},
|
||||
)
|
||||
# make sure that john and empty string are replaced with None
|
||||
with test_db.get_sqla_engine_with_context() as engine:
|
||||
with test_db.get_sqla_engine() as engine:
|
||||
data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE} ORDER BY c").fetchall()
|
||||
assert data == [(None, 1, "x"), ("paul", 2, None)]
|
||||
# default null values
|
||||
|
|
@ -390,7 +390,7 @@ def test_import_csv(mock_event_logger):
|
|||
assert data == [("john", 1, "x"), ("paul", 2, None)]
|
||||
|
||||
# cleanup
|
||||
with get_upload_db().get_sqla_engine_with_context() as engine:
|
||||
with get_upload_db().get_sqla_engine() as engine:
|
||||
engine.execute(f"DROP TABLE {full_table_name}")
|
||||
|
||||
# with dtype
|
||||
|
|
@ -403,12 +403,12 @@ def test_import_csv(mock_event_logger):
|
|||
# you can change the type to something compatible, like an object to string
|
||||
# or an int to a float
|
||||
# file upload should work as normal
|
||||
with test_db.get_sqla_engine_with_context() as engine:
|
||||
with test_db.get_sqla_engine() as engine:
|
||||
data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE} ORDER BY b").fetchall()
|
||||
assert data == [("john", 1), ("paul", 2)]
|
||||
|
||||
# cleanup
|
||||
with get_upload_db().get_sqla_engine_with_context() as engine:
|
||||
with get_upload_db().get_sqla_engine() as engine:
|
||||
engine.execute(f"DROP TABLE {full_table_name}")
|
||||
|
||||
# with dtype - wrong type
|
||||
|
|
@ -475,7 +475,7 @@ def test_import_excel(mock_event_logger):
|
|||
table=EXCEL_UPLOAD_TABLE,
|
||||
)
|
||||
|
||||
with test_db.get_sqla_engine_with_context() as engine:
|
||||
with test_db.get_sqla_engine() as engine:
|
||||
data = engine.execute(
|
||||
f"SELECT * from {EXCEL_UPLOAD_TABLE} ORDER BY b"
|
||||
).fetchall()
|
||||
|
|
@ -541,7 +541,7 @@ def test_import_parquet(mock_event_logger):
|
|||
)
|
||||
assert success_msg_f1 in resp
|
||||
|
||||
with test_db.get_sqla_engine_with_context() as engine:
|
||||
with test_db.get_sqla_engine() as engine:
|
||||
data = engine.execute(
|
||||
f"SELECT * from {PARQUET_UPLOAD_TABLE} ORDER BY b"
|
||||
).fetchall()
|
||||
|
|
@ -554,7 +554,7 @@ def test_import_parquet(mock_event_logger):
|
|||
success_msg_f2 = f"Columnar file {escaped_parquet(ZIP_FILENAME)} uploaded to table {escaped_double_quotes(full_table_name)}"
|
||||
assert success_msg_f2 in resp
|
||||
|
||||
with test_db.get_sqla_engine_with_context() as engine:
|
||||
with test_db.get_sqla_engine() as engine:
|
||||
data = engine.execute(
|
||||
f"SELECT * from {PARQUET_UPLOAD_TABLE} ORDER BY b"
|
||||
).fetchall()
|
||||
|
|
|
|||
|
|
@ -895,7 +895,7 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
if database.backend == "mysql":
|
||||
query = query.replace('"', "`")
|
||||
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
with database.get_sqla_engine() as engine:
|
||||
engine.execute(query)
|
||||
|
||||
self.login(ADMIN_USERNAME)
|
||||
|
|
|
|||
|
|
@ -718,7 +718,7 @@ class TestDatasetApi(SupersetTestCase):
|
|||
return
|
||||
|
||||
example_db = get_example_database()
|
||||
with example_db.get_sqla_engine_with_context() as engine:
|
||||
with example_db.get_sqla_engine() as engine:
|
||||
engine.execute(
|
||||
f"CREATE TABLE {CTAS_SCHEMA_NAME}.birth_names AS SELECT 2 as two"
|
||||
)
|
||||
|
|
@ -739,7 +739,7 @@ class TestDatasetApi(SupersetTestCase):
|
|||
uri = f'api/v1/dataset/{data.get("id")}'
|
||||
rv = self.client.delete(uri)
|
||||
assert rv.status_code == 200
|
||||
with example_db.get_sqla_engine_with_context() as engine:
|
||||
with example_db.get_sqla_engine() as engine:
|
||||
engine.execute(f"DROP TABLE {CTAS_SCHEMA_NAME}.birth_names")
|
||||
|
||||
def test_create_dataset_validate_database(self):
|
||||
|
|
@ -800,7 +800,7 @@ class TestDatasetApi(SupersetTestCase):
|
|||
mock_get_table.return_value = None
|
||||
|
||||
example_db = get_example_database()
|
||||
with example_db.get_sqla_engine_with_context() as engine:
|
||||
with example_db.get_sqla_engine() as engine:
|
||||
engine = engine
|
||||
dialect = engine.dialect
|
||||
|
||||
|
|
@ -2389,7 +2389,7 @@ class TestDatasetApi(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
|
||||
examples_db = get_example_database()
|
||||
with examples_db.get_sqla_engine_with_context() as engine:
|
||||
with examples_db.get_sqla_engine() as engine:
|
||||
engine.execute("DROP TABLE IF EXISTS test_create_sqla_table_api")
|
||||
engine.execute("CREATE TABLE test_create_sqla_table_api AS SELECT 2 as col")
|
||||
|
||||
|
|
@ -2415,7 +2415,7 @@ class TestDatasetApi(SupersetTestCase):
|
|||
db.session.delete(table)
|
||||
db.session.commit()
|
||||
|
||||
with examples_db.get_sqla_engine_with_context() as engine:
|
||||
with examples_db.get_sqla_engine() as engine:
|
||||
engine.execute("DROP TABLE test_create_sqla_table_api")
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
|
|
|
|||
|
|
@ -563,7 +563,7 @@ class TestCreateDatasetCommand(SupersetTestCase):
|
|||
|
||||
def test_create_dataset_command(self):
|
||||
examples_db = get_example_database()
|
||||
with examples_db.get_sqla_engine_with_context() as engine:
|
||||
with examples_db.get_sqla_engine() as engine:
|
||||
engine.execute("DROP TABLE IF EXISTS test_create_dataset_command")
|
||||
engine.execute(
|
||||
"CREATE TABLE test_create_dataset_command AS SELECT 2 as col"
|
||||
|
|
@ -585,7 +585,7 @@ class TestCreateDatasetCommand(SupersetTestCase):
|
|||
self.assertEqual([owner.username for owner in table.owners], ["admin"])
|
||||
|
||||
db.session.delete(table)
|
||||
with examples_db.get_sqla_engine_with_context() as engine:
|
||||
with examples_db.get_sqla_engine() as engine:
|
||||
engine.execute("DROP TABLE test_create_dataset_command")
|
||||
db.session.commit()
|
||||
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ def create_test_table_context(database: Database):
|
|||
schema = get_example_default_schema()
|
||||
full_table_name = f"{schema}.test_table" if schema else "test_table"
|
||||
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
with database.get_sqla_engine() as engine:
|
||||
engine.execute(
|
||||
f"CREATE TABLE IF NOT EXISTS {full_table_name} AS SELECT 1 as first, 2 as second"
|
||||
)
|
||||
|
|
@ -56,7 +56,7 @@ def create_test_table_context(database: Database):
|
|||
|
||||
yield db.session
|
||||
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
with database.get_sqla_engine() as engine:
|
||||
engine.execute(f"DROP TABLE {full_table_name}")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -193,7 +193,7 @@ def test_df_to_sql_if_exists_replace(mock_upload_to_s3, mock_g):
|
|||
mock_database = mock.MagicMock()
|
||||
mock_database.get_df.return_value.empty = False
|
||||
mock_execute = mock.MagicMock(return_value=True)
|
||||
mock_database.get_sqla_engine_with_context.return_value.__enter__.return_value.execute = (
|
||||
mock_database.get_sqla_engine.return_value.__enter__.return_value.execute = (
|
||||
mock_execute
|
||||
)
|
||||
table_name = "foobar"
|
||||
|
|
@ -220,7 +220,7 @@ def test_df_to_sql_if_exists_replace_with_schema(mock_upload_to_s3, mock_g):
|
|||
mock_database = mock.MagicMock()
|
||||
mock_database.get_df.return_value.empty = False
|
||||
mock_execute = mock.MagicMock(return_value=True)
|
||||
mock_database.get_sqla_engine_with_context.return_value.__enter__.return_value.execute = (
|
||||
mock_database.get_sqla_engine.return_value.__enter__.return_value.execute = (
|
||||
mock_execute
|
||||
)
|
||||
table_name = "foobar"
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ ENERGY_USAGE_TBL_NAME = "energy_usage"
|
|||
def load_energy_table_data():
|
||||
with app.app_context():
|
||||
database = get_example_database()
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
with database.get_sqla_engine() as engine:
|
||||
df = _get_dataframe()
|
||||
df.to_sql(
|
||||
ENERGY_USAGE_TBL_NAME,
|
||||
|
|
@ -52,7 +52,7 @@ def load_energy_table_data():
|
|||
)
|
||||
yield
|
||||
with app.app_context():
|
||||
with get_example_database().get_sqla_engine_with_context() as engine:
|
||||
with get_example_database().get_sqla_engine() as engine:
|
||||
engine.execute("DROP TABLE IF EXISTS energy_usage")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ UNICODE_TBL_NAME = "unicode_test"
|
|||
@pytest.fixture(scope="session")
|
||||
def load_unicode_data():
|
||||
with app.app_context():
|
||||
with get_example_database().get_sqla_engine_with_context() as engine:
|
||||
with get_example_database().get_sqla_engine() as engine:
|
||||
_get_dataframe().to_sql(
|
||||
UNICODE_TBL_NAME,
|
||||
engine,
|
||||
|
|
@ -51,7 +51,7 @@ def load_unicode_data():
|
|||
|
||||
yield
|
||||
with app.app_context():
|
||||
with get_example_database().get_sqla_engine_with_context() as engine:
|
||||
with get_example_database().get_sqla_engine() as engine:
|
||||
engine.execute("DROP TABLE IF EXISTS unicode_test")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ def load_world_bank_data():
|
|||
"country_name": String(255),
|
||||
"region": String(255),
|
||||
}
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
with database.get_sqla_engine() as engine:
|
||||
_get_dataframe(database).to_sql(
|
||||
WB_HEALTH_POPULATION,
|
||||
engine,
|
||||
|
|
@ -64,7 +64,7 @@ def load_world_bank_data():
|
|||
|
||||
yield
|
||||
with app.app_context():
|
||||
with get_example_database().get_sqla_engine_with_context() as engine:
|
||||
with get_example_database().get_sqla_engine() as engine:
|
||||
engine.execute("DROP TABLE IF EXISTS wb_health_population")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -56,22 +56,22 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
sqlalchemy_uri = "presto://presto.airbnb.io:8080/hive/default"
|
||||
model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri)
|
||||
|
||||
with model.get_sqla_engine_with_context() as engine:
|
||||
with model.get_sqla_engine() as engine:
|
||||
db = make_url(engine.url).database
|
||||
self.assertEqual("hive/default", db)
|
||||
|
||||
with model.get_sqla_engine_with_context(schema="core_db") as engine:
|
||||
with model.get_sqla_engine(schema="core_db") as engine:
|
||||
db = make_url(engine.url).database
|
||||
self.assertEqual("hive/core_db", db)
|
||||
|
||||
sqlalchemy_uri = "presto://presto.airbnb.io:8080/hive"
|
||||
model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri)
|
||||
|
||||
with model.get_sqla_engine_with_context() as engine:
|
||||
with model.get_sqla_engine() as engine:
|
||||
db = make_url(engine.url).database
|
||||
self.assertEqual("hive", db)
|
||||
|
||||
with model.get_sqla_engine_with_context(schema="core_db") as engine:
|
||||
with model.get_sqla_engine(schema="core_db") as engine:
|
||||
db = make_url(engine.url).database
|
||||
self.assertEqual("hive/core_db", db)
|
||||
|
||||
|
|
@ -79,11 +79,11 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
sqlalchemy_uri = "postgresql+psycopg2://postgres.airbnb.io:5439/prod"
|
||||
model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri)
|
||||
|
||||
with model.get_sqla_engine_with_context() as engine:
|
||||
with model.get_sqla_engine() as engine:
|
||||
db = make_url(engine.url).database
|
||||
self.assertEqual("prod", db)
|
||||
|
||||
with model.get_sqla_engine_with_context(schema="foo") as engine:
|
||||
with model.get_sqla_engine(schema="foo") as engine:
|
||||
db = make_url(engine.url).database
|
||||
self.assertEqual("prod", db)
|
||||
|
||||
|
|
@ -97,11 +97,11 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
sqlalchemy_uri = "hive://hive@hive.airbnb.io:10000/default?auth=NOSASL"
|
||||
model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri)
|
||||
|
||||
with model.get_sqla_engine_with_context() as engine:
|
||||
with model.get_sqla_engine() as engine:
|
||||
db = make_url(engine.url).database
|
||||
self.assertEqual("default", db)
|
||||
|
||||
with model.get_sqla_engine_with_context(schema="core_db") as engine:
|
||||
with model.get_sqla_engine(schema="core_db") as engine:
|
||||
db = make_url(engine.url).database
|
||||
self.assertEqual("core_db", db)
|
||||
|
||||
|
|
@ -112,11 +112,11 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
sqlalchemy_uri = "mysql://root@localhost/superset"
|
||||
model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri)
|
||||
|
||||
with model.get_sqla_engine_with_context() as engine:
|
||||
with model.get_sqla_engine() as engine:
|
||||
db = make_url(engine.url).database
|
||||
self.assertEqual("superset", db)
|
||||
|
||||
with model.get_sqla_engine_with_context(schema="staging") as engine:
|
||||
with model.get_sqla_engine(schema="staging") as engine:
|
||||
db = make_url(engine.url).database
|
||||
self.assertEqual("staging", db)
|
||||
|
||||
|
|
@ -130,12 +130,12 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
|
||||
with override_user(example_user):
|
||||
model.impersonate_user = True
|
||||
with model.get_sqla_engine_with_context() as engine:
|
||||
with model.get_sqla_engine() as engine:
|
||||
username = make_url(engine.url).username
|
||||
self.assertEqual(example_user.username, username)
|
||||
|
||||
model.impersonate_user = False
|
||||
with model.get_sqla_engine_with_context() as engine:
|
||||
with model.get_sqla_engine() as engine:
|
||||
username = make_url(engine.url).username
|
||||
self.assertNotEqual(example_user.username, username)
|
||||
|
||||
|
|
@ -295,7 +295,7 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
db = get_example_database()
|
||||
table_name = "energy_usage"
|
||||
sql = db.select_star(table_name, show_cols=False, latest_partition=False)
|
||||
with db.get_sqla_engine_with_context() as engine:
|
||||
with db.get_sqla_engine() as engine:
|
||||
quote = engine.dialect.identifier_preparer.quote_identifier
|
||||
|
||||
source = quote(table_name) if db.backend in {"presto", "hive"} else table_name
|
||||
|
|
|
|||
|
|
@ -150,13 +150,13 @@ def assert_log(state: str, error_message: Optional[str] = None):
|
|||
|
||||
@contextmanager
|
||||
def create_test_table_context(database: Database):
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
with database.get_sqla_engine() as engine:
|
||||
engine.execute("CREATE TABLE test_table AS SELECT 1 as first, 2 as second")
|
||||
engine.execute("INSERT INTO test_table (first, second) VALUES (1, 2)")
|
||||
engine.execute("INSERT INTO test_table (first, second) VALUES (3, 4)")
|
||||
|
||||
yield db.session
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
with database.get_sqla_engine() as engine:
|
||||
engine.execute("DROP TABLE test_table")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ class TestPrestoValidator(SupersetTestCase):
|
|||
self.validator = PrestoDBSQLValidator
|
||||
self.database = MagicMock()
|
||||
self.database_engine = (
|
||||
self.database.get_sqla_engine_with_context.return_value.__enter__.return_value
|
||||
self.database.get_sqla_engine.return_value.__enter__.return_value
|
||||
)
|
||||
self.database_conn = self.database_engine.raw_connection.return_value
|
||||
self.database_cursor = self.database_conn.cursor.return_value
|
||||
|
|
|
|||
|
|
@ -313,7 +313,7 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
query = table.database.compile_sqla_query(sqla_query.sqla_query)
|
||||
|
||||
database = table.database
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
with database.get_sqla_engine() as engine:
|
||||
quote = engine.dialect.identifier_preparer.quote_identifier
|
||||
|
||||
for metric_label in {"metric using jinja macro", "same but different"}:
|
||||
|
|
|
|||
|
|
@ -212,7 +212,7 @@ class TestSqlLab(SupersetTestCase):
|
|||
# assertions
|
||||
db.session.commit()
|
||||
examples_db = get_example_database()
|
||||
with examples_db.get_sqla_engine_with_context() as engine:
|
||||
with examples_db.get_sqla_engine() as engine:
|
||||
data = engine.execute(
|
||||
f"SELECT * FROM admin_database.{tmp_table_name}"
|
||||
).fetchall()
|
||||
|
|
@ -296,7 +296,7 @@ class TestSqlLab(SupersetTestCase):
|
|||
"SchemaUser", ["SchemaPermission", "Gamma", "sql_lab"]
|
||||
)
|
||||
|
||||
with examples_db.get_sqla_engine_with_context() as engine:
|
||||
with examples_db.get_sqla_engine() as engine:
|
||||
engine.execute(
|
||||
f"CREATE TABLE IF NOT EXISTS {CTAS_SCHEMA_NAME}.test_table AS SELECT 1 as c1, 2 as c2"
|
||||
)
|
||||
|
|
@ -325,7 +325,7 @@ class TestSqlLab(SupersetTestCase):
|
|||
self.assertEqual(1, len(data["data"]))
|
||||
|
||||
db.session.query(Query).delete()
|
||||
with get_example_database().get_sqla_engine_with_context() as engine:
|
||||
with get_example_database().get_sqla_engine() as engine:
|
||||
engine.execute(f"DROP TABLE IF EXISTS {CTAS_SCHEMA_NAME}.test_table")
|
||||
db.session.commit()
|
||||
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ def database1(session: Session) -> Iterator["Database"]:
|
|||
|
||||
@pytest.fixture
|
||||
def table1(session: Session, database1: "Database") -> Iterator[None]:
|
||||
with database1.get_sqla_engine_with_context() as engine:
|
||||
with database1.get_sqla_engine() as engine:
|
||||
conn = engine.connect()
|
||||
conn.execute("CREATE TABLE table1 (a INTEGER NOT NULL PRIMARY KEY, b INTEGER)")
|
||||
conn.execute("INSERT INTO table1 (a, b) VALUES (1, 10), (2, 20)")
|
||||
|
|
@ -92,7 +92,7 @@ def database2(session: Session) -> Iterator["Database"]:
|
|||
|
||||
@pytest.fixture
|
||||
def table2(session: Session, database2: "Database") -> Iterator[None]:
|
||||
with database2.get_sqla_engine_with_context() as engine:
|
||||
with database2.get_sqla_engine() as engine:
|
||||
conn = engine.connect()
|
||||
conn.execute("CREATE TABLE table2 (a INTEGER NOT NULL PRIMARY KEY, b TEXT)")
|
||||
conn.execute("INSERT INTO table2 (a, b) VALUES (1, 'ten'), (2, 'twenty')")
|
||||
|
|
|
|||
|
|
@ -220,7 +220,7 @@ def test_get_prequeries(mocker: MockFixture) -> None:
|
|||
"""
|
||||
mocker.patch.object(
|
||||
Database,
|
||||
"get_sqla_engine_with_context",
|
||||
"get_sqla_engine",
|
||||
)
|
||||
db_engine_spec = mocker.patch.object(Database, "db_engine_spec")
|
||||
db_engine_spec.get_prequeries.return_value = ["set a=1", "set b=2"]
|
||||
|
|
|
|||
|
|
@ -54,13 +54,13 @@ def test_values_for_column(mocker: MockerFixture, session: Session) -> None:
|
|||
# since we're using an in-memory SQLite database, make sure we always
|
||||
# return the same engine where the table was created
|
||||
@contextmanager
|
||||
def mock_get_sqla_engine_with_context():
|
||||
def mock_get_sqla_engine():
|
||||
yield engine
|
||||
|
||||
mocker.patch.object(
|
||||
database,
|
||||
"get_sqla_engine_with_context",
|
||||
new=mock_get_sqla_engine_with_context,
|
||||
"get_sqla_engine",
|
||||
new=mock_get_sqla_engine,
|
||||
)
|
||||
|
||||
table = SqlaTable(
|
||||
|
|
|
|||
Loading…
Reference in New Issue