fix: memoized decorator memory leak (#23139)
This commit is contained in:
parent
ad5ee1ce38
commit
79274eb5bc
|
|
@ -17,6 +17,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
|
|
@ -40,6 +41,7 @@ from sqlalchemy.orm import Session
|
|||
from sqlalchemy.orm.exc import ObjectDeletedError
|
||||
from sqlalchemy.sql.type_api import TypeEngine
|
||||
|
||||
from superset.constants import LRU_CACHE_MAX_SIZE
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import (
|
||||
SupersetGenericDBErrorException,
|
||||
|
|
@ -49,7 +51,6 @@ from superset.models.core import Database
|
|||
from superset.result_set import SupersetResultSet
|
||||
from superset.sql_parse import has_table_query, insert_rls, ParsedQuery
|
||||
from superset.superset_typing import ResultSetColumnType
|
||||
from superset.utils.memoized import memoized
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
|
@ -200,12 +201,12 @@ def validate_adhoc_subquery(
|
|||
return ";\n".join(str(statement) for statement in statements)
|
||||
|
||||
|
||||
@memoized
|
||||
@lru_cache(maxsize=LRU_CACHE_MAX_SIZE)
|
||||
def get_dialect_name(drivername: str) -> str:
|
||||
return SqlaURL.create(drivername).get_dialect().name
|
||||
|
||||
|
||||
@memoized
|
||||
@lru_cache(maxsize=LRU_CACHE_MAX_SIZE)
|
||||
def get_identifier_quoter(drivername: str) -> Dict[str, Callable[[str], str]]:
|
||||
return SqlaURL.create(drivername).get_dialect()().identifier_preparer.quote
|
||||
|
||||
|
|
|
|||
|
|
@ -37,6 +37,8 @@ NO_TIME_RANGE = "No filter"
|
|||
QUERY_CANCEL_KEY = "cancel_query"
|
||||
QUERY_EARLY_CANCEL_KEY = "early_cancel_query"
|
||||
|
||||
LRU_CACHE_MAX_SIZE = 256
|
||||
|
||||
|
||||
class RouteMethod: # pylint: disable=too-few-public-methods
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@
|
|||
"""Defines the templating context for SQL Lab"""
|
||||
import json
|
||||
import re
|
||||
from functools import partial
|
||||
from functools import lru_cache, partial
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
|
|
@ -38,6 +38,7 @@ from sqlalchemy.engine.interfaces import Dialect
|
|||
from sqlalchemy.types import String
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from superset.constants import LRU_CACHE_MAX_SIZE
|
||||
from superset.datasets.commands.exceptions import DatasetNotFoundError
|
||||
from superset.exceptions import SupersetTemplateException
|
||||
from superset.extensions import feature_flag_manager
|
||||
|
|
@ -46,7 +47,6 @@ from superset.utils.core import (
|
|||
get_user_id,
|
||||
merge_extra_filters,
|
||||
)
|
||||
from superset.utils.memoized import memoized
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
|
@ -70,7 +70,7 @@ ALLOWED_TYPES = (
|
|||
COLLECTION_TYPES = ("list", "dict", "tuple", "set")
|
||||
|
||||
|
||||
@memoized
|
||||
@lru_cache(maxsize=LRU_CACHE_MAX_SIZE)
|
||||
def context_addons() -> Dict[str, Any]:
|
||||
return current_app.config.get("JINJA_CONTEXT_ADDONS", {})
|
||||
|
||||
|
|
@ -602,7 +602,7 @@ DEFAULT_PROCESSORS = {
|
|||
}
|
||||
|
||||
|
||||
@memoized
|
||||
@lru_cache(maxsize=LRU_CACHE_MAX_SIZE)
|
||||
def get_template_processors() -> Dict[str, Any]:
|
||||
processors = current_app.config.get("CUSTOM_TEMPLATE_PROCESSORS", {})
|
||||
for engine, processor in DEFAULT_PROCESSORS.items():
|
||||
|
|
|
|||
|
|
@ -34,7 +34,6 @@ from sqlalchemy.ext.declarative import declarative_base
|
|||
|
||||
from superset import db, db_engine_specs
|
||||
from superset.databases.utils import make_url_safe
|
||||
from superset.utils.memoized import memoized
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
|
@ -70,7 +69,6 @@ class Slice(Base):
|
|||
datasource_id = Column(Integer)
|
||||
|
||||
|
||||
@memoized
|
||||
def duration_by_name(database: Database):
|
||||
return {grain.name: grain.duration for grain in database.grains()}
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from ast import literal_eval
|
|||
from contextlib import closing, contextmanager, nullcontext
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TYPE_CHECKING
|
||||
|
||||
import numpy
|
||||
|
|
@ -54,7 +55,7 @@ from sqlalchemy.schema import UniqueConstraint
|
|||
from sqlalchemy.sql import expression, Select
|
||||
|
||||
from superset import app, db_engine_specs
|
||||
from superset.constants import PASSWORD_MASK
|
||||
from superset.constants import LRU_CACHE_MAX_SIZE, PASSWORD_MASK
|
||||
from superset.databases.utils import make_url_safe
|
||||
from superset.db_engine_specs.base import MetricType, TimeGrain
|
||||
from superset.extensions import (
|
||||
|
|
@ -67,7 +68,6 @@ from superset.models.helpers import AuditMixinNullable, ImportExportMixin
|
|||
from superset.result_set import SupersetResultSet
|
||||
from superset.utils import cache as cache_util, core as utils
|
||||
from superset.utils.core import get_username
|
||||
from superset.utils.memoized import memoized
|
||||
|
||||
config = app.config
|
||||
custom_password_store = config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"]
|
||||
|
|
@ -723,7 +723,7 @@ class Database(
|
|||
return self.get_db_engine_spec(url)
|
||||
|
||||
@classmethod
|
||||
@memoized
|
||||
@lru_cache(maxsize=LRU_CACHE_MAX_SIZE)
|
||||
def get_db_engine_spec(cls, url: URL) -> Type[db_engine_specs.BaseEngineSpec]:
|
||||
backend = url.get_backend_name()
|
||||
try:
|
||||
|
|
@ -897,7 +897,6 @@ class Database(
|
|||
def has_view_by_name(self, view_name: str, schema: Optional[str] = None) -> bool:
|
||||
return self.has_view(view_name=view_name, schema=schema)
|
||||
|
||||
@memoized
|
||||
def get_dialect(self) -> Dialect:
|
||||
sqla_url = make_url_safe(self.sqlalchemy_uri_decrypted)
|
||||
return sqla_url.get_dialect()()
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ from sqlalchemy import Column, Integer, String
|
|||
|
||||
from superset import app, db, security_manager
|
||||
from superset.models.helpers import AuditMixinNullable
|
||||
from superset.utils.memoized import memoized
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.connectors.base.models import BaseDatasource
|
||||
|
|
@ -57,7 +56,6 @@ class DatasourceAccessRequest(Model, AuditMixinNullable):
|
|||
return self.get_datasource
|
||||
|
||||
@datasource.getter # type: ignore
|
||||
@memoized
|
||||
def get_datasource(self) -> "BaseDatasource":
|
||||
ds = db.session.query(self.cls_model).filter_by(id=self.datasource_id).first()
|
||||
return ds
|
||||
|
|
|
|||
|
|
@ -46,7 +46,6 @@ from superset.tasks.thumbnails import cache_chart_thumbnail
|
|||
from superset.tasks.utils import get_current_user
|
||||
from superset.thumbnails.digest import get_chart_digest
|
||||
from superset.utils import core as utils
|
||||
from superset.utils.memoized import memoized
|
||||
from superset.viz import BaseViz, viz_types
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -159,9 +158,12 @@ class Slice( # pylint: disable=too-many-public-methods
|
|||
|
||||
# pylint: disable=using-constant-test
|
||||
@datasource.getter # type: ignore
|
||||
@memoized
|
||||
def get_datasource(self) -> Optional["BaseDatasource"]:
|
||||
return db.session.query(self.cls_model).filter_by(id=self.datasource_id).first()
|
||||
return (
|
||||
db.session.query(self.cls_model)
|
||||
.filter_by(id=self.datasource_id)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
@renders("datasource_name")
|
||||
def datasource_link(self) -> Optional[Markup]:
|
||||
|
|
@ -197,8 +199,7 @@ class Slice( # pylint: disable=too-many-public-methods
|
|||
|
||||
# pylint: enable=using-constant-test
|
||||
|
||||
@property # type: ignore
|
||||
@memoized
|
||||
@property
|
||||
def viz(self) -> Optional[BaseViz]:
|
||||
form_data = json.loads(self.params)
|
||||
viz_class = viz_types.get(self.viz_type)
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ import calendar
|
|||
import logging
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from functools import lru_cache
|
||||
from time import struct_time
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
|
|
@ -45,8 +46,7 @@ from superset.charts.commands.exceptions import (
|
|||
TimeRangeAmbiguousError,
|
||||
TimeRangeParseFailError,
|
||||
)
|
||||
from superset.constants import NO_TIME_RANGE
|
||||
from superset.utils.memoized import memoized
|
||||
from superset.constants import LRU_CACHE_MAX_SIZE, NO_TIME_RANGE
|
||||
|
||||
ParserElement.enablePackrat()
|
||||
|
||||
|
|
@ -394,7 +394,7 @@ class EvalHolidayFunc: # pylint: disable=too-few-public-methods
|
|||
)
|
||||
|
||||
|
||||
@memoized
|
||||
@lru_cache(maxsize=LRU_CACHE_MAX_SIZE)
|
||||
def datetime_parser() -> ParseResults: # pylint: disable=too-many-locals
|
||||
( # pylint: disable=invalid-name
|
||||
DATETIME,
|
||||
|
|
|
|||
|
|
@ -1,81 +0,0 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import functools
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type
|
||||
|
||||
|
||||
class _memoized:
|
||||
"""Decorator that caches a function's return value each time it is called
|
||||
|
||||
If called later with the same arguments, the cached value is returned, and
|
||||
not re-evaluated.
|
||||
|
||||
Define ``watch`` as a tuple of attribute names if this Decorator
|
||||
should account for instance variable changes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, func: Callable[..., Any], watch: Optional[Tuple[str, ...]] = None
|
||||
) -> None:
|
||||
self.func = func
|
||||
self.cache: Dict[Any, Any] = {}
|
||||
self.is_method = False
|
||||
self.watch = watch or ()
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
key = [args, frozenset(kwargs.items())]
|
||||
if self.is_method:
|
||||
key.append(tuple(getattr(args[0], v, None) for v in self.watch))
|
||||
key = tuple(key) # type: ignore
|
||||
try:
|
||||
if key in self.cache:
|
||||
return self.cache[key]
|
||||
except TypeError as ex:
|
||||
# Uncachable -- for instance, passing a list as an argument.
|
||||
raise TypeError("Function cannot be memoized") from ex
|
||||
value = self.func(*args, **kwargs)
|
||||
try:
|
||||
self.cache[key] = value
|
||||
except TypeError as ex:
|
||||
raise TypeError("Function cannot be memoized") from ex
|
||||
return value
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return the function's docstring."""
|
||||
return self.func.__doc__ or ""
|
||||
|
||||
def __get__(
|
||||
self, obj: Any, objtype: Type[Any]
|
||||
) -> functools.partial: # type: ignore
|
||||
if not self.is_method:
|
||||
self.is_method = True
|
||||
# Support instance methods.
|
||||
func = functools.partial(self.__call__, obj)
|
||||
func.__func__ = self.func # type: ignore
|
||||
return func
|
||||
|
||||
|
||||
def memoized(
|
||||
func: Optional[Callable[..., Any]] = None, watch: Optional[Tuple[str, ...]] = None
|
||||
) -> Callable[..., Any]:
|
||||
if func:
|
||||
return _memoized(func)
|
||||
|
||||
def wrapper(f: Callable[..., Any]) -> Callable[..., Any]:
|
||||
return _memoized(f, watch)
|
||||
|
||||
return wrapper
|
||||
|
|
@ -991,6 +991,7 @@ class TestUtils(SupersetTestCase):
|
|||
slc = self.get_slice("Girls", db.session)
|
||||
dashboard_id = 1
|
||||
|
||||
assert slc.viz is not None
|
||||
resp = self.get_json_resp(
|
||||
f"/superset/explore_json/{slc.datasource_type}/{slc.datasource_id}/"
|
||||
+ f'?form_data={{"slice_id": {slc.id}}}&dashboard_id={dashboard_id}',
|
||||
|
|
|
|||
|
|
@ -1,96 +0,0 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from pytest import mark
|
||||
|
||||
from superset.utils.memoized import memoized
|
||||
|
||||
|
||||
@mark.unittest
|
||||
class TestMemoized:
|
||||
def test_memoized_on_functions(self):
|
||||
watcher = {"val": 0}
|
||||
|
||||
@memoized
|
||||
def test_function(a, b, c):
|
||||
watcher["val"] += 1
|
||||
return a * b * c
|
||||
|
||||
result1 = test_function(1, 2, 3)
|
||||
result2 = test_function(1, 2, 3)
|
||||
assert result1 == result2
|
||||
assert watcher["val"] == 1
|
||||
|
||||
def test_memoized_on_methods(self):
|
||||
class test_class:
|
||||
def __init__(self, num):
|
||||
self.num = num
|
||||
self.watcher = 0
|
||||
|
||||
@memoized
|
||||
def test_method(self, a, b, c):
|
||||
self.watcher += 1
|
||||
return a * b * c * self.num
|
||||
|
||||
instance = test_class(5)
|
||||
result1 = instance.test_method(1, 2, 3)
|
||||
result2 = instance.test_method(1, 2, 3)
|
||||
assert result1 == result2
|
||||
assert instance.watcher == 1
|
||||
instance.num = 10
|
||||
assert result2 == instance.test_method(1, 2, 3)
|
||||
|
||||
def test_memoized_on_methods_with_watches(self):
|
||||
class test_class:
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
self.watcher = 0
|
||||
|
||||
@memoized(watch=("x", "y"))
|
||||
def test_method(self, a, b, c):
|
||||
self.watcher += 1
|
||||
return a * b * c * self.x * self.y
|
||||
|
||||
instance = test_class(3, 12)
|
||||
result1 = instance.test_method(1, 2, 3)
|
||||
result2 = instance.test_method(1, 2, 3)
|
||||
assert result1 == result2
|
||||
assert instance.watcher == 1
|
||||
result3 = instance.test_method(2, 3, 4)
|
||||
assert instance.watcher == 2
|
||||
result4 = instance.test_method(2, 3, 4)
|
||||
assert instance.watcher == 2
|
||||
assert result3 == result4
|
||||
assert result3 != result1
|
||||
instance.x = 1
|
||||
result5 = instance.test_method(2, 3, 4)
|
||||
assert instance.watcher == 3
|
||||
assert result5 != result4
|
||||
result6 = instance.test_method(2, 3, 4)
|
||||
assert instance.watcher == 3
|
||||
assert result6 == result5
|
||||
instance.x = 10
|
||||
instance.y = 10
|
||||
result7 = instance.test_method(2, 3, 4)
|
||||
assert instance.watcher == 4
|
||||
assert result7 != result6
|
||||
instance.x = 3
|
||||
instance.y = 12
|
||||
result8 = instance.test_method(1, 2, 3)
|
||||
assert instance.watcher == 4
|
||||
assert result1 == result8
|
||||
|
|
@ -59,9 +59,7 @@ def test_get_metrics(mocker: MockFixture) -> None:
|
|||
},
|
||||
]
|
||||
|
||||
database.get_db_engine_spec = mocker.MagicMock( # type: ignore
|
||||
return_value=CustomSqliteEngineSpec
|
||||
)
|
||||
database.get_db_engine_spec = mocker.MagicMock(return_value=CustomSqliteEngineSpec)
|
||||
assert database.get_metrics("table") == [
|
||||
{
|
||||
"expression": "COUNT(DISTINCT user_id)",
|
||||
|
|
|
|||
Loading…
Reference in New Issue