refactor: rename DbColumnType to GenericDataType (#12617)

This commit is contained in:
Jesse Yang 2021-01-20 10:07:42 -08:00 committed by GitHub
parent 2463215d73
commit c14ed80f28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 56 additions and 56 deletions

View File

@ -189,7 +189,7 @@ class TableColumn(Model, BaseColumn):
""" """
db_engine_spec = self.table.database.db_engine_spec db_engine_spec = self.table.database.db_engine_spec
return db_engine_spec.is_db_column_type_match( return db_engine_spec.is_db_column_type_match(
self.type, utils.DbColumnType.NUMERIC self.type, utils.GenericDataType.NUMERIC
) )
@property @property
@ -199,7 +199,7 @@ class TableColumn(Model, BaseColumn):
""" """
db_engine_spec = self.table.database.db_engine_spec db_engine_spec = self.table.database.db_engine_spec
return db_engine_spec.is_db_column_type_match( return db_engine_spec.is_db_column_type_match(
self.type, utils.DbColumnType.STRING self.type, utils.GenericDataType.STRING
) )
@property @property
@ -214,7 +214,7 @@ class TableColumn(Model, BaseColumn):
return self.is_dttm return self.is_dttm
db_engine_spec = self.table.database.db_engine_spec db_engine_spec = self.table.database.db_engine_spec
return db_engine_spec.is_db_column_type_match( return db_engine_spec.is_db_column_type_match(
self.type, utils.DbColumnType.TEMPORAL self.type, utils.GenericDataType.TEMPORAL
) )
def get_sqla_col(self, label: Optional[str] = None) -> Column: def get_sqla_col(self, label: Optional[str] = None) -> Column:

View File

@ -159,8 +159,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
run_multiple_statements_as_one = False run_multiple_statements_as_one = False
# default matching patterns for identifying column types # default matching patterns for identifying column types
db_column_types: Dict[utils.DbColumnType, Tuple[Pattern[Any], ...]] = { db_column_types: Dict[utils.GenericDataType, Tuple[Pattern[Any], ...]] = {
utils.DbColumnType.NUMERIC: ( utils.GenericDataType.NUMERIC: (
re.compile(r"BIT", re.IGNORECASE), re.compile(r"BIT", re.IGNORECASE),
re.compile(r".*DOUBLE.*", re.IGNORECASE), re.compile(r".*DOUBLE.*", re.IGNORECASE),
re.compile(r".*FLOAT.*", re.IGNORECASE), re.compile(r".*FLOAT.*", re.IGNORECASE),
@ -172,12 +172,12 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
re.compile(r".*DECIMAL.*", re.IGNORECASE), re.compile(r".*DECIMAL.*", re.IGNORECASE),
re.compile(r".*MONEY.*", re.IGNORECASE), re.compile(r".*MONEY.*", re.IGNORECASE),
), ),
utils.DbColumnType.STRING: ( utils.GenericDataType.STRING: (
re.compile(r".*CHAR.*", re.IGNORECASE), re.compile(r".*CHAR.*", re.IGNORECASE),
re.compile(r".*STRING.*", re.IGNORECASE), re.compile(r".*STRING.*", re.IGNORECASE),
re.compile(r".*TEXT.*", re.IGNORECASE), re.compile(r".*TEXT.*", re.IGNORECASE),
), ),
utils.DbColumnType.TEMPORAL: ( utils.GenericDataType.TEMPORAL: (
re.compile(r".*DATE.*", re.IGNORECASE), re.compile(r".*DATE.*", re.IGNORECASE),
re.compile(r".*TIME.*", re.IGNORECASE), re.compile(r".*TIME.*", re.IGNORECASE),
), ),
@ -185,7 +185,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod @classmethod
def is_db_column_type_match( def is_db_column_type_match(
cls, db_column_type: Optional[str], target_column_type: utils.DbColumnType cls, db_column_type: Optional[str], target_column_type: utils.GenericDataType
) -> bool: ) -> bool:
""" """
Check if a column type satisfies a pattern in a collection of regexes found in Check if a column type satisfies a pattern in a collection of regexes found in

View File

@ -123,7 +123,7 @@ class SupersetResultSet:
if pa.types.is_nested(pa_data[i].type): if pa.types.is_nested(pa_data[i].type):
# TODO: revisit nested column serialization once nested types # TODO: revisit nested column serialization once nested types
# are added as a natively supported column type in Superset # are added as a natively supported column type in Superset
# (superset.utils.core.DbColumnType). # (superset.utils.core.GenericDataType).
stringified_arr = stringify_values(array[column]) stringified_arr = stringify_values(array[column])
pa_data[i] = pa.array(stringified_arr.tolist()) pa_data[i] = pa.array(stringified_arr.tolist())
@ -182,7 +182,7 @@ class SupersetResultSet:
def is_temporal(self, db_type_str: Optional[str]) -> bool: def is_temporal(self, db_type_str: Optional[str]) -> bool:
return self.db_engine_spec.is_db_column_type_match( return self.db_engine_spec.is_db_column_type_match(
db_type_str, utils.DbColumnType.TEMPORAL db_type_str, utils.GenericDataType.TEMPORAL
) )
def data_type(self, col_name: str, pa_dtype: pa.DataType) -> Optional[str]: def data_type(self, col_name: str, pa_dtype: pa.DataType) -> Optional[str]:

View File

@ -129,7 +129,7 @@ class AnnotationType(str, Enum):
TIME_SERIES = "TIME_SERIES" TIME_SERIES = "TIME_SERIES"
class DbColumnType(Enum): class GenericDataType(Enum):
""" """
Generic database column type Generic database column type
""" """

View File

@ -20,7 +20,7 @@ from sqlalchemy.dialects import mysql
from sqlalchemy.dialects.mysql import DATE, NVARCHAR, TEXT, VARCHAR from sqlalchemy.dialects.mysql import DATE, NVARCHAR, TEXT, VARCHAR
from superset.db_engine_specs.mysql import MySQLEngineSpec from superset.db_engine_specs.mysql import MySQLEngineSpec
from superset.utils.core import DbColumnType from superset.utils.core import GenericDataType
from tests.db_engine_specs.base_tests import TestDbEngineSpec from tests.db_engine_specs.base_tests import TestDbEngineSpec
@ -67,40 +67,40 @@ class TestMySQLEngineSpecsDbEngineSpec(TestDbEngineSpec):
def test_is_db_column_type_match(self): def test_is_db_column_type_match(self):
type_expectations = ( type_expectations = (
# Numeric # Numeric
("TINYINT", DbColumnType.NUMERIC), ("TINYINT", GenericDataType.NUMERIC),
("SMALLINT", DbColumnType.NUMERIC), ("SMALLINT", GenericDataType.NUMERIC),
("MEDIUMINT", DbColumnType.NUMERIC), ("MEDIUMINT", GenericDataType.NUMERIC),
("INT", DbColumnType.NUMERIC), ("INT", GenericDataType.NUMERIC),
("BIGINT", DbColumnType.NUMERIC), ("BIGINT", GenericDataType.NUMERIC),
("DECIMAL", DbColumnType.NUMERIC), ("DECIMAL", GenericDataType.NUMERIC),
("FLOAT", DbColumnType.NUMERIC), ("FLOAT", GenericDataType.NUMERIC),
("DOUBLE", DbColumnType.NUMERIC), ("DOUBLE", GenericDataType.NUMERIC),
("BIT", DbColumnType.NUMERIC), ("BIT", GenericDataType.NUMERIC),
# String # String
("CHAR", DbColumnType.STRING), ("CHAR", GenericDataType.STRING),
("VARCHAR", DbColumnType.STRING), ("VARCHAR", GenericDataType.STRING),
("TINYTEXT", DbColumnType.STRING), ("TINYTEXT", GenericDataType.STRING),
("MEDIUMTEXT", DbColumnType.STRING), ("MEDIUMTEXT", GenericDataType.STRING),
("LONGTEXT", DbColumnType.STRING), ("LONGTEXT", GenericDataType.STRING),
# Temporal # Temporal
("DATE", DbColumnType.TEMPORAL), ("DATE", GenericDataType.TEMPORAL),
("DATETIME", DbColumnType.TEMPORAL), ("DATETIME", GenericDataType.TEMPORAL),
("TIMESTAMP", DbColumnType.TEMPORAL), ("TIMESTAMP", GenericDataType.TEMPORAL),
("TIME", DbColumnType.TEMPORAL), ("TIME", GenericDataType.TEMPORAL),
) )
for type_expectation in type_expectations: for type_expectation in type_expectations:
type_str = type_expectation[0] type_str = type_expectation[0]
col_type = type_expectation[1] col_type = type_expectation[1]
assert MySQLEngineSpec.is_db_column_type_match( assert MySQLEngineSpec.is_db_column_type_match(
type_str, DbColumnType.NUMERIC type_str, GenericDataType.NUMERIC
) is (col_type == DbColumnType.NUMERIC) ) is (col_type == GenericDataType.NUMERIC)
assert MySQLEngineSpec.is_db_column_type_match( assert MySQLEngineSpec.is_db_column_type_match(
type_str, DbColumnType.STRING type_str, GenericDataType.STRING
) is (col_type == DbColumnType.STRING) ) is (col_type == GenericDataType.STRING)
assert MySQLEngineSpec.is_db_column_type_match( assert MySQLEngineSpec.is_db_column_type_match(
type_str, DbColumnType.TEMPORAL type_str, GenericDataType.TEMPORAL
) is (col_type == DbColumnType.TEMPORAL) ) is (col_type == GenericDataType.TEMPORAL)
def test_extract_error_message(self): def test_extract_error_message(self):
from MySQLdb._exceptions import OperationalError from MySQLdb._exceptions import OperationalError

View File

@ -26,7 +26,7 @@ from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.db_engine_specs.druid import DruidEngineSpec from superset.db_engine_specs.druid import DruidEngineSpec
from superset.exceptions import QueryObjectValidationError from superset.exceptions import QueryObjectValidationError
from superset.models.core import Database from superset.models.core import Database
from superset.utils.core import DbColumnType, get_example_database, FilterOperator from superset.utils.core import GenericDataType, get_example_database, FilterOperator
from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices
from .base_tests import SupersetTestCase from .base_tests import SupersetTestCase
@ -76,33 +76,33 @@ class TestDatabaseModel(SupersetTestCase):
assert col.is_temporal is True assert col.is_temporal is True
def test_db_column_types(self): def test_db_column_types(self):
test_cases: Dict[str, DbColumnType] = { test_cases: Dict[str, GenericDataType] = {
# string # string
"CHAR": DbColumnType.STRING, "CHAR": GenericDataType.STRING,
"VARCHAR": DbColumnType.STRING, "VARCHAR": GenericDataType.STRING,
"NVARCHAR": DbColumnType.STRING, "NVARCHAR": GenericDataType.STRING,
"STRING": DbColumnType.STRING, "STRING": GenericDataType.STRING,
"TEXT": DbColumnType.STRING, "TEXT": GenericDataType.STRING,
"NTEXT": DbColumnType.STRING, "NTEXT": GenericDataType.STRING,
# numeric # numeric
"INT": DbColumnType.NUMERIC, "INT": GenericDataType.NUMERIC,
"BIGINT": DbColumnType.NUMERIC, "BIGINT": GenericDataType.NUMERIC,
"FLOAT": DbColumnType.NUMERIC, "FLOAT": GenericDataType.NUMERIC,
"DECIMAL": DbColumnType.NUMERIC, "DECIMAL": GenericDataType.NUMERIC,
"MONEY": DbColumnType.NUMERIC, "MONEY": GenericDataType.NUMERIC,
# temporal # temporal
"DATE": DbColumnType.TEMPORAL, "DATE": GenericDataType.TEMPORAL,
"DATETIME": DbColumnType.TEMPORAL, "DATETIME": GenericDataType.TEMPORAL,
"TIME": DbColumnType.TEMPORAL, "TIME": GenericDataType.TEMPORAL,
"TIMESTAMP": DbColumnType.TEMPORAL, "TIMESTAMP": GenericDataType.TEMPORAL,
} }
tbl = SqlaTable(table_name="col_type_test_tbl", database=get_example_database()) tbl = SqlaTable(table_name="col_type_test_tbl", database=get_example_database())
for str_type, db_col_type in test_cases.items(): for str_type, db_col_type in test_cases.items():
col = TableColumn(column_name="foo", type=str_type, table=tbl) col = TableColumn(column_name="foo", type=str_type, table=tbl)
self.assertEqual(col.is_temporal, db_col_type == DbColumnType.TEMPORAL) self.assertEqual(col.is_temporal, db_col_type == GenericDataType.TEMPORAL)
self.assertEqual(col.is_numeric, db_col_type == DbColumnType.NUMERIC) self.assertEqual(col.is_numeric, db_col_type == GenericDataType.NUMERIC)
self.assertEqual(col.is_string, db_col_type == DbColumnType.STRING) self.assertEqual(col.is_string, db_col_type == GenericDataType.STRING)
@patch("superset.jinja_context.g") @patch("superset.jinja_context.g")
def test_extra_cache_keys(self, flask_g): def test_extra_cache_keys(self, flask_g):