feat!: pass datasource_type and datasource_id to form_data (#19981)

* pass datasource_type and datasource_id to form_data

* add datasource_type to delete command

* add datasource_type to delete command

* fix old keys implementation

* add more tests
This commit is contained in:
Elizabeth Thompson 2022-06-02 16:48:16 -07:00 committed by GitHub
parent a813528958
commit 32bb1ce3ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
47 changed files with 959 additions and 176 deletions

View File

@ -67,10 +67,18 @@ export const URL_PARAMS = {
name: 'slice_id',
type: 'string',
},
datasourceId: {
name: 'datasource_id',
type: 'string',
},
datasetId: {
name: 'dataset_id',
type: 'string',
},
datasourceType: {
name: 'datasource_type',
type: 'string',
},
dashboardId: {
name: 'dashboard_id',
type: 'string',
@ -88,6 +96,8 @@ export const URL_PARAMS = {
export const RESERVED_CHART_URL_PARAMS: string[] = [
URL_PARAMS.formDataKey.name,
URL_PARAMS.sliceId.name,
URL_PARAMS.datasourceId.name,
URL_PARAMS.datasourceType.name,
URL_PARAMS.datasetId.name,
];
export const RESERVED_DASHBOARD_URL_PARAMS: string[] = [

View File

@ -272,6 +272,7 @@ export default class Chart extends React.Component {
: undefined;
const key = await postFormData(
this.props.datasource.id,
this.props.datasource.type,
this.props.formData,
this.props.slice.slice_id,
nextTabId,

View File

@ -92,7 +92,7 @@ test('generates a new form_data param when none is available', async () => {
expect(replaceState).toHaveBeenCalledWith(
expect.anything(),
undefined,
expect.stringMatching('dataset_id'),
expect.stringMatching('datasource_id'),
);
replaceState.mockRestore();
});
@ -109,7 +109,7 @@ test('generates a different form_data param when one is provided and is mounting
expect(replaceState).toHaveBeenCalledWith(
expect.anything(),
undefined,
expect.stringMatching('dataset_id'),
expect.stringMatching('datasource_id'),
);
replaceState.mockRestore();
});

View File

@ -152,14 +152,24 @@ const ExplorePanelContainer = styled.div`
`;
const updateHistory = debounce(
async (formData, datasetId, isReplace, standalone, force, title, tabId) => {
async (
formData,
datasourceId,
datasourceType,
isReplace,
standalone,
force,
title,
tabId,
) => {
const payload = { ...formData };
const chartId = formData.slice_id;
const additionalParam = {};
if (chartId) {
additionalParam[URL_PARAMS.sliceId.name] = chartId;
} else {
additionalParam[URL_PARAMS.datasetId.name] = datasetId;
additionalParam[URL_PARAMS.datasourceId.name] = datasourceId;
additionalParam[URL_PARAMS.datasourceType.name] = datasourceType;
}
const urlParams = payload?.url_params || {};
@ -173,11 +183,24 @@ const updateHistory = debounce(
let key;
let stateModifier;
if (isReplace) {
key = await postFormData(datasetId, formData, chartId, tabId);
key = await postFormData(
datasourceId,
datasourceType,
formData,
chartId,
tabId,
);
stateModifier = 'replaceState';
} else {
key = getUrlParam(URL_PARAMS.formDataKey);
await putFormData(datasetId, key, formData, chartId, tabId);
await putFormData(
datasourceId,
datasourceType,
key,
formData,
chartId,
tabId,
);
stateModifier = 'pushState';
}
const url = mountExploreUrl(
@ -229,11 +252,12 @@ function ExploreViewContainer(props) {
dashboardId: props.dashboardId,
}
: props.form_data;
const datasetId = props.datasource.id;
const { id: datasourceId, type: datasourceType } = props.datasource;
updateHistory(
formData,
datasetId,
datasourceId,
datasourceType,
isReplace,
props.standalone,
props.force,
@ -245,6 +269,7 @@ function ExploreViewContainer(props) {
props.dashboardId,
props.form_data,
props.datasource.id,
props.datasource.type,
props.standalone,
props.force,
tabId,

View File

@ -189,9 +189,9 @@ class DatasourceControl extends React.PureComponent {
const isMissingDatasource = datasource.id == null;
let isMissingParams = false;
if (isMissingDatasource) {
const datasetId = getUrlParam(URL_PARAMS.datasetId);
const datasourceId = getUrlParam(URL_PARAMS.datasourceId);
const sliceId = getUrlParam(URL_PARAMS.sliceId);
if (!datasetId && !sliceId) {
if (!datasourceId && !sliceId) {
isMissingParams = true;
}
}

View File

@ -20,7 +20,8 @@ import { omit } from 'lodash';
import { SupersetClient, JsonObject } from '@superset-ui/core';
type Payload = {
dataset_id: number;
datasource_id: number;
datasource_type: string;
form_data: string;
chart_id?: number;
};
@ -42,12 +43,14 @@ const assembleEndpoint = (key?: string, tabId?: string) => {
};
const assemblePayload = (
datasetId: number,
datasourceId: number,
datasourceType: string,
formData: JsonObject,
chartId?: number,
) => {
const payload: Payload = {
dataset_id: datasetId,
datasource_id: datasourceId,
datasource_type: datasourceType,
form_data: JSON.stringify(sanitizeFormData(formData)),
};
if (chartId) {
@ -57,18 +60,25 @@ const assemblePayload = (
};
export const postFormData = (
datasetId: number,
datasourceId: number,
datasourceType: string,
formData: JsonObject,
chartId?: number,
tabId?: string,
): Promise<string> =>
SupersetClient.post({
endpoint: assembleEndpoint(undefined, tabId),
jsonPayload: assemblePayload(datasetId, formData, chartId),
jsonPayload: assemblePayload(
datasourceId,
datasourceType,
formData,
chartId,
),
}).then(r => r.json.key);
export const putFormData = (
datasetId: number,
datasourceId: number,
datasourceType: string,
key: string,
formData: JsonObject,
chartId?: number,
@ -76,5 +86,10 @@ export const putFormData = (
): Promise<string> =>
SupersetClient.put({
endpoint: assembleEndpoint(key, tabId),
jsonPayload: assemblePayload(datasetId, formData, chartId),
jsonPayload: assemblePayload(
datasourceId,
datasourceType,
formData,
chartId,
),
}).then(r => r.json.message);

View File

@ -22,6 +22,7 @@ from superset.charts.schemas import (
datasource_type_description,
datasource_uid_description,
)
from superset.utils.core import DatasourceType
class Datasource(Schema):
@ -36,7 +37,7 @@ class Datasource(Schema):
)
datasource_type = fields.String(
description=datasource_type_description,
validate=validate.OneOf(choices=("druid", "table", "view")),
validate=validate.OneOf(choices=[ds.value for ds in DatasourceType]),
required=True,
)

View File

@ -31,6 +31,7 @@ from superset.db_engine_specs.base import builtin_time_grains
from superset.utils import pandas_postprocessing, schema as utils
from superset.utils.core import (
AnnotationType,
DatasourceType,
FilterOperator,
PostProcessingBoxplotWhiskerType,
PostProcessingContributionOrientation,
@ -198,7 +199,7 @@ class ChartPostSchema(Schema):
datasource_id = fields.Integer(description=datasource_id_description, required=True)
datasource_type = fields.String(
description=datasource_type_description,
validate=validate.OneOf(choices=("druid", "table", "view")),
validate=validate.OneOf(choices=[ds.value for ds in DatasourceType]),
required=True,
)
datasource_name = fields.String(
@ -244,7 +245,7 @@ class ChartPutSchema(Schema):
)
datasource_type = fields.String(
description=datasource_type_description,
validate=validate.OneOf(choices=("druid", "table", "view")),
validate=validate.OneOf(choices=[ds.value for ds in DatasourceType]),
allow_none=True,
)
dashboards = fields.List(fields.Integer(description=dashboards_description))
@ -983,7 +984,7 @@ class ChartDataDatasourceSchema(Schema):
)
type = fields.String(
description="Datasource type",
validate=validate.OneOf(choices=("druid", "table")),
validate=validate.OneOf(choices=[ds.value for ds in DatasourceType]),
)

View File

@ -115,8 +115,24 @@ class RolesNotFoundValidationError(ValidationError):
super().__init__([_("Some roles do not exist")], field_name="roles")
class DatasourceTypeInvalidError(ValidationError):
status = 422
def __init__(self) -> None:
super().__init__(
[_("Datasource type is invalid")], field_name="datasource_type"
)
class DatasourceNotFoundValidationError(ValidationError):
status = 404
def __init__(self) -> None:
super().__init__([_("Dataset does not exist")], field_name="datasource_id")
super().__init__([_("Datasource does not exist")], field_name="datasource_id")
class QueryNotFoundValidationError(ValidationError):
status = 404
def __init__(self) -> None:
super().__init__([_("Query does not exist")], field_name="datasource_id")

View File

@ -39,11 +39,11 @@ Datasource = Union[Dataset, SqlaTable, Table, Query, SavedQuery]
class DatasourceDAO(BaseDAO):
sources: Dict[DatasourceType, Type[Datasource]] = {
DatasourceType.SQLATABLE: SqlaTable,
DatasourceType.TABLE: SqlaTable,
DatasourceType.QUERY: Query,
DatasourceType.SAVEDQUERY: SavedQuery,
DatasourceType.DATASET: Dataset,
DatasourceType.TABLE: Table,
DatasourceType.SLTABLE: Table,
}
@classmethod
@ -66,7 +66,7 @@ class DatasourceDAO(BaseDAO):
@classmethod
def get_all_sqlatables_datasources(cls, session: Session) -> List[Datasource]:
source_class = DatasourceDAO.sources[DatasourceType.SQLATABLE]
source_class = DatasourceDAO.sources[DatasourceType.TABLE]
qry = session.query(source_class)
qry = source_class.default_query(qry)
return qry.all()

View File

@ -24,6 +24,7 @@ from superset.models.core import Database
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.models.sql_lab import TabState
from superset.utils.core import DatasourceType
logger = logging.getLogger(__name__)
@ -75,7 +76,8 @@ class DatabaseDAO(BaseDAO):
charts = (
db.session.query(Slice)
.filter(
Slice.datasource_id.in_(dataset_ids), Slice.datasource_type == "table"
Slice.datasource_id.in_(dataset_ids),
Slice.datasource_type == DatasourceType.TABLE,
)
.all()
)

View File

@ -26,6 +26,7 @@ from superset.extensions import db
from superset.models.core import Database
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.utils.core import DatasourceType
from superset.views.base import DatasourceFilter
logger = logging.getLogger(__name__)
@ -56,7 +57,8 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
charts = (
db.session.query(Slice)
.filter(
Slice.datasource_id == database_id, Slice.datasource_type == "table"
Slice.datasource_id == database_id,
Slice.datasource_type == DatasourceType.TABLE,
)
.all()
)

View File

@ -29,6 +29,7 @@ from superset.exceptions import NoDataException
from superset.models.core import Database
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.utils.core import DatasourceType
from ..utils.database import get_example_database
from .helpers import (
@ -205,13 +206,16 @@ def create_slices(tbl: SqlaTable, admin_owner: bool) -> Tuple[List[Slice], List[
if admin_owner:
slice_props = dict(
datasource_id=tbl.id,
datasource_type="table",
datasource_type=DatasourceType.TABLE,
owners=[admin],
created_by=admin,
)
else:
slice_props = dict(
datasource_id=tbl.id, datasource_type="table", owners=[], created_by=admin
datasource_id=tbl.id,
datasource_type=DatasourceType.TABLE,
owners=[],
created_by=admin,
)
print("Creating some slices")

View File

@ -24,6 +24,7 @@ import superset.utils.database as database_utils
from superset import db
from superset.connectors.sqla.models import SqlMetric
from superset.models.slice import Slice
from superset.utils.core import DatasourceType
from .helpers import (
get_example_data,
@ -112,7 +113,7 @@ def load_country_map_data(only_metadata: bool = False, force: bool = False) -> N
slc = Slice(
slice_name="Birth in France by department in 2016",
viz_type="country_map",
datasource_type="table",
datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id,
params=get_slice_json(slice_data),
)

View File

@ -19,6 +19,7 @@ import json
from superset import db
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.utils.core import DatasourceType
from .helpers import (
get_slice_json,
@ -213,7 +214,7 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements
slc = Slice(
slice_name="Deck.gl Scatterplot",
viz_type="deck_scatter",
datasource_type="table",
datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id,
params=get_slice_json(slice_data),
)
@ -248,7 +249,7 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements
slc = Slice(
slice_name="Deck.gl Screen grid",
viz_type="deck_screengrid",
datasource_type="table",
datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id,
params=get_slice_json(slice_data),
)
@ -284,7 +285,7 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements
slc = Slice(
slice_name="Deck.gl Hexagons",
viz_type="deck_hex",
datasource_type="table",
datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id,
params=get_slice_json(slice_data),
)
@ -321,7 +322,7 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements
slc = Slice(
slice_name="Deck.gl Grid",
viz_type="deck_grid",
datasource_type="table",
datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id,
params=get_slice_json(slice_data),
)
@ -410,7 +411,7 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements
slc = Slice(
slice_name="Deck.gl Polygons",
viz_type="deck_polygon",
datasource_type="table",
datasource_type=DatasourceType.TABLE,
datasource_id=polygon_tbl.id,
params=get_slice_json(slice_data),
)
@ -460,7 +461,7 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements
slc = Slice(
slice_name="Deck.gl Arcs",
viz_type="deck_arc",
datasource_type="table",
datasource_type=DatasourceType.TABLE,
datasource_id=db.session.query(table)
.filter_by(table_name="flights")
.first()
@ -512,7 +513,7 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements
slc = Slice(
slice_name="Deck.gl Path",
viz_type="deck_path",
datasource_type="table",
datasource_type=DatasourceType.TABLE,
datasource_id=db.session.query(table)
.filter_by(table_name="bart_lines")
.first()

View File

@ -25,6 +25,7 @@ import superset.utils.database as database_utils
from superset import db
from superset.connectors.sqla.models import SqlMetric
from superset.models.slice import Slice
from superset.utils.core import DatasourceType
from .helpers import (
get_example_data,
@ -81,7 +82,7 @@ def load_energy(
slc = Slice(
slice_name="Energy Sankey",
viz_type="sankey",
datasource_type="table",
datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id,
params=textwrap.dedent(
"""\
@ -105,7 +106,7 @@ def load_energy(
slc = Slice(
slice_name="Energy Force Layout",
viz_type="graph_chart",
datasource_type="table",
datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id,
params=textwrap.dedent(
"""\
@ -129,7 +130,7 @@ def load_energy(
slc = Slice(
slice_name="Heatmap",
viz_type="heatmap",
datasource_type="table",
datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id,
params=textwrap.dedent(
"""\

View File

@ -24,6 +24,7 @@ from sqlalchemy import DateTime, Float, inspect, String
import superset.utils.database as database_utils
from superset import db
from superset.models.slice import Slice
from superset.utils.core import DatasourceType
from .helpers import (
get_example_data,
@ -113,7 +114,7 @@ def load_long_lat_data(only_metadata: bool = False, force: bool = False) -> None
slc = Slice(
slice_name="Mapbox Long/Lat",
viz_type="mapbox",
datasource_type="table",
datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id,
params=get_slice_json(slice_data),
)

View File

@ -18,6 +18,7 @@ import json
from superset import db
from superset.models.slice import Slice
from superset.utils.core import DatasourceType
from .birth_names import load_birth_names
from .helpers import merge_slice, misc_dash_slices
@ -35,7 +36,7 @@ def load_multi_line(only_metadata: bool = False) -> None:
]
slc = Slice(
datasource_type="table", # not true, but needed
datasource_type=DatasourceType.TABLE, # not true, but needed
datasource_id=1, # cannot be empty
slice_name="Multi Line",
viz_type="line_multi",

View File

@ -21,6 +21,7 @@ from sqlalchemy import BigInteger, Date, DateTime, inspect, String
from superset import app, db
from superset.models.slice import Slice
from superset.utils.core import DatasourceType
from ..utils.database import get_example_database
from .helpers import (
@ -120,7 +121,7 @@ def load_multiformat_time_series( # pylint: disable=too-many-locals
slc = Slice(
slice_name=f"Calendar Heatmap multiformat {i}",
viz_type="cal_heatmap",
datasource_type="table",
datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id,
params=get_slice_json(slice_data),
)

View File

@ -21,6 +21,7 @@ from sqlalchemy import DateTime, inspect, String
import superset.utils.database as database_utils
from superset import app, db
from superset.models.slice import Slice
from superset.utils.core import DatasourceType
from .helpers import (
get_example_data,
@ -89,7 +90,7 @@ def load_random_time_series_data(
slc = Slice(
slice_name="Calendar Heatmap",
viz_type="cal_heatmap",
datasource_type="table",
datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id,
params=get_slice_json(slice_data),
)

View File

@ -29,6 +29,7 @@ from superset.connectors.sqla.models import SqlMetric
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.utils import core as utils
from superset.utils.core import DatasourceType
from ..connectors.base.models import BaseDatasource
from .helpers import (
@ -172,7 +173,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]:
Slice(
slice_name="Region Filter",
viz_type="filter_box",
datasource_type="table",
datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id,
params=get_slice_json(
defaults,
@ -201,7 +202,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]:
Slice(
slice_name="World's Population",
viz_type="big_number",
datasource_type="table",
datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id,
params=get_slice_json(
defaults,
@ -215,7 +216,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]:
Slice(
slice_name="Most Populated Countries",
viz_type="table",
datasource_type="table",
datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id,
params=get_slice_json(
defaults,
@ -227,7 +228,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]:
Slice(
slice_name="Growth Rate",
viz_type="line",
datasource_type="table",
datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id,
params=get_slice_json(
defaults,
@ -241,7 +242,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]:
Slice(
slice_name="% Rural",
viz_type="world_map",
datasource_type="table",
datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id,
params=get_slice_json(
defaults,
@ -254,7 +255,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]:
Slice(
slice_name="Life Expectancy VS Rural %",
viz_type="bubble",
datasource_type="table",
datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id,
params=get_slice_json(
defaults,
@ -298,7 +299,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]:
Slice(
slice_name="Rural Breakdown",
viz_type="sunburst",
datasource_type="table",
datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id,
params=get_slice_json(
defaults,
@ -313,7 +314,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]:
Slice(
slice_name="World's Pop Growth",
viz_type="area",
datasource_type="table",
datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id,
params=get_slice_json(
defaults,
@ -327,7 +328,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]:
Slice(
slice_name="Box plot",
viz_type="box_plot",
datasource_type="table",
datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id,
params=get_slice_json(
defaults,
@ -343,7 +344,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]:
Slice(
slice_name="Treemap",
viz_type="treemap",
datasource_type="table",
datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id,
params=get_slice_json(
defaults,
@ -357,7 +358,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]:
Slice(
slice_name="Parallel Coordinates",
viz_type="para",
datasource_type="table",
datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id,
params=get_slice_json(
defaults,

View File

@ -104,7 +104,8 @@ class ExploreFormDataRestApi(BaseApi, ABC):
tab_id = request.args.get("tab_id")
args = CommandParameters(
actor=g.user,
dataset_id=item["dataset_id"],
datasource_id=item["datasource_id"],
datasource_type=item["datasource_type"],
chart_id=item.get("chart_id"),
tab_id=tab_id,
form_data=item["form_data"],
@ -123,7 +124,7 @@ class ExploreFormDataRestApi(BaseApi, ABC):
@safe
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.put",
log_to_statsd=False,
log_to_statsd=True,
)
@requires_json
def put(self, key: str) -> Response:
@ -174,7 +175,8 @@ class ExploreFormDataRestApi(BaseApi, ABC):
tab_id = request.args.get("tab_id")
args = CommandParameters(
actor=g.user,
dataset_id=item["dataset_id"],
datasource_id=item["datasource_id"],
datasource_type=item["datasource_type"],
chart_id=item.get("chart_id"),
tab_id=tab_id,
key=key,
@ -196,7 +198,7 @@ class ExploreFormDataRestApi(BaseApi, ABC):
@safe
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.get",
log_to_statsd=False,
log_to_statsd=True,
)
def get(self, key: str) -> Response:
"""Retrives a form_data.
@ -247,7 +249,7 @@ class ExploreFormDataRestApi(BaseApi, ABC):
@safe
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.delete",
log_to_statsd=False,
log_to_statsd=True,
)
def delete(self, key: str) -> Response:
"""Deletes a form_data.

View File

@ -39,20 +39,24 @@ class CreateFormDataCommand(BaseCommand):
def run(self) -> str:
self.validate()
try:
dataset_id = self._cmd_params.dataset_id
datasource_id = self._cmd_params.datasource_id
datasource_type = self._cmd_params.datasource_type
chart_id = self._cmd_params.chart_id
tab_id = self._cmd_params.tab_id
actor = self._cmd_params.actor
form_data = self._cmd_params.form_data
check_access(dataset_id, chart_id, actor)
contextual_key = cache_key(session.get("_id"), tab_id, dataset_id, chart_id)
check_access(datasource_id, chart_id, actor, datasource_type)
contextual_key = cache_key(
session.get("_id"), tab_id, datasource_id, chart_id, datasource_type
)
key = cache_manager.explore_form_data_cache.get(contextual_key)
if not key or not tab_id:
key = random_key()
if form_data:
state: TemporaryExploreState = {
"owner": get_owner(actor),
"dataset_id": dataset_id,
"datasource_id": datasource_id,
"datasource_type": datasource_type,
"chart_id": chart_id,
"form_data": form_data,
}

View File

@ -16,6 +16,7 @@
# under the License.
import logging
from abc import ABC
from typing import Optional
from flask import session
from sqlalchemy.exc import SQLAlchemyError
@ -31,6 +32,7 @@ from superset.temporary_cache.commands.exceptions import (
TemporaryCacheDeleteFailedError,
)
from superset.temporary_cache.utils import cache_key
from superset.utils.core import DatasourceType
logger = logging.getLogger(__name__)
@ -47,14 +49,15 @@ class DeleteFormDataCommand(BaseCommand, ABC):
key
)
if state:
dataset_id = state["dataset_id"]
chart_id = state["chart_id"]
check_access(dataset_id, chart_id, actor)
datasource_id: int = state["datasource_id"]
chart_id: Optional[int] = state["chart_id"]
datasource_type = DatasourceType(state["datasource_type"])
check_access(datasource_id, chart_id, actor, datasource_type)
if state["owner"] != get_owner(actor):
raise TemporaryCacheAccessDeniedError()
tab_id = self._cmd_params.tab_id
contextual_key = cache_key(
session.get("_id"), tab_id, dataset_id, chart_id
session.get("_id"), tab_id, datasource_id, chart_id, datasource_type
)
cache_manager.explore_form_data_cache.delete(contextual_key)
return cache_manager.explore_form_data_cache.delete(key)

View File

@ -27,6 +27,7 @@ from superset.explore.form_data.commands.state import TemporaryExploreState
from superset.explore.form_data.commands.utils import check_access
from superset.extensions import cache_manager
from superset.temporary_cache.commands.exceptions import TemporaryCacheGetFailedError
from superset.utils.core import DatasourceType
logger = logging.getLogger(__name__)
@ -45,7 +46,12 @@ class GetFormDataCommand(BaseCommand, ABC):
key
)
if state:
check_access(state["dataset_id"], state["chart_id"], actor)
check_access(
state["datasource_id"],
state["chart_id"],
actor,
DatasourceType(state["datasource_type"]),
)
if self._refresh_timeout:
cache_manager.explore_form_data_cache.set(key, state)
return state["form_data"]

View File

@ -19,11 +19,14 @@ from typing import Optional
from flask_appbuilder.security.sqla.models import User
from superset.utils.core import DatasourceType
@dataclass
class CommandParameters:
actor: User
dataset_id: int = 0
datasource_type: DatasourceType = DatasourceType.TABLE
datasource_id: int = 0
chart_id: int = 0
tab_id: Optional[int] = None
key: Optional[str] = None

View File

@ -21,6 +21,7 @@ from typing_extensions import TypedDict
class TemporaryExploreState(TypedDict):
owner: Optional[int]
dataset_id: int
datasource_id: int
datasource_type: str
chart_id: Optional[int]
form_data: str

View File

@ -47,12 +47,13 @@ class UpdateFormDataCommand(BaseCommand, ABC):
def run(self) -> Optional[str]:
self.validate()
try:
dataset_id = self._cmd_params.dataset_id
datasource_id = self._cmd_params.datasource_id
chart_id = self._cmd_params.chart_id
datasource_type = self._cmd_params.datasource_type
actor = self._cmd_params.actor
key = self._cmd_params.key
form_data = self._cmd_params.form_data
check_access(dataset_id, chart_id, actor)
check_access(datasource_id, chart_id, actor, datasource_type)
state: TemporaryExploreState = cache_manager.explore_form_data_cache.get(
key
)
@ -64,7 +65,7 @@ class UpdateFormDataCommand(BaseCommand, ABC):
# Generate a new key if tab_id changes or equals 0
tab_id = self._cmd_params.tab_id
contextual_key = cache_key(
session.get("_id"), tab_id, dataset_id, chart_id
session.get("_id"), tab_id, datasource_id, chart_id, datasource_type
)
key = cache_manager.explore_form_data_cache.get(contextual_key)
if not key or not tab_id:
@ -73,7 +74,8 @@ class UpdateFormDataCommand(BaseCommand, ABC):
new_state: TemporaryExploreState = {
"owner": owner,
"dataset_id": dataset_id,
"datasource_id": datasource_id,
"datasource_type": datasource_type,
"chart_id": chart_id,
"form_data": form_data,
}

View File

@ -31,11 +31,17 @@ from superset.temporary_cache.commands.exceptions import (
TemporaryCacheAccessDeniedError,
TemporaryCacheResourceNotFoundError,
)
from superset.utils.core import DatasourceType
def check_access(dataset_id: int, chart_id: Optional[int], actor: User) -> None:
def check_access(
datasource_id: int,
chart_id: Optional[int],
actor: User,
datasource_type: DatasourceType,
) -> None:
try:
explore_check_access(dataset_id, chart_id, actor)
explore_check_access(datasource_id, chart_id, actor, datasource_type)
except (ChartNotFoundError, DatasetNotFoundError) as ex:
raise TemporaryCacheResourceNotFoundError from ex
except (ChartAccessDeniedError, DatasetAccessDeniedError) as ex:

View File

@ -14,12 +14,20 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from marshmallow import fields, Schema
from marshmallow import fields, Schema, validate
from superset.utils.core import DatasourceType
class FormDataPostSchema(Schema):
dataset_id = fields.Integer(
required=True, allow_none=False, description="The dataset ID"
datasource_id = fields.Integer(
required=True, allow_none=False, description="The datasource ID"
)
datasource_type = fields.String(
required=True,
allow_none=False,
description="The datasource type",
validate=validate.OneOf(choices=[ds.value for ds in DatasourceType]),
)
chart_id = fields.Integer(required=False, description="The chart ID")
form_data = fields.String(
@ -28,8 +36,14 @@ class FormDataPostSchema(Schema):
class FormDataPutSchema(Schema):
dataset_id = fields.Integer(
required=True, allow_none=False, description="The dataset ID"
datasource_id = fields.Integer(
required=True, allow_none=False, description="The datasource ID"
)
datasource_type = fields.String(
required=True,
allow_none=False,
description="The datasource type",
validate=validate.OneOf(choices=[ds.value for ds in DatasourceType]),
)
chart_id = fields.Integer(required=False, description="The chart ID")
form_data = fields.String(

View File

@ -22,9 +22,10 @@ from sqlalchemy.exc import SQLAlchemyError
from superset.explore.permalink.commands.base import BaseExplorePermalinkCommand
from superset.explore.permalink.exceptions import ExplorePermalinkCreateFailedError
from superset.explore.utils import check_access
from superset.explore.utils import check_access as check_chart_access
from superset.key_value.commands.create import CreateKeyValueCommand
from superset.key_value.utils import encode_permalink_key
from superset.utils.core import DatasourceType
logger = logging.getLogger(__name__)
@ -39,11 +40,16 @@ class CreateExplorePermalinkCommand(BaseExplorePermalinkCommand):
def run(self) -> str:
self.validate()
try:
dataset_id = int(self.datasource.split("__")[0])
check_access(dataset_id, self.chart_id, self.actor)
d_id, d_type = self.datasource.split("__")
datasource_id = int(d_id)
datasource_type = DatasourceType(d_type)
check_chart_access(
datasource_id, self.chart_id, self.actor, datasource_type
)
value = {
"chartId": self.chart_id,
"datasetId": dataset_id,
"datasourceId": datasource_id,
"datasourceType": datasource_type,
"datasource": self.datasource,
"state": self.state,
}

View File

@ -24,10 +24,11 @@ from superset.datasets.commands.exceptions import DatasetNotFoundError
from superset.explore.permalink.commands.base import BaseExplorePermalinkCommand
from superset.explore.permalink.exceptions import ExplorePermalinkGetFailedError
from superset.explore.permalink.types import ExplorePermalinkValue
from superset.explore.utils import check_access
from superset.explore.utils import check_access as check_chart_access
from superset.key_value.commands.get import GetKeyValueCommand
from superset.key_value.exceptions import KeyValueGetFailedError, KeyValueParseKeyError
from superset.key_value.utils import decode_permalink_id
from superset.utils.core import DatasourceType
logger = logging.getLogger(__name__)
@ -47,8 +48,9 @@ class GetExplorePermalinkCommand(BaseExplorePermalinkCommand):
).run()
if value:
chart_id: Optional[int] = value.get("chartId")
dataset_id = value["datasetId"]
check_access(dataset_id, chart_id, self.actor)
datasource_id: int = value["datasourceId"]
datasource_type = DatasourceType(value["datasourceType"])
check_chart_access(datasource_id, chart_id, self.actor, datasource_type)
return value
return None
except (

View File

@ -24,6 +24,7 @@ class ExplorePermalinkState(TypedDict, total=False):
class ExplorePermalinkValue(TypedDict):
chartId: Optional[int]
datasetId: int
datasourceId: int
datasourceType: str
datasource: str
state: ExplorePermalinkState

View File

@ -24,11 +24,18 @@ from superset.charts.commands.exceptions import (
ChartNotFoundError,
)
from superset.charts.dao import ChartDAO
from superset.commands.exceptions import (
DatasourceNotFoundValidationError,
DatasourceTypeInvalidError,
QueryNotFoundValidationError,
)
from superset.datasets.commands.exceptions import (
DatasetAccessDeniedError,
DatasetNotFoundError,
)
from superset.datasets.dao import DatasetDAO
from superset.queries.dao import QueryDAO
from superset.utils.core import DatasourceType
from superset.views.base import is_user_admin
from superset.views.utils import is_owner
@ -44,10 +51,41 @@ def check_dataset_access(dataset_id: int) -> Optional[bool]:
raise DatasetNotFoundError()
def check_access(dataset_id: int, chart_id: Optional[int], actor: User) -> None:
check_dataset_access(dataset_id)
def check_query_access(query_id: int) -> Optional[bool]:
if query_id:
query = QueryDAO.find_by_id(query_id)
if query:
security_manager.raise_for_access(query=query)
return True
raise QueryNotFoundValidationError()
ACCESS_FUNCTION_MAP = {
DatasourceType.TABLE: check_dataset_access,
DatasourceType.QUERY: check_query_access,
}
def check_datasource_access(
datasource_id: int, datasource_type: DatasourceType
) -> Optional[bool]:
if datasource_id:
try:
return ACCESS_FUNCTION_MAP[datasource_type](datasource_id)
except KeyError as ex:
raise DatasourceTypeInvalidError() from ex
raise DatasourceNotFoundValidationError()
def check_access(
datasource_id: int,
chart_id: Optional[int],
actor: User,
datasource_type: DatasourceType,
) -> Optional[bool]:
check_datasource_access(datasource_id, datasource_type)
if not chart_id:
return
return True
chart = ChartDAO.find_by_id(chart_id)
if chart:
can_access_chart = (
@ -56,6 +94,6 @@ def check_access(dataset_id: int, chart_id: Optional[int], actor: User) -> None:
or security_manager.can_access("can_read", "Chart")
)
if can_access_chart:
return
return True
raise ChartAccessDeniedError()
raise ChartNotFoundError()

View File

@ -15,15 +15,40 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Optional, Union
from flask import Flask
from flask_caching import Cache
from markupsafe import Markup
from superset.utils.core import DatasourceType
logger = logging.getLogger(__name__)
CACHE_IMPORT_PATH = "superset.extensions.metastore_cache.SupersetMetastoreCache"
class ExploreFormDataCache(Cache):
def get(self, *args: Any, **kwargs: Any) -> Optional[Union[str, Markup]]:
cache = self.cache.get(*args, **kwargs)
if not cache:
return None
# rename data keys for existing cache based on new TemporaryExploreState model
if isinstance(cache, dict):
cache = {
("datasource_id" if key == "dataset_id" else key): value
for (key, value) in cache.items()
}
# add default datasource_type if it doesn't exist
# temporarily defaulting to table until sqlatables are deprecated
if "datasource_type" not in cache:
cache["datasource_type"] = DatasourceType.TABLE
return cache
class CacheManager:
def __init__(self) -> None:
super().__init__()
@ -32,7 +57,7 @@ class CacheManager:
self._data_cache = Cache()
self._thumbnail_cache = Cache()
self._filter_state_cache = Cache()
self._explore_form_data_cache = Cache()
self._explore_form_data_cache = ExploreFormDataCache()
@staticmethod
def _init_cache(

View File

@ -175,12 +175,13 @@ class GenericDataType(IntEnum):
# ROW = 7
class DatasourceType(Enum):
SQLATABLE = "sqlatable"
class DatasourceType(str, Enum):
SLTABLE = "sl_table"
TABLE = "table"
DATASET = "dataset"
QUERY = "query"
SAVEDQUERY = "saved_query"
VIEW = "view"
class DatasourceDict(TypedDict):

View File

@ -520,7 +520,13 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
response,
{"message": {"datasource_type": ["Must be one of: druid, table, view."]}},
{
"message": {
"datasource_type": [
"Must be one of: sl_table, table, dataset, query, saved_query, view."
]
}
},
)
chart_data = {
"slice_name": "title1",
@ -531,7 +537,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
self.assertEqual(rv.status_code, 422)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
response, {"message": {"datasource_id": ["Dataset does not exist"]}}
response, {"message": {"datasource_id": ["Datasource does not exist"]}}
)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@ -686,7 +692,13 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
response,
{"message": {"datasource_type": ["Must be one of: druid, table, view."]}},
{
"message": {
"datasource_type": [
"Must be one of: sl_table, table, dataset, query, saved_query, view."
]
}
},
)
chart_data = {"datasource_id": 0, "datasource_type": "table"}
@ -694,7 +706,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
self.assertEqual(rv.status_code, 422)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
response, {"message": {"datasource_id": ["Dataset does not exist"]}}
response, {"message": {"datasource_id": ["Datasource does not exist"]}}
)
db.session.delete(chart)

View File

@ -26,7 +26,7 @@ from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.utils.core import get_example_default_schema
from superset.utils.core import DatasourceType, get_example_default_schema
def get_table(
@ -72,7 +72,7 @@ def create_slice(
return Slice(
slice_name=title,
viz_type=viz_type,
datasource_type="table",
datasource_type=DatasourceType.TABLE,
datasource_id=table.id,
params=json.dumps(slices_dict, indent=4, sort_keys=True),
)

View File

@ -56,7 +56,7 @@ def admin_id() -> int:
@pytest.fixture
def dataset_id() -> int:
def datasource() -> int:
with app.app_context() as ctx:
session: Session = ctx.app.appbuilder.get_session
dataset = (
@ -64,24 +64,26 @@ def dataset_id() -> int:
.filter_by(table_name="wb_health_population")
.first()
)
return dataset.id
return dataset
@pytest.fixture(autouse=True)
def cache(chart_id, admin_id, dataset_id):
def cache(chart_id, admin_id, datasource):
entry: TemporaryExploreState = {
"owner": admin_id,
"dataset_id": dataset_id,
"datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id,
"form_data": INITIAL_FORM_DATA,
}
cache_manager.explore_form_data_cache.set(KEY, entry)
def test_post(client, chart_id: int, dataset_id: int):
def test_post(client, chart_id: int, datasource: SqlaTable):
login(client, "admin")
payload = {
"dataset_id": dataset_id,
"datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id,
"form_data": INITIAL_FORM_DATA,
}
@ -89,10 +91,11 @@ def test_post(client, chart_id: int, dataset_id: int):
assert resp.status_code == 201
def test_post_bad_request_non_string(client, chart_id: int, dataset_id: int):
def test_post_bad_request_non_string(client, chart_id: int, datasource: SqlaTable):
login(client, "admin")
payload = {
"dataset_id": dataset_id,
"datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id,
"form_data": 1234,
}
@ -100,10 +103,11 @@ def test_post_bad_request_non_string(client, chart_id: int, dataset_id: int):
assert resp.status_code == 400
def test_post_bad_request_non_json_string(client, chart_id: int, dataset_id: int):
def test_post_bad_request_non_json_string(client, chart_id: int, datasource: SqlaTable):
login(client, "admin")
payload = {
"dataset_id": dataset_id,
"datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id,
"form_data": "foo",
}
@ -111,10 +115,11 @@ def test_post_bad_request_non_json_string(client, chart_id: int, dataset_id: int
assert resp.status_code == 400
def test_post_access_denied(client, chart_id: int, dataset_id: int):
def test_post_access_denied(client, chart_id: int, datasource: SqlaTable):
login(client, "gamma")
payload = {
"dataset_id": dataset_id,
"datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id,
"form_data": INITIAL_FORM_DATA,
}
@ -122,10 +127,11 @@ def test_post_access_denied(client, chart_id: int, dataset_id: int):
assert resp.status_code == 404
def test_post_same_key_for_same_context(client, chart_id: int, dataset_id: int):
def test_post_same_key_for_same_context(client, chart_id: int, datasource: SqlaTable):
login(client, "admin")
payload = {
"dataset_id": dataset_id,
"datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id,
"form_data": UPDATED_FORM_DATA,
}
@ -139,11 +145,12 @@ def test_post_same_key_for_same_context(client, chart_id: int, dataset_id: int):
def test_post_different_key_for_different_context(
client, chart_id: int, dataset_id: int
client, chart_id: int, datasource: SqlaTable
):
login(client, "admin")
payload = {
"dataset_id": dataset_id,
"datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id,
"form_data": UPDATED_FORM_DATA,
}
@ -151,7 +158,8 @@ def test_post_different_key_for_different_context(
data = json.loads(resp.data.decode("utf-8"))
first_key = data.get("key")
payload = {
"dataset_id": dataset_id,
"datasource_id": datasource.id,
"datasource_type": datasource.type,
"form_data": json.dumps({"test": "initial value"}),
}
resp = client.post("api/v1/explore/form_data?tab_id=1", json=payload)
@ -160,10 +168,11 @@ def test_post_different_key_for_different_context(
assert first_key != second_key
def test_post_same_key_for_same_tab_id(client, chart_id: int, dataset_id: int):
def test_post_same_key_for_same_tab_id(client, chart_id: int, datasource: SqlaTable):
login(client, "admin")
payload = {
"dataset_id": dataset_id,
"datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id,
"form_data": json.dumps({"test": "initial value"}),
}
@ -177,11 +186,12 @@ def test_post_same_key_for_same_tab_id(client, chart_id: int, dataset_id: int):
def test_post_different_key_for_different_tab_id(
client, chart_id: int, dataset_id: int
client, chart_id: int, datasource: SqlaTable
):
login(client, "admin")
payload = {
"dataset_id": dataset_id,
"datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id,
"form_data": json.dumps({"test": "initial value"}),
}
@ -194,10 +204,11 @@ def test_post_different_key_for_different_tab_id(
assert first_key != second_key
def test_post_different_key_for_no_tab_id(client, chart_id: int, dataset_id: int):
def test_post_different_key_for_no_tab_id(client, chart_id: int, datasource: SqlaTable):
login(client, "admin")
payload = {
"dataset_id": dataset_id,
"datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id,
"form_data": INITIAL_FORM_DATA,
}
@ -210,10 +221,11 @@ def test_post_different_key_for_no_tab_id(client, chart_id: int, dataset_id: int
assert first_key != second_key
def test_put(client, chart_id: int, dataset_id: int):
def test_put(client, chart_id: int, datasource: SqlaTable):
login(client, "admin")
payload = {
"dataset_id": dataset_id,
"datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id,
"form_data": UPDATED_FORM_DATA,
}
@ -221,10 +233,11 @@ def test_put(client, chart_id: int, dataset_id: int):
assert resp.status_code == 200
def test_put_same_key_for_same_tab_id(client, chart_id: int, dataset_id: int):
def test_put_same_key_for_same_tab_id(client, chart_id: int, datasource: SqlaTable):
login(client, "admin")
payload = {
"dataset_id": dataset_id,
"datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id,
"form_data": UPDATED_FORM_DATA,
}
@ -237,10 +250,13 @@ def test_put_same_key_for_same_tab_id(client, chart_id: int, dataset_id: int):
assert first_key == second_key
def test_put_different_key_for_different_tab_id(client, chart_id: int, dataset_id: int):
def test_put_different_key_for_different_tab_id(
client, chart_id: int, datasource: SqlaTable
):
login(client, "admin")
payload = {
"dataset_id": dataset_id,
"datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id,
"form_data": UPDATED_FORM_DATA,
}
@ -253,10 +269,11 @@ def test_put_different_key_for_different_tab_id(client, chart_id: int, dataset_i
assert first_key != second_key
def test_put_different_key_for_no_tab_id(client, chart_id: int, dataset_id: int):
def test_put_different_key_for_no_tab_id(client, chart_id: int, datasource: SqlaTable):
login(client, "admin")
payload = {
"dataset_id": dataset_id,
"datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id,
"form_data": UPDATED_FORM_DATA,
}
@ -269,10 +286,11 @@ def test_put_different_key_for_no_tab_id(client, chart_id: int, dataset_id: int)
assert first_key != second_key
def test_put_bad_request(client, chart_id: int, dataset_id: int):
def test_put_bad_request(client, chart_id: int, datasource: SqlaTable):
login(client, "admin")
payload = {
"dataset_id": dataset_id,
"datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id,
"form_data": 1234,
}
@ -280,10 +298,11 @@ def test_put_bad_request(client, chart_id: int, dataset_id: int):
assert resp.status_code == 400
def test_put_bad_request_non_string(client, chart_id: int, dataset_id: int):
def test_put_bad_request_non_string(client, chart_id: int, datasource: SqlaTable):
login(client, "admin")
payload = {
"dataset_id": dataset_id,
"datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id,
"form_data": 1234,
}
@ -291,10 +310,11 @@ def test_put_bad_request_non_string(client, chart_id: int, dataset_id: int):
assert resp.status_code == 400
def test_put_bad_request_non_json_string(client, chart_id: int, dataset_id: int):
def test_put_bad_request_non_json_string(client, chart_id: int, datasource: SqlaTable):
login(client, "admin")
payload = {
"dataset_id": dataset_id,
"datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id,
"form_data": "foo",
}
@ -302,10 +322,11 @@ def test_put_bad_request_non_json_string(client, chart_id: int, dataset_id: int)
assert resp.status_code == 400
def test_put_access_denied(client, chart_id: int, dataset_id: int):
def test_put_access_denied(client, chart_id: int, datasource: SqlaTable):
login(client, "gamma")
payload = {
"dataset_id": dataset_id,
"datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id,
"form_data": UPDATED_FORM_DATA,
}
@ -313,10 +334,11 @@ def test_put_access_denied(client, chart_id: int, dataset_id: int):
assert resp.status_code == 404
def test_put_not_owner(client, chart_id: int, dataset_id: int):
def test_put_not_owner(client, chart_id: int, datasource: SqlaTable):
login(client, "gamma")
payload = {
"dataset_id": dataset_id,
"datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id,
"form_data": UPDATED_FORM_DATA,
}
@ -364,12 +386,13 @@ def test_delete_access_denied(client):
assert resp.status_code == 404
def test_delete_not_owner(client, chart_id: int, dataset_id: int, admin_id: int):
def test_delete_not_owner(client, chart_id: int, datasource: SqlaTable, admin_id: int):
another_key = "another_key"
another_owner = admin_id + 1
entry: TemporaryExploreState = {
"owner": another_owner,
"dataset_id": dataset_id,
"datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id,
"form_data": INITIAL_FORM_DATA,
}

View File

@ -0,0 +1,359 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
from unittest.mock import patch
import pytest
from superset import app, db, security, security_manager
from superset.commands.exceptions import DatasourceTypeInvalidError
from superset.connectors.sqla.models import SqlaTable
from superset.explore.form_data.commands.create import CreateFormDataCommand
from superset.explore.form_data.commands.delete import DeleteFormDataCommand
from superset.explore.form_data.commands.get import GetFormDataCommand
from superset.explore.form_data.commands.parameters import CommandParameters
from superset.explore.form_data.commands.update import UpdateFormDataCommand
from superset.models.slice import Slice
from superset.models.sql_lab import Query
from superset.utils.core import DatasourceType, get_example_default_schema
from superset.utils.database import get_example_database
from tests.integration_tests.base_tests import SupersetTestCase
class TestCreateFormDataCommand(SupersetTestCase):
@pytest.fixture()
def create_dataset(self):
with self.create_app().app_context():
dataset = SqlaTable(
table_name="dummy_sql_table",
database=get_example_database(),
schema=get_example_default_schema(),
sql="select 123 as intcol, 'abc' as strcol",
)
session = db.session
session.add(dataset)
session.commit()
yield dataset
# rollback
session.delete(dataset)
session.commit()
@pytest.fixture()
def create_slice(self):
with self.create_app().app_context():
session = db.session
dataset = (
session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first()
)
slice = Slice(
datasource_id=dataset.id,
datasource_type=DatasourceType.TABLE,
datasource_name="tmp_perm_table",
slice_name="slice_name",
)
session.add(slice)
session.commit()
yield slice
# rollback
session.delete(slice)
session.commit()
@pytest.fixture()
def create_query(self):
with self.create_app().app_context():
session = db.session
query = Query(
sql="select 1 as foo;",
client_id="sldkfjlk",
database=get_example_database(),
)
session.add(query)
session.commit()
yield query
# rollback
session.delete(query)
session.commit()
@patch("superset.security.manager.g")
@pytest.mark.usefixtures("create_dataset", "create_slice")
def test_create_form_data_command(self, mock_g):
mock_g.user = security_manager.find_user("admin")
dataset = (
db.session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first()
)
slice = db.session.query(Slice).filter_by(slice_name="slice_name").first()
datasource = f"{dataset.id}__{DatasourceType.TABLE}"
args = CommandParameters(
actor=mock_g.user,
datasource_id=dataset.id,
datasource_type=DatasourceType.TABLE,
chart_id=slice.id,
tab_id=1,
form_data=json.dumps({"datasource": datasource}),
)
command = CreateFormDataCommand(args)
assert isinstance(command.run(), str)
@patch("superset.security.manager.g")
@pytest.mark.usefixtures("create_dataset", "create_slice", "create_query")
def test_create_form_data_command_invalid_type(self, mock_g):
mock_g.user = security_manager.find_user("admin")
app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = {
"REFRESH_TIMEOUT_ON_RETRIEVAL": True
}
dataset = (
db.session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first()
)
slice = db.session.query(Slice).filter_by(slice_name="slice_name").first()
datasource = f"{dataset.id}__{DatasourceType.TABLE}"
create_args = CommandParameters(
actor=mock_g.user,
datasource_id=dataset.id,
datasource_type="InvalidType",
chart_id=slice.id,
tab_id=1,
form_data=json.dumps({"datasource": datasource}),
)
with pytest.raises(DatasourceTypeInvalidError) as exc:
CreateFormDataCommand(create_args).run()
assert "Datasource type is invalid" in str(exc.value)
@patch("superset.security.manager.g")
@pytest.mark.usefixtures("create_dataset", "create_slice", "create_query")
def test_create_form_data_command_type_as_string(self, mock_g):
mock_g.user = security_manager.find_user("admin")
app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = {
"REFRESH_TIMEOUT_ON_RETRIEVAL": True
}
dataset = (
db.session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first()
)
slice = db.session.query(Slice).filter_by(slice_name="slice_name").first()
datasource = f"{dataset.id}__{DatasourceType.TABLE}"
create_args = CommandParameters(
actor=mock_g.user,
datasource_id=dataset.id,
datasource_type="table",
chart_id=slice.id,
tab_id=1,
form_data=json.dumps({"datasource": datasource}),
)
command = CreateFormDataCommand(create_args)
assert isinstance(command.run(), str)
@patch("superset.security.manager.g")
@pytest.mark.usefixtures("create_dataset", "create_slice")
def test_get_form_data_command(self, mock_g):
mock_g.user = security_manager.find_user("admin")
app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = {
"REFRESH_TIMEOUT_ON_RETRIEVAL": True
}
dataset = (
db.session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first()
)
slice = db.session.query(Slice).filter_by(slice_name="slice_name").first()
datasource = f"{dataset.id}__{DatasourceType.TABLE}"
create_args = CommandParameters(
actor=mock_g.user,
datasource_id=dataset.id,
datasource_type=DatasourceType.TABLE,
chart_id=slice.id,
tab_id=1,
form_data=json.dumps({"datasource": datasource}),
)
key = CreateFormDataCommand(create_args).run()
key_args = CommandParameters(actor=mock_g.user, key=key)
get_command = GetFormDataCommand(key_args)
cache_data = json.loads(get_command.run())
assert cache_data.get("datasource") == datasource
@patch("superset.security.manager.g")
@pytest.mark.usefixtures("create_dataset", "create_slice", "create_query")
def test_update_form_data_command(self, mock_g):
mock_g.user = security_manager.find_user("admin")
app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = {
"REFRESH_TIMEOUT_ON_RETRIEVAL": True
}
dataset = (
db.session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first()
)
slice = db.session.query(Slice).filter_by(slice_name="slice_name").first()
query = db.session.query(Query).filter_by(sql="select 1 as foo;").first()
datasource = f"{dataset.id}__{DatasourceType.TABLE}"
create_args = CommandParameters(
actor=mock_g.user,
datasource_id=dataset.id,
datasource_type=DatasourceType.TABLE,
chart_id=slice.id,
tab_id=1,
form_data=json.dumps({"datasource": datasource}),
)
key = CreateFormDataCommand(create_args).run()
query_datasource = f"{dataset.id}__{DatasourceType.TABLE}"
update_args = CommandParameters(
actor=mock_g.user,
datasource_id=query.id,
datasource_type=DatasourceType.QUERY,
chart_id=slice.id,
tab_id=1,
form_data=json.dumps({"datasource": query_datasource}),
key=key,
)
update_command = UpdateFormDataCommand(update_args)
new_key = update_command.run()
# it should return a key
assert isinstance(new_key, str)
# the updated key returned should be different from the old one
assert new_key != key
key_args = CommandParameters(actor=mock_g.user, key=key)
get_command = GetFormDataCommand(key_args)
cache_data = json.loads(get_command.run())
assert cache_data.get("datasource") == query_datasource
@patch("superset.security.manager.g")
@pytest.mark.usefixtures("create_dataset", "create_slice", "create_query")
def test_update_form_data_command_same_form_data(self, mock_g):
mock_g.user = security_manager.find_user("admin")
app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = {
"REFRESH_TIMEOUT_ON_RETRIEVAL": True
}
dataset = (
db.session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first()
)
slice = db.session.query(Slice).filter_by(slice_name="slice_name").first()
datasource = f"{dataset.id}__{DatasourceType.TABLE}"
create_args = CommandParameters(
actor=mock_g.user,
datasource_id=dataset.id,
datasource_type=DatasourceType.TABLE,
chart_id=slice.id,
tab_id=1,
form_data=json.dumps({"datasource": datasource}),
)
key = CreateFormDataCommand(create_args).run()
update_args = CommandParameters(
actor=mock_g.user,
datasource_id=dataset.id,
datasource_type=DatasourceType.TABLE,
chart_id=slice.id,
tab_id=1,
form_data=json.dumps({"datasource": datasource}),
key=key,
)
update_command = UpdateFormDataCommand(update_args)
new_key = update_command.run()
# it should return a key
assert isinstance(new_key, str)
# the updated key returned should be the same as the old one
assert new_key == key
key_args = CommandParameters(actor=mock_g.user, key=key)
get_command = GetFormDataCommand(key_args)
cache_data = json.loads(get_command.run())
assert cache_data.get("datasource") == datasource
@patch("superset.security.manager.g")
@pytest.mark.usefixtures("create_dataset", "create_slice", "create_query")
def test_delete_form_data_command(self, mock_g):
mock_g.user = security_manager.find_user("admin")
app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = {
"REFRESH_TIMEOUT_ON_RETRIEVAL": True
}
dataset = (
db.session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first()
)
slice = db.session.query(Slice).filter_by(slice_name="slice_name").first()
datasource = f"{dataset.id}__{DatasourceType.TABLE}"
create_args = CommandParameters(
actor=mock_g.user,
datasource_id=dataset.id,
datasource_type=DatasourceType.TABLE,
chart_id=slice.id,
tab_id=1,
form_data=json.dumps({"datasource": datasource}),
)
key = CreateFormDataCommand(create_args).run()
delete_args = CommandParameters(
actor=mock_g.user,
key=key,
)
delete_command = DeleteFormDataCommand(delete_args)
response = delete_command.run()
assert response == True
@patch("superset.security.manager.g")
@pytest.mark.usefixtures("create_dataset", "create_slice", "create_query")
def test_delete_form_data_command_key_expired(self, mock_g):
mock_g.user = security_manager.find_user("admin")
app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = {
"REFRESH_TIMEOUT_ON_RETRIEVAL": True
}
delete_args = CommandParameters(
actor=mock_g.user,
key="some_expired_key",
)
delete_command = DeleteFormDataCommand(delete_args)
response = delete_command.run()
assert response == False

View File

@ -27,6 +27,7 @@ from superset.key_value.models import KeyValueEntry
from superset.key_value.types import KeyValueResource
from superset.key_value.utils import decode_permalink_id, encode_permalink_key
from superset.models.slice import Slice
from superset.utils.core import DatasourceType
from tests.integration_tests.base_tests import login
from tests.integration_tests.fixtures.client import client
from tests.integration_tests.fixtures.world_bank_dashboard import (
@ -97,7 +98,8 @@ def test_get_missing_chart(client, chart, permalink_salt: str) -> None:
value=pickle.dumps(
{
"chartId": chart_id,
"datasetId": chart.datasource.id,
"datasourceId": chart.datasource.id,
"datasourceType": DatasourceType.TABLE,
"formData": {
"slice_id": chart_id,
"datasource": f"{chart.datasource.id}__{chart.datasource.type}",

View File

@ -40,7 +40,7 @@ from superset.dashboards.commands.importers.v0 import import_chart, import_dashb
from superset.datasets.commands.importers.v0 import import_dataset
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.utils.core import get_example_default_schema
from superset.utils.core import DatasourceType, get_example_default_schema
from superset.utils.database import get_example_database
from tests.integration_tests.fixtures.world_bank_dashboard import (
@ -103,7 +103,7 @@ class TestImportExport(SupersetTestCase):
return Slice(
slice_name=name,
datasource_type="table",
datasource_type=DatasourceType.TABLE,
viz_type="bubble",
params=json.dumps(params),
datasource_id=ds_id,

View File

@ -16,6 +16,7 @@
# under the License.
# isort:skip_file
import json
from superset.utils.core import DatasourceType
import textwrap
import unittest
from unittest import mock
@ -604,7 +605,7 @@ class TestSqlaTableModel(SupersetTestCase):
dashboard = self.get_dash_by_slug("births")
slc = Slice(
slice_name="slice with adhoc column",
datasource_type="table",
datasource_type=DatasourceType.TABLE,
viz_type="table",
params=json.dumps(
{

View File

@ -39,6 +39,7 @@ from superset.models.core import Database
from superset.models.slice import Slice
from superset.sql_parse import Table
from superset.utils.core import (
DatasourceType,
backend,
get_example_default_schema,
)
@ -120,7 +121,7 @@ class TestRolePermission(SupersetTestCase):
ds_slices = (
session.query(Slice)
.filter_by(datasource_type="table")
.filter_by(datasource_type=DatasourceType.TABLE)
.filter_by(datasource_id=ds.id)
.all()
)
@ -143,7 +144,7 @@ class TestRolePermission(SupersetTestCase):
ds.schema_perm = None
ds_slices = (
session.query(Slice)
.filter_by(datasource_type="table")
.filter_by(datasource_type=DatasourceType.TABLE)
.filter_by(datasource_id=ds.id)
.all()
)
@ -365,7 +366,7 @@ class TestRolePermission(SupersetTestCase):
# no schema permission
slice = Slice(
datasource_id=table.id,
datasource_type="table",
datasource_type=DatasourceType.TABLE,
datasource_name="tmp_perm_table",
slice_name="slice_name",
)

View File

@ -0,0 +1,49 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from superset.extensions import cache_manager
from superset.utils.core import backend, DatasourceType
from tests.integration_tests.base_tests import SupersetTestCase
class UtilsCacheManagerTests(SupersetTestCase):
def test_get_set_explore_form_data_cache(self):
key = "12345"
data = {"foo": "bar", "datasource_type": "query"}
cache_manager.explore_form_data_cache.set(key, data)
assert cache_manager.explore_form_data_cache.get(key) == data
def test_get_same_context_twice(self):
key = "12345"
data = {"foo": "bar", "datasource_type": "query"}
cache_manager.explore_form_data_cache.set(key, data)
assert cache_manager.explore_form_data_cache.get(key) == data
assert cache_manager.explore_form_data_cache.get(key) == data
def test_get_set_explore_form_data_cache_no_datasource_type(self):
key = "12345"
data = {"foo": "bar"}
cache_manager.explore_form_data_cache.set(key, data)
# datasource_type should be added because it is not present
assert cache_manager.explore_form_data_cache.get(key) == {
"datasource_type": DatasourceType.TABLE,
**data,
}
def test_get_explore_form_data_cache_invalid_key(self):
assert cache_manager.explore_form_data_cache.get("foo") == None

View File

@ -106,7 +106,7 @@ def test_get_datasource_sqlatable(
from superset.dao.datasource.dao import DatasourceDAO
result = DatasourceDAO.get_datasource(
datasource_type=DatasourceType.SQLATABLE,
datasource_type=DatasourceType.TABLE,
datasource_id=1,
session=session_with_data,
)
@ -151,7 +151,9 @@ def test_get_datasource_sl_table(app_context: None, session_with_data: Session)
# todo(hugh): This will break once we remove the dual write
# update the datsource_id=1 and this will pass again
result = DatasourceDAO.get_datasource(
datasource_type=DatasourceType.TABLE, datasource_id=2, session=session_with_data
datasource_type=DatasourceType.SLTABLE,
datasource_id=2,
session=session_with_data,
)
assert result.id == 2

View File

@ -23,12 +23,21 @@ from superset.charts.commands.exceptions import (
ChartAccessDeniedError,
ChartNotFoundError,
)
from superset.commands.exceptions import (
DatasourceNotFoundValidationError,
DatasourceTypeInvalidError,
OwnersNotFoundValidationError,
QueryNotFoundValidationError,
)
from superset.datasets.commands.exceptions import (
DatasetAccessDeniedError,
DatasetNotFoundError,
)
from superset.exceptions import SupersetSecurityException
from superset.utils.core import DatasourceType
dataset_find_by_id = "superset.datasets.dao.DatasetDAO.find_by_id"
query_find_by_id = "superset.queries.dao.QueryDAO.find_by_id"
chart_find_by_id = "superset.charts.dao.ChartDAO.find_by_id"
is_user_admin = "superset.explore.utils.is_user_admin"
is_owner = "superset.explore.utils.is_owner"
@ -36,88 +45,142 @@ can_access_datasource = (
"superset.security.SupersetSecurityManager.can_access_datasource"
)
can_access = "superset.security.SupersetSecurityManager.can_access"
raise_for_access = "superset.security.SupersetSecurityManager.raise_for_access"
query_datasources_by_name = (
"superset.connectors.sqla.models.SqlaTable.query_datasources_by_name"
)
def test_unsaved_chart_no_dataset_id(app_context: AppContext) -> None:
from superset.explore.utils import check_access
from superset.explore.utils import check_access as check_chart_access
with raises(DatasetNotFoundError):
check_access(dataset_id=0, chart_id=0, actor=User())
with raises(DatasourceNotFoundValidationError):
check_chart_access(
datasource_id=0,
chart_id=0,
actor=User(),
datasource_type=DatasourceType.TABLE,
)
def test_unsaved_chart_unknown_dataset_id(
mocker: MockFixture, app_context: AppContext
) -> None:
from superset.explore.utils import check_access
from superset.explore.utils import check_access as check_chart_access
with raises(DatasetNotFoundError):
mocker.patch(dataset_find_by_id, return_value=None)
check_access(dataset_id=1, chart_id=0, actor=User())
check_chart_access(
datasource_id=1,
chart_id=0,
actor=User(),
datasource_type=DatasourceType.TABLE,
)
def test_unsaved_chart_unknown_query_id(
mocker: MockFixture, app_context: AppContext
) -> None:
from superset.explore.utils import check_access as check_chart_access
with raises(QueryNotFoundValidationError):
mocker.patch(query_find_by_id, return_value=None)
check_chart_access(
datasource_id=1,
chart_id=0,
actor=User(),
datasource_type=DatasourceType.QUERY,
)
def test_unsaved_chart_unauthorized_dataset(
mocker: MockFixture, app_context: AppContext
) -> None:
from superset.connectors.sqla.models import SqlaTable
from superset.explore import utils
from superset.explore.utils import check_access as check_chart_access
with raises(DatasetAccessDeniedError):
mocker.patch(dataset_find_by_id, return_value=SqlaTable())
mocker.patch(can_access_datasource, return_value=False)
utils.check_access(dataset_id=1, chart_id=0, actor=User())
check_chart_access(
datasource_id=1,
chart_id=0,
actor=User(),
datasource_type=DatasourceType.TABLE,
)
def test_unsaved_chart_authorized_dataset(
mocker: MockFixture, app_context: AppContext
) -> None:
from superset.connectors.sqla.models import SqlaTable
from superset.explore.utils import check_access
from superset.explore.utils import check_access as check_chart_access
mocker.patch(dataset_find_by_id, return_value=SqlaTable())
mocker.patch(can_access_datasource, return_value=True)
check_access(dataset_id=1, chart_id=0, actor=User())
check_chart_access(
datasource_id=1,
chart_id=0,
actor=User(),
datasource_type=DatasourceType.TABLE,
)
def test_saved_chart_unknown_chart_id(
mocker: MockFixture, app_context: AppContext
) -> None:
from superset.connectors.sqla.models import SqlaTable
from superset.explore.utils import check_access
from superset.explore.utils import check_access as check_chart_access
with raises(ChartNotFoundError):
mocker.patch(dataset_find_by_id, return_value=SqlaTable())
mocker.patch(can_access_datasource, return_value=True)
mocker.patch(chart_find_by_id, return_value=None)
check_access(dataset_id=1, chart_id=1, actor=User())
check_chart_access(
datasource_id=1,
chart_id=1,
actor=User(),
datasource_type=DatasourceType.TABLE,
)
def test_saved_chart_unauthorized_dataset(
mocker: MockFixture, app_context: AppContext
) -> None:
from superset.connectors.sqla.models import SqlaTable
from superset.explore import utils
from superset.explore.utils import check_access as check_chart_access
with raises(DatasetAccessDeniedError):
mocker.patch(dataset_find_by_id, return_value=SqlaTable())
mocker.patch(can_access_datasource, return_value=False)
utils.check_access(dataset_id=1, chart_id=1, actor=User())
check_chart_access(
datasource_id=1,
chart_id=1,
actor=User(),
datasource_type=DatasourceType.TABLE,
)
def test_saved_chart_is_admin(mocker: MockFixture, app_context: AppContext) -> None:
from superset.connectors.sqla.models import SqlaTable
from superset.explore.utils import check_access
from superset.explore.utils import check_access as check_chart_access
from superset.models.slice import Slice
mocker.patch(dataset_find_by_id, return_value=SqlaTable())
mocker.patch(can_access_datasource, return_value=True)
mocker.patch(is_user_admin, return_value=True)
mocker.patch(chart_find_by_id, return_value=Slice())
check_access(dataset_id=1, chart_id=1, actor=User())
check_chart_access(
datasource_id=1,
chart_id=1,
actor=User(),
datasource_type=DatasourceType.TABLE,
)
def test_saved_chart_is_owner(mocker: MockFixture, app_context: AppContext) -> None:
from superset.connectors.sqla.models import SqlaTable
from superset.explore.utils import check_access
from superset.explore.utils import check_access as check_chart_access
from superset.models.slice import Slice
mocker.patch(dataset_find_by_id, return_value=SqlaTable())
@ -125,12 +188,17 @@ def test_saved_chart_is_owner(mocker: MockFixture, app_context: AppContext) -> N
mocker.patch(is_user_admin, return_value=False)
mocker.patch(is_owner, return_value=True)
mocker.patch(chart_find_by_id, return_value=Slice())
check_access(dataset_id=1, chart_id=1, actor=User())
check_chart_access(
datasource_id=1,
chart_id=1,
actor=User(),
datasource_type=DatasourceType.TABLE,
)
def test_saved_chart_has_access(mocker: MockFixture, app_context: AppContext) -> None:
from superset.connectors.sqla.models import SqlaTable
from superset.explore.utils import check_access
from superset.explore.utils import check_access as check_chart_access
from superset.models.slice import Slice
mocker.patch(dataset_find_by_id, return_value=SqlaTable())
@ -139,12 +207,17 @@ def test_saved_chart_has_access(mocker: MockFixture, app_context: AppContext) ->
mocker.patch(is_owner, return_value=False)
mocker.patch(can_access, return_value=True)
mocker.patch(chart_find_by_id, return_value=Slice())
check_access(dataset_id=1, chart_id=1, actor=User())
check_chart_access(
datasource_id=1,
chart_id=1,
actor=User(),
datasource_type=DatasourceType.TABLE,
)
def test_saved_chart_no_access(mocker: MockFixture, app_context: AppContext) -> None:
from superset.connectors.sqla.models import SqlaTable
from superset.explore.utils import check_access
from superset.explore.utils import check_access as check_chart_access
from superset.models.slice import Slice
with raises(ChartAccessDeniedError):
@ -154,4 +227,66 @@ def test_saved_chart_no_access(mocker: MockFixture, app_context: AppContext) ->
mocker.patch(is_owner, return_value=False)
mocker.patch(can_access, return_value=False)
mocker.patch(chart_find_by_id, return_value=Slice())
check_access(dataset_id=1, chart_id=1, actor=User())
check_chart_access(
datasource_id=1,
chart_id=1,
actor=User(),
datasource_type=DatasourceType.TABLE,
)
def test_dataset_has_access(mocker: MockFixture, app_context: AppContext) -> None:
from superset.connectors.sqla.models import SqlaTable
from superset.explore.utils import check_datasource_access
mocker.patch(dataset_find_by_id, return_value=SqlaTable())
mocker.patch(can_access_datasource, return_value=True)
mocker.patch(is_user_admin, return_value=False)
mocker.patch(is_owner, return_value=False)
mocker.patch(can_access, return_value=True)
assert (
check_datasource_access(
datasource_id=1,
datasource_type=DatasourceType.TABLE,
)
== True
)
def test_query_has_access(mocker: MockFixture, app_context: AppContext) -> None:
from superset.explore.utils import check_datasource_access
from superset.models.sql_lab import Query
mocker.patch(query_find_by_id, return_value=Query())
mocker.patch(raise_for_access, return_value=True)
mocker.patch(is_user_admin, return_value=False)
mocker.patch(is_owner, return_value=False)
mocker.patch(can_access, return_value=True)
assert (
check_datasource_access(
datasource_id=1,
datasource_type=DatasourceType.QUERY,
)
== True
)
def test_query_no_access(mocker: MockFixture, app_context: AppContext) -> None:
from superset.connectors.sqla.models import SqlaTable
from superset.explore.utils import check_datasource_access
from superset.models.core import Database
from superset.models.sql_lab import Query
with raises(SupersetSecurityException):
mocker.patch(
query_find_by_id,
return_value=Query(database=Database(), sql="select * from foo"),
)
mocker.patch(query_datasources_by_name, return_value=[SqlaTable()])
mocker.patch(is_user_admin, return_value=False)
mocker.patch(is_owner, return_value=False)
mocker.patch(can_access, return_value=False)
check_datasource_access(
datasource_id=1,
datasource_type=DatasourceType.QUERY,
)