fix: pass string to `process_template` (#31329)
This commit is contained in:
parent
592564b623
commit
9315a8838c
|
|
@ -26,7 +26,7 @@ from typing import Any, cast, TYPE_CHECKING
|
||||||
|
|
||||||
import sqlparse
|
import sqlparse
|
||||||
from flask_babel import gettext as __
|
from flask_babel import gettext as __
|
||||||
from jinja2 import nodes
|
from jinja2 import nodes, Template
|
||||||
from sqlalchemy import and_
|
from sqlalchemy import and_
|
||||||
from sqlparse import keywords
|
from sqlparse import keywords
|
||||||
from sqlparse.lexer import Lexer
|
from sqlparse.lexer import Lexer
|
||||||
|
|
@ -999,10 +999,13 @@ def extract_tables_from_jinja_sql(sql: str, database: Database) -> set[Table]:
|
||||||
node.fields = nodes.TemplateData.fields
|
node.fields = nodes.TemplateData.fields
|
||||||
node.data = "NULL"
|
node.data = "NULL"
|
||||||
|
|
||||||
|
# re-render template back into a string
|
||||||
|
rendered_template = Template(template).render()
|
||||||
|
|
||||||
return (
|
return (
|
||||||
tables
|
tables
|
||||||
| ParsedQuery(
|
| ParsedQuery(
|
||||||
sql_statement=processor.process_template(template),
|
sql_statement=processor.process_template(rendered_template),
|
||||||
engine=database.db_engine_spec.engine,
|
engine=database.db_engine_spec.engine,
|
||||||
).tables
|
).tables
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@
|
||||||
# pylint: disable=invalid-name, redefined-outer-name, too-many-lines
|
# pylint: disable=invalid-name, redefined-outer-name, too-many-lines
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from unittest.mock import Mock
|
from unittest import mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import sqlparse
|
import sqlparse
|
||||||
|
|
@ -1888,12 +1888,33 @@ SELECT * FROM t"""
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_extract_tables_from_jinja_sql(
|
def test_extract_tables_from_jinja_sql(
|
||||||
engine: str, macro: str, expected: set[Table]
|
mocker: MockerFixture,
|
||||||
|
engine: str,
|
||||||
|
macro: str,
|
||||||
|
expected: set[Table],
|
||||||
) -> None:
|
) -> None:
|
||||||
assert (
|
assert (
|
||||||
extract_tables_from_jinja_sql(
|
extract_tables_from_jinja_sql(
|
||||||
sql=f"'{{{{ {engine}.{macro} }}}}'",
|
sql=f"'{{{{ {engine}.{macro} }}}}'",
|
||||||
database=Mock(),
|
database=mocker.Mock(),
|
||||||
)
|
)
|
||||||
== expected
|
== expected
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch.dict(
|
||||||
|
"superset.extensions.feature_flag_manager._feature_flags",
|
||||||
|
{"ENABLE_TEMPLATE_PROCESSING": False},
|
||||||
|
clear=True,
|
||||||
|
)
|
||||||
|
def test_extract_tables_from_jinja_sql_disabled(mocker: MockerFixture) -> None:
|
||||||
|
"""
|
||||||
|
Test the function when the feature flag is disabled.
|
||||||
|
"""
|
||||||
|
database = mocker.Mock()
|
||||||
|
database.db_engine_spec.engine = "mssql"
|
||||||
|
|
||||||
|
assert extract_tables_from_jinja_sql(
|
||||||
|
sql="SELECT 1 FROM t",
|
||||||
|
database=database,
|
||||||
|
) == {Table("t")}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue