From 164db3e5a13c21137afb56a3044ef3f1aaf89e11 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Mon, 4 Jan 2021 09:22:35 -0800 Subject: [PATCH] fix: CTAS on multiple statements (#12188) * WIP * Add unit tests for sql_parse * Add unit tests for sql_lab --- superset/sql_lab.py | 57 +++++++++++--- superset/sql_parse.py | 11 ++- tests/sql_parse_tests.py | 76 ++++++++++++++++++ tests/sqllab_tests.py | 164 ++++++++++++++++++++++++++++++++++++++- 4 files changed, 295 insertions(+), 13 deletions(-) diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 800149889..1153a2b7d 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -36,7 +36,7 @@ from superset.db_engine_specs import BaseEngineSpec from superset.extensions import celery_app from superset.models.sql_lab import Query from superset.result_set import SupersetResultSet -from superset.sql_parse import ParsedQuery +from superset.sql_parse import CtasMethod, ParsedQuery from superset.utils.celery import session_scope from superset.utils.core import ( json_iso_dttm_ser, @@ -160,6 +160,7 @@ def execute_sql_statement( session: Session, cursor: Any, log_params: Optional[Dict[str, Any]], + apply_ctas: bool = False, ) -> SupersetResultSet: """Executes a single SQL statement""" database = query.database @@ -171,14 +172,7 @@ def execute_sql_statement( raise SqlLabSecurityException( _("Only `SELECT` statements are allowed against this database") ) - if query.select_as_cta: - if not parsed_query.is_select(): - raise SqlLabException( - _( - "Only `SELECT` statements can be used with the CREATE TABLE " - "feature." - ) - ) + if apply_ctas: if not query.tmp_table_name: start_dttm = datetime.fromtimestamp(query.start_time) query.tmp_table_name = "tmp_{}_table_{}".format( @@ -322,8 +316,8 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca raise SqlLabException("Results backend isn't configured.") # Breaking down into multiple statements + parsed_query = ParsedQuery(rendered_query, strip_comments=True) if not db_engine_spec.run_multiple_statements_as_one: - parsed_query = ParsedQuery(rendered_query) statements = parsed_query.get_statements() logger.info( "Query %s: Executing %i statement(s)", str(query_id), len(statements) @@ -337,6 +331,32 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca query.start_running_time = now_as_float() session.commit() + # Should we create a table or view from the select? + if ( + query.select_as_cta + and query.ctas_method == CtasMethod.TABLE + and not parsed_query.is_valid_ctas() + ): + raise SqlLabException( + _( + "CTAS (create table as select) can only be run with a query where " + "the last statement is a SELECT. Please make sure your query has " + "a SELECT as its last statement. Then, try running your query again." + ) + ) + if ( + query.select_as_cta + and query.ctas_method == CtasMethod.VIEW + and not parsed_query.is_valid_cvas() + ): + raise SqlLabException( + _( + "CVAS (create view as select) can only be run with a query with " + "a single SELECT statement. Please make sure your query has only " + "a SELECT statement. Then, try running your query again." + ) + ) + engine = database.get_sqla_engine( schema=query.schema, nullpool=True, @@ -354,6 +374,15 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca if query.status == QueryStatus.STOPPED: return None + # For CTAS we create the table only on the last statement + apply_ctas = query.select_as_cta and ( + query.ctas_method == CtasMethod.VIEW + or ( + query.ctas_method == CtasMethod.TABLE + and i == len(statements) - 1 + ) + ) + # Run statement msg = f"Running statement {i+1} out of {statement_count}" logger.info("Query %s: %s", str(query_id), msg) @@ -361,7 +390,13 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca session.commit() try: result_set = execute_sql_statement( - statement, query, user_name, session, cursor, log_params + statement, + query, + user_name, + session, + cursor, + log_params, + apply_ctas, ) except Exception as ex: # pylint: disable=broad-except msg = str(ex) diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 8343f4273..dd345dbd1 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -81,7 +81,10 @@ class Table: # pylint: disable=too-few-public-methods class ParsedQuery: - def __init__(self, sql_statement: str): + def __init__(self, sql_statement: str, strip_comments: bool = False): + if strip_comments: + sql_statement = sqlparse.format(sql_statement, strip_comments=True) + self.sql: str = sql_statement self._tables: Set[Table] = set() self._alias_names: Set[str] = set() @@ -110,6 +113,12 @@ class ParsedQuery: def is_select(self) -> bool: return self._parsed[0].get_type() == "SELECT" + def is_valid_ctas(self) -> bool: + return self._parsed[-1].get_type() == "SELECT" + + def is_valid_cvas(self) -> bool: + return len(self._parsed) == 1 and self._parsed[0].get_type() == "SELECT" + def is_explain(self) -> bool: # Remove comments statements_without_comments = sqlparse.format( diff --git a/tests/sql_parse_tests.py b/tests/sql_parse_tests.py index e46c2b8f6..b54a9ef90 100644 --- a/tests/sql_parse_tests.py +++ b/tests/sql_parse_tests.py @@ -656,3 +656,79 @@ class TestSupersetSqlParse(unittest.TestCase): """ parsed = ParsedQuery(query) self.assertEqual(parsed.is_explain(), False) + + def test_is_valid_ctas(self): + """A valid CTAS has a SELECT as its last statement""" + query = "SELECT * FROM table" + parsed = ParsedQuery(query, strip_comments=True) + assert parsed.is_valid_ctas() + + query = """ + -- comment + SELECT * FROM table + -- comment 2 + """ + parsed = ParsedQuery(query, strip_comments=True) + assert parsed.is_valid_ctas() + + query = """ + -- comment + SET @value = 42; + SELECT @value as foo; + -- comment 2 + """ + parsed = ParsedQuery(query, strip_comments=True) + assert parsed.is_valid_ctas() + + query = """ + -- comment + EXPLAIN SELECT * FROM table + -- comment 2 + """ + parsed = ParsedQuery(query, strip_comments=True) + assert not parsed.is_valid_ctas() + + query = """ + SELECT * FROM table; + INSERT INTO TABLE (foo) VALUES (42); + """ + parsed = ParsedQuery(query, strip_comments=True) + assert not parsed.is_valid_ctas() + + def test_is_valid_cvas(self): + """A valid CVAS has a single SELECT statement""" + query = "SELECT * FROM table" + parsed = ParsedQuery(query, strip_comments=True) + assert parsed.is_valid_cvas() + + query = """ + -- comment + SELECT * FROM table + -- comment 2 + """ + parsed = ParsedQuery(query, strip_comments=True) + assert parsed.is_valid_cvas() + + query = """ + -- comment + SET @value = 42; + SELECT @value as foo; + -- comment 2 + """ + parsed = ParsedQuery(query, strip_comments=True) + assert not parsed.is_valid_cvas() + + query = """ + -- comment + EXPLAIN SELECT * FROM table + -- comment 2 + """ + parsed = ParsedQuery(query, strip_comments=True) + assert not parsed.is_valid_ctas() + + query = """ + SELECT * FROM table; + INSERT INTO TABLE (foo) VALUES (42); + """ + parsed = ParsedQuery(query, strip_comments=True) + assert not parsed.is_valid_ctas() diff --git a/tests/sqllab_tests.py b/tests/sqllab_tests.py index a10ef4fae..8c2b10a25 100644 --- a/tests/sqllab_tests.py +++ b/tests/sqllab_tests.py @@ -18,11 +18,12 @@ """Unit tests for Sql Lab""" import json from datetime import datetime, timedelta -from parameterized import parameterized from random import random from unittest import mock +from parameterized import parameterized import prison +import pytest from superset import db, security_manager from superset.connectors.sqla.models import SqlaTable @@ -30,6 +31,7 @@ from superset.db_engine_specs import BaseEngineSpec from superset.errors import ErrorLevel, SupersetErrorType from superset.models.sql_lab import Query, SavedQuery from superset.result_set import SupersetResultSet +from superset.sql_lab import execute_sql_statements, SqlLabException from superset.sql_parse import CtasMethod from superset.utils.core import ( datetime_to_epoch, @@ -618,3 +620,163 @@ class TestSqlLab(SupersetTestCase): "template_parameters": {"state": "CA"}, "undefined_parameters": ["stat"], } + + @mock.patch("superset.sql_lab.get_query") + @mock.patch("superset.sql_lab.execute_sql_statement") + def test_execute_sql_statements(self, mock_execute_sql_statement, mock_get_query): + sql = """ + -- comment + SET @value = 42; + SELECT @value AS foo; + -- comment + """ + mock_session = mock.MagicMock() + mock_query = mock.MagicMock() + mock_query.database.allow_run_async = False + mock_cursor = mock.MagicMock() + mock_query.database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value = ( + mock_cursor + ) + mock_query.database.db_engine_spec.run_multiple_statements_as_one = False + mock_get_query.return_value = mock_query + + execute_sql_statements( + query_id=1, + rendered_query=sql, + return_results=True, + store_results=False, + user_name="admin", + session=mock_session, + start_time=None, + expand_data=False, + log_params=None, + ) + mock_execute_sql_statement.assert_has_calls( + [ + mock.call( + "SET @value = 42", + mock_query, + "admin", + mock_session, + mock_cursor, + None, + False, + ), + mock.call( + "SELECT @value AS foo", + mock_query, + "admin", + mock_session, + mock_cursor, + None, + False, + ), + ] + ) + + @mock.patch("superset.sql_lab.get_query") + @mock.patch("superset.sql_lab.execute_sql_statement") + def test_execute_sql_statements_ctas( + self, mock_execute_sql_statement, mock_get_query + ): + sql = """ + -- comment + SET @value = 42; + SELECT @value AS foo; + -- comment + """ + mock_session = mock.MagicMock() + mock_query = mock.MagicMock() + mock_query.database.allow_run_async = False + mock_cursor = mock.MagicMock() + mock_query.database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value = ( + mock_cursor + ) + mock_query.database.db_engine_spec.run_multiple_statements_as_one = False + mock_get_query.return_value = mock_query + + # set the query to CTAS + mock_query.select_as_cta = True + mock_query.ctas_method = CtasMethod.TABLE + + execute_sql_statements( + query_id=1, + rendered_query=sql, + return_results=True, + store_results=False, + user_name="admin", + session=mock_session, + start_time=None, + expand_data=False, + log_params=None, + ) + mock_execute_sql_statement.assert_has_calls( + [ + mock.call( + "SET @value = 42", + mock_query, + "admin", + mock_session, + mock_cursor, + None, + False, + ), + mock.call( + "SELECT @value AS foo", + mock_query, + "admin", + mock_session, + mock_cursor, + None, + True, # apply_ctas + ), + ] + ) + + # try invalid CTAS + sql = "DROP TABLE my_table" + with pytest.raises(SqlLabException) as excinfo: + execute_sql_statements( + query_id=1, + rendered_query=sql, + return_results=True, + store_results=False, + user_name="admin", + session=mock_session, + start_time=None, + expand_data=False, + log_params=None, + ) + assert str(excinfo.value) == ( + "CTAS (create table as select) can only be run with " + "a query where the last statement is a SELECT. Please " + "make sure your query has a SELECT as its last " + "statement. Then, try running your query again." + ) + + # try invalid CVAS + mock_query.ctas_method = CtasMethod.VIEW + sql = """ + -- comment + SET @value = 42; + SELECT @value AS foo; + -- comment + """ + with pytest.raises(SqlLabException) as excinfo: + execute_sql_statements( + query_id=1, + rendered_query=sql, + return_results=True, + store_results=False, + user_name="admin", + session=mock_session, + start_time=None, + expand_data=False, + log_params=None, + ) + assert str(excinfo.value) == ( + "CVAS (create view as select) can only be run with a " + "query with a single SELECT statement. Please make " + "sure your query has only a SELECT statement. Then, " + "try running your query again." + )