feat: add samples endpoint (#20170)

This commit is contained in:
Yongjie Zhao 2022-05-25 18:18:58 +08:00 committed by GitHub
parent 365acee663
commit 40abb44ba1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 183 additions and 0 deletions

View File

@ -129,6 +129,7 @@ MODEL_API_RW_METHOD_PERMISSION_MAP = {
"available": "read",
"validate_sql": "read",
"get_data": "read",
"samples": "read",
}
EXTRA_FORM_DATA_APPEND_KEYS = {

View File

@ -45,11 +45,13 @@ from superset.datasets.commands.exceptions import (
DatasetInvalidError,
DatasetNotFoundError,
DatasetRefreshFailedError,
DatasetSamplesFailedError,
DatasetUpdateFailedError,
)
from superset.datasets.commands.export import ExportDatasetsCommand
from superset.datasets.commands.importers.dispatcher import ImportDatasetsCommand
from superset.datasets.commands.refresh import RefreshDatasetCommand
from superset.datasets.commands.samples import SamplesDatasetCommand
from superset.datasets.commands.update import UpdateDatasetCommand
from superset.datasets.dao import DatasetDAO
from superset.datasets.filters import DatasetCertifiedFilter, DatasetIsNullOrEmptyFilter
@ -90,6 +92,7 @@ class DatasetRestApi(BaseSupersetModelRestApi):
"bulk_delete",
"refresh",
"related_objects",
"samples",
}
list_columns = [
"id",
@ -760,3 +763,64 @@ class DatasetRestApi(BaseSupersetModelRestApi):
)
command.run()
return self.response(200, message="OK")
@expose("/<pk>/samples")
@protect()
@safe
@statsd_metrics
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.samples",
log_to_statsd=False,
)
def samples(self, pk: int) -> Response:
"""get samples from a Dataset
---
get:
description: >-
get samples from a Dataset
parameters:
- in: path
schema:
type: integer
name: pk
- in: query
schema:
type: boolean
name: force
responses:
200:
description: Dataset samples
content:
application/json:
schema:
type: object
properties:
result:
$ref: '#/components/schemas/ChartDataResponseResult'
401:
$ref: '#/components/responses/401'
403:
$ref: '#/components/responses/403'
404:
$ref: '#/components/responses/404'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
try:
force = parse_boolean_string(request.args.get("force"))
rv = SamplesDatasetCommand(g.user, pk, force).run()
return self.response(200, result=rv)
except DatasetNotFoundError:
return self.response_404()
except DatasetForbiddenError:
return self.response_403()
except DatasetSamplesFailedError as ex:
logger.error(
"Error get dataset samples %s: %s",
self.__class__.__name__,
str(ex),
exc_info=True,
)
return self.response_422(message=str(ex))

View File

@ -173,6 +173,10 @@ class DatasetRefreshFailedError(UpdateFailedError):
message = _("Dataset could not be updated.")
class DatasetSamplesFailedError(CommandInvalidError):
message = _("Samples for dataset could not be retrieved.")
class DatasetForbiddenError(ForbiddenError):
message = _("Changing this dataset is forbidden")

View File

@ -0,0 +1,74 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict, Optional
from flask_appbuilder.security.sqla.models import User
from superset.commands.base import BaseCommand
from superset.common.chart_data import ChartDataResultType
from superset.common.query_context_factory import QueryContextFactory
from superset.connectors.sqla.models import SqlaTable
from superset.datasets.commands.exceptions import (
DatasetForbiddenError,
DatasetNotFoundError,
DatasetSamplesFailedError,
)
from superset.datasets.dao import DatasetDAO
from superset.exceptions import SupersetSecurityException
from superset.views.base import check_ownership
logger = logging.getLogger(__name__)
class SamplesDatasetCommand(BaseCommand):
def __init__(self, user: User, model_id: int, force: bool):
self._actor = user
self._model_id = model_id
self._force = force
self._model: Optional[SqlaTable] = None
def run(self) -> Dict[str, Any]:
self.validate()
if not self._model:
raise DatasetNotFoundError()
qc_instance = QueryContextFactory().create(
datasource={
"type": self._model.type,
"id": self._model.id,
},
queries=[{}],
result_type=ChartDataResultType.SAMPLES,
force=self._force,
)
results = qc_instance.get_payload()
try:
return results["queries"][0]
except (IndexError, KeyError) as exc:
raise DatasetSamplesFailedError from exc
def validate(self) -> None:
# Validate/populate model exists
self._model = DatasetDAO.find_by_id(self._model_id)
if not self._model:
raise DatasetNotFoundError()
# Check ownership
try:
check_ownership(self._model)
except SupersetSecurityException as ex:
raise DatasetForbiddenError() from ex

View File

@ -1863,3 +1863,43 @@ class TestDatasetApi(SupersetTestCase):
db.session.delete(table_w_certification)
db.session.commit()
@pytest.mark.usefixtures("create_datasets")
def test_get_dataset_samples(self):
"""
Dataset API: Test get dataset samples
"""
dataset = self.get_fixture_datasets()[0]
self.login(username="admin")
uri = f"api/v1/dataset/{dataset.id}/samples"
# 1. should cache data
# feeds data
self.client.get(uri)
# get from cache
rv = self.client.get(uri)
rv_data = json.loads(rv.data)
assert rv.status_code == 200
assert "result" in rv_data
assert rv_data["result"]["cached_dttm"] is not None
# 2. should through cache
uri2 = f"api/v1/dataset/{dataset.id}/samples?force=true"
# feeds data
self.client.get(uri2)
# force query
rv2 = self.client.get(uri2)
rv_data2 = json.loads(rv2.data)
assert rv_data2["result"]["cached_dttm"] is None
# 3. data precision
assert "colnames" in rv_data2["result"]
assert "coltypes" in rv_data2["result"]
assert "data" in rv_data2["result"]
eager_samples = dataset.database.get_df(
f"select * from {dataset.table_name}"
f' limit {self.app.config["SAMPLES_ROW_LIMIT"]}'
).to_dict(orient="records")
assert eager_samples == rv_data2["result"]["data"]