parent
4427d65717
commit
3cc540019f
|
|
@ -22,7 +22,6 @@ from typing import Any, List, Optional, Tuple, TYPE_CHECKING
|
|||
from sqlalchemy.types import String, TypeEngine, UnicodeText
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
|
||||
from superset.sql_parse import ParsedQuery
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.models.core import Database # pylint: disable=unused-import
|
||||
|
|
@ -85,6 +84,10 @@ class MssqlEngineSpec(BaseEngineSpec):
|
|||
return None
|
||||
|
||||
@classmethod
|
||||
def apply_limit_to_sql(cls, sql: str, limit: int, database: "Database") -> str:
|
||||
new_sql = ParsedQuery(sql).set_alias()
|
||||
return super().apply_limit_to_sql(new_sql, limit, database)
|
||||
def extract_error_message(cls, ex: Exception) -> str:
|
||||
if str(ex).startswith("(8155,"):
|
||||
return (
|
||||
f"{cls.engine} error: All your SQL functions need to "
|
||||
"have an alias on MSSQL. For example: SELECT COUNT(*) AS C1 FROM TABLE1"
|
||||
)
|
||||
return f"{cls.engine} error: {cls._extract_error_message(ex)}"
|
||||
|
|
|
|||
|
|
@ -20,14 +20,7 @@ from urllib import parse
|
|||
|
||||
import sqlparse
|
||||
from dataclasses import dataclass
|
||||
from sqlparse.sql import (
|
||||
Function,
|
||||
Identifier,
|
||||
IdentifierList,
|
||||
remove_quotes,
|
||||
Token,
|
||||
TokenList,
|
||||
)
|
||||
from sqlparse.sql import Identifier, IdentifierList, remove_quotes, Token, TokenList
|
||||
from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace
|
||||
from sqlparse.utils import imt
|
||||
|
||||
|
|
@ -284,39 +277,3 @@ class ParsedQuery:
|
|||
for i in statement.tokens:
|
||||
str_res += str(i.value)
|
||||
return str_res
|
||||
|
||||
def set_alias(self) -> str:
|
||||
"""
|
||||
Returns a new query string where all functions have alias.
|
||||
This is particularly necessary for MSSQL engines.
|
||||
|
||||
:return: String with new aliased SQL query
|
||||
"""
|
||||
new_sql = ""
|
||||
changed_counter = 1
|
||||
for token in self._parsed[0].tokens:
|
||||
# Identifier list (list of columns)
|
||||
if isinstance(token, IdentifierList) and token.ttype is None:
|
||||
for i, identifier in enumerate(token.get_identifiers()):
|
||||
# Functions are anonymous on MSSQL
|
||||
if isinstance(identifier, Function) and not identifier.has_alias():
|
||||
identifier.value = (
|
||||
f"{identifier.value} AS"
|
||||
f" {identifier.get_real_name()}_{changed_counter}"
|
||||
)
|
||||
changed_counter += 1
|
||||
new_sql += str(identifier.value)
|
||||
# If not last identifier
|
||||
if i != len(list(token.get_identifiers())) - 1:
|
||||
new_sql += ", "
|
||||
# Just a lonely function?
|
||||
elif isinstance(token, Function) and token.ttype is None:
|
||||
if not token.has_alias():
|
||||
token.value = (
|
||||
f"{token.value} AS {token.get_real_name()}_{changed_counter}"
|
||||
)
|
||||
new_sql += str(token.value)
|
||||
# Nothing to change, assemble what we have
|
||||
else:
|
||||
new_sql += str(token.value)
|
||||
return new_sql
|
||||
|
|
|
|||
|
|
@ -444,7 +444,7 @@ def json_dumps_w_dates(payload):
|
|||
return json.dumps(payload, default=json_int_dttm_ser)
|
||||
|
||||
|
||||
def error_msg_from_exception(e: Exception) -> str:
|
||||
def error_msg_from_exception(ex: Exception) -> str:
|
||||
"""Translate exception into error message
|
||||
|
||||
Database have different ways to handle exception. This function attempts
|
||||
|
|
@ -459,12 +459,12 @@ def error_msg_from_exception(e: Exception) -> str:
|
|||
The latter version is parsed correctly by this function.
|
||||
"""
|
||||
msg = ""
|
||||
if hasattr(e, "message"):
|
||||
if isinstance(e.message, dict): # type: ignore
|
||||
msg = e.message.get("message") # type: ignore
|
||||
elif e.message: # type: ignore
|
||||
msg = e.message # type: ignore
|
||||
return msg or str(e)
|
||||
if hasattr(ex, "message"):
|
||||
if isinstance(ex.message, dict): # type: ignore
|
||||
msg = ex.message.get("message") # type: ignore
|
||||
elif ex.message: # type: ignore
|
||||
msg = ex.message # type: ignore
|
||||
return msg or str(ex)
|
||||
|
||||
|
||||
def markdown(s: str, markup_wrap: Optional[bool] = False) -> str:
|
||||
|
|
|
|||
|
|
@ -15,18 +15,15 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import unittest.mock as mock
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import column, table
|
||||
from sqlalchemy.dialects import mssql
|
||||
from sqlalchemy.dialects.mssql import DATE, NTEXT, NVARCHAR, TEXT, VARCHAR
|
||||
from sqlalchemy.sql import select, Select
|
||||
from sqlalchemy.sql import select
|
||||
from sqlalchemy.types import String, UnicodeText
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.db_engine_specs.mssql import MssqlEngineSpec
|
||||
from superset.extensions import db
|
||||
from superset.models.core import Database
|
||||
from tests.db_engine_specs.base_tests import DbEngineSpecTestCase
|
||||
|
||||
|
||||
|
|
@ -97,64 +94,28 @@ class MssqlEngineSpecTest(DbEngineSpecTestCase):
|
|||
for actual, expected in test_cases:
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_apply_limit(self):
|
||||
def compile_sqla_query(qry: Select, schema: Optional[str] = None) -> str:
|
||||
return str(
|
||||
qry.compile(
|
||||
dialect=mssql.dialect(), compile_kwargs={"literal_binds": True}
|
||||
)
|
||||
)
|
||||
|
||||
database = Database(
|
||||
database_name="mssql_test",
|
||||
sqlalchemy_uri="mssql+pymssql://sa:Password_123@localhost:1433/msdb",
|
||||
def test_extract_error_message(self):
|
||||
test_mssql_exception = Exception(
|
||||
"(8155, b\"No column name was specified for column 1 of 'inner_qry'."
|
||||
"DB-Lib error message 20018, severity 16:\\nGeneral SQL Server error: "
|
||||
'Check messages from the SQL Server\\n")'
|
||||
)
|
||||
db.session.add(database)
|
||||
db.session.commit()
|
||||
error_message = MssqlEngineSpec.extract_error_message(test_mssql_exception)
|
||||
expected_message = (
|
||||
"mssql error: All your SQL functions need to "
|
||||
"have an alias on MSSQL. For example: SELECT COUNT(*) AS C1 FROM TABLE1"
|
||||
)
|
||||
self.assertEqual(expected_message, error_message)
|
||||
|
||||
with mock.patch.object(database, "compile_sqla_query", new=compile_sqla_query):
|
||||
test_sql = "SELECT COUNT(*) FROM FOO_TABLE"
|
||||
|
||||
limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database)
|
||||
|
||||
expected_sql = (
|
||||
"SELECT TOP 1000 * \n"
|
||||
"FROM (SELECT COUNT(*) AS COUNT_1 FROM FOO_TABLE) AS inner_qry"
|
||||
)
|
||||
self.assertEqual(expected_sql, limited_sql)
|
||||
|
||||
test_sql = "SELECT COUNT(*), SUM(id) FROM FOO_TABLE"
|
||||
limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database)
|
||||
|
||||
expected_sql = (
|
||||
"SELECT TOP 1000 * \n"
|
||||
"FROM (SELECT COUNT(*) AS COUNT_1, SUM(id) AS SUM_2 FROM FOO_TABLE) "
|
||||
"AS inner_qry"
|
||||
)
|
||||
self.assertEqual(expected_sql, limited_sql)
|
||||
|
||||
test_sql = "SELECT COUNT(*), FOO_COL1 FROM FOO_TABLE GROUP BY FOO_COL1"
|
||||
limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database)
|
||||
|
||||
expected_sql = (
|
||||
"SELECT TOP 1000 * \n"
|
||||
"FROM (SELECT COUNT(*) AS COUNT_1, "
|
||||
"FOO_COL1 FROM FOO_TABLE GROUP BY FOO_COL1)"
|
||||
" AS inner_qry"
|
||||
)
|
||||
self.assertEqual(expected_sql, limited_sql)
|
||||
|
||||
test_sql = "SELECT COUNT(*), COUNT(*) FROM FOO_TABLE"
|
||||
limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database)
|
||||
expected_sql = (
|
||||
"SELECT TOP 1000 * \n"
|
||||
"FROM (SELECT COUNT(*) AS COUNT_1, COUNT(*) AS COUNT_2 FROM FOO_TABLE)"
|
||||
" AS inner_qry"
|
||||
)
|
||||
self.assertEqual(expected_sql, limited_sql)
|
||||
|
||||
db.session.delete(database)
|
||||
db.session.commit()
|
||||
test_mssql_exception = Exception(
|
||||
'(8200, b"A correlated expression is invalid because it is not in a '
|
||||
"GROUP BY clause.\\n\")'"
|
||||
)
|
||||
error_message = MssqlEngineSpec.extract_error_message(test_mssql_exception)
|
||||
expected_message = "mssql error: " + MssqlEngineSpec._extract_error_message(
|
||||
test_mssql_exception
|
||||
)
|
||||
self.assertEqual(expected_message, error_message)
|
||||
|
||||
@mock.patch.object(
|
||||
MssqlEngineSpec, "pyodbc_rows_to_tuples", return_value="converted"
|
||||
|
|
|
|||
Loading…
Reference in New Issue