reuse_regex_logic
This commit is contained in:
parent
1aced9b562
commit
d38315a307
|
|
@ -104,14 +104,8 @@ class BaseEngineSpec(object):
|
|||
)
|
||||
return database.compile_sqla_query(qry)
|
||||
elif LimitMethod.FORCE_LIMIT:
|
||||
no_limit = re.sub(r"""
|
||||
(?ix) # case insensitive, verbose
|
||||
\s+ # whitespace
|
||||
LIMIT\s+\d+ # LIMIT $ROWS
|
||||
;? # optional semi-colon
|
||||
(\s|;)*$ # remove trailing spaces tabs or semicolons
|
||||
""", '', sql)
|
||||
return '{no_limit} LIMIT {limit}'.format(**locals())
|
||||
sql_without_limit = utils.get_query_without_limit(sql)
|
||||
return '{sql_without_limit} LIMIT {limit}'.format(**locals())
|
||||
return sql
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -171,7 +171,7 @@ def execute_sql(
|
|||
# Limit enforced only for retrieving the data, not for the CTA queries.
|
||||
superset_query = SupersetQuery(rendered_query)
|
||||
executed_sql = superset_query.stripped()
|
||||
SQL_MAX_ROWS = int(app.config.get('SQL_MAX_ROW', None))
|
||||
SQL_MAX_ROWS = app.config.get('SQL_MAX_ROW')
|
||||
if not superset_query.is_select() and not database.allow_dml:
|
||||
return handle_error(
|
||||
'Only `SELECT` statements are allowed against this database')
|
||||
|
|
@ -186,7 +186,8 @@ def execute_sql(
|
|||
query.user_id, start_dttm.strftime('%Y_%m_%d_%H_%M_%S'))
|
||||
executed_sql = superset_query.as_create_table(query.tmp_table_name)
|
||||
query.select_as_cta_used = True
|
||||
elif (not query.limit and superset_query.is_select() and SQL_MAX_ROWS):
|
||||
elif (superset_query.is_select() and SQL_MAX_ROWS and
|
||||
(not query.limit or query.limit > SQL_MAX_ROWS)):
|
||||
query.limit = SQL_MAX_ROWS
|
||||
executed_sql = database.apply_limit_to_sql(executed_sql, query.limit)
|
||||
query.limit_used = True
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ import functools
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import smtplib
|
||||
import sys
|
||||
|
|
@ -884,15 +885,27 @@ def split_adhoc_filters_into_base_filters(fd):
|
|||
del fd['adhoc_filters']
|
||||
|
||||
|
||||
def get_query_without_limit(sql):
|
||||
return re.sub(r"""
|
||||
(?ix) # case insensitive, verbose
|
||||
\s+ # whitespace
|
||||
LIMIT\s+\d+ # LIMIT $ROWS
|
||||
;? # optional semi-colon
|
||||
(\s|;)*$ # remove trailing spaces tabs or semicolons
|
||||
""", '', sql)
|
||||
|
||||
|
||||
def get_limit_from_sql(sql):
|
||||
sql = sql.lower()
|
||||
limit = None
|
||||
tokens = sql.split()
|
||||
try:
|
||||
if 'limit' in tokens:
|
||||
limit_pos = tokens.index('limit') + 1
|
||||
limit = int(tokens[limit_pos])
|
||||
except Exception as e:
|
||||
# fail quietly so we can get the more intelligible error from the database.
|
||||
logging.error('Non-numeric limit added.\n{}'.format(e))
|
||||
return limit
|
||||
# returns the limit of the quest or None if it has no limit.
|
||||
|
||||
limit_pattern = re.compile(r"""
|
||||
(?ix) # case insensitive, verbose
|
||||
\s+ # whitespace
|
||||
LIMIT\s+(\d+) # LIMIT $ROWS
|
||||
;? # optional semi-colon
|
||||
(\s|;)*$ # remove trailing spaces tabs or semicolons
|
||||
""")
|
||||
matches = limit_pattern.findall(sql)
|
||||
|
||||
if matches:
|
||||
return int(matches[0])
|
||||
|
|
|
|||
|
|
@ -196,11 +196,13 @@ class CeleryTestCase(SupersetTestCase):
|
|||
self.assertEqual([{'name': 'Admin'}], df.to_dict(orient='records'))
|
||||
self.assertEqual(QueryStatus.SUCCESS, query.status)
|
||||
self.assertTrue('FROM tmp_async_1' in query.select_sql)
|
||||
self.assertTrue('LIMIT 666' in query.select_sql)
|
||||
self.assertEqual(
|
||||
'CREATE TABLE tmp_async_1 AS \nSELECT name FROM ab_role '
|
||||
"WHERE name='Admin'", query.executed_sql)
|
||||
self.assertEqual(sql_where, query.sql)
|
||||
self.assertEqual(0, query.rows)
|
||||
self.assertEqual(666, query.limit)
|
||||
self.assertEqual(False, query.limit_used)
|
||||
self.assertEqual(True, query.select_as_cta)
|
||||
self.assertEqual(True, query.select_as_cta_used)
|
||||
|
|
|
|||
Loading…
Reference in New Issue