Improve database type inference (#4724)

* Improve database type inference

Python's DBAPI isn't super clear and homogeneous on the
cursor.description specification, and this PR attempts to improve
inferring the datatypes returned in the cursor.

This work started around Presto's TIMESTAMP type being mishandled as
string as the database driver (pyhive) returns it as a string. The work
here fixes this bug and does a better job at inferring MySQL and Presto types.
It also creates a new method in db_engine_specs allowing for other
databases engines to implement and become more precise on type-inference
as needed.

* Fixing tests

* Adressing comments

* Using infer_objects

* Removing faulty line

* Addressing PrestoSpec redundant method comment

* Fix rebase issue

* Fix tests
This commit is contained in:
Maxime Beauchemin 2018-06-27 21:35:12 -07:00 committed by GitHub
parent 04fc1d1089
commit 777d876a52
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 224 additions and 117 deletions

View File

@ -13,6 +13,7 @@ from __future__ import print_function
from __future__ import unicode_literals
from datetime import date, datetime
import logging
import numpy as np
import pandas as pd
@ -26,6 +27,27 @@ INFER_COL_TYPES_THRESHOLD = 95
INFER_COL_TYPES_SAMPLE_SIZE = 100
def dedup(l, suffix='__'):
"""De-duplicates a list of string by suffixing a counter
Always returns the same number of entries as provided, and always returns
unique values.
>>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar'])))
foo,bar,bar__1,bar__2
"""
new_l = []
seen = {}
for s in l:
if s in seen:
seen[s] += 1
s += suffix + str(seen[s])
else:
seen[s] = 0
new_l.append(s)
return new_l
class SupersetDataFrame(object):
# Mapping numpy dtype.char to generic database types
type_map = {
@ -43,19 +65,39 @@ class SupersetDataFrame(object):
'V': None, # raw data (void)
}
def __init__(self, df):
self.__df = df.where((pd.notnull(df)), None)
def __init__(self, data, cursor_description, db_engine_spec):
column_names = []
if cursor_description:
column_names = [col[0] for col in cursor_description]
self.column_names = dedup(
db_engine_spec.get_normalized_column_names(cursor_description))
data = data or []
self.df = (
pd.DataFrame(list(data), columns=column_names).infer_objects())
self._type_dict = {}
try:
# The driver may not be passing a cursor.description
self._type_dict = {
col: db_engine_spec.get_datatype(cursor_description[i][1])
for i, col in enumerate(self.column_names)
if cursor_description
}
except Exception as e:
logging.exception(e)
@property
def size(self):
return len(self.__df.index)
return len(self.df.index)
@property
def data(self):
# work around for https://github.com/pandas-dev/pandas/issues/18372
data = [dict((k, _maybe_box_datetimelike(v))
for k, v in zip(self.__df.columns, np.atleast_1d(row)))
for row in self.__df.values]
for k, v in zip(self.df.columns, np.atleast_1d(row)))
for row in self.df.values]
for d in data:
for k, v in list(d.items()):
# if an int is too big for Java Script to handle
@ -70,7 +112,8 @@ class SupersetDataFrame(object):
"""Given a numpy dtype, Returns a generic database type"""
if isinstance(dtype, ExtensionDtype):
return cls.type_map.get(dtype.kind)
return cls.type_map.get(dtype.char)
elif hasattr(dtype, 'char'):
return cls.type_map.get(dtype.char)
@classmethod
def datetime_conversion_rate(cls, data_series):
@ -105,7 +148,7 @@ class SupersetDataFrame(object):
# consider checking for key substring too.
if cls.is_id(column_name):
return 'count_distinct'
if (issubclass(dtype.type, np.generic) and
if (hasattr(dtype, 'type') and issubclass(dtype.type, np.generic) and
np.issubdtype(dtype, np.number)):
return 'sum'
return None
@ -116,22 +159,25 @@ class SupersetDataFrame(object):
:return: dict, with the fields name, type, is_date, is_dim and agg.
"""
if self.__df.empty:
if self.df.empty:
return None
columns = []
sample_size = min(INFER_COL_TYPES_SAMPLE_SIZE, len(self.__df.index))
sample = self.__df
sample_size = min(INFER_COL_TYPES_SAMPLE_SIZE, len(self.df.index))
sample = self.df
if sample_size:
sample = self.__df.sample(sample_size)
for col in self.__df.dtypes.keys():
col_db_type = self.db_type(self.__df.dtypes[col])
sample = self.df.sample(sample_size)
for col in self.df.dtypes.keys():
col_db_type = (
self._type_dict.get(col) or
self.db_type(self.df.dtypes[col])
)
column = {
'name': col,
'agg': self.agg_func(self.__df.dtypes[col], col),
'agg': self.agg_func(self.df.dtypes[col], col),
'type': col_db_type,
'is_date': self.is_date(self.__df.dtypes[col]),
'is_dim': self.is_dimension(self.__df.dtypes[col], col),
'is_date': self.is_date(self.df.dtypes[col]),
'is_dim': self.is_dimension(self.df.dtypes[col], col),
}
if column['type'] in ('OBJECT', None):

View File

@ -30,6 +30,7 @@ import boto3
from flask import g
from flask_babel import lazy_gettext as _
import pandas
from past.builtins import basestring
import sqlalchemy as sqla
from sqlalchemy import select
from sqlalchemy.engine import create_engine
@ -85,6 +86,11 @@ class BaseEngineSpec(object):
def epoch_ms_to_dttm(cls):
return cls.epoch_to_dttm().replace('{col}', '({col}/1000.0)')
@classmethod
def get_datatype(cls, type_code):
if isinstance(type_code, basestring) and len(type_code):
return type_code.upper()
@classmethod
def extra_table_metadata(cls, database, table_name, schema_name):
"""Returns engine-specific table metadata"""
@ -592,6 +598,7 @@ class MySQLEngineSpec(BaseEngineSpec):
'INTERVAL DAYOFWEEK(DATE_SUB({col}, INTERVAL 1 DAY)) - 1 DAY))',
'P1W'),
)
type_code_map = {} # loaded from get_datatype only if needed
@classmethod
def convert_dttm(cls, target_type, dttm):
@ -606,6 +613,23 @@ class MySQLEngineSpec(BaseEngineSpec):
uri.database = selected_schema
return uri
@classmethod
def get_datatype(cls, type_code):
if not cls.type_code_map:
# only import and store if needed at least once
import MySQLdb
ft = MySQLdb.constants.FIELD_TYPE
cls.type_code_map = {
getattr(ft, k): k
for k in dir(ft)
if not k.startswith('_')
}
datatype = type_code
if isinstance(type_code, int):
datatype = cls.type_code_map.get(type_code)
if datatype and isinstance(datatype, basestring) and len(datatype):
return datatype
@classmethod
def epoch_to_dttm(cls):
return 'from_unixtime({col})'

View File

@ -10,8 +10,6 @@ import uuid
from celery.exceptions import SoftTimeLimitExceeded
from contextlib2 import contextmanager
import numpy as np
import pandas as pd
import sqlalchemy
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool
@ -31,27 +29,6 @@ class SqlLabException(Exception):
pass
def dedup(l, suffix='__'):
"""De-duplicates a list of string by suffixing a counter
Always returns the same number of entries as provided, and always returns
unique values.
>>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar'])))
foo,bar,bar__1,bar__2
"""
new_l = []
seen = {}
for s in l:
if s in seen:
seen[s] += 1
s += suffix + str(seen[s])
else:
seen[s] = 0
new_l.append(s)
return new_l
def get_query(query_id, session, retry_count=5):
"""attemps to get the query and retry if it cannot"""
query = None
@ -96,24 +73,6 @@ def session_scope(nullpool):
session.close()
def convert_results_to_df(column_names, data):
"""Convert raw query results to a DataFrame."""
column_names = dedup(column_names)
# check whether the result set has any nested dict columns
if data:
first_row = data[0]
has_dict_col = any([isinstance(c, dict) for c in first_row])
df_data = list(data) if has_dict_col else np.array(data, dtype=object)
else:
df_data = []
cdf = dataframe.SupersetDataFrame(
pd.DataFrame(df_data, columns=column_names))
return cdf
@celery_app.task(bind=True, soft_time_limit=SQLLAB_TIMEOUT)
def get_sql_results(
ctask, query_id, rendered_query, return_results=True, store_results=False,
@ -233,7 +192,6 @@ def execute_sql(
return handle_error(db_engine_spec.extract_error_message(e))
logging.info('Fetching cursor description')
column_names = db_engine_spec.get_normalized_column_names(cursor.description)
if conn is not None:
conn.commit()
@ -242,7 +200,7 @@ def execute_sql(
if query.status == utils.QueryStatus.STOPPED:
return handle_error('The query has been stopped')
cdf = convert_results_to_df(column_names, data)
cdf = dataframe.SupersetDataFrame(data, cursor.description, db_engine_spec)
query.rows = cdf.size
query.progress = 100

View File

@ -14,7 +14,7 @@ import unittest
import pandas as pd
from past.builtins import basestring
from superset import app, cli, dataframe, db, security_manager
from superset import app, cli, db, security_manager
from superset.models.helpers import QueryStatus
from superset.models.sql_lab import Query
from superset.sql_parse import SupersetQuery
@ -245,55 +245,6 @@ class CeleryTestCase(SupersetTestCase):
def dictify_list_of_dicts(cls, l, k):
return {str(o[k]): cls.de_unicode_dict(o) for o in l}
def test_get_columns(self):
main_db = self.get_main_database(db.session)
df = main_db.get_df('SELECT * FROM multiformat_time_series', None)
cdf = dataframe.SupersetDataFrame(df)
# Making ordering non-deterministic
cols = self.dictify_list_of_dicts(cdf.columns, 'name')
if main_db.sqlalchemy_uri.startswith('sqlite'):
self.assertEqual(self.dictify_list_of_dicts([
{'is_date': True, 'type': 'STRING', 'name': 'ds',
'is_dim': False},
{'is_date': True, 'type': 'STRING', 'name': 'ds2',
'is_dim': False},
{'agg': 'sum', 'is_date': False, 'type': 'INT',
'name': 'epoch_ms', 'is_dim': False},
{'agg': 'sum', 'is_date': False, 'type': 'INT',
'name': 'epoch_s', 'is_dim': False},
{'is_date': True, 'type': 'STRING', 'name': 'string0',
'is_dim': False},
{'is_date': False, 'type': 'STRING',
'name': 'string1', 'is_dim': True},
{'is_date': True, 'type': 'STRING', 'name': 'string2',
'is_dim': False},
{'is_date': False, 'type': 'STRING',
'name': 'string3', 'is_dim': True}], 'name'),
cols,
)
else:
self.assertEqual(self.dictify_list_of_dicts([
{'is_date': True, 'type': 'DATETIME', 'name': 'ds',
'is_dim': False},
{'is_date': True, 'type': 'DATETIME',
'name': 'ds2', 'is_dim': False},
{'agg': 'sum', 'is_date': False, 'type': 'INT',
'name': 'epoch_ms', 'is_dim': False},
{'agg': 'sum', 'is_date': False, 'type': 'INT',
'name': 'epoch_s', 'is_dim': False},
{'is_date': True, 'type': 'STRING', 'name': 'string0',
'is_dim': False},
{'is_date': False, 'type': 'STRING',
'name': 'string1', 'is_dim': True},
{'is_date': True, 'type': 'STRING', 'name': 'string2',
'is_dim': False},
{'is_date': False, 'type': 'STRING',
'name': 'string3', 'is_dim': True}], 'name'),
cols,
)
if __name__ == '__main__':
unittest.main()

View File

@ -24,6 +24,7 @@ import sqlalchemy as sqla
from superset import dataframe, db, jinja_context, security_manager, sql_lab, utils
from superset.connectors.sqla.models import SqlaTable
from superset.db_engine_specs import BaseEngineSpec
from superset.models import core as models
from superset.models.sql_lab import Query
from superset.views.core import DatabaseView
@ -626,8 +627,7 @@ class CoreTests(SupersetTestCase):
(datetime.datetime(2017, 11, 18, 21, 53, 0, 219225, tzinfo=tz),),
(datetime.datetime(2017, 11, 18, 22, 6, 30, 61810, tzinfo=tz),),
]
df = dataframe.SupersetDataFrame(pd.DataFrame(data=list(data),
columns=['data']))
df = dataframe.SupersetDataFrame(list(data), [['data']], BaseEngineSpec)
data = df.data
self.assertDictEqual(
data[0],

115
tests/dataframe_test.py Normal file
View File

@ -0,0 +1,115 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from superset.dataframe import dedup, SupersetDataFrame
from superset.db_engine_specs import BaseEngineSpec
from .base_tests import SupersetTestCase
class SupersetDataFrameTestCase(SupersetTestCase):
def test_dedup(self):
self.assertEquals(
dedup(['foo', 'bar']),
['foo', 'bar'],
)
self.assertEquals(
dedup(['foo', 'bar', 'foo', 'bar']),
['foo', 'bar', 'foo__1', 'bar__1'],
)
self.assertEquals(
dedup(['foo', 'bar', 'bar', 'bar']),
['foo', 'bar', 'bar__1', 'bar__2'],
)
def test_get_columns_basic(self):
data = [
('a1', 'b1', 'c1'),
('a2', 'b2', 'c2'),
]
cursor_descr = (
('a', 'string'),
('b', 'string'),
('c', 'string'),
)
cdf = SupersetDataFrame(data, cursor_descr, BaseEngineSpec)
self.assertEqual(
cdf.columns,
[
{
'is_date': False,
'type': 'STRING',
'name': 'a',
'is_dim': True,
}, {
'is_date': False,
'type': 'STRING',
'name': 'b',
'is_dim': True,
}, {
'is_date': False,
'type': 'STRING',
'name': 'c',
'is_dim': True,
},
],
)
def test_get_columns_with_int(self):
data = [
('a1', 1),
('a2', 2),
]
cursor_descr = (
('a', 'string'),
('b', 'int'),
)
cdf = SupersetDataFrame(data, cursor_descr, BaseEngineSpec)
self.assertEqual(
cdf.columns,
[
{
'is_date': False,
'type': 'STRING',
'name': 'a',
'is_dim': True,
}, {
'is_date': False,
'type': 'INT',
'name': 'b',
'is_dim': False,
'agg': 'sum',
},
],
)
def test_get_columns_type_inference(self):
data = [
(1.2, 1),
(3.14, 2),
]
cursor_descr = (
('a', None),
('b', None),
)
cdf = SupersetDataFrame(data, cursor_descr, BaseEngineSpec)
self.assertEqual(
cdf.columns,
[
{
'is_date': False,
'type': 'FLOAT',
'name': 'a',
'is_dim': False,
'agg': 'sum',
}, {
'is_date': False,
'type': 'INT',
'name': 'b',
'is_dim': False,
'agg': 'sum',
},
],
)

View File

@ -7,7 +7,9 @@ from __future__ import unicode_literals
import textwrap
from superset.db_engine_specs import (
HiveEngineSpec, MssqlEngineSpec, MySQLEngineSpec)
BaseEngineSpec, HiveEngineSpec, MssqlEngineSpec,
MySQLEngineSpec, PrestoEngineSpec,
)
from superset.models.core import Database
from .base_tests import SupersetTestCase
@ -193,3 +195,9 @@ class DbEngineSpecsTestCase(SupersetTestCase):
FROM
table LIMIT 1000"""),
)
def test_get_datatype(self):
self.assertEquals('STRING', PrestoEngineSpec.get_datatype('string'))
self.assertEquals('TINY', MySQLEngineSpec.get_datatype(1))
self.assertEquals('VARCHAR', MySQLEngineSpec.get_datatype(15))
self.assertEquals('VARCHAR', BaseEngineSpec.get_datatype('VARCHAR'))

