fix: % replace in `values_for_column` (#28271)
This commit is contained in:
parent
51da5adbc7
commit
fe37d914e5
|
|
@ -1377,10 +1377,14 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
|||
qry = qry.where(self.get_fetch_values_predicate(template_processor=tp))
|
||||
|
||||
with self.database.get_sqla_engine() as engine:
|
||||
sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
|
||||
sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True}))
|
||||
sql = self._apply_cte(sql, cte)
|
||||
sql = self.database.mutate_sql_based_on_config(sql)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
if engine.dialect.identifier_preparer._double_percents:
|
||||
sql = sql.replace("%%", "%")
|
||||
|
||||
df = pd.read_sql_query(sql=sql, con=engine)
|
||||
# replace NaN with None to ensure it can be serialized to JSON
|
||||
df = df.replace({np.nan: None})
|
||||
|
|
|
|||
|
|
@ -17,22 +17,24 @@
|
|||
|
||||
# pylint: disable=import-outside-toplevel
|
||||
|
||||
from contextlib import contextmanager
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm.session import Session
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.models.core import Database
|
||||
|
||||
def test_values_for_column(mocker: MockerFixture, session: Session) -> None:
|
||||
"""
|
||||
Test the `values_for_column` method.
|
||||
|
||||
NULL values should be returned as `None`, not `np.nan`, since NaN cannot be
|
||||
serialized to JSON.
|
||||
"""
|
||||
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
||||
@pytest.fixture()
|
||||
def database(mocker: MockerFixture, session: Session) -> Database:
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.models.core import Database
|
||||
|
||||
SqlaTable.metadata.create_all(session.get_bind())
|
||||
|
|
@ -42,13 +44,12 @@ def test_values_for_column(mocker: MockerFixture, session: Session) -> None:
|
|||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
|
||||
database = Database(database_name="db", sqlalchemy_uri="sqlite://")
|
||||
|
||||
connection = engine.raw_connection()
|
||||
connection.execute("CREATE TABLE t (c INTEGER)")
|
||||
connection.execute("INSERT INTO t VALUES (1)")
|
||||
connection.execute("INSERT INTO t VALUES (NULL)")
|
||||
connection.execute("CREATE TABLE t (a INTEGER, b TEXT)")
|
||||
connection.execute("INSERT INTO t VALUES (1, 'Alice')")
|
||||
connection.execute("INSERT INTO t VALUES (NULL, 'Bob')")
|
||||
connection.commit()
|
||||
|
||||
# since we're using an in-memory SQLite database, make sure we always
|
||||
|
|
@ -63,10 +64,94 @@ def test_values_for_column(mocker: MockerFixture, session: Session) -> None:
|
|||
new=mock_get_sqla_engine,
|
||||
)
|
||||
|
||||
return database
|
||||
|
||||
|
||||
def test_values_for_column(database: Database) -> None:
|
||||
"""
|
||||
Test the `values_for_column` method.
|
||||
|
||||
NULL values should be returned as `None`, not `np.nan`, since NaN cannot be
|
||||
serialized to JSON.
|
||||
"""
|
||||
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
||||
|
||||
table = SqlaTable(
|
||||
database=database,
|
||||
schema=None,
|
||||
table_name="t",
|
||||
columns=[TableColumn(column_name="c")],
|
||||
columns=[TableColumn(column_name="a")],
|
||||
)
|
||||
assert table.values_for_column("c") == [1, None]
|
||||
assert table.values_for_column("a") == [1, None]
|
||||
|
||||
|
||||
def test_values_for_column_calculated(
|
||||
mocker: MockerFixture,
|
||||
database: Database,
|
||||
) -> None:
|
||||
"""
|
||||
Test that calculated columns work.
|
||||
"""
|
||||
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
||||
|
||||
table = SqlaTable(
|
||||
database=database,
|
||||
schema=None,
|
||||
table_name="t",
|
||||
columns=[
|
||||
TableColumn(
|
||||
column_name="starts_with_A",
|
||||
expression="CASE WHEN b LIKE 'A%' THEN 'yes' ELSE 'nope' END",
|
||||
)
|
||||
],
|
||||
)
|
||||
assert table.values_for_column("starts_with_A") == ["yes", "nope"]
|
||||
|
||||
|
||||
def test_values_for_column_double_percents(
|
||||
mocker: MockerFixture,
|
||||
database: Database,
|
||||
) -> None:
|
||||
"""
|
||||
Test the behavior of `double_percents`.
|
||||
"""
|
||||
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
||||
|
||||
with database.get_sqla_engine() as engine:
|
||||
engine.dialect.identifier_preparer._double_percents = "pyformat"
|
||||
|
||||
table = SqlaTable(
|
||||
database=database,
|
||||
schema=None,
|
||||
table_name="t",
|
||||
columns=[
|
||||
TableColumn(
|
||||
column_name="starts_with_A",
|
||||
expression="CASE WHEN b LIKE 'A%' THEN 'yes' ELSE 'nope' END",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
mutate_sql_based_on_config = mocker.patch.object(
|
||||
database,
|
||||
"mutate_sql_based_on_config",
|
||||
side_effect=lambda sql: sql,
|
||||
)
|
||||
pd = mocker.patch("superset.models.helpers.pd")
|
||||
|
||||
table.values_for_column("starts_with_A")
|
||||
|
||||
# make sure the SQL originally had double percents
|
||||
mutate_sql_based_on_config.assert_called_with(
|
||||
"SELECT DISTINCT CASE WHEN b LIKE 'A%%' THEN 'yes' ELSE 'nope' END "
|
||||
"AS column_values \nFROM t\n LIMIT 10000 OFFSET 0"
|
||||
)
|
||||
# make sure final query has single percents
|
||||
with database.get_sqla_engine() as engine:
|
||||
pd.read_sql_query.assert_called_with(
|
||||
sql=(
|
||||
"SELECT DISTINCT CASE WHEN b LIKE 'A%' THEN 'yes' ELSE 'nope' END "
|
||||
"AS column_values \nFROM t\n LIMIT 10000 OFFSET 0"
|
||||
),
|
||||
con=engine,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue