refactor: rename DbColumnType to GenericDataType (#12617)

This commit is contained in:
Jesse Yang 2021-01-20 10:07:42 -08:00 committed by GitHub
parent 2463215d73
commit c14ed80f28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 56 additions and 56 deletions

View File

@ -189,7 +189,7 @@ class TableColumn(Model, BaseColumn):
"""
db_engine_spec = self.table.database.db_engine_spec
return db_engine_spec.is_db_column_type_match(
self.type, utils.DbColumnType.NUMERIC
self.type, utils.GenericDataType.NUMERIC
)
@property
@ -199,7 +199,7 @@ class TableColumn(Model, BaseColumn):
"""
db_engine_spec = self.table.database.db_engine_spec
return db_engine_spec.is_db_column_type_match(
self.type, utils.DbColumnType.STRING
self.type, utils.GenericDataType.STRING
)
@property
@ -214,7 +214,7 @@ class TableColumn(Model, BaseColumn):
return self.is_dttm
db_engine_spec = self.table.database.db_engine_spec
return db_engine_spec.is_db_column_type_match(
self.type, utils.DbColumnType.TEMPORAL
self.type, utils.GenericDataType.TEMPORAL
)
def get_sqla_col(self, label: Optional[str] = None) -> Column:

View File

@ -159,8 +159,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
run_multiple_statements_as_one = False
# default matching patterns for identifying column types
db_column_types: Dict[utils.DbColumnType, Tuple[Pattern[Any], ...]] = {
utils.DbColumnType.NUMERIC: (
db_column_types: Dict[utils.GenericDataType, Tuple[Pattern[Any], ...]] = {
utils.GenericDataType.NUMERIC: (
re.compile(r"BIT", re.IGNORECASE),
re.compile(r".*DOUBLE.*", re.IGNORECASE),
re.compile(r".*FLOAT.*", re.IGNORECASE),
@ -172,12 +172,12 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
re.compile(r".*DECIMAL.*", re.IGNORECASE),
re.compile(r".*MONEY.*", re.IGNORECASE),
),
utils.DbColumnType.STRING: (
utils.GenericDataType.STRING: (
re.compile(r".*CHAR.*", re.IGNORECASE),
re.compile(r".*STRING.*", re.IGNORECASE),
re.compile(r".*TEXT.*", re.IGNORECASE),
),
utils.DbColumnType.TEMPORAL: (
utils.GenericDataType.TEMPORAL: (
re.compile(r".*DATE.*", re.IGNORECASE),
re.compile(r".*TIME.*", re.IGNORECASE),
),
@ -185,7 +185,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def is_db_column_type_match(
cls, db_column_type: Optional[str], target_column_type: utils.DbColumnType
cls, db_column_type: Optional[str], target_column_type: utils.GenericDataType
) -> bool:
"""
Check if a column type satisfies a pattern in a collection of regexes found in

View File

@ -123,7 +123,7 @@ class SupersetResultSet:
if pa.types.is_nested(pa_data[i].type):
# TODO: revisit nested column serialization once nested types
# are added as a natively supported column type in Superset
# (superset.utils.core.DbColumnType).
# (superset.utils.core.GenericDataType).
stringified_arr = stringify_values(array[column])
pa_data[i] = pa.array(stringified_arr.tolist())
@ -182,7 +182,7 @@ class SupersetResultSet:
def is_temporal(self, db_type_str: Optional[str]) -> bool:
return self.db_engine_spec.is_db_column_type_match(
db_type_str, utils.DbColumnType.TEMPORAL
db_type_str, utils.GenericDataType.TEMPORAL
)
def data_type(self, col_name: str, pa_dtype: pa.DataType) -> Optional[str]:

View File

@ -129,7 +129,7 @@ class AnnotationType(str, Enum):
TIME_SERIES = "TIME_SERIES"
class DbColumnType(Enum):
class GenericDataType(Enum):
"""
Generic database column type
"""

View File

@ -20,7 +20,7 @@ from sqlalchemy.dialects import mysql
from sqlalchemy.dialects.mysql import DATE, NVARCHAR, TEXT, VARCHAR
from superset.db_engine_specs.mysql import MySQLEngineSpec
from superset.utils.core import DbColumnType
from superset.utils.core import GenericDataType
from tests.db_engine_specs.base_tests import TestDbEngineSpec
@ -67,40 +67,40 @@ class TestMySQLEngineSpecsDbEngineSpec(TestDbEngineSpec):
def test_is_db_column_type_match(self):
type_expectations = (
# Numeric
("TINYINT", DbColumnType.NUMERIC),
("SMALLINT", DbColumnType.NUMERIC),
("MEDIUMINT", DbColumnType.NUMERIC),
("INT", DbColumnType.NUMERIC),
("BIGINT", DbColumnType.NUMERIC),
("DECIMAL", DbColumnType.NUMERIC),
("FLOAT", DbColumnType.NUMERIC),
("DOUBLE", DbColumnType.NUMERIC),
("BIT", DbColumnType.NUMERIC),
("TINYINT", GenericDataType.NUMERIC),
("SMALLINT", GenericDataType.NUMERIC),
("MEDIUMINT", GenericDataType.NUMERIC),
("INT", GenericDataType.NUMERIC),
("BIGINT", GenericDataType.NUMERIC),
("DECIMAL", GenericDataType.NUMERIC),
("FLOAT", GenericDataType.NUMERIC),
("DOUBLE", GenericDataType.NUMERIC),
("BIT", GenericDataType.NUMERIC),
# String
("CHAR", DbColumnType.STRING),
("VARCHAR", DbColumnType.STRING),
("TINYTEXT", DbColumnType.STRING),
("MEDIUMTEXT", DbColumnType.STRING),
("LONGTEXT", DbColumnType.STRING),
("CHAR", GenericDataType.STRING),
("VARCHAR", GenericDataType.STRING),
("TINYTEXT", GenericDataType.STRING),
("MEDIUMTEXT", GenericDataType.STRING),
("LONGTEXT", GenericDataType.STRING),
# Temporal
("DATE", DbColumnType.TEMPORAL),
("DATETIME", DbColumnType.TEMPORAL),
("TIMESTAMP", DbColumnType.TEMPORAL),
("TIME", DbColumnType.TEMPORAL),
("DATE", GenericDataType.TEMPORAL),
("DATETIME", GenericDataType.TEMPORAL),
("TIMESTAMP", GenericDataType.TEMPORAL),
("TIME", GenericDataType.TEMPORAL),
)
for type_expectation in type_expectations:
type_str = type_expectation[0]
col_type = type_expectation[1]
assert MySQLEngineSpec.is_db_column_type_match(
type_str, DbColumnType.NUMERIC
) is (col_type == DbColumnType.NUMERIC)
type_str, GenericDataType.NUMERIC
) is (col_type == GenericDataType.NUMERIC)
assert MySQLEngineSpec.is_db_column_type_match(
type_str, DbColumnType.STRING
) is (col_type == DbColumnType.STRING)
type_str, GenericDataType.STRING
) is (col_type == GenericDataType.STRING)
assert MySQLEngineSpec.is_db_column_type_match(
type_str, DbColumnType.TEMPORAL
) is (col_type == DbColumnType.TEMPORAL)
type_str, GenericDataType.TEMPORAL
) is (col_type == GenericDataType.TEMPORAL)
def test_extract_error_message(self):
from MySQLdb._exceptions import OperationalError

View File

@ -26,7 +26,7 @@ from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.db_engine_specs.druid import DruidEngineSpec
from superset.exceptions import QueryObjectValidationError
from superset.models.core import Database
from superset.utils.core import DbColumnType, get_example_database, FilterOperator
from superset.utils.core import GenericDataType, get_example_database, FilterOperator
from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices
from .base_tests import SupersetTestCase
@ -76,33 +76,33 @@ class TestDatabaseModel(SupersetTestCase):
assert col.is_temporal is True
def test_db_column_types(self):
test_cases: Dict[str, DbColumnType] = {
test_cases: Dict[str, GenericDataType] = {
# string
"CHAR": DbColumnType.STRING,
"VARCHAR": DbColumnType.STRING,
"NVARCHAR": DbColumnType.STRING,
"STRING": DbColumnType.STRING,
"TEXT": DbColumnType.STRING,
"NTEXT": DbColumnType.STRING,
"CHAR": GenericDataType.STRING,
"VARCHAR": GenericDataType.STRING,
"NVARCHAR": GenericDataType.STRING,
"STRING": GenericDataType.STRING,
"TEXT": GenericDataType.STRING,
"NTEXT": GenericDataType.STRING,
# numeric
"INT": DbColumnType.NUMERIC,
"BIGINT": DbColumnType.NUMERIC,
"FLOAT": DbColumnType.NUMERIC,
"DECIMAL": DbColumnType.NUMERIC,
"MONEY": DbColumnType.NUMERIC,
"INT": GenericDataType.NUMERIC,
"BIGINT": GenericDataType.NUMERIC,
"FLOAT": GenericDataType.NUMERIC,
"DECIMAL": GenericDataType.NUMERIC,
"MONEY": GenericDataType.NUMERIC,
# temporal
"DATE": DbColumnType.TEMPORAL,
"DATETIME": DbColumnType.TEMPORAL,
"TIME": DbColumnType.TEMPORAL,
"TIMESTAMP": DbColumnType.TEMPORAL,
"DATE": GenericDataType.TEMPORAL,
"DATETIME": GenericDataType.TEMPORAL,
"TIME": GenericDataType.TEMPORAL,
"TIMESTAMP": GenericDataType.TEMPORAL,
}
tbl = SqlaTable(table_name="col_type_test_tbl", database=get_example_database())
for str_type, db_col_type in test_cases.items():
col = TableColumn(column_name="foo", type=str_type, table=tbl)
self.assertEqual(col.is_temporal, db_col_type == DbColumnType.TEMPORAL)
self.assertEqual(col.is_numeric, db_col_type == DbColumnType.NUMERIC)
self.assertEqual(col.is_string, db_col_type == DbColumnType.STRING)
self.assertEqual(col.is_temporal, db_col_type == GenericDataType.TEMPORAL)
self.assertEqual(col.is_numeric, db_col_type == GenericDataType.NUMERIC)
self.assertEqual(col.is_string, db_col_type == GenericDataType.STRING)
@patch("superset.jinja_context.g")
def test_extra_cache_keys(self, flask_g):