feat: refactor all `get_sqla_engine` to use contextmanager in codebase (#21943)
This commit is contained in:
parent
06f87e1467
commit
e23efefc46
|
|
@ -804,13 +804,13 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
if self.fetch_values_predicate:
|
||||
qry = qry.where(self.get_fetch_values_predicate())
|
||||
|
||||
engine = self.database.get_sqla_engine()
|
||||
sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
|
||||
sql = self._apply_cte(sql, cte)
|
||||
sql = self.mutate_query_from_config(sql)
|
||||
with self.database.get_sqla_engine_with_context() as engine:
|
||||
sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
|
||||
sql = self._apply_cte(sql, cte)
|
||||
sql = self.mutate_query_from_config(sql)
|
||||
|
||||
df = pd.read_sql_query(sql=sql, con=engine)
|
||||
return df[column_name].to_list()
|
||||
df = pd.read_sql_query(sql=sql, con=engine)
|
||||
return df[column_name].to_list()
|
||||
|
||||
def mutate_query_from_config(self, sql: str) -> str:
|
||||
"""Apply config's SQL_QUERY_MUTATOR
|
||||
|
|
|
|||
|
|
@ -112,7 +112,6 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> List[ResultSetColumnType]:
|
|||
)
|
||||
|
||||
db_engine_spec = dataset.database.db_engine_spec
|
||||
engine = dataset.database.get_sqla_engine(schema=dataset.schema)
|
||||
sql = dataset.get_template_processor().process_template(
|
||||
dataset.sql, **dataset.template_params_dict
|
||||
)
|
||||
|
|
@ -137,13 +136,18 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> List[ResultSetColumnType]:
|
|||
# TODO(villebro): refactor to use same code that's used by
|
||||
# sql_lab.py:execute_sql_statements
|
||||
try:
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
cursor = conn.cursor()
|
||||
query = dataset.database.apply_limit_to_sql(statements[0], limit=1)
|
||||
db_engine_spec.execute(cursor, query)
|
||||
result = db_engine_spec.fetch_data(cursor, limit=1)
|
||||
result_set = SupersetResultSet(result, cursor.description, db_engine_spec)
|
||||
cols = result_set.columns
|
||||
with dataset.database.get_sqla_engine_with_context(
|
||||
schema=dataset.schema
|
||||
) as engine:
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
cursor = conn.cursor()
|
||||
query = dataset.database.apply_limit_to_sql(statements[0], limit=1)
|
||||
db_engine_spec.execute(cursor, query)
|
||||
result = db_engine_spec.fetch_data(cursor, limit=1)
|
||||
result_set = SupersetResultSet(
|
||||
result, cursor.description, db_engine_spec
|
||||
)
|
||||
cols = result_set.columns
|
||||
except Exception as ex:
|
||||
raise SupersetGenericDBErrorException(message=str(ex)) from ex
|
||||
return cols
|
||||
|
|
@ -155,14 +159,17 @@ def get_columns_description(
|
|||
) -> List[ResultSetColumnType]:
|
||||
db_engine_spec = database.db_engine_spec
|
||||
try:
|
||||
with closing(database.get_sqla_engine().raw_connection()) as conn:
|
||||
cursor = conn.cursor()
|
||||
query = database.apply_limit_to_sql(query, limit=1)
|
||||
cursor.execute(query)
|
||||
db_engine_spec.execute(cursor, query)
|
||||
result = db_engine_spec.fetch_data(cursor, limit=1)
|
||||
result_set = SupersetResultSet(result, cursor.description, db_engine_spec)
|
||||
return result_set.columns
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
cursor = conn.cursor()
|
||||
query = database.apply_limit_to_sql(query, limit=1)
|
||||
cursor.execute(query)
|
||||
db_engine_spec.execute(cursor, query)
|
||||
result = db_engine_spec.fetch_data(cursor, limit=1)
|
||||
result_set = SupersetResultSet(
|
||||
result, cursor.description, db_engine_spec
|
||||
)
|
||||
return result_set.columns
|
||||
except Exception as ex:
|
||||
raise SupersetGenericDBErrorException(message=str(ex)) from ex
|
||||
|
||||
|
|
|
|||
|
|
@ -90,7 +90,6 @@ class TestConnectionDatabaseCommand(BaseCommand):
|
|||
database.set_sqlalchemy_uri(uri)
|
||||
database.db_engine_spec.mutate_db_for_connection_test(database)
|
||||
|
||||
engine = database.get_sqla_engine()
|
||||
event_logger.log_with_context(
|
||||
action="test_connection_attempt",
|
||||
engine=database.db_engine_spec.__name__,
|
||||
|
|
@ -100,31 +99,32 @@ class TestConnectionDatabaseCommand(BaseCommand):
|
|||
with closing(engine.raw_connection()) as conn:
|
||||
return engine.dialect.do_ping(conn)
|
||||
|
||||
try:
|
||||
alive = func_timeout(
|
||||
int(app.config["TEST_DATABASE_CONNECTION_TIMEOUT"].total_seconds()),
|
||||
ping,
|
||||
args=(engine,),
|
||||
)
|
||||
except (sqlite3.ProgrammingError, RuntimeError):
|
||||
# SQLite can't run on a separate thread, so ``func_timeout`` fails
|
||||
# RuntimeError catches the equivalent error from duckdb.
|
||||
alive = engine.dialect.do_ping(engine)
|
||||
except FunctionTimedOut as ex:
|
||||
raise SupersetTimeoutException(
|
||||
error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT,
|
||||
message=(
|
||||
"Please check your connection details and database settings, "
|
||||
"and ensure that your database is accepting connections, "
|
||||
"then try connecting again."
|
||||
),
|
||||
level=ErrorLevel.ERROR,
|
||||
extra={"sqlalchemy_uri": database.sqlalchemy_uri},
|
||||
) from ex
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
alive = False
|
||||
# So we stop losing the original message if any
|
||||
ex_str = str(ex)
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
try:
|
||||
alive = func_timeout(
|
||||
app.config["TEST_DATABASE_CONNECTION_TIMEOUT"].total_seconds(),
|
||||
ping,
|
||||
args=(engine,),
|
||||
)
|
||||
except (sqlite3.ProgrammingError, RuntimeError):
|
||||
# SQLite can't run on a separate thread, so ``func_timeout`` fails
|
||||
# RuntimeError catches the equivalent error from duckdb.
|
||||
alive = engine.dialect.do_ping(engine)
|
||||
except FunctionTimedOut as ex:
|
||||
raise SupersetTimeoutException(
|
||||
error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT,
|
||||
message=(
|
||||
"Please check your connection details and database settings, "
|
||||
"and ensure that your database is accepting connections, "
|
||||
"then try connecting again."
|
||||
),
|
||||
level=ErrorLevel.ERROR,
|
||||
extra={"sqlalchemy_uri": database.sqlalchemy_uri},
|
||||
) from ex
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
alive = False
|
||||
# So we stop losing the original message if any
|
||||
ex_str = str(ex)
|
||||
|
||||
if not alive:
|
||||
raise DBAPIError(ex_str or None, None, None)
|
||||
|
|
|
|||
|
|
@ -101,21 +101,22 @@ class ValidateDatabaseParametersCommand(BaseCommand):
|
|||
database.set_sqlalchemy_uri(sqlalchemy_uri)
|
||||
database.db_engine_spec.mutate_db_for_connection_test(database)
|
||||
|
||||
engine = database.get_sqla_engine()
|
||||
try:
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
alive = engine.dialect.do_ping(conn)
|
||||
except Exception as ex:
|
||||
url = make_url_safe(sqlalchemy_uri)
|
||||
context = {
|
||||
"hostname": url.host,
|
||||
"password": url.password,
|
||||
"port": url.port,
|
||||
"username": url.username,
|
||||
"database": url.database,
|
||||
}
|
||||
errors = database.db_engine_spec.extract_errors(ex, context)
|
||||
raise DatabaseTestConnectionFailedError(errors) from ex
|
||||
alive = False
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
try:
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
alive = engine.dialect.do_ping(conn)
|
||||
except Exception as ex:
|
||||
url = make_url_safe(sqlalchemy_uri)
|
||||
context = {
|
||||
"hostname": url.host,
|
||||
"password": url.password,
|
||||
"port": url.port,
|
||||
"username": url.username,
|
||||
"database": url.database,
|
||||
}
|
||||
errors = database.db_engine_spec.extract_errors(ex, context)
|
||||
raise DatabaseTestConnectionFailedError(errors) from ex
|
||||
|
||||
if not alive:
|
||||
raise DatabaseOfflineError(
|
||||
|
|
|
|||
|
|
@ -166,17 +166,26 @@ def load_data(
|
|||
if database.sqlalchemy_uri == current_app.config.get("SQLALCHEMY_DATABASE_URI"):
|
||||
logger.info("Loading data inside the import transaction")
|
||||
connection = session.connection()
|
||||
df.to_sql(
|
||||
dataset.table_name,
|
||||
con=connection,
|
||||
schema=dataset.schema,
|
||||
if_exists="replace",
|
||||
chunksize=CHUNKSIZE,
|
||||
dtype=dtype,
|
||||
index=False,
|
||||
method="multi",
|
||||
)
|
||||
else:
|
||||
logger.warning("Loading data outside the import transaction")
|
||||
connection = database.get_sqla_engine()
|
||||
|
||||
df.to_sql(
|
||||
dataset.table_name,
|
||||
con=connection,
|
||||
schema=dataset.schema,
|
||||
if_exists="replace",
|
||||
chunksize=CHUNKSIZE,
|
||||
dtype=dtype,
|
||||
index=False,
|
||||
method="multi",
|
||||
)
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
df.to_sql(
|
||||
dataset.table_name,
|
||||
con=engine,
|
||||
schema=dataset.schema,
|
||||
if_exists="replace",
|
||||
chunksize=CHUNKSIZE,
|
||||
dtype=dtype,
|
||||
index=False,
|
||||
method="multi",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from datetime import datetime
|
|||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
ContextManager,
|
||||
Dict,
|
||||
List,
|
||||
Match,
|
||||
|
|
@ -480,8 +481,16 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
database: "Database",
|
||||
schema: Optional[str] = None,
|
||||
source: Optional[utils.QuerySource] = None,
|
||||
) -> Engine:
|
||||
return database.get_sqla_engine(schema=schema, source=source)
|
||||
) -> ContextManager[Engine]:
|
||||
"""
|
||||
Return an engine context manager.
|
||||
|
||||
>>> with DBEngineSpec.get_engine(database, schema, source) as engine:
|
||||
... connection = engine.connect()
|
||||
... connection.execute(sql)
|
||||
|
||||
"""
|
||||
return database.get_sqla_engine_with_context(schema=schema, source=source)
|
||||
|
||||
@classmethod
|
||||
def get_timestamp_expr(
|
||||
|
|
@ -903,17 +912,17 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
:param to_sql_kwargs: The kwargs to be passed to pandas.DataFrame.to_sql` method
|
||||
"""
|
||||
|
||||
engine = cls.get_engine(database)
|
||||
to_sql_kwargs["name"] = table.table
|
||||
|
||||
if table.schema:
|
||||
# Only add schema when it is preset and non empty.
|
||||
to_sql_kwargs["schema"] = table.schema
|
||||
|
||||
if engine.dialect.supports_multivalues_insert:
|
||||
to_sql_kwargs["method"] = "multi"
|
||||
with cls.get_engine(database) as engine:
|
||||
if engine.dialect.supports_multivalues_insert:
|
||||
to_sql_kwargs["method"] = "multi"
|
||||
|
||||
df.to_sql(con=engine, **to_sql_kwargs)
|
||||
df.to_sql(con=engine, **to_sql_kwargs)
|
||||
|
||||
@classmethod
|
||||
def convert_dttm( # pylint: disable=unused-argument
|
||||
|
|
@ -1286,13 +1295,15 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
parsed_query = sql_parse.ParsedQuery(sql)
|
||||
statements = parsed_query.get_statements()
|
||||
|
||||
engine = cls.get_engine(database, schema=schema, source=source)
|
||||
costs = []
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
cursor = conn.cursor()
|
||||
for statement in statements:
|
||||
processed_statement = cls.process_statement(statement, database)
|
||||
costs.append(cls.estimate_statement_cost(processed_statement, cursor))
|
||||
with cls.get_engine(database, schema=schema, source=source) as engine:
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
cursor = conn.cursor()
|
||||
for statement in statements:
|
||||
processed_statement = cls.process_statement(statement, database)
|
||||
costs.append(
|
||||
cls.estimate_statement_cost(processed_statement, cursor)
|
||||
)
|
||||
return costs
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -340,8 +340,12 @@ class BigQueryEngineSpec(BaseEngineSpec):
|
|||
if not table.schema:
|
||||
raise Exception("The table schema must be defined")
|
||||
|
||||
engine = cls.get_engine(database)
|
||||
to_gbq_kwargs = {"destination_table": str(table), "project_id": engine.url.host}
|
||||
to_gbq_kwargs = {}
|
||||
with cls.get_engine(database) as engine:
|
||||
to_gbq_kwargs = {
|
||||
"destination_table": str(table),
|
||||
"project_id": engine.url.host,
|
||||
}
|
||||
|
||||
# Add credentials if they are set on the SQLAlchemy dialect.
|
||||
creds = engine.dialect.credentials_info
|
||||
|
|
|
|||
|
|
@ -109,11 +109,11 @@ class GSheetsEngineSpec(SqliteEngineSpec):
|
|||
table_name: str,
|
||||
schema_name: Optional[str],
|
||||
) -> Dict[str, Any]:
|
||||
engine = cls.get_engine(database, schema=schema_name)
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(f'SELECT GET_METADATA("{table_name}")')
|
||||
results = cursor.fetchone()[0]
|
||||
with cls.get_engine(database, schema=schema_name) as engine:
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(f'SELECT GET_METADATA("{table_name}")')
|
||||
results = cursor.fetchone()[0]
|
||||
|
||||
try:
|
||||
metadata = json.loads(results)
|
||||
|
|
|
|||
|
|
@ -185,8 +185,6 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||
:param to_sql_kwargs: The kwargs to be passed to pandas.DataFrame.to_sql` method
|
||||
"""
|
||||
|
||||
engine = cls.get_engine(database)
|
||||
|
||||
if to_sql_kwargs["if_exists"] == "append":
|
||||
raise SupersetException("Append operation not currently supported")
|
||||
|
||||
|
|
@ -205,7 +203,8 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||
if table_exists:
|
||||
raise SupersetException("Table already exists")
|
||||
elif to_sql_kwargs["if_exists"] == "replace":
|
||||
engine.execute(f"DROP TABLE IF EXISTS {str(table)}")
|
||||
with cls.get_engine(database) as engine:
|
||||
engine.execute(f"DROP TABLE IF EXISTS {str(table)}")
|
||||
|
||||
def _get_hive_type(dtype: np.dtype) -> str:
|
||||
hive_type_by_dtype = {
|
||||
|
|
|
|||
|
|
@ -462,12 +462,11 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
|||
).strip()
|
||||
params = {}
|
||||
|
||||
engine = cls.get_engine(database, schema=schema)
|
||||
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(sql, params)
|
||||
results = cursor.fetchall()
|
||||
with cls.get_engine(database, schema=schema) as engine:
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(sql, params)
|
||||
results = cursor.fetchall()
|
||||
|
||||
return sorted([row[0] for row in results])
|
||||
|
||||
|
|
@ -989,17 +988,17 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
|||
# pylint: disable=import-outside-toplevel
|
||||
from pyhive.exc import DatabaseError
|
||||
|
||||
engine = cls.get_engine(database, schema)
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
cursor = conn.cursor()
|
||||
sql = f"SHOW CREATE VIEW {schema}.{table}"
|
||||
try:
|
||||
cls.execute(cursor, sql)
|
||||
with cls.get_engine(database, schema=schema) as engine:
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
cursor = conn.cursor()
|
||||
sql = f"SHOW CREATE VIEW {schema}.{table}"
|
||||
try:
|
||||
cls.execute(cursor, sql)
|
||||
|
||||
except DatabaseError: # not a VIEW
|
||||
return None
|
||||
rows = cls.fetch_data(cursor, 1)
|
||||
return rows[0][0]
|
||||
except DatabaseError: # not a VIEW
|
||||
return None
|
||||
rows = cls.fetch_data(cursor, 1)
|
||||
return rows[0][0]
|
||||
|
||||
@classmethod
|
||||
def get_tracking_url(cls, cursor: "Cursor") -> Optional[str]:
|
||||
|
|
|
|||
|
|
@ -29,31 +29,31 @@ from .helpers import get_example_url, get_table_connector_registry
|
|||
def load_bart_lines(only_metadata: bool = False, force: bool = False) -> None:
|
||||
tbl_name = "bart_lines"
|
||||
database = get_example_database()
|
||||
engine = database.get_sqla_engine()
|
||||
schema = inspect(engine).default_schema_name
|
||||
table_exists = database.has_table_by_name(tbl_name)
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
schema = inspect(engine).default_schema_name
|
||||
table_exists = database.has_table_by_name(tbl_name)
|
||||
|
||||
if not only_metadata and (not table_exists or force):
|
||||
url = get_example_url("bart-lines.json.gz")
|
||||
df = pd.read_json(url, encoding="latin-1", compression="gzip")
|
||||
df["path_json"] = df.path.map(json.dumps)
|
||||
df["polyline"] = df.path.map(polyline.encode)
|
||||
del df["path"]
|
||||
if not only_metadata and (not table_exists or force):
|
||||
url = get_example_url("bart-lines.json.gz")
|
||||
df = pd.read_json(url, encoding="latin-1", compression="gzip")
|
||||
df["path_json"] = df.path.map(json.dumps)
|
||||
df["polyline"] = df.path.map(polyline.encode)
|
||||
del df["path"]
|
||||
|
||||
df.to_sql(
|
||||
tbl_name,
|
||||
engine,
|
||||
schema=schema,
|
||||
if_exists="replace",
|
||||
chunksize=500,
|
||||
dtype={
|
||||
"color": String(255),
|
||||
"name": String(255),
|
||||
"polyline": Text,
|
||||
"path_json": Text,
|
||||
},
|
||||
index=False,
|
||||
)
|
||||
df.to_sql(
|
||||
tbl_name,
|
||||
engine,
|
||||
schema=schema,
|
||||
if_exists="replace",
|
||||
chunksize=500,
|
||||
dtype={
|
||||
"color": String(255),
|
||||
"name": String(255),
|
||||
"polyline": Text,
|
||||
"path_json": Text,
|
||||
},
|
||||
index=False,
|
||||
)
|
||||
|
||||
print("Creating table {} reference".format(tbl_name))
|
||||
table = get_table_connector_registry()
|
||||
|
|
|
|||
|
|
@ -76,25 +76,25 @@ def load_data(tbl_name: str, database: Database, sample: bool = False) -> None:
|
|||
pdf.ds = pd.to_datetime(pdf.ds, unit="ms")
|
||||
pdf = pdf.head(100) if sample else pdf
|
||||
|
||||
engine = database.get_sqla_engine()
|
||||
schema = inspect(engine).default_schema_name
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
schema = inspect(engine).default_schema_name
|
||||
|
||||
pdf.to_sql(
|
||||
tbl_name,
|
||||
database.get_sqla_engine(),
|
||||
schema=schema,
|
||||
if_exists="replace",
|
||||
chunksize=500,
|
||||
dtype={
|
||||
# TODO(bkyryliuk): use TIMESTAMP type for presto
|
||||
"ds": DateTime if database.backend != "presto" else String(255),
|
||||
"gender": String(16),
|
||||
"state": String(10),
|
||||
"name": String(255),
|
||||
},
|
||||
method="multi",
|
||||
index=False,
|
||||
)
|
||||
pdf.to_sql(
|
||||
tbl_name,
|
||||
engine,
|
||||
schema=schema,
|
||||
if_exists="replace",
|
||||
chunksize=500,
|
||||
dtype={
|
||||
# TODO(bkyryliuk): use TIMESTAMP type for presto
|
||||
"ds": DateTime if database.backend != "presto" else String(255),
|
||||
"gender": String(16),
|
||||
"state": String(10),
|
||||
"name": String(255),
|
||||
},
|
||||
method="multi",
|
||||
index=False,
|
||||
)
|
||||
print("Done loading table!")
|
||||
print("-" * 80)
|
||||
|
||||
|
|
@ -104,8 +104,8 @@ def load_birth_names(
|
|||
) -> None:
|
||||
"""Loading birth name dataset from a zip file in the repo"""
|
||||
database = get_example_database()
|
||||
engine = database.get_sqla_engine()
|
||||
schema = inspect(engine).default_schema_name
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
schema = inspect(engine).default_schema_name
|
||||
|
||||
tbl_name = "birth_names"
|
||||
table_exists = database.has_table_by_name(tbl_name, schema=schema)
|
||||
|
|
|
|||
|
|
@ -39,38 +39,39 @@ def load_country_map_data(only_metadata: bool = False, force: bool = False) -> N
|
|||
"""Loading data for map with country map"""
|
||||
tbl_name = "birth_france_by_region"
|
||||
database = database_utils.get_example_database()
|
||||
engine = database.get_sqla_engine()
|
||||
schema = inspect(engine).default_schema_name
|
||||
table_exists = database.has_table_by_name(tbl_name)
|
||||
|
||||
if not only_metadata and (not table_exists or force):
|
||||
url = get_example_url("birth_france_data_for_country_map.csv")
|
||||
data = pd.read_csv(url, encoding="utf-8")
|
||||
data["dttm"] = datetime.datetime.now().date()
|
||||
data.to_sql(
|
||||
tbl_name,
|
||||
engine,
|
||||
schema=schema,
|
||||
if_exists="replace",
|
||||
chunksize=500,
|
||||
dtype={
|
||||
"DEPT_ID": String(10),
|
||||
"2003": BigInteger,
|
||||
"2004": BigInteger,
|
||||
"2005": BigInteger,
|
||||
"2006": BigInteger,
|
||||
"2007": BigInteger,
|
||||
"2008": BigInteger,
|
||||
"2009": BigInteger,
|
||||
"2010": BigInteger,
|
||||
"2011": BigInteger,
|
||||
"2012": BigInteger,
|
||||
"2013": BigInteger,
|
||||
"2014": BigInteger,
|
||||
"dttm": Date(),
|
||||
},
|
||||
index=False,
|
||||
)
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
schema = inspect(engine).default_schema_name
|
||||
table_exists = database.has_table_by_name(tbl_name)
|
||||
|
||||
if not only_metadata and (not table_exists or force):
|
||||
url = get_example_url("birth_france_data_for_country_map.csv")
|
||||
data = pd.read_csv(url, encoding="utf-8")
|
||||
data["dttm"] = datetime.datetime.now().date()
|
||||
data.to_sql(
|
||||
tbl_name,
|
||||
engine,
|
||||
schema=schema,
|
||||
if_exists="replace",
|
||||
chunksize=500,
|
||||
dtype={
|
||||
"DEPT_ID": String(10),
|
||||
"2003": BigInteger,
|
||||
"2004": BigInteger,
|
||||
"2005": BigInteger,
|
||||
"2006": BigInteger,
|
||||
"2007": BigInteger,
|
||||
"2008": BigInteger,
|
||||
"2009": BigInteger,
|
||||
"2010": BigInteger,
|
||||
"2011": BigInteger,
|
||||
"2012": BigInteger,
|
||||
"2013": BigInteger,
|
||||
"2014": BigInteger,
|
||||
"dttm": Date(),
|
||||
},
|
||||
index=False,
|
||||
)
|
||||
print("Done loading table!")
|
||||
print("-" * 80)
|
||||
|
||||
|
|
|
|||
|
|
@ -41,24 +41,25 @@ def load_energy(
|
|||
"""Loads an energy related dataset to use with sankey and graphs"""
|
||||
tbl_name = "energy_usage"
|
||||
database = database_utils.get_example_database()
|
||||
engine = database.get_sqla_engine()
|
||||
schema = inspect(engine).default_schema_name
|
||||
table_exists = database.has_table_by_name(tbl_name)
|
||||
|
||||
if not only_metadata and (not table_exists or force):
|
||||
url = get_example_url("energy.json.gz")
|
||||
pdf = pd.read_json(url, compression="gzip")
|
||||
pdf = pdf.head(100) if sample else pdf
|
||||
pdf.to_sql(
|
||||
tbl_name,
|
||||
engine,
|
||||
schema=schema,
|
||||
if_exists="replace",
|
||||
chunksize=500,
|
||||
dtype={"source": String(255), "target": String(255), "value": Float()},
|
||||
index=False,
|
||||
method="multi",
|
||||
)
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
schema = inspect(engine).default_schema_name
|
||||
table_exists = database.has_table_by_name(tbl_name)
|
||||
|
||||
if not only_metadata and (not table_exists or force):
|
||||
url = get_example_url("energy.json.gz")
|
||||
pdf = pd.read_json(url, compression="gzip")
|
||||
pdf = pdf.head(100) if sample else pdf
|
||||
pdf.to_sql(
|
||||
tbl_name,
|
||||
engine,
|
||||
schema=schema,
|
||||
if_exists="replace",
|
||||
chunksize=500,
|
||||
dtype={"source": String(255), "target": String(255), "value": Float()},
|
||||
index=False,
|
||||
method="multi",
|
||||
)
|
||||
|
||||
print("Creating table [wb_health_population] reference")
|
||||
table = get_table_connector_registry()
|
||||
|
|
|
|||
|
|
@ -27,35 +27,37 @@ def load_flights(only_metadata: bool = False, force: bool = False) -> None:
|
|||
"""Loading random time series data from a zip file in the repo"""
|
||||
tbl_name = "flights"
|
||||
database = database_utils.get_example_database()
|
||||
engine = database.get_sqla_engine()
|
||||
schema = inspect(engine).default_schema_name
|
||||
table_exists = database.has_table_by_name(tbl_name)
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
schema = inspect(engine).default_schema_name
|
||||
table_exists = database.has_table_by_name(tbl_name)
|
||||
|
||||
if not only_metadata and (not table_exists or force):
|
||||
flight_data_url = get_example_url("flight_data.csv.gz")
|
||||
pdf = pd.read_csv(flight_data_url, encoding="latin-1", compression="gzip")
|
||||
if not only_metadata and (not table_exists or force):
|
||||
flight_data_url = get_example_url("flight_data.csv.gz")
|
||||
pdf = pd.read_csv(flight_data_url, encoding="latin-1", compression="gzip")
|
||||
|
||||
# Loading airports info to join and get lat/long
|
||||
airports_url = get_example_url("airports.csv.gz")
|
||||
airports = pd.read_csv(airports_url, encoding="latin-1", compression="gzip")
|
||||
airports = airports.set_index("IATA_CODE")
|
||||
# Loading airports info to join and get lat/long
|
||||
airports_url = get_example_url("airports.csv.gz")
|
||||
airports = pd.read_csv(airports_url, encoding="latin-1", compression="gzip")
|
||||
airports = airports.set_index("IATA_CODE")
|
||||
|
||||
pdf[ # pylint: disable=unsupported-assignment-operation,useless-suppression
|
||||
"ds"
|
||||
] = (pdf.YEAR.map(str) + "-0" + pdf.MONTH.map(str) + "-0" + pdf.DAY.map(str))
|
||||
pdf.ds = pd.to_datetime(pdf.ds)
|
||||
pdf.drop(columns=["DAY", "MONTH", "YEAR"])
|
||||
pdf = pdf.join(airports, on="ORIGIN_AIRPORT", rsuffix="_ORIG")
|
||||
pdf = pdf.join(airports, on="DESTINATION_AIRPORT", rsuffix="_DEST")
|
||||
pdf.to_sql(
|
||||
tbl_name,
|
||||
engine,
|
||||
schema=schema,
|
||||
if_exists="replace",
|
||||
chunksize=500,
|
||||
dtype={"ds": DateTime},
|
||||
index=False,
|
||||
)
|
||||
pdf[ # pylint: disable=unsupported-assignment-operation,useless-suppression
|
||||
"ds"
|
||||
] = (
|
||||
pdf.YEAR.map(str) + "-0" + pdf.MONTH.map(str) + "-0" + pdf.DAY.map(str)
|
||||
)
|
||||
pdf.ds = pd.to_datetime(pdf.ds)
|
||||
pdf.drop(columns=["DAY", "MONTH", "YEAR"])
|
||||
pdf = pdf.join(airports, on="ORIGIN_AIRPORT", rsuffix="_ORIG")
|
||||
pdf = pdf.join(airports, on="DESTINATION_AIRPORT", rsuffix="_DEST")
|
||||
pdf.to_sql(
|
||||
tbl_name,
|
||||
engine,
|
||||
schema=schema,
|
||||
if_exists="replace",
|
||||
chunksize=500,
|
||||
dtype={"ds": DateTime},
|
||||
index=False,
|
||||
)
|
||||
|
||||
table = get_table_connector_registry()
|
||||
tbl = db.session.query(table).filter_by(table_name=tbl_name).first()
|
||||
|
|
|
|||
|
|
@ -39,49 +39,51 @@ def load_long_lat_data(only_metadata: bool = False, force: bool = False) -> None
|
|||
"""Loading lat/long data from a csv file in the repo"""
|
||||
tbl_name = "long_lat"
|
||||
database = database_utils.get_example_database()
|
||||
engine = database.get_sqla_engine()
|
||||
schema = inspect(engine).default_schema_name
|
||||
table_exists = database.has_table_by_name(tbl_name)
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
schema = inspect(engine).default_schema_name
|
||||
table_exists = database.has_table_by_name(tbl_name)
|
||||
|
||||
if not only_metadata and (not table_exists or force):
|
||||
url = get_example_url("san_francisco.csv.gz")
|
||||
pdf = pd.read_csv(url, encoding="utf-8", compression="gzip")
|
||||
start = datetime.datetime.now().replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
pdf["datetime"] = [
|
||||
start + datetime.timedelta(hours=i * 24 / (len(pdf) - 1))
|
||||
for i in range(len(pdf))
|
||||
]
|
||||
pdf["occupancy"] = [random.randint(1, 6) for _ in range(len(pdf))]
|
||||
pdf["radius_miles"] = [random.uniform(1, 3) for _ in range(len(pdf))]
|
||||
pdf["geohash"] = pdf[["LAT", "LON"]].apply(lambda x: geohash.encode(*x), axis=1)
|
||||
pdf["delimited"] = pdf["LAT"].map(str).str.cat(pdf["LON"].map(str), sep=",")
|
||||
pdf.to_sql(
|
||||
tbl_name,
|
||||
engine,
|
||||
schema=schema,
|
||||
if_exists="replace",
|
||||
chunksize=500,
|
||||
dtype={
|
||||
"longitude": Float(),
|
||||
"latitude": Float(),
|
||||
"number": Float(),
|
||||
"street": String(100),
|
||||
"unit": String(10),
|
||||
"city": String(50),
|
||||
"district": String(50),
|
||||
"region": String(50),
|
||||
"postcode": Float(),
|
||||
"id": String(100),
|
||||
"datetime": DateTime(),
|
||||
"occupancy": Float(),
|
||||
"radius_miles": Float(),
|
||||
"geohash": String(12),
|
||||
"delimited": String(60),
|
||||
},
|
||||
index=False,
|
||||
)
|
||||
if not only_metadata and (not table_exists or force):
|
||||
url = get_example_url("san_francisco.csv.gz")
|
||||
pdf = pd.read_csv(url, encoding="utf-8", compression="gzip")
|
||||
start = datetime.datetime.now().replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
pdf["datetime"] = [
|
||||
start + datetime.timedelta(hours=i * 24 / (len(pdf) - 1))
|
||||
for i in range(len(pdf))
|
||||
]
|
||||
pdf["occupancy"] = [random.randint(1, 6) for _ in range(len(pdf))]
|
||||
pdf["radius_miles"] = [random.uniform(1, 3) for _ in range(len(pdf))]
|
||||
pdf["geohash"] = pdf[["LAT", "LON"]].apply(
|
||||
lambda x: geohash.encode(*x), axis=1
|
||||
)
|
||||
pdf["delimited"] = pdf["LAT"].map(str).str.cat(pdf["LON"].map(str), sep=",")
|
||||
pdf.to_sql(
|
||||
tbl_name,
|
||||
engine,
|
||||
schema=schema,
|
||||
if_exists="replace",
|
||||
chunksize=500,
|
||||
dtype={
|
||||
"longitude": Float(),
|
||||
"latitude": Float(),
|
||||
"number": Float(),
|
||||
"street": String(100),
|
||||
"unit": String(10),
|
||||
"city": String(50),
|
||||
"district": String(50),
|
||||
"region": String(50),
|
||||
"postcode": Float(),
|
||||
"id": String(100),
|
||||
"datetime": DateTime(),
|
||||
"occupancy": Float(),
|
||||
"radius_miles": Float(),
|
||||
"geohash": String(12),
|
||||
"delimited": String(60),
|
||||
},
|
||||
index=False,
|
||||
)
|
||||
print("Done loading table!")
|
||||
print("-" * 80)
|
||||
|
||||
|
|
|
|||
|
|
@ -39,41 +39,41 @@ def load_multiformat_time_series( # pylint: disable=too-many-locals
|
|||
"""Loading time series data from a zip file in the repo"""
|
||||
tbl_name = "multiformat_time_series"
|
||||
database = get_example_database()
|
||||
engine = database.get_sqla_engine()
|
||||
schema = inspect(engine).default_schema_name
|
||||
table_exists = database.has_table_by_name(tbl_name)
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
schema = inspect(engine).default_schema_name
|
||||
table_exists = database.has_table_by_name(tbl_name)
|
||||
|
||||
if not only_metadata and (not table_exists or force):
|
||||
url = get_example_url("multiformat_time_series.json.gz")
|
||||
pdf = pd.read_json(url, compression="gzip")
|
||||
# TODO(bkyryliuk): move load examples data into the pytest fixture
|
||||
if database.backend == "presto":
|
||||
pdf.ds = pd.to_datetime(pdf.ds, unit="s")
|
||||
pdf.ds = pdf.ds.dt.strftime("%Y-%m-%d")
|
||||
pdf.ds2 = pd.to_datetime(pdf.ds2, unit="s")
|
||||
pdf.ds2 = pdf.ds2.dt.strftime("%Y-%m-%d %H:%M%:%S")
|
||||
else:
|
||||
pdf.ds = pd.to_datetime(pdf.ds, unit="s")
|
||||
pdf.ds2 = pd.to_datetime(pdf.ds2, unit="s")
|
||||
if not only_metadata and (not table_exists or force):
|
||||
url = get_example_url("multiformat_time_series.json.gz")
|
||||
pdf = pd.read_json(url, compression="gzip")
|
||||
# TODO(bkyryliuk): move load examples data into the pytest fixture
|
||||
if database.backend == "presto":
|
||||
pdf.ds = pd.to_datetime(pdf.ds, unit="s")
|
||||
pdf.ds = pdf.ds.dt.strftime("%Y-%m-%d")
|
||||
pdf.ds2 = pd.to_datetime(pdf.ds2, unit="s")
|
||||
pdf.ds2 = pdf.ds2.dt.strftime("%Y-%m-%d %H:%M%:%S")
|
||||
else:
|
||||
pdf.ds = pd.to_datetime(pdf.ds, unit="s")
|
||||
pdf.ds2 = pd.to_datetime(pdf.ds2, unit="s")
|
||||
|
||||
pdf.to_sql(
|
||||
tbl_name,
|
||||
engine,
|
||||
schema=schema,
|
||||
if_exists="replace",
|
||||
chunksize=500,
|
||||
dtype={
|
||||
"ds": String(255) if database.backend == "presto" else Date,
|
||||
"ds2": String(255) if database.backend == "presto" else DateTime,
|
||||
"epoch_s": BigInteger,
|
||||
"epoch_ms": BigInteger,
|
||||
"string0": String(100),
|
||||
"string1": String(100),
|
||||
"string2": String(100),
|
||||
"string3": String(100),
|
||||
},
|
||||
index=False,
|
||||
)
|
||||
pdf.to_sql(
|
||||
tbl_name,
|
||||
engine,
|
||||
schema=schema,
|
||||
if_exists="replace",
|
||||
chunksize=500,
|
||||
dtype={
|
||||
"ds": String(255) if database.backend == "presto" else Date,
|
||||
"ds2": String(255) if database.backend == "presto" else DateTime,
|
||||
"epoch_s": BigInteger,
|
||||
"epoch_ms": BigInteger,
|
||||
"string0": String(100),
|
||||
"string1": String(100),
|
||||
"string2": String(100),
|
||||
"string3": String(100),
|
||||
},
|
||||
index=False,
|
||||
)
|
||||
print("Done loading table!")
|
||||
print("-" * 80)
|
||||
|
||||
|
|
|
|||
|
|
@ -28,29 +28,29 @@ from .helpers import get_example_url, get_table_connector_registry
|
|||
def load_paris_iris_geojson(only_metadata: bool = False, force: bool = False) -> None:
|
||||
tbl_name = "paris_iris_mapping"
|
||||
database = database_utils.get_example_database()
|
||||
engine = database.get_sqla_engine()
|
||||
schema = inspect(engine).default_schema_name
|
||||
table_exists = database.has_table_by_name(tbl_name)
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
schema = inspect(engine).default_schema_name
|
||||
table_exists = database.has_table_by_name(tbl_name)
|
||||
|
||||
if not only_metadata and (not table_exists or force):
|
||||
url = get_example_url("paris_iris.json.gz")
|
||||
df = pd.read_json(url, compression="gzip")
|
||||
df["features"] = df.features.map(json.dumps)
|
||||
if not only_metadata and (not table_exists or force):
|
||||
url = get_example_url("paris_iris.json.gz")
|
||||
df = pd.read_json(url, compression="gzip")
|
||||
df["features"] = df.features.map(json.dumps)
|
||||
|
||||
df.to_sql(
|
||||
tbl_name,
|
||||
engine,
|
||||
schema=schema,
|
||||
if_exists="replace",
|
||||
chunksize=500,
|
||||
dtype={
|
||||
"color": String(255),
|
||||
"name": String(255),
|
||||
"features": Text,
|
||||
"type": Text,
|
||||
},
|
||||
index=False,
|
||||
)
|
||||
df.to_sql(
|
||||
tbl_name,
|
||||
engine,
|
||||
schema=schema,
|
||||
if_exists="replace",
|
||||
chunksize=500,
|
||||
dtype={
|
||||
"color": String(255),
|
||||
"name": String(255),
|
||||
"features": Text,
|
||||
"type": Text,
|
||||
},
|
||||
index=False,
|
||||
)
|
||||
|
||||
print("Creating table {} reference".format(tbl_name))
|
||||
table = get_table_connector_registry()
|
||||
|
|
|
|||
|
|
@ -37,28 +37,28 @@ def load_random_time_series_data(
|
|||
"""Loading random time series data from a zip file in the repo"""
|
||||
tbl_name = "random_time_series"
|
||||
database = database_utils.get_example_database()
|
||||
engine = database.get_sqla_engine()
|
||||
schema = inspect(engine).default_schema_name
|
||||
table_exists = database.has_table_by_name(tbl_name)
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
schema = inspect(engine).default_schema_name
|
||||
table_exists = database.has_table_by_name(tbl_name)
|
||||
|
||||
if not only_metadata and (not table_exists or force):
|
||||
url = get_example_url("random_time_series.json.gz")
|
||||
pdf = pd.read_json(url, compression="gzip")
|
||||
if database.backend == "presto":
|
||||
pdf.ds = pd.to_datetime(pdf.ds, unit="s")
|
||||
pdf.ds = pdf.ds.dt.strftime("%Y-%m-%d %H:%M%:%S")
|
||||
else:
|
||||
pdf.ds = pd.to_datetime(pdf.ds, unit="s")
|
||||
if not only_metadata and (not table_exists or force):
|
||||
url = get_example_url("random_time_series.json.gz")
|
||||
pdf = pd.read_json(url, compression="gzip")
|
||||
if database.backend == "presto":
|
||||
pdf.ds = pd.to_datetime(pdf.ds, unit="s")
|
||||
pdf.ds = pdf.ds.dt.strftime("%Y-%m-%d %H:%M%:%S")
|
||||
else:
|
||||
pdf.ds = pd.to_datetime(pdf.ds, unit="s")
|
||||
|
||||
pdf.to_sql(
|
||||
tbl_name,
|
||||
engine,
|
||||
schema=schema,
|
||||
if_exists="replace",
|
||||
chunksize=500,
|
||||
dtype={"ds": DateTime if database.backend != "presto" else String(255)},
|
||||
index=False,
|
||||
)
|
||||
pdf.to_sql(
|
||||
tbl_name,
|
||||
engine,
|
||||
schema=schema,
|
||||
if_exists="replace",
|
||||
chunksize=500,
|
||||
dtype={"ds": DateTime if database.backend != "presto" else String(255)},
|
||||
index=False,
|
||||
)
|
||||
print("Done loading table!")
|
||||
print("-" * 80)
|
||||
|
||||
|
|
|
|||
|
|
@ -30,29 +30,29 @@ def load_sf_population_polygons(
|
|||
) -> None:
|
||||
tbl_name = "sf_population_polygons"
|
||||
database = database_utils.get_example_database()
|
||||
engine = database.get_sqla_engine()
|
||||
schema = inspect(engine).default_schema_name
|
||||
table_exists = database.has_table_by_name(tbl_name)
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
schema = inspect(engine).default_schema_name
|
||||
table_exists = database.has_table_by_name(tbl_name)
|
||||
|
||||
if not only_metadata and (not table_exists or force):
|
||||
url = get_example_url("sf_population.json.gz")
|
||||
df = pd.read_json(url, compression="gzip")
|
||||
df["contour"] = df.contour.map(json.dumps)
|
||||
if not only_metadata and (not table_exists or force):
|
||||
url = get_example_url("sf_population.json.gz")
|
||||
df = pd.read_json(url, compression="gzip")
|
||||
df["contour"] = df.contour.map(json.dumps)
|
||||
|
||||
df.to_sql(
|
||||
tbl_name,
|
||||
engine,
|
||||
schema=schema,
|
||||
if_exists="replace",
|
||||
chunksize=500,
|
||||
dtype={
|
||||
"zipcode": BigInteger,
|
||||
"population": BigInteger,
|
||||
"contour": Text,
|
||||
"area": Float,
|
||||
},
|
||||
index=False,
|
||||
)
|
||||
df.to_sql(
|
||||
tbl_name,
|
||||
engine,
|
||||
schema=schema,
|
||||
if_exists="replace",
|
||||
chunksize=500,
|
||||
dtype={
|
||||
"zipcode": BigInteger,
|
||||
"population": BigInteger,
|
||||
"contour": Text,
|
||||
"area": Float,
|
||||
},
|
||||
index=False,
|
||||
)
|
||||
|
||||
print("Creating table {} reference".format(tbl_name))
|
||||
table = get_table_connector_registry()
|
||||
|
|
|
|||
|
|
@ -453,11 +453,11 @@ def load_supported_charts_dashboard() -> None:
|
|||
"""Loading a dashboard featuring supported charts"""
|
||||
|
||||
database = get_example_database()
|
||||
engine = database.get_sqla_engine()
|
||||
schema = inspect(engine).default_schema_name
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
schema = inspect(engine).default_schema_name
|
||||
|
||||
tbl_name = "birth_names"
|
||||
table_exists = database.has_table_by_name(tbl_name, schema=schema)
|
||||
tbl_name = "birth_names"
|
||||
table_exists = database.has_table_by_name(tbl_name, schema=schema)
|
||||
|
||||
if table_exists:
|
||||
table = get_table_connector_registry()
|
||||
|
|
|
|||
|
|
@ -51,37 +51,38 @@ def load_world_bank_health_n_pop( # pylint: disable=too-many-locals, too-many-s
|
|||
"""Loads the world bank health dataset, slices and a dashboard"""
|
||||
tbl_name = "wb_health_population"
|
||||
database = superset.utils.database.get_example_database()
|
||||
engine = database.get_sqla_engine()
|
||||
schema = inspect(engine).default_schema_name
|
||||
table_exists = database.has_table_by_name(tbl_name)
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
|
||||
if not only_metadata and (not table_exists or force):
|
||||
url = get_example_url("countries.json.gz")
|
||||
pdf = pd.read_json(url, compression="gzip")
|
||||
pdf.columns = [col.replace(".", "_") for col in pdf.columns]
|
||||
if database.backend == "presto":
|
||||
pdf.year = pd.to_datetime(pdf.year)
|
||||
pdf.year = pdf.year.dt.strftime("%Y-%m-%d %H:%M%:%S")
|
||||
else:
|
||||
pdf.year = pd.to_datetime(pdf.year)
|
||||
pdf = pdf.head(100) if sample else pdf
|
||||
schema = inspect(engine).default_schema_name
|
||||
table_exists = database.has_table_by_name(tbl_name)
|
||||
|
||||
pdf.to_sql(
|
||||
tbl_name,
|
||||
engine,
|
||||
schema=schema,
|
||||
if_exists="replace",
|
||||
chunksize=50,
|
||||
dtype={
|
||||
# TODO(bkyryliuk): use TIMESTAMP type for presto
|
||||
"year": DateTime if database.backend != "presto" else String(255),
|
||||
"country_code": String(3),
|
||||
"country_name": String(255),
|
||||
"region": String(255),
|
||||
},
|
||||
method="multi",
|
||||
index=False,
|
||||
)
|
||||
if not only_metadata and (not table_exists or force):
|
||||
url = get_example_url("countries.json.gz")
|
||||
pdf = pd.read_json(url, compression="gzip")
|
||||
pdf.columns = [col.replace(".", "_") for col in pdf.columns]
|
||||
if database.backend == "presto":
|
||||
pdf.year = pd.to_datetime(pdf.year)
|
||||
pdf.year = pdf.year.dt.strftime("%Y-%m-%d %H:%M%:%S")
|
||||
else:
|
||||
pdf.year = pd.to_datetime(pdf.year)
|
||||
pdf = pdf.head(100) if sample else pdf
|
||||
|
||||
pdf.to_sql(
|
||||
tbl_name,
|
||||
engine,
|
||||
schema=schema,
|
||||
if_exists="replace",
|
||||
chunksize=50,
|
||||
dtype={
|
||||
# TODO(bkyryliuk): use TIMESTAMP type for presto
|
||||
"year": DateTime if database.backend != "presto" else String(255),
|
||||
"country_code": String(3),
|
||||
"country_name": String(255),
|
||||
"region": String(255),
|
||||
},
|
||||
method="multi",
|
||||
index=False,
|
||||
)
|
||||
|
||||
print("Creating table [wb_health_population] reference")
|
||||
table = get_table_connector_registry()
|
||||
|
|
|
|||
|
|
@ -369,12 +369,9 @@ class Database(
|
|||
nullpool: bool = True,
|
||||
source: Optional[utils.QuerySource] = None,
|
||||
) -> Engine:
|
||||
try:
|
||||
yield self.get_sqla_engine(schema=schema, nullpool=nullpool, source=source)
|
||||
except Exception as ex:
|
||||
raise self.db_engine_spec.get_dbapi_mapped_exception(ex)
|
||||
yield self._get_sqla_engine(schema=schema, nullpool=nullpool, source=source)
|
||||
|
||||
def get_sqla_engine(
|
||||
def _get_sqla_engine(
|
||||
self,
|
||||
schema: Optional[str] = None,
|
||||
nullpool: bool = True,
|
||||
|
|
@ -392,7 +389,7 @@ class Database(
|
|||
)
|
||||
|
||||
masked_url = self.get_password_masked_url(sqlalchemy_url)
|
||||
logger.debug("Database.get_sqla_engine(). Masked URL: %s", str(masked_url))
|
||||
logger.debug("Database._get_sqla_engine(). Masked URL: %s", str(masked_url))
|
||||
|
||||
params = extra.get("engine_params", {})
|
||||
if nullpool:
|
||||
|
|
@ -442,7 +439,7 @@ class Database(
|
|||
mutator: Optional[Callable[[pd.DataFrame], None]] = None,
|
||||
) -> pd.DataFrame:
|
||||
sqls = self.db_engine_spec.parse_sql(sql)
|
||||
engine = self.get_sqla_engine(schema)
|
||||
engine = self._get_sqla_engine(schema)
|
||||
|
||||
def needs_conversion(df_series: pd.Series) -> bool:
|
||||
return (
|
||||
|
|
@ -487,7 +484,7 @@ class Database(
|
|||
return df
|
||||
|
||||
def compile_sqla_query(self, qry: Select, schema: Optional[str] = None) -> str:
|
||||
engine = self.get_sqla_engine(schema=schema)
|
||||
engine = self._get_sqla_engine(schema=schema)
|
||||
|
||||
sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True}))
|
||||
|
||||
|
|
@ -508,7 +505,7 @@ class Database(
|
|||
cols: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> str:
|
||||
"""Generates a ``select *`` statement in the proper dialect"""
|
||||
eng = self.get_sqla_engine(schema=schema, source=utils.QuerySource.SQL_LAB)
|
||||
eng = self._get_sqla_engine(schema=schema, source=utils.QuerySource.SQL_LAB)
|
||||
return self.db_engine_spec.select_star(
|
||||
self,
|
||||
table_name,
|
||||
|
|
@ -533,7 +530,7 @@ class Database(
|
|||
|
||||
@property
|
||||
def inspector(self) -> Inspector:
|
||||
engine = self.get_sqla_engine()
|
||||
engine = self._get_sqla_engine()
|
||||
return sqla.inspect(engine)
|
||||
|
||||
@cache_util.memoized_func(
|
||||
|
|
@ -674,7 +671,7 @@ class Database(
|
|||
meta,
|
||||
schema=schema or None,
|
||||
autoload=True,
|
||||
autoload_with=self.get_sqla_engine(),
|
||||
autoload_with=self._get_sqla_engine(),
|
||||
)
|
||||
|
||||
def get_table_comment(
|
||||
|
|
@ -765,11 +762,11 @@ class Database(
|
|||
return self.perm # type: ignore
|
||||
|
||||
def has_table(self, table: Table) -> bool:
|
||||
engine = self.get_sqla_engine()
|
||||
engine = self._get_sqla_engine()
|
||||
return engine.has_table(table.table_name, table.schema or None)
|
||||
|
||||
def has_table_by_name(self, table_name: str, schema: Optional[str] = None) -> bool:
|
||||
engine = self.get_sqla_engine()
|
||||
engine = self._get_sqla_engine()
|
||||
return engine.has_table(table_name, schema)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -788,7 +785,7 @@ class Database(
|
|||
return view_name in view_names
|
||||
|
||||
def has_view(self, view_name: str, schema: Optional[str] = None) -> bool:
|
||||
engine = self.get_sqla_engine()
|
||||
engine = self._get_sqla_engine()
|
||||
return engine.run_callable(self._has_view, engine.dialect, view_name, schema)
|
||||
|
||||
def has_view_by_name(self, view_name: str, schema: Optional[str] = None) -> bool:
|
||||
|
|
|
|||
|
|
@ -224,8 +224,9 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin):
|
|||
@property
|
||||
def sqla_metadata(self) -> None:
|
||||
# pylint: disable=no-member
|
||||
meta = MetaData(bind=self.get_sqla_engine())
|
||||
meta.reflect()
|
||||
with self.get_sqla_engine_with_context() as engine:
|
||||
meta = MetaData(bind=engine)
|
||||
meta.reflect()
|
||||
|
||||
@property
|
||||
def status(self) -> utils.DashboardStatus:
|
||||
|
|
|
|||
|
|
@ -55,8 +55,9 @@ class FilterSet(Model, AuditMixinNullable):
|
|||
@property
|
||||
def sqla_metadata(self) -> None:
|
||||
# pylint: disable=no-member
|
||||
meta = MetaData(bind=self.get_sqla_engine())
|
||||
meta.reflect()
|
||||
with self.get_sqla_engine_with_context() as engine:
|
||||
meta = MetaData(bind=engine)
|
||||
meta.reflect()
|
||||
|
||||
@property
|
||||
def changed_by_name(self) -> str:
|
||||
|
|
|
|||
|
|
@ -1281,13 +1281,13 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
|||
if limit:
|
||||
qry = qry.limit(limit)
|
||||
|
||||
engine = self.database.get_sqla_engine() # type: ignore
|
||||
sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
|
||||
sql = self._apply_cte(sql, cte)
|
||||
sql = self.mutate_query_from_config(sql)
|
||||
with self.database.get_sqla_engine_with_context() as engine: # type: ignore
|
||||
sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
|
||||
sql = self._apply_cte(sql, cte)
|
||||
sql = self.mutate_query_from_config(sql)
|
||||
|
||||
df = pd.read_sql_query(sql=sql, con=engine)
|
||||
return df[column_name].to_list()
|
||||
df = pd.read_sql_query(sql=sql, con=engine)
|
||||
return df[column_name].to_list()
|
||||
|
||||
def get_timestamp_expression(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -463,61 +463,66 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca
|
|||
)
|
||||
)
|
||||
|
||||
engine = database.get_sqla_engine(query.schema, source=QuerySource.SQL_LAB)
|
||||
# Sharing a single connection and cursor across the
|
||||
# execution of all statements (if many)
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
# closing the connection closes the cursor as well
|
||||
cursor = conn.cursor()
|
||||
cancel_query_id = db_engine_spec.get_cancel_query_id(cursor, query)
|
||||
if cancel_query_id is not None:
|
||||
query.set_extra_json_key(cancel_query_key, cancel_query_id)
|
||||
session.commit()
|
||||
statement_count = len(statements)
|
||||
for i, statement in enumerate(statements):
|
||||
# Check if stopped
|
||||
session.refresh(query)
|
||||
if query.status == QueryStatus.STOPPED:
|
||||
payload.update({"status": query.status})
|
||||
return payload
|
||||
with database.get_sqla_engine_with_context(
|
||||
query.schema, source=QuerySource.SQL_LAB
|
||||
) as engine:
|
||||
# Sharing a single connection and cursor across the
|
||||
# execution of all statements (if many)
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
# closing the connection closes the cursor as well
|
||||
cursor = conn.cursor()
|
||||
cancel_query_id = db_engine_spec.get_cancel_query_id(cursor, query)
|
||||
if cancel_query_id is not None:
|
||||
query.set_extra_json_key(cancel_query_key, cancel_query_id)
|
||||
session.commit()
|
||||
statement_count = len(statements)
|
||||
for i, statement in enumerate(statements):
|
||||
# Check if stopped
|
||||
session.refresh(query)
|
||||
if query.status == QueryStatus.STOPPED:
|
||||
payload.update({"status": query.status})
|
||||
return payload
|
||||
|
||||
# For CTAS we create the table only on the last statement
|
||||
apply_ctas = query.select_as_cta and (
|
||||
query.ctas_method == CtasMethod.VIEW
|
||||
or (query.ctas_method == CtasMethod.TABLE and i == len(statements) - 1)
|
||||
)
|
||||
# For CTAS we create the table only on the last statement
|
||||
apply_ctas = query.select_as_cta and (
|
||||
query.ctas_method == CtasMethod.VIEW
|
||||
or (
|
||||
query.ctas_method == CtasMethod.TABLE
|
||||
and i == len(statements) - 1
|
||||
)
|
||||
)
|
||||
|
||||
# Run statement
|
||||
msg = f"Running statement {i+1} out of {statement_count}"
|
||||
logger.info("Query %s: %s", str(query_id), msg)
|
||||
query.set_extra_json_key("progress", msg)
|
||||
session.commit()
|
||||
try:
|
||||
result_set = execute_sql_statement(
|
||||
statement,
|
||||
query,
|
||||
session,
|
||||
cursor,
|
||||
log_params,
|
||||
apply_ctas,
|
||||
)
|
||||
except SqlLabQueryStoppedException:
|
||||
payload.update({"status": QueryStatus.STOPPED})
|
||||
return payload
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
msg = str(ex)
|
||||
prefix_message = (
|
||||
f"[Statement {i+1} out of {statement_count}]"
|
||||
if statement_count > 1
|
||||
else ""
|
||||
)
|
||||
payload = handle_query_error(
|
||||
ex, query, session, payload, prefix_message
|
||||
)
|
||||
return payload
|
||||
# Run statement
|
||||
msg = f"Running statement {i+1} out of {statement_count}"
|
||||
logger.info("Query %s: %s", str(query_id), msg)
|
||||
query.set_extra_json_key("progress", msg)
|
||||
session.commit()
|
||||
try:
|
||||
result_set = execute_sql_statement(
|
||||
statement,
|
||||
query,
|
||||
session,
|
||||
cursor,
|
||||
log_params,
|
||||
apply_ctas,
|
||||
)
|
||||
except SqlLabQueryStoppedException:
|
||||
payload.update({"status": QueryStatus.STOPPED})
|
||||
return payload
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
msg = str(ex)
|
||||
prefix_message = (
|
||||
f"[Statement {i+1} out of {statement_count}]"
|
||||
if statement_count > 1
|
||||
else ""
|
||||
)
|
||||
payload = handle_query_error(
|
||||
ex, query, session, payload, prefix_message
|
||||
)
|
||||
return payload
|
||||
|
||||
# Commit the connection so CTA queries will create the table.
|
||||
conn.commit()
|
||||
# Commit the connection so CTA queries will create the table.
|
||||
conn.commit()
|
||||
|
||||
# Success, updating the query entry in database
|
||||
query.rows = result_set.size
|
||||
|
|
@ -622,10 +627,11 @@ def cancel_query(query: Query) -> bool:
|
|||
if cancel_query_id is None:
|
||||
return False
|
||||
|
||||
engine = query.database.get_sqla_engine(query.schema, source=QuerySource.SQL_LAB)
|
||||
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
with closing(conn.cursor()) as cursor:
|
||||
return query.database.db_engine_spec.cancel_query(
|
||||
cursor, query, cancel_query_id
|
||||
)
|
||||
with query.database.get_sqla_engine_with_context(
|
||||
query.schema, source=QuerySource.SQL_LAB
|
||||
) as engine:
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
with closing(conn.cursor()) as cursor:
|
||||
return query.database.db_engine_spec.cancel_query(
|
||||
cursor, query, cancel_query_id
|
||||
)
|
||||
|
|
|
|||
|
|
@ -162,16 +162,18 @@ class PrestoDBSQLValidator(BaseSQLValidator):
|
|||
statements = parsed_query.get_statements()
|
||||
|
||||
logger.info("Validating %i statement(s)", len(statements))
|
||||
engine = database.get_sqla_engine(schema, source=QuerySource.SQL_LAB)
|
||||
# Sharing a single connection and cursor across the
|
||||
# execution of all statements (if many)
|
||||
annotations: List[SQLValidationAnnotation] = []
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
cursor = conn.cursor()
|
||||
for statement in parsed_query.get_statements():
|
||||
annotation = cls.validate_statement(statement, database, cursor)
|
||||
if annotation:
|
||||
annotations.append(annotation)
|
||||
logger.debug("Validation found %i error(s)", len(annotations))
|
||||
with database.get_sqla_engine_with_context(
|
||||
schema, source=QuerySource.SQL_LAB
|
||||
) as engine:
|
||||
# Sharing a single connection and cursor across the
|
||||
# execution of all statements (if many)
|
||||
annotations: List[SQLValidationAnnotation] = []
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
cursor = conn.cursor()
|
||||
for statement in parsed_query.get_statements():
|
||||
annotation = cls.validate_statement(statement, database, cursor)
|
||||
if annotation:
|
||||
annotations.append(annotation)
|
||||
logger.debug("Validation found %i error(s)", len(annotations))
|
||||
|
||||
return annotations
|
||||
|
|
|
|||
|
|
@ -1284,8 +1284,8 @@ def get_example_default_schema() -> Optional[str]:
|
|||
Return the default schema of the examples database, if any.
|
||||
"""
|
||||
database = get_example_database()
|
||||
engine = database.get_sqla_engine()
|
||||
return inspect(engine).default_schema_name
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
return inspect(engine).default_schema_name
|
||||
|
||||
|
||||
def backend() -> str:
|
||||
|
|
|
|||
|
|
@ -187,29 +187,29 @@ def add_data(
|
|||
|
||||
database = get_example_database()
|
||||
table_exists = database.has_table_by_name(table_name)
|
||||
engine = database.get_sqla_engine()
|
||||
|
||||
if columns is None:
|
||||
if not table_exists:
|
||||
raise Exception(
|
||||
f"The table {table_name} does not exist. To create it you need to "
|
||||
"pass a list of column names and types."
|
||||
)
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
if columns is None:
|
||||
if not table_exists:
|
||||
raise Exception(
|
||||
f"The table {table_name} does not exist. To create it you need to "
|
||||
"pass a list of column names and types."
|
||||
)
|
||||
|
||||
inspector = inspect(engine)
|
||||
columns = inspector.get_columns(table_name)
|
||||
inspector = inspect(engine)
|
||||
columns = inspector.get_columns(table_name)
|
||||
|
||||
# create table if needed
|
||||
column_objects = get_column_objects(columns)
|
||||
metadata = MetaData()
|
||||
table = Table(table_name, metadata, *column_objects)
|
||||
metadata.create_all(engine)
|
||||
# create table if needed
|
||||
column_objects = get_column_objects(columns)
|
||||
metadata = MetaData()
|
||||
table = Table(table_name, metadata, *column_objects)
|
||||
metadata.create_all(engine)
|
||||
|
||||
if not append:
|
||||
engine.execute(table.delete())
|
||||
if not append:
|
||||
engine.execute(table.delete())
|
||||
|
||||
data = generate_data(columns, num_rows)
|
||||
engine.execute(table.insert(), data)
|
||||
data = generate_data(columns, num_rows)
|
||||
engine.execute(table.insert(), data)
|
||||
|
||||
|
||||
def get_column_objects(columns: List[ColumnInfo]) -> List[Column]:
|
||||
|
|
|
|||
|
|
@ -1379,11 +1379,11 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
)
|
||||
database.set_sqlalchemy_uri(uri)
|
||||
database.db_engine_spec.mutate_db_for_connection_test(database)
|
||||
engine = database.get_sqla_engine()
|
||||
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
if engine.dialect.do_ping(conn):
|
||||
return json_success('"OK"')
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
if engine.dialect.do_ping(conn):
|
||||
return json_success('"OK"')
|
||||
|
||||
raise DBAPIError(None, None, None)
|
||||
except CertificateException as ex:
|
||||
|
|
|
|||
|
|
@ -171,7 +171,7 @@ def example_db_provider() -> Callable[[], Database]: # type: ignore
|
|||
return self._db
|
||||
|
||||
def _load_lazy_data_to_decouple_from_session(self) -> None:
|
||||
self._db.get_sqla_engine() # type: ignore
|
||||
self._db._get_sqla_engine() # type: ignore
|
||||
self._db.backend # type: ignore
|
||||
|
||||
def remove(self) -> None:
|
||||
|
|
@ -336,37 +336,38 @@ def physical_dataset():
|
|||
from superset.connectors.sqla.utils import get_identifier_quoter
|
||||
|
||||
example_database = get_example_database()
|
||||
engine = example_database.get_sqla_engine()
|
||||
quoter = get_identifier_quoter(engine.name)
|
||||
# sqlite can only execute one statement at a time
|
||||
engine.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS physical_dataset(
|
||||
col1 INTEGER,
|
||||
col2 VARCHAR(255),
|
||||
col3 DECIMAL(4,2),
|
||||
col4 VARCHAR(255),
|
||||
col5 TIMESTAMP DEFAULT '1970-01-01 00:00:01',
|
||||
col6 TIMESTAMP DEFAULT '1970-01-01 00:00:01',
|
||||
{quoter('time column with spaces')} TIMESTAMP DEFAULT '1970-01-01 00:00:01'
|
||||
);
|
||||
|
||||
with example_database.get_sqla_engine_with_context() as engine:
|
||||
quoter = get_identifier_quoter(engine.name)
|
||||
# sqlite can only execute one statement at a time
|
||||
engine.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS physical_dataset(
|
||||
col1 INTEGER,
|
||||
col2 VARCHAR(255),
|
||||
col3 DECIMAL(4,2),
|
||||
col4 VARCHAR(255),
|
||||
col5 TIMESTAMP DEFAULT '1970-01-01 00:00:01',
|
||||
col6 TIMESTAMP DEFAULT '1970-01-01 00:00:01',
|
||||
{quoter('time column with spaces')} TIMESTAMP DEFAULT '1970-01-01 00:00:01'
|
||||
);
|
||||
"""
|
||||
)
|
||||
engine.execute(
|
||||
"""
|
||||
INSERT INTO physical_dataset values
|
||||
(0, 'a', 1.0, NULL, '2000-01-01 00:00:00', '2002-01-03 00:00:00', '2002-01-03 00:00:00'),
|
||||
(1, 'b', 1.1, NULL, '2000-01-02 00:00:00', '2002-02-04 00:00:00', '2002-02-04 00:00:00'),
|
||||
(2, 'c', 1.2, NULL, '2000-01-03 00:00:00', '2002-03-07 00:00:00', '2002-03-07 00:00:00'),
|
||||
(3, 'd', 1.3, NULL, '2000-01-04 00:00:00', '2002-04-12 00:00:00', '2002-04-12 00:00:00'),
|
||||
(4, 'e', 1.4, NULL, '2000-01-05 00:00:00', '2002-05-11 00:00:00', '2002-05-11 00:00:00'),
|
||||
(5, 'f', 1.5, NULL, '2000-01-06 00:00:00', '2002-06-13 00:00:00', '2002-06-13 00:00:00'),
|
||||
(6, 'g', 1.6, NULL, '2000-01-07 00:00:00', '2002-07-15 00:00:00', '2002-07-15 00:00:00'),
|
||||
(7, 'h', 1.7, NULL, '2000-01-08 00:00:00', '2002-08-18 00:00:00', '2002-08-18 00:00:00'),
|
||||
(8, 'i', 1.8, NULL, '2000-01-09 00:00:00', '2002-09-20 00:00:00', '2002-09-20 00:00:00'),
|
||||
(9, 'j', 1.9, NULL, '2000-01-10 00:00:00', '2002-10-22 00:00:00', '2002-10-22 00:00:00');
|
||||
"""
|
||||
)
|
||||
engine.execute(
|
||||
"""
|
||||
INSERT INTO physical_dataset values
|
||||
(0, 'a', 1.0, NULL, '2000-01-01 00:00:00', '2002-01-03 00:00:00', '2002-01-03 00:00:00'),
|
||||
(1, 'b', 1.1, NULL, '2000-01-02 00:00:00', '2002-02-04 00:00:00', '2002-02-04 00:00:00'),
|
||||
(2, 'c', 1.2, NULL, '2000-01-03 00:00:00', '2002-03-07 00:00:00', '2002-03-07 00:00:00'),
|
||||
(3, 'd', 1.3, NULL, '2000-01-04 00:00:00', '2002-04-12 00:00:00', '2002-04-12 00:00:00'),
|
||||
(4, 'e', 1.4, NULL, '2000-01-05 00:00:00', '2002-05-11 00:00:00', '2002-05-11 00:00:00'),
|
||||
(5, 'f', 1.5, NULL, '2000-01-06 00:00:00', '2002-06-13 00:00:00', '2002-06-13 00:00:00'),
|
||||
(6, 'g', 1.6, NULL, '2000-01-07 00:00:00', '2002-07-15 00:00:00', '2002-07-15 00:00:00'),
|
||||
(7, 'h', 1.7, NULL, '2000-01-08 00:00:00', '2002-08-18 00:00:00', '2002-08-18 00:00:00'),
|
||||
(8, 'i', 1.8, NULL, '2000-01-09 00:00:00', '2002-09-20 00:00:00', '2002-09-20 00:00:00'),
|
||||
(9, 'j', 1.9, NULL, '2000-01-10 00:00:00', '2002-10-22 00:00:00', '2002-10-22 00:00:00');
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
dataset = SqlaTable(
|
||||
table_name="physical_dataset",
|
||||
|
|
|
|||
|
|
@ -641,7 +641,7 @@ class TestImportDatabasesCommand(SupersetTestCase):
|
|||
|
||||
|
||||
class TestTestConnectionDatabaseCommand(SupersetTestCase):
|
||||
@mock.patch("superset.databases.dao.Database.get_sqla_engine")
|
||||
@mock.patch("superset.databases.dao.Database._get_sqla_engine")
|
||||
@mock.patch(
|
||||
"superset.databases.commands.test_connection.event_logger.log_with_context"
|
||||
)
|
||||
|
|
@ -664,7 +664,7 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
|
|||
)
|
||||
mock_event_logger.assert_called()
|
||||
|
||||
@mock.patch("superset.databases.dao.Database.get_sqla_engine")
|
||||
@mock.patch("superset.databases.dao.Database._get_sqla_engine")
|
||||
@mock.patch(
|
||||
"superset.databases.commands.test_connection.event_logger.log_with_context"
|
||||
)
|
||||
|
|
@ -713,7 +713,7 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
|
|||
== SupersetErrorType.CONNECTION_DATABASE_TIMEOUT
|
||||
)
|
||||
|
||||
@mock.patch("superset.databases.dao.Database.get_sqla_engine")
|
||||
@mock.patch("superset.databases.dao.Database._get_sqla_engine")
|
||||
@mock.patch(
|
||||
"superset.databases.commands.test_connection.event_logger.log_with_context"
|
||||
)
|
||||
|
|
@ -738,7 +738,7 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
|
|||
|
||||
mock_event_logger.assert_called()
|
||||
|
||||
@mock.patch("superset.databases.dao.Database.get_sqla_engine")
|
||||
@mock.patch("superset.databases.dao.Database._get_sqla_engine")
|
||||
@mock.patch(
|
||||
"superset.databases.commands.test_connection.event_logger.log_with_context"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -227,8 +227,10 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
|
|||
return_value="account_info"
|
||||
)
|
||||
|
||||
mock_get_engine.return_value.url.host = "google-host"
|
||||
mock_get_engine.return_value.dialect.credentials_info = "secrets"
|
||||
mock_get_engine.return_value.__enter__.return_value.url.host = "google-host"
|
||||
mock_get_engine.return_value.__enter__.return_value.dialect.credentials_info = (
|
||||
"secrets"
|
||||
)
|
||||
|
||||
BigQueryEngineSpec.df_to_sql(
|
||||
database=database,
|
||||
|
|
|
|||
|
|
@ -204,7 +204,9 @@ def test_df_to_sql_if_exists_replace(mock_upload_to_s3, mock_g):
|
|||
mock_database = mock.MagicMock()
|
||||
mock_database.get_df.return_value.empty = False
|
||||
mock_execute = mock.MagicMock(return_value=True)
|
||||
mock_database.get_sqla_engine.return_value.execute = mock_execute
|
||||
mock_database.get_sqla_engine_with_context.return_value.__enter__.return_value.execute = (
|
||||
mock_execute
|
||||
)
|
||||
table_name = "foobar"
|
||||
|
||||
with app.app_context():
|
||||
|
|
@ -229,7 +231,9 @@ def test_df_to_sql_if_exists_replace_with_schema(mock_upload_to_s3, mock_g):
|
|||
mock_database = mock.MagicMock()
|
||||
mock_database.get_df.return_value.empty = False
|
||||
mock_execute = mock.MagicMock(return_value=True)
|
||||
mock_database.get_sqla_engine.return_value.execute = mock_execute
|
||||
mock_database.get_sqla_engine_with_context.return_value.__enter__.return_value.execute = (
|
||||
mock_execute
|
||||
)
|
||||
table_name = "foobar"
|
||||
schema = "schema"
|
||||
|
||||
|
|
|
|||
|
|
@ -37,12 +37,13 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||
def test_get_view_names_with_schema(self):
|
||||
database = mock.MagicMock()
|
||||
mock_execute = mock.MagicMock()
|
||||
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = (
|
||||
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = (
|
||||
mock_execute
|
||||
)
|
||||
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall = mock.MagicMock(
|
||||
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.fetchall = mock.MagicMock(
|
||||
return_value=[["a", "b,", "c"], ["d", "e"]]
|
||||
)
|
||||
|
||||
schema = "schema"
|
||||
result = PrestoEngineSpec.get_view_names(database, mock.Mock(), schema)
|
||||
mock_execute.assert_called_once_with(
|
||||
|
|
@ -60,10 +61,10 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||
def test_get_view_names_without_schema(self):
|
||||
database = mock.MagicMock()
|
||||
mock_execute = mock.MagicMock()
|
||||
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = (
|
||||
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = (
|
||||
mock_execute
|
||||
)
|
||||
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall = mock.MagicMock(
|
||||
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.fetchall = mock.MagicMock(
|
||||
return_value=[["a", "b,", "c"], ["d", "e"]]
|
||||
)
|
||||
result = PrestoEngineSpec.get_view_names(database, mock.Mock(), None)
|
||||
|
|
@ -821,13 +822,13 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||
mock_execute = mock.MagicMock()
|
||||
mock_fetchall = mock.MagicMock(return_value=[["a", "b,", "c"], ["d", "e"]])
|
||||
database = mock.MagicMock()
|
||||
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = (
|
||||
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = (
|
||||
mock_execute
|
||||
)
|
||||
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall = (
|
||||
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.fetchall = (
|
||||
mock_fetchall
|
||||
)
|
||||
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.poll.return_value = (
|
||||
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.poll.return_value = (
|
||||
False
|
||||
)
|
||||
schema = "schema"
|
||||
|
|
@ -839,7 +840,7 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||
def test_get_create_view_exception(self):
|
||||
mock_execute = mock.MagicMock(side_effect=Exception())
|
||||
database = mock.MagicMock()
|
||||
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = (
|
||||
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = (
|
||||
mock_execute
|
||||
)
|
||||
schema = "schema"
|
||||
|
|
@ -852,7 +853,7 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||
|
||||
mock_execute = mock.MagicMock(side_effect=DatabaseError())
|
||||
database = mock.MagicMock()
|
||||
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = (
|
||||
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = (
|
||||
mock_execute
|
||||
)
|
||||
schema = "schema"
|
||||
|
|
|
|||
|
|
@ -51,8 +51,8 @@ def load_unicode_data():
|
|||
|
||||
yield
|
||||
with app.app_context():
|
||||
engine = get_example_database().get_sqla_engine()
|
||||
engine.execute("DROP TABLE IF EXISTS unicode_test")
|
||||
with get_example_database().get_sqla_engine_with_context() as engine:
|
||||
engine.execute("DROP TABLE IF EXISTS unicode_test")
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
|
|
|||
|
|
@ -64,8 +64,8 @@ def load_world_bank_data():
|
|||
|
||||
yield
|
||||
with app.app_context():
|
||||
engine = get_example_database().get_sqla_engine()
|
||||
engine.execute("DROP TABLE IF EXISTS wb_health_population")
|
||||
with get_example_database().get_sqla_engine_with_context() as engine:
|
||||
engine.execute("DROP TABLE IF EXISTS wb_health_population")
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
|
|
|||
|
|
@ -164,7 +164,7 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
database_name="test_database", sqlalchemy_uri=uri, extra=extra
|
||||
)
|
||||
model.impersonate_user = True
|
||||
model.get_sqla_engine()
|
||||
model._get_sqla_engine()
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
assert str(call_args[0][0]) == "presto://gamma@localhost"
|
||||
|
|
@ -177,7 +177,7 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
}
|
||||
|
||||
model.impersonate_user = False
|
||||
model.get_sqla_engine()
|
||||
model._get_sqla_engine()
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
assert str(call_args[0][0]) == "presto://localhost"
|
||||
|
|
@ -197,7 +197,7 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
database_name="test_database", sqlalchemy_uri="trino://localhost"
|
||||
)
|
||||
model.impersonate_user = True
|
||||
model.get_sqla_engine()
|
||||
model._get_sqla_engine()
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
assert str(call_args[0][0]) == "trino://localhost"
|
||||
|
|
@ -209,7 +209,7 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
)
|
||||
|
||||
model.impersonate_user = True
|
||||
model.get_sqla_engine()
|
||||
model._get_sqla_engine()
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
assert (
|
||||
|
|
@ -242,7 +242,7 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
database_name="test_database", sqlalchemy_uri=uri, extra=extra
|
||||
)
|
||||
model.impersonate_user = True
|
||||
model.get_sqla_engine()
|
||||
model._get_sqla_engine()
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
assert str(call_args[0][0]) == "hive://localhost"
|
||||
|
|
@ -255,7 +255,7 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
}
|
||||
|
||||
model.impersonate_user = False
|
||||
model.get_sqla_engine()
|
||||
model._get_sqla_engine()
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
assert str(call_args[0][0]) == "hive://localhost"
|
||||
|
|
@ -380,21 +380,7 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
)
|
||||
mocked_create_engine.side_effect = Exception()
|
||||
with self.assertRaises(SupersetException):
|
||||
model.get_sqla_engine()
|
||||
|
||||
# todo(hughhh): update this test
|
||||
# @mock.patch("superset.models.core.create_engine")
|
||||
# def test_get_sqla_engine_with_context(self, mocked_create_engine):
|
||||
# model = Database(
|
||||
# database_name="test_database",
|
||||
# sqlalchemy_uri="mysql://root@localhost",
|
||||
# )
|
||||
# model.db_engine_spec.get_dbapi_exception_mapping = mock.Mock(
|
||||
# return_value={Exception: SupersetException}
|
||||
# )
|
||||
# mocked_create_engine.side_effect = Exception()
|
||||
# with self.assertRaises(SupersetException):
|
||||
# model.get_sqla_engine()
|
||||
model._get_sqla_engine()
|
||||
|
||||
|
||||
class TestSqlaTableModel(SupersetTestCase):
|
||||
|
|
|
|||
|
|
@ -174,7 +174,9 @@ class TestPrestoValidator(SupersetTestCase):
|
|||
def setUp(self):
|
||||
self.validator = PrestoDBSQLValidator
|
||||
self.database = MagicMock()
|
||||
self.database_engine = self.database.get_sqla_engine.return_value
|
||||
self.database_engine = (
|
||||
self.database.get_sqla_engine_with_context.return_value.__enter__.return_value
|
||||
)
|
||||
self.database_conn = self.database_engine.raw_connection.return_value
|
||||
self.database_cursor = self.database_conn.cursor.return_value
|
||||
self.database_cursor.poll.return_value = None
|
||||
|
|
|
|||
|
|
@ -733,7 +733,7 @@ class TestSqlLab(SupersetTestCase):
|
|||
mock_query = mock.MagicMock()
|
||||
mock_query.database.allow_run_async = False
|
||||
mock_cursor = mock.MagicMock()
|
||||
mock_query.database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value = (
|
||||
mock_query.database.get_sqla_engine_with_context().__enter__().raw_connection().cursor.return_value = (
|
||||
mock_cursor
|
||||
)
|
||||
mock_query.database.db_engine_spec.run_multiple_statements_as_one = False
|
||||
|
|
@ -786,7 +786,7 @@ class TestSqlLab(SupersetTestCase):
|
|||
mock_query = mock.MagicMock()
|
||||
mock_query.database.allow_run_async = True
|
||||
mock_cursor = mock.MagicMock()
|
||||
mock_query.database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value = (
|
||||
mock_query.database.get_sqla_engine_with_context().__enter__().raw_connection().cursor.return_value = (
|
||||
mock_cursor
|
||||
)
|
||||
mock_query.database.db_engine_spec.run_multiple_statements_as_one = False
|
||||
|
|
@ -836,7 +836,7 @@ class TestSqlLab(SupersetTestCase):
|
|||
mock_query = mock.MagicMock()
|
||||
mock_query.database.allow_run_async = False
|
||||
mock_cursor = mock.MagicMock()
|
||||
mock_query.database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value = (
|
||||
mock_query.database.get_sqla_engine_with_context().__enter__().raw_connection().cursor.return_value = (
|
||||
mock_cursor
|
||||
)
|
||||
mock_query.database.db_engine_spec.run_multiple_statements_as_one = False
|
||||
|
|
|
|||
Loading…
Reference in New Issue