Force lowercase column names for Snowflake and Oracle (#4994)
* Force lowercase column names for Snowflake and Oracle * Force lowercase column names for Snowflake and Oracle * Remove lowercasing of DB2 columns * Remove DB2 lowercasing * Fix test cases
This commit is contained in:
parent
071c6a6c03
commit
b391676544
|
|
@ -281,6 +281,15 @@ class BaseEngineSpec(object):
|
|||
"""
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def get_normalized_column_names(cls, cursor_description):
|
||||
columns = cursor_description if cursor_description else []
|
||||
return [cls.normalize_column_name(col[0]) for col in columns]
|
||||
|
||||
@staticmethod
|
||||
def normalize_column_name(column_name):
|
||||
return column_name
|
||||
|
||||
|
||||
class PostgresBaseEngineSpec(BaseEngineSpec):
|
||||
""" Abstract class for Postgres 'like' databases """
|
||||
|
|
@ -350,6 +359,10 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
|||
Grain('year', _('year'), "DATE_TRUNC('YEAR', {col})", 'P1Y'),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def normalize_column_name(column_name):
|
||||
return column_name.lower()
|
||||
|
||||
|
||||
class VerticaEngineSpec(PostgresBaseEngineSpec):
|
||||
engine = 'vertica'
|
||||
|
|
@ -379,6 +392,10 @@ class OracleEngineSpec(PostgresBaseEngineSpec):
|
|||
"""TO_TIMESTAMP('{}', 'YYYY-MM-DD"T"HH24:MI:SS.ff6')"""
|
||||
).format(dttm.isoformat())
|
||||
|
||||
@staticmethod
|
||||
def normalize_column_name(column_name):
|
||||
return column_name.lower()
|
||||
|
||||
|
||||
class Db2EngineSpec(BaseEngineSpec):
|
||||
engine = 'ibm_db_sa'
|
||||
|
|
|
|||
|
|
@ -97,10 +97,8 @@ def session_scope(nullpool):
|
|||
session.close()
|
||||
|
||||
|
||||
def convert_results_to_df(cursor_description, data):
|
||||
def convert_results_to_df(column_names, data):
|
||||
"""Convert raw query results to a DataFrame."""
|
||||
column_names = (
|
||||
[col[0] for col in cursor_description] if cursor_description else [])
|
||||
column_names = dedup(column_names)
|
||||
|
||||
# check whether the result set has any nested dict columns
|
||||
|
|
@ -236,7 +234,7 @@ def execute_sql(
|
|||
return handle_error(db_engine_spec.extract_error_message(e))
|
||||
|
||||
logging.info('Fetching cursor description')
|
||||
cursor_description = cursor.description
|
||||
column_names = db_engine_spec.get_normalized_column_names(cursor.description)
|
||||
|
||||
if conn is not None:
|
||||
conn.commit()
|
||||
|
|
@ -245,7 +243,7 @@ def execute_sql(
|
|||
if query.status == utils.QueryStatus.STOPPED:
|
||||
return handle_error('The query has been stopped')
|
||||
|
||||
cdf = convert_results_to_df(cursor_description, data)
|
||||
cdf = convert_results_to_df(column_names, data)
|
||||
|
||||
query.rows = cdf.size
|
||||
query.progress = 100
|
||||
|
|
|
|||
|
|
@ -203,7 +203,7 @@ class SqlLabTests(SupersetTestCase):
|
|||
raise_on_error=True)
|
||||
|
||||
def test_df_conversion_no_dict(self):
|
||||
cols = [['string_col'], ['int_col'], ['float_col']]
|
||||
cols = ['string_col', 'int_col', 'float_col']
|
||||
data = [['a', 4, 4.0]]
|
||||
cdf = convert_results_to_df(cols, data)
|
||||
|
||||
|
|
@ -211,7 +211,7 @@ class SqlLabTests(SupersetTestCase):
|
|||
self.assertEquals(len(cols), len(cdf.columns))
|
||||
|
||||
def test_df_conversion_tuple(self):
|
||||
cols = [['string_col'], ['int_col'], ['list_col'], ['float_col']]
|
||||
cols = ['string_col', 'int_col', 'list_col', 'float_col']
|
||||
data = [(u'Text', 111, [123], 1.0)]
|
||||
cdf = convert_results_to_df(cols, data)
|
||||
|
||||
|
|
@ -219,7 +219,7 @@ class SqlLabTests(SupersetTestCase):
|
|||
self.assertEquals(len(cols), len(cdf.columns))
|
||||
|
||||
def test_df_conversion_dict(self):
|
||||
cols = [['string_col'], ['dict_col'], ['int_col']]
|
||||
cols = ['string_col', 'dict_col', 'int_col']
|
||||
data = [['a', {'c1': 1, 'c2': 2, 'c3': 3}, 4]]
|
||||
cdf = convert_results_to_df(cols, data)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue