[bugfix] temporal columns with expression fail (#4890)
* [bugfix] temporal columns with expression fail error msg: "local variable 'literal' referenced before assignment" Error occurs [only] when using temporal column defined as a SQL expression. Also noticed that examples were using `granularity` instead of using `granularity_sqla` as they should. Fixed that here. * Add tests
This commit is contained in:
parent
fa3da8c888
commit
3f48c005df
|
|
@ -241,6 +241,11 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
|
|||
def default_query(qry):
|
||||
return qry
|
||||
|
||||
def get_column(self, column_name):
|
||||
for col in self.columns:
|
||||
if col.column_name == column_name:
|
||||
return col
|
||||
|
||||
|
||||
class BaseColumn(AuditMixinNullable, ImportMixin):
|
||||
"""Interface for column"""
|
||||
|
|
|
|||
|
|
@ -117,22 +117,24 @@ class TableColumn(Model, BaseColumn):
|
|||
|
||||
def get_timestamp_expression(self, time_grain):
|
||||
"""Getting the time component of the query"""
|
||||
pdf = self.python_date_format
|
||||
is_epoch = pdf in ('epoch_s', 'epoch_ms')
|
||||
if not self.expression and not time_grain and not is_epoch:
|
||||
return column(self.column_name, type_=DateTime).label(DTTM_ALIAS)
|
||||
|
||||
expr = self.expression or self.column_name
|
||||
if not self.expression and not time_grain:
|
||||
return column(expr, type_=DateTime).label(DTTM_ALIAS)
|
||||
if is_epoch:
|
||||
# if epoch, translate to DATE using db specific conf
|
||||
db_spec = self.table.database.db_engine_spec
|
||||
if pdf == 'epoch_s':
|
||||
expr = db_spec.epoch_to_dttm().format(col=expr)
|
||||
elif pdf == 'epoch_ms':
|
||||
expr = db_spec.epoch_ms_to_dttm().format(col=expr)
|
||||
if time_grain:
|
||||
pdf = self.python_date_format
|
||||
if pdf in ('epoch_s', 'epoch_ms'):
|
||||
# if epoch, translate to DATE using db specific conf
|
||||
db_spec = self.table.database.db_engine_spec
|
||||
if pdf == 'epoch_s':
|
||||
expr = db_spec.epoch_to_dttm().format(col=expr)
|
||||
elif pdf == 'epoch_ms':
|
||||
expr = db_spec.epoch_ms_to_dttm().format(col=expr)
|
||||
grain = self.table.database.grains_dict().get(time_grain)
|
||||
literal = grain.function if grain else '{col}'
|
||||
literal = expr.format(col=expr)
|
||||
return literal_column(literal, type_=DateTime).label(DTTM_ALIAS)
|
||||
if grain:
|
||||
expr = grain.function.format(col=expr)
|
||||
return literal_column(expr, type_=DateTime).label(DTTM_ALIAS)
|
||||
|
||||
@classmethod
|
||||
def import_obj(cls, i_column):
|
||||
|
|
|
|||
|
|
@ -188,7 +188,7 @@ def load_world_bank_health_n_pop():
|
|||
"compare_lag": "10",
|
||||
"compare_suffix": "o10Y",
|
||||
"limit": "25",
|
||||
"granularity": "year",
|
||||
"granularity_sqla": "year",
|
||||
"groupby": [],
|
||||
"metric": 'sum__SP_POP_TOTL',
|
||||
"metrics": ["sum__SP_POP_TOTL"],
|
||||
|
|
@ -593,7 +593,7 @@ def load_birth_names():
|
|||
"compare_lag": "10",
|
||||
"compare_suffix": "o10Y",
|
||||
"limit": "25",
|
||||
"granularity": "ds",
|
||||
"granularity_sqla": "ds",
|
||||
"groupby": [],
|
||||
"metric": 'sum__num',
|
||||
"metrics": ["sum__num"],
|
||||
|
|
@ -642,7 +642,7 @@ def load_birth_names():
|
|||
datasource_id=tbl.id,
|
||||
params=get_slice_json(
|
||||
defaults,
|
||||
viz_type="big_number", granularity="ds",
|
||||
viz_type="big_number", granularity_sqla="ds",
|
||||
compare_lag="5", compare_suffix="over 5Y")),
|
||||
Slice(
|
||||
slice_name="Genders",
|
||||
|
|
@ -675,7 +675,7 @@ def load_birth_names():
|
|||
params=get_slice_json(
|
||||
defaults,
|
||||
viz_type="line", groupby=['name'],
|
||||
granularity='ds', rich_tooltip=True, show_legend=True)),
|
||||
granularity_sqla='ds', rich_tooltip=True, show_legend=True)),
|
||||
Slice(
|
||||
slice_name="Average and Sum Trends",
|
||||
viz_type='dual_line',
|
||||
|
|
@ -684,7 +684,7 @@ def load_birth_names():
|
|||
params=get_slice_json(
|
||||
defaults,
|
||||
viz_type="dual_line", metric='avg__num', metric_2='sum__num',
|
||||
granularity='ds')),
|
||||
granularity_sqla='ds')),
|
||||
Slice(
|
||||
slice_name="Title",
|
||||
viz_type='markup',
|
||||
|
|
@ -729,7 +729,7 @@ def load_birth_names():
|
|||
datasource_id=tbl.id,
|
||||
params=get_slice_json(
|
||||
defaults,
|
||||
viz_type="big_number_total", granularity="ds",
|
||||
viz_type="big_number_total", granularity_sqla="ds",
|
||||
filters=[{
|
||||
'col': 'gender',
|
||||
'op': 'in',
|
||||
|
|
@ -876,7 +876,7 @@ def load_unicode_test_data():
|
|||
tbl = obj
|
||||
|
||||
slice_data = {
|
||||
"granularity": "dttm",
|
||||
"granularity_sqla": "dttm",
|
||||
"groupby": [],
|
||||
"metric": 'sum__value',
|
||||
"row_limit": config.get("ROW_LIMIT"),
|
||||
|
|
@ -954,7 +954,7 @@ def load_random_time_series_data():
|
|||
tbl = obj
|
||||
|
||||
slice_data = {
|
||||
"granularity": "day",
|
||||
"granularity_sqla": "day",
|
||||
"row_limit": config.get("ROW_LIMIT"),
|
||||
"since": "1 year ago",
|
||||
"until": "now",
|
||||
|
|
@ -1017,7 +1017,7 @@ def load_country_map_data():
|
|||
tbl = obj
|
||||
|
||||
slice_data = {
|
||||
"granularity": "",
|
||||
"granularity_sqla": "",
|
||||
"since": "",
|
||||
"until": "",
|
||||
"where": "",
|
||||
|
|
@ -1092,7 +1092,7 @@ def load_long_lat_data():
|
|||
tbl = obj
|
||||
|
||||
slice_data = {
|
||||
"granularity": "day",
|
||||
"granularity_sqla": "day",
|
||||
"since": "2014-01-01",
|
||||
"until": "now",
|
||||
"where": "",
|
||||
|
|
@ -1172,7 +1172,7 @@ def load_multiformat_time_series_data():
|
|||
slice_data = {
|
||||
"metric": 'count',
|
||||
"granularity_sqla": col.column_name,
|
||||
"granularity": "day",
|
||||
"granularity_sqla": "day",
|
||||
"row_limit": config.get("ROW_LIMIT"),
|
||||
"since": "1 year ago",
|
||||
"until": "now",
|
||||
|
|
|
|||
|
|
@ -105,3 +105,64 @@ class DatabaseModelTestCase(SupersetTestCase):
|
|||
self.assertEquals(d.get('day').function, 'DATE({col})')
|
||||
self.assertEquals(d.get('P1D').function, 'DATE({col})')
|
||||
self.assertEquals(d.get('Time Column').function, '{col}')
|
||||
|
||||
|
||||
class SqlaTableModelTestCase(SupersetTestCase):
|
||||
|
||||
def test_get_timestamp_expression(self):
|
||||
tbl = self.get_table_by_name('birth_names')
|
||||
ds_col = tbl.get_column('ds')
|
||||
sqla_literal = ds_col.get_timestamp_expression(None)
|
||||
self.assertEquals(str(sqla_literal.compile()), 'ds')
|
||||
|
||||
sqla_literal = ds_col.get_timestamp_expression('P1D')
|
||||
compiled = '{}'.format(sqla_literal.compile())
|
||||
if tbl.database.backend == 'mysql':
|
||||
self.assertEquals(compiled, 'DATE(ds)')
|
||||
|
||||
ds_col.expression = 'DATE_ADD(ds, 1)'
|
||||
sqla_literal = ds_col.get_timestamp_expression('P1D')
|
||||
compiled = '{}'.format(sqla_literal.compile())
|
||||
if tbl.database.backend == 'mysql':
|
||||
self.assertEquals(compiled, 'DATE(DATE_ADD(ds, 1))')
|
||||
|
||||
def test_get_timestamp_expression_epoch(self):
|
||||
tbl = self.get_table_by_name('birth_names')
|
||||
ds_col = tbl.get_column('ds')
|
||||
|
||||
ds_col.expression = None
|
||||
ds_col.python_date_format = 'epoch_s'
|
||||
sqla_literal = ds_col.get_timestamp_expression(None)
|
||||
compiled = '{}'.format(sqla_literal.compile())
|
||||
if tbl.database.backend == 'mysql':
|
||||
self.assertEquals(compiled, 'from_unixtime(ds)')
|
||||
|
||||
ds_col.python_date_format = 'epoch_s'
|
||||
sqla_literal = ds_col.get_timestamp_expression('P1D')
|
||||
compiled = '{}'.format(sqla_literal.compile())
|
||||
if tbl.database.backend == 'mysql':
|
||||
self.assertEquals(compiled, 'DATE(from_unixtime(ds))')
|
||||
|
||||
ds_col.expression = 'DATE_ADD(ds, 1)'
|
||||
sqla_literal = ds_col.get_timestamp_expression('P1D')
|
||||
compiled = '{}'.format(sqla_literal.compile())
|
||||
if tbl.database.backend == 'mysql':
|
||||
self.assertEquals(compiled, 'DATE(from_unixtime(DATE_ADD(ds, 1)))')
|
||||
|
||||
def test_get_timestamp_expression_backward(self):
|
||||
tbl = self.get_table_by_name('birth_names')
|
||||
ds_col = tbl.get_column('ds')
|
||||
|
||||
ds_col.expression = None
|
||||
ds_col.python_date_format = None
|
||||
sqla_literal = ds_col.get_timestamp_expression('day')
|
||||
compiled = '{}'.format(sqla_literal.compile())
|
||||
if tbl.database.backend == 'mysql':
|
||||
self.assertEquals(compiled, 'DATE(ds)')
|
||||
|
||||
ds_col.expression = None
|
||||
ds_col.python_date_format = None
|
||||
sqla_literal = ds_col.get_timestamp_expression('Time Column')
|
||||
compiled = '{}'.format(sqla_literal.compile())
|
||||
if tbl.database.backend == 'mysql':
|
||||
self.assertEquals(compiled, 'ds')
|
||||
|
|
|
|||
Loading…
Reference in New Issue