fix: CTAS on multiple statements (#12188)
* WIP * Add unit tests for sql_parse * Add unit tests for sql_lab
This commit is contained in:
parent
ff0fe434e4
commit
164db3e5a1
|
|
@ -36,7 +36,7 @@ from superset.db_engine_specs import BaseEngineSpec
|
|||
from superset.extensions import celery_app
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.result_set import SupersetResultSet
|
||||
from superset.sql_parse import ParsedQuery
|
||||
from superset.sql_parse import CtasMethod, ParsedQuery
|
||||
from superset.utils.celery import session_scope
|
||||
from superset.utils.core import (
|
||||
json_iso_dttm_ser,
|
||||
|
|
@ -160,6 +160,7 @@ def execute_sql_statement(
|
|||
session: Session,
|
||||
cursor: Any,
|
||||
log_params: Optional[Dict[str, Any]],
|
||||
apply_ctas: bool = False,
|
||||
) -> SupersetResultSet:
|
||||
"""Executes a single SQL statement"""
|
||||
database = query.database
|
||||
|
|
@ -171,14 +172,7 @@ def execute_sql_statement(
|
|||
raise SqlLabSecurityException(
|
||||
_("Only `SELECT` statements are allowed against this database")
|
||||
)
|
||||
if query.select_as_cta:
|
||||
if not parsed_query.is_select():
|
||||
raise SqlLabException(
|
||||
_(
|
||||
"Only `SELECT` statements can be used with the CREATE TABLE "
|
||||
"feature."
|
||||
)
|
||||
)
|
||||
if apply_ctas:
|
||||
if not query.tmp_table_name:
|
||||
start_dttm = datetime.fromtimestamp(query.start_time)
|
||||
query.tmp_table_name = "tmp_{}_table_{}".format(
|
||||
|
|
@ -322,8 +316,8 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca
|
|||
raise SqlLabException("Results backend isn't configured.")
|
||||
|
||||
# Breaking down into multiple statements
|
||||
parsed_query = ParsedQuery(rendered_query, strip_comments=True)
|
||||
if not db_engine_spec.run_multiple_statements_as_one:
|
||||
parsed_query = ParsedQuery(rendered_query)
|
||||
statements = parsed_query.get_statements()
|
||||
logger.info(
|
||||
"Query %s: Executing %i statement(s)", str(query_id), len(statements)
|
||||
|
|
@ -337,6 +331,32 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca
|
|||
query.start_running_time = now_as_float()
|
||||
session.commit()
|
||||
|
||||
# Should we create a table or view from the select?
|
||||
if (
|
||||
query.select_as_cta
|
||||
and query.ctas_method == CtasMethod.TABLE
|
||||
and not parsed_query.is_valid_ctas()
|
||||
):
|
||||
raise SqlLabException(
|
||||
_(
|
||||
"CTAS (create table as select) can only be run with a query where "
|
||||
"the last statement is a SELECT. Please make sure your query has "
|
||||
"a SELECT as its last statement. Then, try running your query again."
|
||||
)
|
||||
)
|
||||
if (
|
||||
query.select_as_cta
|
||||
and query.ctas_method == CtasMethod.VIEW
|
||||
and not parsed_query.is_valid_cvas()
|
||||
):
|
||||
raise SqlLabException(
|
||||
_(
|
||||
"CVAS (create view as select) can only be run with a query with "
|
||||
"a single SELECT statement. Please make sure your query has only "
|
||||
"a SELECT statement. Then, try running your query again."
|
||||
)
|
||||
)
|
||||
|
||||
engine = database.get_sqla_engine(
|
||||
schema=query.schema,
|
||||
nullpool=True,
|
||||
|
|
@ -354,6 +374,15 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca
|
|||
if query.status == QueryStatus.STOPPED:
|
||||
return None
|
||||
|
||||
# For CTAS we create the table only on the last statement
|
||||
apply_ctas = query.select_as_cta and (
|
||||
query.ctas_method == CtasMethod.VIEW
|
||||
or (
|
||||
query.ctas_method == CtasMethod.TABLE
|
||||
and i == len(statements) - 1
|
||||
)
|
||||
)
|
||||
|
||||
# Run statement
|
||||
msg = f"Running statement {i+1} out of {statement_count}"
|
||||
logger.info("Query %s: %s", str(query_id), msg)
|
||||
|
|
@ -361,7 +390,13 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca
|
|||
session.commit()
|
||||
try:
|
||||
result_set = execute_sql_statement(
|
||||
statement, query, user_name, session, cursor, log_params
|
||||
statement,
|
||||
query,
|
||||
user_name,
|
||||
session,
|
||||
cursor,
|
||||
log_params,
|
||||
apply_ctas,
|
||||
)
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
msg = str(ex)
|
||||
|
|
|
|||
|
|
@ -81,7 +81,10 @@ class Table: # pylint: disable=too-few-public-methods
|
|||
|
||||
|
||||
class ParsedQuery:
|
||||
def __init__(self, sql_statement: str):
|
||||
def __init__(self, sql_statement: str, strip_comments: bool = False):
|
||||
if strip_comments:
|
||||
sql_statement = sqlparse.format(sql_statement, strip_comments=True)
|
||||
|
||||
self.sql: str = sql_statement
|
||||
self._tables: Set[Table] = set()
|
||||
self._alias_names: Set[str] = set()
|
||||
|
|
@ -110,6 +113,12 @@ class ParsedQuery:
|
|||
def is_select(self) -> bool:
|
||||
return self._parsed[0].get_type() == "SELECT"
|
||||
|
||||
def is_valid_ctas(self) -> bool:
|
||||
return self._parsed[-1].get_type() == "SELECT"
|
||||
|
||||
def is_valid_cvas(self) -> bool:
|
||||
return len(self._parsed) == 1 and self._parsed[0].get_type() == "SELECT"
|
||||
|
||||
def is_explain(self) -> bool:
|
||||
# Remove comments
|
||||
statements_without_comments = sqlparse.format(
|
||||
|
|
|
|||
|
|
@ -656,3 +656,79 @@ class TestSupersetSqlParse(unittest.TestCase):
|
|||
"""
|
||||
parsed = ParsedQuery(query)
|
||||
self.assertEqual(parsed.is_explain(), False)
|
||||
|
||||
def test_is_valid_ctas(self):
|
||||
"""A valid CTAS has a SELECT as its last statement"""
|
||||
query = "SELECT * FROM table"
|
||||
parsed = ParsedQuery(query, strip_comments=True)
|
||||
assert parsed.is_valid_ctas()
|
||||
|
||||
query = """
|
||||
-- comment
|
||||
SELECT * FROM table
|
||||
-- comment 2
|
||||
"""
|
||||
parsed = ParsedQuery(query, strip_comments=True)
|
||||
assert parsed.is_valid_ctas()
|
||||
|
||||
query = """
|
||||
-- comment
|
||||
SET @value = 42;
|
||||
SELECT @value as foo;
|
||||
-- comment 2
|
||||
"""
|
||||
parsed = ParsedQuery(query, strip_comments=True)
|
||||
assert parsed.is_valid_ctas()
|
||||
|
||||
query = """
|
||||
-- comment
|
||||
EXPLAIN SELECT * FROM table
|
||||
-- comment 2
|
||||
"""
|
||||
parsed = ParsedQuery(query, strip_comments=True)
|
||||
assert not parsed.is_valid_ctas()
|
||||
|
||||
query = """
|
||||
SELECT * FROM table;
|
||||
INSERT INTO TABLE (foo) VALUES (42);
|
||||
"""
|
||||
parsed = ParsedQuery(query, strip_comments=True)
|
||||
assert not parsed.is_valid_ctas()
|
||||
|
||||
def test_is_valid_cvas(self):
|
||||
"""A valid CVAS has a single SELECT statement"""
|
||||
query = "SELECT * FROM table"
|
||||
parsed = ParsedQuery(query, strip_comments=True)
|
||||
assert parsed.is_valid_cvas()
|
||||
|
||||
query = """
|
||||
-- comment
|
||||
SELECT * FROM table
|
||||
-- comment 2
|
||||
"""
|
||||
parsed = ParsedQuery(query, strip_comments=True)
|
||||
assert parsed.is_valid_cvas()
|
||||
|
||||
query = """
|
||||
-- comment
|
||||
SET @value = 42;
|
||||
SELECT @value as foo;
|
||||
-- comment 2
|
||||
"""
|
||||
parsed = ParsedQuery(query, strip_comments=True)
|
||||
assert not parsed.is_valid_cvas()
|
||||
|
||||
query = """
|
||||
-- comment
|
||||
EXPLAIN SELECT * FROM table
|
||||
-- comment 2
|
||||
"""
|
||||
parsed = ParsedQuery(query, strip_comments=True)
|
||||
assert not parsed.is_valid_ctas()
|
||||
|
||||
query = """
|
||||
SELECT * FROM table;
|
||||
INSERT INTO TABLE (foo) VALUES (42);
|
||||
"""
|
||||
parsed = ParsedQuery(query, strip_comments=True)
|
||||
assert not parsed.is_valid_ctas()
|
||||
|
|
|
|||
|
|
@ -18,11 +18,12 @@
|
|||
"""Unit tests for Sql Lab"""
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from parameterized import parameterized
|
||||
from random import random
|
||||
from unittest import mock
|
||||
|
||||
from parameterized import parameterized
|
||||
import prison
|
||||
import pytest
|
||||
|
||||
from superset import db, security_manager
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
|
@ -30,6 +31,7 @@ from superset.db_engine_specs import BaseEngineSpec
|
|||
from superset.errors import ErrorLevel, SupersetErrorType
|
||||
from superset.models.sql_lab import Query, SavedQuery
|
||||
from superset.result_set import SupersetResultSet
|
||||
from superset.sql_lab import execute_sql_statements, SqlLabException
|
||||
from superset.sql_parse import CtasMethod
|
||||
from superset.utils.core import (
|
||||
datetime_to_epoch,
|
||||
|
|
@ -618,3 +620,163 @@ class TestSqlLab(SupersetTestCase):
|
|||
"template_parameters": {"state": "CA"},
|
||||
"undefined_parameters": ["stat"],
|
||||
}
|
||||
|
||||
@mock.patch("superset.sql_lab.get_query")
|
||||
@mock.patch("superset.sql_lab.execute_sql_statement")
|
||||
def test_execute_sql_statements(self, mock_execute_sql_statement, mock_get_query):
|
||||
sql = """
|
||||
-- comment
|
||||
SET @value = 42;
|
||||
SELECT @value AS foo;
|
||||
-- comment
|
||||
"""
|
||||
mock_session = mock.MagicMock()
|
||||
mock_query = mock.MagicMock()
|
||||
mock_query.database.allow_run_async = False
|
||||
mock_cursor = mock.MagicMock()
|
||||
mock_query.database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value = (
|
||||
mock_cursor
|
||||
)
|
||||
mock_query.database.db_engine_spec.run_multiple_statements_as_one = False
|
||||
mock_get_query.return_value = mock_query
|
||||
|
||||
execute_sql_statements(
|
||||
query_id=1,
|
||||
rendered_query=sql,
|
||||
return_results=True,
|
||||
store_results=False,
|
||||
user_name="admin",
|
||||
session=mock_session,
|
||||
start_time=None,
|
||||
expand_data=False,
|
||||
log_params=None,
|
||||
)
|
||||
mock_execute_sql_statement.assert_has_calls(
|
||||
[
|
||||
mock.call(
|
||||
"SET @value = 42",
|
||||
mock_query,
|
||||
"admin",
|
||||
mock_session,
|
||||
mock_cursor,
|
||||
None,
|
||||
False,
|
||||
),
|
||||
mock.call(
|
||||
"SELECT @value AS foo",
|
||||
mock_query,
|
||||
"admin",
|
||||
mock_session,
|
||||
mock_cursor,
|
||||
None,
|
||||
False,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
@mock.patch("superset.sql_lab.get_query")
|
||||
@mock.patch("superset.sql_lab.execute_sql_statement")
|
||||
def test_execute_sql_statements_ctas(
|
||||
self, mock_execute_sql_statement, mock_get_query
|
||||
):
|
||||
sql = """
|
||||
-- comment
|
||||
SET @value = 42;
|
||||
SELECT @value AS foo;
|
||||
-- comment
|
||||
"""
|
||||
mock_session = mock.MagicMock()
|
||||
mock_query = mock.MagicMock()
|
||||
mock_query.database.allow_run_async = False
|
||||
mock_cursor = mock.MagicMock()
|
||||
mock_query.database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value = (
|
||||
mock_cursor
|
||||
)
|
||||
mock_query.database.db_engine_spec.run_multiple_statements_as_one = False
|
||||
mock_get_query.return_value = mock_query
|
||||
|
||||
# set the query to CTAS
|
||||
mock_query.select_as_cta = True
|
||||
mock_query.ctas_method = CtasMethod.TABLE
|
||||
|
||||
execute_sql_statements(
|
||||
query_id=1,
|
||||
rendered_query=sql,
|
||||
return_results=True,
|
||||
store_results=False,
|
||||
user_name="admin",
|
||||
session=mock_session,
|
||||
start_time=None,
|
||||
expand_data=False,
|
||||
log_params=None,
|
||||
)
|
||||
mock_execute_sql_statement.assert_has_calls(
|
||||
[
|
||||
mock.call(
|
||||
"SET @value = 42",
|
||||
mock_query,
|
||||
"admin",
|
||||
mock_session,
|
||||
mock_cursor,
|
||||
None,
|
||||
False,
|
||||
),
|
||||
mock.call(
|
||||
"SELECT @value AS foo",
|
||||
mock_query,
|
||||
"admin",
|
||||
mock_session,
|
||||
mock_cursor,
|
||||
None,
|
||||
True, # apply_ctas
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# try invalid CTAS
|
||||
sql = "DROP TABLE my_table"
|
||||
with pytest.raises(SqlLabException) as excinfo:
|
||||
execute_sql_statements(
|
||||
query_id=1,
|
||||
rendered_query=sql,
|
||||
return_results=True,
|
||||
store_results=False,
|
||||
user_name="admin",
|
||||
session=mock_session,
|
||||
start_time=None,
|
||||
expand_data=False,
|
||||
log_params=None,
|
||||
)
|
||||
assert str(excinfo.value) == (
|
||||
"CTAS (create table as select) can only be run with "
|
||||
"a query where the last statement is a SELECT. Please "
|
||||
"make sure your query has a SELECT as its last "
|
||||
"statement. Then, try running your query again."
|
||||
)
|
||||
|
||||
# try invalid CVAS
|
||||
mock_query.ctas_method = CtasMethod.VIEW
|
||||
sql = """
|
||||
-- comment
|
||||
SET @value = 42;
|
||||
SELECT @value AS foo;
|
||||
-- comment
|
||||
"""
|
||||
with pytest.raises(SqlLabException) as excinfo:
|
||||
execute_sql_statements(
|
||||
query_id=1,
|
||||
rendered_query=sql,
|
||||
return_results=True,
|
||||
store_results=False,
|
||||
user_name="admin",
|
||||
session=mock_session,
|
||||
start_time=None,
|
||||
expand_data=False,
|
||||
log_params=None,
|
||||
)
|
||||
assert str(excinfo.value) == (
|
||||
"CVAS (create view as select) can only be run with a "
|
||||
"query with a single SELECT statement. Please make "
|
||||
"sure your query has only a SELECT statement. Then, "
|
||||
"try running your query again."
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue