fix: memoized decorator memory leak (#23139)

This commit is contained in:
Daniel Vaz Gaspar 2023-02-27 15:59:11 +00:00 committed by GitHub
parent ad5ee1ce38
commit 79274eb5bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 24 additions and 203 deletions

View File

@ -17,6 +17,7 @@
from __future__ import annotations
import logging
from functools import lru_cache
from typing import (
Any,
Callable,
@ -40,6 +41,7 @@ from sqlalchemy.orm import Session
from sqlalchemy.orm.exc import ObjectDeletedError
from sqlalchemy.sql.type_api import TypeEngine
from superset.constants import LRU_CACHE_MAX_SIZE
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import (
SupersetGenericDBErrorException,
@ -49,7 +51,6 @@ from superset.models.core import Database
from superset.result_set import SupersetResultSet
from superset.sql_parse import has_table_query, insert_rls, ParsedQuery
from superset.superset_typing import ResultSetColumnType
from superset.utils.memoized import memoized
if TYPE_CHECKING:
from superset.connectors.sqla.models import SqlaTable
@ -200,12 +201,12 @@ def validate_adhoc_subquery(
return ";\n".join(str(statement) for statement in statements)
@memoized
@lru_cache(maxsize=LRU_CACHE_MAX_SIZE)
def get_dialect_name(drivername: str) -> str:
return SqlaURL.create(drivername).get_dialect().name
@memoized
@lru_cache(maxsize=LRU_CACHE_MAX_SIZE)
def get_identifier_quoter(drivername: str) -> Dict[str, Callable[[str], str]]:
return SqlaURL.create(drivername).get_dialect()().identifier_preparer.quote

View File

@ -37,6 +37,8 @@ NO_TIME_RANGE = "No filter"
QUERY_CANCEL_KEY = "cancel_query"
QUERY_EARLY_CANCEL_KEY = "early_cancel_query"
LRU_CACHE_MAX_SIZE = 256
class RouteMethod: # pylint: disable=too-few-public-methods
"""

View File

@ -17,7 +17,7 @@
"""Defines the templating context for SQL Lab"""
import json
import re
from functools import partial
from functools import lru_cache, partial
from typing import (
Any,
Callable,
@ -38,6 +38,7 @@ from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.types import String
from typing_extensions import TypedDict
from superset.constants import LRU_CACHE_MAX_SIZE
from superset.datasets.commands.exceptions import DatasetNotFoundError
from superset.exceptions import SupersetTemplateException
from superset.extensions import feature_flag_manager
@ -46,7 +47,6 @@ from superset.utils.core import (
get_user_id,
merge_extra_filters,
)
from superset.utils.memoized import memoized
if TYPE_CHECKING:
from superset.connectors.sqla.models import SqlaTable
@ -70,7 +70,7 @@ ALLOWED_TYPES = (
COLLECTION_TYPES = ("list", "dict", "tuple", "set")
@memoized
@lru_cache(maxsize=LRU_CACHE_MAX_SIZE)
def context_addons() -> Dict[str, Any]:
return current_app.config.get("JINJA_CONTEXT_ADDONS", {})
@ -602,7 +602,7 @@ DEFAULT_PROCESSORS = {
}
@memoized
@lru_cache(maxsize=LRU_CACHE_MAX_SIZE)
def get_template_processors() -> Dict[str, Any]:
processors = current_app.config.get("CUSTOM_TEMPLATE_PROCESSORS", {})
for engine, processor in DEFAULT_PROCESSORS.items():

View File

@ -34,7 +34,6 @@ from sqlalchemy.ext.declarative import declarative_base
from superset import db, db_engine_specs
from superset.databases.utils import make_url_safe
from superset.utils.memoized import memoized
Base = declarative_base()
@ -70,7 +69,6 @@ class Slice(Base):
datasource_id = Column(Integer)
@memoized
def duration_by_name(database: Database):
return {grain.name: grain.duration for grain in database.grains()}

View File

@ -24,6 +24,7 @@ from ast import literal_eval
from contextlib import closing, contextmanager, nullcontext
from copy import deepcopy
from datetime import datetime
from functools import lru_cache
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TYPE_CHECKING
import numpy
@ -54,7 +55,7 @@ from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import expression, Select
from superset import app, db_engine_specs
from superset.constants import PASSWORD_MASK
from superset.constants import LRU_CACHE_MAX_SIZE, PASSWORD_MASK
from superset.databases.utils import make_url_safe
from superset.db_engine_specs.base import MetricType, TimeGrain
from superset.extensions import (
@ -67,7 +68,6 @@ from superset.models.helpers import AuditMixinNullable, ImportExportMixin
from superset.result_set import SupersetResultSet
from superset.utils import cache as cache_util, core as utils
from superset.utils.core import get_username
from superset.utils.memoized import memoized
config = app.config
custom_password_store = config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"]
@ -723,7 +723,7 @@ class Database(
return self.get_db_engine_spec(url)
@classmethod
@memoized
@lru_cache(maxsize=LRU_CACHE_MAX_SIZE)
def get_db_engine_spec(cls, url: URL) -> Type[db_engine_specs.BaseEngineSpec]:
backend = url.get_backend_name()
try:
@ -897,7 +897,6 @@ class Database(
def has_view_by_name(self, view_name: str, schema: Optional[str] = None) -> bool:
return self.has_view(view_name=view_name, schema=schema)
@memoized
def get_dialect(self) -> Dialect:
sqla_url = make_url_safe(self.sqlalchemy_uri_decrypted)
return sqla_url.get_dialect()()

View File

@ -22,7 +22,6 @@ from sqlalchemy import Column, Integer, String
from superset import app, db, security_manager
from superset.models.helpers import AuditMixinNullable
from superset.utils.memoized import memoized
if TYPE_CHECKING:
from superset.connectors.base.models import BaseDatasource
@ -57,7 +56,6 @@ class DatasourceAccessRequest(Model, AuditMixinNullable):
return self.get_datasource
@datasource.getter # type: ignore
@memoized
def get_datasource(self) -> "BaseDatasource":
ds = db.session.query(self.cls_model).filter_by(id=self.datasource_id).first()
return ds

View File

@ -46,7 +46,6 @@ from superset.tasks.thumbnails import cache_chart_thumbnail
from superset.tasks.utils import get_current_user
from superset.thumbnails.digest import get_chart_digest
from superset.utils import core as utils
from superset.utils.memoized import memoized
from superset.viz import BaseViz, viz_types
if TYPE_CHECKING:
@ -159,9 +158,12 @@ class Slice( # pylint: disable=too-many-public-methods
# pylint: disable=using-constant-test
@datasource.getter # type: ignore
@memoized
def get_datasource(self) -> Optional["BaseDatasource"]:
return db.session.query(self.cls_model).filter_by(id=self.datasource_id).first()
return (
db.session.query(self.cls_model)
.filter_by(id=self.datasource_id)
.one_or_none()
)
@renders("datasource_name")
def datasource_link(self) -> Optional[Markup]:
@ -197,8 +199,7 @@ class Slice( # pylint: disable=too-many-public-methods
# pylint: enable=using-constant-test
@property # type: ignore
@memoized
@property
def viz(self) -> Optional[BaseViz]:
form_data = json.loads(self.params)
viz_class = viz_types.get(self.viz_type)

View File

@ -18,6 +18,7 @@ import calendar
import logging
import re
from datetime import datetime, timedelta
from functools import lru_cache
from time import struct_time
from typing import Dict, List, Optional, Tuple
@ -45,8 +46,7 @@ from superset.charts.commands.exceptions import (
TimeRangeAmbiguousError,
TimeRangeParseFailError,
)
from superset.constants import NO_TIME_RANGE
from superset.utils.memoized import memoized
from superset.constants import LRU_CACHE_MAX_SIZE, NO_TIME_RANGE
ParserElement.enablePackrat()
@ -394,7 +394,7 @@ class EvalHolidayFunc: # pylint: disable=too-few-public-methods
)
@memoized
@lru_cache(maxsize=LRU_CACHE_MAX_SIZE)
def datetime_parser() -> ParseResults: # pylint: disable=too-many-locals
( # pylint: disable=invalid-name
DATETIME,

View File

@ -1,81 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import functools
from typing import Any, Callable, Dict, Optional, Tuple, Type
class _memoized:
"""Decorator that caches a function's return value each time it is called
If called later with the same arguments, the cached value is returned, and
not re-evaluated.
Define ``watch`` as a tuple of attribute names if this Decorator
should account for instance variable changes.
"""
def __init__(
self, func: Callable[..., Any], watch: Optional[Tuple[str, ...]] = None
) -> None:
self.func = func
self.cache: Dict[Any, Any] = {}
self.is_method = False
self.watch = watch or ()
def __call__(self, *args: Any, **kwargs: Any) -> Any:
key = [args, frozenset(kwargs.items())]
if self.is_method:
key.append(tuple(getattr(args[0], v, None) for v in self.watch))
key = tuple(key) # type: ignore
try:
if key in self.cache:
return self.cache[key]
except TypeError as ex:
# Uncachable -- for instance, passing a list as an argument.
raise TypeError("Function cannot be memoized") from ex
value = self.func(*args, **kwargs)
try:
self.cache[key] = value
except TypeError as ex:
raise TypeError("Function cannot be memoized") from ex
return value
def __repr__(self) -> str:
"""Return the function's docstring."""
return self.func.__doc__ or ""
def __get__(
self, obj: Any, objtype: Type[Any]
) -> functools.partial: # type: ignore
if not self.is_method:
self.is_method = True
# Support instance methods.
func = functools.partial(self.__call__, obj)
func.__func__ = self.func # type: ignore
return func
def memoized(
func: Optional[Callable[..., Any]] = None, watch: Optional[Tuple[str, ...]] = None
) -> Callable[..., Any]:
if func:
return _memoized(func)
def wrapper(f: Callable[..., Any]) -> Callable[..., Any]:
return _memoized(f, watch)
return wrapper

View File

@ -991,6 +991,7 @@ class TestUtils(SupersetTestCase):
slc = self.get_slice("Girls", db.session)
dashboard_id = 1
assert slc.viz is not None
resp = self.get_json_resp(
f"/superset/explore_json/{slc.datasource_type}/{slc.datasource_id}/"
+ f'?form_data={{"slice_id": {slc.id}}}&dashboard_id={dashboard_id}',

View File

@ -1,96 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from pytest import mark
from superset.utils.memoized import memoized
@mark.unittest
class TestMemoized:
def test_memoized_on_functions(self):
watcher = {"val": 0}
@memoized
def test_function(a, b, c):
watcher["val"] += 1
return a * b * c
result1 = test_function(1, 2, 3)
result2 = test_function(1, 2, 3)
assert result1 == result2
assert watcher["val"] == 1
def test_memoized_on_methods(self):
class test_class:
def __init__(self, num):
self.num = num
self.watcher = 0
@memoized
def test_method(self, a, b, c):
self.watcher += 1
return a * b * c * self.num
instance = test_class(5)
result1 = instance.test_method(1, 2, 3)
result2 = instance.test_method(1, 2, 3)
assert result1 == result2
assert instance.watcher == 1
instance.num = 10
assert result2 == instance.test_method(1, 2, 3)
def test_memoized_on_methods_with_watches(self):
class test_class:
def __init__(self, x, y):
self.x = x
self.y = y
self.watcher = 0
@memoized(watch=("x", "y"))
def test_method(self, a, b, c):
self.watcher += 1
return a * b * c * self.x * self.y
instance = test_class(3, 12)
result1 = instance.test_method(1, 2, 3)
result2 = instance.test_method(1, 2, 3)
assert result1 == result2
assert instance.watcher == 1
result3 = instance.test_method(2, 3, 4)
assert instance.watcher == 2
result4 = instance.test_method(2, 3, 4)
assert instance.watcher == 2
assert result3 == result4
assert result3 != result1
instance.x = 1
result5 = instance.test_method(2, 3, 4)
assert instance.watcher == 3
assert result5 != result4
result6 = instance.test_method(2, 3, 4)
assert instance.watcher == 3
assert result6 == result5
instance.x = 10
instance.y = 10
result7 = instance.test_method(2, 3, 4)
assert instance.watcher == 4
assert result7 != result6
instance.x = 3
instance.y = 12
result8 = instance.test_method(1, 2, 3)
assert instance.watcher == 4
assert result1 == result8

View File

@ -59,9 +59,7 @@ def test_get_metrics(mocker: MockFixture) -> None:
},
]
database.get_db_engine_spec = mocker.MagicMock( # type: ignore
return_value=CustomSqliteEngineSpec
)
database.get_db_engine_spec = mocker.MagicMock(return_value=CustomSqliteEngineSpec)
assert database.get_metrics("table") == [
{
"expression": "COUNT(DISTINCT user_id)",