feat: implement cache invalidation api (#10761)
* Add cache endpoints * Implement cache endpoint * Tests and address feedback * Set cache config * Address feedback * Expose only invalidate endpoint Co-authored-by: bogdan kyryliuk <bogdankyryliuk@dropbox.com>
This commit is contained in:
parent
838a70ea8d
commit
9c420d6efe
|
|
@ -125,6 +125,7 @@ class SupersetAppInitializer:
|
|||
#
|
||||
# pylint: disable=too-many-locals
|
||||
# pylint: disable=too-many-statements
|
||||
from superset.cachekeys.api import CacheRestApi
|
||||
from superset.charts.api import ChartRestApi
|
||||
from superset.connectors.druid.views import (
|
||||
Druid,
|
||||
|
|
@ -194,6 +195,7 @@ class SupersetAppInitializer:
|
|||
#
|
||||
# Setup API views
|
||||
#
|
||||
appbuilder.add_api(CacheRestApi)
|
||||
appbuilder.add_api(ChartRestApi)
|
||||
appbuilder.add_api(DashboardRestApi)
|
||||
appbuilder.add_api(DatabaseRestApi)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,16 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
|
@ -0,0 +1,123 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
|
||||
from flask import request, Response
|
||||
from flask_appbuilder import expose
|
||||
from flask_appbuilder.api import safe
|
||||
from flask_appbuilder.models.sqla.interface import SQLAInterface
|
||||
from flask_appbuilder.security.decorators import protect
|
||||
from marshmallow.exceptions import ValidationError
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from superset.cachekeys.schemas import CacheInvalidationRequestSchema
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.extensions import cache_manager, db, event_logger
|
||||
from superset.models.cache import CacheKey
|
||||
from superset.views.base_api import BaseSupersetModelRestApi, statsd_metrics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CacheRestApi(BaseSupersetModelRestApi):
|
||||
datamodel = SQLAInterface(CacheKey)
|
||||
resource_name = "cachekey"
|
||||
allow_browser_login = True
|
||||
class_permission_name = "CacheRestApi"
|
||||
include_route_methods = {
|
||||
"invalidate",
|
||||
}
|
||||
|
||||
openapi_spec_component_schemas = (CacheInvalidationRequestSchema,)
|
||||
|
||||
@expose("/invalidate", methods=["POST"])
|
||||
@event_logger.log_this
|
||||
@protect()
|
||||
@safe
|
||||
@statsd_metrics
|
||||
def invalidate(self) -> Response:
|
||||
"""
|
||||
Takes a list of datasources, finds the associated cache records and
|
||||
invalidates them and removes the database records
|
||||
|
||||
---
|
||||
post:
|
||||
description: >-
|
||||
Takes a list of datasources, finds the associated cache records and
|
||||
invalidates them and removes the database records
|
||||
requestBody:
|
||||
description: >-
|
||||
A list of datasources uuid or the tuples of database and datasource names
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CacheInvalidationRequestSchema"
|
||||
responses:
|
||||
201:
|
||||
description: cache was successfully invalidated
|
||||
400:
|
||||
$ref: '#/components/responses/400'
|
||||
500:
|
||||
$ref: '#/components/responses/500'
|
||||
"""
|
||||
try:
|
||||
datasources = CacheInvalidationRequestSchema().load(request.json)
|
||||
except KeyError:
|
||||
return self.response_400(message="Request is incorrect")
|
||||
except ValidationError as error:
|
||||
return self.response_400(message=str(error))
|
||||
datasource_uids = set(datasources.get("datasource_uids", []))
|
||||
for ds in datasources.get("datasources", []):
|
||||
ds_obj = ConnectorRegistry.get_datasource_by_name(
|
||||
session=db.session,
|
||||
datasource_type=ds.get("datasource_type"),
|
||||
datasource_name=ds.get("datasource_name"),
|
||||
schema=ds.get("schema"),
|
||||
database_name=ds.get("database_name"),
|
||||
)
|
||||
if ds_obj:
|
||||
datasource_uids.add(ds_obj.uid)
|
||||
|
||||
cache_key_objs = (
|
||||
db.session.query(CacheKey)
|
||||
.filter(CacheKey.datasource_uid.in_(datasource_uids))
|
||||
.all()
|
||||
)
|
||||
cache_keys = [c.cache_key for c in cache_key_objs]
|
||||
if cache_key_objs:
|
||||
all_keys_deleted = cache_manager.cache.delete_many(*cache_keys)
|
||||
|
||||
if not all_keys_deleted:
|
||||
# expected behavior as keys may expire and cache is not a
|
||||
# persistent storage
|
||||
logger.info(
|
||||
"Some of the cache keys were not deleted in the list %s", cache_keys
|
||||
)
|
||||
|
||||
try:
|
||||
delete_stmt = CacheKey.__table__.delete().where( # pylint: disable=no-member
|
||||
CacheKey.cache_key.in_(cache_keys)
|
||||
)
|
||||
db.session.execute(delete_stmt)
|
||||
db.session.commit()
|
||||
except SQLAlchemyError as ex: # pragma: no cover
|
||||
logger.error(ex)
|
||||
db.session.rollback()
|
||||
return self.response_500(str(ex))
|
||||
db.session.commit()
|
||||
return self.response(201)
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# RISON/JSON schemas for query parameters
|
||||
from marshmallow import fields, Schema, validate
|
||||
|
||||
from superset.charts.schemas import (
|
||||
datasource_name_description,
|
||||
datasource_type_description,
|
||||
datasource_uid_description,
|
||||
)
|
||||
|
||||
|
||||
class Datasource(Schema):
|
||||
database_name = fields.String(description="Datasource name",)
|
||||
datasource_name = fields.String(description=datasource_name_description,)
|
||||
schema = fields.String(description="Datasource schema",)
|
||||
datasource_type = fields.String(
|
||||
description=datasource_type_description,
|
||||
validate=validate.OneOf(choices=("druid", "table", "view")),
|
||||
required=True,
|
||||
)
|
||||
|
||||
|
||||
class CacheInvalidationRequestSchema(Schema):
|
||||
datasource_uids = fields.List(
|
||||
fields.String(), description=datasource_uid_description,
|
||||
)
|
||||
datasources = fields.List(
|
||||
fields.Nested(Datasource),
|
||||
description="A list of the data source and database names",
|
||||
)
|
||||
|
|
@ -71,6 +71,10 @@ datasource_id_description = (
|
|||
"A complete datasource identification needs `datasouce_id` "
|
||||
"and `datasource_type`."
|
||||
)
|
||||
datasource_uid_description = (
|
||||
"The uid of the dataset/datasource this new chart will use. "
|
||||
"A complete datasource identification needs `datasouce_uid` "
|
||||
)
|
||||
datasource_type_description = (
|
||||
"The type of dataset/datasource identified on `datasource_id`."
|
||||
)
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from typing import Any, Dict, Union, List, Optional
|
|||
from unittest.mock import Mock, patch
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from flask import Response
|
||||
from flask_appbuilder.security.sqla import models as ab_models
|
||||
from flask_testing import TestCase
|
||||
|
|
@ -42,6 +43,7 @@ from superset.utils.core import get_example_database
|
|||
from superset.views.base_api import BaseSupersetModelRestApi
|
||||
|
||||
FAKE_DB_NAME = "fake_db_100"
|
||||
test_client = app.test_client()
|
||||
|
||||
|
||||
def login(client: Any, username: str = "admin", password: str = "general"):
|
||||
|
|
@ -69,6 +71,39 @@ def get_resp(
|
|||
return resp.data.decode("utf-8")
|
||||
|
||||
|
||||
def post_assert_metric(
|
||||
client: Any, uri: str, data: Dict[str, Any], func_name: str
|
||||
) -> Response:
|
||||
"""
|
||||
Simple client post with an extra assertion for statsd metrics
|
||||
|
||||
:param client: test client for superset api requests
|
||||
:param uri: The URI to use for the HTTP POST
|
||||
:param data: The JSON data payload to be posted
|
||||
:param func_name: The function name that the HTTP POST triggers
|
||||
for the statsd metric assertion
|
||||
:return: HTTP Response
|
||||
"""
|
||||
with patch.object(
|
||||
BaseSupersetModelRestApi, "incr_stats", return_value=None
|
||||
) as mock_method:
|
||||
rv = client.post(uri, json=data)
|
||||
if 200 <= rv.status_code < 400:
|
||||
mock_method.assert_called_once_with("success", func_name)
|
||||
else:
|
||||
mock_method.assert_called_once_with("error", func_name)
|
||||
return rv
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def logged_in_admin():
|
||||
"""Fixture with app context and logged in admin user."""
|
||||
with app.app_context():
|
||||
login(test_client, username="admin")
|
||||
yield
|
||||
test_client.get("/logout/", follow_redirects=True)
|
||||
|
||||
|
||||
class SupersetTestCase(TestCase):
|
||||
|
||||
default_schema_backend_map = {
|
||||
|
|
@ -84,6 +119,15 @@ class SupersetTestCase(TestCase):
|
|||
def create_app(self):
|
||||
return app
|
||||
|
||||
@staticmethod
|
||||
def get_birth_names_dataset():
|
||||
example_db = get_example_database()
|
||||
return (
|
||||
db.session.query(SqlaTable)
|
||||
.filter_by(database=example_db, table_name="birth_names")
|
||||
.one()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_user_with_roles(username: str, roles: List[str]):
|
||||
user_to_create = security_manager.find_user(username)
|
||||
|
|
@ -422,24 +466,7 @@ class SupersetTestCase(TestCase):
|
|||
def post_assert_metric(
|
||||
self, uri: str, data: Dict[str, Any], func_name: str
|
||||
) -> Response:
|
||||
"""
|
||||
Simple client post with an extra assertion for statsd metrics
|
||||
|
||||
:param uri: The URI to use for the HTTP POST
|
||||
:param data: The JSON data payload to be posted
|
||||
:param func_name: The function name that the HTTP POST triggers
|
||||
for the statsd metric assertion
|
||||
:return: HTTP Response
|
||||
"""
|
||||
with patch.object(
|
||||
BaseSupersetModelRestApi, "incr_stats", return_value=None
|
||||
) as mock_method:
|
||||
rv = self.client.post(uri, json=data)
|
||||
if 200 <= rv.status_code < 400:
|
||||
mock_method.assert_called_once_with("success", func_name)
|
||||
else:
|
||||
mock_method.assert_called_once_with("error", func_name)
|
||||
return rv
|
||||
return post_assert_metric(self.client, uri, data, func_name)
|
||||
|
||||
def put_assert_metric(
|
||||
self, uri: str, data: Dict[str, Any], func_name: str
|
||||
|
|
|
|||
|
|
@ -0,0 +1,16 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
|
@ -0,0 +1,163 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# isort:skip_file
|
||||
"""Unit tests for Superset"""
|
||||
from typing import Dict, Any
|
||||
|
||||
from tests.test_app import app # noqa
|
||||
|
||||
from superset.extensions import cache_manager, db
|
||||
from superset.models.cache import CacheKey
|
||||
from tests.base_tests import (
|
||||
SupersetTestCase,
|
||||
post_assert_metric,
|
||||
test_client,
|
||||
logged_in_admin,
|
||||
) # noqa
|
||||
|
||||
|
||||
def invalidate(params: Dict[str, Any]):
|
||||
return post_assert_metric(
|
||||
test_client, "api/v1/cachekey/invalidate", params, "invalidate"
|
||||
)
|
||||
|
||||
|
||||
def test_invalidate_cache(logged_in_admin):
|
||||
rv = invalidate({"datasource_uids": ["3__table"]})
|
||||
assert rv.status_code == 201
|
||||
|
||||
|
||||
def test_invalidate_existing_cache(logged_in_admin):
|
||||
db.session.add(CacheKey(cache_key="cache_key", datasource_uid="3__table"))
|
||||
db.session.commit()
|
||||
cache_manager.cache.set("cache_key", "value")
|
||||
|
||||
rv = invalidate({"datasource_uids": ["3__table"]})
|
||||
|
||||
assert rv.status_code == 201
|
||||
assert cache_manager.cache.get("cache_key") == None
|
||||
assert (
|
||||
not db.session.query(CacheKey).filter(CacheKey.cache_key == "cache_key").first()
|
||||
)
|
||||
|
||||
|
||||
def test_invalidate_cache_empty_input(logged_in_admin):
|
||||
rv = invalidate({"datasource_uids": []})
|
||||
assert rv.status_code == 201
|
||||
|
||||
rv = invalidate({"datasources": []})
|
||||
assert rv.status_code == 201
|
||||
|
||||
rv = invalidate({"datasource_uids": [], "datasources": []})
|
||||
assert rv.status_code == 201
|
||||
|
||||
|
||||
def test_invalidate_cache_bad_request(logged_in_admin):
|
||||
rv = invalidate(
|
||||
{
|
||||
"datasource_uids": [],
|
||||
"datasources": [{"datasource_name": "", "datasource_type": None}],
|
||||
}
|
||||
)
|
||||
assert rv.status_code == 400
|
||||
|
||||
rv = invalidate(
|
||||
{
|
||||
"datasource_uids": [],
|
||||
"datasources": [{"datasource_name": "", "datasource_type": "bla"}],
|
||||
}
|
||||
)
|
||||
assert rv.status_code == 400
|
||||
|
||||
rv = invalidate(
|
||||
{
|
||||
"datasource_uids": "datasource",
|
||||
"datasources": [{"datasource_name": "", "datasource_type": "bla"}],
|
||||
}
|
||||
)
|
||||
assert rv.status_code == 400
|
||||
|
||||
|
||||
def test_invalidate_existing_caches(logged_in_admin):
|
||||
bn = SupersetTestCase.get_birth_names_dataset()
|
||||
|
||||
db.session.add(CacheKey(cache_key="cache_key1", datasource_uid="3__druid"))
|
||||
db.session.add(CacheKey(cache_key="cache_key2", datasource_uid="3__druid"))
|
||||
db.session.add(CacheKey(cache_key="cache_key4", datasource_uid=f"{bn.id}__table"))
|
||||
db.session.add(CacheKey(cache_key="cache_keyX", datasource_uid="X__table"))
|
||||
db.session.commit()
|
||||
|
||||
cache_manager.cache.set("cache_key1", "value")
|
||||
cache_manager.cache.set("cache_key2", "value")
|
||||
cache_manager.cache.set("cache_key4", "value")
|
||||
cache_manager.cache.set("cache_keyX", "value")
|
||||
|
||||
rv = invalidate(
|
||||
{
|
||||
"datasource_uids": ["3__druid", "4__druid"],
|
||||
"datasources": [
|
||||
{
|
||||
"datasource_name": "birth_names",
|
||||
"database_name": "examples",
|
||||
"schema": "",
|
||||
"datasource_type": "table",
|
||||
},
|
||||
{ # table exists, no cache to invalidate
|
||||
"datasource_name": "energy_usage",
|
||||
"database_name": "examples",
|
||||
"schema": "",
|
||||
"datasource_type": "table",
|
||||
},
|
||||
{ # table doesn't exist
|
||||
"datasource_name": "does_not_exist",
|
||||
"database_name": "examples",
|
||||
"schema": "",
|
||||
"datasource_type": "table",
|
||||
},
|
||||
{ # database doesn't exist
|
||||
"datasource_name": "birth_names",
|
||||
"database_name": "does_not_exist",
|
||||
"schema": "",
|
||||
"datasource_type": "table",
|
||||
},
|
||||
{ # database doesn't exist
|
||||
"datasource_name": "birth_names",
|
||||
"database_name": "examples",
|
||||
"schema": "does_not_exist",
|
||||
"datasource_type": "table",
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
assert rv.status_code == 201
|
||||
assert cache_manager.cache.get("cache_key1") is None
|
||||
assert cache_manager.cache.get("cache_key2") is None
|
||||
assert cache_manager.cache.get("cache_key4") is None
|
||||
assert cache_manager.cache.get("cache_keyX") == "value"
|
||||
assert (
|
||||
not db.session.query(CacheKey)
|
||||
.filter(CacheKey.cache_key.in_({"cache_key1", "cache_key2", "cache_key4"}))
|
||||
.first()
|
||||
)
|
||||
assert (
|
||||
db.session.query(CacheKey)
|
||||
.filter(CacheKey.cache_key == "cache_keyX")
|
||||
.first()
|
||||
.datasource_uid
|
||||
== "X__table"
|
||||
)
|
||||
|
|
@ -60,15 +60,6 @@ class TestDatasetApi(SupersetTestCase):
|
|||
"ab_permission", "", [self.get_user("admin").id], get_main_database()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_birth_names_dataset():
|
||||
example_db = get_example_database()
|
||||
return (
|
||||
db.session.query(SqlaTable)
|
||||
.filter_by(database=example_db, table_name="birth_names")
|
||||
.one()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_energy_usage_dataset():
|
||||
example_db = get_example_database()
|
||||
|
|
|
|||
|
|
@ -76,6 +76,14 @@ REDIS_HOST = os.environ.get("REDIS_HOST", "localhost")
|
|||
REDIS_PORT = os.environ.get("REDIS_PORT", "6379")
|
||||
REDIS_CELERY_DB = os.environ.get("REDIS_CELERY_DB", 2)
|
||||
REDIS_RESULTS_DB = os.environ.get("REDIS_RESULTS_DB", 3)
|
||||
REDIS_CACHE_DB = os.environ.get("REDIS_CACHE_DB", 4)
|
||||
|
||||
CACHE_CONFIG = {
|
||||
"CACHE_TYPE": "redis",
|
||||
"CACHE_DEFAULT_TIMEOUT": 60 * 60 * 24, # 1 day default (in secs)
|
||||
"CACHE_KEY_PREFIX": "superset_cache",
|
||||
"CACHE_REDIS_URL": f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_CACHE_DB}",
|
||||
}
|
||||
|
||||
|
||||
class CeleryConfig(object):
|
||||
|
|
|
|||
Loading…
Reference in New Issue