fix: avoid filters containing null value (#17168)

This commit is contained in:
Yongjie Zhao 2021-10-21 11:20:38 +01:00 committed by GitHub
parent cd9e99402d
commit 4c708af710
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 38 additions and 5 deletions

View File

@ -419,7 +419,9 @@ class BaseDatasource(
"""
raise NotImplementedError()
def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
def values_for_column(
self, column_name: str, limit: int = 10000, contain_null: bool = True,
) -> List[Any]:
"""Given a column, returns an iterable of distinct values
This is used to populate the dropdown showing a list of

View File

@ -948,7 +948,9 @@ class DruidDatasource(Model, BaseDatasource):
)
return aggs, post_aggs
def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
def values_for_column(
self, column_name: str, limit: int = 10000, contain_null: bool = True,
) -> List[Any]:
"""Retrieve some values for the given column"""
logger.info(
"Getting values for columns [{}] limited to [{}]".format(column_name, limit)

View File

@ -173,7 +173,9 @@ class AnnotationDatasource(BaseDatasource):
def get_query_str(self, query_obj: QueryObjectDict) -> str:
raise NotImplementedError()
def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
def values_for_column(
self, column_name: str, limit: int = 10000, contain_null: bool = True,
) -> List[Any]:
raise NotImplementedError()
@ -712,7 +714,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
)
) from ex
def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
def values_for_column(
self, column_name: str, limit: int = 10000, contain_null: bool = True,
) -> List[Any]:
"""Runs query against sqla to retrieve some
sample values for the given column.
"""
@ -728,6 +732,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
if limit:
qry = qry.limit(limit)
if not contain_null:
qry = qry.where(target_col.get_sqla_col().isnot(None))
if self.fetch_values_predicate:
qry = qry.where(self.get_fetch_values_predicate())

View File

@ -921,7 +921,9 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
datasource.raise_for_access()
row_limit = apply_max_row_limit(config["FILTER_SELECT_ROW_LIMIT"])
payload = json.dumps(
datasource.values_for_column(column, row_limit),
datasource.values_for_column(
column_name=column, limit=row_limit, contain_null=False,
),
default=utils.json_int_dttm_ser,
ignore_nan=True,
)

View File

@ -463,3 +463,23 @@ class TestDatabaseModel(SupersetTestCase):
db.session.delete(table)
db.session.delete(database)
db.session.commit()
def test_values_for_column(self):
table = SqlaTable(
table_name="test_null_in_column",
sql="SELECT 'foo' as foo UNION SELECT 'bar' UNION SELECT NULL",
database=get_example_database(),
)
TableColumn(column_name="foo", type="VARCHAR(255)", table=table)
with_null = table.values_for_column(
column_name="foo", limit=10000, contain_null=True
)
assert None in with_null
assert len(with_null) == 3
without_null = table.values_for_column(
column_name="foo", limit=10000, contain_null=False
)
assert None not in without_null
assert len(without_null) == 2