[bugfix] temporal columns with expression fail (#4890)

* [bugfix] temporal columns with expression fail

error msg: "local variable 'literal' referenced before assignment"

Error occurs [only] when using temporal column defined as a SQL
expression.

Also noticed that examples were using `granularity` instead of using
`granularity_sqla` as they should. Fixed that here.

* Add tests
This commit is contained in:
Maxime Beauchemin 2018-04-26 21:13:52 -07:00 committed by GitHub
parent fa3da8c888
commit 3f48c005df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 92 additions and 24 deletions

View File

@ -241,6 +241,11 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
def default_query(qry):
return qry
def get_column(self, column_name):
for col in self.columns:
if col.column_name == column_name:
return col
class BaseColumn(AuditMixinNullable, ImportMixin):
"""Interface for column"""

View File

@ -117,22 +117,24 @@ class TableColumn(Model, BaseColumn):
def get_timestamp_expression(self, time_grain):
"""Getting the time component of the query"""
pdf = self.python_date_format
is_epoch = pdf in ('epoch_s', 'epoch_ms')
if not self.expression and not time_grain and not is_epoch:
return column(self.column_name, type_=DateTime).label(DTTM_ALIAS)
expr = self.expression or self.column_name
if not self.expression and not time_grain:
return column(expr, type_=DateTime).label(DTTM_ALIAS)
if is_epoch:
# if epoch, translate to DATE using db specific conf
db_spec = self.table.database.db_engine_spec
if pdf == 'epoch_s':
expr = db_spec.epoch_to_dttm().format(col=expr)
elif pdf == 'epoch_ms':
expr = db_spec.epoch_ms_to_dttm().format(col=expr)
if time_grain:
pdf = self.python_date_format
if pdf in ('epoch_s', 'epoch_ms'):
# if epoch, translate to DATE using db specific conf
db_spec = self.table.database.db_engine_spec
if pdf == 'epoch_s':
expr = db_spec.epoch_to_dttm().format(col=expr)
elif pdf == 'epoch_ms':
expr = db_spec.epoch_ms_to_dttm().format(col=expr)
grain = self.table.database.grains_dict().get(time_grain)
literal = grain.function if grain else '{col}'
literal = expr.format(col=expr)
return literal_column(literal, type_=DateTime).label(DTTM_ALIAS)
if grain:
expr = grain.function.format(col=expr)
return literal_column(expr, type_=DateTime).label(DTTM_ALIAS)
@classmethod
def import_obj(cls, i_column):

View File

@ -188,7 +188,7 @@ def load_world_bank_health_n_pop():
"compare_lag": "10",
"compare_suffix": "o10Y",
"limit": "25",
"granularity": "year",
"granularity_sqla": "year",
"groupby": [],
"metric": 'sum__SP_POP_TOTL',
"metrics": ["sum__SP_POP_TOTL"],
@ -593,7 +593,7 @@ def load_birth_names():
"compare_lag": "10",
"compare_suffix": "o10Y",
"limit": "25",
"granularity": "ds",
"granularity_sqla": "ds",
"groupby": [],
"metric": 'sum__num',
"metrics": ["sum__num"],
@ -642,7 +642,7 @@ def load_birth_names():
datasource_id=tbl.id,
params=get_slice_json(
defaults,
viz_type="big_number", granularity="ds",
viz_type="big_number", granularity_sqla="ds",
compare_lag="5", compare_suffix="over 5Y")),
Slice(
slice_name="Genders",
@ -675,7 +675,7 @@ def load_birth_names():
params=get_slice_json(
defaults,
viz_type="line", groupby=['name'],
granularity='ds', rich_tooltip=True, show_legend=True)),
granularity_sqla='ds', rich_tooltip=True, show_legend=True)),
Slice(
slice_name="Average and Sum Trends",
viz_type='dual_line',
@ -684,7 +684,7 @@ def load_birth_names():
params=get_slice_json(
defaults,
viz_type="dual_line", metric='avg__num', metric_2='sum__num',
granularity='ds')),
granularity_sqla='ds')),
Slice(
slice_name="Title",
viz_type='markup',
@ -729,7 +729,7 @@ def load_birth_names():
datasource_id=tbl.id,
params=get_slice_json(
defaults,
viz_type="big_number_total", granularity="ds",
viz_type="big_number_total", granularity_sqla="ds",
filters=[{
'col': 'gender',
'op': 'in',
@ -876,7 +876,7 @@ def load_unicode_test_data():
tbl = obj
slice_data = {
"granularity": "dttm",
"granularity_sqla": "dttm",
"groupby": [],
"metric": 'sum__value',
"row_limit": config.get("ROW_LIMIT"),
@ -954,7 +954,7 @@ def load_random_time_series_data():
tbl = obj
slice_data = {
"granularity": "day",
"granularity_sqla": "day",
"row_limit": config.get("ROW_LIMIT"),
"since": "1 year ago",
"until": "now",
@ -1017,7 +1017,7 @@ def load_country_map_data():
tbl = obj
slice_data = {
"granularity": "",
"granularity_sqla": "",
"since": "",
"until": "",
"where": "",
@ -1092,7 +1092,7 @@ def load_long_lat_data():
tbl = obj
slice_data = {
"granularity": "day",
"granularity_sqla": "day",
"since": "2014-01-01",
"until": "now",
"where": "",
@ -1172,7 +1172,7 @@ def load_multiformat_time_series_data():
slice_data = {
"metric": 'count',
"granularity_sqla": col.column_name,
"granularity": "day",
"granularity_sqla": "day",
"row_limit": config.get("ROW_LIMIT"),
"since": "1 year ago",
"until": "now",

View File

@ -105,3 +105,64 @@ class DatabaseModelTestCase(SupersetTestCase):
self.assertEquals(d.get('day').function, 'DATE({col})')
self.assertEquals(d.get('P1D').function, 'DATE({col})')
self.assertEquals(d.get('Time Column').function, '{col}')
class SqlaTableModelTestCase(SupersetTestCase):
def test_get_timestamp_expression(self):
tbl = self.get_table_by_name('birth_names')
ds_col = tbl.get_column('ds')
sqla_literal = ds_col.get_timestamp_expression(None)
self.assertEquals(str(sqla_literal.compile()), 'ds')
sqla_literal = ds_col.get_timestamp_expression('P1D')
compiled = '{}'.format(sqla_literal.compile())
if tbl.database.backend == 'mysql':
self.assertEquals(compiled, 'DATE(ds)')
ds_col.expression = 'DATE_ADD(ds, 1)'
sqla_literal = ds_col.get_timestamp_expression('P1D')
compiled = '{}'.format(sqla_literal.compile())
if tbl.database.backend == 'mysql':
self.assertEquals(compiled, 'DATE(DATE_ADD(ds, 1))')
def test_get_timestamp_expression_epoch(self):
tbl = self.get_table_by_name('birth_names')
ds_col = tbl.get_column('ds')
ds_col.expression = None
ds_col.python_date_format = 'epoch_s'
sqla_literal = ds_col.get_timestamp_expression(None)
compiled = '{}'.format(sqla_literal.compile())
if tbl.database.backend == 'mysql':
self.assertEquals(compiled, 'from_unixtime(ds)')
ds_col.python_date_format = 'epoch_s'
sqla_literal = ds_col.get_timestamp_expression('P1D')
compiled = '{}'.format(sqla_literal.compile())
if tbl.database.backend == 'mysql':
self.assertEquals(compiled, 'DATE(from_unixtime(ds))')
ds_col.expression = 'DATE_ADD(ds, 1)'
sqla_literal = ds_col.get_timestamp_expression('P1D')
compiled = '{}'.format(sqla_literal.compile())
if tbl.database.backend == 'mysql':
self.assertEquals(compiled, 'DATE(from_unixtime(DATE_ADD(ds, 1)))')
def test_get_timestamp_expression_backward(self):
tbl = self.get_table_by_name('birth_names')
ds_col = tbl.get_column('ds')
ds_col.expression = None
ds_col.python_date_format = None
sqla_literal = ds_col.get_timestamp_expression('day')
compiled = '{}'.format(sqla_literal.compile())
if tbl.database.backend == 'mysql':
self.assertEquals(compiled, 'DATE(ds)')
ds_col.expression = None
ds_col.python_date_format = None
sqla_literal = ds_col.get_timestamp_expression('Time Column')
compiled = '{}'.format(sqla_literal.compile())
if tbl.database.backend == 'mysql':
self.assertEquals(compiled, 'ds')