diff --git a/superset/dataframe.py b/superset/dataframe.py index 79a2c3d56..5fba4ffed 100644 --- a/superset/dataframe.py +++ b/superset/dataframe.py @@ -13,6 +13,7 @@ from __future__ import print_function from __future__ import unicode_literals from datetime import date, datetime +import logging import numpy as np import pandas as pd @@ -26,6 +27,27 @@ INFER_COL_TYPES_THRESHOLD = 95 INFER_COL_TYPES_SAMPLE_SIZE = 100 +def dedup(l, suffix='__'): + """De-duplicates a list of string by suffixing a counter + + Always returns the same number of entries as provided, and always returns + unique values. + + >>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar']))) + foo,bar,bar__1,bar__2 + """ + new_l = [] + seen = {} + for s in l: + if s in seen: + seen[s] += 1 + s += suffix + str(seen[s]) + else: + seen[s] = 0 + new_l.append(s) + return new_l + + class SupersetDataFrame(object): # Mapping numpy dtype.char to generic database types type_map = { @@ -43,19 +65,39 @@ class SupersetDataFrame(object): 'V': None, # raw data (void) } - def __init__(self, df): - self.__df = df.where((pd.notnull(df)), None) + def __init__(self, data, cursor_description, db_engine_spec): + column_names = [] + if cursor_description: + column_names = [col[0] for col in cursor_description] + + self.column_names = dedup( + db_engine_spec.get_normalized_column_names(cursor_description)) + + data = data or [] + self.df = ( + pd.DataFrame(list(data), columns=column_names).infer_objects()) + + self._type_dict = {} + try: + # The driver may not be passing a cursor.description + self._type_dict = { + col: db_engine_spec.get_datatype(cursor_description[i][1]) + for i, col in enumerate(self.column_names) + if cursor_description + } + except Exception as e: + logging.exception(e) @property def size(self): - return len(self.__df.index) + return len(self.df.index) @property def data(self): # work around for https://github.com/pandas-dev/pandas/issues/18372 data = [dict((k, _maybe_box_datetimelike(v)) - for k, v in zip(self.__df.columns, np.atleast_1d(row))) - for row in self.__df.values] + for k, v in zip(self.df.columns, np.atleast_1d(row))) + for row in self.df.values] for d in data: for k, v in list(d.items()): # if an int is too big for Java Script to handle @@ -70,7 +112,8 @@ class SupersetDataFrame(object): """Given a numpy dtype, Returns a generic database type""" if isinstance(dtype, ExtensionDtype): return cls.type_map.get(dtype.kind) - return cls.type_map.get(dtype.char) + elif hasattr(dtype, 'char'): + return cls.type_map.get(dtype.char) @classmethod def datetime_conversion_rate(cls, data_series): @@ -105,7 +148,7 @@ class SupersetDataFrame(object): # consider checking for key substring too. if cls.is_id(column_name): return 'count_distinct' - if (issubclass(dtype.type, np.generic) and + if (hasattr(dtype, 'type') and issubclass(dtype.type, np.generic) and np.issubdtype(dtype, np.number)): return 'sum' return None @@ -116,22 +159,25 @@ class SupersetDataFrame(object): :return: dict, with the fields name, type, is_date, is_dim and agg. """ - if self.__df.empty: + if self.df.empty: return None columns = [] - sample_size = min(INFER_COL_TYPES_SAMPLE_SIZE, len(self.__df.index)) - sample = self.__df + sample_size = min(INFER_COL_TYPES_SAMPLE_SIZE, len(self.df.index)) + sample = self.df if sample_size: - sample = self.__df.sample(sample_size) - for col in self.__df.dtypes.keys(): - col_db_type = self.db_type(self.__df.dtypes[col]) + sample = self.df.sample(sample_size) + for col in self.df.dtypes.keys(): + col_db_type = ( + self._type_dict.get(col) or + self.db_type(self.df.dtypes[col]) + ) column = { 'name': col, - 'agg': self.agg_func(self.__df.dtypes[col], col), + 'agg': self.agg_func(self.df.dtypes[col], col), 'type': col_db_type, - 'is_date': self.is_date(self.__df.dtypes[col]), - 'is_dim': self.is_dimension(self.__df.dtypes[col], col), + 'is_date': self.is_date(self.df.dtypes[col]), + 'is_dim': self.is_dimension(self.df.dtypes[col], col), } if column['type'] in ('OBJECT', None): diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index 4f6b22e30..4181c49d6 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -30,6 +30,7 @@ import boto3 from flask import g from flask_babel import lazy_gettext as _ import pandas +from past.builtins import basestring import sqlalchemy as sqla from sqlalchemy import select from sqlalchemy.engine import create_engine @@ -85,6 +86,11 @@ class BaseEngineSpec(object): def epoch_ms_to_dttm(cls): return cls.epoch_to_dttm().replace('{col}', '({col}/1000.0)') + @classmethod + def get_datatype(cls, type_code): + if isinstance(type_code, basestring) and len(type_code): + return type_code.upper() + @classmethod def extra_table_metadata(cls, database, table_name, schema_name): """Returns engine-specific table metadata""" @@ -592,6 +598,7 @@ class MySQLEngineSpec(BaseEngineSpec): 'INTERVAL DAYOFWEEK(DATE_SUB({col}, INTERVAL 1 DAY)) - 1 DAY))', 'P1W'), ) + type_code_map = {} # loaded from get_datatype only if needed @classmethod def convert_dttm(cls, target_type, dttm): @@ -606,6 +613,23 @@ class MySQLEngineSpec(BaseEngineSpec): uri.database = selected_schema return uri + @classmethod + def get_datatype(cls, type_code): + if not cls.type_code_map: + # only import and store if needed at least once + import MySQLdb + ft = MySQLdb.constants.FIELD_TYPE + cls.type_code_map = { + getattr(ft, k): k + for k in dir(ft) + if not k.startswith('_') + } + datatype = type_code + if isinstance(type_code, int): + datatype = cls.type_code_map.get(type_code) + if datatype and isinstance(datatype, basestring) and len(datatype): + return datatype + @classmethod def epoch_to_dttm(cls): return 'from_unixtime({col})' diff --git a/superset/sql_lab.py b/superset/sql_lab.py index df00a2b6b..34a9eeb9e 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -10,8 +10,6 @@ import uuid from celery.exceptions import SoftTimeLimitExceeded from contextlib2 import contextmanager -import numpy as np -import pandas as pd import sqlalchemy from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import NullPool @@ -31,27 +29,6 @@ class SqlLabException(Exception): pass -def dedup(l, suffix='__'): - """De-duplicates a list of string by suffixing a counter - - Always returns the same number of entries as provided, and always returns - unique values. - - >>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar']))) - foo,bar,bar__1,bar__2 - """ - new_l = [] - seen = {} - for s in l: - if s in seen: - seen[s] += 1 - s += suffix + str(seen[s]) - else: - seen[s] = 0 - new_l.append(s) - return new_l - - def get_query(query_id, session, retry_count=5): """attemps to get the query and retry if it cannot""" query = None @@ -96,24 +73,6 @@ def session_scope(nullpool): session.close() -def convert_results_to_df(column_names, data): - """Convert raw query results to a DataFrame.""" - column_names = dedup(column_names) - - # check whether the result set has any nested dict columns - if data: - first_row = data[0] - has_dict_col = any([isinstance(c, dict) for c in first_row]) - df_data = list(data) if has_dict_col else np.array(data, dtype=object) - else: - df_data = [] - - cdf = dataframe.SupersetDataFrame( - pd.DataFrame(df_data, columns=column_names)) - - return cdf - - @celery_app.task(bind=True, soft_time_limit=SQLLAB_TIMEOUT) def get_sql_results( ctask, query_id, rendered_query, return_results=True, store_results=False, @@ -233,7 +192,6 @@ def execute_sql( return handle_error(db_engine_spec.extract_error_message(e)) logging.info('Fetching cursor description') - column_names = db_engine_spec.get_normalized_column_names(cursor.description) if conn is not None: conn.commit() @@ -242,7 +200,7 @@ def execute_sql( if query.status == utils.QueryStatus.STOPPED: return handle_error('The query has been stopped') - cdf = convert_results_to_df(column_names, data) + cdf = dataframe.SupersetDataFrame(data, cursor.description, db_engine_spec) query.rows = cdf.size query.progress = 100 diff --git a/tests/celery_tests.py b/tests/celery_tests.py index 39b7749ae..afaeea9df 100644 --- a/tests/celery_tests.py +++ b/tests/celery_tests.py @@ -14,7 +14,7 @@ import unittest import pandas as pd from past.builtins import basestring -from superset import app, cli, dataframe, db, security_manager +from superset import app, cli, db, security_manager from superset.models.helpers import QueryStatus from superset.models.sql_lab import Query from superset.sql_parse import SupersetQuery @@ -245,55 +245,6 @@ class CeleryTestCase(SupersetTestCase): def dictify_list_of_dicts(cls, l, k): return {str(o[k]): cls.de_unicode_dict(o) for o in l} - def test_get_columns(self): - main_db = self.get_main_database(db.session) - df = main_db.get_df('SELECT * FROM multiformat_time_series', None) - cdf = dataframe.SupersetDataFrame(df) - - # Making ordering non-deterministic - cols = self.dictify_list_of_dicts(cdf.columns, 'name') - - if main_db.sqlalchemy_uri.startswith('sqlite'): - self.assertEqual(self.dictify_list_of_dicts([ - {'is_date': True, 'type': 'STRING', 'name': 'ds', - 'is_dim': False}, - {'is_date': True, 'type': 'STRING', 'name': 'ds2', - 'is_dim': False}, - {'agg': 'sum', 'is_date': False, 'type': 'INT', - 'name': 'epoch_ms', 'is_dim': False}, - {'agg': 'sum', 'is_date': False, 'type': 'INT', - 'name': 'epoch_s', 'is_dim': False}, - {'is_date': True, 'type': 'STRING', 'name': 'string0', - 'is_dim': False}, - {'is_date': False, 'type': 'STRING', - 'name': 'string1', 'is_dim': True}, - {'is_date': True, 'type': 'STRING', 'name': 'string2', - 'is_dim': False}, - {'is_date': False, 'type': 'STRING', - 'name': 'string3', 'is_dim': True}], 'name'), - cols, - ) - else: - self.assertEqual(self.dictify_list_of_dicts([ - {'is_date': True, 'type': 'DATETIME', 'name': 'ds', - 'is_dim': False}, - {'is_date': True, 'type': 'DATETIME', - 'name': 'ds2', 'is_dim': False}, - {'agg': 'sum', 'is_date': False, 'type': 'INT', - 'name': 'epoch_ms', 'is_dim': False}, - {'agg': 'sum', 'is_date': False, 'type': 'INT', - 'name': 'epoch_s', 'is_dim': False}, - {'is_date': True, 'type': 'STRING', 'name': 'string0', - 'is_dim': False}, - {'is_date': False, 'type': 'STRING', - 'name': 'string1', 'is_dim': True}, - {'is_date': True, 'type': 'STRING', 'name': 'string2', - 'is_dim': False}, - {'is_date': False, 'type': 'STRING', - 'name': 'string3', 'is_dim': True}], 'name'), - cols, - ) - if __name__ == '__main__': unittest.main() diff --git a/tests/core_tests.py b/tests/core_tests.py index 6a4f153eb..f1a01796b 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -24,6 +24,7 @@ import sqlalchemy as sqla from superset import dataframe, db, jinja_context, security_manager, sql_lab, utils from superset.connectors.sqla.models import SqlaTable +from superset.db_engine_specs import BaseEngineSpec from superset.models import core as models from superset.models.sql_lab import Query from superset.views.core import DatabaseView @@ -626,8 +627,7 @@ class CoreTests(SupersetTestCase): (datetime.datetime(2017, 11, 18, 21, 53, 0, 219225, tzinfo=tz),), (datetime.datetime(2017, 11, 18, 22, 6, 30, 61810, tzinfo=tz),), ] - df = dataframe.SupersetDataFrame(pd.DataFrame(data=list(data), - columns=['data'])) + df = dataframe.SupersetDataFrame(list(data), [['data']], BaseEngineSpec) data = df.data self.assertDictEqual( data[0], diff --git a/tests/dataframe_test.py b/tests/dataframe_test.py new file mode 100644 index 000000000..b56770240 --- /dev/null +++ b/tests/dataframe_test.py @@ -0,0 +1,115 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from superset.dataframe import dedup, SupersetDataFrame +from superset.db_engine_specs import BaseEngineSpec +from .base_tests import SupersetTestCase + + +class SupersetDataFrameTestCase(SupersetTestCase): + def test_dedup(self): + self.assertEquals( + dedup(['foo', 'bar']), + ['foo', 'bar'], + ) + self.assertEquals( + dedup(['foo', 'bar', 'foo', 'bar']), + ['foo', 'bar', 'foo__1', 'bar__1'], + ) + self.assertEquals( + dedup(['foo', 'bar', 'bar', 'bar']), + ['foo', 'bar', 'bar__1', 'bar__2'], + ) + + def test_get_columns_basic(self): + data = [ + ('a1', 'b1', 'c1'), + ('a2', 'b2', 'c2'), + ] + cursor_descr = ( + ('a', 'string'), + ('b', 'string'), + ('c', 'string'), + ) + cdf = SupersetDataFrame(data, cursor_descr, BaseEngineSpec) + self.assertEqual( + cdf.columns, + [ + { + 'is_date': False, + 'type': 'STRING', + 'name': 'a', + 'is_dim': True, + }, { + 'is_date': False, + 'type': 'STRING', + 'name': 'b', + 'is_dim': True, + }, { + 'is_date': False, + 'type': 'STRING', + 'name': 'c', + 'is_dim': True, + }, + ], + ) + + def test_get_columns_with_int(self): + data = [ + ('a1', 1), + ('a2', 2), + ] + cursor_descr = ( + ('a', 'string'), + ('b', 'int'), + ) + cdf = SupersetDataFrame(data, cursor_descr, BaseEngineSpec) + self.assertEqual( + cdf.columns, + [ + { + 'is_date': False, + 'type': 'STRING', + 'name': 'a', + 'is_dim': True, + }, { + 'is_date': False, + 'type': 'INT', + 'name': 'b', + 'is_dim': False, + 'agg': 'sum', + }, + ], + ) + + def test_get_columns_type_inference(self): + data = [ + (1.2, 1), + (3.14, 2), + ] + cursor_descr = ( + ('a', None), + ('b', None), + ) + cdf = SupersetDataFrame(data, cursor_descr, BaseEngineSpec) + self.assertEqual( + cdf.columns, + [ + { + 'is_date': False, + 'type': 'FLOAT', + 'name': 'a', + 'is_dim': False, + 'agg': 'sum', + }, { + 'is_date': False, + 'type': 'INT', + 'name': 'b', + 'is_dim': False, + 'agg': 'sum', + }, + ], + ) diff --git a/tests/db_engine_specs_test.py b/tests/db_engine_specs_test.py index bdce0b060..447914ed5 100644 --- a/tests/db_engine_specs_test.py +++ b/tests/db_engine_specs_test.py @@ -7,7 +7,9 @@ from __future__ import unicode_literals import textwrap from superset.db_engine_specs import ( - HiveEngineSpec, MssqlEngineSpec, MySQLEngineSpec) + BaseEngineSpec, HiveEngineSpec, MssqlEngineSpec, + MySQLEngineSpec, PrestoEngineSpec, +) from superset.models.core import Database from .base_tests import SupersetTestCase @@ -193,3 +195,9 @@ class DbEngineSpecsTestCase(SupersetTestCase): FROM table LIMIT 1000"""), ) + + def test_get_datatype(self): + self.assertEquals('STRING', PrestoEngineSpec.get_datatype('string')) + self.assertEquals('TINY', MySQLEngineSpec.get_datatype(1)) + self.assertEquals('VARCHAR', MySQLEngineSpec.get_datatype(15)) + self.assertEquals('VARCHAR', BaseEngineSpec.get_datatype('VARCHAR')) diff --git a/tests/sqllab_tests.py b/tests/sqllab_tests.py index 49926f80d..a3bb564dd 100644 --- a/tests/sqllab_tests.py +++ b/tests/sqllab_tests.py @@ -12,8 +12,9 @@ import unittest from flask_appbuilder.security.sqla import models as ab_models from superset import db, security_manager, utils +from superset.dataframe import SupersetDataFrame +from superset.db_engine_specs import BaseEngineSpec from superset.models.sql_lab import Query -from superset.sql_lab import convert_results_to_df from .base_tests import SupersetTestCase @@ -203,9 +204,13 @@ class SqlLabTests(SupersetTestCase): raise_on_error=True) def test_df_conversion_no_dict(self): - cols = ['string_col', 'int_col', 'float_col'] + cols = [ + ['string_col', 'string'], + ['int_col', 'int'], + ['float_col', 'float'], + ] data = [['a', 4, 4.0]] - cdf = convert_results_to_df(cols, data) + cdf = SupersetDataFrame(data, cols, BaseEngineSpec) self.assertEquals(len(data), cdf.size) self.assertEquals(len(cols), len(cdf.columns)) @@ -213,7 +218,7 @@ class SqlLabTests(SupersetTestCase): def test_df_conversion_tuple(self): cols = ['string_col', 'int_col', 'list_col', 'float_col'] data = [(u'Text', 111, [123], 1.0)] - cdf = convert_results_to_df(cols, data) + cdf = SupersetDataFrame(data, cols, BaseEngineSpec) self.assertEquals(len(data), cdf.size) self.assertEquals(len(cols), len(cdf.columns)) @@ -221,7 +226,7 @@ class SqlLabTests(SupersetTestCase): def test_df_conversion_dict(self): cols = ['string_col', 'dict_col', 'int_col'] data = [['a', {'c1': 1, 'c2': 2, 'c3': 3}, 4]] - cdf = convert_results_to_df(cols, data) + cdf = SupersetDataFrame(data, cols, BaseEngineSpec) self.assertEquals(len(data), cdf.size) self.assertEquals(len(cols), len(cdf.columns))