Force lowercase column names for Snowflake and Oracle (#4994)

* Force lowercase column names for Snowflake and Oracle

* Force lowercase column names for Snowflake and Oracle

* Remove lowercasing of DB2 columns

* Remove DB2 lowercasing

* Fix test cases
This commit is contained in:
Ville Brofeldt 2018-05-14 21:43:13 +03:00 committed by Maxime Beauchemin
parent 071c6a6c03
commit b391676544
3 changed files with 23 additions and 8 deletions

View File

@ -281,6 +281,15 @@ class BaseEngineSpec(object):
"""
return {}
@classmethod
def get_normalized_column_names(cls, cursor_description):
columns = cursor_description if cursor_description else []
return [cls.normalize_column_name(col[0]) for col in columns]
@staticmethod
def normalize_column_name(column_name):
return column_name
class PostgresBaseEngineSpec(BaseEngineSpec):
""" Abstract class for Postgres 'like' databases """
@ -350,6 +359,10 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
Grain('year', _('year'), "DATE_TRUNC('YEAR', {col})", 'P1Y'),
)
@staticmethod
def normalize_column_name(column_name):
return column_name.lower()
class VerticaEngineSpec(PostgresBaseEngineSpec):
engine = 'vertica'
@ -379,6 +392,10 @@ class OracleEngineSpec(PostgresBaseEngineSpec):
"""TO_TIMESTAMP('{}', 'YYYY-MM-DD"T"HH24:MI:SS.ff6')"""
).format(dttm.isoformat())
@staticmethod
def normalize_column_name(column_name):
return column_name.lower()
class Db2EngineSpec(BaseEngineSpec):
engine = 'ibm_db_sa'

View File

@ -97,10 +97,8 @@ def session_scope(nullpool):
session.close()
def convert_results_to_df(cursor_description, data):
def convert_results_to_df(column_names, data):
"""Convert raw query results to a DataFrame."""
column_names = (
[col[0] for col in cursor_description] if cursor_description else [])
column_names = dedup(column_names)
# check whether the result set has any nested dict columns
@ -236,7 +234,7 @@ def execute_sql(
return handle_error(db_engine_spec.extract_error_message(e))
logging.info('Fetching cursor description')
cursor_description = cursor.description
column_names = db_engine_spec.get_normalized_column_names(cursor.description)
if conn is not None:
conn.commit()
@ -245,7 +243,7 @@ def execute_sql(
if query.status == utils.QueryStatus.STOPPED:
return handle_error('The query has been stopped')
cdf = convert_results_to_df(cursor_description, data)
cdf = convert_results_to_df(column_names, data)
query.rows = cdf.size
query.progress = 100

View File

@ -203,7 +203,7 @@ class SqlLabTests(SupersetTestCase):
raise_on_error=True)
def test_df_conversion_no_dict(self):
cols = [['string_col'], ['int_col'], ['float_col']]
cols = ['string_col', 'int_col', 'float_col']
data = [['a', 4, 4.0]]
cdf = convert_results_to_df(cols, data)
@ -211,7 +211,7 @@ class SqlLabTests(SupersetTestCase):
self.assertEquals(len(cols), len(cdf.columns))
def test_df_conversion_tuple(self):
cols = [['string_col'], ['int_col'], ['list_col'], ['float_col']]
cols = ['string_col', 'int_col', 'list_col', 'float_col']
data = [(u'Text', 111, [123], 1.0)]
cdf = convert_results_to_df(cols, data)
@ -219,7 +219,7 @@ class SqlLabTests(SupersetTestCase):
self.assertEquals(len(cols), len(cdf.columns))
def test_df_conversion_dict(self):
cols = [['string_col'], ['dict_col'], ['int_col']]
cols = ['string_col', 'dict_col', 'int_col']
data = [['a', {'c1': 1, 'c2': 2, 'c3': 3}, 4]]
cdf = convert_results_to_df(cols, data)