feat: recursive metric definitions (#32228)
This commit is contained in:
parent
15fbb195e9
commit
2c583d1584
|
|
@ -27,7 +27,8 @@ from typing import Any, Callable, cast, Optional, TYPE_CHECKING, TypedDict, Unio
|
||||||
import dateutil
|
import dateutil
|
||||||
from flask import current_app, g, has_request_context, request
|
from flask import current_app, g, has_request_context, request
|
||||||
from flask_babel import gettext as _
|
from flask_babel import gettext as _
|
||||||
from jinja2 import DebugUndefined, Environment
|
from jinja2 import DebugUndefined, Environment, nodes
|
||||||
|
from jinja2.nodes import Call, Node
|
||||||
from jinja2.sandbox import SandboxedEnvironment
|
from jinja2.sandbox import SandboxedEnvironment
|
||||||
from sqlalchemy.engine.interfaces import Dialect
|
from sqlalchemy.engine.interfaces import Dialect
|
||||||
from sqlalchemy.sql.expression import bindparam
|
from sqlalchemy.sql.expression import bindparam
|
||||||
|
|
@ -888,6 +889,26 @@ def get_dataset_id_from_context(metric_key: str) -> int:
|
||||||
raise SupersetTemplateException(exc_message)
|
raise SupersetTemplateException(exc_message)
|
||||||
|
|
||||||
|
|
||||||
|
def has_metric_macro(template_string: str, env: Environment) -> bool:
|
||||||
|
"""
|
||||||
|
Checks if a template string contains a metric macro.
|
||||||
|
|
||||||
|
>>> has_metric_macro("{{ metric('my_metric') }}")
|
||||||
|
True
|
||||||
|
|
||||||
|
"""
|
||||||
|
ast = env.parse(template_string)
|
||||||
|
|
||||||
|
def visit_node(node: Node) -> bool:
|
||||||
|
return (
|
||||||
|
isinstance(node, Call)
|
||||||
|
and isinstance(node.node, nodes.Name)
|
||||||
|
and node.node.name == "metric"
|
||||||
|
) or any(visit_node(child) for child in node.iter_child_nodes())
|
||||||
|
|
||||||
|
return visit_node(ast)
|
||||||
|
|
||||||
|
|
||||||
def metric_macro(metric_key: str, dataset_id: Optional[int] = None) -> str:
|
def metric_macro(metric_key: str, dataset_id: Optional[int] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Given a metric key, returns its syntax.
|
Given a metric key, returns its syntax.
|
||||||
|
|
@ -908,16 +929,32 @@ def metric_macro(metric_key: str, dataset_id: Optional[int] = None) -> str:
|
||||||
dataset = DatasetDAO.find_by_id(dataset_id)
|
dataset = DatasetDAO.find_by_id(dataset_id)
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise DatasetNotFoundError(f"Dataset ID {dataset_id} not found.")
|
raise DatasetNotFoundError(f"Dataset ID {dataset_id} not found.")
|
||||||
|
|
||||||
metrics: dict[str, str] = {
|
metrics: dict[str, str] = {
|
||||||
metric.metric_name: metric.expression for metric in dataset.metrics
|
metric.metric_name: metric.expression for metric in dataset.metrics
|
||||||
}
|
}
|
||||||
dataset_name = dataset.table_name
|
if metric_key not in metrics:
|
||||||
if metric := metrics.get(metric_key):
|
|
||||||
return metric
|
|
||||||
raise SupersetTemplateException(
|
raise SupersetTemplateException(
|
||||||
_(
|
_(
|
||||||
"Metric ``%(metric_name)s`` not found in %(dataset_name)s.",
|
"Metric ``%(metric_name)s`` not found in %(dataset_name)s.",
|
||||||
metric_name=metric_key,
|
metric_name=metric_key,
|
||||||
dataset_name=dataset_name,
|
dataset_name=dataset.table_name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
definition = metrics[metric_key]
|
||||||
|
|
||||||
|
env = SandboxedEnvironment(undefined=DebugUndefined)
|
||||||
|
context = {"metric": partial(safe_proxy, metric_macro)}
|
||||||
|
while has_metric_macro(definition, env):
|
||||||
|
old_definition = definition
|
||||||
|
template = env.from_string(definition)
|
||||||
|
try:
|
||||||
|
definition = template.render(context)
|
||||||
|
except RecursionError as ex:
|
||||||
|
raise SupersetTemplateException("Cyclic metric macro detected") from ex
|
||||||
|
|
||||||
|
if definition == old_definition:
|
||||||
|
break
|
||||||
|
|
||||||
|
return definition
|
||||||
|
|
|
||||||
|
|
@ -544,6 +544,99 @@ def test_metric_macro_with_dataset_id(mocker: MockerFixture) -> None:
|
||||||
mock_get_form_data.assert_not_called()
|
mock_get_form_data.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_metric_macro_recursive(mocker: MockerFixture) -> None:
|
||||||
|
"""
|
||||||
|
Test the ``metric_macro`` when the definition is recursive.
|
||||||
|
"""
|
||||||
|
mock_g = mocker.patch("superset.jinja_context.g")
|
||||||
|
mock_g.form_data = {"datasource": {"id": 1}}
|
||||||
|
DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806
|
||||||
|
DatasetDAO.find_by_id.return_value = SqlaTable(
|
||||||
|
table_name="test_dataset",
|
||||||
|
metrics=[
|
||||||
|
SqlMetric(metric_name="a", expression="COUNT(*)"),
|
||||||
|
SqlMetric(metric_name="b", expression="{{ metric('a') }}"),
|
||||||
|
SqlMetric(metric_name="c", expression="{{ metric('b') }}"),
|
||||||
|
],
|
||||||
|
database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"),
|
||||||
|
schema="my_schema",
|
||||||
|
sql=None,
|
||||||
|
)
|
||||||
|
assert metric_macro("c", 1) == "COUNT(*)"
|
||||||
|
|
||||||
|
|
||||||
|
def test_metric_macro_recursive_compound(mocker: MockerFixture) -> None:
|
||||||
|
"""
|
||||||
|
Test the ``metric_macro`` when the definition is compound.
|
||||||
|
"""
|
||||||
|
mock_g = mocker.patch("superset.jinja_context.g")
|
||||||
|
mock_g.form_data = {"datasource": {"id": 1}}
|
||||||
|
DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806
|
||||||
|
DatasetDAO.find_by_id.return_value = SqlaTable(
|
||||||
|
table_name="test_dataset",
|
||||||
|
metrics=[
|
||||||
|
SqlMetric(metric_name="a", expression="SUM(*)"),
|
||||||
|
SqlMetric(metric_name="b", expression="COUNT(*)"),
|
||||||
|
SqlMetric(
|
||||||
|
metric_name="c",
|
||||||
|
expression="{{ metric('a') }} / {{ metric('b') }}",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"),
|
||||||
|
schema="my_schema",
|
||||||
|
sql=None,
|
||||||
|
)
|
||||||
|
assert metric_macro("c", 1) == "SUM(*) / COUNT(*)"
|
||||||
|
|
||||||
|
|
||||||
|
def test_metric_macro_recursive_cyclic(mocker: MockerFixture) -> None:
|
||||||
|
"""
|
||||||
|
Test the ``metric_macro`` when the definition is cyclic.
|
||||||
|
|
||||||
|
In this case it should stop, and not go into an infinite loop.
|
||||||
|
"""
|
||||||
|
mock_g = mocker.patch("superset.jinja_context.g")
|
||||||
|
mock_g.form_data = {"datasource": {"id": 1}}
|
||||||
|
DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806
|
||||||
|
DatasetDAO.find_by_id.return_value = SqlaTable(
|
||||||
|
table_name="test_dataset",
|
||||||
|
metrics=[
|
||||||
|
SqlMetric(metric_name="a", expression="{{ metric('c') }}"),
|
||||||
|
SqlMetric(metric_name="b", expression="{{ metric('a') }}"),
|
||||||
|
SqlMetric(metric_name="c", expression="{{ metric('b') }}"),
|
||||||
|
],
|
||||||
|
database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"),
|
||||||
|
schema="my_schema",
|
||||||
|
sql=None,
|
||||||
|
)
|
||||||
|
with pytest.raises(SupersetTemplateException) as excinfo:
|
||||||
|
metric_macro("c", 1)
|
||||||
|
assert str(excinfo.value) == "Cyclic metric macro detected"
|
||||||
|
|
||||||
|
|
||||||
|
def test_metric_macro_recursive_infinite(mocker: MockerFixture) -> None:
|
||||||
|
"""
|
||||||
|
Test the ``metric_macro`` when the definition is cyclic.
|
||||||
|
|
||||||
|
In this case it should stop, and not go into an infinite loop.
|
||||||
|
"""
|
||||||
|
mock_g = mocker.patch("superset.jinja_context.g")
|
||||||
|
mock_g.form_data = {"datasource": {"id": 1}}
|
||||||
|
DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806
|
||||||
|
DatasetDAO.find_by_id.return_value = SqlaTable(
|
||||||
|
table_name="test_dataset",
|
||||||
|
metrics=[
|
||||||
|
SqlMetric(metric_name="a", expression="{{ metric('a') }}"),
|
||||||
|
],
|
||||||
|
database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"),
|
||||||
|
schema="my_schema",
|
||||||
|
sql=None,
|
||||||
|
)
|
||||||
|
with pytest.raises(SupersetTemplateException) as excinfo:
|
||||||
|
metric_macro("a", 1)
|
||||||
|
assert str(excinfo.value) == "Cyclic metric macro detected"
|
||||||
|
|
||||||
|
|
||||||
def test_metric_macro_with_dataset_id_invalid_key(mocker: MockerFixture) -> None:
|
def test_metric_macro_with_dataset_id_invalid_key(mocker: MockerFixture) -> None:
|
||||||
"""
|
"""
|
||||||
Test the ``metric_macro`` when passing a dataset ID and an invalid key.
|
Test the ``metric_macro`` when passing a dataset ID and an invalid key.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue