diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index 00a9310b2..89e677b01 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -945,11 +945,20 @@ class PrestoEngineSpec(BaseEngineSpec): r'{}(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)'.format(delimiter), data_type) @classmethod - def _parse_structural_column(cls, full_data_type: str, result: List[dict]) -> None: + def _parse_structural_column(cls, + parent_column_name: str, + parent_data_type: str, + result: List[dict]) -> None: """ Parse a row or array column :param result: list tracking the results """ + formatted_parent_column_name = parent_column_name + # Quote the column name if there is a space + if ' ' in parent_column_name: + formatted_parent_column_name = f'"{parent_column_name}"' + full_data_type = f'{formatted_parent_column_name} {parent_data_type}' + original_result_len = len(result) # split on open parenthesis ( to get the structural # data type and its component types data_types = cls._split_data_type(full_data_type, r'\(') @@ -1001,6 +1010,11 @@ class PrestoEngineSpec(BaseEngineSpec): # Because it is an array of a basic data type. We have finished # parsing the structural data type and can move on. stack.pop() + # Unquote the column name if necessary + if formatted_parent_column_name != parent_column_name: + for index in range(original_result_len, len(result)): + result[index]['name'] = result[index]['name'].replace( + formatted_parent_column_name, parent_column_name) @classmethod def _show_columns( @@ -1037,9 +1051,8 @@ class PrestoEngineSpec(BaseEngineSpec): try: # parse column if it is a row or array if 'array' in column.Type or 'row' in column.Type: - full_data_type = '{} {}'.format(column.Column, column.Type) structural_column_index = len(result) - cls._parse_structural_column(full_data_type, result) + cls._parse_structural_column(column.Column, column.Type, result) result[structural_column_index]['nullable'] = getattr( column, 'Null', True) result[structural_column_index]['default'] = None @@ -1244,8 +1257,9 @@ class PrestoEngineSpec(BaseEngineSpec): for column in selected_columns: if column['type'].startswith('ROW'): parsed_row_columns: List[dict] = [] - full_data_type = '{} {}'.format(column['name'], column['type'].lower()) - cls._parse_structural_column(full_data_type, parsed_row_columns) + cls._parse_structural_column(column['name'], + column['type'].lower(), + parsed_row_columns) expanded_columns = expanded_columns + parsed_row_columns[1:] filtered_row_columns, array_columns = cls._filter_out_array_nested_cols( parsed_row_columns) @@ -1257,8 +1271,9 @@ class PrestoEngineSpec(BaseEngineSpec): array_column_hierarchy) elif column['type'].startswith('ARRAY'): parsed_array_columns: List[dict] = [] - full_data_type = '{} {}'.format(column['name'], column['type'].lower()) - cls._parse_structural_column(full_data_type, parsed_array_columns) + cls._parse_structural_column(column['name'], + column['type'].lower(), + parsed_array_columns) expanded_columns = expanded_columns + parsed_array_columns[1:] cls._build_column_hierarchy(parsed_array_columns, ['ROW', 'ARRAY'], @@ -1523,8 +1538,9 @@ class PrestoEngineSpec(BaseEngineSpec): # Get the list of all columns (selected fields and their nested fields) for column in columns: if column['type'].startswith('ARRAY') or column['type'].startswith('ROW'): - full_data_type = '{} {}'.format(column['name'], column['type'].lower()) - cls._parse_structural_column(full_data_type, all_columns) + cls._parse_structural_column(column['name'], + column['type'].lower(), + all_columns) else: all_columns.append(column) diff --git a/tests/db_engine_specs_test.py b/tests/db_engine_specs_test.py index 02dbbae8b..44919143d 100644 --- a/tests/db_engine_specs_test.py +++ b/tests/db_engine_specs_test.py @@ -350,7 +350,14 @@ class DbEngineSpecsTestCase(SupersetTestCase): ('column_name.nested_obj', 'FLOAT')] self.verify_presto_column(presto_column, expected_results) - def test_presto_get_simple_row_column_with_tricky_name(self): + def test_presto_get_simple_row_column_with_name_containing_whitespace(self): + presto_column = ('column name', 'row(nested_obj double)', '') + expected_results = [ + ('column name', 'ROW'), + ('column name.nested_obj', 'FLOAT')] + self.verify_presto_column(presto_column, expected_results) + + def test_presto_get_simple_row_column_with_tricky_nested_field_name(self): presto_column = ('column_name', 'row("Field Name(Tricky, Name)" double)', '') expected_results = [ ('column_name', 'ROW'),