fix: do not drop calculated column on metadata sync (#11731)
This commit is contained in:
parent
676e0bb282
commit
7ae8cd07cc
|
|
@ -128,10 +128,8 @@ class BaseDatasource(
|
|||
),
|
||||
)
|
||||
|
||||
# placeholder for a relationship to a derivative of BaseColumn
|
||||
columns: List[Any] = []
|
||||
# placeholder for a relationship to a derivative of BaseMetric
|
||||
metrics: List[Any] = []
|
||||
columns: List["BaseColumn"] = []
|
||||
metrics: List["BaseMetric"] = []
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
|
|
|
|||
|
|
@ -453,6 +453,8 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
type = "druid"
|
||||
query_language = "json"
|
||||
cluster_class = DruidCluster
|
||||
columns: List[DruidColumn] = []
|
||||
metrics: List[DruidMetric] = []
|
||||
metric_class = DruidMetric
|
||||
column_class = DruidColumn
|
||||
owner_class = security_manager.user_model
|
||||
|
|
|
|||
|
|
@ -438,6 +438,8 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
|
|||
type = "table"
|
||||
query_language = "sql"
|
||||
is_rls_supported = True
|
||||
columns: List[TableColumn] = []
|
||||
metrics: List[SqlMetric] = []
|
||||
metric_class = SqlMetric
|
||||
column_class = TableColumn
|
||||
owner_class = security_manager.user_model
|
||||
|
|
@ -1333,7 +1335,9 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
|
|||
db_engine_spec = self.database.db_engine_spec
|
||||
old_columns = db.session.query(TableColumn).filter(TableColumn.table == self)
|
||||
|
||||
old_columns_by_name = {col.column_name: col for col in old_columns}
|
||||
old_columns_by_name: Dict[str, TableColumn] = {
|
||||
col.column_name: col for col in old_columns
|
||||
}
|
||||
results = MetadataResult(
|
||||
removed=[
|
||||
col
|
||||
|
|
@ -1345,7 +1349,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
|
|||
# clear old columns before adding modified columns back
|
||||
self.columns = []
|
||||
for col in new_columns:
|
||||
old_column = old_columns_by_name.get(col["name"], None)
|
||||
old_column = old_columns_by_name.pop(col["name"], None)
|
||||
if not old_column:
|
||||
results.added.append(col["name"])
|
||||
new_column = TableColumn(
|
||||
|
|
@ -1358,11 +1362,15 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
|
|||
if new_column.type != col["type"]:
|
||||
results.modified.append(col["name"])
|
||||
new_column.type = col["type"]
|
||||
new_column.expression = ""
|
||||
new_column.groupby = True
|
||||
new_column.filterable = True
|
||||
self.columns.append(new_column)
|
||||
if not any_date_col and new_column.is_temporal:
|
||||
any_date_col = col["name"]
|
||||
self.columns.extend(
|
||||
[col for col in old_columns_by_name.values() if col.expression]
|
||||
)
|
||||
metrics.append(
|
||||
SqlMetric(
|
||||
metric_name="count",
|
||||
|
|
|
|||
|
|
@ -454,9 +454,13 @@ class TableModelView( # pylint: disable=too-many-ancestors
|
|||
validate_sqlatable(item)
|
||||
|
||||
def post_add( # pylint: disable=arguments-differ
|
||||
self, item: "TableModelView", flash_message: bool = True
|
||||
self,
|
||||
item: "TableModelView",
|
||||
flash_message: bool = True,
|
||||
fetch_metadata: bool = True,
|
||||
) -> None:
|
||||
item.fetch_metadata()
|
||||
if fetch_metadata:
|
||||
item.fetch_metadata()
|
||||
create_table_permissions(item)
|
||||
if flash_message:
|
||||
flash(
|
||||
|
|
@ -470,7 +474,7 @@ class TableModelView( # pylint: disable=too-many-ancestors
|
|||
)
|
||||
|
||||
def post_update(self, item: "TableModelView") -> None:
|
||||
self.post_add(item, flash_message=False)
|
||||
self.post_add(item, flash_message=False, fetch_metadata=False)
|
||||
|
||||
def _delete(self, pk: int) -> None:
|
||||
DeleteMixin._delete(self, pk)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,8 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# isort:skip_file
|
||||
from typing import Any, Dict, NamedTuple, List, Tuple, Union
|
||||
import re
|
||||
from typing import Any, Dict, NamedTuple, List, Pattern, Tuple, Union
|
||||
from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
|
|
@ -30,6 +31,23 @@ from superset.utils.core import DbColumnType, get_example_database, FilterOperat
|
|||
from .base_tests import SupersetTestCase
|
||||
|
||||
|
||||
VIRTUAL_TABLE_INT_TYPES: Dict[str, Pattern[str]] = {
|
||||
"hive": re.compile(r"^INT_TYPE$"),
|
||||
"mysql": re.compile("^LONGLONG$"),
|
||||
"postgresql": re.compile(r"^INT$"),
|
||||
"presto": re.compile(r"^INTEGER$"),
|
||||
"sqlite": re.compile(r"^INT$"),
|
||||
}
|
||||
|
||||
VIRTUAL_TABLE_STRING_TYPES: Dict[str, Pattern[str]] = {
|
||||
"hive": re.compile(r"^STRING_TYPE$"),
|
||||
"mysql": re.compile(r"^VAR_STRING$"),
|
||||
"postgresql": re.compile(r"^STRING$"),
|
||||
"presto": re.compile(r"^VARCHAR*"),
|
||||
"sqlite": re.compile(r"^STRING$"),
|
||||
}
|
||||
|
||||
|
||||
class TestDatabaseModel(SupersetTestCase):
|
||||
def test_is_time_druid_time_col(self):
|
||||
"""Druid has a special __time column"""
|
||||
|
|
@ -247,3 +265,44 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
query_obj = dict(**base_query_obj, extras={})
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
table.get_sqla_query(**query_obj)
|
||||
|
||||
def test_fetch_metadata_for_updated_virtual_table(self):
|
||||
table = SqlaTable(
|
||||
table_name="updated_sql_table",
|
||||
database=get_example_database(),
|
||||
sql="select 123 as intcol, 'abc' as strcol, 'abc' as mycase",
|
||||
)
|
||||
TableColumn(column_name="intcol", type="FLOAT", table=table)
|
||||
TableColumn(column_name="oldcol", type="INT", table=table)
|
||||
TableColumn(
|
||||
column_name="expr",
|
||||
expression="case when 1 then 1 else 0 end",
|
||||
type="INT",
|
||||
table=table,
|
||||
)
|
||||
TableColumn(
|
||||
column_name="mycase",
|
||||
expression="case when 1 then 1 else 0 end",
|
||||
type="INT",
|
||||
table=table,
|
||||
)
|
||||
|
||||
# make sure the columns have been mapped properly
|
||||
assert len(table.columns) == 4
|
||||
table.fetch_metadata()
|
||||
# assert that the removed column has been dropped and
|
||||
# the physical and calculated columns are present
|
||||
assert {col.column_name for col in table.columns} == {
|
||||
"intcol",
|
||||
"strcol",
|
||||
"mycase",
|
||||
"expr",
|
||||
}
|
||||
cols: Dict[str, TableColumn] = {col.column_name: col for col in table.columns}
|
||||
# assert that the type for intcol has been updated (asserting CI types)
|
||||
backend = get_example_database().backend
|
||||
assert VIRTUAL_TABLE_INT_TYPES[backend].match(cols["intcol"].type)
|
||||
# assert that the expression has been replaced with the new physical column
|
||||
assert cols["mycase"].expression == ""
|
||||
assert VIRTUAL_TABLE_STRING_TYPES[backend].match(cols["mycase"].type)
|
||||
assert cols["expr"].expression == "case when 1 then 1 else 0 end"
|
||||
|
|
|
|||
Loading…
Reference in New Issue