diff --git a/superset/views/base_api.py b/superset/views/base_api.py index 7fb63ebb2..e80134605 100644 --- a/superset/views/base_api.py +++ b/superset/views/base_api.py @@ -487,6 +487,12 @@ class BaseSupersetModelRestApi(ModelRestApi): # handle pagination page, page_size = self._handle_page_args(args) + + ids = args.get("include_ids") + if page and ids: + # pagination with forced ids is not supported + return self.response_422() + try: datamodel = self.datamodel.get_related_interface(column_name) except KeyError: @@ -501,7 +507,7 @@ class BaseSupersetModelRestApi(ModelRestApi): # handle filters filters = self._get_related_filter(datamodel, column_name, args.get("filter")) # Make the query - _, rows = datamodel.query( + total_rows, rows = datamodel.query( filters, order_column, order_direction, page=page, page_size=page_size ) @@ -509,10 +515,11 @@ class BaseSupersetModelRestApi(ModelRestApi): result = self._get_result_from_rows(datamodel, rows, column_name) # If ids are specified make sure we fetch and include them on the response - ids = args.get("include_ids") - self._add_extra_ids_to_result(datamodel, column_name, ids, result) + if ids: + self._add_extra_ids_to_result(datamodel, column_name, ids, result) + total_rows = len(result) - return self.response(200, count=len(result), result=result) + return self.response(200, count=total_rows, result=result) @expose("/distinct/", methods=["GET"]) @protect() diff --git a/tests/integration_tests/base_api_tests.py b/tests/integration_tests/base_api_tests.py index e6e795f4d..a76346149 100644 --- a/tests/integration_tests/base_api_tests.py +++ b/tests/integration_tests/base_api_tests.py @@ -202,6 +202,40 @@ class ApiOwnersTestCaseMixin: for expected_user in expected_users: assert expected_user in response_users + def test_get_related_owners_paginated(self): + """ + API: Test get related owners with pagination + """ + self.login(username="admin") + page_size = 1 + argument = {"page_size": page_size} + uri = f"api/v1/{self.resource_name}/related/owners?q={prison.dumps(argument)}" + rv = self.client.get(uri) + assert rv.status_code == 200 + response = json.loads(rv.data.decode("utf-8")) + users = db.session.query(security_manager.user_model).all() + + # the count should correspond with the total number of users + assert response["count"] == len(users) + + # the length of the result should be at most equal to the page size + assert len(response["result"]) == min(page_size, len(users)) + + # make sure all received users are included in the full set of users + all_users = [str(user) for user in users] + for received_user in [result["text"] for result in response["result"]]: + assert received_user in all_users + + def test_get_ids_related_owners_paginated(self): + """ + API: Test get related owners with pagination returns 422 + """ + self.login(username="admin") + argument = {"page": 1, "page_size": 1, "include_ids": [2]} + uri = f"api/v1/{self.resource_name}/related/owners?q={prison.dumps(argument)}" + rv = self.client.get(uri) + assert rv.status_code == 422 + def test_get_filter_related_owners(self): """ API: Test get filter related owners