fix(presto): Handle ROW data stored as string (#10456)

* Handle ROW data stored as string

* Use destringify

* Fix mypy

* Fix mypy with cast

* Bypass pylint
This commit is contained in:
Beto Dealmeida 2020-07-28 16:05:58 -07:00 committed by GitHub
parent 39fad8575c
commit 4f678272d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 73 additions and 5 deletions

View File

@ -22,7 +22,7 @@ from collections import defaultdict, deque
from contextlib import closing
from datetime import datetime
from distutils.version import StrictVersion
from typing import Any, cast, Dict, List, Optional, Tuple, TYPE_CHECKING
from typing import Any, cast, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
from urllib import parse
import pandas as pd
@ -40,6 +40,7 @@ from superset.db_engine_specs.base import BaseEngineSpec
from superset.exceptions import SupersetTemplateException
from superset.models.sql_lab import Query
from superset.models.sql_types.presto_sql_types import type_map as presto_type_map
from superset.result_set import destringify
from superset.sql_parse import ParsedQuery
from superset.utils import core as utils
@ -568,7 +569,7 @@ class PrestoEngineSpec(BaseEngineSpec):
return datasource_names
@classmethod
def expand_data( # pylint: disable=too-many-locals
def expand_data( # pylint: disable=too-many-locals,too-many-branches
cls, columns: List[Dict[Any, Any]], data: List[Dict[Any, Any]]
) -> Tuple[List[Dict[Any, Any]], List[Dict[Any, Any]], List[Dict[Any, Any]]]:
"""
@ -616,6 +617,7 @@ class PrestoEngineSpec(BaseEngineSpec):
current_array_level = level
name = column["name"]
values: Optional[Union[str, List[Any]]]
if column["type"].startswith("ARRAY("):
# keep processing array children; we append to the right so that
@ -627,6 +629,8 @@ class PrestoEngineSpec(BaseEngineSpec):
while i < len(data):
row = data[i]
values = row.get(name)
if isinstance(values, str):
row[name] = values = destringify(values)
if values:
# how many extra rows we need to unnest the data?
extra_rows = len(values) - 1
@ -653,12 +657,15 @@ class PrestoEngineSpec(BaseEngineSpec):
# expand columns; we append them to the left so they are added
# immediately after the parent
expanded = get_children(column)
to_process.extendleft((column, level) for column in expanded)
to_process.extendleft((column, level) for column in expanded[::-1])
expanded_columns.extend(expanded)
# expand row objects into new columns
for row in data:
for value, col in zip(row.get(name) or [], expanded):
values = row.get(name) or []
if isinstance(values, str):
row[name] = values = cast(List[Any], destringify(values))
for value, col in zip(values, expanded):
row[col["name"]] = value
data = [

View File

@ -67,6 +67,10 @@ def stringify_values(array: np.ndarray) -> np.ndarray:
return vstringify(array)
def destringify(obj: str) -> Any:
return json.loads(obj)
class SupersetResultSet:
def __init__( # pylint: disable=too-many-locals,too-many-branches
self,

View File

@ -214,9 +214,9 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
"name": "row_column",
"type": "ROW(NESTED_OBJ1 VARCHAR, NESTED_ROW ROW(NESTED_OBJ2 VARCHAR))",
},
{"name": "row_column.nested_obj1", "type": "VARCHAR"},
{"name": "row_column.nested_row", "type": "ROW(NESTED_OBJ2 VARCHAR)"},
{"name": "row_column.nested_row.nested_obj2", "type": "VARCHAR"},
{"name": "row_column.nested_obj1", "type": "VARCHAR"},
]
expected_data = [
{
@ -433,3 +433,60 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
}
]
self.assertEqual(formatted_cost, expected)
@mock.patch.dict(
"superset.extensions.feature_flag_manager._feature_flags",
{"PRESTO_EXPAND_DATA": True},
clear=True,
)
def test_presto_expand_data_array(self):
cols = [
{"name": "event_id", "type": "VARCHAR", "is_date": False},
{"name": "timestamp", "type": "BIGINT", "is_date": False},
{
"name": "user",
"type": "ROW(ID BIGINT, FIRST_NAME VARCHAR, LAST_NAME VARCHAR)",
"is_date": False,
},
]
data = [
{
"event_id": "abcdef01-2345-6789-abcd-ef0123456789",
"timestamp": "1595895506219",
"user": '[1, "JOHN", "DOE"]',
}
]
actual_cols, actual_data, actual_expanded_cols = PrestoEngineSpec.expand_data(
cols, data
)
expected_cols = [
{"name": "event_id", "type": "VARCHAR", "is_date": False},
{"name": "timestamp", "type": "BIGINT", "is_date": False},
{
"name": "user",
"type": "ROW(ID BIGINT, FIRST_NAME VARCHAR, LAST_NAME VARCHAR)",
"is_date": False,
},
{"name": "user.id", "type": "BIGINT"},
{"name": "user.first_name", "type": "VARCHAR"},
{"name": "user.last_name", "type": "VARCHAR"},
]
expected_data = [
{
"event_id": "abcdef01-2345-6789-abcd-ef0123456789",
"timestamp": "1595895506219",
"user": [1, "JOHN", "DOE"],
"user.id": 1,
"user.first_name": "JOHN",
"user.last_name": "DOE",
}
]
expected_expanded_cols = [
{"name": "user.id", "type": "BIGINT"},
{"name": "user.first_name", "type": "VARCHAR"},
{"name": "user.last_name", "type": "VARCHAR"},
]
self.assertEqual(actual_cols, expected_cols)
self.assertEqual(actual_data, expected_data)
self.assertEqual(actual_expanded_cols, expected_expanded_cols)