fix: CTAS on multiple statements (#12188)

* WIP

* Add unit tests for sql_parse

* Add unit tests for sql_lab
This commit is contained in:
Beto Dealmeida 2021-01-04 09:22:35 -08:00 committed by GitHub
parent ff0fe434e4
commit 164db3e5a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 295 additions and 13 deletions

View File

@ -36,7 +36,7 @@ from superset.db_engine_specs import BaseEngineSpec
from superset.extensions import celery_app
from superset.models.sql_lab import Query
from superset.result_set import SupersetResultSet
from superset.sql_parse import ParsedQuery
from superset.sql_parse import CtasMethod, ParsedQuery
from superset.utils.celery import session_scope
from superset.utils.core import (
json_iso_dttm_ser,
@ -160,6 +160,7 @@ def execute_sql_statement(
session: Session,
cursor: Any,
log_params: Optional[Dict[str, Any]],
apply_ctas: bool = False,
) -> SupersetResultSet:
"""Executes a single SQL statement"""
database = query.database
@ -171,14 +172,7 @@ def execute_sql_statement(
raise SqlLabSecurityException(
_("Only `SELECT` statements are allowed against this database")
)
if query.select_as_cta:
if not parsed_query.is_select():
raise SqlLabException(
_(
"Only `SELECT` statements can be used with the CREATE TABLE "
"feature."
)
)
if apply_ctas:
if not query.tmp_table_name:
start_dttm = datetime.fromtimestamp(query.start_time)
query.tmp_table_name = "tmp_{}_table_{}".format(
@ -322,8 +316,8 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca
raise SqlLabException("Results backend isn't configured.")
# Breaking down into multiple statements
parsed_query = ParsedQuery(rendered_query, strip_comments=True)
if not db_engine_spec.run_multiple_statements_as_one:
parsed_query = ParsedQuery(rendered_query)
statements = parsed_query.get_statements()
logger.info(
"Query %s: Executing %i statement(s)", str(query_id), len(statements)
@ -337,6 +331,32 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca
query.start_running_time = now_as_float()
session.commit()
# Should we create a table or view from the select?
if (
query.select_as_cta
and query.ctas_method == CtasMethod.TABLE
and not parsed_query.is_valid_ctas()
):
raise SqlLabException(
_(
"CTAS (create table as select) can only be run with a query where "
"the last statement is a SELECT. Please make sure your query has "
"a SELECT as its last statement. Then, try running your query again."
)
)
if (
query.select_as_cta
and query.ctas_method == CtasMethod.VIEW
and not parsed_query.is_valid_cvas()
):
raise SqlLabException(
_(
"CVAS (create view as select) can only be run with a query with "
"a single SELECT statement. Please make sure your query has only "
"a SELECT statement. Then, try running your query again."
)
)
engine = database.get_sqla_engine(
schema=query.schema,
nullpool=True,
@ -354,6 +374,15 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca
if query.status == QueryStatus.STOPPED:
return None
# For CTAS we create the table only on the last statement
apply_ctas = query.select_as_cta and (
query.ctas_method == CtasMethod.VIEW
or (
query.ctas_method == CtasMethod.TABLE
and i == len(statements) - 1
)
)
# Run statement
msg = f"Running statement {i+1} out of {statement_count}"
logger.info("Query %s: %s", str(query_id), msg)
@ -361,7 +390,13 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca
session.commit()
try:
result_set = execute_sql_statement(
statement, query, user_name, session, cursor, log_params
statement,
query,
user_name,
session,
cursor,
log_params,
apply_ctas,
)
except Exception as ex: # pylint: disable=broad-except
msg = str(ex)

View File

@ -81,7 +81,10 @@ class Table: # pylint: disable=too-few-public-methods
class ParsedQuery:
def __init__(self, sql_statement: str):
def __init__(self, sql_statement: str, strip_comments: bool = False):
if strip_comments:
sql_statement = sqlparse.format(sql_statement, strip_comments=True)
self.sql: str = sql_statement
self._tables: Set[Table] = set()
self._alias_names: Set[str] = set()
@ -110,6 +113,12 @@ class ParsedQuery:
def is_select(self) -> bool:
return self._parsed[0].get_type() == "SELECT"
def is_valid_ctas(self) -> bool:
return self._parsed[-1].get_type() == "SELECT"
def is_valid_cvas(self) -> bool:
return len(self._parsed) == 1 and self._parsed[0].get_type() == "SELECT"
def is_explain(self) -> bool:
# Remove comments
statements_without_comments = sqlparse.format(

View File

@ -656,3 +656,79 @@ class TestSupersetSqlParse(unittest.TestCase):
"""
parsed = ParsedQuery(query)
self.assertEqual(parsed.is_explain(), False)
def test_is_valid_ctas(self):
"""A valid CTAS has a SELECT as its last statement"""
query = "SELECT * FROM table"
parsed = ParsedQuery(query, strip_comments=True)
assert parsed.is_valid_ctas()
query = """
-- comment
SELECT * FROM table
-- comment 2
"""
parsed = ParsedQuery(query, strip_comments=True)
assert parsed.is_valid_ctas()
query = """
-- comment
SET @value = 42;
SELECT @value as foo;
-- comment 2
"""
parsed = ParsedQuery(query, strip_comments=True)
assert parsed.is_valid_ctas()
query = """
-- comment
EXPLAIN SELECT * FROM table
-- comment 2
"""
parsed = ParsedQuery(query, strip_comments=True)
assert not parsed.is_valid_ctas()
query = """
SELECT * FROM table;
INSERT INTO TABLE (foo) VALUES (42);
"""
parsed = ParsedQuery(query, strip_comments=True)
assert not parsed.is_valid_ctas()
def test_is_valid_cvas(self):
"""A valid CVAS has a single SELECT statement"""
query = "SELECT * FROM table"
parsed = ParsedQuery(query, strip_comments=True)
assert parsed.is_valid_cvas()
query = """
-- comment
SELECT * FROM table
-- comment 2
"""
parsed = ParsedQuery(query, strip_comments=True)
assert parsed.is_valid_cvas()
query = """
-- comment
SET @value = 42;
SELECT @value as foo;
-- comment 2
"""
parsed = ParsedQuery(query, strip_comments=True)
assert not parsed.is_valid_cvas()
query = """
-- comment
EXPLAIN SELECT * FROM table
-- comment 2
"""
parsed = ParsedQuery(query, strip_comments=True)
assert not parsed.is_valid_ctas()
query = """
SELECT * FROM table;
INSERT INTO TABLE (foo) VALUES (42);
"""
parsed = ParsedQuery(query, strip_comments=True)
assert not parsed.is_valid_ctas()

View File

@ -18,11 +18,12 @@
"""Unit tests for Sql Lab"""
import json
from datetime import datetime, timedelta
from parameterized import parameterized
from random import random
from unittest import mock
from parameterized import parameterized
import prison
import pytest
from superset import db, security_manager
from superset.connectors.sqla.models import SqlaTable
@ -30,6 +31,7 @@ from superset.db_engine_specs import BaseEngineSpec
from superset.errors import ErrorLevel, SupersetErrorType
from superset.models.sql_lab import Query, SavedQuery
from superset.result_set import SupersetResultSet
from superset.sql_lab import execute_sql_statements, SqlLabException
from superset.sql_parse import CtasMethod
from superset.utils.core import (
datetime_to_epoch,
@ -618,3 +620,163 @@ class TestSqlLab(SupersetTestCase):
"template_parameters": {"state": "CA"},
"undefined_parameters": ["stat"],
}
@mock.patch("superset.sql_lab.get_query")
@mock.patch("superset.sql_lab.execute_sql_statement")
def test_execute_sql_statements(self, mock_execute_sql_statement, mock_get_query):
sql = """
-- comment
SET @value = 42;
SELECT @value AS foo;
-- comment
"""
mock_session = mock.MagicMock()
mock_query = mock.MagicMock()
mock_query.database.allow_run_async = False
mock_cursor = mock.MagicMock()
mock_query.database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value = (
mock_cursor
)
mock_query.database.db_engine_spec.run_multiple_statements_as_one = False
mock_get_query.return_value = mock_query
execute_sql_statements(
query_id=1,
rendered_query=sql,
return_results=True,
store_results=False,
user_name="admin",
session=mock_session,
start_time=None,
expand_data=False,
log_params=None,
)
mock_execute_sql_statement.assert_has_calls(
[
mock.call(
"SET @value = 42",
mock_query,
"admin",
mock_session,
mock_cursor,
None,
False,
),
mock.call(
"SELECT @value AS foo",
mock_query,
"admin",
mock_session,
mock_cursor,
None,
False,
),
]
)
@mock.patch("superset.sql_lab.get_query")
@mock.patch("superset.sql_lab.execute_sql_statement")
def test_execute_sql_statements_ctas(
self, mock_execute_sql_statement, mock_get_query
):
sql = """
-- comment
SET @value = 42;
SELECT @value AS foo;
-- comment
"""
mock_session = mock.MagicMock()
mock_query = mock.MagicMock()
mock_query.database.allow_run_async = False
mock_cursor = mock.MagicMock()
mock_query.database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value = (
mock_cursor
)
mock_query.database.db_engine_spec.run_multiple_statements_as_one = False
mock_get_query.return_value = mock_query
# set the query to CTAS
mock_query.select_as_cta = True
mock_query.ctas_method = CtasMethod.TABLE
execute_sql_statements(
query_id=1,
rendered_query=sql,
return_results=True,
store_results=False,
user_name="admin",
session=mock_session,
start_time=None,
expand_data=False,
log_params=None,
)
mock_execute_sql_statement.assert_has_calls(
[
mock.call(
"SET @value = 42",
mock_query,
"admin",
mock_session,
mock_cursor,
None,
False,
),
mock.call(
"SELECT @value AS foo",
mock_query,
"admin",
mock_session,
mock_cursor,
None,
True, # apply_ctas
),
]
)
# try invalid CTAS
sql = "DROP TABLE my_table"
with pytest.raises(SqlLabException) as excinfo:
execute_sql_statements(
query_id=1,
rendered_query=sql,
return_results=True,
store_results=False,
user_name="admin",
session=mock_session,
start_time=None,
expand_data=False,
log_params=None,
)
assert str(excinfo.value) == (
"CTAS (create table as select) can only be run with "
"a query where the last statement is a SELECT. Please "
"make sure your query has a SELECT as its last "
"statement. Then, try running your query again."
)
# try invalid CVAS
mock_query.ctas_method = CtasMethod.VIEW
sql = """
-- comment
SET @value = 42;
SELECT @value AS foo;
-- comment
"""
with pytest.raises(SqlLabException) as excinfo:
execute_sql_statements(
query_id=1,
rendered_query=sql,
return_results=True,
store_results=False,
user_name="admin",
session=mock_session,
start_time=None,
expand_data=False,
log_params=None,
)
assert str(excinfo.value) == (
"CVAS (create view as select) can only be run with a "
"query with a single SELECT statement. Please make "
"sure your query has only a SELECT statement. Then, "
"try running your query again."
)