refactor(db_engine_specs): Removing top-level import of app (#14366)

Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
John Bodley 2021-04-28 15:47:32 +12:00 committed by GitHub
parent 77d17152bc
commit d8bb2d3e62
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 143 additions and 142 deletions

View File

@ -40,7 +40,7 @@ import pandas as pd
import sqlparse
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
from flask import g
from flask import current_app, g
from flask_babel import gettext as __, lazy_gettext as _
from marshmallow import fields, Schema
from sqlalchemy import column, DateTime, select, types
@ -55,7 +55,7 @@ from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom
from sqlalchemy.types import String, TypeEngine, UnicodeText
from typing_extensions import TypedDict
from superset import app, security_manager, sql_parse
from superset import security_manager, sql_parse
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.models.sql_lab import Query
from superset.models.sql_types.base import literal_dttm_type_factory
@ -80,7 +80,6 @@ class TimeGrain(NamedTuple): # pylint: disable=too-few-public-methods
QueryStatus = utils.QueryStatus
config = app.config
builtin_time_grains: Dict[Optional[str], str] = {
None: __("Original value"),
@ -369,7 +368,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
ret_list = []
time_grains = builtin_time_grains.copy()
time_grains.update(config["TIME_GRAIN_ADDONS"])
time_grains.update(current_app.config["TIME_GRAIN_ADDONS"])
for duration, func in cls.get_time_grain_expressions().items():
if duration in time_grains:
name = time_grains[duration]
@ -448,9 +447,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
"""
# TODO: use @memoize decorator or similar to avoid recomputation on every call
time_grain_expressions = cls._time_grain_expressions.copy()
grain_addon_expressions = config["TIME_GRAIN_ADDON_EXPRESSIONS"]
grain_addon_expressions = current_app.config["TIME_GRAIN_ADDON_EXPRESSIONS"]
time_grain_expressions.update(grain_addon_expressions.get(cls.engine, {}))
denylist: List[str] = config["TIME_GRAIN_DENYLIST"]
denylist: List[str] = current_app.config["TIME_GRAIN_DENYLIST"]
for key in denylist:
time_grain_expressions.pop(key)
@ -977,7 +976,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
"""
parsed_query = ParsedQuery(statement)
sql = parsed_query.stripped()
sql_query_mutator = config["SQL_QUERY_MUTATOR"]
sql_query_mutator = current_app.config["SQL_QUERY_MUTATOR"]
if sql_query_mutator:
sql = sql_query_mutator(sql, user_name, security_manager, database)

View File

@ -27,7 +27,7 @@ import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from flask import g
from flask import current_app, g
from sqlalchemy import Column, text
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.reflection import Inspector
@ -35,7 +35,6 @@ from sqlalchemy.engine.url import make_url, URL
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import ColumnClause, Select
from superset import app, conf
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.presto import PrestoEngineSpec
from superset.exceptions import SupersetException
@ -50,12 +49,8 @@ if TYPE_CHECKING:
QueryStatus = utils.QueryStatus
config = app.config
logger = logging.getLogger(__name__)
tracking_url_trans = conf.get("TRACKING_URL_TRANSFORMER")
hive_poll_interval = conf.get("HIVE_POLL_INTERVAL")
def upload_to_s3(filename: str, upload_prefix: str, table: Table) -> str:
"""
@ -70,7 +65,7 @@ def upload_to_s3(filename: str, upload_prefix: str, table: Table) -> str:
# Optional dependency
import boto3 # pylint: disable=import-error
bucket_path = config["CSV_TO_HIVE_UPLOAD_S3_BUCKET"]
bucket_path = current_app.config["CSV_TO_HIVE_UPLOAD_S3_BUCKET"]
if not bucket_path:
logger.info("No upload bucket specified")
@ -229,7 +224,7 @@ class HiveEngineSpec(PrestoEngineSpec):
)
with tempfile.NamedTemporaryFile(
dir=config["UPLOAD_FOLDER"], suffix=".parquet"
dir=current_app.config["UPLOAD_FOLDER"], suffix=".parquet"
) as file:
pq.write_table(pa.Table.from_pandas(df), where=file.name)
@ -243,9 +238,9 @@ class HiveEngineSpec(PrestoEngineSpec):
),
location=upload_to_s3(
filename=file.name,
upload_prefix=config["CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"](
database, g.user, table.schema
),
upload_prefix=current_app.config[
"CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"
](database, g.user, table.schema),
table=table,
),
)
@ -356,7 +351,7 @@ class HiveEngineSpec(PrestoEngineSpec):
str(query_id),
tracking_url,
)
tracking_url = tracking_url_trans(tracking_url)
tracking_url = current_app.config["TRACKING_URL_TRANSFORMER"]
logger.info(
"Query %s: Transformation applied: %s",
str(query_id),
@ -374,7 +369,7 @@ class HiveEngineSpec(PrestoEngineSpec):
last_log_line = len(log_lines)
if needs_commit:
session.commit()
time.sleep(hive_poll_interval)
time.sleep(current_app.config["HIVE_POLL_INTERVAL"])
polled = cursor.poll()
@classmethod

View File

@ -39,6 +39,7 @@ from urllib import parse
import pandas as pd
import simplejson as json
from flask import current_app
from flask_babel import gettext as __, lazy_gettext as _
from sqlalchemy import Column, literal_column, types
from sqlalchemy.engine.base import Engine
@ -49,7 +50,7 @@ from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import ColumnClause, Select
from sqlalchemy.types import TypeEngine
from superset import app, cache_manager, is_feature_enabled
from superset import cache_manager, is_feature_enabled
from superset.db_engine_specs.base import BaseEngineSpec
from superset.errors import SupersetErrorType
from superset.exceptions import SupersetTemplateException
@ -94,7 +95,6 @@ CONNECTION_UNKNOWN_DATABASE_ERROR = re.compile(
QueryStatus = utils.QueryStatus
config = app.config
logger = logging.getLogger(__name__)
@ -940,7 +940,7 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho
"""Updates progress information"""
query_id = query.id
poll_interval = query.database.connect_args.get(
"poll_interval", config["PRESTO_POLL_INTERVAL"]
"poll_interval", current_app.config["PRESTO_POLL_INTERVAL"]
)
logger.info("Query %i: Polling the cursor for progress", query_id)
polled = cursor.poll()

View File

@ -21,7 +21,6 @@ import pytest
from sqlalchemy.engine import Engine
from tests.test_app import app
from superset import db
from superset.utils.core import get_example_database, json_dumps_w_dates
@ -38,6 +37,9 @@ def app_context():
@pytest.fixture(autouse=True, scope="session")
def setup_sample_data() -> Any:
# TODO(john-bodley): Determine a cleaner way of setting up the sample data without
# relying on `tests.test_app.app` leveraging an `app` fixture which is purposely
# scoped to the function level to ensure tests remain idempotent.
with app.app_context():
setup_presto_if_needed()

View File

@ -14,8 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from tests.test_app import app # isort:skip
from superset.db_engine_specs.athena import AthenaEngineSpec
from tests.db_engine_specs.base_tests import TestDbEngineSpec

View File

@ -167,25 +167,6 @@ class TestDbEngineSpecs(TestDbEngineSpec):
"SELECT * FROM table", "SELECT * FROM table", DummyEngineSpec
)
def test_time_grain_denylist(self):
with app.app_context():
app.config["TIME_GRAIN_DENYLIST"] = ["PT1M"]
time_grain_functions = SqliteEngineSpec.get_time_grain_expressions()
self.assertNotIn("PT1M", time_grain_functions)
def test_time_grain_addons(self):
with app.app_context():
app.config["TIME_GRAIN_ADDONS"] = {"PTXM": "x seconds"}
app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {
"sqlite": {"PTXM": "ABC({col})"}
}
time_grains = SqliteEngineSpec.get_time_grains()
time_grain_addon = time_grains[-1]
self.assertEqual("PTXM", time_grain_addon.duration)
self.assertEqual("x seconds", time_grain_addon.label)
app.config["TIME_GRAIN_ADDONS"] = {}
app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {}
def test_engine_time_grain_validity(self):
time_grains = set(builtin_time_grains.keys())
# loop over all subclasses of BaseEngineSpec
@ -198,43 +179,6 @@ class TestDbEngineSpecs(TestDbEngineSpec):
intersection = time_grains.intersection(defined_grains)
self.assertSetEqual(defined_grains, intersection, engine)
def test_get_time_grain_with_config(self):
""" Should concatenate from configs and then sort in the proper order """
app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {
"mysql": {
"PT2H": "foo",
"PT4H": "foo",
"PT6H": "foo",
"PT8H": "foo",
"PT10H": "foo",
"PT12H": "foo",
"PT1S": "foo",
}
}
time_grains = MySQLEngineSpec.get_time_grain_expressions()
self.assertEqual(
list(time_grains.keys()),
[
None,
"PT1S",
"PT1M",
"PT1H",
"PT2H",
"PT4H",
"PT6H",
"PT8H",
"PT10H",
"PT12H",
"P1D",
"P1W",
"P1M",
"P0.25Y",
"P1Y",
"1969-12-29T00:00:00Z/P1W",
],
)
app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {}
def test_get_time_grain_expressions(self):
time_grains = MySQLEngineSpec.get_time_grain_expressions()
self.assertEqual(
@ -253,18 +197,6 @@ class TestDbEngineSpecs(TestDbEngineSpec):
],
)
def test_get_time_grain_with_unkown_values(self):
"""Should concatenate from configs and then sort in the proper order
putting unknown patterns at the end"""
app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {
"mysql": {"PT2H": "foo", "weird": "foo", "PT12H": "foo",}
}
time_grains = MySQLEngineSpec.get_time_grain_expressions()
self.assertEqual(
list(time_grains)[-1], "weird",
)
app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {}
def test_get_table_names(self):
inspector = mock.Mock()
inspector.get_table_names = mock.Mock(return_value=["schema.table", "table_2"])
@ -339,3 +271,84 @@ def test_is_readonly():
assert is_readonly("WITH (SELECT 1) bla SELECT * from bla")
assert is_readonly("SHOW CATALOGS")
assert is_readonly("SHOW TABLES")
def test_time_grain_denylist():
config = app.config.copy()
app.config["TIME_GRAIN_DENYLIST"] = ["PT1M"]
with app.app_context():
time_grain_functions = SqliteEngineSpec.get_time_grain_expressions()
assert not "PT1M" in time_grain_functions
app.config = config
def test_time_grain_addons():
config = app.config.copy()
app.config["TIME_GRAIN_ADDONS"] = {"PTXM": "x seconds"}
app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {"sqlite": {"PTXM": "ABC({col})"}}
with app.app_context():
time_grains = SqliteEngineSpec.get_time_grains()
time_grain_addon = time_grains[-1]
assert "PTXM" == time_grain_addon.duration
assert "x seconds" == time_grain_addon.label
app.config = config
def test_get_time_grain_with_config():
""" Should concatenate from configs and then sort in the proper order """
config = app.config.copy()
app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {
"mysql": {
"PT2H": "foo",
"PT4H": "foo",
"PT6H": "foo",
"PT8H": "foo",
"PT10H": "foo",
"PT12H": "foo",
"PT1S": "foo",
}
}
with app.app_context():
time_grains = MySQLEngineSpec.get_time_grain_expressions()
assert set(time_grains.keys()) == {
None,
"PT1S",
"PT1M",
"PT1H",
"PT2H",
"PT4H",
"PT6H",
"PT8H",
"PT10H",
"PT12H",
"P1D",
"P1W",
"P1M",
"P0.25Y",
"P1Y",
"1969-12-29T00:00:00Z/P1W",
}
app.config = config
def test_get_time_grain_with_unkown_values():
"""Should concatenate from configs and then sort in the proper order
putting unknown patterns at the end"""
config = app.config.copy()
app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {
"mysql": {"PT2H": "foo", "weird": "foo", "PT12H": "foo",}
}
with app.app_context():
time_grains = MySQLEngineSpec.get_time_grain_expressions()
assert list(time_grains)[-1] == "weird"
app.config = config

View File

@ -21,12 +21,11 @@ from unittest import mock
import pytest
import pandas as pd
from sqlalchemy.sql import select
from tests.test_app import app
with app.app_context():
from superset.db_engine_specs.hive import HiveEngineSpec, upload_to_s3
from superset.db_engine_specs.hive import HiveEngineSpec, upload_to_s3
from superset.exceptions import SupersetException
from superset.sql_parse import Table, ParsedQuery
from tests.test_app import app
def test_0_progress():
@ -170,10 +169,6 @@ def test_df_to_csv() -> None:
)
@mock.patch(
"superset.db_engine_specs.hive.config",
{**app.config, "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC": lambda *args: ""},
)
@mock.patch("superset.db_engine_specs.hive.g", spec={})
def test_df_to_sql_if_exists_fail(mock_g):
mock_g.user = True
@ -185,10 +180,6 @@ def test_df_to_sql_if_exists_fail(mock_g):
)
@mock.patch(
"superset.db_engine_specs.hive.config",
{**app.config, "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC": lambda *args: ""},
)
@mock.patch("superset.db_engine_specs.hive.g", spec={})
def test_df_to_sql_if_exists_fail_with_schema(mock_g):
mock_g.user = True
@ -203,13 +194,11 @@ def test_df_to_sql_if_exists_fail_with_schema(mock_g):
)
@mock.patch(
"superset.db_engine_specs.hive.config",
{**app.config, "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC": lambda *args: ""},
)
@mock.patch("superset.db_engine_specs.hive.g", spec={})
@mock.patch("superset.db_engine_specs.hive.upload_to_s3")
def test_df_to_sql_if_exists_replace(mock_upload_to_s3, mock_g):
config = app.config.copy()
app.config["CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"]: lambda *args: ""
mock_upload_to_s3.return_value = "mock-location"
mock_g.user = True
mock_database = mock.MagicMock()
@ -218,23 +207,23 @@ def test_df_to_sql_if_exists_replace(mock_upload_to_s3, mock_g):
mock_database.get_sqla_engine.return_value.execute = mock_execute
table_name = "foobar"
HiveEngineSpec.df_to_sql(
mock_database,
Table(table=table_name),
pd.DataFrame(),
{"if_exists": "replace", "header": 1, "na_values": "mock", "sep": "mock"},
)
with app.app_context():
HiveEngineSpec.df_to_sql(
mock_database,
Table(table=table_name),
pd.DataFrame(),
{"if_exists": "replace", "header": 1, "na_values": "mock", "sep": "mock"},
)
mock_execute.assert_any_call(f"DROP TABLE IF EXISTS {table_name}")
app.config = config
@mock.patch(
"superset.db_engine_specs.hive.config",
{**app.config, "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC": lambda *args: ""},
)
@mock.patch("superset.db_engine_specs.hive.g", spec={})
@mock.patch("superset.db_engine_specs.hive.upload_to_s3")
def test_df_to_sql_if_exists_replace_with_schema(mock_upload_to_s3, mock_g):
config = app.config.copy()
app.config["CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"]: lambda *args: ""
mock_upload_to_s3.return_value = "mock-location"
mock_g.user = True
mock_database = mock.MagicMock()
@ -244,14 +233,16 @@ def test_df_to_sql_if_exists_replace_with_schema(mock_upload_to_s3, mock_g):
table_name = "foobar"
schema = "schema"
HiveEngineSpec.df_to_sql(
mock_database,
Table(table=table_name, schema=schema),
pd.DataFrame(),
{"if_exists": "replace", "header": 1, "na_values": "mock", "sep": "mock"},
)
with app.app_context():
HiveEngineSpec.df_to_sql(
mock_database,
Table(table=table_name, schema=schema),
pd.DataFrame(),
{"if_exists": "replace", "header": 1, "na_values": "mock", "sep": "mock"},
)
mock_execute.assert_any_call(f"DROP TABLE IF EXISTS {schema}.{table_name}")
app.config = config
def test_is_readonly():
@ -284,39 +275,42 @@ def test_s3_upload_prefix(schema: str, upload_prefix: str) -> None:
def test_upload_to_s3_no_bucket_path():
with pytest.raises(
Exception,
match="No upload bucket specified. You can specify one in the config file.",
):
upload_to_s3("filename", "prefix", Table("table"))
with app.app_context():
with pytest.raises(
Exception,
match="No upload bucket specified. You can specify one in the config file.",
):
upload_to_s3("filename", "prefix", Table("table"))
@mock.patch("boto3.client")
@mock.patch(
"superset.db_engine_specs.hive.config",
{**app.config, "CSV_TO_HIVE_UPLOAD_S3_BUCKET": "bucket"},
)
def test_upload_to_s3_client_error(client):
config = app.config.copy()
app.config["CSV_TO_HIVE_UPLOAD_S3_BUCKET"] = "bucket"
from botocore.exceptions import ClientError
client.return_value.upload_file.side_effect = ClientError(
{"Error": {}}, "operation_name"
)
with pytest.raises(ClientError):
upload_to_s3("filename", "prefix", Table("table"))
with app.app_context():
with pytest.raises(ClientError):
upload_to_s3("filename", "prefix", Table("table"))
app.config = config
@mock.patch("boto3.client")
@mock.patch(
"superset.db_engine_specs.hive.config",
{**app.config, "CSV_TO_HIVE_UPLOAD_S3_BUCKET": "bucket"},
)
def test_upload_to_s3_success(client):
config = app.config.copy()
app.config["CSV_TO_HIVE_UPLOAD_S3_BUCKET"] = "bucket"
client.return_value.upload_file.return_value = True
location = upload_to_s3("filename", "prefix", Table("table"))
assert f"s3a://bucket/prefix/table" == location
with app.app_context():
location = upload_to_s3("filename", "prefix", Table("table"))
assert f"s3a://bucket/prefix/table" == location
app.config = config
def test_fetch_data_query_error():