fix(api): return total count on related endpoint (#16397)

* fix(api): return total count on related endpoint

* update response code from 400 to 422
This commit is contained in:
Ville Brofeldt 2021-08-24 15:07:58 +03:00 committed by GitHub
parent 1fc9318594
commit f6637cac7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 45 additions and 4 deletions

View File

@ -487,6 +487,12 @@ class BaseSupersetModelRestApi(ModelRestApi):
# handle pagination
page, page_size = self._handle_page_args(args)
ids = args.get("include_ids")
if page and ids:
# pagination with forced ids is not supported
return self.response_422()
try:
datamodel = self.datamodel.get_related_interface(column_name)
except KeyError:
@ -501,7 +507,7 @@ class BaseSupersetModelRestApi(ModelRestApi):
# handle filters
filters = self._get_related_filter(datamodel, column_name, args.get("filter"))
# Make the query
_, rows = datamodel.query(
total_rows, rows = datamodel.query(
filters, order_column, order_direction, page=page, page_size=page_size
)
@ -509,10 +515,11 @@ class BaseSupersetModelRestApi(ModelRestApi):
result = self._get_result_from_rows(datamodel, rows, column_name)
# If ids are specified make sure we fetch and include them on the response
ids = args.get("include_ids")
self._add_extra_ids_to_result(datamodel, column_name, ids, result)
if ids:
self._add_extra_ids_to_result(datamodel, column_name, ids, result)
total_rows = len(result)
return self.response(200, count=len(result), result=result)
return self.response(200, count=total_rows, result=result)
@expose("/distinct/<column_name>", methods=["GET"])
@protect()

View File

@ -202,6 +202,40 @@ class ApiOwnersTestCaseMixin:
for expected_user in expected_users:
assert expected_user in response_users
def test_get_related_owners_paginated(self):
"""
API: Test get related owners with pagination
"""
self.login(username="admin")
page_size = 1
argument = {"page_size": page_size}
uri = f"api/v1/{self.resource_name}/related/owners?q={prison.dumps(argument)}"
rv = self.client.get(uri)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
users = db.session.query(security_manager.user_model).all()
# the count should correspond with the total number of users
assert response["count"] == len(users)
# the length of the result should be at most equal to the page size
assert len(response["result"]) == min(page_size, len(users))
# make sure all received users are included in the full set of users
all_users = [str(user) for user in users]
for received_user in [result["text"] for result in response["result"]]:
assert received_user in all_users
def test_get_ids_related_owners_paginated(self):
"""
API: Test get related owners with pagination returns 422
"""
self.login(username="admin")
argument = {"page": 1, "page_size": 1, "include_ids": [2]}
uri = f"api/v1/{self.resource_name}/related/owners?q={prison.dumps(argument)}"
rv = self.client.get(uri)
assert rv.status_code == 422
def test_get_filter_related_owners(self):
"""
API: Test get filter related owners