fix: pass string to `process_template` (#31329)

This commit is contained in:
Beto Dealmeida 2024-12-07 09:49:49 -05:00 committed by GitHub
parent 592564b623
commit 9315a8838c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 29 additions and 5 deletions

View File

@ -26,7 +26,7 @@ from typing import Any, cast, TYPE_CHECKING
import sqlparse import sqlparse
from flask_babel import gettext as __ from flask_babel import gettext as __
from jinja2 import nodes from jinja2 import nodes, Template
from sqlalchemy import and_ from sqlalchemy import and_
from sqlparse import keywords from sqlparse import keywords
from sqlparse.lexer import Lexer from sqlparse.lexer import Lexer
@ -999,10 +999,13 @@ def extract_tables_from_jinja_sql(sql: str, database: Database) -> set[Table]:
node.fields = nodes.TemplateData.fields node.fields = nodes.TemplateData.fields
node.data = "NULL" node.data = "NULL"
# re-render template back into a string
rendered_template = Template(template).render()
return ( return (
tables tables
| ParsedQuery( | ParsedQuery(
sql_statement=processor.process_template(template), sql_statement=processor.process_template(rendered_template),
engine=database.db_engine_spec.engine, engine=database.db_engine_spec.engine,
).tables ).tables
) )

View File

@ -17,7 +17,7 @@
# pylint: disable=invalid-name, redefined-outer-name, too-many-lines # pylint: disable=invalid-name, redefined-outer-name, too-many-lines
from typing import Optional from typing import Optional
from unittest.mock import Mock from unittest import mock
import pytest import pytest
import sqlparse import sqlparse
@ -1888,12 +1888,33 @@ SELECT * FROM t"""
], ],
) )
def test_extract_tables_from_jinja_sql( def test_extract_tables_from_jinja_sql(
engine: str, macro: str, expected: set[Table] mocker: MockerFixture,
engine: str,
macro: str,
expected: set[Table],
) -> None: ) -> None:
assert ( assert (
extract_tables_from_jinja_sql( extract_tables_from_jinja_sql(
sql=f"'{{{{ {engine}.{macro} }}}}'", sql=f"'{{{{ {engine}.{macro} }}}}'",
database=Mock(), database=mocker.Mock(),
) )
== expected == expected
) )
@mock.patch.dict(
"superset.extensions.feature_flag_manager._feature_flags",
{"ENABLE_TEMPLATE_PROCESSING": False},
clear=True,
)
def test_extract_tables_from_jinja_sql_disabled(mocker: MockerFixture) -> None:
"""
Test the function when the feature flag is disabled.
"""
database = mocker.Mock()
database.db_engine_spec.engine = "mssql"
assert extract_tables_from_jinja_sql(
sql="SELECT 1 FROM t",
database=database,
) == {Table("t")}