Escaping the user's SQL in the explore view (#3186)
* Escaping the user's SQL in the explore view When executing SQL from SQL Lab, we use a lower level API to the database which doesn't require escaping the SQL. When going through the explore view, the stack chain leading to the same method may need escaping depending on how the DBAPI driver is written, and that is the case for Presto (and perhaps other drivers). * Using regex to avoid doubling doubles
This commit is contained in:
parent
fb866a937b
commit
25c599d040
|
|
@ -285,10 +285,12 @@ class SqlaTable(Model, BaseDatasource):
|
|||
"""
|
||||
cols = {col.column_name: col for col in self.columns}
|
||||
target_col = cols[column_name]
|
||||
tp = self.get_template_processor()
|
||||
db_engine_spec = self.database.db_engine_spec
|
||||
|
||||
qry = (
|
||||
select([target_col.sqla_col])
|
||||
.select_from(self.get_from_clause())
|
||||
.select_from(self.get_from_clause(tp, db_engine_spec))
|
||||
.distinct(column_name)
|
||||
)
|
||||
if limit:
|
||||
|
|
@ -322,7 +324,6 @@ class SqlaTable(Model, BaseDatasource):
|
|||
)
|
||||
logging.info(sql)
|
||||
sql = sqlparse.format(sql, reindent=True)
|
||||
sql = self.database.db_engine_spec.sql_preprocessor(sql)
|
||||
return sql
|
||||
|
||||
def get_sqla_table(self):
|
||||
|
|
@ -331,12 +332,14 @@ class SqlaTable(Model, BaseDatasource):
|
|||
tbl.schema = self.schema
|
||||
return tbl
|
||||
|
||||
def get_from_clause(self, template_processor=None):
|
||||
def get_from_clause(self, template_processor=None, db_engine_spec=None):
|
||||
# Supporting arbitrary SQL statements in place of tables
|
||||
if self.sql:
|
||||
from_sql = self.sql
|
||||
if template_processor:
|
||||
from_sql = template_processor.process_template(from_sql)
|
||||
if db_engine_spec:
|
||||
from_sql = db_engine_spec.escape_sql(from_sql)
|
||||
return TextAsFrom(sa.text(from_sql), []).alias('expr_qry')
|
||||
return self.get_sqla_table()
|
||||
|
||||
|
|
@ -367,13 +370,14 @@ class SqlaTable(Model, BaseDatasource):
|
|||
'form_data': form_data,
|
||||
}
|
||||
template_processor = self.get_template_processor(**template_kwargs)
|
||||
db_engine_spec = self.database.db_engine_spec
|
||||
|
||||
# For backward compatibility
|
||||
if granularity not in self.dttm_cols:
|
||||
granularity = self.main_dttm_col
|
||||
|
||||
# Database spec supports join-free timeslot grouping
|
||||
time_groupby_inline = self.database.db_engine_spec.time_groupby_inline
|
||||
time_groupby_inline = db_engine_spec.time_groupby_inline
|
||||
|
||||
cols = {col.column_name: col for col in self.columns}
|
||||
metrics_dict = {m.metric_name: m for m in self.metrics}
|
||||
|
|
@ -428,7 +432,7 @@ class SqlaTable(Model, BaseDatasource):
|
|||
groupby_exprs += [timestamp]
|
||||
|
||||
# Use main dttm column to support index with secondary dttm columns
|
||||
if self.database.db_engine_spec.time_secondary_columns and \
|
||||
if db_engine_spec.time_secondary_columns and \
|
||||
self.main_dttm_col in self.dttm_cols and \
|
||||
self.main_dttm_col != dttm_col.column_name:
|
||||
time_filters.append(cols[self.main_dttm_col].
|
||||
|
|
@ -438,7 +442,7 @@ class SqlaTable(Model, BaseDatasource):
|
|||
select_exprs += metrics_exprs
|
||||
qry = sa.select(select_exprs)
|
||||
|
||||
tbl = self.get_from_clause(template_processor)
|
||||
tbl = self.get_from_clause(template_processor, db_engine_spec)
|
||||
|
||||
if not columns:
|
||||
qry = qry.group_by(*groupby_exprs)
|
||||
|
|
|
|||
|
|
@ -73,6 +73,11 @@ class BaseEngineSpec(object):
|
|||
"""Returns engine-specific table metadata"""
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def escape_sql(cls, sql):
|
||||
"""Escapes the raw SQL"""
|
||||
return sql
|
||||
|
||||
@classmethod
|
||||
def convert_dttm(cls, target_type, dttm):
|
||||
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
|
||||
|
|
@ -139,14 +144,6 @@ class BaseEngineSpec(object):
|
|||
"""
|
||||
return uri
|
||||
|
||||
@classmethod
|
||||
def sql_preprocessor(cls, sql):
|
||||
"""If the SQL needs to be altered prior to running it
|
||||
|
||||
For example Presto needs to double `%` characters
|
||||
"""
|
||||
return sql
|
||||
|
||||
@classmethod
|
||||
def patch(cls):
|
||||
pass
|
||||
|
|
@ -399,6 +396,10 @@ class PrestoEngineSpec(BaseEngineSpec):
|
|||
uri.database = database
|
||||
return uri
|
||||
|
||||
@classmethod
|
||||
def escape_sql(cls, sql):
|
||||
return re.sub(r'%%|%', "%%", sql)
|
||||
|
||||
@classmethod
|
||||
def convert_dttm(cls, target_type, dttm):
|
||||
tt = target_type.upper()
|
||||
|
|
|
|||
|
|
@ -154,7 +154,6 @@ def execute_sql(ctask, query_id, return_results=True, store_results=False):
|
|||
template_processor = get_template_processor(
|
||||
database=database, query=query)
|
||||
executed_sql = template_processor.process_template(executed_sql)
|
||||
executed_sql = db_engine_spec.sql_preprocessor(executed_sql)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
msg = "Template rendering failed: " + utils.error_msg_from_exception(e)
|
||||
|
|
|
|||
Loading…
Reference in New Issue