From a12ccf2c1d3fe48f3e5e6fe5a08810aa3f57df37 Mon Sep 17 00:00:00 2001 From: Geido <60598000+geido@users.noreply.github.com> Date: Fri, 25 Oct 2024 19:11:28 +0300 Subject: [PATCH] fix(Jinja): Extra cache keys for Jinja columns (#30715) --- superset/connectors/sqla/models.py | 6 ++- tests/integration_tests/sqla_models_tests.py | 50 ++++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index e8aa0d705..75354ad35 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -116,7 +116,7 @@ from superset.superset_typing import ( ) from superset.utils import core as utils, json from superset.utils.backports import StrEnum -from superset.utils.core import GenericDataType, MediumText +from superset.utils.core import GenericDataType, is_adhoc_column, MediumText config = app.config metadata = Model.metadata # pylint: disable=no-member @@ -1980,6 +1980,10 @@ class SqlaTable( templatable_statements.append(extras["where"]) if "having" in extras: templatable_statements.append(extras["having"]) + if "columns" in query_obj: + templatable_statements += [ + c["sqlExpression"] for c in query_obj["columns"] if is_adhoc_column(c) + ] if self.is_rls_supported: templatable_statements += [ f.clause for f in security_manager.get_rls_filters(self) diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py index 2d7f6bf04..79d4bf00e 100644 --- a/tests/integration_tests/sqla_models_tests.py +++ b/tests/integration_tests/sqla_models_tests.py @@ -911,6 +911,56 @@ def test_extra_cache_keys_in_sql_expression( assert extra_cache_keys == expected_cache_keys +@pytest.mark.usefixtures("app_context") +@pytest.mark.parametrize( + "sql_expression,expected_cache_keys,has_extra_cache_keys", + [ + ("'{{ current_username() }}'", ["abc"], True), + ("(user != 'abc')", [], False), + ], +) +@patch("superset.jinja_context.get_user_id", return_value=1) +@patch("superset.jinja_context.get_username", return_value="abc") +@patch("superset.jinja_context.get_user_email", return_value="abc@test.com") +def test_extra_cache_keys_in_columns( + mock_user_email, + mock_username, + mock_user_id, + sql_expression, + expected_cache_keys, + has_extra_cache_keys, +): + table = SqlaTable( + table_name="test_has_no_extra_cache_keys_table", + sql="SELECT 'abc' as user", + database=get_example_database(), + ) + base_query_obj = { + "granularity": None, + "from_dttm": None, + "to_dttm": None, + "groupby": [], + "metrics": [], + "is_timeseries": False, + "filter": [], + } + + query_obj = dict( + **base_query_obj, + columns=[ + { + "label": None, + "expressionType": "SQL", + "sqlExpression": sql_expression, + } + ], + ) + + extra_cache_keys = table.get_extra_cache_keys(query_obj) + assert table.has_extra_cache_key_calls(query_obj) == has_extra_cache_keys + assert extra_cache_keys == expected_cache_keys + + @pytest.mark.usefixtures("app_context") @pytest.mark.parametrize( "row,dimension,result",