fix(db_engine_specs): improve Presto column type matching (#10658)
* fix: improve Presto column type matching * add optional callback to type map and add tests * lint * change private to public
This commit is contained in:
parent
0177c2f591
commit
9461f9c1e0
|
|
@ -24,8 +24,10 @@ from contextlib import closing
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
|
Match,
|
||||||
NamedTuple,
|
NamedTuple,
|
||||||
Optional,
|
Optional,
|
||||||
Pattern,
|
Pattern,
|
||||||
|
|
@ -142,6 +144,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||||
] = None # used for user messages, overridden in child classes
|
] = None # used for user messages, overridden in child classes
|
||||||
_date_trunc_functions: Dict[str, str] = {}
|
_date_trunc_functions: Dict[str, str] = {}
|
||||||
_time_grain_expressions: Dict[Optional[str], str] = {}
|
_time_grain_expressions: Dict[Optional[str], str] = {}
|
||||||
|
column_type_mappings: Tuple[
|
||||||
|
Tuple[Pattern[str], Union[TypeEngine, Callable[[Match[str]], TypeEngine]]], ...,
|
||||||
|
] = ()
|
||||||
time_groupby_inline = False
|
time_groupby_inline = False
|
||||||
limit_method = LimitMethod.FORCE_LIMIT
|
limit_method = LimitMethod.FORCE_LIMIT
|
||||||
time_secondary_columns = False
|
time_secondary_columns = False
|
||||||
|
|
@ -886,12 +891,18 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||||
"""
|
"""
|
||||||
Return a sqlalchemy native column type that corresponds to the column type
|
Return a sqlalchemy native column type that corresponds to the column type
|
||||||
defined in the data source (return None to use default type inferred by
|
defined in the data source (return None to use default type inferred by
|
||||||
SQLAlchemy). Needs to be overridden if column requires special handling
|
SQLAlchemy). Override `_column_type_mappings` for specific needs
|
||||||
(see MSSQL for example of NCHAR/NVARCHAR handling).
|
(see MSSQL for example of NCHAR/NVARCHAR handling).
|
||||||
|
|
||||||
:param type_: Column type returned by inspector
|
:param type_: Column type returned by inspector
|
||||||
:return: SqlAlchemy column type
|
:return: SqlAlchemy column type
|
||||||
"""
|
"""
|
||||||
|
for regex, sqla_type in cls.column_type_mappings:
|
||||||
|
match = regex.match(type_)
|
||||||
|
if match:
|
||||||
|
if callable(sqla_type):
|
||||||
|
return sqla_type(match)
|
||||||
|
return sqla_type
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ import re
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, List, Optional, Tuple, TYPE_CHECKING
|
from typing import Any, List, Optional, Tuple, TYPE_CHECKING
|
||||||
|
|
||||||
from sqlalchemy.types import String, TypeEngine, UnicodeText
|
from sqlalchemy.types import String, UnicodeText
|
||||||
|
|
||||||
from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
|
from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
|
||||||
from superset.utils import core as utils
|
from superset.utils import core as utils
|
||||||
|
|
@ -73,18 +73,11 @@ class MssqlEngineSpec(BaseEngineSpec):
|
||||||
# Lists of `pyodbc.Row` need to be unpacked further
|
# Lists of `pyodbc.Row` need to be unpacked further
|
||||||
return cls.pyodbc_rows_to_tuples(data)
|
return cls.pyodbc_rows_to_tuples(data)
|
||||||
|
|
||||||
column_types = (
|
column_type_mappings = (
|
||||||
(String(), re.compile(r"^(?<!N)((VAR){0,1}CHAR|TEXT|STRING)", re.IGNORECASE)),
|
(re.compile(r"^N((VAR)?CHAR|TEXT)", re.IGNORECASE), UnicodeText()),
|
||||||
(UnicodeText(), re.compile(r"^N((VAR){0,1}CHAR|TEXT)", re.IGNORECASE)),
|
(re.compile(r"^((VAR)?CHAR|TEXT|STRING)", re.IGNORECASE), String()),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_sqla_column_type(cls, type_: str) -> Optional[TypeEngine]:
|
|
||||||
for sqla_type, regex in cls.column_types:
|
|
||||||
if regex.match(type_):
|
|
||||||
return sqla_type
|
|
||||||
return None
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def extract_error_message(cls, ex: Exception) -> str:
|
def extract_error_message(cls, ex: Exception) -> str:
|
||||||
if str(ex).startswith("(8155,"):
|
if str(ex).startswith("(8155,"):
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ from urllib import parse
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
from flask_babel import lazy_gettext as _
|
from flask_babel import lazy_gettext as _
|
||||||
from sqlalchemy import Column, literal_column
|
from sqlalchemy import Column, literal_column, types
|
||||||
from sqlalchemy.engine.base import Engine
|
from sqlalchemy.engine.base import Engine
|
||||||
from sqlalchemy.engine.reflection import Inspector
|
from sqlalchemy.engine.reflection import Inspector
|
||||||
from sqlalchemy.engine.result import RowProxy
|
from sqlalchemy.engine.result import RowProxy
|
||||||
|
|
@ -40,7 +40,13 @@ from superset import app, cache, is_feature_enabled, security_manager
|
||||||
from superset.db_engine_specs.base import BaseEngineSpec
|
from superset.db_engine_specs.base import BaseEngineSpec
|
||||||
from superset.exceptions import SupersetTemplateException
|
from superset.exceptions import SupersetTemplateException
|
||||||
from superset.models.sql_lab import Query
|
from superset.models.sql_lab import Query
|
||||||
from superset.models.sql_types.presto_sql_types import type_map as presto_type_map
|
from superset.models.sql_types.presto_sql_types import (
|
||||||
|
Array,
|
||||||
|
Interval,
|
||||||
|
Map,
|
||||||
|
Row,
|
||||||
|
TinyInteger,
|
||||||
|
)
|
||||||
from superset.result_set import destringify
|
from superset.result_set import destringify
|
||||||
from superset.sql_parse import ParsedQuery
|
from superset.sql_parse import ParsedQuery
|
||||||
from superset.utils import core as utils
|
from superset.utils import core as utils
|
||||||
|
|
@ -260,13 +266,16 @@ class PrestoEngineSpec(BaseEngineSpec):
|
||||||
field_info = cls._split_data_type(single_field, r"\s")
|
field_info = cls._split_data_type(single_field, r"\s")
|
||||||
# check if there is a structural data type within
|
# check if there is a structural data type within
|
||||||
# overall structural data type
|
# overall structural data type
|
||||||
|
column_type = cls.get_sqla_column_type(field_info[1])
|
||||||
|
if column_type is None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
_("Unknown column type: %(col)s", col=field_info[1])
|
||||||
|
)
|
||||||
if field_info[1] == "array" or field_info[1] == "row":
|
if field_info[1] == "array" or field_info[1] == "row":
|
||||||
stack.append((field_info[0], field_info[1]))
|
stack.append((field_info[0], field_info[1]))
|
||||||
full_parent_path = cls._get_full_name(stack)
|
full_parent_path = cls._get_full_name(stack)
|
||||||
result.append(
|
result.append(
|
||||||
cls._create_column_info(
|
cls._create_column_info(full_parent_path, column_type)
|
||||||
full_parent_path, presto_type_map[field_info[1]]()
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
else: # otherwise this field is a basic data type
|
else: # otherwise this field is a basic data type
|
||||||
full_parent_path = cls._get_full_name(stack)
|
full_parent_path = cls._get_full_name(stack)
|
||||||
|
|
@ -274,9 +283,7 @@ class PrestoEngineSpec(BaseEngineSpec):
|
||||||
full_parent_path, field_info[0]
|
full_parent_path, field_info[0]
|
||||||
)
|
)
|
||||||
result.append(
|
result.append(
|
||||||
cls._create_column_info(
|
cls._create_column_info(column_name, column_type)
|
||||||
column_name, presto_type_map[field_info[1]]()
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
# If the component type ends with a structural data type, do not pop
|
# If the component type ends with a structural data type, do not pop
|
||||||
# the stack. We have run across a structural data type within the
|
# the stack. We have run across a structural data type within the
|
||||||
|
|
@ -318,6 +325,34 @@ class PrestoEngineSpec(BaseEngineSpec):
|
||||||
columns = inspector.bind.execute("SHOW COLUMNS FROM {}".format(full_table))
|
columns = inspector.bind.execute("SHOW COLUMNS FROM {}".format(full_table))
|
||||||
return columns
|
return columns
|
||||||
|
|
||||||
|
column_type_mappings = (
|
||||||
|
(re.compile(r"^boolean.*", re.IGNORECASE), types.Boolean()),
|
||||||
|
(re.compile(r"^tinyint.*", re.IGNORECASE), TinyInteger()),
|
||||||
|
(re.compile(r"^smallint.*", re.IGNORECASE), types.SmallInteger()),
|
||||||
|
(re.compile(r"^integer.*", re.IGNORECASE), types.Integer()),
|
||||||
|
(re.compile(r"^bigint.*", re.IGNORECASE), types.BigInteger()),
|
||||||
|
(re.compile(r"^real.*", re.IGNORECASE), types.Float()),
|
||||||
|
(re.compile(r"^double.*", re.IGNORECASE), types.Float()),
|
||||||
|
(re.compile(r"^decimal.*", re.IGNORECASE), types.DECIMAL()),
|
||||||
|
(
|
||||||
|
re.compile(r"^varchar(\((\d+)\))*$", re.IGNORECASE),
|
||||||
|
lambda match: types.VARCHAR(int(match[2])) if match[2] else types.String(),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
re.compile(r"^char(\((\d+)\))*$", re.IGNORECASE),
|
||||||
|
lambda match: types.CHAR(int(match[2])) if match[2] else types.CHAR(),
|
||||||
|
),
|
||||||
|
(re.compile(r"^varbinary.*", re.IGNORECASE), types.VARBINARY()),
|
||||||
|
(re.compile(r"^json.*", re.IGNORECASE), types.JSON()),
|
||||||
|
(re.compile(r"^date.*", re.IGNORECASE), types.DATE()),
|
||||||
|
(re.compile(r"^time.*", re.IGNORECASE), types.Time()),
|
||||||
|
(re.compile(r"^timestamp.*", re.IGNORECASE), types.TIMESTAMP()),
|
||||||
|
(re.compile(r"^interval.*", re.IGNORECASE), Interval()),
|
||||||
|
(re.compile(r"^array.*", re.IGNORECASE), Array()),
|
||||||
|
(re.compile(r"^map.*", re.IGNORECASE), Map()),
|
||||||
|
(re.compile(r"^row.*", re.IGNORECASE), Row()),
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_columns(
|
def get_columns(
|
||||||
cls, inspector: Inspector, table_name: str, schema: Optional[str]
|
cls, inspector: Inspector, table_name: str, schema: Optional[str]
|
||||||
|
|
@ -334,28 +369,24 @@ class PrestoEngineSpec(BaseEngineSpec):
|
||||||
columns = cls._show_columns(inspector, table_name, schema)
|
columns = cls._show_columns(inspector, table_name, schema)
|
||||||
result: List[Dict[str, Any]] = []
|
result: List[Dict[str, Any]] = []
|
||||||
for column in columns:
|
for column in columns:
|
||||||
try:
|
# parse column if it is a row or array
|
||||||
# parse column if it is a row or array
|
if is_feature_enabled("PRESTO_EXPAND_DATA") and (
|
||||||
if is_feature_enabled("PRESTO_EXPAND_DATA") and (
|
"array" in column.Type or "row" in column.Type
|
||||||
"array" in column.Type or "row" in column.Type
|
):
|
||||||
):
|
structural_column_index = len(result)
|
||||||
structural_column_index = len(result)
|
cls._parse_structural_column(column.Column, column.Type, result)
|
||||||
cls._parse_structural_column(column.Column, column.Type, result)
|
result[structural_column_index]["nullable"] = getattr(
|
||||||
result[structural_column_index]["nullable"] = getattr(
|
column, "Null", True
|
||||||
column, "Null", True
|
)
|
||||||
)
|
result[structural_column_index]["default"] = None
|
||||||
result[structural_column_index]["default"] = None
|
continue
|
||||||
continue
|
|
||||||
|
# otherwise column is a basic data type
|
||||||
# otherwise column is a basic data type
|
column_type = cls.get_sqla_column_type(column.Type)
|
||||||
column_type = presto_type_map[column.Type]()
|
if column_type is None:
|
||||||
except KeyError:
|
raise NotImplementedError(
|
||||||
logger.info(
|
_("Unknown column type: %(col)s", col=column_type)
|
||||||
"Did not recognize type {} of column {}".format( # pylint: disable=logging-format-interpolation
|
|
||||||
column.Type, column.Column
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
column_type = "OTHER"
|
|
||||||
column_info = cls._create_column_info(column.Column, column_type)
|
column_info = cls._create_column_info(column.Column, column_type)
|
||||||
column_info["nullable"] = getattr(column, "Null", True)
|
column_info["nullable"] = getattr(column, "Null", True)
|
||||||
column_info["default"] = None
|
column_info["default"] = None
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,6 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
from typing import Any, Dict, List, Optional, Type
|
from typing import Any, Dict, List, Optional, Type
|
||||||
|
|
||||||
from sqlalchemy import types
|
|
||||||
from sqlalchemy.sql.sqltypes import Integer
|
from sqlalchemy.sql.sqltypes import Integer
|
||||||
from sqlalchemy.sql.type_api import TypeEngine
|
from sqlalchemy.sql.type_api import TypeEngine
|
||||||
from sqlalchemy.sql.visitors import Visitable
|
from sqlalchemy.sql.visitors import Visitable
|
||||||
|
|
@ -92,26 +91,3 @@ class Row(TypeEngine):
|
||||||
@classmethod
|
@classmethod
|
||||||
def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str:
|
def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str:
|
||||||
return "ROW"
|
return "ROW"
|
||||||
|
|
||||||
|
|
||||||
type_map = {
|
|
||||||
"boolean": types.Boolean,
|
|
||||||
"tinyint": TinyInteger,
|
|
||||||
"smallint": types.SmallInteger,
|
|
||||||
"integer": types.Integer,
|
|
||||||
"bigint": types.BigInteger,
|
|
||||||
"real": types.Float,
|
|
||||||
"double": types.Float,
|
|
||||||
"decimal": types.DECIMAL,
|
|
||||||
"varchar": types.String,
|
|
||||||
"char": types.CHAR,
|
|
||||||
"varbinary": types.VARBINARY,
|
|
||||||
"JSON": types.JSON,
|
|
||||||
"date": types.DATE,
|
|
||||||
"time": types.Time,
|
|
||||||
"timestamp": types.TIMESTAMP,
|
|
||||||
"interval": Interval,
|
|
||||||
"array": Array,
|
|
||||||
"map": Map,
|
|
||||||
"row": Row,
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@
|
||||||
from unittest import mock, skipUnless
|
from unittest import mock, skipUnless
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
from sqlalchemy import types
|
||||||
from sqlalchemy.engine.result import RowProxy
|
from sqlalchemy.engine.result import RowProxy
|
||||||
from sqlalchemy.sql import select
|
from sqlalchemy.sql import select
|
||||||
|
|
||||||
|
|
@ -490,3 +491,23 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
||||||
self.assertEqual(actual_cols, expected_cols)
|
self.assertEqual(actual_cols, expected_cols)
|
||||||
self.assertEqual(actual_data, expected_data)
|
self.assertEqual(actual_data, expected_data)
|
||||||
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
|
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
|
||||||
|
|
||||||
|
def test_get_sqla_column_type(self):
|
||||||
|
sqla_type = PrestoEngineSpec.get_sqla_column_type("varchar(255)")
|
||||||
|
assert isinstance(sqla_type, types.VARCHAR)
|
||||||
|
assert sqla_type.length == 255
|
||||||
|
|
||||||
|
sqla_type = PrestoEngineSpec.get_sqla_column_type("varchar")
|
||||||
|
assert isinstance(sqla_type, types.String)
|
||||||
|
assert sqla_type.length is None
|
||||||
|
|
||||||
|
sqla_type = PrestoEngineSpec.get_sqla_column_type("char(10)")
|
||||||
|
assert isinstance(sqla_type, types.CHAR)
|
||||||
|
assert sqla_type.length == 10
|
||||||
|
|
||||||
|
sqla_type = PrestoEngineSpec.get_sqla_column_type("char")
|
||||||
|
assert isinstance(sqla_type, types.CHAR)
|
||||||
|
assert sqla_type.length is None
|
||||||
|
|
||||||
|
sqla_type = PrestoEngineSpec.get_sqla_column_type("integer")
|
||||||
|
assert isinstance(sqla_type, types.Integer)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue