fix(db_engine_specs): improve Presto column type matching (#10658)

* fix: improve Presto column type matching

* add optional callback to type map and add tests

* lint

* change private to public
This commit is contained in:
Ville Brofeldt 2020-08-24 22:42:07 +03:00 committed by GitHub
parent 0177c2f591
commit 9461f9c1e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 97 additions and 65 deletions

View File

@ -24,8 +24,10 @@ from contextlib import closing
from datetime import datetime
from typing import (
Any,
Callable,
Dict,
List,
Match,
NamedTuple,
Optional,
Pattern,
@ -142,6 +144,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
] = None # used for user messages, overridden in child classes
_date_trunc_functions: Dict[str, str] = {}
_time_grain_expressions: Dict[Optional[str], str] = {}
column_type_mappings: Tuple[
Tuple[Pattern[str], Union[TypeEngine, Callable[[Match[str]], TypeEngine]]], ...,
] = ()
time_groupby_inline = False
limit_method = LimitMethod.FORCE_LIMIT
time_secondary_columns = False
@ -886,12 +891,18 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
"""
Return a sqlalchemy native column type that corresponds to the column type
defined in the data source (return None to use default type inferred by
SQLAlchemy). Needs to be overridden if column requires special handling
SQLAlchemy). Override `_column_type_mappings` for specific needs
(see MSSQL for example of NCHAR/NVARCHAR handling).
:param type_: Column type returned by inspector
:return: SqlAlchemy column type
"""
for regex, sqla_type in cls.column_type_mappings:
match = regex.match(type_)
if match:
if callable(sqla_type):
return sqla_type(match)
return sqla_type
return None
@staticmethod

View File

@ -19,7 +19,7 @@ import re
from datetime import datetime
from typing import Any, List, Optional, Tuple, TYPE_CHECKING
from sqlalchemy.types import String, TypeEngine, UnicodeText
from sqlalchemy.types import String, UnicodeText
from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
from superset.utils import core as utils
@ -73,18 +73,11 @@ class MssqlEngineSpec(BaseEngineSpec):
# Lists of `pyodbc.Row` need to be unpacked further
return cls.pyodbc_rows_to_tuples(data)
column_types = (
(String(), re.compile(r"^(?<!N)((VAR){0,1}CHAR|TEXT|STRING)", re.IGNORECASE)),
(UnicodeText(), re.compile(r"^N((VAR){0,1}CHAR|TEXT)", re.IGNORECASE)),
column_type_mappings = (
(re.compile(r"^N((VAR)?CHAR|TEXT)", re.IGNORECASE), UnicodeText()),
(re.compile(r"^((VAR)?CHAR|TEXT|STRING)", re.IGNORECASE), String()),
)
@classmethod
def get_sqla_column_type(cls, type_: str) -> Optional[TypeEngine]:
for sqla_type, regex in cls.column_types:
if regex.match(type_):
return sqla_type
return None
@classmethod
def extract_error_message(cls, ex: Exception) -> str:
if str(ex).startswith("(8155,"):

View File

@ -28,7 +28,7 @@ from urllib import parse
import pandas as pd
import simplejson as json
from flask_babel import lazy_gettext as _
from sqlalchemy import Column, literal_column
from sqlalchemy import Column, literal_column, types
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.result import RowProxy
@ -40,7 +40,13 @@ from superset import app, cache, is_feature_enabled, security_manager
from superset.db_engine_specs.base import BaseEngineSpec
from superset.exceptions import SupersetTemplateException
from superset.models.sql_lab import Query
from superset.models.sql_types.presto_sql_types import type_map as presto_type_map
from superset.models.sql_types.presto_sql_types import (
Array,
Interval,
Map,
Row,
TinyInteger,
)
from superset.result_set import destringify
from superset.sql_parse import ParsedQuery
from superset.utils import core as utils
@ -260,13 +266,16 @@ class PrestoEngineSpec(BaseEngineSpec):
field_info = cls._split_data_type(single_field, r"\s")
# check if there is a structural data type within
# overall structural data type
column_type = cls.get_sqla_column_type(field_info[1])
if column_type is None:
raise NotImplementedError(
_("Unknown column type: %(col)s", col=field_info[1])
)
if field_info[1] == "array" or field_info[1] == "row":
stack.append((field_info[0], field_info[1]))
full_parent_path = cls._get_full_name(stack)
result.append(
cls._create_column_info(
full_parent_path, presto_type_map[field_info[1]]()
)
cls._create_column_info(full_parent_path, column_type)
)
else: # otherwise this field is a basic data type
full_parent_path = cls._get_full_name(stack)
@ -274,9 +283,7 @@ class PrestoEngineSpec(BaseEngineSpec):
full_parent_path, field_info[0]
)
result.append(
cls._create_column_info(
column_name, presto_type_map[field_info[1]]()
)
cls._create_column_info(column_name, column_type)
)
# If the component type ends with a structural data type, do not pop
# the stack. We have run across a structural data type within the
@ -318,6 +325,34 @@ class PrestoEngineSpec(BaseEngineSpec):
columns = inspector.bind.execute("SHOW COLUMNS FROM {}".format(full_table))
return columns
column_type_mappings = (
(re.compile(r"^boolean.*", re.IGNORECASE), types.Boolean()),
(re.compile(r"^tinyint.*", re.IGNORECASE), TinyInteger()),
(re.compile(r"^smallint.*", re.IGNORECASE), types.SmallInteger()),
(re.compile(r"^integer.*", re.IGNORECASE), types.Integer()),
(re.compile(r"^bigint.*", re.IGNORECASE), types.BigInteger()),
(re.compile(r"^real.*", re.IGNORECASE), types.Float()),
(re.compile(r"^double.*", re.IGNORECASE), types.Float()),
(re.compile(r"^decimal.*", re.IGNORECASE), types.DECIMAL()),
(
re.compile(r"^varchar(\((\d+)\))*$", re.IGNORECASE),
lambda match: types.VARCHAR(int(match[2])) if match[2] else types.String(),
),
(
re.compile(r"^char(\((\d+)\))*$", re.IGNORECASE),
lambda match: types.CHAR(int(match[2])) if match[2] else types.CHAR(),
),
(re.compile(r"^varbinary.*", re.IGNORECASE), types.VARBINARY()),
(re.compile(r"^json.*", re.IGNORECASE), types.JSON()),
(re.compile(r"^date.*", re.IGNORECASE), types.DATE()),
(re.compile(r"^time.*", re.IGNORECASE), types.Time()),
(re.compile(r"^timestamp.*", re.IGNORECASE), types.TIMESTAMP()),
(re.compile(r"^interval.*", re.IGNORECASE), Interval()),
(re.compile(r"^array.*", re.IGNORECASE), Array()),
(re.compile(r"^map.*", re.IGNORECASE), Map()),
(re.compile(r"^row.*", re.IGNORECASE), Row()),
)
@classmethod
def get_columns(
cls, inspector: Inspector, table_name: str, schema: Optional[str]
@ -334,28 +369,24 @@ class PrestoEngineSpec(BaseEngineSpec):
columns = cls._show_columns(inspector, table_name, schema)
result: List[Dict[str, Any]] = []
for column in columns:
try:
# parse column if it is a row or array
if is_feature_enabled("PRESTO_EXPAND_DATA") and (
"array" in column.Type or "row" in column.Type
):
structural_column_index = len(result)
cls._parse_structural_column(column.Column, column.Type, result)
result[structural_column_index]["nullable"] = getattr(
column, "Null", True
)
result[structural_column_index]["default"] = None
continue
# otherwise column is a basic data type
column_type = presto_type_map[column.Type]()
except KeyError:
logger.info(
"Did not recognize type {} of column {}".format( # pylint: disable=logging-format-interpolation
column.Type, column.Column
)
# parse column if it is a row or array
if is_feature_enabled("PRESTO_EXPAND_DATA") and (
"array" in column.Type or "row" in column.Type
):
structural_column_index = len(result)
cls._parse_structural_column(column.Column, column.Type, result)
result[structural_column_index]["nullable"] = getattr(
column, "Null", True
)
result[structural_column_index]["default"] = None
continue
# otherwise column is a basic data type
column_type = cls.get_sqla_column_type(column.Type)
if column_type is None:
raise NotImplementedError(
_("Unknown column type: %(col)s", col=column_type)
)
column_type = "OTHER"
column_info = cls._create_column_info(column.Column, column_type)
column_info["nullable"] = getattr(column, "Null", True)
column_info["default"] = None

View File

@ -16,7 +16,6 @@
# under the License.
from typing import Any, Dict, List, Optional, Type
from sqlalchemy import types
from sqlalchemy.sql.sqltypes import Integer
from sqlalchemy.sql.type_api import TypeEngine
from sqlalchemy.sql.visitors import Visitable
@ -92,26 +91,3 @@ class Row(TypeEngine):
@classmethod
def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str:
return "ROW"
type_map = {
"boolean": types.Boolean,
"tinyint": TinyInteger,
"smallint": types.SmallInteger,
"integer": types.Integer,
"bigint": types.BigInteger,
"real": types.Float,
"double": types.Float,
"decimal": types.DECIMAL,
"varchar": types.String,
"char": types.CHAR,
"varbinary": types.VARBINARY,
"JSON": types.JSON,
"date": types.DATE,
"time": types.Time,
"timestamp": types.TIMESTAMP,
"interval": Interval,
"array": Array,
"map": Map,
"row": Row,
}

View File

@ -17,6 +17,7 @@
from unittest import mock, skipUnless
import pandas as pd
from sqlalchemy import types
from sqlalchemy.engine.result import RowProxy
from sqlalchemy.sql import select
@ -490,3 +491,23 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
self.assertEqual(actual_cols, expected_cols)
self.assertEqual(actual_data, expected_data)
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
def test_get_sqla_column_type(self):
sqla_type = PrestoEngineSpec.get_sqla_column_type("varchar(255)")
assert isinstance(sqla_type, types.VARCHAR)
assert sqla_type.length == 255
sqla_type = PrestoEngineSpec.get_sqla_column_type("varchar")
assert isinstance(sqla_type, types.String)
assert sqla_type.length is None
sqla_type = PrestoEngineSpec.get_sqla_column_type("char(10)")
assert isinstance(sqla_type, types.CHAR)
assert sqla_type.length == 10
sqla_type = PrestoEngineSpec.get_sqla_column_type("char")
assert isinstance(sqla_type, types.CHAR)
assert sqla_type.length is None
sqla_type = PrestoEngineSpec.get_sqla_column_type("integer")
assert isinstance(sqla_type, types.Integer)