diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 7ea73e30f..3f91d58dd 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -40,7 +40,7 @@ import pandas as pd import sqlparse from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin -from flask import g +from flask import current_app, g from flask_babel import gettext as __, lazy_gettext as _ from marshmallow import fields, Schema from sqlalchemy import column, DateTime, select, types @@ -55,7 +55,7 @@ from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom from sqlalchemy.types import String, TypeEngine, UnicodeText from typing_extensions import TypedDict -from superset import app, security_manager, sql_parse +from superset import security_manager, sql_parse from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.models.sql_lab import Query from superset.models.sql_types.base import literal_dttm_type_factory @@ -80,7 +80,6 @@ class TimeGrain(NamedTuple): # pylint: disable=too-few-public-methods QueryStatus = utils.QueryStatus -config = app.config builtin_time_grains: Dict[Optional[str], str] = { None: __("Original value"), @@ -369,7 +368,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods ret_list = [] time_grains = builtin_time_grains.copy() - time_grains.update(config["TIME_GRAIN_ADDONS"]) + time_grains.update(current_app.config["TIME_GRAIN_ADDONS"]) for duration, func in cls.get_time_grain_expressions().items(): if duration in time_grains: name = time_grains[duration] @@ -448,9 +447,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods """ # TODO: use @memoize decorator or similar to avoid recomputation on every call time_grain_expressions = cls._time_grain_expressions.copy() - grain_addon_expressions = config["TIME_GRAIN_ADDON_EXPRESSIONS"] + grain_addon_expressions = current_app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] time_grain_expressions.update(grain_addon_expressions.get(cls.engine, {})) - denylist: List[str] = config["TIME_GRAIN_DENYLIST"] + denylist: List[str] = current_app.config["TIME_GRAIN_DENYLIST"] for key in denylist: time_grain_expressions.pop(key) @@ -977,7 +976,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods """ parsed_query = ParsedQuery(statement) sql = parsed_query.stripped() - sql_query_mutator = config["SQL_QUERY_MUTATOR"] + sql_query_mutator = current_app.config["SQL_QUERY_MUTATOR"] if sql_query_mutator: sql = sql_query_mutator(sql, user_name, security_manager, database) diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index 6996beb76..66c68b3c2 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -27,7 +27,7 @@ import numpy as np import pandas as pd import pyarrow as pa import pyarrow.parquet as pq -from flask import g +from flask import current_app, g from sqlalchemy import Column, text from sqlalchemy.engine.base import Engine from sqlalchemy.engine.reflection import Inspector @@ -35,7 +35,6 @@ from sqlalchemy.engine.url import make_url, URL from sqlalchemy.orm import Session from sqlalchemy.sql.expression import ColumnClause, Select -from superset import app, conf from superset.db_engine_specs.base import BaseEngineSpec from superset.db_engine_specs.presto import PrestoEngineSpec from superset.exceptions import SupersetException @@ -50,12 +49,8 @@ if TYPE_CHECKING: QueryStatus = utils.QueryStatus -config = app.config logger = logging.getLogger(__name__) -tracking_url_trans = conf.get("TRACKING_URL_TRANSFORMER") -hive_poll_interval = conf.get("HIVE_POLL_INTERVAL") - def upload_to_s3(filename: str, upload_prefix: str, table: Table) -> str: """ @@ -70,7 +65,7 @@ def upload_to_s3(filename: str, upload_prefix: str, table: Table) -> str: # Optional dependency import boto3 # pylint: disable=import-error - bucket_path = config["CSV_TO_HIVE_UPLOAD_S3_BUCKET"] + bucket_path = current_app.config["CSV_TO_HIVE_UPLOAD_S3_BUCKET"] if not bucket_path: logger.info("No upload bucket specified") @@ -229,7 +224,7 @@ class HiveEngineSpec(PrestoEngineSpec): ) with tempfile.NamedTemporaryFile( - dir=config["UPLOAD_FOLDER"], suffix=".parquet" + dir=current_app.config["UPLOAD_FOLDER"], suffix=".parquet" ) as file: pq.write_table(pa.Table.from_pandas(df), where=file.name) @@ -243,9 +238,9 @@ class HiveEngineSpec(PrestoEngineSpec): ), location=upload_to_s3( filename=file.name, - upload_prefix=config["CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"]( - database, g.user, table.schema - ), + upload_prefix=current_app.config[ + "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC" + ](database, g.user, table.schema), table=table, ), ) @@ -356,7 +351,7 @@ class HiveEngineSpec(PrestoEngineSpec): str(query_id), tracking_url, ) - tracking_url = tracking_url_trans(tracking_url) + tracking_url = current_app.config["TRACKING_URL_TRANSFORMER"] logger.info( "Query %s: Transformation applied: %s", str(query_id), @@ -374,7 +369,7 @@ class HiveEngineSpec(PrestoEngineSpec): last_log_line = len(log_lines) if needs_commit: session.commit() - time.sleep(hive_poll_interval) + time.sleep(current_app.config["HIVE_POLL_INTERVAL"]) polled = cursor.poll() @classmethod diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index c6cec6a54..32741b150 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -39,6 +39,7 @@ from urllib import parse import pandas as pd import simplejson as json +from flask import current_app from flask_babel import gettext as __, lazy_gettext as _ from sqlalchemy import Column, literal_column, types from sqlalchemy.engine.base import Engine @@ -49,7 +50,7 @@ from sqlalchemy.orm import Session from sqlalchemy.sql.expression import ColumnClause, Select from sqlalchemy.types import TypeEngine -from superset import app, cache_manager, is_feature_enabled +from superset import cache_manager, is_feature_enabled from superset.db_engine_specs.base import BaseEngineSpec from superset.errors import SupersetErrorType from superset.exceptions import SupersetTemplateException @@ -94,7 +95,6 @@ CONNECTION_UNKNOWN_DATABASE_ERROR = re.compile( QueryStatus = utils.QueryStatus -config = app.config logger = logging.getLogger(__name__) @@ -940,7 +940,7 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho """Updates progress information""" query_id = query.id poll_interval = query.database.connect_args.get( - "poll_interval", config["PRESTO_POLL_INTERVAL"] + "poll_interval", current_app.config["PRESTO_POLL_INTERVAL"] ) logger.info("Query %i: Polling the cursor for progress", query_id) polled = cursor.poll() diff --git a/tests/conftest.py b/tests/conftest.py index b04a76c90..456c8fb65 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,7 +21,6 @@ import pytest from sqlalchemy.engine import Engine from tests.test_app import app - from superset import db from superset.utils.core import get_example_database, json_dumps_w_dates @@ -38,6 +37,9 @@ def app_context(): @pytest.fixture(autouse=True, scope="session") def setup_sample_data() -> Any: + # TODO(john-bodley): Determine a cleaner way of setting up the sample data without + # relying on `tests.test_app.app` leveraging an `app` fixture which is purposely + # scoped to the function level to ensure tests remain idempotent. with app.app_context(): setup_presto_if_needed() diff --git a/tests/db_engine_specs/athena_tests.py b/tests/db_engine_specs/athena_tests.py index 92160dbd0..d928a986d 100644 --- a/tests/db_engine_specs/athena_tests.py +++ b/tests/db_engine_specs/athena_tests.py @@ -14,8 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from tests.test_app import app # isort:skip - from superset.db_engine_specs.athena import AthenaEngineSpec from tests.db_engine_specs.base_tests import TestDbEngineSpec diff --git a/tests/db_engine_specs/base_engine_spec_tests.py b/tests/db_engine_specs/base_engine_spec_tests.py index b8b82ab60..097e0ba71 100644 --- a/tests/db_engine_specs/base_engine_spec_tests.py +++ b/tests/db_engine_specs/base_engine_spec_tests.py @@ -167,25 +167,6 @@ class TestDbEngineSpecs(TestDbEngineSpec): "SELECT * FROM table", "SELECT * FROM table", DummyEngineSpec ) - def test_time_grain_denylist(self): - with app.app_context(): - app.config["TIME_GRAIN_DENYLIST"] = ["PT1M"] - time_grain_functions = SqliteEngineSpec.get_time_grain_expressions() - self.assertNotIn("PT1M", time_grain_functions) - - def test_time_grain_addons(self): - with app.app_context(): - app.config["TIME_GRAIN_ADDONS"] = {"PTXM": "x seconds"} - app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = { - "sqlite": {"PTXM": "ABC({col})"} - } - time_grains = SqliteEngineSpec.get_time_grains() - time_grain_addon = time_grains[-1] - self.assertEqual("PTXM", time_grain_addon.duration) - self.assertEqual("x seconds", time_grain_addon.label) - app.config["TIME_GRAIN_ADDONS"] = {} - app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {} - def test_engine_time_grain_validity(self): time_grains = set(builtin_time_grains.keys()) # loop over all subclasses of BaseEngineSpec @@ -198,43 +179,6 @@ class TestDbEngineSpecs(TestDbEngineSpec): intersection = time_grains.intersection(defined_grains) self.assertSetEqual(defined_grains, intersection, engine) - def test_get_time_grain_with_config(self): - """ Should concatenate from configs and then sort in the proper order """ - app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = { - "mysql": { - "PT2H": "foo", - "PT4H": "foo", - "PT6H": "foo", - "PT8H": "foo", - "PT10H": "foo", - "PT12H": "foo", - "PT1S": "foo", - } - } - time_grains = MySQLEngineSpec.get_time_grain_expressions() - self.assertEqual( - list(time_grains.keys()), - [ - None, - "PT1S", - "PT1M", - "PT1H", - "PT2H", - "PT4H", - "PT6H", - "PT8H", - "PT10H", - "PT12H", - "P1D", - "P1W", - "P1M", - "P0.25Y", - "P1Y", - "1969-12-29T00:00:00Z/P1W", - ], - ) - app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {} - def test_get_time_grain_expressions(self): time_grains = MySQLEngineSpec.get_time_grain_expressions() self.assertEqual( @@ -253,18 +197,6 @@ class TestDbEngineSpecs(TestDbEngineSpec): ], ) - def test_get_time_grain_with_unkown_values(self): - """Should concatenate from configs and then sort in the proper order - putting unknown patterns at the end""" - app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = { - "mysql": {"PT2H": "foo", "weird": "foo", "PT12H": "foo",} - } - time_grains = MySQLEngineSpec.get_time_grain_expressions() - self.assertEqual( - list(time_grains)[-1], "weird", - ) - app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {} - def test_get_table_names(self): inspector = mock.Mock() inspector.get_table_names = mock.Mock(return_value=["schema.table", "table_2"]) @@ -339,3 +271,84 @@ def test_is_readonly(): assert is_readonly("WITH (SELECT 1) bla SELECT * from bla") assert is_readonly("SHOW CATALOGS") assert is_readonly("SHOW TABLES") + + +def test_time_grain_denylist(): + config = app.config.copy() + app.config["TIME_GRAIN_DENYLIST"] = ["PT1M"] + + with app.app_context(): + time_grain_functions = SqliteEngineSpec.get_time_grain_expressions() + assert not "PT1M" in time_grain_functions + + app.config = config + + +def test_time_grain_addons(): + config = app.config.copy() + app.config["TIME_GRAIN_ADDONS"] = {"PTXM": "x seconds"} + app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {"sqlite": {"PTXM": "ABC({col})"}} + + with app.app_context(): + time_grains = SqliteEngineSpec.get_time_grains() + time_grain_addon = time_grains[-1] + assert "PTXM" == time_grain_addon.duration + assert "x seconds" == time_grain_addon.label + + app.config = config + + +def test_get_time_grain_with_config(): + """ Should concatenate from configs and then sort in the proper order """ + config = app.config.copy() + + app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = { + "mysql": { + "PT2H": "foo", + "PT4H": "foo", + "PT6H": "foo", + "PT8H": "foo", + "PT10H": "foo", + "PT12H": "foo", + "PT1S": "foo", + } + } + + with app.app_context(): + time_grains = MySQLEngineSpec.get_time_grain_expressions() + assert set(time_grains.keys()) == { + None, + "PT1S", + "PT1M", + "PT1H", + "PT2H", + "PT4H", + "PT6H", + "PT8H", + "PT10H", + "PT12H", + "P1D", + "P1W", + "P1M", + "P0.25Y", + "P1Y", + "1969-12-29T00:00:00Z/P1W", + } + + app.config = config + + +def test_get_time_grain_with_unkown_values(): + """Should concatenate from configs and then sort in the proper order + putting unknown patterns at the end""" + config = app.config.copy() + + app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = { + "mysql": {"PT2H": "foo", "weird": "foo", "PT12H": "foo",} + } + + with app.app_context(): + time_grains = MySQLEngineSpec.get_time_grain_expressions() + assert list(time_grains)[-1] == "weird" + + app.config = config diff --git a/tests/db_engine_specs/hive_tests.py b/tests/db_engine_specs/hive_tests.py index 4d5051824..1e978b71b 100644 --- a/tests/db_engine_specs/hive_tests.py +++ b/tests/db_engine_specs/hive_tests.py @@ -21,12 +21,11 @@ from unittest import mock import pytest import pandas as pd from sqlalchemy.sql import select -from tests.test_app import app -with app.app_context(): - from superset.db_engine_specs.hive import HiveEngineSpec, upload_to_s3 +from superset.db_engine_specs.hive import HiveEngineSpec, upload_to_s3 from superset.exceptions import SupersetException from superset.sql_parse import Table, ParsedQuery +from tests.test_app import app def test_0_progress(): @@ -170,10 +169,6 @@ def test_df_to_csv() -> None: ) -@mock.patch( - "superset.db_engine_specs.hive.config", - {**app.config, "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC": lambda *args: ""}, -) @mock.patch("superset.db_engine_specs.hive.g", spec={}) def test_df_to_sql_if_exists_fail(mock_g): mock_g.user = True @@ -185,10 +180,6 @@ def test_df_to_sql_if_exists_fail(mock_g): ) -@mock.patch( - "superset.db_engine_specs.hive.config", - {**app.config, "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC": lambda *args: ""}, -) @mock.patch("superset.db_engine_specs.hive.g", spec={}) def test_df_to_sql_if_exists_fail_with_schema(mock_g): mock_g.user = True @@ -203,13 +194,11 @@ def test_df_to_sql_if_exists_fail_with_schema(mock_g): ) -@mock.patch( - "superset.db_engine_specs.hive.config", - {**app.config, "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC": lambda *args: ""}, -) @mock.patch("superset.db_engine_specs.hive.g", spec={}) @mock.patch("superset.db_engine_specs.hive.upload_to_s3") def test_df_to_sql_if_exists_replace(mock_upload_to_s3, mock_g): + config = app.config.copy() + app.config["CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"]: lambda *args: "" mock_upload_to_s3.return_value = "mock-location" mock_g.user = True mock_database = mock.MagicMock() @@ -218,23 +207,23 @@ def test_df_to_sql_if_exists_replace(mock_upload_to_s3, mock_g): mock_database.get_sqla_engine.return_value.execute = mock_execute table_name = "foobar" - HiveEngineSpec.df_to_sql( - mock_database, - Table(table=table_name), - pd.DataFrame(), - {"if_exists": "replace", "header": 1, "na_values": "mock", "sep": "mock"}, - ) + with app.app_context(): + HiveEngineSpec.df_to_sql( + mock_database, + Table(table=table_name), + pd.DataFrame(), + {"if_exists": "replace", "header": 1, "na_values": "mock", "sep": "mock"}, + ) mock_execute.assert_any_call(f"DROP TABLE IF EXISTS {table_name}") + app.config = config -@mock.patch( - "superset.db_engine_specs.hive.config", - {**app.config, "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC": lambda *args: ""}, -) @mock.patch("superset.db_engine_specs.hive.g", spec={}) @mock.patch("superset.db_engine_specs.hive.upload_to_s3") def test_df_to_sql_if_exists_replace_with_schema(mock_upload_to_s3, mock_g): + config = app.config.copy() + app.config["CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"]: lambda *args: "" mock_upload_to_s3.return_value = "mock-location" mock_g.user = True mock_database = mock.MagicMock() @@ -244,14 +233,16 @@ def test_df_to_sql_if_exists_replace_with_schema(mock_upload_to_s3, mock_g): table_name = "foobar" schema = "schema" - HiveEngineSpec.df_to_sql( - mock_database, - Table(table=table_name, schema=schema), - pd.DataFrame(), - {"if_exists": "replace", "header": 1, "na_values": "mock", "sep": "mock"}, - ) + with app.app_context(): + HiveEngineSpec.df_to_sql( + mock_database, + Table(table=table_name, schema=schema), + pd.DataFrame(), + {"if_exists": "replace", "header": 1, "na_values": "mock", "sep": "mock"}, + ) mock_execute.assert_any_call(f"DROP TABLE IF EXISTS {schema}.{table_name}") + app.config = config def test_is_readonly(): @@ -284,39 +275,42 @@ def test_s3_upload_prefix(schema: str, upload_prefix: str) -> None: def test_upload_to_s3_no_bucket_path(): - with pytest.raises( - Exception, - match="No upload bucket specified. You can specify one in the config file.", - ): - upload_to_s3("filename", "prefix", Table("table")) + with app.app_context(): + with pytest.raises( + Exception, + match="No upload bucket specified. You can specify one in the config file.", + ): + upload_to_s3("filename", "prefix", Table("table")) @mock.patch("boto3.client") -@mock.patch( - "superset.db_engine_specs.hive.config", - {**app.config, "CSV_TO_HIVE_UPLOAD_S3_BUCKET": "bucket"}, -) def test_upload_to_s3_client_error(client): + config = app.config.copy() + app.config["CSV_TO_HIVE_UPLOAD_S3_BUCKET"] = "bucket" from botocore.exceptions import ClientError client.return_value.upload_file.side_effect = ClientError( {"Error": {}}, "operation_name" ) - with pytest.raises(ClientError): - upload_to_s3("filename", "prefix", Table("table")) + with app.app_context(): + with pytest.raises(ClientError): + upload_to_s3("filename", "prefix", Table("table")) + + app.config = config @mock.patch("boto3.client") -@mock.patch( - "superset.db_engine_specs.hive.config", - {**app.config, "CSV_TO_HIVE_UPLOAD_S3_BUCKET": "bucket"}, -) def test_upload_to_s3_success(client): + config = app.config.copy() + app.config["CSV_TO_HIVE_UPLOAD_S3_BUCKET"] = "bucket" client.return_value.upload_file.return_value = True - location = upload_to_s3("filename", "prefix", Table("table")) - assert f"s3a://bucket/prefix/table" == location + with app.app_context(): + location = upload_to_s3("filename", "prefix", Table("table")) + assert f"s3a://bucket/prefix/table" == location + + app.config = config def test_fetch_data_query_error():