diff --git a/superset/app.py b/superset/app.py index 2f41d3747..c3731a4a1 100644 --- a/superset/app.py +++ b/superset/app.py @@ -125,6 +125,7 @@ class SupersetAppInitializer: # # pylint: disable=too-many-locals # pylint: disable=too-many-statements + from superset.cachekeys.api import CacheRestApi from superset.charts.api import ChartRestApi from superset.connectors.druid.views import ( Druid, @@ -194,6 +195,7 @@ class SupersetAppInitializer: # # Setup API views # + appbuilder.add_api(CacheRestApi) appbuilder.add_api(ChartRestApi) appbuilder.add_api(DashboardRestApi) appbuilder.add_api(DatabaseRestApi) diff --git a/superset/cachekeys/__init__.py b/superset/cachekeys/__init__.py new file mode 100644 index 000000000..13a83393a --- /dev/null +++ b/superset/cachekeys/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/cachekeys/api.py b/superset/cachekeys/api.py new file mode 100644 index 000000000..92c51bdac --- /dev/null +++ b/superset/cachekeys/api.py @@ -0,0 +1,123 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging + +from flask import request, Response +from flask_appbuilder import expose +from flask_appbuilder.api import safe +from flask_appbuilder.models.sqla.interface import SQLAInterface +from flask_appbuilder.security.decorators import protect +from marshmallow.exceptions import ValidationError +from sqlalchemy.exc import SQLAlchemyError + +from superset.cachekeys.schemas import CacheInvalidationRequestSchema +from superset.connectors.connector_registry import ConnectorRegistry +from superset.extensions import cache_manager, db, event_logger +from superset.models.cache import CacheKey +from superset.views.base_api import BaseSupersetModelRestApi, statsd_metrics + +logger = logging.getLogger(__name__) + + +class CacheRestApi(BaseSupersetModelRestApi): + datamodel = SQLAInterface(CacheKey) + resource_name = "cachekey" + allow_browser_login = True + class_permission_name = "CacheRestApi" + include_route_methods = { + "invalidate", + } + + openapi_spec_component_schemas = (CacheInvalidationRequestSchema,) + + @expose("/invalidate", methods=["POST"]) + @event_logger.log_this + @protect() + @safe + @statsd_metrics + def invalidate(self) -> Response: + """ + Takes a list of datasources, finds the associated cache records and + invalidates them and removes the database records + + --- + post: + description: >- + Takes a list of datasources, finds the associated cache records and + invalidates them and removes the database records + requestBody: + description: >- + A list of datasources uuid or the tuples of database and datasource names + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CacheInvalidationRequestSchema" + responses: + 201: + description: cache was successfully invalidated + 400: + $ref: '#/components/responses/400' + 500: + $ref: '#/components/responses/500' + """ + try: + datasources = CacheInvalidationRequestSchema().load(request.json) + except KeyError: + return self.response_400(message="Request is incorrect") + except ValidationError as error: + return self.response_400(message=str(error)) + datasource_uids = set(datasources.get("datasource_uids", [])) + for ds in datasources.get("datasources", []): + ds_obj = ConnectorRegistry.get_datasource_by_name( + session=db.session, + datasource_type=ds.get("datasource_type"), + datasource_name=ds.get("datasource_name"), + schema=ds.get("schema"), + database_name=ds.get("database_name"), + ) + if ds_obj: + datasource_uids.add(ds_obj.uid) + + cache_key_objs = ( + db.session.query(CacheKey) + .filter(CacheKey.datasource_uid.in_(datasource_uids)) + .all() + ) + cache_keys = [c.cache_key for c in cache_key_objs] + if cache_key_objs: + all_keys_deleted = cache_manager.cache.delete_many(*cache_keys) + + if not all_keys_deleted: + # expected behavior as keys may expire and cache is not a + # persistent storage + logger.info( + "Some of the cache keys were not deleted in the list %s", cache_keys + ) + + try: + delete_stmt = CacheKey.__table__.delete().where( # pylint: disable=no-member + CacheKey.cache_key.in_(cache_keys) + ) + db.session.execute(delete_stmt) + db.session.commit() + except SQLAlchemyError as ex: # pragma: no cover + logger.error(ex) + db.session.rollback() + return self.response_500(str(ex)) + db.session.commit() + return self.response(201) diff --git a/superset/cachekeys/schemas.py b/superset/cachekeys/schemas.py new file mode 100644 index 000000000..a97aebdf2 --- /dev/null +++ b/superset/cachekeys/schemas.py @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# RISON/JSON schemas for query parameters +from marshmallow import fields, Schema, validate + +from superset.charts.schemas import ( + datasource_name_description, + datasource_type_description, + datasource_uid_description, +) + + +class Datasource(Schema): + database_name = fields.String(description="Datasource name",) + datasource_name = fields.String(description=datasource_name_description,) + schema = fields.String(description="Datasource schema",) + datasource_type = fields.String( + description=datasource_type_description, + validate=validate.OneOf(choices=("druid", "table", "view")), + required=True, + ) + + +class CacheInvalidationRequestSchema(Schema): + datasource_uids = fields.List( + fields.String(), description=datasource_uid_description, + ) + datasources = fields.List( + fields.Nested(Datasource), + description="A list of the data source and database names", + ) diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 1fba09c9a..5500271c5 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -71,6 +71,10 @@ datasource_id_description = ( "A complete datasource identification needs `datasouce_id` " "and `datasource_type`." ) +datasource_uid_description = ( + "The uid of the dataset/datasource this new chart will use. " + "A complete datasource identification needs `datasouce_uid` " +) datasource_type_description = ( "The type of dataset/datasource identified on `datasource_id`." ) diff --git a/tests/base_tests.py b/tests/base_tests.py index 1c593320c..51bae4883 100644 --- a/tests/base_tests.py +++ b/tests/base_tests.py @@ -22,6 +22,7 @@ from typing import Any, Dict, Union, List, Optional from unittest.mock import Mock, patch import pandas as pd +import pytest from flask import Response from flask_appbuilder.security.sqla import models as ab_models from flask_testing import TestCase @@ -42,6 +43,7 @@ from superset.utils.core import get_example_database from superset.views.base_api import BaseSupersetModelRestApi FAKE_DB_NAME = "fake_db_100" +test_client = app.test_client() def login(client: Any, username: str = "admin", password: str = "general"): @@ -69,6 +71,39 @@ def get_resp( return resp.data.decode("utf-8") +def post_assert_metric( + client: Any, uri: str, data: Dict[str, Any], func_name: str +) -> Response: + """ + Simple client post with an extra assertion for statsd metrics + + :param client: test client for superset api requests + :param uri: The URI to use for the HTTP POST + :param data: The JSON data payload to be posted + :param func_name: The function name that the HTTP POST triggers + for the statsd metric assertion + :return: HTTP Response + """ + with patch.object( + BaseSupersetModelRestApi, "incr_stats", return_value=None + ) as mock_method: + rv = client.post(uri, json=data) + if 200 <= rv.status_code < 400: + mock_method.assert_called_once_with("success", func_name) + else: + mock_method.assert_called_once_with("error", func_name) + return rv + + +@pytest.fixture +def logged_in_admin(): + """Fixture with app context and logged in admin user.""" + with app.app_context(): + login(test_client, username="admin") + yield + test_client.get("/logout/", follow_redirects=True) + + class SupersetTestCase(TestCase): default_schema_backend_map = { @@ -84,6 +119,15 @@ class SupersetTestCase(TestCase): def create_app(self): return app + @staticmethod + def get_birth_names_dataset(): + example_db = get_example_database() + return ( + db.session.query(SqlaTable) + .filter_by(database=example_db, table_name="birth_names") + .one() + ) + @staticmethod def create_user_with_roles(username: str, roles: List[str]): user_to_create = security_manager.find_user(username) @@ -422,24 +466,7 @@ class SupersetTestCase(TestCase): def post_assert_metric( self, uri: str, data: Dict[str, Any], func_name: str ) -> Response: - """ - Simple client post with an extra assertion for statsd metrics - - :param uri: The URI to use for the HTTP POST - :param data: The JSON data payload to be posted - :param func_name: The function name that the HTTP POST triggers - for the statsd metric assertion - :return: HTTP Response - """ - with patch.object( - BaseSupersetModelRestApi, "incr_stats", return_value=None - ) as mock_method: - rv = self.client.post(uri, json=data) - if 200 <= rv.status_code < 400: - mock_method.assert_called_once_with("success", func_name) - else: - mock_method.assert_called_once_with("error", func_name) - return rv + return post_assert_metric(self.client, uri, data, func_name) def put_assert_metric( self, uri: str, data: Dict[str, Any], func_name: str diff --git a/tests/cachekeys/__init__.py b/tests/cachekeys/__init__.py new file mode 100644 index 000000000..13a83393a --- /dev/null +++ b/tests/cachekeys/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/cachekeys/api_tests.py b/tests/cachekeys/api_tests.py new file mode 100644 index 000000000..3f08750d8 --- /dev/null +++ b/tests/cachekeys/api_tests.py @@ -0,0 +1,163 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# isort:skip_file +"""Unit tests for Superset""" +from typing import Dict, Any + +from tests.test_app import app # noqa + +from superset.extensions import cache_manager, db +from superset.models.cache import CacheKey +from tests.base_tests import ( + SupersetTestCase, + post_assert_metric, + test_client, + logged_in_admin, +) # noqa + + +def invalidate(params: Dict[str, Any]): + return post_assert_metric( + test_client, "api/v1/cachekey/invalidate", params, "invalidate" + ) + + +def test_invalidate_cache(logged_in_admin): + rv = invalidate({"datasource_uids": ["3__table"]}) + assert rv.status_code == 201 + + +def test_invalidate_existing_cache(logged_in_admin): + db.session.add(CacheKey(cache_key="cache_key", datasource_uid="3__table")) + db.session.commit() + cache_manager.cache.set("cache_key", "value") + + rv = invalidate({"datasource_uids": ["3__table"]}) + + assert rv.status_code == 201 + assert cache_manager.cache.get("cache_key") == None + assert ( + not db.session.query(CacheKey).filter(CacheKey.cache_key == "cache_key").first() + ) + + +def test_invalidate_cache_empty_input(logged_in_admin): + rv = invalidate({"datasource_uids": []}) + assert rv.status_code == 201 + + rv = invalidate({"datasources": []}) + assert rv.status_code == 201 + + rv = invalidate({"datasource_uids": [], "datasources": []}) + assert rv.status_code == 201 + + +def test_invalidate_cache_bad_request(logged_in_admin): + rv = invalidate( + { + "datasource_uids": [], + "datasources": [{"datasource_name": "", "datasource_type": None}], + } + ) + assert rv.status_code == 400 + + rv = invalidate( + { + "datasource_uids": [], + "datasources": [{"datasource_name": "", "datasource_type": "bla"}], + } + ) + assert rv.status_code == 400 + + rv = invalidate( + { + "datasource_uids": "datasource", + "datasources": [{"datasource_name": "", "datasource_type": "bla"}], + } + ) + assert rv.status_code == 400 + + +def test_invalidate_existing_caches(logged_in_admin): + bn = SupersetTestCase.get_birth_names_dataset() + + db.session.add(CacheKey(cache_key="cache_key1", datasource_uid="3__druid")) + db.session.add(CacheKey(cache_key="cache_key2", datasource_uid="3__druid")) + db.session.add(CacheKey(cache_key="cache_key4", datasource_uid=f"{bn.id}__table")) + db.session.add(CacheKey(cache_key="cache_keyX", datasource_uid="X__table")) + db.session.commit() + + cache_manager.cache.set("cache_key1", "value") + cache_manager.cache.set("cache_key2", "value") + cache_manager.cache.set("cache_key4", "value") + cache_manager.cache.set("cache_keyX", "value") + + rv = invalidate( + { + "datasource_uids": ["3__druid", "4__druid"], + "datasources": [ + { + "datasource_name": "birth_names", + "database_name": "examples", + "schema": "", + "datasource_type": "table", + }, + { # table exists, no cache to invalidate + "datasource_name": "energy_usage", + "database_name": "examples", + "schema": "", + "datasource_type": "table", + }, + { # table doesn't exist + "datasource_name": "does_not_exist", + "database_name": "examples", + "schema": "", + "datasource_type": "table", + }, + { # database doesn't exist + "datasource_name": "birth_names", + "database_name": "does_not_exist", + "schema": "", + "datasource_type": "table", + }, + { # database doesn't exist + "datasource_name": "birth_names", + "database_name": "examples", + "schema": "does_not_exist", + "datasource_type": "table", + }, + ], + } + ) + + assert rv.status_code == 201 + assert cache_manager.cache.get("cache_key1") is None + assert cache_manager.cache.get("cache_key2") is None + assert cache_manager.cache.get("cache_key4") is None + assert cache_manager.cache.get("cache_keyX") == "value" + assert ( + not db.session.query(CacheKey) + .filter(CacheKey.cache_key.in_({"cache_key1", "cache_key2", "cache_key4"})) + .first() + ) + assert ( + db.session.query(CacheKey) + .filter(CacheKey.cache_key == "cache_keyX") + .first() + .datasource_uid + == "X__table" + ) diff --git a/tests/datasets/api_tests.py b/tests/datasets/api_tests.py index 3d54785f5..56f9ef846 100644 --- a/tests/datasets/api_tests.py +++ b/tests/datasets/api_tests.py @@ -60,15 +60,6 @@ class TestDatasetApi(SupersetTestCase): "ab_permission", "", [self.get_user("admin").id], get_main_database() ) - @staticmethod - def get_birth_names_dataset(): - example_db = get_example_database() - return ( - db.session.query(SqlaTable) - .filter_by(database=example_db, table_name="birth_names") - .one() - ) - @staticmethod def get_energy_usage_dataset(): example_db = get_example_database() diff --git a/tests/superset_test_config.py b/tests/superset_test_config.py index f0259ab61..513d3d90c 100644 --- a/tests/superset_test_config.py +++ b/tests/superset_test_config.py @@ -76,6 +76,14 @@ REDIS_HOST = os.environ.get("REDIS_HOST", "localhost") REDIS_PORT = os.environ.get("REDIS_PORT", "6379") REDIS_CELERY_DB = os.environ.get("REDIS_CELERY_DB", 2) REDIS_RESULTS_DB = os.environ.get("REDIS_RESULTS_DB", 3) +REDIS_CACHE_DB = os.environ.get("REDIS_CACHE_DB", 4) + +CACHE_CONFIG = { + "CACHE_TYPE": "redis", + "CACHE_DEFAULT_TIMEOUT": 60 * 60 * 24, # 1 day default (in secs) + "CACHE_KEY_PREFIX": "superset_cache", + "CACHE_REDIS_URL": f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_CACHE_DB}", +} class CeleryConfig(object):