feat: request ids on API related endpoints (#12663)
* feat: request ids on API related endpoints * rename ids to include_ids
This commit is contained in:
parent
11ca7301b5
commit
365770e7c3
|
|
@ -110,7 +110,10 @@ openapi_spec_methods_override = {
|
|||
}
|
||||
},
|
||||
"related": {
|
||||
"get": {"description": "Get a list of all possible owners for a chart."}
|
||||
"get": {
|
||||
"description": "Get a list of all possible owners for a chart. "
|
||||
"Use `owners` has the `column_name` parameter"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ get_related_schema = {
|
|||
"properties": {
|
||||
"page_size": {"type": "integer"},
|
||||
"page": {"type": "integer"},
|
||||
"include_ids": {"type": "array", "items": {"type": "integer"}},
|
||||
"filter": {"type": "string"},
|
||||
},
|
||||
}
|
||||
|
|
@ -213,7 +214,10 @@ class BaseSupersetModelRestApi(ModelRestApi):
|
|||
super().__init__()
|
||||
|
||||
def add_apispec_components(self, api_spec: APISpec) -> None:
|
||||
|
||||
"""
|
||||
Adds extra OpenApi schema spec components, these are declared
|
||||
on the `openapi_spec_component_schemas` class property
|
||||
"""
|
||||
for schema in self.openapi_spec_component_schemas:
|
||||
try:
|
||||
api_spec.components.schema(
|
||||
|
|
@ -271,6 +275,40 @@ class BaseSupersetModelRestApi(ModelRestApi):
|
|||
)
|
||||
return filters
|
||||
|
||||
def _get_text_for_model(self, model: Model, column_name: str) -> str:
|
||||
if column_name in self.text_field_rel_fields:
|
||||
model_column_name = self.text_field_rel_fields.get(column_name)
|
||||
if model_column_name:
|
||||
return getattr(model, model_column_name)
|
||||
return str(model)
|
||||
|
||||
def _get_result_from_rows(
|
||||
self, datamodel: SQLAInterface, rows: List[Model], column_name: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
return [
|
||||
{
|
||||
"value": datamodel.get_pk_value(row),
|
||||
"text": self._get_text_for_model(row, column_name),
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def _add_extra_ids_to_result(
|
||||
self,
|
||||
datamodel: SQLAInterface,
|
||||
column_name: str,
|
||||
ids: List[int],
|
||||
result: List[Dict[str, Any]],
|
||||
) -> None:
|
||||
if ids:
|
||||
# Filter out already present values on the result
|
||||
values = [row["value"] for row in result]
|
||||
ids = [id_ for id_ in ids if id_ not in values]
|
||||
pk_col = datamodel.get_pk()
|
||||
# Fetch requested values from ids
|
||||
extra_rows = db.session.query(datamodel.obj).filter(pk_col.in_(ids)).all()
|
||||
result += self._get_result_from_rows(datamodel, extra_rows, column_name)
|
||||
|
||||
def incr_stats(self, action: str, func_name: str) -> None:
|
||||
"""
|
||||
Proxy function for statsd.incr to impose a key structure for REST API's
|
||||
|
|
@ -424,18 +462,11 @@ class BaseSupersetModelRestApi(ModelRestApi):
|
|||
500:
|
||||
$ref: '#/components/responses/500'
|
||||
"""
|
||||
|
||||
def get_text_for_model(model: Model) -> str:
|
||||
if column_name in self.text_field_rel_fields:
|
||||
model_column_name = self.text_field_rel_fields.get(column_name)
|
||||
if model_column_name:
|
||||
return getattr(model, model_column_name)
|
||||
return str(model)
|
||||
|
||||
if column_name not in self.allowed_rel_fields:
|
||||
self.incr_stats("error", self.related.__name__)
|
||||
return self.response_404()
|
||||
args = kwargs.get("rison", {})
|
||||
|
||||
# handle pagination
|
||||
page, page_size = self._handle_page_args(args)
|
||||
try:
|
||||
|
|
@ -452,15 +483,18 @@ class BaseSupersetModelRestApi(ModelRestApi):
|
|||
# handle filters
|
||||
filters = self._get_related_filter(datamodel, column_name, args.get("filter"))
|
||||
# Make the query
|
||||
count, values = datamodel.query(
|
||||
_, rows = datamodel.query(
|
||||
filters, order_column, order_direction, page=page, page_size=page_size
|
||||
)
|
||||
|
||||
# produce response
|
||||
result = [
|
||||
{"value": datamodel.get_pk_value(value), "text": get_text_for_model(value)}
|
||||
for value in values
|
||||
]
|
||||
return self.response(200, count=count, result=result)
|
||||
result = self._get_result_from_rows(datamodel, rows, column_name)
|
||||
|
||||
# If ids are specified make sure we fetch and include them on the response
|
||||
ids = args.get("include_ids")
|
||||
self._add_extra_ids_to_result(datamodel, column_name, ids, result)
|
||||
|
||||
return self.response(200, count=len(result), result=result)
|
||||
|
||||
@expose("/distinct/<column_name>", methods=["GET"])
|
||||
@protect()
|
||||
|
|
|
|||
|
|
@ -184,48 +184,86 @@ class ApiOwnersTestCaseMixin:
|
|||
|
||||
def test_get_related_owners(self):
|
||||
"""
|
||||
API: Test get related owners
|
||||
API: Test get related owners
|
||||
"""
|
||||
self.login(username="admin")
|
||||
uri = f"api/v1/{self.resource_name}/related/owners"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
users = db.session.query(security_manager.user_model).all()
|
||||
expected_users = [str(user) for user in users]
|
||||
self.assertEqual(response["count"], len(users))
|
||||
assert response["count"] == len(users)
|
||||
# This needs to be implemented like this, because ordering varies between
|
||||
# postgres and mysql
|
||||
response_users = [result["text"] for result in response["result"]]
|
||||
for expected_user in expected_users:
|
||||
self.assertIn(expected_user, response_users)
|
||||
assert expected_user in response_users
|
||||
|
||||
def test_get_filter_related_owners(self):
|
||||
"""
|
||||
API: Test get filter related owners
|
||||
API: Test get filter related owners
|
||||
"""
|
||||
self.login(username="admin")
|
||||
argument = {"filter": "gamma"}
|
||||
uri = f"api/v1/{self.resource_name}/related/owners?q={prison.dumps(argument)}"
|
||||
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(3, response["count"])
|
||||
assert 3 == response["count"]
|
||||
sorted_results = sorted(response["result"], key=lambda value: value["text"])
|
||||
expected_results = [
|
||||
{"text": "gamma user", "value": 2},
|
||||
{"text": "gamma2 user", "value": 3},
|
||||
{"text": "gamma_sqllab user", "value": 4},
|
||||
]
|
||||
self.assertEqual(expected_results, sorted_results)
|
||||
assert expected_results == sorted_results
|
||||
|
||||
def test_get_ids_related_owners(self):
|
||||
"""
|
||||
API: Test get filter related owners
|
||||
"""
|
||||
self.login(username="admin")
|
||||
argument = {"filter": "gamma_sqllab", "include_ids": [2]}
|
||||
uri = f"api/v1/{self.resource_name}/related/owners?q={prison.dumps(argument)}"
|
||||
|
||||
rv = self.client.get(uri)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
assert rv.status_code == 200
|
||||
assert 2 == response["count"]
|
||||
sorted_results = sorted(response["result"], key=lambda value: value["text"])
|
||||
expected_results = [
|
||||
{"text": "gamma user", "value": 2},
|
||||
{"text": "gamma_sqllab user", "value": 4},
|
||||
]
|
||||
assert expected_results == sorted_results
|
||||
|
||||
def test_get_repeated_ids_related_owners(self):
|
||||
"""
|
||||
API: Test get filter related owners
|
||||
"""
|
||||
self.login(username="admin")
|
||||
argument = {"filter": "gamma_sqllab", "include_ids": [2, 4]}
|
||||
uri = f"api/v1/{self.resource_name}/related/owners?q={prison.dumps(argument)}"
|
||||
|
||||
rv = self.client.get(uri)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
assert rv.status_code == 200
|
||||
assert 2 == response["count"]
|
||||
sorted_results = sorted(response["result"], key=lambda value: value["text"])
|
||||
expected_results = [
|
||||
{"text": "gamma user", "value": 2},
|
||||
{"text": "gamma_sqllab user", "value": 4},
|
||||
]
|
||||
assert expected_results == sorted_results
|
||||
|
||||
def test_get_related_fail(self):
|
||||
"""
|
||||
API: Test get related fail
|
||||
API: Test get related fail
|
||||
"""
|
||||
self.login(username="admin")
|
||||
uri = f"api/v1/{self.resource_name}/related/owner"
|
||||
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
|
|
|
|||
Loading…
Reference in New Issue