diff --git a/superset/datasets/api.py b/superset/datasets/api.py index 727047848..d6e5fe62d 100644 --- a/superset/datasets/api.py +++ b/superset/datasets/api.py @@ -37,6 +37,7 @@ from superset.datasets.commands.exceptions import ( ) from superset.datasets.commands.refresh import RefreshDatasetCommand from superset.datasets.commands.update import UpdateDatasetCommand +from superset.datasets.dao import DatasetDAO from superset.datasets.schemas import ( DatasetPostSchema, DatasetPutSchema, @@ -66,6 +67,7 @@ class DatasetRestApi(BaseSupersetModelRestApi): RouteMethod.EXPORT, RouteMethod.RELATED, "refresh", + "related_objects", } list_columns = [ "id", @@ -412,3 +414,98 @@ class DatasetRestApi(BaseSupersetModelRestApi): "Error refreshing dataset %s: %s", self.__class__.__name__, str(ex) ) return self.response_422(message=str(ex)) + + @expose("//related_objects", methods=["GET"]) + @protect() + @safe + @statsd_metrics + def related_objects(self, pk: int) -> Response: + """Get charts and dashboards count associated to a dataset + --- + get: + description: + Get charts and dashboards count associated to a dataset + parameters: + - in: path + name: pk + schema: + type: integer + responses: + 200: + description: chart and dashboard counts + content: + application/json: + schema: + type: object + properties: + charts: + type: object + properties: + count: + type: integer + result: + type: array + items: + type: object + properties: + id: + type: integer + slice_name: + type: string + viz_type: + type: string + dashboards: + type: object + properties: + count: + type: integer + result: + type: array + items: + type: object + properties: + id: + type: integer + json_metadata: + type: object + slug: + type: string + title: + type: string + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 404: + $ref: '#/components/responses/404' + 500: + $ref: '#/components/responses/500' + """ + try: + data = DatasetDAO.get_related_objects(pk) + charts = [ + { + "id": chart.id, + "slice_name": chart.slice_name, + "viz_type": chart.viz_type, + } + for chart in data["charts"] + ] + dashboards = [ + { + "id": dashboard.id, + "json_metadata": dashboard.json_metadata, + "slug": dashboard.slug, + "title": dashboard.dashboard_title, + } + for dashboard in data["dashboards"] + ] + return self.response( + 200, + charts={"count": len(charts), "result": charts}, + dashboards={"count": len(dashboards), "result": dashboards}, + ) + except DatasetNotFoundError: + return self.response_404() + except DatasetForbiddenError: + return self.response_403() diff --git a/superset/datasets/dao.py b/superset/datasets/dao.py index 0a253fa5a..b5c278c02 100644 --- a/superset/datasets/dao.py +++ b/superset/datasets/dao.py @@ -24,6 +24,8 @@ from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.dao.base import BaseDAO from superset.extensions import db from superset.models.core import Database +from superset.models.dashboard import Dashboard +from superset.models.slice import Slice from superset.views.base import DatasourceFilter logger = logging.getLogger(__name__) @@ -49,6 +51,28 @@ class DatasetDAO(BaseDAO): logger.error("Could not get database by id: %s", str(ex)) return None + @staticmethod + def get_related_objects(database_id: int) -> Dict[str, Any]: + charts = ( + db.session.query(Slice) + .filter( + Slice.datasource_id == database_id, Slice.datasource_type == "table" + ) + .all() + ) + chart_ids = [chart.id for chart in charts] + + dashboards = ( + ( + db.session.query(Dashboard) + .join(Dashboard.slices) + .filter(Slice.id.in_(chart_ids)) + ) + .distinct() + .all() + ) + return dict(charts=charts, dashboards=dashboards) + @staticmethod def validate_table_exists(database: Database, table_name: str, schema: str) -> bool: try: diff --git a/superset/views/base_api.py b/superset/views/base_api.py index a72a1c5e5..2279adee5 100644 --- a/superset/views/base_api.py +++ b/superset/views/base_api.py @@ -83,6 +83,7 @@ class BaseSupersetModelRestApi(ModelRestApi): "data": "list", "viz_types": "list", "datasources": "list", + "related_objects": "list", } order_rel_fields: Dict[str, Tuple[str, str]] = {} diff --git a/tests/datasets/api_tests.py b/tests/datasets/api_tests.py index e1d642ff8..08c751f01 100644 --- a/tests/datasets/api_tests.py +++ b/tests/datasets/api_tests.py @@ -750,3 +750,17 @@ class TestDatasetApi(SupersetTestCase): self.login(username="gamma") rv = self.client.get(uri) self.assertEqual(rv.status_code, 401) + + def test_get_dataset_related_objects(self): + """ + Dataset API: Test get chart and dashboard count related to a dataset + :return: + """ + self.login(username="admin") + table = self.get_birth_names_dataset() + uri = f"api/v1/dataset/{table.id}/related_objects" + rv = self.get_assert_metric(uri, "related_objects") + self.assertEqual(rv.status_code, 200) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(response["charts"]["count"], 18) + self.assertEqual(response["dashboards"]["count"], 2)