refactor(db_engine_specs): Removing top-level import of app (#14366)
Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
parent
77d17152bc
commit
d8bb2d3e62
|
|
@ -40,7 +40,7 @@ import pandas as pd
|
|||
import sqlparse
|
||||
from apispec import APISpec
|
||||
from apispec.ext.marshmallow import MarshmallowPlugin
|
||||
from flask import g
|
||||
from flask import current_app, g
|
||||
from flask_babel import gettext as __, lazy_gettext as _
|
||||
from marshmallow import fields, Schema
|
||||
from sqlalchemy import column, DateTime, select, types
|
||||
|
|
@ -55,7 +55,7 @@ from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom
|
|||
from sqlalchemy.types import String, TypeEngine, UnicodeText
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from superset import app, security_manager, sql_parse
|
||||
from superset import security_manager, sql_parse
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.models.sql_types.base import literal_dttm_type_factory
|
||||
|
|
@ -80,7 +80,6 @@ class TimeGrain(NamedTuple): # pylint: disable=too-few-public-methods
|
|||
|
||||
|
||||
QueryStatus = utils.QueryStatus
|
||||
config = app.config
|
||||
|
||||
builtin_time_grains: Dict[Optional[str], str] = {
|
||||
None: __("Original value"),
|
||||
|
|
@ -369,7 +368,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
|
||||
ret_list = []
|
||||
time_grains = builtin_time_grains.copy()
|
||||
time_grains.update(config["TIME_GRAIN_ADDONS"])
|
||||
time_grains.update(current_app.config["TIME_GRAIN_ADDONS"])
|
||||
for duration, func in cls.get_time_grain_expressions().items():
|
||||
if duration in time_grains:
|
||||
name = time_grains[duration]
|
||||
|
|
@ -448,9 +447,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
"""
|
||||
# TODO: use @memoize decorator or similar to avoid recomputation on every call
|
||||
time_grain_expressions = cls._time_grain_expressions.copy()
|
||||
grain_addon_expressions = config["TIME_GRAIN_ADDON_EXPRESSIONS"]
|
||||
grain_addon_expressions = current_app.config["TIME_GRAIN_ADDON_EXPRESSIONS"]
|
||||
time_grain_expressions.update(grain_addon_expressions.get(cls.engine, {}))
|
||||
denylist: List[str] = config["TIME_GRAIN_DENYLIST"]
|
||||
denylist: List[str] = current_app.config["TIME_GRAIN_DENYLIST"]
|
||||
for key in denylist:
|
||||
time_grain_expressions.pop(key)
|
||||
|
||||
|
|
@ -977,7 +976,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
"""
|
||||
parsed_query = ParsedQuery(statement)
|
||||
sql = parsed_query.stripped()
|
||||
sql_query_mutator = config["SQL_QUERY_MUTATOR"]
|
||||
sql_query_mutator = current_app.config["SQL_QUERY_MUTATOR"]
|
||||
if sql_query_mutator:
|
||||
sql = sql_query_mutator(sql, user_name, security_manager, database)
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ import numpy as np
|
|||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
from flask import g
|
||||
from flask import current_app, g
|
||||
from sqlalchemy import Column, text
|
||||
from sqlalchemy.engine.base import Engine
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
|
|
@ -35,7 +35,6 @@ from sqlalchemy.engine.url import make_url, URL
|
|||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.expression import ColumnClause, Select
|
||||
|
||||
from superset import app, conf
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.db_engine_specs.presto import PrestoEngineSpec
|
||||
from superset.exceptions import SupersetException
|
||||
|
|
@ -50,12 +49,8 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
QueryStatus = utils.QueryStatus
|
||||
config = app.config
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
tracking_url_trans = conf.get("TRACKING_URL_TRANSFORMER")
|
||||
hive_poll_interval = conf.get("HIVE_POLL_INTERVAL")
|
||||
|
||||
|
||||
def upload_to_s3(filename: str, upload_prefix: str, table: Table) -> str:
|
||||
"""
|
||||
|
|
@ -70,7 +65,7 @@ def upload_to_s3(filename: str, upload_prefix: str, table: Table) -> str:
|
|||
# Optional dependency
|
||||
import boto3 # pylint: disable=import-error
|
||||
|
||||
bucket_path = config["CSV_TO_HIVE_UPLOAD_S3_BUCKET"]
|
||||
bucket_path = current_app.config["CSV_TO_HIVE_UPLOAD_S3_BUCKET"]
|
||||
|
||||
if not bucket_path:
|
||||
logger.info("No upload bucket specified")
|
||||
|
|
@ -229,7 +224,7 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
dir=config["UPLOAD_FOLDER"], suffix=".parquet"
|
||||
dir=current_app.config["UPLOAD_FOLDER"], suffix=".parquet"
|
||||
) as file:
|
||||
pq.write_table(pa.Table.from_pandas(df), where=file.name)
|
||||
|
||||
|
|
@ -243,9 +238,9 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||
),
|
||||
location=upload_to_s3(
|
||||
filename=file.name,
|
||||
upload_prefix=config["CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"](
|
||||
database, g.user, table.schema
|
||||
),
|
||||
upload_prefix=current_app.config[
|
||||
"CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"
|
||||
](database, g.user, table.schema),
|
||||
table=table,
|
||||
),
|
||||
)
|
||||
|
|
@ -356,7 +351,7 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||
str(query_id),
|
||||
tracking_url,
|
||||
)
|
||||
tracking_url = tracking_url_trans(tracking_url)
|
||||
tracking_url = current_app.config["TRACKING_URL_TRANSFORMER"]
|
||||
logger.info(
|
||||
"Query %s: Transformation applied: %s",
|
||||
str(query_id),
|
||||
|
|
@ -374,7 +369,7 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||
last_log_line = len(log_lines)
|
||||
if needs_commit:
|
||||
session.commit()
|
||||
time.sleep(hive_poll_interval)
|
||||
time.sleep(current_app.config["HIVE_POLL_INTERVAL"])
|
||||
polled = cursor.poll()
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ from urllib import parse
|
|||
|
||||
import pandas as pd
|
||||
import simplejson as json
|
||||
from flask import current_app
|
||||
from flask_babel import gettext as __, lazy_gettext as _
|
||||
from sqlalchemy import Column, literal_column, types
|
||||
from sqlalchemy.engine.base import Engine
|
||||
|
|
@ -49,7 +50,7 @@ from sqlalchemy.orm import Session
|
|||
from sqlalchemy.sql.expression import ColumnClause, Select
|
||||
from sqlalchemy.types import TypeEngine
|
||||
|
||||
from superset import app, cache_manager, is_feature_enabled
|
||||
from superset import cache_manager, is_feature_enabled
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.errors import SupersetErrorType
|
||||
from superset.exceptions import SupersetTemplateException
|
||||
|
|
@ -94,7 +95,6 @@ CONNECTION_UNKNOWN_DATABASE_ERROR = re.compile(
|
|||
|
||||
|
||||
QueryStatus = utils.QueryStatus
|
||||
config = app.config
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -940,7 +940,7 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho
|
|||
"""Updates progress information"""
|
||||
query_id = query.id
|
||||
poll_interval = query.database.connect_args.get(
|
||||
"poll_interval", config["PRESTO_POLL_INTERVAL"]
|
||||
"poll_interval", current_app.config["PRESTO_POLL_INTERVAL"]
|
||||
)
|
||||
logger.info("Query %i: Polling the cursor for progress", query_id)
|
||||
polled = cursor.poll()
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ import pytest
|
|||
from sqlalchemy.engine import Engine
|
||||
|
||||
from tests.test_app import app
|
||||
|
||||
from superset import db
|
||||
from superset.utils.core import get_example_database, json_dumps_w_dates
|
||||
|
||||
|
|
@ -38,6 +37,9 @@ def app_context():
|
|||
|
||||
@pytest.fixture(autouse=True, scope="session")
|
||||
def setup_sample_data() -> Any:
|
||||
# TODO(john-bodley): Determine a cleaner way of setting up the sample data without
|
||||
# relying on `tests.test_app.app` leveraging an `app` fixture which is purposely
|
||||
# scoped to the function level to ensure tests remain idempotent.
|
||||
with app.app_context():
|
||||
setup_presto_if_needed()
|
||||
|
||||
|
|
|
|||
|
|
@ -14,8 +14,6 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from tests.test_app import app # isort:skip
|
||||
|
||||
from superset.db_engine_specs.athena import AthenaEngineSpec
|
||||
from tests.db_engine_specs.base_tests import TestDbEngineSpec
|
||||
|
||||
|
|
|
|||
|
|
@ -167,25 +167,6 @@ class TestDbEngineSpecs(TestDbEngineSpec):
|
|||
"SELECT * FROM table", "SELECT * FROM table", DummyEngineSpec
|
||||
)
|
||||
|
||||
def test_time_grain_denylist(self):
|
||||
with app.app_context():
|
||||
app.config["TIME_GRAIN_DENYLIST"] = ["PT1M"]
|
||||
time_grain_functions = SqliteEngineSpec.get_time_grain_expressions()
|
||||
self.assertNotIn("PT1M", time_grain_functions)
|
||||
|
||||
def test_time_grain_addons(self):
|
||||
with app.app_context():
|
||||
app.config["TIME_GRAIN_ADDONS"] = {"PTXM": "x seconds"}
|
||||
app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {
|
||||
"sqlite": {"PTXM": "ABC({col})"}
|
||||
}
|
||||
time_grains = SqliteEngineSpec.get_time_grains()
|
||||
time_grain_addon = time_grains[-1]
|
||||
self.assertEqual("PTXM", time_grain_addon.duration)
|
||||
self.assertEqual("x seconds", time_grain_addon.label)
|
||||
app.config["TIME_GRAIN_ADDONS"] = {}
|
||||
app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {}
|
||||
|
||||
def test_engine_time_grain_validity(self):
|
||||
time_grains = set(builtin_time_grains.keys())
|
||||
# loop over all subclasses of BaseEngineSpec
|
||||
|
|
@ -198,43 +179,6 @@ class TestDbEngineSpecs(TestDbEngineSpec):
|
|||
intersection = time_grains.intersection(defined_grains)
|
||||
self.assertSetEqual(defined_grains, intersection, engine)
|
||||
|
||||
def test_get_time_grain_with_config(self):
|
||||
""" Should concatenate from configs and then sort in the proper order """
|
||||
app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {
|
||||
"mysql": {
|
||||
"PT2H": "foo",
|
||||
"PT4H": "foo",
|
||||
"PT6H": "foo",
|
||||
"PT8H": "foo",
|
||||
"PT10H": "foo",
|
||||
"PT12H": "foo",
|
||||
"PT1S": "foo",
|
||||
}
|
||||
}
|
||||
time_grains = MySQLEngineSpec.get_time_grain_expressions()
|
||||
self.assertEqual(
|
||||
list(time_grains.keys()),
|
||||
[
|
||||
None,
|
||||
"PT1S",
|
||||
"PT1M",
|
||||
"PT1H",
|
||||
"PT2H",
|
||||
"PT4H",
|
||||
"PT6H",
|
||||
"PT8H",
|
||||
"PT10H",
|
||||
"PT12H",
|
||||
"P1D",
|
||||
"P1W",
|
||||
"P1M",
|
||||
"P0.25Y",
|
||||
"P1Y",
|
||||
"1969-12-29T00:00:00Z/P1W",
|
||||
],
|
||||
)
|
||||
app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {}
|
||||
|
||||
def test_get_time_grain_expressions(self):
|
||||
time_grains = MySQLEngineSpec.get_time_grain_expressions()
|
||||
self.assertEqual(
|
||||
|
|
@ -253,18 +197,6 @@ class TestDbEngineSpecs(TestDbEngineSpec):
|
|||
],
|
||||
)
|
||||
|
||||
def test_get_time_grain_with_unkown_values(self):
|
||||
"""Should concatenate from configs and then sort in the proper order
|
||||
putting unknown patterns at the end"""
|
||||
app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {
|
||||
"mysql": {"PT2H": "foo", "weird": "foo", "PT12H": "foo",}
|
||||
}
|
||||
time_grains = MySQLEngineSpec.get_time_grain_expressions()
|
||||
self.assertEqual(
|
||||
list(time_grains)[-1], "weird",
|
||||
)
|
||||
app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {}
|
||||
|
||||
def test_get_table_names(self):
|
||||
inspector = mock.Mock()
|
||||
inspector.get_table_names = mock.Mock(return_value=["schema.table", "table_2"])
|
||||
|
|
@ -339,3 +271,84 @@ def test_is_readonly():
|
|||
assert is_readonly("WITH (SELECT 1) bla SELECT * from bla")
|
||||
assert is_readonly("SHOW CATALOGS")
|
||||
assert is_readonly("SHOW TABLES")
|
||||
|
||||
|
||||
def test_time_grain_denylist():
|
||||
config = app.config.copy()
|
||||
app.config["TIME_GRAIN_DENYLIST"] = ["PT1M"]
|
||||
|
||||
with app.app_context():
|
||||
time_grain_functions = SqliteEngineSpec.get_time_grain_expressions()
|
||||
assert not "PT1M" in time_grain_functions
|
||||
|
||||
app.config = config
|
||||
|
||||
|
||||
def test_time_grain_addons():
|
||||
config = app.config.copy()
|
||||
app.config["TIME_GRAIN_ADDONS"] = {"PTXM": "x seconds"}
|
||||
app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {"sqlite": {"PTXM": "ABC({col})"}}
|
||||
|
||||
with app.app_context():
|
||||
time_grains = SqliteEngineSpec.get_time_grains()
|
||||
time_grain_addon = time_grains[-1]
|
||||
assert "PTXM" == time_grain_addon.duration
|
||||
assert "x seconds" == time_grain_addon.label
|
||||
|
||||
app.config = config
|
||||
|
||||
|
||||
def test_get_time_grain_with_config():
|
||||
""" Should concatenate from configs and then sort in the proper order """
|
||||
config = app.config.copy()
|
||||
|
||||
app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {
|
||||
"mysql": {
|
||||
"PT2H": "foo",
|
||||
"PT4H": "foo",
|
||||
"PT6H": "foo",
|
||||
"PT8H": "foo",
|
||||
"PT10H": "foo",
|
||||
"PT12H": "foo",
|
||||
"PT1S": "foo",
|
||||
}
|
||||
}
|
||||
|
||||
with app.app_context():
|
||||
time_grains = MySQLEngineSpec.get_time_grain_expressions()
|
||||
assert set(time_grains.keys()) == {
|
||||
None,
|
||||
"PT1S",
|
||||
"PT1M",
|
||||
"PT1H",
|
||||
"PT2H",
|
||||
"PT4H",
|
||||
"PT6H",
|
||||
"PT8H",
|
||||
"PT10H",
|
||||
"PT12H",
|
||||
"P1D",
|
||||
"P1W",
|
||||
"P1M",
|
||||
"P0.25Y",
|
||||
"P1Y",
|
||||
"1969-12-29T00:00:00Z/P1W",
|
||||
}
|
||||
|
||||
app.config = config
|
||||
|
||||
|
||||
def test_get_time_grain_with_unkown_values():
|
||||
"""Should concatenate from configs and then sort in the proper order
|
||||
putting unknown patterns at the end"""
|
||||
config = app.config.copy()
|
||||
|
||||
app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {
|
||||
"mysql": {"PT2H": "foo", "weird": "foo", "PT12H": "foo",}
|
||||
}
|
||||
|
||||
with app.app_context():
|
||||
time_grains = MySQLEngineSpec.get_time_grain_expressions()
|
||||
assert list(time_grains)[-1] == "weird"
|
||||
|
||||
app.config = config
|
||||
|
|
|
|||
|
|
@ -21,12 +21,11 @@ from unittest import mock
|
|||
import pytest
|
||||
import pandas as pd
|
||||
from sqlalchemy.sql import select
|
||||
from tests.test_app import app
|
||||
|
||||
with app.app_context():
|
||||
from superset.db_engine_specs.hive import HiveEngineSpec, upload_to_s3
|
||||
from superset.db_engine_specs.hive import HiveEngineSpec, upload_to_s3
|
||||
from superset.exceptions import SupersetException
|
||||
from superset.sql_parse import Table, ParsedQuery
|
||||
from tests.test_app import app
|
||||
|
||||
|
||||
def test_0_progress():
|
||||
|
|
@ -170,10 +169,6 @@ def test_df_to_csv() -> None:
|
|||
)
|
||||
|
||||
|
||||
@mock.patch(
|
||||
"superset.db_engine_specs.hive.config",
|
||||
{**app.config, "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC": lambda *args: ""},
|
||||
)
|
||||
@mock.patch("superset.db_engine_specs.hive.g", spec={})
|
||||
def test_df_to_sql_if_exists_fail(mock_g):
|
||||
mock_g.user = True
|
||||
|
|
@ -185,10 +180,6 @@ def test_df_to_sql_if_exists_fail(mock_g):
|
|||
)
|
||||
|
||||
|
||||
@mock.patch(
|
||||
"superset.db_engine_specs.hive.config",
|
||||
{**app.config, "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC": lambda *args: ""},
|
||||
)
|
||||
@mock.patch("superset.db_engine_specs.hive.g", spec={})
|
||||
def test_df_to_sql_if_exists_fail_with_schema(mock_g):
|
||||
mock_g.user = True
|
||||
|
|
@ -203,13 +194,11 @@ def test_df_to_sql_if_exists_fail_with_schema(mock_g):
|
|||
)
|
||||
|
||||
|
||||
@mock.patch(
|
||||
"superset.db_engine_specs.hive.config",
|
||||
{**app.config, "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC": lambda *args: ""},
|
||||
)
|
||||
@mock.patch("superset.db_engine_specs.hive.g", spec={})
|
||||
@mock.patch("superset.db_engine_specs.hive.upload_to_s3")
|
||||
def test_df_to_sql_if_exists_replace(mock_upload_to_s3, mock_g):
|
||||
config = app.config.copy()
|
||||
app.config["CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"]: lambda *args: ""
|
||||
mock_upload_to_s3.return_value = "mock-location"
|
||||
mock_g.user = True
|
||||
mock_database = mock.MagicMock()
|
||||
|
|
@ -218,23 +207,23 @@ def test_df_to_sql_if_exists_replace(mock_upload_to_s3, mock_g):
|
|||
mock_database.get_sqla_engine.return_value.execute = mock_execute
|
||||
table_name = "foobar"
|
||||
|
||||
HiveEngineSpec.df_to_sql(
|
||||
mock_database,
|
||||
Table(table=table_name),
|
||||
pd.DataFrame(),
|
||||
{"if_exists": "replace", "header": 1, "na_values": "mock", "sep": "mock"},
|
||||
)
|
||||
with app.app_context():
|
||||
HiveEngineSpec.df_to_sql(
|
||||
mock_database,
|
||||
Table(table=table_name),
|
||||
pd.DataFrame(),
|
||||
{"if_exists": "replace", "header": 1, "na_values": "mock", "sep": "mock"},
|
||||
)
|
||||
|
||||
mock_execute.assert_any_call(f"DROP TABLE IF EXISTS {table_name}")
|
||||
app.config = config
|
||||
|
||||
|
||||
@mock.patch(
|
||||
"superset.db_engine_specs.hive.config",
|
||||
{**app.config, "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC": lambda *args: ""},
|
||||
)
|
||||
@mock.patch("superset.db_engine_specs.hive.g", spec={})
|
||||
@mock.patch("superset.db_engine_specs.hive.upload_to_s3")
|
||||
def test_df_to_sql_if_exists_replace_with_schema(mock_upload_to_s3, mock_g):
|
||||
config = app.config.copy()
|
||||
app.config["CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"]: lambda *args: ""
|
||||
mock_upload_to_s3.return_value = "mock-location"
|
||||
mock_g.user = True
|
||||
mock_database = mock.MagicMock()
|
||||
|
|
@ -244,14 +233,16 @@ def test_df_to_sql_if_exists_replace_with_schema(mock_upload_to_s3, mock_g):
|
|||
table_name = "foobar"
|
||||
schema = "schema"
|
||||
|
||||
HiveEngineSpec.df_to_sql(
|
||||
mock_database,
|
||||
Table(table=table_name, schema=schema),
|
||||
pd.DataFrame(),
|
||||
{"if_exists": "replace", "header": 1, "na_values": "mock", "sep": "mock"},
|
||||
)
|
||||
with app.app_context():
|
||||
HiveEngineSpec.df_to_sql(
|
||||
mock_database,
|
||||
Table(table=table_name, schema=schema),
|
||||
pd.DataFrame(),
|
||||
{"if_exists": "replace", "header": 1, "na_values": "mock", "sep": "mock"},
|
||||
)
|
||||
|
||||
mock_execute.assert_any_call(f"DROP TABLE IF EXISTS {schema}.{table_name}")
|
||||
app.config = config
|
||||
|
||||
|
||||
def test_is_readonly():
|
||||
|
|
@ -284,39 +275,42 @@ def test_s3_upload_prefix(schema: str, upload_prefix: str) -> None:
|
|||
|
||||
|
||||
def test_upload_to_s3_no_bucket_path():
|
||||
with pytest.raises(
|
||||
Exception,
|
||||
match="No upload bucket specified. You can specify one in the config file.",
|
||||
):
|
||||
upload_to_s3("filename", "prefix", Table("table"))
|
||||
with app.app_context():
|
||||
with pytest.raises(
|
||||
Exception,
|
||||
match="No upload bucket specified. You can specify one in the config file.",
|
||||
):
|
||||
upload_to_s3("filename", "prefix", Table("table"))
|
||||
|
||||
|
||||
@mock.patch("boto3.client")
|
||||
@mock.patch(
|
||||
"superset.db_engine_specs.hive.config",
|
||||
{**app.config, "CSV_TO_HIVE_UPLOAD_S3_BUCKET": "bucket"},
|
||||
)
|
||||
def test_upload_to_s3_client_error(client):
|
||||
config = app.config.copy()
|
||||
app.config["CSV_TO_HIVE_UPLOAD_S3_BUCKET"] = "bucket"
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
client.return_value.upload_file.side_effect = ClientError(
|
||||
{"Error": {}}, "operation_name"
|
||||
)
|
||||
|
||||
with pytest.raises(ClientError):
|
||||
upload_to_s3("filename", "prefix", Table("table"))
|
||||
with app.app_context():
|
||||
with pytest.raises(ClientError):
|
||||
upload_to_s3("filename", "prefix", Table("table"))
|
||||
|
||||
app.config = config
|
||||
|
||||
|
||||
@mock.patch("boto3.client")
|
||||
@mock.patch(
|
||||
"superset.db_engine_specs.hive.config",
|
||||
{**app.config, "CSV_TO_HIVE_UPLOAD_S3_BUCKET": "bucket"},
|
||||
)
|
||||
def test_upload_to_s3_success(client):
|
||||
config = app.config.copy()
|
||||
app.config["CSV_TO_HIVE_UPLOAD_S3_BUCKET"] = "bucket"
|
||||
client.return_value.upload_file.return_value = True
|
||||
|
||||
location = upload_to_s3("filename", "prefix", Table("table"))
|
||||
assert f"s3a://bucket/prefix/table" == location
|
||||
with app.app_context():
|
||||
location = upload_to_s3("filename", "prefix", Table("table"))
|
||||
assert f"s3a://bucket/prefix/table" == location
|
||||
|
||||
app.config = config
|
||||
|
||||
|
||||
def test_fetch_data_query_error():
|
||||
|
|
|
|||
Loading…
Reference in New Issue