feat: implement cache invalidation api (#10761)

* Add cache endpoints

* Implement cache endpoint

* Tests and address feedback

* Set cache config

* Address feedback

* Expose only invalidate endpoint

Co-authored-by: bogdan kyryliuk <bogdankyryliuk@dropbox.com>
This commit is contained in:
Bogdan 2020-09-15 11:17:21 -07:00 committed by GitHub
parent 838a70ea8d
commit 9c420d6efe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 422 additions and 27 deletions

View File

@ -125,6 +125,7 @@ class SupersetAppInitializer:
#
# pylint: disable=too-many-locals
# pylint: disable=too-many-statements
from superset.cachekeys.api import CacheRestApi
from superset.charts.api import ChartRestApi
from superset.connectors.druid.views import (
Druid,
@ -194,6 +195,7 @@ class SupersetAppInitializer:
#
# Setup API views
#
appbuilder.add_api(CacheRestApi)
appbuilder.add_api(ChartRestApi)
appbuilder.add_api(DashboardRestApi)
appbuilder.add_api(DatabaseRestApi)

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

123
superset/cachekeys/api.py Normal file
View File

@ -0,0 +1,123 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from flask import request, Response
from flask_appbuilder import expose
from flask_appbuilder.api import safe
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_appbuilder.security.decorators import protect
from marshmallow.exceptions import ValidationError
from sqlalchemy.exc import SQLAlchemyError
from superset.cachekeys.schemas import CacheInvalidationRequestSchema
from superset.connectors.connector_registry import ConnectorRegistry
from superset.extensions import cache_manager, db, event_logger
from superset.models.cache import CacheKey
from superset.views.base_api import BaseSupersetModelRestApi, statsd_metrics
logger = logging.getLogger(__name__)
class CacheRestApi(BaseSupersetModelRestApi):
datamodel = SQLAInterface(CacheKey)
resource_name = "cachekey"
allow_browser_login = True
class_permission_name = "CacheRestApi"
include_route_methods = {
"invalidate",
}
openapi_spec_component_schemas = (CacheInvalidationRequestSchema,)
@expose("/invalidate", methods=["POST"])
@event_logger.log_this
@protect()
@safe
@statsd_metrics
def invalidate(self) -> Response:
"""
Takes a list of datasources, finds the associated cache records and
invalidates them and removes the database records
---
post:
description: >-
Takes a list of datasources, finds the associated cache records and
invalidates them and removes the database records
requestBody:
description: >-
A list of datasources uuid or the tuples of database and datasource names
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/CacheInvalidationRequestSchema"
responses:
201:
description: cache was successfully invalidated
400:
$ref: '#/components/responses/400'
500:
$ref: '#/components/responses/500'
"""
try:
datasources = CacheInvalidationRequestSchema().load(request.json)
except KeyError:
return self.response_400(message="Request is incorrect")
except ValidationError as error:
return self.response_400(message=str(error))
datasource_uids = set(datasources.get("datasource_uids", []))
for ds in datasources.get("datasources", []):
ds_obj = ConnectorRegistry.get_datasource_by_name(
session=db.session,
datasource_type=ds.get("datasource_type"),
datasource_name=ds.get("datasource_name"),
schema=ds.get("schema"),
database_name=ds.get("database_name"),
)
if ds_obj:
datasource_uids.add(ds_obj.uid)
cache_key_objs = (
db.session.query(CacheKey)
.filter(CacheKey.datasource_uid.in_(datasource_uids))
.all()
)
cache_keys = [c.cache_key for c in cache_key_objs]
if cache_key_objs:
all_keys_deleted = cache_manager.cache.delete_many(*cache_keys)
if not all_keys_deleted:
# expected behavior as keys may expire and cache is not a
# persistent storage
logger.info(
"Some of the cache keys were not deleted in the list %s", cache_keys
)
try:
delete_stmt = CacheKey.__table__.delete().where( # pylint: disable=no-member
CacheKey.cache_key.in_(cache_keys)
)
db.session.execute(delete_stmt)
db.session.commit()
except SQLAlchemyError as ex: # pragma: no cover
logger.error(ex)
db.session.rollback()
return self.response_500(str(ex))
db.session.commit()
return self.response(201)

View File

@ -0,0 +1,45 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# RISON/JSON schemas for query parameters
from marshmallow import fields, Schema, validate
from superset.charts.schemas import (
datasource_name_description,
datasource_type_description,
datasource_uid_description,
)
class Datasource(Schema):
database_name = fields.String(description="Datasource name",)
datasource_name = fields.String(description=datasource_name_description,)
schema = fields.String(description="Datasource schema",)
datasource_type = fields.String(
description=datasource_type_description,
validate=validate.OneOf(choices=("druid", "table", "view")),
required=True,
)
class CacheInvalidationRequestSchema(Schema):
datasource_uids = fields.List(
fields.String(), description=datasource_uid_description,
)
datasources = fields.List(
fields.Nested(Datasource),
description="A list of the data source and database names",
)

View File

@ -71,6 +71,10 @@ datasource_id_description = (
"A complete datasource identification needs `datasouce_id` "
"and `datasource_type`."
)
datasource_uid_description = (
"The uid of the dataset/datasource this new chart will use. "
"A complete datasource identification needs `datasouce_uid` "
)
datasource_type_description = (
"The type of dataset/datasource identified on `datasource_id`."
)

View File

@ -22,6 +22,7 @@ from typing import Any, Dict, Union, List, Optional
from unittest.mock import Mock, patch
import pandas as pd
import pytest
from flask import Response
from flask_appbuilder.security.sqla import models as ab_models
from flask_testing import TestCase
@ -42,6 +43,7 @@ from superset.utils.core import get_example_database
from superset.views.base_api import BaseSupersetModelRestApi
FAKE_DB_NAME = "fake_db_100"
test_client = app.test_client()
def login(client: Any, username: str = "admin", password: str = "general"):
@ -69,6 +71,39 @@ def get_resp(
return resp.data.decode("utf-8")
def post_assert_metric(
client: Any, uri: str, data: Dict[str, Any], func_name: str
) -> Response:
"""
Simple client post with an extra assertion for statsd metrics
:param client: test client for superset api requests
:param uri: The URI to use for the HTTP POST
:param data: The JSON data payload to be posted
:param func_name: The function name that the HTTP POST triggers
for the statsd metric assertion
:return: HTTP Response
"""
with patch.object(
BaseSupersetModelRestApi, "incr_stats", return_value=None
) as mock_method:
rv = client.post(uri, json=data)
if 200 <= rv.status_code < 400:
mock_method.assert_called_once_with("success", func_name)
else:
mock_method.assert_called_once_with("error", func_name)
return rv
@pytest.fixture
def logged_in_admin():
"""Fixture with app context and logged in admin user."""
with app.app_context():
login(test_client, username="admin")
yield
test_client.get("/logout/", follow_redirects=True)
class SupersetTestCase(TestCase):
default_schema_backend_map = {
@ -84,6 +119,15 @@ class SupersetTestCase(TestCase):
def create_app(self):
return app
@staticmethod
def get_birth_names_dataset():
example_db = get_example_database()
return (
db.session.query(SqlaTable)
.filter_by(database=example_db, table_name="birth_names")
.one()
)
@staticmethod
def create_user_with_roles(username: str, roles: List[str]):
user_to_create = security_manager.find_user(username)
@ -422,24 +466,7 @@ class SupersetTestCase(TestCase):
def post_assert_metric(
self, uri: str, data: Dict[str, Any], func_name: str
) -> Response:
"""
Simple client post with an extra assertion for statsd metrics
:param uri: The URI to use for the HTTP POST
:param data: The JSON data payload to be posted
:param func_name: The function name that the HTTP POST triggers
for the statsd metric assertion
:return: HTTP Response
"""
with patch.object(
BaseSupersetModelRestApi, "incr_stats", return_value=None
) as mock_method:
rv = self.client.post(uri, json=data)
if 200 <= rv.status_code < 400:
mock_method.assert_called_once_with("success", func_name)
else:
mock_method.assert_called_once_with("error", func_name)
return rv
return post_assert_metric(self.client, uri, data, func_name)
def put_assert_metric(
self, uri: str, data: Dict[str, Any], func_name: str

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,163 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
"""Unit tests for Superset"""
from typing import Dict, Any
from tests.test_app import app # noqa
from superset.extensions import cache_manager, db
from superset.models.cache import CacheKey
from tests.base_tests import (
SupersetTestCase,
post_assert_metric,
test_client,
logged_in_admin,
) # noqa
def invalidate(params: Dict[str, Any]):
return post_assert_metric(
test_client, "api/v1/cachekey/invalidate", params, "invalidate"
)
def test_invalidate_cache(logged_in_admin):
rv = invalidate({"datasource_uids": ["3__table"]})
assert rv.status_code == 201
def test_invalidate_existing_cache(logged_in_admin):
db.session.add(CacheKey(cache_key="cache_key", datasource_uid="3__table"))
db.session.commit()
cache_manager.cache.set("cache_key", "value")
rv = invalidate({"datasource_uids": ["3__table"]})
assert rv.status_code == 201
assert cache_manager.cache.get("cache_key") == None
assert (
not db.session.query(CacheKey).filter(CacheKey.cache_key == "cache_key").first()
)
def test_invalidate_cache_empty_input(logged_in_admin):
rv = invalidate({"datasource_uids": []})
assert rv.status_code == 201
rv = invalidate({"datasources": []})
assert rv.status_code == 201
rv = invalidate({"datasource_uids": [], "datasources": []})
assert rv.status_code == 201
def test_invalidate_cache_bad_request(logged_in_admin):
rv = invalidate(
{
"datasource_uids": [],
"datasources": [{"datasource_name": "", "datasource_type": None}],
}
)
assert rv.status_code == 400
rv = invalidate(
{
"datasource_uids": [],
"datasources": [{"datasource_name": "", "datasource_type": "bla"}],
}
)
assert rv.status_code == 400
rv = invalidate(
{
"datasource_uids": "datasource",
"datasources": [{"datasource_name": "", "datasource_type": "bla"}],
}
)
assert rv.status_code == 400
def test_invalidate_existing_caches(logged_in_admin):
bn = SupersetTestCase.get_birth_names_dataset()
db.session.add(CacheKey(cache_key="cache_key1", datasource_uid="3__druid"))
db.session.add(CacheKey(cache_key="cache_key2", datasource_uid="3__druid"))
db.session.add(CacheKey(cache_key="cache_key4", datasource_uid=f"{bn.id}__table"))
db.session.add(CacheKey(cache_key="cache_keyX", datasource_uid="X__table"))
db.session.commit()
cache_manager.cache.set("cache_key1", "value")
cache_manager.cache.set("cache_key2", "value")
cache_manager.cache.set("cache_key4", "value")
cache_manager.cache.set("cache_keyX", "value")
rv = invalidate(
{
"datasource_uids": ["3__druid", "4__druid"],
"datasources": [
{
"datasource_name": "birth_names",
"database_name": "examples",
"schema": "",
"datasource_type": "table",
},
{ # table exists, no cache to invalidate
"datasource_name": "energy_usage",
"database_name": "examples",
"schema": "",
"datasource_type": "table",
},
{ # table doesn't exist
"datasource_name": "does_not_exist",
"database_name": "examples",
"schema": "",
"datasource_type": "table",
},
{ # database doesn't exist
"datasource_name": "birth_names",
"database_name": "does_not_exist",
"schema": "",
"datasource_type": "table",
},
{ # database doesn't exist
"datasource_name": "birth_names",
"database_name": "examples",
"schema": "does_not_exist",
"datasource_type": "table",
},
],
}
)
assert rv.status_code == 201
assert cache_manager.cache.get("cache_key1") is None
assert cache_manager.cache.get("cache_key2") is None
assert cache_manager.cache.get("cache_key4") is None
assert cache_manager.cache.get("cache_keyX") == "value"
assert (
not db.session.query(CacheKey)
.filter(CacheKey.cache_key.in_({"cache_key1", "cache_key2", "cache_key4"}))
.first()
)
assert (
db.session.query(CacheKey)
.filter(CacheKey.cache_key == "cache_keyX")
.first()
.datasource_uid
== "X__table"
)

View File

@ -60,15 +60,6 @@ class TestDatasetApi(SupersetTestCase):
"ab_permission", "", [self.get_user("admin").id], get_main_database()
)
@staticmethod
def get_birth_names_dataset():
example_db = get_example_database()
return (
db.session.query(SqlaTable)
.filter_by(database=example_db, table_name="birth_names")
.one()
)
@staticmethod
def get_energy_usage_dataset():
example_db = get_example_database()

View File

@ -76,6 +76,14 @@ REDIS_HOST = os.environ.get("REDIS_HOST", "localhost")
REDIS_PORT = os.environ.get("REDIS_PORT", "6379")
REDIS_CELERY_DB = os.environ.get("REDIS_CELERY_DB", 2)
REDIS_RESULTS_DB = os.environ.get("REDIS_RESULTS_DB", 3)
REDIS_CACHE_DB = os.environ.get("REDIS_CACHE_DB", 4)
CACHE_CONFIG = {
"CACHE_TYPE": "redis",
"CACHE_DEFAULT_TIMEOUT": 60 * 60 * 24, # 1 day default (in secs)
"CACHE_KEY_PREFIX": "superset_cache",
"CACHE_REDIS_URL": f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_CACHE_DB}",
}
class CeleryConfig(object):