diff --git a/superset/config.py b/superset/config.py index a794635ed..7ccb69542 100644 --- a/superset/config.py +++ b/superset/config.py @@ -427,7 +427,14 @@ FEATURE_FLAGS: Dict[str, bool] = {} # feature_flags_dict['some_feature'] = g.user and g.user.get_id() == 5 # return feature_flags_dict GET_FEATURE_FLAGS_FUNC: Optional[Callable[[Dict[str, bool]], Dict[str, bool]]] = None - +# A function that receives a feature flag name and an optional default value. +# Has a similar utility to GET_FEATURE_FLAGS_FUNC but it's useful to not force the +# evaluation of all feature flags when just evaluating a single one. +# +# Note that the default `get_feature_flags` will evaluate each feature with this +# callable when the config key is set, so don't use both GET_FEATURE_FLAGS_FUNC +# and IS_FEATURE_ENABLED_FUNC in conjunction. +IS_FEATURE_ENABLED_FUNC: Optional[Callable[[str, Optional[bool]], bool]] = None # A function that expands/overrides the frontend `bootstrap_data.common` object. # Can be used to implement custom frontend functionality, # or dynamically change certain configs. diff --git a/superset/utils/feature_flag_manager.py b/superset/utils/feature_flag_manager.py index 88f19c2f4..86d2487f8 100644 --- a/superset/utils/feature_flag_manager.py +++ b/superset/utils/feature_flag_manager.py @@ -24,24 +24,36 @@ class FeatureFlagManager: def __init__(self) -> None: super().__init__() self._get_feature_flags_func = None + self._is_feature_enabled_func = None self._feature_flags: Dict[str, Any] = {} def init_app(self, app: Flask) -> None: self._get_feature_flags_func = app.config["GET_FEATURE_FLAGS_FUNC"] + self._is_feature_enabled_func = app.config["IS_FEATURE_ENABLED_FUNC"] self._feature_flags = app.config["DEFAULT_FEATURE_FLAGS"] self._feature_flags.update(app.config["FEATURE_FLAGS"]) def get_feature_flags(self) -> Dict[str, Any]: if self._get_feature_flags_func: return self._get_feature_flags_func(deepcopy(self._feature_flags)) - + if callable(self._is_feature_enabled_func): + return dict( + map( + lambda kv: (kv[0], self._is_feature_enabled_func(kv[0], kv[1])), + self._feature_flags.items(), + ) + ) return self._feature_flags def is_feature_enabled(self, feature: str) -> bool: """Utility function for checking whether a feature is turned on""" + if self._is_feature_enabled_func: + return ( + self._is_feature_enabled_func(feature, self._feature_flags[feature]) + if feature in self._feature_flags + else False + ) feature_flags = self.get_feature_flags() - if feature_flags and feature in feature_flags: return feature_flags[feature] - return False diff --git a/tests/integration_tests/feature_flag_tests.py b/tests/integration_tests/feature_flag_tests.py index 500cb572e..a5818bd8c 100644 --- a/tests/integration_tests/feature_flag_tests.py +++ b/tests/integration_tests/feature_flag_tests.py @@ -16,10 +16,16 @@ # under the License. from unittest.mock import patch -from superset import is_feature_enabled +from parameterized import parameterized + +from superset import get_feature_flags, is_feature_enabled from tests.integration_tests.base_tests import SupersetTestCase +def dummy_is_feature_enabled(feature_flag_name: str, default: bool = True) -> bool: + return True if feature_flag_name.startswith("True_") else default + + class TestFeatureFlag(SupersetTestCase): @patch.dict( "superset.extensions.feature_flag_manager._feature_flags", @@ -38,3 +44,40 @@ class TestFeatureFlag(SupersetTestCase): def test_feature_flags(self): self.assertEqual(is_feature_enabled("foo"), "bar") self.assertEqual(is_feature_enabled("super"), "set") + + +@patch.dict( + "superset.extensions.feature_flag_manager._feature_flags", + {"True_Flag1": False, "True_Flag2": True, "Flag3": False, "Flag4": True}, + clear=True, +) +class TestFeatureFlagBackend(SupersetTestCase): + @parameterized.expand( + [ + ("True_Flag1", True), + ("True_Flag2", True), + ("Flag3", False), + ("Flag4", True), + ("True_DoesNotExist", False), + ] + ) + @patch( + "superset.extensions.feature_flag_manager._is_feature_enabled_func", + dummy_is_feature_enabled, + ) + def test_feature_flags_override(self, feature_flag_name, expected): + self.assertEqual(is_feature_enabled(feature_flag_name), expected) + + @patch( + "superset.extensions.feature_flag_manager._is_feature_enabled_func", + dummy_is_feature_enabled, + ) + @patch( + "superset.extensions.feature_flag_manager._get_feature_flags_func", None, + ) + def test_get_feature_flags(self): + feature_flags = get_feature_flags() + self.assertEqual( + feature_flags, + {"True_Flag1": True, "True_Flag2": True, "Flag3": False, "Flag4": True}, + )