fix(mssql): reverts #9644 and displays a better error msg (#9752)

This commit is contained in:
Daniel Vaz Gaspar 2020-05-14 17:00:02 +01:00 committed by GitHub
parent 4427d65717
commit 3cc540019f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 36 additions and 115 deletions

View File

@ -22,7 +22,6 @@ from typing import Any, List, Optional, Tuple, TYPE_CHECKING
from sqlalchemy.types import String, TypeEngine, UnicodeText
from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
from superset.sql_parse import ParsedQuery
if TYPE_CHECKING:
from superset.models.core import Database # pylint: disable=unused-import
@ -85,6 +84,10 @@ class MssqlEngineSpec(BaseEngineSpec):
return None
@classmethod
def apply_limit_to_sql(cls, sql: str, limit: int, database: "Database") -> str:
new_sql = ParsedQuery(sql).set_alias()
return super().apply_limit_to_sql(new_sql, limit, database)
def extract_error_message(cls, ex: Exception) -> str:
if str(ex).startswith("(8155,"):
return (
f"{cls.engine} error: All your SQL functions need to "
"have an alias on MSSQL. For example: SELECT COUNT(*) AS C1 FROM TABLE1"
)
return f"{cls.engine} error: {cls._extract_error_message(ex)}"

View File

@ -20,14 +20,7 @@ from urllib import parse
import sqlparse
from dataclasses import dataclass
from sqlparse.sql import (
Function,
Identifier,
IdentifierList,
remove_quotes,
Token,
TokenList,
)
from sqlparse.sql import Identifier, IdentifierList, remove_quotes, Token, TokenList
from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace
from sqlparse.utils import imt
@ -284,39 +277,3 @@ class ParsedQuery:
for i in statement.tokens:
str_res += str(i.value)
return str_res
def set_alias(self) -> str:
"""
Returns a new query string where all functions have alias.
This is particularly necessary for MSSQL engines.
:return: String with new aliased SQL query
"""
new_sql = ""
changed_counter = 1
for token in self._parsed[0].tokens:
# Identifier list (list of columns)
if isinstance(token, IdentifierList) and token.ttype is None:
for i, identifier in enumerate(token.get_identifiers()):
# Functions are anonymous on MSSQL
if isinstance(identifier, Function) and not identifier.has_alias():
identifier.value = (
f"{identifier.value} AS"
f" {identifier.get_real_name()}_{changed_counter}"
)
changed_counter += 1
new_sql += str(identifier.value)
# If not last identifier
if i != len(list(token.get_identifiers())) - 1:
new_sql += ", "
# Just a lonely function?
elif isinstance(token, Function) and token.ttype is None:
if not token.has_alias():
token.value = (
f"{token.value} AS {token.get_real_name()}_{changed_counter}"
)
new_sql += str(token.value)
# Nothing to change, assemble what we have
else:
new_sql += str(token.value)
return new_sql

View File

@ -444,7 +444,7 @@ def json_dumps_w_dates(payload):
return json.dumps(payload, default=json_int_dttm_ser)
def error_msg_from_exception(e: Exception) -> str:
def error_msg_from_exception(ex: Exception) -> str:
"""Translate exception into error message
Database have different ways to handle exception. This function attempts
@ -459,12 +459,12 @@ def error_msg_from_exception(e: Exception) -> str:
The latter version is parsed correctly by this function.
"""
msg = ""
if hasattr(e, "message"):
if isinstance(e.message, dict): # type: ignore
msg = e.message.get("message") # type: ignore
elif e.message: # type: ignore
msg = e.message # type: ignore
return msg or str(e)
if hasattr(ex, "message"):
if isinstance(ex.message, dict): # type: ignore
msg = ex.message.get("message") # type: ignore
elif ex.message: # type: ignore
msg = ex.message # type: ignore
return msg or str(ex)
def markdown(s: str, markup_wrap: Optional[bool] = False) -> str:

View File

