diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index d62707e0b..c44ae757a 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -110,7 +110,10 @@ openapi_spec_methods_override = { } }, "related": { - "get": {"description": "Get a list of all possible owners for a chart."} + "get": { + "description": "Get a list of all possible owners for a chart. " + "Use `owners` has the `column_name` parameter" + } }, } diff --git a/superset/views/base_api.py b/superset/views/base_api.py index 3e11f9230..295605810 100644 --- a/superset/views/base_api.py +++ b/superset/views/base_api.py @@ -46,6 +46,7 @@ get_related_schema = { "properties": { "page_size": {"type": "integer"}, "page": {"type": "integer"}, + "include_ids": {"type": "array", "items": {"type": "integer"}}, "filter": {"type": "string"}, }, } @@ -213,7 +214,10 @@ class BaseSupersetModelRestApi(ModelRestApi): super().__init__() def add_apispec_components(self, api_spec: APISpec) -> None: - + """ + Adds extra OpenApi schema spec components, these are declared + on the `openapi_spec_component_schemas` class property + """ for schema in self.openapi_spec_component_schemas: try: api_spec.components.schema( @@ -271,6 +275,40 @@ class BaseSupersetModelRestApi(ModelRestApi): ) return filters + def _get_text_for_model(self, model: Model, column_name: str) -> str: + if column_name in self.text_field_rel_fields: + model_column_name = self.text_field_rel_fields.get(column_name) + if model_column_name: + return getattr(model, model_column_name) + return str(model) + + def _get_result_from_rows( + self, datamodel: SQLAInterface, rows: List[Model], column_name: str + ) -> List[Dict[str, Any]]: + return [ + { + "value": datamodel.get_pk_value(row), + "text": self._get_text_for_model(row, column_name), + } + for row in rows + ] + + def _add_extra_ids_to_result( + self, + datamodel: SQLAInterface, + column_name: str, + ids: List[int], + result: List[Dict[str, Any]], + ) -> None: + if ids: + # Filter out already present values on the result + values = [row["value"] for row in result] + ids = [id_ for id_ in ids if id_ not in values] + pk_col = datamodel.get_pk() + # Fetch requested values from ids + extra_rows = db.session.query(datamodel.obj).filter(pk_col.in_(ids)).all() + result += self._get_result_from_rows(datamodel, extra_rows, column_name) + def incr_stats(self, action: str, func_name: str) -> None: """ Proxy function for statsd.incr to impose a key structure for REST API's @@ -424,18 +462,11 @@ class BaseSupersetModelRestApi(ModelRestApi): 500: $ref: '#/components/responses/500' """ - - def get_text_for_model(model: Model) -> str: - if column_name in self.text_field_rel_fields: - model_column_name = self.text_field_rel_fields.get(column_name) - if model_column_name: - return getattr(model, model_column_name) - return str(model) - if column_name not in self.allowed_rel_fields: self.incr_stats("error", self.related.__name__) return self.response_404() args = kwargs.get("rison", {}) + # handle pagination page, page_size = self._handle_page_args(args) try: @@ -452,15 +483,18 @@ class BaseSupersetModelRestApi(ModelRestApi): # handle filters filters = self._get_related_filter(datamodel, column_name, args.get("filter")) # Make the query - count, values = datamodel.query( + _, rows = datamodel.query( filters, order_column, order_direction, page=page, page_size=page_size ) + # produce response - result = [ - {"value": datamodel.get_pk_value(value), "text": get_text_for_model(value)} - for value in values - ] - return self.response(200, count=count, result=result) + result = self._get_result_from_rows(datamodel, rows, column_name) + + # If ids are specified make sure we fetch and include them on the response + ids = args.get("include_ids") + self._add_extra_ids_to_result(datamodel, column_name, ids, result) + + return self.response(200, count=len(result), result=result) @expose("/distinct/", methods=["GET"]) @protect() diff --git a/tests/base_api_tests.py b/tests/base_api_tests.py index 3dd21dcfc..f23b01e8a 100644 --- a/tests/base_api_tests.py +++ b/tests/base_api_tests.py @@ -184,48 +184,86 @@ class ApiOwnersTestCaseMixin: def test_get_related_owners(self): """ - API: Test get related owners + API: Test get related owners """ self.login(username="admin") uri = f"api/v1/{self.resource_name}/related/owners" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = json.loads(rv.data.decode("utf-8")) users = db.session.query(security_manager.user_model).all() expected_users = [str(user) for user in users] - self.assertEqual(response["count"], len(users)) + assert response["count"] == len(users) # This needs to be implemented like this, because ordering varies between # postgres and mysql response_users = [result["text"] for result in response["result"]] for expected_user in expected_users: - self.assertIn(expected_user, response_users) + assert expected_user in response_users def test_get_filter_related_owners(self): """ - API: Test get filter related owners + API: Test get filter related owners """ self.login(username="admin") argument = {"filter": "gamma"} uri = f"api/v1/{self.resource_name}/related/owners?q={prison.dumps(argument)}" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(3, response["count"]) + assert 3 == response["count"] sorted_results = sorted(response["result"], key=lambda value: value["text"]) expected_results = [ {"text": "gamma user", "value": 2}, {"text": "gamma2 user", "value": 3}, {"text": "gamma_sqllab user", "value": 4}, ] - self.assertEqual(expected_results, sorted_results) + assert expected_results == sorted_results + + def test_get_ids_related_owners(self): + """ + API: Test get filter related owners + """ + self.login(username="admin") + argument = {"filter": "gamma_sqllab", "include_ids": [2]} + uri = f"api/v1/{self.resource_name}/related/owners?q={prison.dumps(argument)}" + + rv = self.client.get(uri) + response = json.loads(rv.data.decode("utf-8")) + assert rv.status_code == 200 + assert 2 == response["count"] + sorted_results = sorted(response["result"], key=lambda value: value["text"]) + expected_results = [ + {"text": "gamma user", "value": 2}, + {"text": "gamma_sqllab user", "value": 4}, + ] + assert expected_results == sorted_results + + def test_get_repeated_ids_related_owners(self): + """ + API: Test get filter related owners + """ + self.login(username="admin") + argument = {"filter": "gamma_sqllab", "include_ids": [2, 4]} + uri = f"api/v1/{self.resource_name}/related/owners?q={prison.dumps(argument)}" + + rv = self.client.get(uri) + response = json.loads(rv.data.decode("utf-8")) + assert rv.status_code == 200 + assert 2 == response["count"] + sorted_results = sorted(response["result"], key=lambda value: value["text"]) + expected_results = [ + {"text": "gamma user", "value": 2}, + {"text": "gamma_sqllab user", "value": 4}, + ] + assert expected_results == sorted_results def test_get_related_fail(self): """ - API: Test get related fail + API: Test get related fail """ self.login(username="admin") uri = f"api/v1/{self.resource_name}/related/owner" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404