fix(db_engine_specs): improve Presto column type matching (#10658)

* fix: improve Presto column type matching

* add optional callback to type map and add tests

* lint

* change private to public
This commit is contained in:
Ville Brofeldt 2020-08-24 22:42:07 +03:00 committed by GitHub
parent 0177c2f591
commit 9461f9c1e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 97 additions and 65 deletions

View File

@ -24,8 +24,10 @@ from contextlib import closing
from datetime import datetime from datetime import datetime
from typing import ( from typing import (
Any, Any,
Callable,
Dict, Dict,
List, List,
Match,
NamedTuple, NamedTuple,
Optional, Optional,
Pattern, Pattern,
@ -142,6 +144,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
] = None # used for user messages, overridden in child classes ] = None # used for user messages, overridden in child classes
_date_trunc_functions: Dict[str, str] = {} _date_trunc_functions: Dict[str, str] = {}
_time_grain_expressions: Dict[Optional[str], str] = {} _time_grain_expressions: Dict[Optional[str], str] = {}
column_type_mappings: Tuple[
Tuple[Pattern[str], Union[TypeEngine, Callable[[Match[str]], TypeEngine]]], ...,
] = ()
time_groupby_inline = False time_groupby_inline = False
limit_method = LimitMethod.FORCE_LIMIT limit_method = LimitMethod.FORCE_LIMIT
time_secondary_columns = False time_secondary_columns = False
@ -886,12 +891,18 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
""" """
Return a sqlalchemy native column type that corresponds to the column type Return a sqlalchemy native column type that corresponds to the column type
defined in the data source (return None to use default type inferred by defined in the data source (return None to use default type inferred by
SQLAlchemy). Needs to be overridden if column requires special handling SQLAlchemy). Override `_column_type_mappings` for specific needs
(see MSSQL for example of NCHAR/NVARCHAR handling). (see MSSQL for example of NCHAR/NVARCHAR handling).
:param type_: Column type returned by inspector :param type_: Column type returned by inspector
:return: SqlAlchemy column type :return: SqlAlchemy column type
""" """
for regex, sqla_type in cls.column_type_mappings:
match = regex.match(type_)
if match:
if callable(sqla_type):
return sqla_type(match)
return sqla_type
return None return None
@staticmethod @staticmethod

View File

@ -19,7 +19,7 @@ import re
from datetime import datetime from datetime import datetime
from typing import Any, List, Optional, Tuple, TYPE_CHECKING from typing import Any, List, Optional, Tuple, TYPE_CHECKING
from sqlalchemy.types import String, TypeEngine, UnicodeText from sqlalchemy.types import String, UnicodeText
from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
from superset.utils import core as utils from superset.utils import core as utils
@ -73,18 +73,11 @@ class MssqlEngineSpec(BaseEngineSpec):
# Lists of `pyodbc.Row` need to be unpacked further # Lists of `pyodbc.Row` need to be unpacked further
return cls.pyodbc_rows_to_tuples(data) return cls.pyodbc_rows_to_tuples(data)
column_types = ( column_type_mappings = (
(String(), re.compile(r"^(?<!N)((VAR){0,1}CHAR|TEXT|STRING)", re.IGNORECASE)), (re.compile(r"^N((VAR)?CHAR|TEXT)", re.IGNORECASE), UnicodeText()),
(UnicodeText(), re.compile(r"^N((VAR){0,1}CHAR|TEXT)", re.IGNORECASE)), (re.compile(r"^((VAR)?CHAR|TEXT|STRING)", re.IGNORECASE), String()),
) )
@classmethod
def get_sqla_column_type(cls, type_: str) -> Optional[TypeEngine]:
for sqla_type, regex in cls.column_types:
if regex.match(type_):
return sqla_type
return None
@classmethod @classmethod
def extract_error_message(cls, ex: Exception) -> str: def extract_error_message(cls, ex: Exception) -> str:
if str(ex).startswith("(8155,"): if str(ex).startswith("(8155,"):

View File

@ -28,7 +28,7 @@ from urllib import parse
import pandas as pd import pandas as pd
import simplejson as json import simplejson as json
from flask_babel import lazy_gettext as _ from flask_babel import lazy_gettext as _
from sqlalchemy import Column, literal_column from sqlalchemy import Column, literal_column, types
from sqlalchemy.engine.base import Engine from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.result import RowProxy from sqlalchemy.engine.result import RowProxy
@ -40,7 +40,13 @@ from superset import app, cache, is_feature_enabled, security_manager
from superset.db_engine_specs.base import BaseEngineSpec from superset.db_engine_specs.base import BaseEngineSpec
from superset.exceptions import SupersetTemplateException from superset.exceptions import SupersetTemplateException
from superset.models.sql_lab import Query from superset.models.sql_lab import Query
from superset.models.sql_types.presto_sql_types import type_map as presto_type_map from superset.models.sql_types.presto_sql_types import (
Array,
Interval,
Map,
Row,
TinyInteger,
)
from superset.result_set import destringify from superset.result_set import destringify
from superset.sql_parse import ParsedQuery from superset.sql_parse import ParsedQuery
from superset.utils import core as utils from superset.utils import core as utils
@ -260,13 +266,16 @@ class PrestoEngineSpec(BaseEngineSpec):
field_info = cls._split_data_type(single_field, r"\s") field_info = cls._split_data_type(single_field, r"\s")
# check if there is a structural data type within # check if there is a structural data type within
# overall structural data type # overall structural data type
column_type = cls.get_sqla_column_type(field_info[1])
if column_type is None:
raise NotImplementedError(
_("Unknown column type: %(col)s", col=field_info[1])
)
if field_info[1] == "array" or field_info[1] == "row": if field_info[1] == "array" or field_info[1] == "row":
stack.append((field_info[0], field_info[1])) stack.append((field_info[0], field_info[1]))
full_parent_path = cls._get_full_name(stack) full_parent_path = cls._get_full_name(stack)
result.append( result.append(
cls._create_column_info( cls._create_column_info(full_parent_path, column_type)
full_parent_path, presto_type_map[field_info[1]]()
)
) )
else: # otherwise this field is a basic data type else: # otherwise this field is a basic data type
full_parent_path = cls._get_full_name(stack) full_parent_path = cls._get_full_name(stack)
@ -274,9 +283,7 @@ class PrestoEngineSpec(BaseEngineSpec):
full_parent_path, field_info[0] full_parent_path, field_info[0]
) )
result.append( result.append(
cls._create_column_info( cls._create_column_info(column_name, column_type)
column_name, presto_type_map[field_info[1]]()
)
) )
# If the component type ends with a structural data type, do not pop # If the component type ends with a structural data type, do not pop
# the stack. We have run across a structural data type within the # the stack. We have run across a structural data type within the
@ -318,6 +325,34 @@ class PrestoEngineSpec(BaseEngineSpec):
columns = inspector.bind.execute("SHOW COLUMNS FROM {}".format(full_table)) columns = inspector.bind.execute("SHOW COLUMNS FROM {}".format(full_table))
return columns return columns
column_type_mappings = (
(re.compile(r"^boolean.*", re.IGNORECASE), types.Boolean()),
(re.compile(r"^tinyint.*", re.IGNORECASE), TinyInteger()),
(re.compile(r"^smallint.*", re.IGNORECASE), types.SmallInteger()),
(re.compile(r"^integer.*", re.IGNORECASE), types.Integer()),
(re.compile(r"^bigint.*", re.IGNORECASE), types.BigInteger()),
(re.compile(r"^real.*", re.IGNORECASE), types.Float()),
(re.compile(r"^double.*", re.IGNORECASE), types.Float()),
(re.compile(r"^decimal.*", re.IGNORECASE), types.DECIMAL()),
(
re.compile(r"^varchar(\((\d+)\))*$", re.IGNORECASE),
lambda match: types.VARCHAR(int(match[2])) if match[2] else types.String(),
),
(
re.compile(r"^char(\((\d+)\))*$", re.IGNORECASE),
lambda match: types.CHAR(int(match[2])) if match[2] else types.CHAR(),
),
(re.compile(r"^varbinary.*", re.IGNORECASE), types.VARBINARY()),
(re.compile(r"^json.*", re.IGNORECASE), types.JSON()),
(re.compile(r"^date.*", re.IGNORECASE), types.DATE()),
(re.compile(r"^time.*", re.IGNORECASE), types.Time()),
(re.compile(r"^timestamp.*", re.IGNORECASE), types.TIMESTAMP()),
(re.compile(r"^interval.*", re.IGNORECASE), Interval()),
(re.compile(r"^array.*", re.IGNORECASE), Array()),
(re.compile(r"^map.*", re.IGNORECASE), Map()),
(re.compile(r"^row.*", re.IGNORECASE), Row()),
)
@classmethod @classmethod
def get_columns( def get_columns(
cls, inspector: Inspector, table_name: str, schema: Optional[str] cls, inspector: Inspector, table_name: str, schema: Optional[str]
@ -334,28 +369,24 @@ class PrestoEngineSpec(BaseEngineSpec):
columns = cls._show_columns(inspector, table_name, schema) columns = cls._show_columns(inspector, table_name, schema)
result: List[Dict[str, Any]] = [] result: List[Dict[str, Any]] = []
for column in columns: for column in columns:
try: # parse column if it is a row or array
# parse column if it is a row or array if is_feature_enabled("PRESTO_EXPAND_DATA") and (
if is_feature_enabled("PRESTO_EXPAND_DATA") and ( "array" in column.Type or "row" in column.Type
"array" in column.Type or "row" in column.Type ):
): structural_column_index = len(result)
structural_column_index = len(result) cls._parse_structural_column(column.Column, column.Type, result)
cls._parse_structural_column(column.Column, column.Type, result) result[structural_column_index]["nullable"] = getattr(
result[structural_column_index]["nullable"] = getattr( column, "Null", True
column, "Null", True )
) result[structural_column_index]["default"] = None
result[structural_column_index]["default"] = None continue
continue
# otherwise column is a basic data type
# otherwise column is a basic data type column_type = cls.get_sqla_column_type(column.Type)
column_type = presto_type_map[column.Type]() if column_type is None:
except KeyError: raise NotImplementedError(
logger.info( _("Unknown column type: %(col)s", col=column_type)
"Did not recognize type {} of column {}".format( # pylint: disable=logging-format-interpolation
column.Type, column.Column
)
) )
column_type = "OTHER"
column_info = cls._create_column_info(column.Column, column_type) column_info = cls._create_column_info(column.Column, column_type)
column_info["nullable"] = getattr(column, "Null", True) column_info["nullable"] = getattr(column, "Null", True)
column_info["default"] = None column_info["default"] = None

View File

@ -16,7 +16,6 @@
# under the License. # under the License.
from typing import Any, Dict, List, Optional, Type from typing import Any, Dict, List, Optional, Type
from sqlalchemy import types
from sqlalchemy.sql.sqltypes import Integer from sqlalchemy.sql.sqltypes import Integer
from sqlalchemy.sql.type_api import TypeEngine from sqlalchemy.sql.type_api import TypeEngine
from sqlalchemy.sql.visitors import Visitable from sqlalchemy.sql.visitors import Visitable
@ -92,26 +91,3 @@ class Row(TypeEngine):
@classmethod @classmethod
def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str: def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str:
return "ROW" return "ROW"
type_map = {
"boolean": types.Boolean,
"tinyint": TinyInteger,
"smallint": types.SmallInteger,
"integer": types.Integer,
"bigint": types.BigInteger,
"real": types.Float,
"double": types.Float,
"decimal": types.DECIMAL,
"varchar": types.String,
"char": types.CHAR,
"varbinary": types.VARBINARY,
"JSON": types.JSON,
"date": types.DATE,
"time": types.Time,
"timestamp": types.TIMESTAMP,
"interval": Interval,
"array": Array,
"map": Map,
"row": Row,
}

View File

@ -17,6 +17,7 @@
from unittest import mock, skipUnless from unittest import mock, skipUnless
import pandas as pd import pandas as pd
from sqlalchemy import types
from sqlalchemy.engine.result import RowProxy from sqlalchemy.engine.result import RowProxy
from sqlalchemy.sql import select from sqlalchemy.sql import select
@ -490,3 +491,23 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
self.assertEqual(actual_cols, expected_cols) self.assertEqual(actual_cols, expected_cols)
self.assertEqual(actual_data, expected_data) self.assertEqual(actual_data, expected_data)
self.assertEqual(actual_expanded_cols, expected_expanded_cols) self.assertEqual(actual_expanded_cols, expected_expanded_cols)
def test_get_sqla_column_type(self):
sqla_type = PrestoEngineSpec.get_sqla_column_type("varchar(255)")
assert isinstance(sqla_type, types.VARCHAR)
assert sqla_type.length == 255
sqla_type = PrestoEngineSpec.get_sqla_column_type("varchar")
assert isinstance(sqla_type, types.String)
assert sqla_type.length is None
sqla_type = PrestoEngineSpec.get_sqla_column_type("char(10)")
assert isinstance(sqla_type, types.CHAR)
assert sqla_type.length == 10
sqla_type = PrestoEngineSpec.get_sqla_column_type("char")
assert isinstance(sqla_type, types.CHAR)
assert sqla_type.length is None
sqla_type = PrestoEngineSpec.get_sqla_column_type("integer")
assert isinstance(sqla_type, types.Integer)