From 3cc540019f2aa6c3dac1d356f0f21eeca96b34f2 Mon Sep 17 00:00:00 2001 From: Daniel Vaz Gaspar Date: Thu, 14 May 2020 17:00:02 +0100 Subject: [PATCH] fix(mssql): reverts #9644 and displays a better error msg (#9752) --- superset/db_engine_specs/mssql.py | 11 ++-- superset/sql_parse.py | 45 +--------------- superset/utils/core.py | 14 ++--- tests/db_engine_specs/mssql_tests.py | 81 ++++++++-------------------- 4 files changed, 36 insertions(+), 115 deletions(-) diff --git a/superset/db_engine_specs/mssql.py b/superset/db_engine_specs/mssql.py index 4fc6e6f44..fde69b3c2 100644 --- a/superset/db_engine_specs/mssql.py +++ b/superset/db_engine_specs/mssql.py @@ -22,7 +22,6 @@ from typing import Any, List, Optional, Tuple, TYPE_CHECKING from sqlalchemy.types import String, TypeEngine, UnicodeText from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod -from superset.sql_parse import ParsedQuery if TYPE_CHECKING: from superset.models.core import Database # pylint: disable=unused-import @@ -85,6 +84,10 @@ class MssqlEngineSpec(BaseEngineSpec): return None @classmethod - def apply_limit_to_sql(cls, sql: str, limit: int, database: "Database") -> str: - new_sql = ParsedQuery(sql).set_alias() - return super().apply_limit_to_sql(new_sql, limit, database) + def extract_error_message(cls, ex: Exception) -> str: + if str(ex).startswith("(8155,"): + return ( + f"{cls.engine} error: All your SQL functions need to " + "have an alias on MSSQL. For example: SELECT COUNT(*) AS C1 FROM TABLE1" + ) + return f"{cls.engine} error: {cls._extract_error_message(ex)}" diff --git a/superset/sql_parse.py b/superset/sql_parse.py index bb6f34123..e841db10c 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -20,14 +20,7 @@ from urllib import parse import sqlparse from dataclasses import dataclass -from sqlparse.sql import ( - Function, - Identifier, - IdentifierList, - remove_quotes, - Token, - TokenList, -) +from sqlparse.sql import Identifier, IdentifierList, remove_quotes, Token, TokenList from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace from sqlparse.utils import imt @@ -284,39 +277,3 @@ class ParsedQuery: for i in statement.tokens: str_res += str(i.value) return str_res - - def set_alias(self) -> str: - """ - Returns a new query string where all functions have alias. - This is particularly necessary for MSSQL engines. - - :return: String with new aliased SQL query - """ - new_sql = "" - changed_counter = 1 - for token in self._parsed[0].tokens: - # Identifier list (list of columns) - if isinstance(token, IdentifierList) and token.ttype is None: - for i, identifier in enumerate(token.get_identifiers()): - # Functions are anonymous on MSSQL - if isinstance(identifier, Function) and not identifier.has_alias(): - identifier.value = ( - f"{identifier.value} AS" - f" {identifier.get_real_name()}_{changed_counter}" - ) - changed_counter += 1 - new_sql += str(identifier.value) - # If not last identifier - if i != len(list(token.get_identifiers())) - 1: - new_sql += ", " - # Just a lonely function? - elif isinstance(token, Function) and token.ttype is None: - if not token.has_alias(): - token.value = ( - f"{token.value} AS {token.get_real_name()}_{changed_counter}" - ) - new_sql += str(token.value) - # Nothing to change, assemble what we have - else: - new_sql += str(token.value) - return new_sql diff --git a/superset/utils/core.py b/superset/utils/core.py index 41deae504..47fe8c889 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -444,7 +444,7 @@ def json_dumps_w_dates(payload): return json.dumps(payload, default=json_int_dttm_ser) -def error_msg_from_exception(e: Exception) -> str: +def error_msg_from_exception(ex: Exception) -> str: """Translate exception into error message Database have different ways to handle exception. This function attempts @@ -459,12 +459,12 @@ def error_msg_from_exception(e: Exception) -> str: The latter version is parsed correctly by this function. """ msg = "" - if hasattr(e, "message"): - if isinstance(e.message, dict): # type: ignore - msg = e.message.get("message") # type: ignore - elif e.message: # type: ignore - msg = e.message # type: ignore - return msg or str(e) + if hasattr(ex, "message"): + if isinstance(ex.message, dict): # type: ignore + msg = ex.message.get("message") # type: ignore + elif ex.message: # type: ignore + msg = ex.message # type: ignore + return msg or str(ex) def markdown(s: str, markup_wrap: Optional[bool] = False) -> str: diff --git a/tests/db_engine_specs/mssql_tests.py b/tests/db_engine_specs/mssql_tests.py index 9f5351cc9..0a254de02 100644 --- a/tests/db_engine_specs/mssql_tests.py +++ b/tests/db_engine_specs/mssql_tests.py @@ -15,18 +15,15 @@ # specific language governing permissions and limitations # under the License. import unittest.mock as mock -from typing import Optional from sqlalchemy import column, table from sqlalchemy.dialects import mssql from sqlalchemy.dialects.mssql import DATE, NTEXT, NVARCHAR, TEXT, VARCHAR -from sqlalchemy.sql import select, Select +from sqlalchemy.sql import select from sqlalchemy.types import String, UnicodeText from superset.db_engine_specs.base import BaseEngineSpec from superset.db_engine_specs.mssql import MssqlEngineSpec -from superset.extensions import db -from superset.models.core import Database from tests.db_engine_specs.base_tests import DbEngineSpecTestCase @@ -97,64 +94,28 @@ class MssqlEngineSpecTest(DbEngineSpecTestCase): for actual, expected in test_cases: self.assertEqual(actual, expected) - def test_apply_limit(self): - def compile_sqla_query(qry: Select, schema: Optional[str] = None) -> str: - return str( - qry.compile( - dialect=mssql.dialect(), compile_kwargs={"literal_binds": True} - ) - ) - - database = Database( - database_name="mssql_test", - sqlalchemy_uri="mssql+pymssql://sa:Password_123@localhost:1433/msdb", + def test_extract_error_message(self): + test_mssql_exception = Exception( + "(8155, b\"No column name was specified for column 1 of 'inner_qry'." + "DB-Lib error message 20018, severity 16:\\nGeneral SQL Server error: " + 'Check messages from the SQL Server\\n")' ) - db.session.add(database) - db.session.commit() + error_message = MssqlEngineSpec.extract_error_message(test_mssql_exception) + expected_message = ( + "mssql error: All your SQL functions need to " + "have an alias on MSSQL. For example: SELECT COUNT(*) AS C1 FROM TABLE1" + ) + self.assertEqual(expected_message, error_message) - with mock.patch.object(database, "compile_sqla_query", new=compile_sqla_query): - test_sql = "SELECT COUNT(*) FROM FOO_TABLE" - - limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database) - - expected_sql = ( - "SELECT TOP 1000 * \n" - "FROM (SELECT COUNT(*) AS COUNT_1 FROM FOO_TABLE) AS inner_qry" - ) - self.assertEqual(expected_sql, limited_sql) - - test_sql = "SELECT COUNT(*), SUM(id) FROM FOO_TABLE" - limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database) - - expected_sql = ( - "SELECT TOP 1000 * \n" - "FROM (SELECT COUNT(*) AS COUNT_1, SUM(id) AS SUM_2 FROM FOO_TABLE) " - "AS inner_qry" - ) - self.assertEqual(expected_sql, limited_sql) - - test_sql = "SELECT COUNT(*), FOO_COL1 FROM FOO_TABLE GROUP BY FOO_COL1" - limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database) - - expected_sql = ( - "SELECT TOP 1000 * \n" - "FROM (SELECT COUNT(*) AS COUNT_1, " - "FOO_COL1 FROM FOO_TABLE GROUP BY FOO_COL1)" - " AS inner_qry" - ) - self.assertEqual(expected_sql, limited_sql) - - test_sql = "SELECT COUNT(*), COUNT(*) FROM FOO_TABLE" - limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database) - expected_sql = ( - "SELECT TOP 1000 * \n" - "FROM (SELECT COUNT(*) AS COUNT_1, COUNT(*) AS COUNT_2 FROM FOO_TABLE)" - " AS inner_qry" - ) - self.assertEqual(expected_sql, limited_sql) - - db.session.delete(database) - db.session.commit() + test_mssql_exception = Exception( + '(8200, b"A correlated expression is invalid because it is not in a ' + "GROUP BY clause.\\n\")'" + ) + error_message = MssqlEngineSpec.extract_error_message(test_mssql_exception) + expected_message = "mssql error: " + MssqlEngineSpec._extract_error_message( + test_mssql_exception + ) + self.assertEqual(expected_message, error_message) @mock.patch.object( MssqlEngineSpec, "pyodbc_rows_to_tuples", return_value="converted"