diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index 0e189b125..5e706c50a 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -104,14 +104,8 @@ class BaseEngineSpec(object): ) return database.compile_sqla_query(qry) elif LimitMethod.FORCE_LIMIT: - no_limit = re.sub(r""" - (?ix) # case insensitive, verbose - \s+ # whitespace - LIMIT\s+\d+ # LIMIT $ROWS - ;? # optional semi-colon - (\s|;)*$ # remove trailing spaces tabs or semicolons - """, '', sql) - return '{no_limit} LIMIT {limit}'.format(**locals()) + sql_without_limit = utils.get_query_without_limit(sql) + return '{sql_without_limit} LIMIT {limit}'.format(**locals()) return sql @staticmethod diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 52b08273d..7aa5d03ef 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -171,7 +171,7 @@ def execute_sql( # Limit enforced only for retrieving the data, not for the CTA queries. superset_query = SupersetQuery(rendered_query) executed_sql = superset_query.stripped() - SQL_MAX_ROWS = int(app.config.get('SQL_MAX_ROW', None)) + SQL_MAX_ROWS = app.config.get('SQL_MAX_ROW') if not superset_query.is_select() and not database.allow_dml: return handle_error( 'Only `SELECT` statements are allowed against this database') @@ -186,7 +186,8 @@ def execute_sql( query.user_id, start_dttm.strftime('%Y_%m_%d_%H_%M_%S')) executed_sql = superset_query.as_create_table(query.tmp_table_name) query.select_as_cta_used = True - elif (not query.limit and superset_query.is_select() and SQL_MAX_ROWS): + elif (superset_query.is_select() and SQL_MAX_ROWS and + (not query.limit or query.limit > SQL_MAX_ROWS)): query.limit = SQL_MAX_ROWS executed_sql = database.apply_limit_to_sql(executed_sql, query.limit) query.limit_used = True diff --git a/superset/utils.py b/superset/utils.py index 47626aa8e..09131f6d9 100644 --- a/superset/utils.py +++ b/superset/utils.py @@ -18,6 +18,7 @@ import functools import json import logging import os +import re import signal import smtplib import sys @@ -884,15 +885,27 @@ def split_adhoc_filters_into_base_filters(fd): del fd['adhoc_filters'] +def get_query_without_limit(sql): + return re.sub(r""" + (?ix) # case insensitive, verbose + \s+ # whitespace + LIMIT\s+\d+ # LIMIT $ROWS + ;? # optional semi-colon + (\s|;)*$ # remove trailing spaces tabs or semicolons + """, '', sql) + + def get_limit_from_sql(sql): - sql = sql.lower() - limit = None - tokens = sql.split() - try: - if 'limit' in tokens: - limit_pos = tokens.index('limit') + 1 - limit = int(tokens[limit_pos]) - except Exception as e: - # fail quietly so we can get the more intelligible error from the database. - logging.error('Non-numeric limit added.\n{}'.format(e)) - return limit + # returns the limit of the quest or None if it has no limit. + + limit_pattern = re.compile(r""" + (?ix) # case insensitive, verbose + \s+ # whitespace + LIMIT\s+(\d+) # LIMIT $ROWS + ;? # optional semi-colon + (\s|;)*$ # remove trailing spaces tabs or semicolons + """) + matches = limit_pattern.findall(sql) + + if matches: + return int(matches[0]) diff --git a/tests/celery_tests.py b/tests/celery_tests.py index c785702f5..f6d1a2958 100644 --- a/tests/celery_tests.py +++ b/tests/celery_tests.py @@ -196,11 +196,13 @@ class CeleryTestCase(SupersetTestCase): self.assertEqual([{'name': 'Admin'}], df.to_dict(orient='records')) self.assertEqual(QueryStatus.SUCCESS, query.status) self.assertTrue('FROM tmp_async_1' in query.select_sql) + self.assertTrue('LIMIT 666' in query.select_sql) self.assertEqual( 'CREATE TABLE tmp_async_1 AS \nSELECT name FROM ab_role ' "WHERE name='Admin'", query.executed_sql) self.assertEqual(sql_where, query.sql) self.assertEqual(0, query.rows) + self.assertEqual(666, query.limit) self.assertEqual(False, query.limit_used) self.assertEqual(True, query.select_as_cta) self.assertEqual(True, query.select_as_cta_used)