[dataset] columns and metrics API (nested) (#9268)

* [dataset] columns and metrics API (nested)

* [dataset] tests and validation

* [datasets] Fix, revert list field name to database_name
This commit is contained in:
Daniel Vaz Gaspar 2020-03-24 17:24:08 +00:00 committed by GitHub
parent 46e39d1036
commit ccb22dc976
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 733 additions and 128 deletions

View File

@ -21,7 +21,7 @@ croniter==0.3.31 # via apache-superset (setup.py)
cryptography==2.8 # via apache-superset (setup.py)
decorator==4.4.1 # via retry
defusedxml==0.6.0 # via python3-openid
flask-appbuilder==2.3.0
flask-appbuilder==2.3.0 # via apache-superset (setup.py)
flask-babel==1.0.0 # via flask-appbuilder
flask-caching==1.8.0
flask-compress==1.4.0

View File

@ -25,7 +25,10 @@ from superset.exceptions import SupersetException
class CommandException(SupersetException):
""" Common base class for Command exceptions. """
pass
def __repr__(self):
if self._exception:
return self._exception
return self
class CommandInvalidError(CommandException):

View File

@ -1073,7 +1073,7 @@ class SqlaTable(Model, BaseDatasource):
def get_sqla_table_object(self) -> Table:
return self.database.get_table(self.table_name, schema=self.schema)
def fetch_metadata(self) -> None:
def fetch_metadata(self, commit=True) -> None:
"""Fetches the metadata for the table and merges it in"""
try:
table = self.get_sqla_table_object()
@ -1086,7 +1086,6 @@ class SqlaTable(Model, BaseDatasource):
).format(self.table_name)
)
M = SqlMetric
metrics = []
any_date_col = None
db_engine_spec = self.database.db_engine_spec
@ -1123,7 +1122,7 @@ class SqlaTable(Model, BaseDatasource):
any_date_col = col.name
metrics.append(
M(
SqlMetric(
metric_name="count",
verbose_name="COUNT(*)",
metric_type="count",
@ -1134,7 +1133,8 @@ class SqlaTable(Model, BaseDatasource):
self.main_dttm_col = any_date_col
self.add_missing_metrics(metrics)
db.session.merge(self)
db.session.commit()
if commit:
db.session.commit()
@classmethod
def import_obj(cls, i_datasource, import_time=None) -> int:

View File

@ -30,8 +30,10 @@ from superset.datasets.commands.exceptions import (
DatasetForbiddenError,
DatasetInvalidError,
DatasetNotFoundError,
DatasetRefreshFailedError,
DatasetUpdateFailedError,
)
from superset.datasets.commands.refresh import RefreshDatasetCommand
from superset.datasets.commands.update import UpdateDatasetCommand
from superset.datasets.schemas import DatasetPostSchema, DatasetPutSchema
from superset.views.base import DatasourceFilter
@ -49,9 +51,12 @@ class DatasetRestApi(BaseSupersetModelRestApi):
allow_browser_login = True
class_permission_name = "TableModelView"
include_route_methods = RouteMethod.REST_MODEL_VIEW_CRUD_SET | {RouteMethod.RELATED}
include_route_methods = (
RouteMethod.REST_MODEL_VIEW_CRUD_SET | {RouteMethod.RELATED} | {"refresh"}
)
list_columns = [
"database_name",
"changed_by_name",
"changed_by_url",
"changed_by.username",
@ -79,6 +84,8 @@ class DatasetRestApi(BaseSupersetModelRestApi):
"template_params",
"owners.id",
"owners.username",
"columns",
"metrics",
]
add_model_schema = DatasetPostSchema()
edit_model_schema = DatasetPutSchema()
@ -97,6 +104,8 @@ class DatasetRestApi(BaseSupersetModelRestApi):
"is_sqllab_view",
"template_params",
"owners",
"columns",
"metrics",
]
openapi_spec_tag = "Datasets"
@ -268,3 +277,49 @@ class DatasetRestApi(BaseSupersetModelRestApi):
except DatasetDeleteFailedError as e:
logger.error(f"Error deleting model {self.__class__.__name__}: {e}")
return self.response_422(message=str(e))
@expose("/<pk>/refresh", methods=["PUT"])
@protect()
@safe
def refresh(self, pk: int) -> Response: # pylint: disable=invalid-name
"""Refresh a Dataset
---
put:
description: >-
Refreshes and updates columns of a dataset
parameters:
- in: path
schema:
type: integer
name: pk
responses:
200:
description: Dataset delete
content:
application/json:
schema:
type: object
properties:
message:
type: string
401:
$ref: '#/components/responses/401'
403:
$ref: '#/components/responses/403'
404:
$ref: '#/components/responses/404'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
try:
RefreshDatasetCommand(g.user, pk).run()
return self.response(200, message="OK")
except DatasetNotFoundError:
return self.response_404()
except DatasetForbiddenError:
return self.response_403()
except DatasetRefreshFailedError as e:
logger.error(f"Error refreshing dataset {self.__class__.__name__}: {e}")
return self.response_422(message=str(e))

View File

@ -19,6 +19,7 @@ from typing import Dict, List, Optional
from flask_appbuilder.security.sqla.models import User
from marshmallow import ValidationError
from sqlalchemy.exc import SQLAlchemyError
from superset.commands.base import BaseCommand
from superset.commands.utils import populate_owners
@ -31,6 +32,7 @@ from superset.datasets.commands.exceptions import (
TableNotFoundValidationError,
)
from superset.datasets.dao import DatasetDAO
from superset.extensions import db, security_manager
logger = logging.getLogger(__name__)
@ -43,9 +45,23 @@ class CreateDatasetCommand(BaseCommand):
def run(self):
self.validate()
try:
dataset = DatasetDAO.create(self._properties)
except DAOCreateFailedError as e:
logger.exception(e.exception)
# Creates SqlaTable (Dataset)
dataset = DatasetDAO.create(self._properties, commit=False)
# Updates columns and metrics from the dataset
dataset.fetch_metadata(commit=False)
# Add datasource access permission
security_manager.add_permission_view_menu(
"datasource_access", dataset.get_perm()
)
# Add schema access permission if exists
if dataset.schema:
security_manager.add_permission_view_menu(
"schema_access", dataset.schema_perm
)
db.session.commit()
except (SQLAlchemyError, DAOCreateFailedError) as e:
logger.exception(e)
db.session.rollback()
raise DatasetCreateFailedError()
return dataset

View File

@ -18,6 +18,7 @@ import logging
from typing import Optional
from flask_appbuilder.security.sqla.models import User
from sqlalchemy.exc import SQLAlchemyError
from superset.commands.base import BaseCommand
from superset.connectors.sqla.models import SqlaTable
@ -29,6 +30,7 @@ from superset.datasets.commands.exceptions import (
)
from superset.datasets.dao import DatasetDAO
from superset.exceptions import SupersetSecurityException
from superset.extensions import db, security_manager
from superset.views.base import check_ownership
logger = logging.getLogger(__name__)
@ -43,9 +45,14 @@ class DeleteDatasetCommand(BaseCommand):
def run(self):
self.validate()
try:
dataset = DatasetDAO.delete(self._model)
except DAODeleteFailedError as e:
logger.exception(e.exception)
dataset = DatasetDAO.delete(self._model, commit=False)
security_manager.del_permission_view_menu(
"datasource_access", dataset.get_perm()
)
db.session.commit()
except (SQLAlchemyError, DAODeleteFailedError) as e:
logger.exception(e)
db.session.rollback()
raise DatasetDeleteFailedError()
return dataset

View File

@ -57,6 +57,68 @@ class DatasetExistsValidationError(ValidationError):
)
class DatasetColumnNotFoundValidationError(ValidationError):
"""
Marshmallow validation error when dataset column for update does not exist
"""
def __init__(self):
super().__init__(_("One or more columns do not exist"), field_names=["columns"])
class DatasetColumnsDuplicateValidationError(ValidationError):
"""
Marshmallow validation error when dataset columns have a duplicate on the list
"""
def __init__(self):
super().__init__(
_("One or more columns are duplicated"), field_names=["columns"]
)
class DatasetColumnsExistsValidationError(ValidationError):
"""
Marshmallow validation error when dataset columns already exist
"""
def __init__(self):
super().__init__(
_("One or more columns already exist"), field_names=["columns"]
)
class DatasetMetricsNotFoundValidationError(ValidationError):
"""
Marshmallow validation error when dataset metric for update does not exist
"""
def __init__(self):
super().__init__(_("One or more metrics do not exist"), field_names=["metrics"])
class DatasetMetricsDuplicateValidationError(ValidationError):
"""
Marshmallow validation error when dataset metrics have a duplicate on the list
"""
def __init__(self):
super().__init__(
_("One or more metrics are duplicated"), field_names=["metrics"]
)
class DatasetMetricsExistsValidationError(ValidationError):
"""
Marshmallow validation error when dataset metrics already exist
"""
def __init__(self):
super().__init__(
_("One or more metrics already exist"), field_names=["metrics"]
)
class TableNotFoundValidationError(ValidationError):
"""
Marshmallow validation error when a table does not exist on the database
@ -99,5 +161,9 @@ class DatasetDeleteFailedError(DeleteFailedError):
message = _("Dataset could not be deleted.")
class DatasetRefreshFailedError(UpdateFailedError):
message = _("Dataset could not be updated.")
class DatasetForbiddenError(ForbiddenError):
message = _("Changing this dataset is forbidden")

View File

@ -0,0 +1,61 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Optional
from flask_appbuilder.security.sqla.models import User
from superset.commands.base import BaseCommand
from superset.connectors.sqla.models import SqlaTable
from superset.datasets.commands.exceptions import (
DatasetForbiddenError,
DatasetNotFoundError,
DatasetRefreshFailedError,
)
from superset.datasets.dao import DatasetDAO
from superset.exceptions import SupersetSecurityException
from superset.views.base import check_ownership
logger = logging.getLogger(__name__)
class RefreshDatasetCommand(BaseCommand):
def __init__(self, user: User, model_id: int):
self._actor = user
self._model_id = model_id
self._model: Optional[SqlaTable] = None
def run(self):
self.validate()
try:
# Updates columns and metrics from the dataset
self._model.fetch_metadata()
except Exception as e:
logger.exception(e)
raise DatasetRefreshFailedError()
return self._model
def validate(self) -> None:
# Validate/populate model exists
self._model = DatasetDAO.find_by_id(self._model_id)
if not self._model:
raise DatasetNotFoundError()
# Check ownership
try:
check_ownership(self._model)
except SupersetSecurityException:
raise DatasetForbiddenError()

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from collections import Counter
from typing import Dict, List, Optional
from flask_appbuilder.security.sqla.models import User
@ -26,9 +27,15 @@ from superset.connectors.sqla.models import SqlaTable
from superset.dao.exceptions import DAOUpdateFailedError
from superset.datasets.commands.exceptions import (
DatabaseChangeValidationError,
DatasetColumnNotFoundValidationError,
DatasetColumnsDuplicateValidationError,
DatasetColumnsExistsValidationError,
DatasetExistsValidationError,
DatasetForbiddenError,
DatasetInvalidError,
DatasetMetricsDuplicateValidationError,
DatasetMetricsExistsValidationError,
DatasetMetricsNotFoundValidationError,
DatasetNotFoundError,
DatasetUpdateFailedError,
)
@ -84,7 +91,64 @@ class UpdateDatasetCommand(BaseCommand):
self._properties["owners"] = owners
except ValidationError as e:
exceptions.append(e)
# Validate columns
columns = self._properties.get("columns")
if columns:
self._validate_columns(columns, exceptions)
# Validate metrics
metrics = self._properties.get("metrics")
if metrics:
self._validate_metrics(metrics, exceptions)
if exceptions:
exception = DatasetInvalidError()
exception.add_list(exceptions)
raise exception
def _validate_columns(self, columns: List[Dict], exceptions: List[ValidationError]):
# Validate duplicates on data
if self._get_duplicates(columns, "column_name"):
exceptions.append(DatasetColumnsDuplicateValidationError())
else:
# validate invalid id's
columns_ids: List[int] = [
column["id"] for column in columns if "id" in column
]
if not DatasetDAO.validate_columns_exist(self._model_id, columns_ids):
exceptions.append(DatasetColumnNotFoundValidationError())
# validate new column names uniqueness
columns_names: List[str] = [
column["column_name"] for column in columns if "id" not in column
]
if not DatasetDAO.validate_columns_uniqueness(
self._model_id, columns_names
):
exceptions.append(DatasetColumnsExistsValidationError())
def _validate_metrics(self, metrics: List[Dict], exceptions: List[ValidationError]):
if self._get_duplicates(metrics, "metric_name"):
exceptions.append(DatasetMetricsDuplicateValidationError())
else:
# validate invalid id's
metrics_ids: List[int] = [
metric["id"] for metric in metrics if "id" in metric
]
if not DatasetDAO.validate_metrics_exist(self._model_id, metrics_ids):
exceptions.append(DatasetMetricsNotFoundValidationError())
# validate new metric names uniqueness
metric_names: List[str] = [
metric["metric_name"] for metric in metrics if "id" not in metric
]
if not DatasetDAO.validate_metrics_uniqueness(self._model_id, metric_names):
exceptions.append(DatasetMetricsExistsValidationError())
@staticmethod
def _get_duplicates(data: List[Dict], key: str):
duplicates = [
name
for name, count in Counter([item[key] for item in data]).items()
if count > 1
]
return duplicates

View File

@ -15,18 +15,13 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Dict, Optional
from typing import Dict, List, Optional
from flask import current_app
from flask_appbuilder.models.sqla.interface import SQLAInterface
from sqlalchemy.exc import SQLAlchemyError
from superset.commands.exceptions import (
CreateFailedError,
DeleteFailedError,
UpdateFailedError,
)
from superset.connectors.sqla.models import SqlaTable
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.dao.base import BaseDAO
from superset.extensions import db
from superset.models.core import Database
from superset.views.base import DatasourceFilter
@ -34,7 +29,10 @@ from superset.views.base import DatasourceFilter
logger = logging.getLogger(__name__)
class DatasetDAO:
class DatasetDAO(BaseDAO):
model_cls = SqlaTable
base_filter = DatasourceFilter
@staticmethod
def get_owner_by_id(owner_id: int) -> Optional[object]:
return (
@ -79,47 +77,108 @@ class DatasetDAO:
return not db.session.query(dataset_query.exists()).scalar()
@staticmethod
def find_by_id(model_id: int) -> SqlaTable:
data_model = SQLAInterface(SqlaTable, db.session)
query = db.session.query(SqlaTable)
query = DatasourceFilter("id", data_model).apply(query, None)
return query.filter_by(id=model_id).one_or_none()
def validate_columns_exist(dataset_id: int, columns_ids: List[int]) -> bool:
dataset_query = (
db.session.query(TableColumn.id).filter(
TableColumn.table_id == dataset_id, TableColumn.id.in_(columns_ids)
)
).all()
return len(columns_ids) == len(dataset_query)
@staticmethod
def create(properties: Dict, commit=True) -> Optional[SqlaTable]:
model = SqlaTable()
for key, value in properties.items():
setattr(model, key, value)
try:
db.session.add(model)
if commit:
db.session.commit()
except SQLAlchemyError as e: # pragma: no cover
db.session.rollback()
raise CreateFailedError(exception=e)
return model
def validate_columns_uniqueness(dataset_id: int, columns_names: List[str]) -> bool:
dataset_query = (
db.session.query(TableColumn.id).filter(
TableColumn.table_id == dataset_id,
TableColumn.column_name.in_(columns_names),
)
).all()
return len(dataset_query) == 0
@staticmethod
def update(model: SqlaTable, properties: Dict, commit=True) -> Optional[SqlaTable]:
for key, value in properties.items():
setattr(model, key, value)
try:
db.session.merge(model)
if commit:
db.session.commit()
except SQLAlchemyError as e: # pragma: no cover
db.session.rollback()
raise UpdateFailedError(exception=e)
return model
def validate_metrics_exist(dataset_id: int, metrics_ids: List[int]) -> bool:
dataset_query = (
db.session.query(SqlMetric.id).filter(
SqlMetric.table_id == dataset_id, SqlMetric.id.in_(metrics_ids)
)
).all()
return len(metrics_ids) == len(dataset_query)
@staticmethod
def delete(model: SqlaTable, commit=True):
try:
db.session.delete(model)
if commit:
db.session.commit()
except SQLAlchemyError as e: # pragma: no cover
logger.error(f"Failed to delete dataset: {e}")
db.session.rollback()
raise DeleteFailedError(exception=e)
return model
def validate_metrics_uniqueness(dataset_id: int, metrics_names: List[str]) -> bool:
dataset_query = (
db.session.query(SqlMetric.id).filter(
SqlMetric.table_id == dataset_id,
SqlMetric.metric_name.in_(metrics_names),
)
).all()
return len(dataset_query) == 0
@classmethod
def update(
cls, model: SqlaTable, properties: Dict, commit=True
) -> Optional[SqlaTable]:
"""
Updates a Dataset model on the metadata DB
"""
if "columns" in properties:
new_columns = list()
for column in properties.get("columns", []):
if column.get("id"):
column_obj = db.session.query(TableColumn).get(column.get("id"))
column_obj = DatasetDAO.update_column(
column_obj, column, commit=commit
)
else:
column_obj = DatasetDAO.create_column(column, commit=commit)
new_columns.append(column_obj)
properties["columns"] = new_columns
if "metrics" in properties:
new_metrics = list()
for metric in properties.get("metrics", []):
if metric.get("id"):
metric_obj = db.session.query(SqlMetric).get(metric.get("id"))
metric_obj = DatasetDAO.update_metric(
metric_obj, metric, commit=commit
)
else:
metric_obj = DatasetDAO.create_metric(metric, commit=commit)
new_metrics.append(metric_obj)
properties["metrics"] = new_metrics
return super().update(model, properties, commit=commit)
@classmethod
def update_column(
cls, model: TableColumn, properties: Dict, commit=True
) -> Optional[TableColumn]:
return DatasetColumnDAO.update(model, properties, commit=commit)
@classmethod
def create_column(cls, properties: Dict, commit=True) -> Optional[TableColumn]:
"""
Creates a Dataset model on the metadata DB
"""
return DatasetColumnDAO.create(properties, commit=commit)
@classmethod
def update_metric(
cls, model: SqlMetric, properties: Dict, commit=True
) -> Optional[SqlMetric]:
return DatasetMetricDAO.update(model, properties, commit=commit)
@classmethod
def create_metric(cls, properties: Dict, commit=True) -> Optional[SqlMetric]:
"""
Creates a Dataset model on the metadata DB
"""
return DatasetMetricDAO.create(properties, commit=commit)
class DatasetColumnDAO(BaseDAO):
model_cls = TableColumn
class DatasetMetricDAO(BaseDAO):
model_cls = SqlMetric

View File

@ -14,11 +14,54 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import re
from marshmallow import fields, Schema
from flask_babel import lazy_gettext as _
from marshmallow import fields, Schema, ValidationError
from marshmallow.validate import Length
def validate_python_date_format(value):
regex = re.compile(
r"""
^(
epoch_s|epoch_ms|
(?P<date>%Y(-%m(-%d)?)?)([\sT](?P<time>%H(:%M(:%S(\.%f)?)?)?))?
)$
""",
re.VERBOSE,
)
match = regex.match(value or "")
if not match:
raise ValidationError(_("Invalid date/timestamp format"))
class DatasetColumnsPutSchema(Schema):
id = fields.Integer() # pylint: disable=invalid-name
column_name = fields.String(required=True, validate=Length(1, 255))
type = fields.String(validate=Length(1, 32))
verbose_name = fields.String(allow_none=True, Length=(1, 1024))
description = fields.String(allow_none=True)
expression = fields.String(allow_none=True)
filterable = fields.Boolean()
groupby = fields.Boolean()
is_active = fields.Boolean()
is_dttm = fields.Boolean(default=False)
python_date_format = fields.String(
allow_none=True, validate=[Length(1, 255), validate_python_date_format]
)
class DatasetMetricsPutSchema(Schema):
id = fields.Integer() # pylint: disable=invalid-name
expression = fields.String(required=True)
description = fields.String(allow_none=True)
metric_name = fields.String(required=True, validate=Length(1, 255))
metric_type = fields.String(allow_none=True, validate=Length(1, 32))
d3format = fields.String(allow_none=True, validate=Length(1, 128))
warning_text = fields.String(allow_none=True)
class DatasetPostSchema(Schema):
database = fields.Integer(required=True)
schema = fields.String(validate=Length(0, 250))
@ -31,7 +74,7 @@ class DatasetPutSchema(Schema):
sql = fields.String(allow_none=True)
filter_select_enabled = fields.Boolean(allow_none=True)
fetch_values_predicate = fields.String(allow_none=True, validate=Length(0, 1000))
schema = fields.String(allow_none=True, validate=Length(1, 255))
schema = fields.String(allow_none=True, validate=Length(0, 255))
description = fields.String(allow_none=True)
main_dttm_col = fields.String(allow_none=True)
offset = fields.Integer(allow_none=True)
@ -40,3 +83,5 @@ class DatasetPutSchema(Schema):
is_sqllab_view = fields.Boolean(allow_none=True)
template_params = fields.String(allow_none=True)
owners = fields.List(fields.Integer())
columns = fields.List(fields.Nested(DatasetColumnsPutSchema))
metrics = fields.List(fields.Nested(DatasetMetricsPutSchema))

View File

@ -74,6 +74,7 @@ class BaseSupersetModelRestApi(ModelRestApi):
"bulk_delete": "delete",
"info": "list",
"related": "list",
"refresh": "edit",
}
order_rel_fields: Dict[str, Tuple[str, str]] = {}

View File

@ -20,9 +20,10 @@ from typing import List
from unittest.mock import patch
import prison
from sqlalchemy.sql import func
from superset import db, security_manager
from superset.connectors.sqla.models import SqlaTable
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.dao.exceptions import (
DAOCreateFailedError,
DAODeleteFailedError,
@ -30,8 +31,7 @@ from superset.dao.exceptions import (
)
from superset.models.core import Database
from superset.utils.core import get_example_database
from .base_tests import SupersetTestCase
from tests.base_tests import SupersetTestCase
class DatasetApiTests(SupersetTestCase):
@ -48,8 +48,23 @@ class DatasetApiTests(SupersetTestCase):
)
db.session.add(table)
db.session.commit()
table.fetch_metadata()
return table
def insert_default_dataset(self):
return self.insert_dataset(
"ab_permission", "", [self.get_user("admin").id], get_example_database()
)
@staticmethod
def get_birth_names_dataset():
example_db = get_example_database()
return (
db.session.query(SqlaTable)
.filter_by(database=example_db, table_name="birth_names")
.one()
)
def test_get_dataset_list(self):
"""
Dataset API: Test get dataset list
@ -109,12 +124,7 @@ class DatasetApiTests(SupersetTestCase):
"""
Dataset API: Test get dataset item
"""
example_db = get_example_database()
table = (
db.session.query(SqlaTable)
.filter_by(database=example_db, table_name="birth_names")
.one()
)
table = self.get_birth_names_dataset()
self.login(username="admin")
uri = f"api/v1/dataset/{table.id}"
rv = self.client.get(uri)
@ -136,7 +146,10 @@ class DatasetApiTests(SupersetTestCase):
"table_name": "birth_names",
"template_params": None,
}
self.assertEqual(response["result"], expected_result)
for key, value in expected_result.items():
self.assertEqual(response["result"][key], expected_result[key])
self.assertEqual(len(response["result"]["columns"]), 8)
self.assertEqual(len(response["result"]["metrics"]), 2)
def test_get_dataset_info(self):
"""
@ -162,9 +175,30 @@ class DatasetApiTests(SupersetTestCase):
rv = self.client.post(uri, json=table_data)
self.assertEqual(rv.status_code, 201)
data = json.loads(rv.data.decode("utf-8"))
model = db.session.query(SqlaTable).get(data.get("id"))
table_id = data.get("id")
model = db.session.query(SqlaTable).get(table_id)
self.assertEqual(model.table_name, table_data["table_name"])
self.assertEqual(model.database_id, table_data["database"])
# Assert that columns were created
columns = (
db.session.query(TableColumn)
.filter_by(table_id=table_id)
.order_by("column_name")
.all()
)
self.assertEqual(columns[0].column_name, "id")
self.assertEqual(columns[1].column_name, "name")
# Assert that metrics were created
columns = (
db.session.query(SqlMetric)
.filter_by(table_id=table_id)
.order_by("metric_name")
.all()
)
self.assertEqual(columns[0].expression, "COUNT(*)")
db.session.delete(model)
db.session.commit()
@ -252,9 +286,9 @@ class DatasetApiTests(SupersetTestCase):
Dataset API: Test create dataset validate database exists
"""
self.login(username="admin")
table_data = {"database": 1000, "schema": "", "table_name": "birth_names"}
dataset_data = {"database": 1000, "schema": "", "table_name": "birth_names"}
uri = "api/v1/dataset/"
rv = self.client.post(uri, json=table_data)
rv = self.client.post(uri, json=dataset_data)
self.assertEqual(rv.status_code, 422)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(data, {"message": {"database": ["Database does not exist"]}})
@ -297,73 +331,224 @@ class DatasetApiTests(SupersetTestCase):
"""
Dataset API: Test update dataset item
"""
table = self.insert_dataset("ab_permission", "", [], get_example_database())
dataset = self.insert_default_dataset()
self.login(username="admin")
table_data = {"description": "changed_description"}
uri = f"api/v1/dataset/{table.id}"
rv = self.client.put(uri, json=table_data)
dataset_data = {"description": "changed_description"}
uri = f"api/v1/dataset/{dataset.id}"
rv = self.client.put(uri, json=dataset_data)
self.assertEqual(rv.status_code, 200)
model = db.session.query(SqlaTable).get(table.id)
self.assertEqual(model.description, table_data["description"])
db.session.delete(table)
model = db.session.query(SqlaTable).get(dataset.id)
self.assertEqual(model.description, dataset_data["description"])
db.session.delete(dataset)
db.session.commit()
def test_update_dataset_create_column(self):
"""
Dataset API: Test update dataset create column
"""
# create example dataset by Command
dataset = self.insert_default_dataset()
new_column_data = {
"column_name": "new_col",
"description": "description",
"expression": "expression",
"type": "INTEGER",
"verbose_name": "New Col",
}
uri = f"api/v1/dataset/{dataset.id}"
# Get current cols and append the new column
self.login(username="admin")
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
data["result"]["columns"].append(new_column_data)
rv = self.client.put(uri, json={"columns": data["result"]["columns"]})
self.assertEqual(rv.status_code, 200)
columns = (
db.session.query(TableColumn)
.filter_by(table_id=dataset.id)
.order_by("column_name")
.all()
)
self.assertEqual(columns[0].column_name, "id")
self.assertEqual(columns[1].column_name, "name")
self.assertEqual(columns[2].column_name, new_column_data["column_name"])
self.assertEqual(columns[2].description, new_column_data["description"])
self.assertEqual(columns[2].expression, new_column_data["expression"])
self.assertEqual(columns[2].type, new_column_data["type"])
self.assertEqual(columns[2].verbose_name, new_column_data["verbose_name"])
db.session.delete(dataset)
db.session.commit()
def test_update_dataset_update_column(self):
"""
Dataset API: Test update dataset columns
"""
dataset = self.insert_default_dataset()
self.login(username="admin")
uri = f"api/v1/dataset/{dataset.id}"
# Get current cols and alter one
rv = self.client.get(uri)
resp_columns = json.loads(rv.data.decode("utf-8"))["result"]["columns"]
resp_columns[0]["groupby"] = False
resp_columns[0]["filterable"] = False
v = self.client.put(uri, json={"columns": resp_columns})
self.assertEqual(rv.status_code, 200)
columns = (
db.session.query(TableColumn)
.filter_by(table_id=dataset.id)
.order_by("column_name")
.all()
)
self.assertEqual(columns[0].column_name, "id")
self.assertEqual(columns[1].column_name, "name")
self.assertEqual(columns[0].groupby, False)
self.assertEqual(columns[0].filterable, False)
db.session.delete(dataset)
db.session.commit()
def test_update_dataset_update_column_uniqueness(self):
"""
Dataset API: Test update dataset columns uniqueness
"""
dataset = self.insert_default_dataset()
self.login(username="admin")
uri = f"api/v1/dataset/{dataset.id}"
# try to insert a new column ID that already exists
data = {"columns": [{"column_name": "id", "type": "INTEGER"}]}
rv = self.client.put(uri, json=data)
self.assertEqual(rv.status_code, 422)
data = json.loads(rv.data.decode("utf-8"))
expected_result = {
"message": {"columns": ["One or more columns already exist"]}
}
self.assertEqual(data, expected_result)
db.session.delete(dataset)
db.session.commit()
def test_update_dataset_update_metric_uniqueness(self):
"""
Dataset API: Test update dataset metric uniqueness
"""
dataset = self.insert_default_dataset()
self.login(username="admin")
uri = f"api/v1/dataset/{dataset.id}"
# try to insert a new column ID that already exists
data = {"metrics": [{"metric_name": "count", "expression": "COUNT(*)"}]}
rv = self.client.put(uri, json=data)
self.assertEqual(rv.status_code, 422)
data = json.loads(rv.data.decode("utf-8"))
expected_result = {
"message": {"metrics": ["One or more metrics already exist"]}
}
self.assertEqual(data, expected_result)
db.session.delete(dataset)
db.session.commit()
def test_update_dataset_update_column_duplicate(self):
"""
Dataset API: Test update dataset columns duplicate
"""
dataset = self.insert_default_dataset()
self.login(username="admin")
uri = f"api/v1/dataset/{dataset.id}"
# try to insert a new column ID that already exists
data = {
"columns": [
{"column_name": "id", "type": "INTEGER"},
{"column_name": "id", "type": "VARCHAR"},
]
}
rv = self.client.put(uri, json=data)
self.assertEqual(rv.status_code, 422)
data = json.loads(rv.data.decode("utf-8"))
expected_result = {
"message": {"columns": ["One or more columns are duplicated"]}
}
self.assertEqual(data, expected_result)
db.session.delete(dataset)
db.session.commit()
def test_update_dataset_update_metric_duplicate(self):
"""
Dataset API: Test update dataset metric duplicate
"""
dataset = self.insert_default_dataset()
self.login(username="admin")
uri = f"api/v1/dataset/{dataset.id}"
# try to insert a new column ID that already exists
data = {
"metrics": [
{"metric_name": "dup", "expression": "COUNT(*)"},
{"metric_name": "dup", "expression": "DIFF_COUNT(*)"},
]
}
rv = self.client.put(uri, json=data)
self.assertEqual(rv.status_code, 422)
data = json.loads(rv.data.decode("utf-8"))
expected_result = {
"message": {"metrics": ["One or more metrics are duplicated"]}
}
self.assertEqual(data, expected_result)
db.session.delete(dataset)
db.session.commit()
def test_update_dataset_item_gamma(self):
"""
Dataset API: Test update dataset item gamma
"""
table = self.insert_dataset("ab_permission", "", [], get_example_database())
dataset = self.insert_default_dataset()
self.login(username="gamma")
table_data = {"description": "changed_description"}
uri = f"api/v1/dataset/{table.id}"
uri = f"api/v1/dataset/{dataset.id}"
rv = self.client.put(uri, json=table_data)
self.assertEqual(rv.status_code, 401)
db.session.delete(table)
db.session.delete(dataset)
db.session.commit()
def test_update_dataset_item_not_owned(self):
"""
Dataset API: Test update dataset item not owned
"""
admin = self.get_user("admin")
table = self.insert_dataset(
"ab_permission", "", [admin.id], get_example_database()
)
dataset = self.insert_default_dataset()
self.login(username="alpha")
table_data = {"description": "changed_description"}
uri = f"api/v1/dataset/{table.id}"
uri = f"api/v1/dataset/{dataset.id}"
rv = self.client.put(uri, json=table_data)
self.assertEqual(rv.status_code, 403)
db.session.delete(table)
db.session.delete(dataset)
db.session.commit()
def test_update_dataset_item_owners_invalid(self):
"""
Dataset API: Test update dataset item owner invalid
"""
admin = self.get_user("admin")
table = self.insert_dataset(
"ab_permission", "", [admin.id], get_example_database()
)
dataset = self.insert_default_dataset()
self.login(username="admin")
table_data = {"description": "changed_description", "owners": [1000]}
uri = f"api/v1/dataset/{table.id}"
uri = f"api/v1/dataset/{dataset.id}"
rv = self.client.put(uri, json=table_data)
self.assertEqual(rv.status_code, 422)
db.session.delete(table)
db.session.delete(dataset)
db.session.commit()
def test_update_dataset_item_uniqueness(self):
"""
Dataset API: Test update dataset uniqueness
"""
admin = self.get_user("admin")
table = self.insert_dataset(
"ab_permission", "", [admin.id], get_example_database()
)
dataset = self.insert_default_dataset()
self.login(username="admin")
table_data = {"table_name": "birth_names"}
uri = f"api/v1/dataset/{table.id}"
uri = f"api/v1/dataset/{dataset.id}"
rv = self.client.put(uri, json=table_data)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 422)
@ -371,7 +556,7 @@ class DatasetApiTests(SupersetTestCase):
"message": {"table_name": ["Datasource birth_names already exists"]}
}
self.assertEqual(data, expected_response)
db.session.delete(table)
db.session.delete(dataset)
db.session.commit()
@patch("superset.datasets.dao.DatasetDAO.update")
@ -381,25 +566,25 @@ class DatasetApiTests(SupersetTestCase):
"""
mock_dao_update.side_effect = DAOUpdateFailedError()
table = self.insert_dataset("ab_permission", "", [], get_example_database())
dataset = self.insert_default_dataset()
self.login(username="admin")
table_data = {"description": "changed_description"}
uri = f"api/v1/dataset/{table.id}"
uri = f"api/v1/dataset/{dataset.id}"
rv = self.client.put(uri, json=table_data)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 422)
self.assertEqual(data, {"message": "Dataset could not be updated."})
db.session.delete(dataset)
db.session.commit()
def test_delete_dataset_item(self):
"""
Dataset API: Test delete dataset item
"""
admin = self.get_user("admin")
table = self.insert_dataset(
"ab_permission", "", [admin.id], get_example_database()
)
dataset = self.insert_default_dataset()
self.login(username="admin")
uri = f"api/v1/dataset/{table.id}"
uri = f"api/v1/dataset/{dataset.id}"
rv = self.client.delete(uri)
self.assertEqual(rv.status_code, 200)
@ -407,30 +592,24 @@ class DatasetApiTests(SupersetTestCase):
"""
Dataset API: Test delete item not owned
"""
admin = self.get_user("admin")
table = self.insert_dataset(
"ab_permission", "", [admin.id], get_example_database()
)
dataset = self.insert_default_dataset()
self.login(username="alpha")
uri = f"api/v1/dataset/{table.id}"
uri = f"api/v1/dataset/{dataset.id}"
rv = self.client.delete(uri)
self.assertEqual(rv.status_code, 403)
db.session.delete(table)
db.session.delete(dataset)
db.session.commit()
def test_delete_dataset_item_not_authorized(self):
"""
Dataset API: Test delete item not authorized
"""
admin = self.get_user("admin")
table = self.insert_dataset(
"ab_permission", "", [admin.id], get_example_database()
)
dataset = self.insert_default_dataset()
self.login(username="gamma")
uri = f"api/v1/dataset/{table.id}"
uri = f"api/v1/dataset/{dataset.id}"
rv = self.client.delete(uri)
self.assertEqual(rv.status_code, 401)
db.session.delete(table)
db.session.delete(dataset)
db.session.commit()
@patch("superset.datasets.dao.DatasetDAO.delete")
@ -440,15 +619,64 @@ class DatasetApiTests(SupersetTestCase):
"""
mock_dao_delete.side_effect = DAODeleteFailedError()
admin = self.get_user("admin")
table = self.insert_dataset(
"ab_permission", "", [admin.id], get_example_database()
)
dataset = self.insert_default_dataset()
self.login(username="admin")
uri = f"api/v1/dataset/{table.id}"
uri = f"api/v1/dataset/{dataset.id}"
rv = self.client.delete(uri)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 422)
self.assertEqual(data, {"message": "Dataset could not be deleted."})
db.session.delete(table)
db.session.delete(dataset)
db.session.commit()
def test_dataset_item_refresh(self):
"""
Dataset API: Test item refresh
"""
dataset = self.insert_default_dataset()
# delete a column
id_column = (
db.session.query(TableColumn)
.filter_by(table_id=dataset.id, column_name="id")
.one()
)
db.session.delete(id_column)
db.session.commit()
self.login(username="admin")
uri = f"api/v1/dataset/{dataset.id}/refresh"
rv = self.client.put(uri)
self.assertEqual(rv.status_code, 200)
# Assert the column is restored on refresh
id_column = (
db.session.query(TableColumn)
.filter_by(table_id=dataset.id, column_name="id")
.one()
)
self.assertIsNotNone(id_column)
db.session.delete(dataset)
db.session.commit()
def test_dataset_item_refresh_not_found(self):
"""
Dataset API: Test item refresh not found dataset
"""
max_id = db.session.query(func.max(SqlaTable.id)).scalar()
self.login(username="admin")
uri = f"api/v1/dataset/{max_id + 1}/refresh"
rv = self.client.put(uri)
self.assertEqual(rv.status_code, 404)
def test_dataset_item_refresh_not_owned(self):
"""
Dataset API: Test item refresh not owned dataset
"""
dataset = self.insert_default_dataset()
self.login(username="alpha")
uri = f"api/v1/dataset/{dataset.id}/refresh"
rv = self.client.put(uri)
self.assertEqual(rv.status_code, 403)
db.session.delete(dataset)
db.session.commit()