feat: add support for query offset (#10010)
* feat: add support for query offset * Address comments and add new tests
This commit is contained in:
parent
2a3305e7dd
commit
315518d2d2
|
|
@ -16,8 +16,9 @@
|
|||
# under the License.
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from flask_babel import gettext as _
|
||||
from marshmallow import fields, post_load, Schema, validate, ValidationError
|
||||
from marshmallow.validate import Length
|
||||
from marshmallow.validate import Length, Range
|
||||
|
||||
from superset.common.query_context import QueryContext
|
||||
from superset.exceptions import SupersetException
|
||||
|
|
@ -663,6 +664,15 @@ class ChartDataQueryObjectSchema(Schema):
|
|||
)
|
||||
row_limit = fields.Integer(
|
||||
description='Maximum row count. Default: `config["ROW_LIMIT"]`',
|
||||
validate=[
|
||||
Range(min=1, error=_("`row_limit` must be greater than or equal to 1"))
|
||||
],
|
||||
)
|
||||
row_offset = fields.Integer(
|
||||
description="Number of rows to skip. Default: `0`",
|
||||
validate=[
|
||||
Range(min=0, error=_("`row_offset` must be greater than or equal to 0"))
|
||||
],
|
||||
)
|
||||
order_desc = fields.Boolean(
|
||||
description="Reverse order. Default: `false`", required=False
|
||||
|
|
|
|||
|
|
@ -25,14 +25,13 @@ import numpy as np
|
|||
import pandas as pd
|
||||
|
||||
from superset import app, cache, db, security_manager
|
||||
from superset.common.query_object import QueryObject
|
||||
from superset.connectors.base.models import BaseDatasource
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.stats_logger import BaseStatsLogger
|
||||
from superset.utils import core as utils
|
||||
from superset.utils.core import DTTM_ALIAS
|
||||
|
||||
from .query_object import QueryObject
|
||||
|
||||
config = app.config
|
||||
stats_logger: BaseStatsLogger = config["STATS_LOGGER"]
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -156,6 +155,7 @@ class QueryContext:
|
|||
query_obj.metrics = []
|
||||
query_obj.post_processing = []
|
||||
query_obj.row_limit = min(row_limit, config["SAMPLES_ROW_LIMIT"])
|
||||
query_obj.row_offset = 0
|
||||
query_obj.columns = [o.column_name for o in self.datasource.columns]
|
||||
payload = self.get_df_payload(query_obj)
|
||||
df = payload["df"]
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from superset.typing import Metric
|
|||
from superset.utils import core as utils, pandas_postprocessing
|
||||
from superset.views.utils import get_time_range_endpoints
|
||||
|
||||
config = app.config
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# TODO: Type Metrics dictionary with TypedDict when it becomes a vanilla python type
|
||||
|
|
@ -66,6 +67,7 @@ class QueryObject:
|
|||
groupby: List[str]
|
||||
metrics: List[Union[Dict[str, Any], str]]
|
||||
row_limit: int
|
||||
row_offset: int
|
||||
filter: List[Dict[str, Any]]
|
||||
timeseries_limit: int
|
||||
timeseries_limit_metric: Optional[Metric]
|
||||
|
|
@ -85,7 +87,8 @@ class QueryObject:
|
|||
time_shift: Optional[str] = None,
|
||||
is_timeseries: bool = False,
|
||||
timeseries_limit: int = 0,
|
||||
row_limit: int = app.config["ROW_LIMIT"],
|
||||
row_limit: Optional[int] = None,
|
||||
row_offset: Optional[int] = None,
|
||||
timeseries_limit_metric: Optional[Metric] = None,
|
||||
order_desc: bool = True,
|
||||
extras: Optional[Dict[str, Any]] = None,
|
||||
|
|
@ -100,10 +103,10 @@ class QueryObject:
|
|||
self.granularity = granularity
|
||||
self.from_dttm, self.to_dttm = utils.get_since_until(
|
||||
relative_start=extras.get(
|
||||
"relative_start", app.config["DEFAULT_RELATIVE_START_TIME"]
|
||||
"relative_start", config["DEFAULT_RELATIVE_START_TIME"]
|
||||
),
|
||||
relative_end=extras.get(
|
||||
"relative_end", app.config["DEFAULT_RELATIVE_END_TIME"]
|
||||
"relative_end", config["DEFAULT_RELATIVE_END_TIME"]
|
||||
),
|
||||
time_range=time_range,
|
||||
time_shift=time_shift,
|
||||
|
|
@ -123,14 +126,15 @@ class QueryObject:
|
|||
for metric in metrics
|
||||
]
|
||||
|
||||
self.row_limit = row_limit
|
||||
self.row_limit = row_limit or config["ROW_LIMIT"]
|
||||
self.row_offset = row_offset or 0
|
||||
self.filter = filters or []
|
||||
self.timeseries_limit = timeseries_limit
|
||||
self.timeseries_limit_metric = timeseries_limit_metric
|
||||
self.order_desc = order_desc
|
||||
self.extras = extras
|
||||
|
||||
if app.config["SIP_15_ENABLED"] and "time_range_endpoints" not in self.extras:
|
||||
if config["SIP_15_ENABLED"] and "time_range_endpoints" not in self.extras:
|
||||
self.extras["time_range_endpoints"] = get_time_range_endpoints(form_data={})
|
||||
|
||||
self.columns = columns or []
|
||||
|
|
@ -184,6 +188,7 @@ class QueryObject:
|
|||
"is_timeseries": self.is_timeseries,
|
||||
"metrics": self.metrics,
|
||||
"row_limit": self.row_limit,
|
||||
"row_offset": self.row_offset,
|
||||
"filter": self.filter,
|
||||
"timeseries_limit": self.timeseries_limit,
|
||||
"timeseries_limit_metric": self.timeseries_limit_metric,
|
||||
|
|
|
|||
|
|
@ -1179,6 +1179,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
timeseries_limit: Optional[int] = None,
|
||||
timeseries_limit_metric: Optional[Metric] = None,
|
||||
row_limit: Optional[int] = None,
|
||||
row_offset: Optional[int] = None,
|
||||
inner_from_dttm: Optional[datetime] = None,
|
||||
inner_to_dttm: Optional[datetime] = None,
|
||||
orderby: Optional[Any] = None,
|
||||
|
|
@ -1192,6 +1193,8 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
# TODO refactor into using a TBD Query object
|
||||
client = client or self.cluster.get_pydruid_client()
|
||||
row_limit = row_limit or conf.get("ROW_LIMIT")
|
||||
if row_offset:
|
||||
raise SupersetException("Offset not implemented for Druid connector")
|
||||
|
||||
if not is_timeseries:
|
||||
granularity = "all"
|
||||
|
|
|
|||
|
|
@ -741,6 +741,7 @@ class SqlaTable(Model, BaseDatasource):
|
|||
timeseries_limit: int = 15,
|
||||
timeseries_limit_metric: Optional[Metric] = None,
|
||||
row_limit: Optional[int] = None,
|
||||
row_offset: Optional[int] = None,
|
||||
inner_from_dttm: Optional[datetime] = None,
|
||||
inner_to_dttm: Optional[datetime] = None,
|
||||
orderby: Optional[List[Tuple[ColumnElement, bool]]] = None,
|
||||
|
|
@ -753,6 +754,7 @@ class SqlaTable(Model, BaseDatasource):
|
|||
"groupby": groupby,
|
||||
"metrics": metrics,
|
||||
"row_limit": row_limit,
|
||||
"row_offset": row_offset,
|
||||
"to_dttm": to_dttm,
|
||||
"filter": filter,
|
||||
"columns": {col.column_name: col for col in self.columns},
|
||||
|
|
@ -967,6 +969,8 @@ class SqlaTable(Model, BaseDatasource):
|
|||
|
||||
if row_limit:
|
||||
qry = qry.limit(row_limit)
|
||||
if row_offset:
|
||||
qry = qry.offset(row_offset)
|
||||
|
||||
if (
|
||||
is_timeseries
|
||||
|
|
|
|||
|
|
@ -14,22 +14,27 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# isort:skip_file
|
||||
"""Unit tests for Superset"""
|
||||
import json
|
||||
from typing import List, Optional
|
||||
from unittest import mock
|
||||
|
||||
import prison
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
import tests.test_app
|
||||
from tests.test_app import app
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.extensions import db, security_manager
|
||||
from superset.models.dashboard import Dashboard
|
||||
from superset.models.slice import Slice
|
||||
from superset.utils import core as utils
|
||||
from tests.base_api_tests import ApiOwnersTestCaseMixin
|
||||
from tests.base_tests import SupersetTestCase
|
||||
from tests.fixtures.query_context import get_query_context
|
||||
|
||||
CHART_DATA_URI = "api/v1/chart/data"
|
||||
|
||||
|
||||
class ChartApiTests(SupersetTestCase, ApiOwnersTestCaseMixin):
|
||||
resource_name = "chart"
|
||||
|
|
@ -634,32 +639,88 @@ class ChartApiTests(SupersetTestCase, ApiOwnersTestCaseMixin):
|
|||
data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(data["count"], 0)
|
||||
|
||||
def test_chart_data(self):
|
||||
def test_chart_data_simple(self):
|
||||
"""
|
||||
Query API: Test chart data query
|
||||
Chart data API: Test chart data query
|
||||
"""
|
||||
self.login(username="admin")
|
||||
table = self.get_table_by_name("birth_names")
|
||||
payload = get_query_context(table.name, table.id, table.type)
|
||||
uri = "api/v1/chart/data"
|
||||
rv = self.post_assert_metric(uri, payload, "data")
|
||||
request_payload = get_query_context(table.name, table.id, table.type)
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(data["result"][0]["rowcount"], 100)
|
||||
|
||||
def test_chart_data_limit_offset(self):
|
||||
"""
|
||||
Chart data API: Test chart data query with limit and offset
|
||||
"""
|
||||
self.login(username="admin")
|
||||
table = self.get_table_by_name("birth_names")
|
||||
request_payload = get_query_context(table.name, table.id, table.type)
|
||||
request_payload["queries"][0]["row_limit"] = 5
|
||||
request_payload["queries"][0]["row_offset"] = 0
|
||||
request_payload["queries"][0]["orderby"] = [["name", True]]
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
result = response_payload["result"][0]
|
||||
self.assertEqual(result["rowcount"], 5)
|
||||
|
||||
# ensure that offset works properly
|
||||
offset = 2
|
||||
expected_name = result["data"][offset]["name"]
|
||||
request_payload["queries"][0]["row_offset"] = offset
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
result = response_payload["result"][0]
|
||||
self.assertEqual(result["rowcount"], 5)
|
||||
self.assertEqual(result["data"][0]["name"], expected_name)
|
||||
|
||||
@mock.patch(
|
||||
"superset.common.query_object.config", {**app.config, "ROW_LIMIT": 7},
|
||||
)
|
||||
def test_chart_data_default_row_limit(self):
|
||||
"""
|
||||
Chart data API: Ensure row count doesn't exceed default limit
|
||||
"""
|
||||
self.login(username="admin")
|
||||
table = self.get_table_by_name("birth_names")
|
||||
request_payload = get_query_context(table.name, table.id, table.type)
|
||||
del request_payload["queries"][0]["row_limit"]
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
result = response_payload["result"][0]
|
||||
self.assertEqual(result["rowcount"], 7)
|
||||
|
||||
@mock.patch(
|
||||
"superset.common.query_context.config", {**app.config, "SAMPLES_ROW_LIMIT": 5},
|
||||
)
|
||||
def test_chart_data_default_sample_limit(self):
|
||||
"""
|
||||
Chart data API: Ensure sample response row count doesn't exceed default limit
|
||||
"""
|
||||
self.login(username="admin")
|
||||
table = self.get_table_by_name("birth_names")
|
||||
request_payload = get_query_context(table.name, table.id, table.type)
|
||||
request_payload["result_type"] = utils.ChartDataResultType.SAMPLES
|
||||
request_payload["queries"][0]["row_limit"] = 10
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
result = response_payload["result"][0]
|
||||
self.assertEqual(result["rowcount"], 5)
|
||||
|
||||
def test_chart_data_with_invalid_datasource(self):
|
||||
"""Query API: Test chart data query with invalid schema
|
||||
"""Chart data API: Test chart data query with invalid schema
|
||||
"""
|
||||
self.login(username="admin")
|
||||
table = self.get_table_by_name("birth_names")
|
||||
payload = get_query_context(table.name, table.id, table.type)
|
||||
payload["datasource"] = "abc"
|
||||
uri = "api/v1/chart/data"
|
||||
rv = self.post_assert_metric(uri, payload, "data")
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, payload, "data")
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
|
||||
def test_chart_data_with_invalid_enum_value(self):
|
||||
"""Query API: Test chart data query with invalid enum value
|
||||
"""Chart data API: Test chart data query with invalid enum value
|
||||
"""
|
||||
self.login(username="admin")
|
||||
table = self.get_table_by_name("birth_names")
|
||||
|
|
@ -668,19 +729,17 @@ class ChartApiTests(SupersetTestCase, ApiOwnersTestCaseMixin):
|
|||
"abc",
|
||||
"EXCLUSIVE",
|
||||
]
|
||||
uri = "api/v1/chart/data"
|
||||
rv = self.client.post(uri, json=payload)
|
||||
rv = self.client.post(CHART_DATA_URI, json=payload)
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
|
||||
def test_query_exec_not_allowed(self):
|
||||
"""
|
||||
Query API: Test chart data query not allowed
|
||||
Chart data API: Test chart data query not allowed
|
||||
"""
|
||||
self.login(username="gamma")
|
||||
table = self.get_table_by_name("birth_names")
|
||||
payload = get_query_context(table.name, table.id, table.type)
|
||||
uri = "api/v1/chart/data"
|
||||
rv = self.post_assert_metric(uri, payload, "data")
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, payload, "data")
|
||||
self.assertEqual(rv.status_code, 401)
|
||||
|
||||
def test_datasources(self):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,61 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Unit tests for Superset"""
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
from superset.charts.schemas import ChartDataQueryContextSchema
|
||||
from superset.common.query_context import QueryContext
|
||||
from tests.base_tests import SupersetTestCase
|
||||
from tests.fixtures.query_context import get_query_context
|
||||
from tests.test_app import app
|
||||
|
||||
|
||||
def load_query_context(payload: Dict[str, Any]) -> Tuple[QueryContext, Dict[str, Any]]:
|
||||
return ChartDataQueryContextSchema().load(payload)
|
||||
|
||||
|
||||
class SchemaTestCase(SupersetTestCase):
|
||||
def test_query_context_limit_and_offset(self):
|
||||
self.login(username="admin")
|
||||
table_name = "birth_names"
|
||||
table = self.get_table_by_name(table_name)
|
||||
payload = get_query_context(table.name, table.id, table.type)
|
||||
|
||||
# Use defaults
|
||||
payload["queries"][0].pop("row_limit", None)
|
||||
payload["queries"][0].pop("row_offset", None)
|
||||
query_context, errors = load_query_context(payload)
|
||||
self.assertEqual(errors, {})
|
||||
query_object = query_context.queries[0]
|
||||
self.assertEqual(query_object.row_limit, app.config["ROW_LIMIT"])
|
||||
self.assertEqual(query_object.row_offset, 0)
|
||||
|
||||
# Valid limit and offset
|
||||
payload["queries"][0]["row_limit"] = 100
|
||||
payload["queries"][0]["row_offset"] = 200
|
||||
query_context, errors = ChartDataQueryContextSchema().load(payload)
|
||||
self.assertEqual(errors, {})
|
||||
query_object = query_context.queries[0]
|
||||
self.assertEqual(query_object.row_limit, 100)
|
||||
self.assertEqual(query_object.row_offset, 200)
|
||||
|
||||
# too low limit and offset
|
||||
payload["queries"][0]["row_limit"] = 0
|
||||
payload["queries"][0]["row_offset"] = -1
|
||||
query_context, errors = ChartDataQueryContextSchema().load(payload)
|
||||
self.assertIn("row_limit", errors["queries"][0])
|
||||
self.assertIn("row_offset", errors["queries"][0])
|
||||
Loading…
Reference in New Issue