diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py index 1745da3dd..c4cc6f406 100644 --- a/superset/connectors/base/models.py +++ b/superset/connectors/base/models.py @@ -128,10 +128,8 @@ class BaseDatasource( ), ) - # placeholder for a relationship to a derivative of BaseColumn - columns: List[Any] = [] - # placeholder for a relationship to a derivative of BaseMetric - metrics: List[Any] = [] + columns: List["BaseColumn"] = [] + metrics: List["BaseMetric"] = [] @property def type(self) -> str: diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index 644d6a345..fe15b2473 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -453,6 +453,8 @@ class DruidDatasource(Model, BaseDatasource): type = "druid" query_language = "json" cluster_class = DruidCluster + columns: List[DruidColumn] = [] + metrics: List[DruidMetric] = [] metric_class = DruidMetric column_class = DruidColumn owner_class = security_manager.user_model diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index f2ed5f1f4..c75c081cb 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -438,6 +438,8 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at type = "table" query_language = "sql" is_rls_supported = True + columns: List[TableColumn] = [] + metrics: List[SqlMetric] = [] metric_class = SqlMetric column_class = TableColumn owner_class = security_manager.user_model @@ -1333,7 +1335,9 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at db_engine_spec = self.database.db_engine_spec old_columns = db.session.query(TableColumn).filter(TableColumn.table == self) - old_columns_by_name = {col.column_name: col for col in old_columns} + old_columns_by_name: Dict[str, TableColumn] = { + col.column_name: col for col in old_columns + } results = MetadataResult( removed=[ col @@ -1345,7 +1349,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at # clear old columns before adding modified columns back self.columns = [] for col in new_columns: - old_column = old_columns_by_name.get(col["name"], None) + old_column = old_columns_by_name.pop(col["name"], None) if not old_column: results.added.append(col["name"]) new_column = TableColumn( @@ -1358,11 +1362,15 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at if new_column.type != col["type"]: results.modified.append(col["name"]) new_column.type = col["type"] + new_column.expression = "" new_column.groupby = True new_column.filterable = True self.columns.append(new_column) if not any_date_col and new_column.is_temporal: any_date_col = col["name"] + self.columns.extend( + [col for col in old_columns_by_name.values() if col.expression] + ) metrics.append( SqlMetric( metric_name="count", diff --git a/superset/connectors/sqla/views.py b/superset/connectors/sqla/views.py index e2ef7945c..3533fc50c 100644 --- a/superset/connectors/sqla/views.py +++ b/superset/connectors/sqla/views.py @@ -454,9 +454,13 @@ class TableModelView( # pylint: disable=too-many-ancestors validate_sqlatable(item) def post_add( # pylint: disable=arguments-differ - self, item: "TableModelView", flash_message: bool = True + self, + item: "TableModelView", + flash_message: bool = True, + fetch_metadata: bool = True, ) -> None: - item.fetch_metadata() + if fetch_metadata: + item.fetch_metadata() create_table_permissions(item) if flash_message: flash( @@ -470,7 +474,7 @@ class TableModelView( # pylint: disable=too-many-ancestors ) def post_update(self, item: "TableModelView") -> None: - self.post_add(item, flash_message=False) + self.post_add(item, flash_message=False, fetch_metadata=False) def _delete(self, pk: int) -> None: DeleteMixin._delete(self, pk) diff --git a/tests/sqla_models_tests.py b/tests/sqla_models_tests.py index 1f573f52f..f6bb4fc76 100644 --- a/tests/sqla_models_tests.py +++ b/tests/sqla_models_tests.py @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. # isort:skip_file -from typing import Any, Dict, NamedTuple, List, Tuple, Union +import re +from typing import Any, Dict, NamedTuple, List, Pattern, Tuple, Union from unittest.mock import patch import pytest @@ -30,6 +31,23 @@ from superset.utils.core import DbColumnType, get_example_database, FilterOperat from .base_tests import SupersetTestCase +VIRTUAL_TABLE_INT_TYPES: Dict[str, Pattern[str]] = { + "hive": re.compile(r"^INT_TYPE$"), + "mysql": re.compile("^LONGLONG$"), + "postgresql": re.compile(r"^INT$"), + "presto": re.compile(r"^INTEGER$"), + "sqlite": re.compile(r"^INT$"), +} + +VIRTUAL_TABLE_STRING_TYPES: Dict[str, Pattern[str]] = { + "hive": re.compile(r"^STRING_TYPE$"), + "mysql": re.compile(r"^VAR_STRING$"), + "postgresql": re.compile(r"^STRING$"), + "presto": re.compile(r"^VARCHAR*"), + "sqlite": re.compile(r"^STRING$"), +} + + class TestDatabaseModel(SupersetTestCase): def test_is_time_druid_time_col(self): """Druid has a special __time column""" @@ -247,3 +265,44 @@ class TestDatabaseModel(SupersetTestCase): query_obj = dict(**base_query_obj, extras={}) with pytest.raises(QueryObjectValidationError): table.get_sqla_query(**query_obj) + + def test_fetch_metadata_for_updated_virtual_table(self): + table = SqlaTable( + table_name="updated_sql_table", + database=get_example_database(), + sql="select 123 as intcol, 'abc' as strcol, 'abc' as mycase", + ) + TableColumn(column_name="intcol", type="FLOAT", table=table) + TableColumn(column_name="oldcol", type="INT", table=table) + TableColumn( + column_name="expr", + expression="case when 1 then 1 else 0 end", + type="INT", + table=table, + ) + TableColumn( + column_name="mycase", + expression="case when 1 then 1 else 0 end", + type="INT", + table=table, + ) + + # make sure the columns have been mapped properly + assert len(table.columns) == 4 + table.fetch_metadata() + # assert that the removed column has been dropped and + # the physical and calculated columns are present + assert {col.column_name for col in table.columns} == { + "intcol", + "strcol", + "mycase", + "expr", + } + cols: Dict[str, TableColumn] = {col.column_name: col for col in table.columns} + # assert that the type for intcol has been updated (asserting CI types) + backend = get_example_database().backend + assert VIRTUAL_TABLE_INT_TYPES[backend].match(cols["intcol"].type) + # assert that the expression has been replaced with the new physical column + assert cols["mycase"].expression == "" + assert VIRTUAL_TABLE_STRING_TYPES[backend].match(cols["mycase"].type) + assert cols["expr"].expression == "case when 1 then 1 else 0 end"