From 4c708af71081eef3454e7f0ac2bba5d0588bfa87 Mon Sep 17 00:00:00 2001 From: Yongjie Zhao Date: Thu, 21 Oct 2021 11:20:38 +0100 Subject: [PATCH] fix: avoid filters containing null value (#17168) --- superset/connectors/base/models.py | 4 +++- superset/connectors/druid/models.py | 4 +++- superset/connectors/sqla/models.py | 11 +++++++++-- superset/views/core.py | 4 +++- tests/integration_tests/sqla_models_tests.py | 20 ++++++++++++++++++++ 5 files changed, 38 insertions(+), 5 deletions(-) diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py index 711f1b72b..0bd94885a 100644 --- a/superset/connectors/base/models.py +++ b/superset/connectors/base/models.py @@ -419,7 +419,9 @@ class BaseDatasource( """ raise NotImplementedError() - def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]: + def values_for_column( + self, column_name: str, limit: int = 10000, contain_null: bool = True, + ) -> List[Any]: """Given a column, returns an iterable of distinct values This is used to populate the dropdown showing a list of diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index 32edb6952..aa86d7ab8 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -948,7 +948,9 @@ class DruidDatasource(Model, BaseDatasource): ) return aggs, post_aggs - def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]: + def values_for_column( + self, column_name: str, limit: int = 10000, contain_null: bool = True, + ) -> List[Any]: """Retrieve some values for the given column""" logger.info( "Getting values for columns [{}] limited to [{}]".format(column_name, limit) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index fcb40f219..a08050707 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -173,7 +173,9 @@ class AnnotationDatasource(BaseDatasource): def get_query_str(self, query_obj: QueryObjectDict) -> str: raise NotImplementedError() - def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]: + def values_for_column( + self, column_name: str, limit: int = 10000, contain_null: bool = True, + ) -> List[Any]: raise NotImplementedError() @@ -712,7 +714,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho ) ) from ex - def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]: + def values_for_column( + self, column_name: str, limit: int = 10000, contain_null: bool = True, + ) -> List[Any]: """Runs query against sqla to retrieve some sample values for the given column. """ @@ -728,6 +732,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho if limit: qry = qry.limit(limit) + if not contain_null: + qry = qry.where(target_col.get_sqla_col().isnot(None)) + if self.fetch_values_predicate: qry = qry.where(self.get_fetch_values_predicate()) diff --git a/superset/views/core.py b/superset/views/core.py index 6dfa63050..5a7dd18f9 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -921,7 +921,9 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods datasource.raise_for_access() row_limit = apply_max_row_limit(config["FILTER_SELECT_ROW_LIMIT"]) payload = json.dumps( - datasource.values_for_column(column, row_limit), + datasource.values_for_column( + column_name=column, limit=row_limit, contain_null=False, + ), default=utils.json_int_dttm_ser, ignore_nan=True, ) diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py index f17cedb7c..7155735b5 100644 --- a/tests/integration_tests/sqla_models_tests.py +++ b/tests/integration_tests/sqla_models_tests.py @@ -463,3 +463,23 @@ class TestDatabaseModel(SupersetTestCase): db.session.delete(table) db.session.delete(database) db.session.commit() + + def test_values_for_column(self): + table = SqlaTable( + table_name="test_null_in_column", + sql="SELECT 'foo' as foo UNION SELECT 'bar' UNION SELECT NULL", + database=get_example_database(), + ) + TableColumn(column_name="foo", type="VARCHAR(255)", table=table) + + with_null = table.values_for_column( + column_name="foo", limit=10000, contain_null=True + ) + assert None in with_null + assert len(with_null) == 3 + + without_null = table.values_for_column( + column_name="foo", limit=10000, contain_null=False + ) + assert None not in without_null + assert len(without_null) == 2