Improve database type inference (#4724)
* Improve database type inference Python's DBAPI isn't super clear and homogeneous on the cursor.description specification, and this PR attempts to improve inferring the datatypes returned in the cursor. This work started around Presto's TIMESTAMP type being mishandled as string as the database driver (pyhive) returns it as a string. The work here fixes this bug and does a better job at inferring MySQL and Presto types. It also creates a new method in db_engine_specs allowing for other databases engines to implement and become more precise on type-inference as needed. * Fixing tests * Adressing comments * Using infer_objects * Removing faulty line * Addressing PrestoSpec redundant method comment * Fix rebase issue * Fix tests
This commit is contained in:
parent
04fc1d1089
commit
777d876a52
|
|
@ -13,6 +13,7 @@ from __future__ import print_function
|
|||
from __future__ import unicode_literals
|
||||
|
||||
from datetime import date, datetime
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
|
@ -26,6 +27,27 @@ INFER_COL_TYPES_THRESHOLD = 95
|
|||
INFER_COL_TYPES_SAMPLE_SIZE = 100
|
||||
|
||||
|
||||
def dedup(l, suffix='__'):
|
||||
"""De-duplicates a list of string by suffixing a counter
|
||||
|
||||
Always returns the same number of entries as provided, and always returns
|
||||
unique values.
|
||||
|
||||
>>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar'])))
|
||||
foo,bar,bar__1,bar__2
|
||||
"""
|
||||
new_l = []
|
||||
seen = {}
|
||||
for s in l:
|
||||
if s in seen:
|
||||
seen[s] += 1
|
||||
s += suffix + str(seen[s])
|
||||
else:
|
||||
seen[s] = 0
|
||||
new_l.append(s)
|
||||
return new_l
|
||||
|
||||
|
||||
class SupersetDataFrame(object):
|
||||
# Mapping numpy dtype.char to generic database types
|
||||
type_map = {
|
||||
|
|
@ -43,19 +65,39 @@ class SupersetDataFrame(object):
|
|||
'V': None, # raw data (void)
|
||||
}
|
||||
|
||||
def __init__(self, df):
|
||||
self.__df = df.where((pd.notnull(df)), None)
|
||||
def __init__(self, data, cursor_description, db_engine_spec):
|
||||
column_names = []
|
||||
if cursor_description:
|
||||
column_names = [col[0] for col in cursor_description]
|
||||
|
||||
self.column_names = dedup(
|
||||
db_engine_spec.get_normalized_column_names(cursor_description))
|
||||
|
||||
data = data or []
|
||||
self.df = (
|
||||
pd.DataFrame(list(data), columns=column_names).infer_objects())
|
||||
|
||||
self._type_dict = {}
|
||||
try:
|
||||
# The driver may not be passing a cursor.description
|
||||
self._type_dict = {
|
||||
col: db_engine_spec.get_datatype(cursor_description[i][1])
|
||||
for i, col in enumerate(self.column_names)
|
||||
if cursor_description
|
||||
}
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
return len(self.__df.index)
|
||||
return len(self.df.index)
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
# work around for https://github.com/pandas-dev/pandas/issues/18372
|
||||
data = [dict((k, _maybe_box_datetimelike(v))
|
||||
for k, v in zip(self.__df.columns, np.atleast_1d(row)))
|
||||
for row in self.__df.values]
|
||||
for k, v in zip(self.df.columns, np.atleast_1d(row)))
|
||||
for row in self.df.values]
|
||||
for d in data:
|
||||
for k, v in list(d.items()):
|
||||
# if an int is too big for Java Script to handle
|
||||
|
|
@ -70,7 +112,8 @@ class SupersetDataFrame(object):
|
|||
"""Given a numpy dtype, Returns a generic database type"""
|
||||
if isinstance(dtype, ExtensionDtype):
|
||||
return cls.type_map.get(dtype.kind)
|
||||
return cls.type_map.get(dtype.char)
|
||||
elif hasattr(dtype, 'char'):
|
||||
return cls.type_map.get(dtype.char)
|
||||
|
||||
@classmethod
|
||||
def datetime_conversion_rate(cls, data_series):
|
||||
|
|
@ -105,7 +148,7 @@ class SupersetDataFrame(object):
|
|||
# consider checking for key substring too.
|
||||
if cls.is_id(column_name):
|
||||
return 'count_distinct'
|
||||
if (issubclass(dtype.type, np.generic) and
|
||||
if (hasattr(dtype, 'type') and issubclass(dtype.type, np.generic) and
|
||||
np.issubdtype(dtype, np.number)):
|
||||
return 'sum'
|
||||
return None
|
||||
|
|
@ -116,22 +159,25 @@ class SupersetDataFrame(object):
|
|||
|
||||
:return: dict, with the fields name, type, is_date, is_dim and agg.
|
||||
"""
|
||||
if self.__df.empty:
|
||||
if self.df.empty:
|
||||
return None
|
||||
|
||||
columns = []
|
||||
sample_size = min(INFER_COL_TYPES_SAMPLE_SIZE, len(self.__df.index))
|
||||
sample = self.__df
|
||||
sample_size = min(INFER_COL_TYPES_SAMPLE_SIZE, len(self.df.index))
|
||||
sample = self.df
|
||||
if sample_size:
|
||||
sample = self.__df.sample(sample_size)
|
||||
for col in self.__df.dtypes.keys():
|
||||
col_db_type = self.db_type(self.__df.dtypes[col])
|
||||
sample = self.df.sample(sample_size)
|
||||
for col in self.df.dtypes.keys():
|
||||
col_db_type = (
|
||||
self._type_dict.get(col) or
|
||||
self.db_type(self.df.dtypes[col])
|
||||
)
|
||||
column = {
|
||||
'name': col,
|
||||
'agg': self.agg_func(self.__df.dtypes[col], col),
|
||||
'agg': self.agg_func(self.df.dtypes[col], col),
|
||||
'type': col_db_type,
|
||||
'is_date': self.is_date(self.__df.dtypes[col]),
|
||||
'is_dim': self.is_dimension(self.__df.dtypes[col], col),
|
||||
'is_date': self.is_date(self.df.dtypes[col]),
|
||||
'is_dim': self.is_dimension(self.df.dtypes[col], col),
|
||||
}
|
||||
|
||||
if column['type'] in ('OBJECT', None):
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ import boto3
|
|||
from flask import g
|
||||
from flask_babel import lazy_gettext as _
|
||||
import pandas
|
||||
from past.builtins import basestring
|
||||
import sqlalchemy as sqla
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.engine import create_engine
|
||||
|
|
@ -85,6 +86,11 @@ class BaseEngineSpec(object):
|
|||
def epoch_ms_to_dttm(cls):
|
||||
return cls.epoch_to_dttm().replace('{col}', '({col}/1000.0)')
|
||||
|
||||
@classmethod
|
||||
def get_datatype(cls, type_code):
|
||||
if isinstance(type_code, basestring) and len(type_code):
|
||||
return type_code.upper()
|
||||
|
||||
@classmethod
|
||||
def extra_table_metadata(cls, database, table_name, schema_name):
|
||||
"""Returns engine-specific table metadata"""
|
||||
|
|
@ -592,6 +598,7 @@ class MySQLEngineSpec(BaseEngineSpec):
|
|||
'INTERVAL DAYOFWEEK(DATE_SUB({col}, INTERVAL 1 DAY)) - 1 DAY))',
|
||||
'P1W'),
|
||||
)
|
||||
type_code_map = {} # loaded from get_datatype only if needed
|
||||
|
||||
@classmethod
|
||||
def convert_dttm(cls, target_type, dttm):
|
||||
|
|
@ -606,6 +613,23 @@ class MySQLEngineSpec(BaseEngineSpec):
|
|||
uri.database = selected_schema
|
||||
return uri
|
||||
|
||||
@classmethod
|
||||
def get_datatype(cls, type_code):
|
||||
if not cls.type_code_map:
|
||||
# only import and store if needed at least once
|
||||
import MySQLdb
|
||||
ft = MySQLdb.constants.FIELD_TYPE
|
||||
cls.type_code_map = {
|
||||
getattr(ft, k): k
|
||||
for k in dir(ft)
|
||||
if not k.startswith('_')
|
||||
}
|
||||
datatype = type_code
|
||||
if isinstance(type_code, int):
|
||||
datatype = cls.type_code_map.get(type_code)
|
||||
if datatype and isinstance(datatype, basestring) and len(datatype):
|
||||
return datatype
|
||||
|
||||
@classmethod
|
||||
def epoch_to_dttm(cls):
|
||||
return 'from_unixtime({col})'
|
||||
|
|
|
|||
|
|
@ -10,8 +10,6 @@ import uuid
|
|||
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from contextlib2 import contextmanager
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import sqlalchemy
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
|
@ -31,27 +29,6 @@ class SqlLabException(Exception):
|
|||
pass
|
||||
|
||||
|
||||
def dedup(l, suffix='__'):
|
||||
"""De-duplicates a list of string by suffixing a counter
|
||||
|
||||
Always returns the same number of entries as provided, and always returns
|
||||
unique values.
|
||||
|
||||
>>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar'])))
|
||||
foo,bar,bar__1,bar__2
|
||||
"""
|
||||
new_l = []
|
||||
seen = {}
|
||||
for s in l:
|
||||
if s in seen:
|
||||
seen[s] += 1
|
||||
s += suffix + str(seen[s])
|
||||
else:
|
||||
seen[s] = 0
|
||||
new_l.append(s)
|
||||
return new_l
|
||||
|
||||
|
||||
def get_query(query_id, session, retry_count=5):
|
||||
"""attemps to get the query and retry if it cannot"""
|
||||
query = None
|
||||
|
|
@ -96,24 +73,6 @@ def session_scope(nullpool):
|
|||
session.close()
|
||||
|
||||
|
||||
def convert_results_to_df(column_names, data):
|
||||
"""Convert raw query results to a DataFrame."""
|
||||
column_names = dedup(column_names)
|
||||
|
||||
# check whether the result set has any nested dict columns
|
||||
if data:
|
||||
first_row = data[0]
|
||||
has_dict_col = any([isinstance(c, dict) for c in first_row])
|
||||
df_data = list(data) if has_dict_col else np.array(data, dtype=object)
|
||||
else:
|
||||
df_data = []
|
||||
|
||||
cdf = dataframe.SupersetDataFrame(
|
||||
pd.DataFrame(df_data, columns=column_names))
|
||||
|
||||
return cdf
|
||||
|
||||
|
||||
@celery_app.task(bind=True, soft_time_limit=SQLLAB_TIMEOUT)
|
||||
def get_sql_results(
|
||||
ctask, query_id, rendered_query, return_results=True, store_results=False,
|
||||
|
|
@ -233,7 +192,6 @@ def execute_sql(
|
|||
return handle_error(db_engine_spec.extract_error_message(e))
|
||||
|
||||
logging.info('Fetching cursor description')
|
||||
column_names = db_engine_spec.get_normalized_column_names(cursor.description)
|
||||
|
||||
if conn is not None:
|
||||
conn.commit()
|
||||
|
|
@ -242,7 +200,7 @@ def execute_sql(
|
|||
if query.status == utils.QueryStatus.STOPPED:
|
||||
return handle_error('The query has been stopped')
|
||||
|
||||
cdf = convert_results_to_df(column_names, data)
|
||||
cdf = dataframe.SupersetDataFrame(data, cursor.description, db_engine_spec)
|
||||
|
||||
query.rows = cdf.size
|
||||
query.progress = 100
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ import unittest
|
|||
import pandas as pd
|
||||
from past.builtins import basestring
|
||||
|
||||
from superset import app, cli, dataframe, db, security_manager
|
||||
from superset import app, cli, db, security_manager
|
||||
from superset.models.helpers import QueryStatus
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.sql_parse import SupersetQuery
|
||||
|
|
@ -245,55 +245,6 @@ class CeleryTestCase(SupersetTestCase):
|
|||
def dictify_list_of_dicts(cls, l, k):
|
||||
return {str(o[k]): cls.de_unicode_dict(o) for o in l}
|
||||
|
||||
def test_get_columns(self):
|
||||
main_db = self.get_main_database(db.session)
|
||||
df = main_db.get_df('SELECT * FROM multiformat_time_series', None)
|
||||
cdf = dataframe.SupersetDataFrame(df)
|
||||
|
||||
# Making ordering non-deterministic
|
||||
cols = self.dictify_list_of_dicts(cdf.columns, 'name')
|
||||
|
||||
if main_db.sqlalchemy_uri.startswith('sqlite'):
|
||||
self.assertEqual(self.dictify_list_of_dicts([
|
||||
{'is_date': True, 'type': 'STRING', 'name': 'ds',
|
||||
'is_dim': False},
|
||||
{'is_date': True, 'type': 'STRING', 'name': 'ds2',
|
||||
'is_dim': False},
|
||||
{'agg': 'sum', 'is_date': False, 'type': 'INT',
|
||||
'name': 'epoch_ms', 'is_dim': False},
|
||||
{'agg': 'sum', 'is_date': False, 'type': 'INT',
|
||||
'name': 'epoch_s', 'is_dim': False},
|
||||
{'is_date': True, 'type': 'STRING', 'name': 'string0',
|
||||
'is_dim': False},
|
||||
{'is_date': False, 'type': 'STRING',
|
||||
'name': 'string1', 'is_dim': True},
|
||||
{'is_date': True, 'type': 'STRING', 'name': 'string2',
|
||||
'is_dim': False},
|
||||
{'is_date': False, 'type': 'STRING',
|
||||
'name': 'string3', 'is_dim': True}], 'name'),
|
||||
cols,
|
||||
)
|
||||
else:
|
||||
self.assertEqual(self.dictify_list_of_dicts([
|
||||
{'is_date': True, 'type': 'DATETIME', 'name': 'ds',
|
||||
'is_dim': False},
|
||||
{'is_date': True, 'type': 'DATETIME',
|
||||
'name': 'ds2', 'is_dim': False},
|
||||
{'agg': 'sum', 'is_date': False, 'type': 'INT',
|
||||
'name': 'epoch_ms', 'is_dim': False},
|
||||
{'agg': 'sum', 'is_date': False, 'type': 'INT',
|
||||
'name': 'epoch_s', 'is_dim': False},
|
||||
{'is_date': True, 'type': 'STRING', 'name': 'string0',
|
||||
'is_dim': False},
|
||||
{'is_date': False, 'type': 'STRING',
|
||||
'name': 'string1', 'is_dim': True},
|
||||
{'is_date': True, 'type': 'STRING', 'name': 'string2',
|
||||
'is_dim': False},
|
||||
{'is_date': False, 'type': 'STRING',
|
||||
'name': 'string3', 'is_dim': True}], 'name'),
|
||||
cols,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ import sqlalchemy as sqla
|
|||
|
||||
from superset import dataframe, db, jinja_context, security_manager, sql_lab, utils
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.db_engine_specs import BaseEngineSpec
|
||||
from superset.models import core as models
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.views.core import DatabaseView
|
||||
|
|
@ -626,8 +627,7 @@ class CoreTests(SupersetTestCase):
|
|||
(datetime.datetime(2017, 11, 18, 21, 53, 0, 219225, tzinfo=tz),),
|
||||
(datetime.datetime(2017, 11, 18, 22, 6, 30, 61810, tzinfo=tz),),
|
||||
]
|
||||
df = dataframe.SupersetDataFrame(pd.DataFrame(data=list(data),
|
||||
columns=['data']))
|
||||
df = dataframe.SupersetDataFrame(list(data), [['data']], BaseEngineSpec)
|
||||
data = df.data
|
||||
self.assertDictEqual(
|
||||
data[0],
|
||||
|
|
|
|||
|
|
@ -0,0 +1,115 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from superset.dataframe import dedup, SupersetDataFrame
|
||||
from superset.db_engine_specs import BaseEngineSpec
|
||||
from .base_tests import SupersetTestCase
|
||||
|
||||
|
||||
class SupersetDataFrameTestCase(SupersetTestCase):
|
||||
def test_dedup(self):
|
||||
self.assertEquals(
|
||||
dedup(['foo', 'bar']),
|
||||
['foo', 'bar'],
|
||||
)
|
||||
self.assertEquals(
|
||||
dedup(['foo', 'bar', 'foo', 'bar']),
|
||||
['foo', 'bar', 'foo__1', 'bar__1'],
|
||||
)
|
||||
self.assertEquals(
|
||||
dedup(['foo', 'bar', 'bar', 'bar']),
|
||||
['foo', 'bar', 'bar__1', 'bar__2'],
|
||||
)
|
||||
|
||||
def test_get_columns_basic(self):
|
||||
data = [
|
||||
('a1', 'b1', 'c1'),
|
||||
('a2', 'b2', 'c2'),
|
||||
]
|
||||
cursor_descr = (
|
||||
('a', 'string'),
|
||||
('b', 'string'),
|
||||
('c', 'string'),
|
||||
)
|
||||
cdf = SupersetDataFrame(data, cursor_descr, BaseEngineSpec)
|
||||
self.assertEqual(
|
||||
cdf.columns,
|
||||
[
|
||||
{
|
||||
'is_date': False,
|
||||
'type': 'STRING',
|
||||
'name': 'a',
|
||||
'is_dim': True,
|
||||
}, {
|
||||
'is_date': False,
|
||||
'type': 'STRING',
|
||||
'name': 'b',
|
||||
'is_dim': True,
|
||||
}, {
|
||||
'is_date': False,
|
||||
'type': 'STRING',
|
||||
'name': 'c',
|
||||
'is_dim': True,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_get_columns_with_int(self):
|
||||
data = [
|
||||
('a1', 1),
|
||||
('a2', 2),
|
||||
]
|
||||
cursor_descr = (
|
||||
('a', 'string'),
|
||||
('b', 'int'),
|
||||
)
|
||||
cdf = SupersetDataFrame(data, cursor_descr, BaseEngineSpec)
|
||||
self.assertEqual(
|
||||
cdf.columns,
|
||||
[
|
||||
{
|
||||
'is_date': False,
|
||||
'type': 'STRING',
|
||||
'name': 'a',
|
||||
'is_dim': True,
|
||||
}, {
|
||||
'is_date': False,
|
||||
'type': 'INT',
|
||||
'name': 'b',
|
||||
'is_dim': False,
|
||||
'agg': 'sum',
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_get_columns_type_inference(self):
|
||||
data = [
|
||||
(1.2, 1),
|
||||
(3.14, 2),
|
||||
]
|
||||
cursor_descr = (
|
||||
('a', None),
|
||||
('b', None),
|
||||
)
|
||||
cdf = SupersetDataFrame(data, cursor_descr, BaseEngineSpec)
|
||||
self.assertEqual(
|
||||
cdf.columns,
|
||||
[
|
||||
{
|
||||
'is_date': False,
|
||||
'type': 'FLOAT',
|
||||
'name': 'a',
|
||||
'is_dim': False,
|
||||
'agg': 'sum',
|
||||
}, {
|
||||
'is_date': False,
|
||||
'type': 'INT',
|
||||
'name': 'b',
|
||||
'is_dim': False,
|
||||
'agg': 'sum',
|
||||
},
|
||||
],
|
||||
)
|
||||
|
|
@ -7,7 +7,9 @@ from __future__ import unicode_literals
|
|||
import textwrap
|
||||
|
||||
from superset.db_engine_specs import (
|
||||
HiveEngineSpec, MssqlEngineSpec, MySQLEngineSpec)
|
||||
BaseEngineSpec, HiveEngineSpec, MssqlEngineSpec,
|
||||
MySQLEngineSpec, PrestoEngineSpec,
|
||||
)
|
||||
from superset.models.core import Database
|
||||
from .base_tests import SupersetTestCase
|
||||
|
||||
|
|
@ -193,3 +195,9 @@ class DbEngineSpecsTestCase(SupersetTestCase):
|
|||
FROM
|
||||
table LIMIT 1000"""),
|
||||
)
|
||||
|
||||
def test_get_datatype(self):
|
||||
self.assertEquals('STRING', PrestoEngineSpec.get_datatype('string'))
|
||||
self.assertEquals('TINY', MySQLEngineSpec.get_datatype(1))
|
||||
self.assertEquals('VARCHAR', MySQLEngineSpec.get_datatype(15))
|
||||
self.assertEquals('VARCHAR', BaseEngineSpec.get_datatype('VARCHAR'))
|
||||
|
|
|
|||
|
|
@ -12,8 +12,9 @@ import unittest
|
|||
from flask_appbuilder.security.sqla import models as ab_models
|
||||
|
||||
from superset import db, security_manager, utils
|
||||
from superset.dataframe import SupersetDataFrame
|
||||
from superset.db_engine_specs import BaseEngineSpec
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.sql_lab import convert_results_to_df
|
||||
from .base_tests import SupersetTestCase
|
||||
|
||||
|
||||
|
|
@ -203,9 +204,13 @@ class SqlLabTests(SupersetTestCase):
|
|||
raise_on_error=True)
|
||||
|
||||
def test_df_conversion_no_dict(self):
|
||||
cols = ['string_col', 'int_col', 'float_col']
|
||||
cols = [
|
||||
['string_col', 'string'],
|
||||
['int_col', 'int'],
|
||||
['float_col', 'float'],
|
||||
]
|
||||
data = [['a', 4, 4.0]]
|
||||
cdf = convert_results_to_df(cols, data)
|
||||
cdf = SupersetDataFrame(data, cols, BaseEngineSpec)
|
||||
|
||||
self.assertEquals(len(data), cdf.size)
|
||||
self.assertEquals(len(cols), len(cdf.columns))
|
||||
|
|
@ -213,7 +218,7 @@ class SqlLabTests(SupersetTestCase):
|
|||
def test_df_conversion_tuple(self):
|
||||
cols = ['string_col', 'int_col', 'list_col', 'float_col']
|
||||
data = [(u'Text', 111, [123], 1.0)]
|
||||
cdf = convert_results_to_df(cols, data)
|
||||
cdf = SupersetDataFrame(data, cols, BaseEngineSpec)
|
||||
|
||||
self.assertEquals(len(data), cdf.size)
|
||||
self.assertEquals(len(cols), len(cdf.columns))
|
||||
|
|
@ -221,7 +226,7 @@ class SqlLabTests(SupersetTestCase):
|
|||
def test_df_conversion_dict(self):
|
||||
cols = ['string_col', 'dict_col', 'int_col']
|
||||
data = [['a', {'c1': 1, 'c2': 2, 'c3': 3}, 4]]
|
||||
cdf = convert_results_to_df(cols, data)
|
||||
cdf = SupersetDataFrame(data, cols, BaseEngineSpec)
|
||||
|
||||
self.assertEquals(len(data), cdf.size)
|
||||
self.assertEquals(len(cols), len(cdf.columns))
|
||||
|
|
|
|||
Loading…
Reference in New Issue