feat: refactor all `get_sqla_engine` to use contextmanager in codebase (#21943)

This commit is contained in:
Hugh A. Miles II 2022-11-15 13:45:14 -05:00 committed by GitHub
parent 06f87e1467
commit e23efefc46
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
41 changed files with 635 additions and 595 deletions

View File

@ -804,13 +804,13 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
if self.fetch_values_predicate:
qry = qry.where(self.get_fetch_values_predicate())
engine = self.database.get_sqla_engine()
sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
sql = self._apply_cte(sql, cte)
sql = self.mutate_query_from_config(sql)
with self.database.get_sqla_engine_with_context() as engine:
sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
sql = self._apply_cte(sql, cte)
sql = self.mutate_query_from_config(sql)
df = pd.read_sql_query(sql=sql, con=engine)
return df[column_name].to_list()
df = pd.read_sql_query(sql=sql, con=engine)
return df[column_name].to_list()
def mutate_query_from_config(self, sql: str) -> str:
"""Apply config's SQL_QUERY_MUTATOR

View File

@ -112,7 +112,6 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> List[ResultSetColumnType]:
)
db_engine_spec = dataset.database.db_engine_spec
engine = dataset.database.get_sqla_engine(schema=dataset.schema)
sql = dataset.get_template_processor().process_template(
dataset.sql, **dataset.template_params_dict
)
@ -137,13 +136,18 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> List[ResultSetColumnType]:
# TODO(villebro): refactor to use same code that's used by
# sql_lab.py:execute_sql_statements
try:
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
query = dataset.database.apply_limit_to_sql(statements[0], limit=1)
db_engine_spec.execute(cursor, query)
result = db_engine_spec.fetch_data(cursor, limit=1)
result_set = SupersetResultSet(result, cursor.description, db_engine_spec)
cols = result_set.columns
with dataset.database.get_sqla_engine_with_context(
schema=dataset.schema
) as engine:
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
query = dataset.database.apply_limit_to_sql(statements[0], limit=1)
db_engine_spec.execute(cursor, query)
result = db_engine_spec.fetch_data(cursor, limit=1)
result_set = SupersetResultSet(
result, cursor.description, db_engine_spec
)
cols = result_set.columns
except Exception as ex:
raise SupersetGenericDBErrorException(message=str(ex)) from ex
return cols
@ -155,14 +159,17 @@ def get_columns_description(
) -> List[ResultSetColumnType]:
db_engine_spec = database.db_engine_spec
try:
with closing(database.get_sqla_engine().raw_connection()) as conn:
cursor = conn.cursor()
query = database.apply_limit_to_sql(query, limit=1)
cursor.execute(query)
db_engine_spec.execute(cursor, query)
result = db_engine_spec.fetch_data(cursor, limit=1)
result_set = SupersetResultSet(result, cursor.description, db_engine_spec)
return result_set.columns
with database.get_sqla_engine_with_context() as engine:
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
query = database.apply_limit_to_sql(query, limit=1)
cursor.execute(query)
db_engine_spec.execute(cursor, query)
result = db_engine_spec.fetch_data(cursor, limit=1)
result_set = SupersetResultSet(
result, cursor.description, db_engine_spec
)
return result_set.columns
except Exception as ex:
raise SupersetGenericDBErrorException(message=str(ex)) from ex

View File

@ -90,7 +90,6 @@ class TestConnectionDatabaseCommand(BaseCommand):
database.set_sqlalchemy_uri(uri)
database.db_engine_spec.mutate_db_for_connection_test(database)
engine = database.get_sqla_engine()
event_logger.log_with_context(
action="test_connection_attempt",
engine=database.db_engine_spec.__name__,
@ -100,31 +99,32 @@ class TestConnectionDatabaseCommand(BaseCommand):
with closing(engine.raw_connection()) as conn:
return engine.dialect.do_ping(conn)
try:
alive = func_timeout(
int(app.config["TEST_DATABASE_CONNECTION_TIMEOUT"].total_seconds()),
ping,
args=(engine,),
)
except (sqlite3.ProgrammingError, RuntimeError):
# SQLite can't run on a separate thread, so ``func_timeout`` fails
# RuntimeError catches the equivalent error from duckdb.
alive = engine.dialect.do_ping(engine)
except FunctionTimedOut as ex:
raise SupersetTimeoutException(
error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT,
message=(
"Please check your connection details and database settings, "
"and ensure that your database is accepting connections, "
"then try connecting again."
),
level=ErrorLevel.ERROR,
extra={"sqlalchemy_uri": database.sqlalchemy_uri},
) from ex
except Exception as ex: # pylint: disable=broad-except
alive = False
# So we stop losing the original message if any
ex_str = str(ex)
with database.get_sqla_engine_with_context() as engine:
try:
alive = func_timeout(
app.config["TEST_DATABASE_CONNECTION_TIMEOUT"].total_seconds(),
ping,
args=(engine,),
)
except (sqlite3.ProgrammingError, RuntimeError):
# SQLite can't run on a separate thread, so ``func_timeout`` fails
# RuntimeError catches the equivalent error from duckdb.
alive = engine.dialect.do_ping(engine)
except FunctionTimedOut as ex:
raise SupersetTimeoutException(
error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT,
message=(
"Please check your connection details and database settings, "
"and ensure that your database is accepting connections, "
"then try connecting again."
),
level=ErrorLevel.ERROR,
extra={"sqlalchemy_uri": database.sqlalchemy_uri},
) from ex
except Exception as ex: # pylint: disable=broad-except
alive = False
# So we stop losing the original message if any
ex_str = str(ex)
if not alive:
raise DBAPIError(ex_str or None, None, None)

View File

@ -101,21 +101,22 @@ class ValidateDatabaseParametersCommand(BaseCommand):
database.set_sqlalchemy_uri(sqlalchemy_uri)
database.db_engine_spec.mutate_db_for_connection_test(database)
engine = database.get_sqla_engine()
try:
with closing(engine.raw_connection()) as conn:
alive = engine.dialect.do_ping(conn)
except Exception as ex:
url = make_url_safe(sqlalchemy_uri)
context = {
"hostname": url.host,
"password": url.password,
"port": url.port,
"username": url.username,
"database": url.database,
}
errors = database.db_engine_spec.extract_errors(ex, context)
raise DatabaseTestConnectionFailedError(errors) from ex
alive = False
with database.get_sqla_engine_with_context() as engine:
try:
with closing(engine.raw_connection()) as conn:
alive = engine.dialect.do_ping(conn)
except Exception as ex:
url = make_url_safe(sqlalchemy_uri)
context = {
"hostname": url.host,
"password": url.password,
"port": url.port,
"username": url.username,
"database": url.database,
}
errors = database.db_engine_spec.extract_errors(ex, context)
raise DatabaseTestConnectionFailedError(errors) from ex
if not alive:
raise DatabaseOfflineError(

View File

@ -166,17 +166,26 @@ def load_data(
if database.sqlalchemy_uri == current_app.config.get("SQLALCHEMY_DATABASE_URI"):
logger.info("Loading data inside the import transaction")
connection = session.connection()
df.to_sql(
dataset.table_name,
con=connection,
schema=dataset.schema,
if_exists="replace",
chunksize=CHUNKSIZE,
dtype=dtype,
index=False,
method="multi",
)
else:
logger.warning("Loading data outside the import transaction")
connection = database.get_sqla_engine()
df.to_sql(
dataset.table_name,
con=connection,
schema=dataset.schema,
if_exists="replace",
chunksize=CHUNKSIZE,
dtype=dtype,
index=False,
method="multi",
)
with database.get_sqla_engine_with_context() as engine:
df.to_sql(
dataset.table_name,
con=engine,
schema=dataset.schema,
if_exists="replace",
chunksize=CHUNKSIZE,
dtype=dtype,
index=False,
method="multi",
)

View File

@ -23,6 +23,7 @@ from datetime import datetime
from typing import (
Any,
Callable,
ContextManager,
Dict,
List,
Match,
@ -480,8 +481,16 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
database: "Database",
schema: Optional[str] = None,
source: Optional[utils.QuerySource] = None,
) -> Engine:
return database.get_sqla_engine(schema=schema, source=source)
) -> ContextManager[Engine]:
"""
Return an engine context manager.
>>> with DBEngineSpec.get_engine(database, schema, source) as engine:
... connection = engine.connect()
... connection.execute(sql)
"""
return database.get_sqla_engine_with_context(schema=schema, source=source)
@classmethod
def get_timestamp_expr(
@ -903,17 +912,17 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:param to_sql_kwargs: The kwargs to be passed to pandas.DataFrame.to_sql` method
"""
engine = cls.get_engine(database)
to_sql_kwargs["name"] = table.table
if table.schema:
# Only add schema when it is preset and non empty.
to_sql_kwargs["schema"] = table.schema
if engine.dialect.supports_multivalues_insert:
to_sql_kwargs["method"] = "multi"
with cls.get_engine(database) as engine:
if engine.dialect.supports_multivalues_insert:
to_sql_kwargs["method"] = "multi"
df.to_sql(con=engine, **to_sql_kwargs)
df.to_sql(con=engine, **to_sql_kwargs)
@classmethod
def convert_dttm( # pylint: disable=unused-argument
@ -1286,13 +1295,15 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
parsed_query = sql_parse.ParsedQuery(sql)
statements = parsed_query.get_statements()
engine = cls.get_engine(database, schema=schema, source=source)
costs = []
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
for statement in statements:
processed_statement = cls.process_statement(statement, database)
costs.append(cls.estimate_statement_cost(processed_statement, cursor))
with cls.get_engine(database, schema=schema, source=source) as engine:
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
for statement in statements:
processed_statement = cls.process_statement(statement, database)
costs.append(
cls.estimate_statement_cost(processed_statement, cursor)
)
return costs
@classmethod

View File

@ -340,8 +340,12 @@ class BigQueryEngineSpec(BaseEngineSpec):
if not table.schema:
raise Exception("The table schema must be defined")
engine = cls.get_engine(database)
to_gbq_kwargs = {"destination_table": str(table), "project_id": engine.url.host}
to_gbq_kwargs = {}
with cls.get_engine(database) as engine:
to_gbq_kwargs = {
"destination_table": str(table),
"project_id": engine.url.host,
}
# Add credentials if they are set on the SQLAlchemy dialect.
creds = engine.dialect.credentials_info

View File

@ -109,11 +109,11 @@ class GSheetsEngineSpec(SqliteEngineSpec):
table_name: str,
schema_name: Optional[str],
) -> Dict[str, Any]:
engine = cls.get_engine(database, schema=schema_name)
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
cursor.execute(f'SELECT GET_METADATA("{table_name}")')
results = cursor.fetchone()[0]
with cls.get_engine(database, schema=schema_name) as engine:
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
cursor.execute(f'SELECT GET_METADATA("{table_name}")')
results = cursor.fetchone()[0]
try:
metadata = json.loads(results)

View File

@ -185,8 +185,6 @@ class HiveEngineSpec(PrestoEngineSpec):
:param to_sql_kwargs: The kwargs to be passed to pandas.DataFrame.to_sql` method
"""
engine = cls.get_engine(database)
if to_sql_kwargs["if_exists"] == "append":
raise SupersetException("Append operation not currently supported")
@ -205,7 +203,8 @@ class HiveEngineSpec(PrestoEngineSpec):
if table_exists:
raise SupersetException("Table already exists")
elif to_sql_kwargs["if_exists"] == "replace":
engine.execute(f"DROP TABLE IF EXISTS {str(table)}")
with cls.get_engine(database) as engine:
engine.execute(f"DROP TABLE IF EXISTS {str(table)}")
def _get_hive_type(dtype: np.dtype) -> str:
hive_type_by_dtype = {

View File

@ -462,12 +462,11 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
).strip()
params = {}
engine = cls.get_engine(database, schema=schema)
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
cursor.execute(sql, params)
results = cursor.fetchall()
with cls.get_engine(database, schema=schema) as engine:
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
cursor.execute(sql, params)
results = cursor.fetchall()
return sorted([row[0] for row in results])
@ -989,17 +988,17 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
# pylint: disable=import-outside-toplevel
from pyhive.exc import DatabaseError
engine = cls.get_engine(database, schema)
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
sql = f"SHOW CREATE VIEW {schema}.{table}"
try:
cls.execute(cursor, sql)
with cls.get_engine(database, schema=schema) as engine:
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
sql = f"SHOW CREATE VIEW {schema}.{table}"
try:
cls.execute(cursor, sql)
except DatabaseError: # not a VIEW
return None
rows = cls.fetch_data(cursor, 1)
return rows[0][0]
except DatabaseError: # not a VIEW
return None
rows = cls.fetch_data(cursor, 1)
return rows[0][0]
@classmethod
def get_tracking_url(cls, cursor: "Cursor") -> Optional[str]:

View File

@ -29,31 +29,31 @@ from .helpers import get_example_url, get_table_connector_registry
def load_bart_lines(only_metadata: bool = False, force: bool = False) -> None:
tbl_name = "bart_lines"
database = get_example_database()
engine = database.get_sqla_engine()
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)
with database.get_sqla_engine_with_context() as engine:
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)
if not only_metadata and (not table_exists or force):
url = get_example_url("bart-lines.json.gz")
df = pd.read_json(url, encoding="latin-1", compression="gzip")
df["path_json"] = df.path.map(json.dumps)
df["polyline"] = df.path.map(polyline.encode)
del df["path"]
if not only_metadata and (not table_exists or force):
url = get_example_url("bart-lines.json.gz")
df = pd.read_json(url, encoding="latin-1", compression="gzip")
df["path_json"] = df.path.map(json.dumps)
df["polyline"] = df.path.map(polyline.encode)
del df["path"]
df.to_sql(
tbl_name,
engine,
schema=schema,
if_exists="replace",
chunksize=500,
dtype={
"color": String(255),
"name": String(255),
"polyline": Text,
"path_json": Text,
},
index=False,
)
df.to_sql(
tbl_name,
engine,
schema=schema,
if_exists="replace",
chunksize=500,
dtype={
"color": String(255),
"name": String(255),
"polyline": Text,
"path_json": Text,
},
index=False,
)
print("Creating table {} reference".format(tbl_name))
table = get_table_connector_registry()

View File

@ -76,25 +76,25 @@ def load_data(tbl_name: str, database: Database, sample: bool = False) -> None:
pdf.ds = pd.to_datetime(pdf.ds, unit="ms")
pdf = pdf.head(100) if sample else pdf
engine = database.get_sqla_engine()
schema = inspect(engine).default_schema_name
with database.get_sqla_engine_with_context() as engine:
schema = inspect(engine).default_schema_name
pdf.to_sql(
tbl_name,
database.get_sqla_engine(),
schema=schema,
if_exists="replace",
chunksize=500,
dtype={
# TODO(bkyryliuk): use TIMESTAMP type for presto
"ds": DateTime if database.backend != "presto" else String(255),
"gender": String(16),
"state": String(10),
"name": String(255),
},
method="multi",
index=False,
)
pdf.to_sql(
tbl_name,
engine,
schema=schema,
if_exists="replace",
chunksize=500,
dtype={
# TODO(bkyryliuk): use TIMESTAMP type for presto
"ds": DateTime if database.backend != "presto" else String(255),
"gender": String(16),
"state": String(10),
"name": String(255),
},
method="multi",
index=False,
)
print("Done loading table!")
print("-" * 80)
@ -104,8 +104,8 @@ def load_birth_names(
) -> None:
"""Loading birth name dataset from a zip file in the repo"""
database = get_example_database()
engine = database.get_sqla_engine()
schema = inspect(engine).default_schema_name
with database.get_sqla_engine_with_context() as engine:
schema = inspect(engine).default_schema_name
tbl_name = "birth_names"
table_exists = database.has_table_by_name(tbl_name, schema=schema)

View File

@ -39,38 +39,39 @@ def load_country_map_data(only_metadata: bool = False, force: bool = False) -> N
"""Loading data for map with country map"""
tbl_name = "birth_france_by_region"
database = database_utils.get_example_database()
engine = database.get_sqla_engine()
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)
if not only_metadata and (not table_exists or force):
url = get_example_url("birth_france_data_for_country_map.csv")
data = pd.read_csv(url, encoding="utf-8")
data["dttm"] = datetime.datetime.now().date()
data.to_sql(
tbl_name,
engine,
schema=schema,
if_exists="replace",
chunksize=500,
dtype={
"DEPT_ID": String(10),
"2003": BigInteger,
"2004": BigInteger,
"2005": BigInteger,
"2006": BigInteger,
"2007": BigInteger,
"2008": BigInteger,
"2009": BigInteger,
"2010": BigInteger,
"2011": BigInteger,
"2012": BigInteger,
"2013": BigInteger,
"2014": BigInteger,
"dttm": Date(),
},
index=False,
)
with database.get_sqla_engine_with_context() as engine:
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)
if not only_metadata and (not table_exists or force):
url = get_example_url("birth_france_data_for_country_map.csv")
data = pd.read_csv(url, encoding="utf-8")
data["dttm"] = datetime.datetime.now().date()
data.to_sql(
tbl_name,
engine,
schema=schema,
if_exists="replace",
chunksize=500,
dtype={
"DEPT_ID": String(10),
"2003": BigInteger,
"2004": BigInteger,
"2005": BigInteger,
"2006": BigInteger,
"2007": BigInteger,
"2008": BigInteger,
"2009": BigInteger,
"2010": BigInteger,
"2011": BigInteger,
"2012": BigInteger,
"2013": BigInteger,
"2014": BigInteger,
"dttm": Date(),
},
index=False,
)
print("Done loading table!")
print("-" * 80)

View File

@ -41,24 +41,25 @@ def load_energy(
"""Loads an energy related dataset to use with sankey and graphs"""
tbl_name = "energy_usage"
database = database_utils.get_example_database()
engine = database.get_sqla_engine()
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)
if not only_metadata and (not table_exists or force):
url = get_example_url("energy.json.gz")
pdf = pd.read_json(url, compression="gzip")
pdf = pdf.head(100) if sample else pdf
pdf.to_sql(
tbl_name,
engine,
schema=schema,
if_exists="replace",
chunksize=500,
dtype={"source": String(255), "target": String(255), "value": Float()},
index=False,
method="multi",
)
with database.get_sqla_engine_with_context() as engine:
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)
if not only_metadata and (not table_exists or force):
url = get_example_url("energy.json.gz")
pdf = pd.read_json(url, compression="gzip")
pdf = pdf.head(100) if sample else pdf
pdf.to_sql(
tbl_name,
engine,
schema=schema,
if_exists="replace",
chunksize=500,
dtype={"source": String(255), "target": String(255), "value": Float()},
index=False,
method="multi",
)
print("Creating table [wb_health_population] reference")
table = get_table_connector_registry()

View File

@ -27,35 +27,37 @@ def load_flights(only_metadata: bool = False, force: bool = False) -> None:
"""Loading random time series data from a zip file in the repo"""
tbl_name = "flights"
database = database_utils.get_example_database()
engine = database.get_sqla_engine()
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)
with database.get_sqla_engine_with_context() as engine:
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)
if not only_metadata and (not table_exists or force):
flight_data_url = get_example_url("flight_data.csv.gz")
pdf = pd.read_csv(flight_data_url, encoding="latin-1", compression="gzip")
if not only_metadata and (not table_exists or force):
flight_data_url = get_example_url("flight_data.csv.gz")
pdf = pd.read_csv(flight_data_url, encoding="latin-1", compression="gzip")
# Loading airports info to join and get lat/long
airports_url = get_example_url("airports.csv.gz")
airports = pd.read_csv(airports_url, encoding="latin-1", compression="gzip")
airports = airports.set_index("IATA_CODE")
# Loading airports info to join and get lat/long
airports_url = get_example_url("airports.csv.gz")
airports = pd.read_csv(airports_url, encoding="latin-1", compression="gzip")
airports = airports.set_index("IATA_CODE")
pdf[ # pylint: disable=unsupported-assignment-operation,useless-suppression
"ds"
] = (pdf.YEAR.map(str) + "-0" + pdf.MONTH.map(str) + "-0" + pdf.DAY.map(str))
pdf.ds = pd.to_datetime(pdf.ds)
pdf.drop(columns=["DAY", "MONTH", "YEAR"])
pdf = pdf.join(airports, on="ORIGIN_AIRPORT", rsuffix="_ORIG")
pdf = pdf.join(airports, on="DESTINATION_AIRPORT", rsuffix="_DEST")
pdf.to_sql(
tbl_name,
engine,
schema=schema,
if_exists="replace",
chunksize=500,
dtype={"ds": DateTime},
index=False,
)
pdf[ # pylint: disable=unsupported-assignment-operation,useless-suppression
"ds"
] = (
pdf.YEAR.map(str) + "-0" + pdf.MONTH.map(str) + "-0" + pdf.DAY.map(str)
)
pdf.ds = pd.to_datetime(pdf.ds)
pdf.drop(columns=["DAY", "MONTH", "YEAR"])
pdf = pdf.join(airports, on="ORIGIN_AIRPORT", rsuffix="_ORIG")
pdf = pdf.join(airports, on="DESTINATION_AIRPORT", rsuffix="_DEST")
pdf.to_sql(
tbl_name,
engine,
schema=schema,
if_exists="replace",
chunksize=500,
dtype={"ds": DateTime},
index=False,
)
table = get_table_connector_registry()
tbl = db.session.query(table).filter_by(table_name=tbl_name).first()

View File

@ -39,49 +39,51 @@ def load_long_lat_data(only_metadata: bool = False, force: bool = False) -> None
"""Loading lat/long data from a csv file in the repo"""
tbl_name = "long_lat"
database = database_utils.get_example_database()
engine = database.get_sqla_engine()
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)
with database.get_sqla_engine_with_context() as engine:
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)
if not only_metadata and (not table_exists or force):
url = get_example_url("san_francisco.csv.gz")
pdf = pd.read_csv(url, encoding="utf-8", compression="gzip")
start = datetime.datetime.now().replace(
hour=0, minute=0, second=0, microsecond=0
)
pdf["datetime"] = [
start + datetime.timedelta(hours=i * 24 / (len(pdf) - 1))
for i in range(len(pdf))
]
pdf["occupancy"] = [random.randint(1, 6) for _ in range(len(pdf))]
pdf["radius_miles"] = [random.uniform(1, 3) for _ in range(len(pdf))]
pdf["geohash"] = pdf[["LAT", "LON"]].apply(lambda x: geohash.encode(*x), axis=1)
pdf["delimited"] = pdf["LAT"].map(str).str.cat(pdf["LON"].map(str), sep=",")
pdf.to_sql(
tbl_name,
engine,
schema=schema,
if_exists="replace",
chunksize=500,
dtype={
"longitude": Float(),
"latitude": Float(),
"number": Float(),
"street": String(100),
"unit": String(10),
"city": String(50),
"district": String(50),
"region": String(50),
"postcode": Float(),
"id": String(100),
"datetime": DateTime(),
"occupancy": Float(),
"radius_miles": Float(),
"geohash": String(12),
"delimited": String(60),
},
index=False,
)
if not only_metadata and (not table_exists or force):
url = get_example_url("san_francisco.csv.gz")
pdf = pd.read_csv(url, encoding="utf-8", compression="gzip")
start = datetime.datetime.now().replace(
hour=0, minute=0, second=0, microsecond=0
)
pdf["datetime"] = [
start + datetime.timedelta(hours=i * 24 / (len(pdf) - 1))
for i in range(len(pdf))
]
pdf["occupancy"] = [random.randint(1, 6) for _ in range(len(pdf))]
pdf["radius_miles"] = [random.uniform(1, 3) for _ in range(len(pdf))]
pdf["geohash"] = pdf[["LAT", "LON"]].apply(
lambda x: geohash.encode(*x), axis=1
)
pdf["delimited"] = pdf["LAT"].map(str).str.cat(pdf["LON"].map(str), sep=",")
pdf.to_sql(
tbl_name,
engine,
schema=schema,
if_exists="replace",
chunksize=500,
dtype={
"longitude": Float(),
"latitude": Float(),
"number": Float(),
"street": String(100),
"unit": String(10),
"city": String(50),
"district": String(50),
"region": String(50),
"postcode": Float(),
"id": String(100),
"datetime": DateTime(),
"occupancy": Float(),
"radius_miles": Float(),
"geohash": String(12),
"delimited": String(60),
},
index=False,
)
print("Done loading table!")
print("-" * 80)

View File

@ -39,41 +39,41 @@ def load_multiformat_time_series( # pylint: disable=too-many-locals
"""Loading time series data from a zip file in the repo"""
tbl_name = "multiformat_time_series"
database = get_example_database()
engine = database.get_sqla_engine()
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)
with database.get_sqla_engine_with_context() as engine:
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)
if not only_metadata and (not table_exists or force):
url = get_example_url("multiformat_time_series.json.gz")
pdf = pd.read_json(url, compression="gzip")
# TODO(bkyryliuk): move load examples data into the pytest fixture
if database.backend == "presto":
pdf.ds = pd.to_datetime(pdf.ds, unit="s")
pdf.ds = pdf.ds.dt.strftime("%Y-%m-%d")
pdf.ds2 = pd.to_datetime(pdf.ds2, unit="s")
pdf.ds2 = pdf.ds2.dt.strftime("%Y-%m-%d %H:%M%:%S")
else:
pdf.ds = pd.to_datetime(pdf.ds, unit="s")
pdf.ds2 = pd.to_datetime(pdf.ds2, unit="s")
if not only_metadata and (not table_exists or force):
url = get_example_url("multiformat_time_series.json.gz")
pdf = pd.read_json(url, compression="gzip")
# TODO(bkyryliuk): move load examples data into the pytest fixture
if database.backend == "presto":
pdf.ds = pd.to_datetime(pdf.ds, unit="s")
pdf.ds = pdf.ds.dt.strftime("%Y-%m-%d")
pdf.ds2 = pd.to_datetime(pdf.ds2, unit="s")
pdf.ds2 = pdf.ds2.dt.strftime("%Y-%m-%d %H:%M%:%S")
else:
pdf.ds = pd.to_datetime(pdf.ds, unit="s")
pdf.ds2 = pd.to_datetime(pdf.ds2, unit="s")
pdf.to_sql(
tbl_name,
engine,
schema=schema,
if_exists="replace",
chunksize=500,
dtype={
"ds": String(255) if database.backend == "presto" else Date,
"ds2": String(255) if database.backend == "presto" else DateTime,
"epoch_s": BigInteger,
"epoch_ms": BigInteger,
"string0": String(100),
"string1": String(100),
"string2": String(100),
"string3": String(100),
},
index=False,
)
pdf.to_sql(
tbl_name,
engine,
schema=schema,
if_exists="replace",
chunksize=500,
dtype={
"ds": String(255) if database.backend == "presto" else Date,
"ds2": String(255) if database.backend == "presto" else DateTime,
"epoch_s": BigInteger,
"epoch_ms": BigInteger,
"string0": String(100),
"string1": String(100),
"string2": String(100),
"string3": String(100),
},
index=False,
)
print("Done loading table!")
print("-" * 80)

View File

@ -28,29 +28,29 @@ from .helpers import get_example_url, get_table_connector_registry
def load_paris_iris_geojson(only_metadata: bool = False, force: bool = False) -> None:
tbl_name = "paris_iris_mapping"
database = database_utils.get_example_database()
engine = database.get_sqla_engine()
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)
with database.get_sqla_engine_with_context() as engine:
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)
if not only_metadata and (not table_exists or force):
url = get_example_url("paris_iris.json.gz")
df = pd.read_json(url, compression="gzip")
df["features"] = df.features.map(json.dumps)
if not only_metadata and (not table_exists or force):
url = get_example_url("paris_iris.json.gz")
df = pd.read_json(url, compression="gzip")
df["features"] = df.features.map(json.dumps)
df.to_sql(
tbl_name,
engine,
schema=schema,
if_exists="replace",
chunksize=500,
dtype={
"color": String(255),
"name": String(255),
"features": Text,
"type": Text,
},
index=False,
)
df.to_sql(
tbl_name,
engine,
schema=schema,
if_exists="replace",
chunksize=500,
dtype={
"color": String(255),
"name": String(255),
"features": Text,
"type": Text,
},
index=False,
)
print("Creating table {} reference".format(tbl_name))
table = get_table_connector_registry()

View File

@ -37,28 +37,28 @@ def load_random_time_series_data(
"""Loading random time series data from a zip file in the repo"""
tbl_name = "random_time_series"
database = database_utils.get_example_database()
engine = database.get_sqla_engine()
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)
with database.get_sqla_engine_with_context() as engine:
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)
if not only_metadata and (not table_exists or force):
url = get_example_url("random_time_series.json.gz")
pdf = pd.read_json(url, compression="gzip")
if database.backend == "presto":
pdf.ds = pd.to_datetime(pdf.ds, unit="s")
pdf.ds = pdf.ds.dt.strftime("%Y-%m-%d %H:%M%:%S")
else:
pdf.ds = pd.to_datetime(pdf.ds, unit="s")
if not only_metadata and (not table_exists or force):
url = get_example_url("random_time_series.json.gz")
pdf = pd.read_json(url, compression="gzip")
if database.backend == "presto":
pdf.ds = pd.to_datetime(pdf.ds, unit="s")
pdf.ds = pdf.ds.dt.strftime("%Y-%m-%d %H:%M%:%S")
else:
pdf.ds = pd.to_datetime(pdf.ds, unit="s")
pdf.to_sql(
tbl_name,
engine,
schema=schema,
if_exists="replace",
chunksize=500,
dtype={"ds": DateTime if database.backend != "presto" else String(255)},
index=False,
)
pdf.to_sql(
tbl_name,
engine,
schema=schema,
if_exists="replace",
chunksize=500,
dtype={"ds": DateTime if database.backend != "presto" else String(255)},
index=False,
)
print("Done loading table!")
print("-" * 80)

View File

@ -30,29 +30,29 @@ def load_sf_population_polygons(
) -> None:
tbl_name = "sf_population_polygons"
database = database_utils.get_example_database()
engine = database.get_sqla_engine()
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)
with database.get_sqla_engine_with_context() as engine:
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)
if not only_metadata and (not table_exists or force):
url = get_example_url("sf_population.json.gz")
df = pd.read_json(url, compression="gzip")
df["contour"] = df.contour.map(json.dumps)
if not only_metadata and (not table_exists or force):
url = get_example_url("sf_population.json.gz")
df = pd.read_json(url, compression="gzip")
df["contour"] = df.contour.map(json.dumps)
df.to_sql(
tbl_name,
engine,
schema=schema,
if_exists="replace",
chunksize=500,
dtype={
"zipcode": BigInteger,
"population": BigInteger,
"contour": Text,
"area": Float,
},
index=False,
)
df.to_sql(
tbl_name,
engine,
schema=schema,
if_exists="replace",
chunksize=500,
dtype={
"zipcode": BigInteger,
"population": BigInteger,
"contour": Text,
"area": Float,
},
index=False,
)
print("Creating table {} reference".format(tbl_name))
table = get_table_connector_registry()

View File

@ -453,11 +453,11 @@ def load_supported_charts_dashboard() -> None:
"""Loading a dashboard featuring supported charts"""
database = get_example_database()
engine = database.get_sqla_engine()
schema = inspect(engine).default_schema_name
with database.get_sqla_engine_with_context() as engine:
schema = inspect(engine).default_schema_name
tbl_name = "birth_names"
table_exists = database.has_table_by_name(tbl_name, schema=schema)
tbl_name = "birth_names"
table_exists = database.has_table_by_name(tbl_name, schema=schema)
if table_exists:
table = get_table_connector_registry()

View File

@ -51,37 +51,38 @@ def load_world_bank_health_n_pop( # pylint: disable=too-many-locals, too-many-s
"""Loads the world bank health dataset, slices and a dashboard"""
tbl_name = "wb_health_population"
database = superset.utils.database.get_example_database()
engine = database.get_sqla_engine()
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)
with database.get_sqla_engine_with_context() as engine:
if not only_metadata and (not table_exists or force):
url = get_example_url("countries.json.gz")
pdf = pd.read_json(url, compression="gzip")
pdf.columns = [col.replace(".", "_") for col in pdf.columns]
if database.backend == "presto":
pdf.year = pd.to_datetime(pdf.year)
pdf.year = pdf.year.dt.strftime("%Y-%m-%d %H:%M%:%S")
else:
pdf.year = pd.to_datetime(pdf.year)
pdf = pdf.head(100) if sample else pdf
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)
pdf.to_sql(
tbl_name,
engine,
schema=schema,
if_exists="replace",
chunksize=50,
dtype={
# TODO(bkyryliuk): use TIMESTAMP type for presto
"year": DateTime if database.backend != "presto" else String(255),
"country_code": String(3),
"country_name": String(255),
"region": String(255),
},
method="multi",
index=False,
)
if not only_metadata and (not table_exists or force):
url = get_example_url("countries.json.gz")
pdf = pd.read_json(url, compression="gzip")
pdf.columns = [col.replace(".", "_") for col in pdf.columns]
if database.backend == "presto":
pdf.year = pd.to_datetime(pdf.year)
pdf.year = pdf.year.dt.strftime("%Y-%m-%d %H:%M%:%S")
else:
pdf.year = pd.to_datetime(pdf.year)
pdf = pdf.head(100) if sample else pdf
pdf.to_sql(
tbl_name,
engine,
schema=schema,
if_exists="replace",
chunksize=50,
dtype={
# TODO(bkyryliuk): use TIMESTAMP type for presto
"year": DateTime if database.backend != "presto" else String(255),
"country_code": String(3),
"country_name": String(255),
"region": String(255),
},
method="multi",
index=False,
)
print("Creating table [wb_health_population] reference")
table = get_table_connector_registry()

View File

@ -369,12 +369,9 @@ class Database(
nullpool: bool = True,
source: Optional[utils.QuerySource] = None,
) -> Engine:
try:
yield self.get_sqla_engine(schema=schema, nullpool=nullpool, source=source)
except Exception as ex:
raise self.db_engine_spec.get_dbapi_mapped_exception(ex)
yield self._get_sqla_engine(schema=schema, nullpool=nullpool, source=source)
def get_sqla_engine(
def _get_sqla_engine(
self,
schema: Optional[str] = None,
nullpool: bool = True,
@ -392,7 +389,7 @@ class Database(
)
masked_url = self.get_password_masked_url(sqlalchemy_url)
logger.debug("Database.get_sqla_engine(). Masked URL: %s", str(masked_url))
logger.debug("Database._get_sqla_engine(). Masked URL: %s", str(masked_url))
params = extra.get("engine_params", {})
if nullpool:
@ -442,7 +439,7 @@ class Database(
mutator: Optional[Callable[[pd.DataFrame], None]] = None,
) -> pd.DataFrame:
sqls = self.db_engine_spec.parse_sql(sql)
engine = self.get_sqla_engine(schema)
engine = self._get_sqla_engine(schema)
def needs_conversion(df_series: pd.Series) -> bool:
return (
@ -487,7 +484,7 @@ class Database(
return df
def compile_sqla_query(self, qry: Select, schema: Optional[str] = None) -> str:
engine = self.get_sqla_engine(schema=schema)
engine = self._get_sqla_engine(schema=schema)
sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True}))
@ -508,7 +505,7 @@ class Database(
cols: Optional[List[Dict[str, Any]]] = None,
) -> str:
"""Generates a ``select *`` statement in the proper dialect"""
eng = self.get_sqla_engine(schema=schema, source=utils.QuerySource.SQL_LAB)
eng = self._get_sqla_engine(schema=schema, source=utils.QuerySource.SQL_LAB)
return self.db_engine_spec.select_star(
self,
table_name,
@ -533,7 +530,7 @@ class Database(
@property
def inspector(self) -> Inspector:
engine = self.get_sqla_engine()
engine = self._get_sqla_engine()
return sqla.inspect(engine)
@cache_util.memoized_func(
@ -674,7 +671,7 @@ class Database(
meta,
schema=schema or None,
autoload=True,
autoload_with=self.get_sqla_engine(),
autoload_with=self._get_sqla_engine(),
)
def get_table_comment(
@ -765,11 +762,11 @@ class Database(
return self.perm # type: ignore
def has_table(self, table: Table) -> bool:
engine = self.get_sqla_engine()
engine = self._get_sqla_engine()
return engine.has_table(table.table_name, table.schema or None)
def has_table_by_name(self, table_name: str, schema: Optional[str] = None) -> bool:
engine = self.get_sqla_engine()
engine = self._get_sqla_engine()
return engine.has_table(table_name, schema)
@classmethod
@ -788,7 +785,7 @@ class Database(
return view_name in view_names
def has_view(self, view_name: str, schema: Optional[str] = None) -> bool:
engine = self.get_sqla_engine()
engine = self._get_sqla_engine()
return engine.run_callable(self._has_view, engine.dialect, view_name, schema)
def has_view_by_name(self, view_name: str, schema: Optional[str] = None) -> bool:

View File

@ -224,8 +224,9 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin):
@property
def sqla_metadata(self) -> None:
# pylint: disable=no-member
meta = MetaData(bind=self.get_sqla_engine())
meta.reflect()
with self.get_sqla_engine_with_context() as engine:
meta = MetaData(bind=engine)
meta.reflect()
@property
def status(self) -> utils.DashboardStatus:

View File

@ -55,8 +55,9 @@ class FilterSet(Model, AuditMixinNullable):
@property
def sqla_metadata(self) -> None:
# pylint: disable=no-member
meta = MetaData(bind=self.get_sqla_engine())
meta.reflect()
with self.get_sqla_engine_with_context() as engine:
meta = MetaData(bind=engine)
meta.reflect()
@property
def changed_by_name(self) -> str:

View File

@ -1281,13 +1281,13 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
if limit:
qry = qry.limit(limit)
engine = self.database.get_sqla_engine() # type: ignore
sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
sql = self._apply_cte(sql, cte)
sql = self.mutate_query_from_config(sql)
with self.database.get_sqla_engine_with_context() as engine: # type: ignore
sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
sql = self._apply_cte(sql, cte)
sql = self.mutate_query_from_config(sql)
df = pd.read_sql_query(sql=sql, con=engine)
return df[column_name].to_list()
df = pd.read_sql_query(sql=sql, con=engine)
return df[column_name].to_list()
def get_timestamp_expression(
self,

View File

@ -463,61 +463,66 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca
)
)
engine = database.get_sqla_engine(query.schema, source=QuerySource.SQL_LAB)
# Sharing a single connection and cursor across the
# execution of all statements (if many)
with closing(engine.raw_connection()) as conn:
# closing the connection closes the cursor as well
cursor = conn.cursor()
cancel_query_id = db_engine_spec.get_cancel_query_id(cursor, query)
if cancel_query_id is not None:
query.set_extra_json_key(cancel_query_key, cancel_query_id)
session.commit()
statement_count = len(statements)
for i, statement in enumerate(statements):
# Check if stopped
session.refresh(query)
if query.status == QueryStatus.STOPPED:
payload.update({"status": query.status})
return payload
with database.get_sqla_engine_with_context(
query.schema, source=QuerySource.SQL_LAB
) as engine:
# Sharing a single connection and cursor across the
# execution of all statements (if many)
with closing(engine.raw_connection()) as conn:
# closing the connection closes the cursor as well
cursor = conn.cursor()
cancel_query_id = db_engine_spec.get_cancel_query_id(cursor, query)
if cancel_query_id is not None:
query.set_extra_json_key(cancel_query_key, cancel_query_id)
session.commit()
statement_count = len(statements)
for i, statement in enumerate(statements):
# Check if stopped
session.refresh(query)
if query.status == QueryStatus.STOPPED:
payload.update({"status": query.status})
return payload
# For CTAS we create the table only on the last statement
apply_ctas = query.select_as_cta and (
query.ctas_method == CtasMethod.VIEW
or (query.ctas_method == CtasMethod.TABLE and i == len(statements) - 1)
)
# For CTAS we create the table only on the last statement
apply_ctas = query.select_as_cta and (
query.ctas_method == CtasMethod.VIEW
or (
query.ctas_method == CtasMethod.TABLE
and i == len(statements) - 1
)
)
# Run statement
msg = f"Running statement {i+1} out of {statement_count}"
logger.info("Query %s: %s", str(query_id), msg)
query.set_extra_json_key("progress", msg)
session.commit()
try:
result_set = execute_sql_statement(
statement,
query,
session,
cursor,
log_params,
apply_ctas,
)
except SqlLabQueryStoppedException:
payload.update({"status": QueryStatus.STOPPED})
return payload
except Exception as ex: # pylint: disable=broad-except
msg = str(ex)
prefix_message = (
f"[Statement {i+1} out of {statement_count}]"
if statement_count > 1
else ""
)
payload = handle_query_error(
ex, query, session, payload, prefix_message
)
return payload
# Run statement
msg = f"Running statement {i+1} out of {statement_count}"
logger.info("Query %s: %s", str(query_id), msg)
query.set_extra_json_key("progress", msg)
session.commit()
try:
result_set = execute_sql_statement(
statement,
query,
session,
cursor,
log_params,
apply_ctas,
)
except SqlLabQueryStoppedException:
payload.update({"status": QueryStatus.STOPPED})
return payload
except Exception as ex: # pylint: disable=broad-except
msg = str(ex)
prefix_message = (
f"[Statement {i+1} out of {statement_count}]"
if statement_count > 1
else ""
)
payload = handle_query_error(
ex, query, session, payload, prefix_message
)
return payload
# Commit the connection so CTA queries will create the table.
conn.commit()
# Commit the connection so CTA queries will create the table.
conn.commit()
# Success, updating the query entry in database
query.rows = result_set.size
@ -622,10 +627,11 @@ def cancel_query(query: Query) -> bool:
if cancel_query_id is None:
return False
engine = query.database.get_sqla_engine(query.schema, source=QuerySource.SQL_LAB)
with closing(engine.raw_connection()) as conn:
with closing(conn.cursor()) as cursor:
return query.database.db_engine_spec.cancel_query(
cursor, query, cancel_query_id
)
with query.database.get_sqla_engine_with_context(
query.schema, source=QuerySource.SQL_LAB
) as engine:
with closing(engine.raw_connection()) as conn:
with closing(conn.cursor()) as cursor:
return query.database.db_engine_spec.cancel_query(
cursor, query, cancel_query_id
)

View File

@ -162,16 +162,18 @@ class PrestoDBSQLValidator(BaseSQLValidator):
statements = parsed_query.get_statements()
logger.info("Validating %i statement(s)", len(statements))
engine = database.get_sqla_engine(schema, source=QuerySource.SQL_LAB)
# Sharing a single connection and cursor across the
# execution of all statements (if many)
annotations: List[SQLValidationAnnotation] = []
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
for statement in parsed_query.get_statements():
annotation = cls.validate_statement(statement, database, cursor)
if annotation:
annotations.append(annotation)
logger.debug("Validation found %i error(s)", len(annotations))
with database.get_sqla_engine_with_context(
schema, source=QuerySource.SQL_LAB
) as engine:
# Sharing a single connection and cursor across the
# execution of all statements (if many)
annotations: List[SQLValidationAnnotation] = []
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
for statement in parsed_query.get_statements():
annotation = cls.validate_statement(statement, database, cursor)
if annotation:
annotations.append(annotation)
logger.debug("Validation found %i error(s)", len(annotations))
return annotations

View File

@ -1284,8 +1284,8 @@ def get_example_default_schema() -> Optional[str]:
Return the default schema of the examples database, if any.
"""
database = get_example_database()
engine = database.get_sqla_engine()
return inspect(engine).default_schema_name
with database.get_sqla_engine_with_context() as engine:
return inspect(engine).default_schema_name
def backend() -> str:

View File

@ -187,29 +187,29 @@ def add_data(
database = get_example_database()
table_exists = database.has_table_by_name(table_name)
engine = database.get_sqla_engine()
if columns is None:
if not table_exists:
raise Exception(
f"The table {table_name} does not exist. To create it you need to "
"pass a list of column names and types."
)
with database.get_sqla_engine_with_context() as engine:
if columns is None:
if not table_exists:
raise Exception(
f"The table {table_name} does not exist. To create it you need to "
"pass a list of column names and types."
)
inspector = inspect(engine)
columns = inspector.get_columns(table_name)
inspector = inspect(engine)
columns = inspector.get_columns(table_name)
# create table if needed
column_objects = get_column_objects(columns)
metadata = MetaData()
table = Table(table_name, metadata, *column_objects)
metadata.create_all(engine)
# create table if needed
column_objects = get_column_objects(columns)
metadata = MetaData()
table = Table(table_name, metadata, *column_objects)
metadata.create_all(engine)
if not append:
engine.execute(table.delete())
if not append:
engine.execute(table.delete())
data = generate_data(columns, num_rows)
engine.execute(table.insert(), data)
data = generate_data(columns, num_rows)
engine.execute(table.insert(), data)
def get_column_objects(columns: List[ColumnInfo]) -> List[Column]:

View File

@ -1379,11 +1379,11 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
)
database.set_sqlalchemy_uri(uri)
database.db_engine_spec.mutate_db_for_connection_test(database)
engine = database.get_sqla_engine()
with closing(engine.raw_connection()) as conn:
if engine.dialect.do_ping(conn):
return json_success('"OK"')
with database.get_sqla_engine_with_context() as engine:
with closing(engine.raw_connection()) as conn:
if engine.dialect.do_ping(conn):
return json_success('"OK"')
raise DBAPIError(None, None, None)
except CertificateException as ex:

View File

@ -171,7 +171,7 @@ def example_db_provider() -> Callable[[], Database]: # type: ignore
return self._db
def _load_lazy_data_to_decouple_from_session(self) -> None:
self._db.get_sqla_engine() # type: ignore
self._db._get_sqla_engine() # type: ignore
self._db.backend # type: ignore
def remove(self) -> None:
@ -336,37 +336,38 @@ def physical_dataset():
from superset.connectors.sqla.utils import get_identifier_quoter
example_database = get_example_database()
engine = example_database.get_sqla_engine()
quoter = get_identifier_quoter(engine.name)
# sqlite can only execute one statement at a time
engine.execute(
f"""
CREATE TABLE IF NOT EXISTS physical_dataset(
col1 INTEGER,
col2 VARCHAR(255),
col3 DECIMAL(4,2),
col4 VARCHAR(255),
col5 TIMESTAMP DEFAULT '1970-01-01 00:00:01',
col6 TIMESTAMP DEFAULT '1970-01-01 00:00:01',
{quoter('time column with spaces')} TIMESTAMP DEFAULT '1970-01-01 00:00:01'
);
with example_database.get_sqla_engine_with_context() as engine:
quoter = get_identifier_quoter(engine.name)
# sqlite can only execute one statement at a time
engine.execute(
f"""
CREATE TABLE IF NOT EXISTS physical_dataset(
col1 INTEGER,
col2 VARCHAR(255),
col3 DECIMAL(4,2),
col4 VARCHAR(255),
col5 TIMESTAMP DEFAULT '1970-01-01 00:00:01',
col6 TIMESTAMP DEFAULT '1970-01-01 00:00:01',
{quoter('time column with spaces')} TIMESTAMP DEFAULT '1970-01-01 00:00:01'
);
"""
)
engine.execute(
"""
INSERT INTO physical_dataset values
(0, 'a', 1.0, NULL, '2000-01-01 00:00:00', '2002-01-03 00:00:00', '2002-01-03 00:00:00'),
(1, 'b', 1.1, NULL, '2000-01-02 00:00:00', '2002-02-04 00:00:00', '2002-02-04 00:00:00'),
(2, 'c', 1.2, NULL, '2000-01-03 00:00:00', '2002-03-07 00:00:00', '2002-03-07 00:00:00'),
(3, 'd', 1.3, NULL, '2000-01-04 00:00:00', '2002-04-12 00:00:00', '2002-04-12 00:00:00'),
(4, 'e', 1.4, NULL, '2000-01-05 00:00:00', '2002-05-11 00:00:00', '2002-05-11 00:00:00'),
(5, 'f', 1.5, NULL, '2000-01-06 00:00:00', '2002-06-13 00:00:00', '2002-06-13 00:00:00'),
(6, 'g', 1.6, NULL, '2000-01-07 00:00:00', '2002-07-15 00:00:00', '2002-07-15 00:00:00'),
(7, 'h', 1.7, NULL, '2000-01-08 00:00:00', '2002-08-18 00:00:00', '2002-08-18 00:00:00'),
(8, 'i', 1.8, NULL, '2000-01-09 00:00:00', '2002-09-20 00:00:00', '2002-09-20 00:00:00'),
(9, 'j', 1.9, NULL, '2000-01-10 00:00:00', '2002-10-22 00:00:00', '2002-10-22 00:00:00');
"""
)
engine.execute(
"""
INSERT INTO physical_dataset values
(0, 'a', 1.0, NULL, '2000-01-01 00:00:00', '2002-01-03 00:00:00', '2002-01-03 00:00:00'),
(1, 'b', 1.1, NULL, '2000-01-02 00:00:00', '2002-02-04 00:00:00', '2002-02-04 00:00:00'),
(2, 'c', 1.2, NULL, '2000-01-03 00:00:00', '2002-03-07 00:00:00', '2002-03-07 00:00:00'),
(3, 'd', 1.3, NULL, '2000-01-04 00:00:00', '2002-04-12 00:00:00', '2002-04-12 00:00:00'),
(4, 'e', 1.4, NULL, '2000-01-05 00:00:00', '2002-05-11 00:00:00', '2002-05-11 00:00:00'),
(5, 'f', 1.5, NULL, '2000-01-06 00:00:00', '2002-06-13 00:00:00', '2002-06-13 00:00:00'),
(6, 'g', 1.6, NULL, '2000-01-07 00:00:00', '2002-07-15 00:00:00', '2002-07-15 00:00:00'),
(7, 'h', 1.7, NULL, '2000-01-08 00:00:00', '2002-08-18 00:00:00', '2002-08-18 00:00:00'),
(8, 'i', 1.8, NULL, '2000-01-09 00:00:00', '2002-09-20 00:00:00', '2002-09-20 00:00:00'),
(9, 'j', 1.9, NULL, '2000-01-10 00:00:00', '2002-10-22 00:00:00', '2002-10-22 00:00:00');
"""
)
)
dataset = SqlaTable(
table_name="physical_dataset",

View File

@ -641,7 +641,7 @@ class TestImportDatabasesCommand(SupersetTestCase):
class TestTestConnectionDatabaseCommand(SupersetTestCase):
@mock.patch("superset.databases.dao.Database.get_sqla_engine")
@mock.patch("superset.databases.dao.Database._get_sqla_engine")
@mock.patch(
"superset.databases.commands.test_connection.event_logger.log_with_context"
)
@ -664,7 +664,7 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
)
mock_event_logger.assert_called()
@mock.patch("superset.databases.dao.Database.get_sqla_engine")
@mock.patch("superset.databases.dao.Database._get_sqla_engine")
@mock.patch(
"superset.databases.commands.test_connection.event_logger.log_with_context"
)
@ -713,7 +713,7 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
== SupersetErrorType.CONNECTION_DATABASE_TIMEOUT
)
@mock.patch("superset.databases.dao.Database.get_sqla_engine")
@mock.patch("superset.databases.dao.Database._get_sqla_engine")
@mock.patch(
"superset.databases.commands.test_connection.event_logger.log_with_context"
)
@ -738,7 +738,7 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
mock_event_logger.assert_called()
@mock.patch("superset.databases.dao.Database.get_sqla_engine")
@mock.patch("superset.databases.dao.Database._get_sqla_engine")
@mock.patch(
"superset.databases.commands.test_connection.event_logger.log_with_context"
)

View File

@ -227,8 +227,10 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
return_value="account_info"
)
mock_get_engine.return_value.url.host = "google-host"
mock_get_engine.return_value.dialect.credentials_info = "secrets"
mock_get_engine.return_value.__enter__.return_value.url.host = "google-host"
mock_get_engine.return_value.__enter__.return_value.dialect.credentials_info = (
"secrets"
)
BigQueryEngineSpec.df_to_sql(
database=database,

View File

@ -204,7 +204,9 @@ def test_df_to_sql_if_exists_replace(mock_upload_to_s3, mock_g):
mock_database = mock.MagicMock()
mock_database.get_df.return_value.empty = False
mock_execute = mock.MagicMock(return_value=True)
mock_database.get_sqla_engine.return_value.execute = mock_execute
mock_database.get_sqla_engine_with_context.return_value.__enter__.return_value.execute = (
mock_execute
)
table_name = "foobar"
with app.app_context():
@ -229,7 +231,9 @@ def test_df_to_sql_if_exists_replace_with_schema(mock_upload_to_s3, mock_g):
mock_database = mock.MagicMock()
mock_database.get_df.return_value.empty = False
mock_execute = mock.MagicMock(return_value=True)
mock_database.get_sqla_engine.return_value.execute = mock_execute
mock_database.get_sqla_engine_with_context.return_value.__enter__.return_value.execute = (
mock_execute
)
table_name = "foobar"
schema = "schema"

View File

@ -37,12 +37,13 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
def test_get_view_names_with_schema(self):
database = mock.MagicMock()
mock_execute = mock.MagicMock()
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = (
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = (
mock_execute
)
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall = mock.MagicMock(
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.fetchall = mock.MagicMock(
return_value=[["a", "b,", "c"], ["d", "e"]]
)
schema = "schema"
result = PrestoEngineSpec.get_view_names(database, mock.Mock(), schema)
mock_execute.assert_called_once_with(
@ -60,10 +61,10 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
def test_get_view_names_without_schema(self):
database = mock.MagicMock()
mock_execute = mock.MagicMock()
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = (
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = (
mock_execute
)
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall = mock.MagicMock(
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.fetchall = mock.MagicMock(
return_value=[["a", "b,", "c"], ["d", "e"]]
)
result = PrestoEngineSpec.get_view_names(database, mock.Mock(), None)
@ -821,13 +822,13 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
mock_execute = mock.MagicMock()
mock_fetchall = mock.MagicMock(return_value=[["a", "b,", "c"], ["d", "e"]])
database = mock.MagicMock()
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = (
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = (
mock_execute
)
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall = (
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.fetchall = (
mock_fetchall
)
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.poll.return_value = (
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.poll.return_value = (
False
)
schema = "schema"
@ -839,7 +840,7 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
def test_get_create_view_exception(self):
mock_execute = mock.MagicMock(side_effect=Exception())
database = mock.MagicMock()
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = (
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = (
mock_execute
)
schema = "schema"
@ -852,7 +853,7 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
mock_execute = mock.MagicMock(side_effect=DatabaseError())
database = mock.MagicMock()
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = (
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = (
mock_execute
)
schema = "schema"

View File

@ -51,8 +51,8 @@ def load_unicode_data():
yield
with app.app_context():
engine = get_example_database().get_sqla_engine()
engine.execute("DROP TABLE IF EXISTS unicode_test")
with get_example_database().get_sqla_engine_with_context() as engine:
engine.execute("DROP TABLE IF EXISTS unicode_test")
@pytest.fixture()

View File

@ -64,8 +64,8 @@ def load_world_bank_data():
yield
with app.app_context():
engine = get_example_database().get_sqla_engine()
engine.execute("DROP TABLE IF EXISTS wb_health_population")
with get_example_database().get_sqla_engine_with_context() as engine:
engine.execute("DROP TABLE IF EXISTS wb_health_population")
@pytest.fixture()

View File

@ -164,7 +164,7 @@ class TestDatabaseModel(SupersetTestCase):
database_name="test_database", sqlalchemy_uri=uri, extra=extra
)
model.impersonate_user = True
model.get_sqla_engine()
model._get_sqla_engine()
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "presto://gamma@localhost"
@ -177,7 +177,7 @@ class TestDatabaseModel(SupersetTestCase):
}
model.impersonate_user = False
model.get_sqla_engine()
model._get_sqla_engine()
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "presto://localhost"
@ -197,7 +197,7 @@ class TestDatabaseModel(SupersetTestCase):
database_name="test_database", sqlalchemy_uri="trino://localhost"
)
model.impersonate_user = True
model.get_sqla_engine()
model._get_sqla_engine()
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "trino://localhost"
@ -209,7 +209,7 @@ class TestDatabaseModel(SupersetTestCase):
)
model.impersonate_user = True
model.get_sqla_engine()
model._get_sqla_engine()
call_args = mocked_create_engine.call_args
assert (
@ -242,7 +242,7 @@ class TestDatabaseModel(SupersetTestCase):
database_name="test_database", sqlalchemy_uri=uri, extra=extra
)
model.impersonate_user = True
model.get_sqla_engine()
model._get_sqla_engine()
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "hive://localhost"
@ -255,7 +255,7 @@ class TestDatabaseModel(SupersetTestCase):
}
model.impersonate_user = False
model.get_sqla_engine()
model._get_sqla_engine()
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "hive://localhost"
@ -380,21 +380,7 @@ class TestDatabaseModel(SupersetTestCase):
)
mocked_create_engine.side_effect = Exception()
with self.assertRaises(SupersetException):
model.get_sqla_engine()
# todo(hughhh): update this test
# @mock.patch("superset.models.core.create_engine")
# def test_get_sqla_engine_with_context(self, mocked_create_engine):
# model = Database(
# database_name="test_database",
# sqlalchemy_uri="mysql://root@localhost",
# )
# model.db_engine_spec.get_dbapi_exception_mapping = mock.Mock(
# return_value={Exception: SupersetException}
# )
# mocked_create_engine.side_effect = Exception()
# with self.assertRaises(SupersetException):
# model.get_sqla_engine()
model._get_sqla_engine()
class TestSqlaTableModel(SupersetTestCase):

View File

@ -174,7 +174,9 @@ class TestPrestoValidator(SupersetTestCase):
def setUp(self):
self.validator = PrestoDBSQLValidator
self.database = MagicMock()
self.database_engine = self.database.get_sqla_engine.return_value
self.database_engine = (
self.database.get_sqla_engine_with_context.return_value.__enter__.return_value
)
self.database_conn = self.database_engine.raw_connection.return_value
self.database_cursor = self.database_conn.cursor.return_value
self.database_cursor.poll.return_value = None

View File

@ -733,7 +733,7 @@ class TestSqlLab(SupersetTestCase):
mock_query = mock.MagicMock()
mock_query.database.allow_run_async = False
mock_cursor = mock.MagicMock()
mock_query.database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value = (
mock_query.database.get_sqla_engine_with_context().__enter__().raw_connection().cursor.return_value = (
mock_cursor
)
mock_query.database.db_engine_spec.run_multiple_statements_as_one = False
@ -786,7 +786,7 @@ class TestSqlLab(SupersetTestCase):
mock_query = mock.MagicMock()
mock_query.database.allow_run_async = True
mock_cursor = mock.MagicMock()
mock_query.database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value = (
mock_query.database.get_sqla_engine_with_context().__enter__().raw_connection().cursor.return_value = (
mock_cursor
)
mock_query.database.db_engine_spec.run_multiple_statements_as_one = False
@ -836,7 +836,7 @@ class TestSqlLab(SupersetTestCase):
mock_query = mock.MagicMock()
mock_query.database.allow_run_async = False
mock_cursor = mock.MagicMock()
mock_query.database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value = (
mock_query.database.get_sqla_engine_with_context().__enter__().raw_connection().cursor.return_value = (
mock_cursor
)
mock_query.database.db_engine_spec.run_multiple_statements_as_one = False