View File

@ -12,8 +12,9 @@ import unittest
from flask_appbuilder.security.sqla import models as ab_models
from superset import db, security_manager, utils
from superset.dataframe import SupersetDataFrame
from superset.db_engine_specs import BaseEngineSpec
from superset.models.sql_lab import Query
from superset.sql_lab import convert_results_to_df
from .base_tests import SupersetTestCase
@ -203,9 +204,13 @@ class SqlLabTests(SupersetTestCase):
raise_on_error=True)
def test_df_conversion_no_dict(self):
cols = ['string_col', 'int_col', 'float_col']
cols = [
['string_col', 'string'],
['int_col', 'int'],
['float_col', 'float'],
]
data = [['a', 4, 4.0]]
cdf = convert_results_to_df(cols, data)
cdf = SupersetDataFrame(data, cols, BaseEngineSpec)
self.assertEquals(len(data), cdf.size)
self.assertEquals(len(cols), len(cdf.columns))
@ -213,7 +218,7 @@ class SqlLabTests(SupersetTestCase):
def test_df_conversion_tuple(self):
cols = ['string_col', 'int_col', 'list_col', 'float_col']
data = [(u'Text', 111, [123], 1.0)]
cdf = convert_results_to_df(cols, data)
cdf = SupersetDataFrame(data, cols, BaseEngineSpec)
self.assertEquals(len(data), cdf.size)
self.assertEquals(len(cols), len(cdf.columns))
@ -221,7 +226,7 @@ class SqlLabTests(SupersetTestCase):
def test_df_conversion_dict(self):
cols = ['string_col', 'dict_col', 'int_col']
data = [['a', {'c1': 1, 'c2': 2, 'c3': 3}, 4]]
cdf = convert_results_to_df(cols, data)
cdf = SupersetDataFrame(data, cols, BaseEngineSpec)
self.assertEquals(len(data), cdf.size)
self.assertEquals(len(cols), len(cdf.columns))