fix(explore): strip semicolons in virtual table SQL (#13801)

* add method to strip semicolon

* address comments

* test the test

* Update tests/sqla_models_tests.py

Co-authored-by: Jesse Yang <jesse.yang@airbnb.com>

* Update tests/sqla_models_tests.py

Co-authored-by: Ville Brofeldt <33317356+villebro@users.noreply.github.com>

* fix test

* add suggestion

* fix trailing space

* remove logger

* fix unit test

Co-authored-by: Jesse Yang <jesse.yang@airbnb.com>
Co-authored-by: Ville Brofeldt <33317356+villebro@users.noreply.github.com>
This commit is contained in:
Phillip Kelley-Dotson 2021-04-06 13:40:34 -07:00 committed by GitHub
parent c0888dc16d
commit 34991f5fab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 2 deletions

View File

@ -773,7 +773,6 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
def get_query_str_extended(self, query_obj: QueryObjectDict) -> QueryStringExtended:
sqlaq = self.get_sqla_query(**query_obj)
sql = self.database.compile_sqla_query(sqlaq.sqla_query)
logger.info(sql)
sql = sqlparse.format(sql, reindent=True)
sql = self.mutate_query_from_config(sql)
return QueryStringExtended(
@ -818,6 +817,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
"""
Render sql with template engine (Jinja).
"""
sql = self.sql
if template_processor:
try:
@ -829,7 +829,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
msg=ex.message,
)
)
sql = sqlparse.format(sql, strip_comments=True)
sql = sqlparse.format(sql.strip("\t\r\n; "), strip_comments=True)
if not sql:
raise QueryObjectValidationError(_("Virtual dataset query cannot be empty"))
if len(sqlparse.split(sql)) > 1:

View File

@ -223,6 +223,28 @@ class TestDatabaseModel(SupersetTestCase):
with pytest.raises(QueryObjectValidationError):
table.get_sqla_query(**query_obj)
def test_query_format_strip_trailing_semicolon(self):
query_obj = {
"granularity": None,
"from_dttm": None,
"to_dttm": None,
"groupby": ["user"],
"metrics": [],
"is_timeseries": False,
"filter": [],
"extras": {},
}
# Table with Jinja callable.
table = SqlaTable(
table_name="test_table",
sql="SELECT * from test_table;",
database=get_example_database(),
)
sqlaq = table.get_sqla_query(**query_obj)
sql = table.database.compile_sqla_query(sqlaq.sqla_query)
assert sql[-1] != ";"
def test_multiple_sql_statements_raises_exception(self):
base_query_obj = {
"granularity": None,