@ -15,18 +15,15 @@
# specific language governing permissions and limitations
# under the License.
import unittest.mock as mock
from typing import Optional
from sqlalchemy import column, table
from sqlalchemy.dialects import mssql
from sqlalchemy.dialects.mssql import DATE, NTEXT, NVARCHAR, TEXT, VARCHAR
from sqlalchemy.sql import select, Select
from sqlalchemy.sql import select
from sqlalchemy.types import String, UnicodeText
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.mssql import MssqlEngineSpec
from superset.extensions import db
from superset.models.core import Database
from tests.db_engine_specs.base_tests import DbEngineSpecTestCase
@ -97,64 +94,28 @@ class MssqlEngineSpecTest(DbEngineSpecTestCase):
for actual, expected in test_cases:
self.assertEqual(actual, expected)
def test_apply_limit(self):
def compile_sqla_query(qry: Select, schema: Optional[str] = None) -> str:
return str(
qry.compile(
dialect=mssql.dialect(), compile_kwargs={"literal_binds": True}
)
)
database = Database(
database_name="mssql_test",
sqlalchemy_uri="mssql+pymssql://sa:Password_123@localhost:1433/msdb",
def test_extract_error_message(self):
test_mssql_exception = Exception(
"(8155, b\"No column name was specified for column 1 of 'inner_qry'."
"DB-Lib error message 20018, severity 16:\\nGeneral SQL Server error: "
'Check messages from the SQL Server\\n")'
)
db.session.add(database)
db.session.commit()
error_message = MssqlEngineSpec.extract_error_message(test_mssql_exception)
expected_message = (
"mssql error: All your SQL functions need to "
"have an alias on MSSQL. For example: SELECT COUNT(*) AS C1 FROM TABLE1"
)
self.assertEqual(expected_message, error_message)
with mock.patch.object(database, "compile_sqla_query", new=compile_sqla_query):
test_sql = "SELECT COUNT(*) FROM FOO_TABLE"
limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database)
expected_sql = (
"SELECT TOP 1000 * \n"
"FROM (SELECT COUNT(*) AS COUNT_1 FROM FOO_TABLE) AS inner_qry"
)
self.assertEqual(expected_sql, limited_sql)
test_sql = "SELECT COUNT(*), SUM(id) FROM FOO_TABLE"
limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database)
expected_sql = (
"SELECT TOP 1000 * \n"
"FROM (SELECT COUNT(*) AS COUNT_1, SUM(id) AS SUM_2 FROM FOO_TABLE) "
"AS inner_qry"
)
self.assertEqual(expected_sql, limited_sql)
test_sql = "SELECT COUNT(*), FOO_COL1 FROM FOO_TABLE GROUP BY FOO_COL1"
limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database)
expected_sql = (
"SELECT TOP 1000 * \n"
"FROM (SELECT COUNT(*) AS COUNT_1, "
"FOO_COL1 FROM FOO_TABLE GROUP BY FOO_COL1)"
" AS inner_qry"
)
self.assertEqual(expected_sql, limited_sql)
test_sql = "SELECT COUNT(*), COUNT(*) FROM FOO_TABLE"
limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database)
expected_sql = (
"SELECT TOP 1000 * \n"
"FROM (SELECT COUNT(*) AS COUNT_1, COUNT(*) AS COUNT_2 FROM FOO_TABLE)"
" AS inner_qry"
)
self.assertEqual(expected_sql, limited_sql)
db.session.delete(database)
db.session.commit()
test_mssql_exception = Exception(
'(8200, b"A correlated expression is invalid because it is not in a '
"GROUP BY clause.\\n\")'"
)
error_message = MssqlEngineSpec.extract_error_message(test_mssql_exception)
expected_message = "mssql error: " + MssqlEngineSpec._extract_error_message(
test_mssql_exception
)
self.assertEqual(expected_message, error_message)
@mock.patch.object(
MssqlEngineSpec, "pyodbc_rows_to_tuples", return_value="converted"