feat: refactor on DBEventLogger to allow for context management (#13441)

Co-authored-by: Beto Dealmeida <roberto@dealmeida.net>
This commit is contained in:
Hugh A. Miles II 2021-03-05 15:12:42 -05:00 committed by GitHub
parent 8d48d2e37b
commit b17e7aa5c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 162 additions and 22 deletions

View File

@ -19,9 +19,9 @@ import inspect
import json
import logging
import textwrap
import time
from abc import ABC, abstractmethod
from contextlib import contextmanager
from datetime import datetime, timedelta
from typing import Any, Callable, cast, Dict, Iterator, Optional, Type, Union
from flask import current_app, g, request
@ -58,6 +58,35 @@ def collect_request_payload() -> Dict[str, Any]:
class AbstractEventLogger(ABC):
def __call__(
self,
action: str,
object_ref: Optional[str] = None,
log_to_statsd: bool = True,
duration: Optional[timedelta] = None,
**payload_override: Dict[str, Any],
) -> object:
# pylint: disable=W0201
self.action = action
self.object_ref = object_ref
self.log_to_statsd = log_to_statsd
self.payload_override = payload_override
return self
def __enter__(self) -> None:
# pylint: disable=W0201
self.start = datetime.now()
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
# Log data w/ arguments being passed in
self.log_with_context(
action=self.action,
object_ref=self.object_ref,
log_to_statsd=self.log_to_statsd,
duration=datetime.now() - self.start,
**self.payload_override,
)
@abstractmethod
def log( # pylint: disable=too-many-arguments
self,
@ -72,32 +101,28 @@ class AbstractEventLogger(ABC):
) -> None:
pass
@contextmanager
def log_context( # pylint: disable=too-many-locals
self, action: str, object_ref: Optional[str] = None, log_to_statsd: bool = True,
) -> Iterator[Callable[..., None]]:
"""
Log an event with additional information from the request context.
:param action: a name to identify the event
:param object_ref: reference to the Python object that triggered this action
:param log_to_statsd: whether to update statsd counter for the action
"""
def log_with_context( # pylint: disable=too-many-locals
self,
action: str,
duration: timedelta,
object_ref: Optional[str] = None,
log_to_statsd: bool = True,
**payload_override: Optional[Dict[str, Any]],
) -> None:
from superset.views.core import get_form_data
start_time = time.time()
referrer = request.referrer[:1000] if request.referrer else None
user_id = g.user.get_id() if hasattr(g, "user") and g.user else None
payload_override = {}
# yield a helper to add additional payload
yield lambda **kwargs: payload_override.update(kwargs)
try:
user_id = g.user.get_id()
except Exception as ex: # pylint: disable=broad-except
logging.warning(ex)
user_id = None
payload = collect_request_payload()
if object_ref:
payload["object_ref"] = object_ref
# manual updates from context comes the last
payload.update(payload_override)
if payload_override:
payload.update(payload_override)
dashboard_id: Optional[int] = None
try:
@ -133,10 +158,32 @@ class AbstractEventLogger(ABC):
records=records,
dashboard_id=dashboard_id,
slice_id=slice_id,
duration_ms=round((time.time() - start_time) * 1000),
duration_ms=int(duration.total_seconds() * 1000),
referrer=referrer,
)
@contextmanager
def log_context( # pylint: disable=too-many-locals
self, action: str, object_ref: Optional[str] = None, log_to_statsd: bool = True,
) -> Iterator[Callable[..., None]]:
"""
Log an event with additional information from the request context.
:param action: a name to identify the event
:param object_ref: reference to the Python object that triggered this action
:param log_to_statsd: whether to update statsd counter for the action
"""
payload_override = {}
start = datetime.now()
# yield a helper to add additional payload
yield lambda **kwargs: payload_override.update(kwargs)
duration = datetime.now() - start
# take the action from payload_override else take the function param action
action_str = payload_override.pop("action", action)
self.log_with_context(
action_str, duration, object_ref, log_to_statsd, **payload_override
)
def _wrapper(
self,
f: Callable[..., Any],

View File

@ -17,9 +17,18 @@
import logging
import time
import unittest
from datetime import datetime, timedelta
from typing import Any, Callable, cast, Dict, Iterator, Optional, Type, Union
from unittest.mock import patch
from superset.utils.log import DBEventLogger, get_event_logger_from_cfg_value
from freezegun import freeze_time
from superset import security_manager
from superset.utils.log import (
AbstractEventLogger,
DBEventLogger,
get_event_logger_from_cfg_value,
)
from tests.test_app import app
@ -101,3 +110,87 @@ class TestEventLogger(unittest.TestCase):
],
)
self.assertGreaterEqual(payload["duration_ms"], 100)
@patch("superset.utils.log.g", spec={})
@freeze_time("Jan 14th, 2020", auto_tick_seconds=15)
def test_context_manager_log(self, mock_g):
class DummyEventLogger(AbstractEventLogger):
def __init__(self):
self.records = []
def log(
self,
user_id: Optional[int],
action: str,
dashboard_id: Optional[int],
duration_ms: Optional[int],
slice_id: Optional[int],
referrer: Optional[str],
*args: Any,
**kwargs: Any,
):
self.records.append(
{**kwargs, "user_id": user_id, "duration": duration_ms}
)
logger = DummyEventLogger()
with app.test_request_context():
mock_g.user = security_manager.find_user("gamma")
with logger(action="foo", engine="bar"):
pass
assert logger.records == [
{
"records": [{"path": "/", "engine": "bar"}],
"user_id": "2",
"duration": 15000.0,
}
]
@patch("superset.utils.log.g", spec={})
def test_context_manager_log_with_context(self, mock_g):
class DummyEventLogger(AbstractEventLogger):
def __init__(self):
self.records = []
def log(
self,
user_id: Optional[int],
action: str,
dashboard_id: Optional[int],
duration_ms: Optional[int],
slice_id: Optional[int],
referrer: Optional[str],
*args: Any,
**kwargs: Any,
):
self.records.append(
{**kwargs, "user_id": user_id, "duration": duration_ms}
)
logger = DummyEventLogger()
with app.test_request_context():
mock_g.user = security_manager.find_user("gamma")
logger.log_with_context(
action="foo",
duration=timedelta(days=64, seconds=29156, microseconds=10),
object_ref={"baz": "food"},
log_to_statsd=False,
payload_override={"engine": "sqllite"},
)
assert logger.records == [
{
"records": [
{
"path": "/",
"object_ref": {"baz": "food"},
"payload_override": {"engine": "sqllite"},
}
],
"user_id": "2",
"duration": 5558756000,
}
]