feat: dataset REST API for distinct values (#10595)
* feat: dataset REST API for distinct values * add tests and fix lint * fix mypy, and tests * fix docs * fix test * lint * fix test
This commit is contained in:
parent
f868580f64
commit
692266f4f5
|
|
@ -55,6 +55,7 @@ class RouteMethod: # pylint: disable=too-few-public-methods
|
|||
POST = "post"
|
||||
PUT = "put"
|
||||
RELATED = "related"
|
||||
DISTINCT = "distinct"
|
||||
|
||||
# Commonly used sets
|
||||
API_SET = {API_CREATE, API_DELETE, API_GET, API_READ, API_UPDATE}
|
||||
|
|
|
|||
|
|
@ -67,6 +67,7 @@ class DatasetRestApi(BaseSupersetModelRestApi):
|
|||
include_route_methods = RouteMethod.REST_MODEL_VIEW_CRUD_SET | {
|
||||
RouteMethod.EXPORT,
|
||||
RouteMethod.RELATED,
|
||||
RouteMethod.DISTINCT,
|
||||
"refresh",
|
||||
"related_objects",
|
||||
}
|
||||
|
|
@ -151,6 +152,7 @@ class DatasetRestApi(BaseSupersetModelRestApi):
|
|||
}
|
||||
filter_rel_fields = {"database": [["id", DatabaseFilter, lambda: []]]}
|
||||
allowed_rel_fields = {"database", "owners"}
|
||||
allowed_distinct_fields = {"schema"}
|
||||
|
||||
openapi_spec_component_schemas = (DatasetRelatedObjectsResponse,)
|
||||
|
||||
|
|
|
|||
|
|
@ -19,12 +19,15 @@ import logging
|
|||
from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
from apispec import APISpec
|
||||
from apispec.exceptions import DuplicateComponentNameError
|
||||
from flask import Blueprint, Response
|
||||
from flask_appbuilder import AppBuilder, Model, ModelRestApi
|
||||
from flask_appbuilder import AppBuilder, ModelRestApi
|
||||
from flask_appbuilder.api import expose, protect, rison, safe
|
||||
from flask_appbuilder.models.filters import BaseFilter, Filters
|
||||
from flask_appbuilder.models.sqla.filters import FilterStartsWith
|
||||
from marshmallow import Schema
|
||||
from flask_appbuilder.models.sqla.interface import SQLAInterface
|
||||
from marshmallow import fields, Schema
|
||||
from sqlalchemy import distinct, func
|
||||
|
||||
from superset.stats_logger import BaseStatsLogger
|
||||
from superset.typing import FlaskResponse
|
||||
|
|
@ -41,6 +44,25 @@ get_related_schema = {
|
|||
}
|
||||
|
||||
|
||||
class RelatedResultResponseSchema(Schema):
|
||||
value = fields.Integer(description="The related item identifier")
|
||||
text = fields.String(description="The related item string representation")
|
||||
|
||||
|
||||
class RelatedResponseSchema(Schema):
|
||||
count = fields.Integer(description="The total number of related values")
|
||||
result = fields.List(fields.Nested(RelatedResultResponseSchema))
|
||||
|
||||
|
||||
class DistinctResultResponseSchema(Schema):
|
||||
text = fields.String(description="The distinct item")
|
||||
|
||||
|
||||
class DistincResponseSchema(Schema):
|
||||
count = fields.Integer(description="The total number of distinct values")
|
||||
result = fields.List(fields.Nested(DistinctResultResponseSchema))
|
||||
|
||||
|
||||
def statsd_metrics(f: Callable[..., Any]) -> Callable[..., Any]:
|
||||
"""
|
||||
Handle sending all statsd metrics from the REST API
|
||||
|
|
@ -78,6 +100,7 @@ class BaseSupersetModelRestApi(ModelRestApi):
|
|||
"bulk_delete": "delete",
|
||||
"info": "list",
|
||||
"related": "list",
|
||||
"distinct": "list",
|
||||
"thumbnail": "list",
|
||||
"refresh": "edit",
|
||||
"data": "list",
|
||||
|
|
@ -112,6 +135,8 @@ class BaseSupersetModelRestApi(ModelRestApi):
|
|||
""" # pylint: disable=pointless-string-statement
|
||||
allowed_rel_fields: Set[str] = set()
|
||||
|
||||
allowed_distinct_fields: Set[str] = set()
|
||||
|
||||
openapi_spec_component_schemas: Tuple[Type[Schema], ...] = tuple()
|
||||
"""
|
||||
Add extra schemas to the OpenAPI component schemas section
|
||||
|
|
@ -123,15 +148,29 @@ class BaseSupersetModelRestApi(ModelRestApi):
|
|||
show_columns: List[str]
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Setup statsd
|
||||
self.stats_logger = BaseStatsLogger()
|
||||
# Add base API spec base query parameter schemas
|
||||
if self.apispec_parameter_schemas is None: # type: ignore
|
||||
self.apispec_parameter_schemas = {}
|
||||
self.apispec_parameter_schemas["get_related_schema"] = get_related_schema
|
||||
if self.openapi_spec_component_schemas is None:
|
||||
self.openapi_spec_component_schemas = ()
|
||||
self.openapi_spec_component_schemas = self.openapi_spec_component_schemas + (
|
||||
RelatedResponseSchema,
|
||||
DistincResponseSchema,
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
def add_apispec_components(self, api_spec: APISpec) -> None:
|
||||
|
||||
for schema in self.openapi_spec_component_schemas:
|
||||
api_spec.components.schema(
|
||||
schema.__name__, schema=schema,
|
||||
)
|
||||
try:
|
||||
api_spec.components.schema(
|
||||
schema.__name__, schema=schema,
|
||||
)
|
||||
except DuplicateComponentNameError:
|
||||
pass
|
||||
super().add_apispec_components(api_spec)
|
||||
|
||||
def create_blueprint(
|
||||
|
|
@ -153,7 +192,7 @@ class BaseSupersetModelRestApi(ModelRestApi):
|
|||
super()._init_properties()
|
||||
|
||||
def _get_related_filter(
|
||||
self, datamodel: Model, column_name: str, value: str
|
||||
self, datamodel: SQLAInterface, column_name: str, value: str
|
||||
) -> Filters:
|
||||
filter_field = self.related_field_filters.get(column_name)
|
||||
if isinstance(filter_field, str):
|
||||
|
|
@ -170,6 +209,18 @@ class BaseSupersetModelRestApi(ModelRestApi):
|
|||
)
|
||||
return filters
|
||||
|
||||
def _get_distinct_filter(self, column_name: str, value: str) -> Filters:
|
||||
filter_field = RelatedFieldFilter(column_name, FilterStartsWith)
|
||||
filter_field = cast(RelatedFieldFilter, filter_field)
|
||||
search_columns = [filter_field.field_name] if filter_field else None
|
||||
filters = self.datamodel.get_filters(search_columns)
|
||||
filters.add_filter_list(self.base_filters)
|
||||
if value and filter_field:
|
||||
filters.add_filter(
|
||||
filter_field.field_name, filter_field.filter_class, value
|
||||
)
|
||||
return filters
|
||||
|
||||
def incr_stats(self, action: str, func_name: str) -> None:
|
||||
"""
|
||||
Proxy function for statsd.incr to impose a key structure for REST API's
|
||||
|
|
@ -251,39 +302,21 @@ class BaseSupersetModelRestApi(ModelRestApi):
|
|||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
page_size:
|
||||
type: integer
|
||||
page:
|
||||
type: integer
|
||||
filter:
|
||||
type: string
|
||||
$ref: '#/components/schemas/get_related_schema'
|
||||
responses:
|
||||
200:
|
||||
description: Related column data
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
count:
|
||||
type: integer
|
||||
result:
|
||||
type: object
|
||||
properties:
|
||||
value:
|
||||
type: integer
|
||||
text:
|
||||
type: string
|
||||
schema:
|
||||
$ref: "#/components/schemas/RelatedResponseSchema"
|
||||
400:
|
||||
$ref: '#/components/responses/400'
|
||||
401:
|
||||
$ref: '#/components/responses/401'
|
||||
404:
|
||||
$ref: '#/components/responses/404'
|
||||
422:
|
||||
$ref: '#/components/responses/422'
|
||||
500:
|
||||
$ref: '#/components/responses/500'
|
||||
"""
|
||||
|
|
@ -316,3 +349,68 @@ class BaseSupersetModelRestApi(ModelRestApi):
|
|||
for value in values
|
||||
]
|
||||
return self.response(200, count=count, result=result)
|
||||
|
||||
@expose("/distinct/<column_name>", methods=["GET"])
|
||||
@protect()
|
||||
@safe
|
||||
@statsd_metrics
|
||||
@rison(get_related_schema)
|
||||
def distinct(self, column_name: str, **kwargs: Any) -> FlaskResponse:
|
||||
"""Get distinct values from field data
|
||||
---
|
||||
get:
|
||||
parameters:
|
||||
- in: path
|
||||
schema:
|
||||
type: string
|
||||
name: column_name
|
||||
- in: query
|
||||
name: q
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/get_related_schema'
|
||||
responses:
|
||||
200:
|
||||
description: Distinct field data
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
schema:
|
||||
$ref: "#/components/schemas/DistincResponseSchema"
|
||||
400:
|
||||
$ref: '#/components/responses/400'
|
||||
401:
|
||||
$ref: '#/components/responses/401'
|
||||
404:
|
||||
$ref: '#/components/responses/404'
|
||||
500:
|
||||
$ref: '#/components/responses/500'
|
||||
"""
|
||||
if column_name not in self.allowed_distinct_fields:
|
||||
self.incr_stats("error", self.related.__name__)
|
||||
return self.response_404()
|
||||
args = kwargs.get("rison", {})
|
||||
# handle pagination
|
||||
page, page_size = self._sanitize_page_args(*self._handle_page_args(args))
|
||||
# Create generic base filters with added request filter
|
||||
filters = self._get_distinct_filter(column_name, args.get("filter"))
|
||||
# Make the query
|
||||
query_count = self.appbuilder.get_session.query(
|
||||
func.count(distinct(getattr(self.datamodel.obj, column_name)))
|
||||
)
|
||||
count = self.datamodel.apply_filters(query_count, filters).scalar()
|
||||
if count == 0:
|
||||
return self.response(200, count=count, result=[])
|
||||
query = self.appbuilder.get_session.query(
|
||||
distinct(getattr(self.datamodel.obj, column_name))
|
||||
)
|
||||
# Apply generic base filters with added request filter
|
||||
query = self.datamodel.apply_filters(query, filters)
|
||||
# Apply sort
|
||||
query = self.datamodel.apply_order_by(query, column_name, "asc")
|
||||
# Apply pagination
|
||||
result = self.datamodel.apply_pagination(query, page, page_size).all()
|
||||
# produce response
|
||||
result = [{"text": item[0]} for item in result if item[0] is not None]
|
||||
return self.response(200, count=count, result=result)
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@
|
|||
# under the License.
|
||||
"""Unit tests for Superset"""
|
||||
import json
|
||||
from typing import List
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import prison
|
||||
|
|
@ -129,7 +129,6 @@ class TestDatasetApi(SupersetTestCase):
|
|||
"""
|
||||
Dataset API: Test get dataset related databases gamma
|
||||
"""
|
||||
example_db = get_example_database()
|
||||
self.login(username="gamma")
|
||||
uri = "api/v1/dataset/related/database"
|
||||
rv = self.client.get(uri)
|
||||
|
|
@ -170,6 +169,93 @@ class TestDatasetApi(SupersetTestCase):
|
|||
self.assertEqual(len(response["result"]["columns"]), 3)
|
||||
self.assertEqual(len(response["result"]["metrics"]), 2)
|
||||
|
||||
def test_get_dataset_distinct_schema(self):
|
||||
"""
|
||||
Dataset API: Test get dataset distinct schema
|
||||
"""
|
||||
|
||||
def pg_test_query_parameter(query_parameter, expected_response):
|
||||
uri = f"api/v1/dataset/distinct/schema?q={prison.dumps(query_parameter)}"
|
||||
rv = self.client.get(uri)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
self.assertEqual(response, expected_response)
|
||||
|
||||
example_db = get_example_database()
|
||||
datasets = []
|
||||
if example_db.backend == "postgresql":
|
||||
datasets.append(
|
||||
self.insert_dataset("ab_permission", "public", [], get_main_database())
|
||||
)
|
||||
datasets.append(
|
||||
self.insert_dataset(
|
||||
"columns", "information_schema", [], get_main_database()
|
||||
)
|
||||
)
|
||||
expected_response = {
|
||||
"count": 5,
|
||||
"result": [
|
||||
{"text": ""},
|
||||
{"text": "admin_database"},
|
||||
{"text": "information_schema"},
|
||||
{"text": "public"},
|
||||
{"text": "superset"},
|
||||
],
|
||||
}
|
||||
self.login(username="admin")
|
||||
uri = "api/v1/dataset/distinct/schema"
|
||||
rv = self.client.get(uri)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
self.assertEqual(response, expected_response)
|
||||
|
||||
# Test filter
|
||||
query_parameter = {"filter": "inf"}
|
||||
pg_test_query_parameter(
|
||||
query_parameter,
|
||||
{"count": 1, "result": [{"text": "information_schema"}]},
|
||||
)
|
||||
|
||||
query_parameter = {"page": 0, "page_size": 1}
|
||||
pg_test_query_parameter(
|
||||
query_parameter, {"count": 5, "result": [{"text": ""}]},
|
||||
)
|
||||
|
||||
query_parameter = {"page": 1, "page_size": 1}
|
||||
pg_test_query_parameter(
|
||||
query_parameter, {"count": 5, "result": [{"text": "admin_database"}]}
|
||||
)
|
||||
|
||||
for dataset in datasets:
|
||||
db.session.delete(dataset)
|
||||
db.session.commit()
|
||||
|
||||
def test_get_dataset_distinct_not_allowed(self):
|
||||
"""
|
||||
Dataset API: Test get dataset distinct not allowed
|
||||
"""
|
||||
self.login(username="admin")
|
||||
uri = "api/v1/dataset/distinct/table_name"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
|
||||
def test_get_dataset_distinct_gamma(self):
|
||||
"""
|
||||
Dataset API: Test get dataset distinct with gamma
|
||||
"""
|
||||
dataset = self.insert_default_dataset()
|
||||
|
||||
self.login(username="gamma")
|
||||
uri = "api/v1/dataset/distinct/schema"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(response["count"], 0)
|
||||
self.assertEqual(response["result"], [])
|
||||
|
||||
db.session.delete(dataset)
|
||||
db.session.commit()
|
||||
|
||||
def test_get_dataset_info(self):
|
||||
"""
|
||||
Dataset API: Test get dataset info
|
||||
|
|
@ -358,6 +444,7 @@ class TestDatasetApi(SupersetTestCase):
|
|||
self.assertEqual(rv.status_code, 200)
|
||||
model = db.session.query(SqlaTable).get(dataset.id)
|
||||
self.assertEqual(model.description, dataset_data["description"])
|
||||
|
||||
db.session.delete(dataset)
|
||||
db.session.commit()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue