feat!: pass datasource_type and datasource_id to form_data (#19981)
* pass datasource_type and datasource_id to form_data * add datasource_type to delete command * add datasource_type to delete command * fix old keys implementation * add more tests
This commit is contained in:
parent
a813528958
commit
32bb1ce3ff
|
|
@ -67,10 +67,18 @@ export const URL_PARAMS = {
|
|||
name: 'slice_id',
|
||||
type: 'string',
|
||||
},
|
||||
datasourceId: {
|
||||
name: 'datasource_id',
|
||||
type: 'string',
|
||||
},
|
||||
datasetId: {
|
||||
name: 'dataset_id',
|
||||
type: 'string',
|
||||
},
|
||||
datasourceType: {
|
||||
name: 'datasource_type',
|
||||
type: 'string',
|
||||
},
|
||||
dashboardId: {
|
||||
name: 'dashboard_id',
|
||||
type: 'string',
|
||||
|
|
@ -88,6 +96,8 @@ export const URL_PARAMS = {
|
|||
export const RESERVED_CHART_URL_PARAMS: string[] = [
|
||||
URL_PARAMS.formDataKey.name,
|
||||
URL_PARAMS.sliceId.name,
|
||||
URL_PARAMS.datasourceId.name,
|
||||
URL_PARAMS.datasourceType.name,
|
||||
URL_PARAMS.datasetId.name,
|
||||
];
|
||||
export const RESERVED_DASHBOARD_URL_PARAMS: string[] = [
|
||||
|
|
|
|||
|
|
@ -272,6 +272,7 @@ export default class Chart extends React.Component {
|
|||
: undefined;
|
||||
const key = await postFormData(
|
||||
this.props.datasource.id,
|
||||
this.props.datasource.type,
|
||||
this.props.formData,
|
||||
this.props.slice.slice_id,
|
||||
nextTabId,
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ test('generates a new form_data param when none is available', async () => {
|
|||
expect(replaceState).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
undefined,
|
||||
expect.stringMatching('dataset_id'),
|
||||
expect.stringMatching('datasource_id'),
|
||||
);
|
||||
replaceState.mockRestore();
|
||||
});
|
||||
|
|
@ -109,7 +109,7 @@ test('generates a different form_data param when one is provided and is mounting
|
|||
expect(replaceState).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
undefined,
|
||||
expect.stringMatching('dataset_id'),
|
||||
expect.stringMatching('datasource_id'),
|
||||
);
|
||||
replaceState.mockRestore();
|
||||
});
|
||||
|
|
|
|||
|
|
@ -152,14 +152,24 @@ const ExplorePanelContainer = styled.div`
|
|||
`;
|
||||
|
||||
const updateHistory = debounce(
|
||||
async (formData, datasetId, isReplace, standalone, force, title, tabId) => {
|
||||
async (
|
||||
formData,
|
||||
datasourceId,
|
||||
datasourceType,
|
||||
isReplace,
|
||||
standalone,
|
||||
force,
|
||||
title,
|
||||
tabId,
|
||||
) => {
|
||||
const payload = { ...formData };
|
||||
const chartId = formData.slice_id;
|
||||
const additionalParam = {};
|
||||
if (chartId) {
|
||||
additionalParam[URL_PARAMS.sliceId.name] = chartId;
|
||||
} else {
|
||||
additionalParam[URL_PARAMS.datasetId.name] = datasetId;
|
||||
additionalParam[URL_PARAMS.datasourceId.name] = datasourceId;
|
||||
additionalParam[URL_PARAMS.datasourceType.name] = datasourceType;
|
||||
}
|
||||
|
||||
const urlParams = payload?.url_params || {};
|
||||
|
|
@ -173,11 +183,24 @@ const updateHistory = debounce(
|
|||
let key;
|
||||
let stateModifier;
|
||||
if (isReplace) {
|
||||
key = await postFormData(datasetId, formData, chartId, tabId);
|
||||
key = await postFormData(
|
||||
datasourceId,
|
||||
datasourceType,
|
||||
formData,
|
||||
chartId,
|
||||
tabId,
|
||||
);
|
||||
stateModifier = 'replaceState';
|
||||
} else {
|
||||
key = getUrlParam(URL_PARAMS.formDataKey);
|
||||
await putFormData(datasetId, key, formData, chartId, tabId);
|
||||
await putFormData(
|
||||
datasourceId,
|
||||
datasourceType,
|
||||
key,
|
||||
formData,
|
||||
chartId,
|
||||
tabId,
|
||||
);
|
||||
stateModifier = 'pushState';
|
||||
}
|
||||
const url = mountExploreUrl(
|
||||
|
|
@ -229,11 +252,12 @@ function ExploreViewContainer(props) {
|
|||
dashboardId: props.dashboardId,
|
||||
}
|
||||
: props.form_data;
|
||||
const datasetId = props.datasource.id;
|
||||
const { id: datasourceId, type: datasourceType } = props.datasource;
|
||||
|
||||
updateHistory(
|
||||
formData,
|
||||
datasetId,
|
||||
datasourceId,
|
||||
datasourceType,
|
||||
isReplace,
|
||||
props.standalone,
|
||||
props.force,
|
||||
|
|
@ -245,6 +269,7 @@ function ExploreViewContainer(props) {
|
|||
props.dashboardId,
|
||||
props.form_data,
|
||||
props.datasource.id,
|
||||
props.datasource.type,
|
||||
props.standalone,
|
||||
props.force,
|
||||
tabId,
|
||||
|
|
|
|||
|
|
@ -189,9 +189,9 @@ class DatasourceControl extends React.PureComponent {
|
|||
const isMissingDatasource = datasource.id == null;
|
||||
let isMissingParams = false;
|
||||
if (isMissingDatasource) {
|
||||
const datasetId = getUrlParam(URL_PARAMS.datasetId);
|
||||
const datasourceId = getUrlParam(URL_PARAMS.datasourceId);
|
||||
const sliceId = getUrlParam(URL_PARAMS.sliceId);
|
||||
if (!datasetId && !sliceId) {
|
||||
if (!datasourceId && !sliceId) {
|
||||
isMissingParams = true;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,7 +20,8 @@ import { omit } from 'lodash';
|
|||
import { SupersetClient, JsonObject } from '@superset-ui/core';
|
||||
|
||||
type Payload = {
|
||||
dataset_id: number;
|
||||
datasource_id: number;
|
||||
datasource_type: string;
|
||||
form_data: string;
|
||||
chart_id?: number;
|
||||
};
|
||||
|
|
@ -42,12 +43,14 @@ const assembleEndpoint = (key?: string, tabId?: string) => {
|
|||
};
|
||||
|
||||
const assemblePayload = (
|
||||
datasetId: number,
|
||||
datasourceId: number,
|
||||
datasourceType: string,
|
||||
formData: JsonObject,
|
||||
chartId?: number,
|
||||
) => {
|
||||
const payload: Payload = {
|
||||
dataset_id: datasetId,
|
||||
datasource_id: datasourceId,
|
||||
datasource_type: datasourceType,
|
||||
form_data: JSON.stringify(sanitizeFormData(formData)),
|
||||
};
|
||||
if (chartId) {
|
||||
|
|
@ -57,18 +60,25 @@ const assemblePayload = (
|
|||
};
|
||||
|
||||
export const postFormData = (
|
||||
datasetId: number,
|
||||
datasourceId: number,
|
||||
datasourceType: string,
|
||||
formData: JsonObject,
|
||||
chartId?: number,
|
||||
tabId?: string,
|
||||
): Promise<string> =>
|
||||
SupersetClient.post({
|
||||
endpoint: assembleEndpoint(undefined, tabId),
|
||||
jsonPayload: assemblePayload(datasetId, formData, chartId),
|
||||
jsonPayload: assemblePayload(
|
||||
datasourceId,
|
||||
datasourceType,
|
||||
formData,
|
||||
chartId,
|
||||
),
|
||||
}).then(r => r.json.key);
|
||||
|
||||
export const putFormData = (
|
||||
datasetId: number,
|
||||
datasourceId: number,
|
||||
datasourceType: string,
|
||||
key: string,
|
||||
formData: JsonObject,
|
||||
chartId?: number,
|
||||
|
|
@ -76,5 +86,10 @@ export const putFormData = (
|
|||
): Promise<string> =>
|
||||
SupersetClient.put({
|
||||
endpoint: assembleEndpoint(key, tabId),
|
||||
jsonPayload: assemblePayload(datasetId, formData, chartId),
|
||||
jsonPayload: assemblePayload(
|
||||
datasourceId,
|
||||
datasourceType,
|
||||
formData,
|
||||
chartId,
|
||||
),
|
||||
}).then(r => r.json.message);
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from superset.charts.schemas import (
|
|||
datasource_type_description,
|
||||
datasource_uid_description,
|
||||
)
|
||||
from superset.utils.core import DatasourceType
|
||||
|
||||
|
||||
class Datasource(Schema):
|
||||
|
|
@ -36,7 +37,7 @@ class Datasource(Schema):
|
|||
)
|
||||
datasource_type = fields.String(
|
||||
description=datasource_type_description,
|
||||
validate=validate.OneOf(choices=("druid", "table", "view")),
|
||||
validate=validate.OneOf(choices=[ds.value for ds in DatasourceType]),
|
||||
required=True,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ from superset.db_engine_specs.base import builtin_time_grains
|
|||
from superset.utils import pandas_postprocessing, schema as utils
|
||||
from superset.utils.core import (
|
||||
AnnotationType,
|
||||
DatasourceType,
|
||||
FilterOperator,
|
||||
PostProcessingBoxplotWhiskerType,
|
||||
PostProcessingContributionOrientation,
|
||||
|
|
@ -198,7 +199,7 @@ class ChartPostSchema(Schema):
|
|||
datasource_id = fields.Integer(description=datasource_id_description, required=True)
|
||||
datasource_type = fields.String(
|
||||
description=datasource_type_description,
|
||||
validate=validate.OneOf(choices=("druid", "table", "view")),
|
||||
validate=validate.OneOf(choices=[ds.value for ds in DatasourceType]),
|
||||
required=True,
|
||||
)
|
||||
datasource_name = fields.String(
|
||||
|
|
@ -244,7 +245,7 @@ class ChartPutSchema(Schema):
|
|||
)
|
||||
datasource_type = fields.String(
|
||||
description=datasource_type_description,
|
||||
validate=validate.OneOf(choices=("druid", "table", "view")),
|
||||
validate=validate.OneOf(choices=[ds.value for ds in DatasourceType]),
|
||||
allow_none=True,
|
||||
)
|
||||
dashboards = fields.List(fields.Integer(description=dashboards_description))
|
||||
|
|
@ -983,7 +984,7 @@ class ChartDataDatasourceSchema(Schema):
|
|||
)
|
||||
type = fields.String(
|
||||
description="Datasource type",
|
||||
validate=validate.OneOf(choices=("druid", "table")),
|
||||
validate=validate.OneOf(choices=[ds.value for ds in DatasourceType]),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -115,8 +115,24 @@ class RolesNotFoundValidationError(ValidationError):
|
|||
super().__init__([_("Some roles do not exist")], field_name="roles")
|
||||
|
||||
|
||||
class DatasourceTypeInvalidError(ValidationError):
|
||||
status = 422
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
[_("Datasource type is invalid")], field_name="datasource_type"
|
||||
)
|
||||
|
||||
|
||||
class DatasourceNotFoundValidationError(ValidationError):
|
||||
status = 404
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__([_("Dataset does not exist")], field_name="datasource_id")
|
||||
super().__init__([_("Datasource does not exist")], field_name="datasource_id")
|
||||
|
||||
|
||||
class QueryNotFoundValidationError(ValidationError):
|
||||
status = 404
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__([_("Query does not exist")], field_name="datasource_id")
|
||||
|
|
|
|||
|
|
@ -39,11 +39,11 @@ Datasource = Union[Dataset, SqlaTable, Table, Query, SavedQuery]
|
|||
class DatasourceDAO(BaseDAO):
|
||||
|
||||
sources: Dict[DatasourceType, Type[Datasource]] = {
|
||||
DatasourceType.SQLATABLE: SqlaTable,
|
||||
DatasourceType.TABLE: SqlaTable,
|
||||
DatasourceType.QUERY: Query,
|
||||
DatasourceType.SAVEDQUERY: SavedQuery,
|
||||
DatasourceType.DATASET: Dataset,
|
||||
DatasourceType.TABLE: Table,
|
||||
DatasourceType.SLTABLE: Table,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
|
@ -66,7 +66,7 @@ class DatasourceDAO(BaseDAO):
|
|||
|
||||
@classmethod
|
||||
def get_all_sqlatables_datasources(cls, session: Session) -> List[Datasource]:
|
||||
source_class = DatasourceDAO.sources[DatasourceType.SQLATABLE]
|
||||
source_class = DatasourceDAO.sources[DatasourceType.TABLE]
|
||||
qry = session.query(source_class)
|
||||
qry = source_class.default_query(qry)
|
||||
return qry.all()
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from superset.models.core import Database
|
|||
from superset.models.dashboard import Dashboard
|
||||
from superset.models.slice import Slice
|
||||
from superset.models.sql_lab import TabState
|
||||
from superset.utils.core import DatasourceType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -75,7 +76,8 @@ class DatabaseDAO(BaseDAO):
|
|||
charts = (
|
||||
db.session.query(Slice)
|
||||
.filter(
|
||||
Slice.datasource_id.in_(dataset_ids), Slice.datasource_type == "table"
|
||||
Slice.datasource_id.in_(dataset_ids),
|
||||
Slice.datasource_type == DatasourceType.TABLE,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ from superset.extensions import db
|
|||
from superset.models.core import Database
|
||||
from superset.models.dashboard import Dashboard
|
||||
from superset.models.slice import Slice
|
||||
from superset.utils.core import DatasourceType
|
||||
from superset.views.base import DatasourceFilter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -56,7 +57,8 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
|
|||
charts = (
|
||||
db.session.query(Slice)
|
||||
.filter(
|
||||
Slice.datasource_id == database_id, Slice.datasource_type == "table"
|
||||
Slice.datasource_id == database_id,
|
||||
Slice.datasource_type == DatasourceType.TABLE,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from superset.exceptions import NoDataException
|
|||
from superset.models.core import Database
|
||||
from superset.models.dashboard import Dashboard
|
||||
from superset.models.slice import Slice
|
||||
from superset.utils.core import DatasourceType
|
||||
|
||||
from ..utils.database import get_example_database
|
||||
from .helpers import (
|
||||
|
|
@ -205,13 +206,16 @@ def create_slices(tbl: SqlaTable, admin_owner: bool) -> Tuple[List[Slice], List[
|
|||
if admin_owner:
|
||||
slice_props = dict(
|
||||
datasource_id=tbl.id,
|
||||
datasource_type="table",
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
owners=[admin],
|
||||
created_by=admin,
|
||||
)
|
||||
else:
|
||||
slice_props = dict(
|
||||
datasource_id=tbl.id, datasource_type="table", owners=[], created_by=admin
|
||||
datasource_id=tbl.id,
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
owners=[],
|
||||
created_by=admin,
|
||||
)
|
||||
|
||||
print("Creating some slices")
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ import superset.utils.database as database_utils
|
|||
from superset import db
|
||||
from superset.connectors.sqla.models import SqlMetric
|
||||
from superset.models.slice import Slice
|
||||
from superset.utils.core import DatasourceType
|
||||
|
||||
from .helpers import (
|
||||
get_example_data,
|
||||
|
|
@ -112,7 +113,7 @@ def load_country_map_data(only_metadata: bool = False, force: bool = False) -> N
|
|||
slc = Slice(
|
||||
slice_name="Birth in France by department in 2016",
|
||||
viz_type="country_map",
|
||||
datasource_type="table",
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
datasource_id=tbl.id,
|
||||
params=get_slice_json(slice_data),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ import json
|
|||
from superset import db
|
||||
from superset.models.dashboard import Dashboard
|
||||
from superset.models.slice import Slice
|
||||
from superset.utils.core import DatasourceType
|
||||
|
||||
from .helpers import (
|
||||
get_slice_json,
|
||||
|
|
@ -213,7 +214,7 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements
|
|||
slc = Slice(
|
||||
slice_name="Deck.gl Scatterplot",
|
||||
viz_type="deck_scatter",
|
||||
datasource_type="table",
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
datasource_id=tbl.id,
|
||||
params=get_slice_json(slice_data),
|
||||
)
|
||||
|
|
@ -248,7 +249,7 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements
|
|||
slc = Slice(
|
||||
slice_name="Deck.gl Screen grid",
|
||||
viz_type="deck_screengrid",
|
||||
datasource_type="table",
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
datasource_id=tbl.id,
|
||||
params=get_slice_json(slice_data),
|
||||
)
|
||||
|
|
@ -284,7 +285,7 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements
|
|||
slc = Slice(
|
||||
slice_name="Deck.gl Hexagons",
|
||||
viz_type="deck_hex",
|
||||
datasource_type="table",
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
datasource_id=tbl.id,
|
||||
params=get_slice_json(slice_data),
|
||||
)
|
||||
|
|
@ -321,7 +322,7 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements
|
|||
slc = Slice(
|
||||
slice_name="Deck.gl Grid",
|
||||
viz_type="deck_grid",
|
||||
datasource_type="table",
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
datasource_id=tbl.id,
|
||||
params=get_slice_json(slice_data),
|
||||
)
|
||||
|
|
@ -410,7 +411,7 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements
|
|||
slc = Slice(
|
||||
slice_name="Deck.gl Polygons",
|
||||
viz_type="deck_polygon",
|
||||
datasource_type="table",
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
datasource_id=polygon_tbl.id,
|
||||
params=get_slice_json(slice_data),
|
||||
)
|
||||
|
|
@ -460,7 +461,7 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements
|
|||
slc = Slice(
|
||||
slice_name="Deck.gl Arcs",
|
||||
viz_type="deck_arc",
|
||||
datasource_type="table",
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
datasource_id=db.session.query(table)
|
||||
.filter_by(table_name="flights")
|
||||
.first()
|
||||
|
|
@ -512,7 +513,7 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements
|
|||
slc = Slice(
|
||||
slice_name="Deck.gl Path",
|
||||
viz_type="deck_path",
|
||||
datasource_type="table",
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
datasource_id=db.session.query(table)
|
||||
.filter_by(table_name="bart_lines")
|
||||
.first()
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ import superset.utils.database as database_utils
|
|||
from superset import db
|
||||
from superset.connectors.sqla.models import SqlMetric
|
||||
from superset.models.slice import Slice
|
||||
from superset.utils.core import DatasourceType
|
||||
|
||||
from .helpers import (
|
||||
get_example_data,
|
||||
|
|
@ -81,7 +82,7 @@ def load_energy(
|
|||
slc = Slice(
|
||||
slice_name="Energy Sankey",
|
||||
viz_type="sankey",
|
||||
datasource_type="table",
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
datasource_id=tbl.id,
|
||||
params=textwrap.dedent(
|
||||
"""\
|
||||
|
|
@ -105,7 +106,7 @@ def load_energy(
|
|||
slc = Slice(
|
||||
slice_name="Energy Force Layout",
|
||||
viz_type="graph_chart",
|
||||
datasource_type="table",
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
datasource_id=tbl.id,
|
||||
params=textwrap.dedent(
|
||||
"""\
|
||||
|
|
@ -129,7 +130,7 @@ def load_energy(
|
|||
slc = Slice(
|
||||
slice_name="Heatmap",
|
||||
viz_type="heatmap",
|
||||
datasource_type="table",
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
datasource_id=tbl.id,
|
||||
params=textwrap.dedent(
|
||||
"""\
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from sqlalchemy import DateTime, Float, inspect, String
|
|||
import superset.utils.database as database_utils
|
||||
from superset import db
|
||||
from superset.models.slice import Slice
|
||||
from superset.utils.core import DatasourceType
|
||||
|
||||
from .helpers import (
|
||||
get_example_data,
|
||||
|
|
@ -113,7 +114,7 @@ def load_long_lat_data(only_metadata: bool = False, force: bool = False) -> None
|
|||
slc = Slice(
|
||||
slice_name="Mapbox Long/Lat",
|
||||
viz_type="mapbox",
|
||||
datasource_type="table",
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
datasource_id=tbl.id,
|
||||
params=get_slice_json(slice_data),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ import json
|
|||
|
||||
from superset import db
|
||||
from superset.models.slice import Slice
|
||||
from superset.utils.core import DatasourceType
|
||||
|
||||
from .birth_names import load_birth_names
|
||||
from .helpers import merge_slice, misc_dash_slices
|
||||
|
|
@ -35,7 +36,7 @@ def load_multi_line(only_metadata: bool = False) -> None:
|
|||
]
|
||||
|
||||
slc = Slice(
|
||||
datasource_type="table", # not true, but needed
|
||||
datasource_type=DatasourceType.TABLE, # not true, but needed
|
||||
datasource_id=1, # cannot be empty
|
||||
slice_name="Multi Line",
|
||||
viz_type="line_multi",
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from sqlalchemy import BigInteger, Date, DateTime, inspect, String
|
|||
|
||||
from superset import app, db
|
||||
from superset.models.slice import Slice
|
||||
from superset.utils.core import DatasourceType
|
||||
|
||||
from ..utils.database import get_example_database
|
||||
from .helpers import (
|
||||
|
|
@ -120,7 +121,7 @@ def load_multiformat_time_series( # pylint: disable=too-many-locals
|
|||
slc = Slice(
|
||||
slice_name=f"Calendar Heatmap multiformat {i}",
|
||||
viz_type="cal_heatmap",
|
||||
datasource_type="table",
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
datasource_id=tbl.id,
|
||||
params=get_slice_json(slice_data),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from sqlalchemy import DateTime, inspect, String
|
|||
import superset.utils.database as database_utils
|
||||
from superset import app, db
|
||||
from superset.models.slice import Slice
|
||||
from superset.utils.core import DatasourceType
|
||||
|
||||
from .helpers import (
|
||||
get_example_data,
|
||||
|
|
@ -89,7 +90,7 @@ def load_random_time_series_data(
|
|||
slc = Slice(
|
||||
slice_name="Calendar Heatmap",
|
||||
viz_type="cal_heatmap",
|
||||
datasource_type="table",
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
datasource_id=tbl.id,
|
||||
params=get_slice_json(slice_data),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from superset.connectors.sqla.models import SqlMetric
|
|||
from superset.models.dashboard import Dashboard
|
||||
from superset.models.slice import Slice
|
||||
from superset.utils import core as utils
|
||||
from superset.utils.core import DatasourceType
|
||||
|
||||
from ..connectors.base.models import BaseDatasource
|
||||
from .helpers import (
|
||||
|
|
@ -172,7 +173,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]:
|
|||
Slice(
|
||||
slice_name="Region Filter",
|
||||
viz_type="filter_box",
|
||||
datasource_type="table",
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
datasource_id=tbl.id,
|
||||
params=get_slice_json(
|
||||
defaults,
|
||||
|
|
@ -201,7 +202,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]:
|
|||
Slice(
|
||||
slice_name="World's Population",
|
||||
viz_type="big_number",
|
||||
datasource_type="table",
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
datasource_id=tbl.id,
|
||||
params=get_slice_json(
|
||||
defaults,
|
||||
|
|
@ -215,7 +216,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]:
|
|||
Slice(
|
||||
slice_name="Most Populated Countries",
|
||||
viz_type="table",
|
||||
datasource_type="table",
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
datasource_id=tbl.id,
|
||||
params=get_slice_json(
|
||||
defaults,
|
||||
|
|
@ -227,7 +228,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]:
|
|||
Slice(
|
||||
slice_name="Growth Rate",
|
||||
viz_type="line",
|
||||
datasource_type="table",
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
datasource_id=tbl.id,
|
||||
params=get_slice_json(
|
||||
defaults,
|
||||
|
|
@ -241,7 +242,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]:
|
|||
Slice(
|
||||
slice_name="% Rural",
|
||||
viz_type="world_map",
|
||||
datasource_type="table",
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
datasource_id=tbl.id,
|
||||
params=get_slice_json(
|
||||
defaults,
|
||||
|
|
@ -254,7 +255,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]:
|
|||
Slice(
|
||||
slice_name="Life Expectancy VS Rural %",
|
||||
viz_type="bubble",
|
||||
datasource_type="table",
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
datasource_id=tbl.id,
|
||||
params=get_slice_json(
|
||||
defaults,
|
||||
|
|
@ -298,7 +299,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]:
|
|||
Slice(
|
||||
slice_name="Rural Breakdown",
|
||||
viz_type="sunburst",
|
||||
datasource_type="table",
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
datasource_id=tbl.id,
|
||||
params=get_slice_json(
|
||||
defaults,
|
||||
|
|
@ -313,7 +314,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]:
|
|||
Slice(
|
||||
slice_name="World's Pop Growth",
|
||||
viz_type="area",
|
||||
datasource_type="table",
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
datasource_id=tbl.id,
|
||||
params=get_slice_json(
|
||||
defaults,
|
||||
|
|
@ -327,7 +328,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]:
|
|||
Slice(
|
||||
slice_name="Box plot",
|
||||
viz_type="box_plot",
|
||||
datasource_type="table",
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
datasource_id=tbl.id,
|
||||
params=get_slice_json(
|
||||
defaults,
|
||||
|
|
@ -343,7 +344,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]:
|
|||
Slice(
|
||||
slice_name="Treemap",
|
||||
viz_type="treemap",
|
||||
datasource_type="table",
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
datasource_id=tbl.id,
|
||||
params=get_slice_json(
|
||||
defaults,
|
||||
|
|
@ -357,7 +358,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]:
|
|||
Slice(
|
||||
slice_name="Parallel Coordinates",
|
||||
viz_type="para",
|
||||
datasource_type="table",
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
datasource_id=tbl.id,
|
||||
params=get_slice_json(
|
||||
defaults,
|
||||
|
|
|
|||
|
|
@ -104,7 +104,8 @@ class ExploreFormDataRestApi(BaseApi, ABC):
|
|||
tab_id = request.args.get("tab_id")
|
||||
args = CommandParameters(
|
||||
actor=g.user,
|
||||
dataset_id=item["dataset_id"],
|
||||
datasource_id=item["datasource_id"],
|
||||
datasource_type=item["datasource_type"],
|
||||
chart_id=item.get("chart_id"),
|
||||
tab_id=tab_id,
|
||||
form_data=item["form_data"],
|
||||
|
|
@ -123,7 +124,7 @@ class ExploreFormDataRestApi(BaseApi, ABC):
|
|||
@safe
|
||||
@event_logger.log_this_with_context(
|
||||
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.put",
|
||||
log_to_statsd=False,
|
||||
log_to_statsd=True,
|
||||
)
|
||||
@requires_json
|
||||
def put(self, key: str) -> Response:
|
||||
|
|
@ -174,7 +175,8 @@ class ExploreFormDataRestApi(BaseApi, ABC):
|
|||
tab_id = request.args.get("tab_id")
|
||||
args = CommandParameters(
|
||||
actor=g.user,
|
||||
dataset_id=item["dataset_id"],
|
||||
datasource_id=item["datasource_id"],
|
||||
datasource_type=item["datasource_type"],
|
||||
chart_id=item.get("chart_id"),
|
||||
tab_id=tab_id,
|
||||
key=key,
|
||||
|
|
@ -196,7 +198,7 @@ class ExploreFormDataRestApi(BaseApi, ABC):
|
|||
@safe
|
||||
@event_logger.log_this_with_context(
|
||||
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.get",
|
||||
log_to_statsd=False,
|
||||
log_to_statsd=True,
|
||||
)
|
||||
def get(self, key: str) -> Response:
|
||||
"""Retrives a form_data.
|
||||
|
|
@ -247,7 +249,7 @@ class ExploreFormDataRestApi(BaseApi, ABC):
|
|||
@safe
|
||||
@event_logger.log_this_with_context(
|
||||
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.delete",
|
||||
log_to_statsd=False,
|
||||
log_to_statsd=True,
|
||||
)
|
||||
def delete(self, key: str) -> Response:
|
||||
"""Deletes a form_data.
|
||||
|
|
|
|||
|
|
@ -39,20 +39,24 @@ class CreateFormDataCommand(BaseCommand):
|
|||
def run(self) -> str:
|
||||
self.validate()
|
||||
try:
|
||||
dataset_id = self._cmd_params.dataset_id
|
||||
datasource_id = self._cmd_params.datasource_id
|
||||
datasource_type = self._cmd_params.datasource_type
|
||||
chart_id = self._cmd_params.chart_id
|
||||
tab_id = self._cmd_params.tab_id
|
||||
actor = self._cmd_params.actor
|
||||
form_data = self._cmd_params.form_data
|
||||
check_access(dataset_id, chart_id, actor)
|
||||
contextual_key = cache_key(session.get("_id"), tab_id, dataset_id, chart_id)
|
||||
check_access(datasource_id, chart_id, actor, datasource_type)
|
||||
contextual_key = cache_key(
|
||||
session.get("_id"), tab_id, datasource_id, chart_id, datasource_type
|
||||
)
|
||||
key = cache_manager.explore_form_data_cache.get(contextual_key)
|
||||
if not key or not tab_id:
|
||||
key = random_key()
|
||||
if form_data:
|
||||
state: TemporaryExploreState = {
|
||||
"owner": get_owner(actor),
|
||||
"dataset_id": dataset_id,
|
||||
"datasource_id": datasource_id,
|
||||
"datasource_type": datasource_type,
|
||||
"chart_id": chart_id,
|
||||
"form_data": form_data,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@
|
|||
# under the License.
|
||||
import logging
|
||||
from abc import ABC
|
||||
from typing import Optional
|
||||
|
||||
from flask import session
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
|
@ -31,6 +32,7 @@ from superset.temporary_cache.commands.exceptions import (
|
|||
TemporaryCacheDeleteFailedError,
|
||||
)
|
||||
from superset.temporary_cache.utils import cache_key
|
||||
from superset.utils.core import DatasourceType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -47,14 +49,15 @@ class DeleteFormDataCommand(BaseCommand, ABC):
|
|||
key
|
||||
)
|
||||
if state:
|
||||
dataset_id = state["dataset_id"]
|
||||
chart_id = state["chart_id"]
|
||||
check_access(dataset_id, chart_id, actor)
|
||||
datasource_id: int = state["datasource_id"]
|
||||
chart_id: Optional[int] = state["chart_id"]
|
||||
datasource_type = DatasourceType(state["datasource_type"])
|
||||
check_access(datasource_id, chart_id, actor, datasource_type)
|
||||
if state["owner"] != get_owner(actor):
|
||||
raise TemporaryCacheAccessDeniedError()
|
||||
tab_id = self._cmd_params.tab_id
|
||||
contextual_key = cache_key(
|
||||
session.get("_id"), tab_id, dataset_id, chart_id
|
||||
session.get("_id"), tab_id, datasource_id, chart_id, datasource_type
|
||||
)
|
||||
cache_manager.explore_form_data_cache.delete(contextual_key)
|
||||
return cache_manager.explore_form_data_cache.delete(key)
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from superset.explore.form_data.commands.state import TemporaryExploreState
|
|||
from superset.explore.form_data.commands.utils import check_access
|
||||
from superset.extensions import cache_manager
|
||||
from superset.temporary_cache.commands.exceptions import TemporaryCacheGetFailedError
|
||||
from superset.utils.core import DatasourceType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -45,7 +46,12 @@ class GetFormDataCommand(BaseCommand, ABC):
|
|||
key
|
||||
)
|
||||
if state:
|
||||
check_access(state["dataset_id"], state["chart_id"], actor)
|
||||
check_access(
|
||||
state["datasource_id"],
|
||||
state["chart_id"],
|
||||
actor,
|
||||
DatasourceType(state["datasource_type"]),
|
||||
)
|
||||
if self._refresh_timeout:
|
||||
cache_manager.explore_form_data_cache.set(key, state)
|
||||
return state["form_data"]
|
||||
|
|
|
|||
|
|
@ -19,11 +19,14 @@ from typing import Optional
|
|||
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
|
||||
from superset.utils.core import DatasourceType
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommandParameters:
|
||||
actor: User
|
||||
dataset_id: int = 0
|
||||
datasource_type: DatasourceType = DatasourceType.TABLE
|
||||
datasource_id: int = 0
|
||||
chart_id: int = 0
|
||||
tab_id: Optional[int] = None
|
||||
key: Optional[str] = None
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from typing_extensions import TypedDict
|
|||
|
||||
class TemporaryExploreState(TypedDict):
|
||||
owner: Optional[int]
|
||||
dataset_id: int
|
||||
datasource_id: int
|
||||
datasource_type: str
|
||||
chart_id: Optional[int]
|
||||
form_data: str
|
||||
|
|
|
|||
|
|
@ -47,12 +47,13 @@ class UpdateFormDataCommand(BaseCommand, ABC):
|
|||
def run(self) -> Optional[str]:
|
||||
self.validate()
|
||||
try:
|
||||
dataset_id = self._cmd_params.dataset_id
|
||||
datasource_id = self._cmd_params.datasource_id
|
||||
chart_id = self._cmd_params.chart_id
|
||||
datasource_type = self._cmd_params.datasource_type
|
||||
actor = self._cmd_params.actor
|
||||
key = self._cmd_params.key
|
||||
form_data = self._cmd_params.form_data
|
||||
check_access(dataset_id, chart_id, actor)
|
||||
check_access(datasource_id, chart_id, actor, datasource_type)
|
||||
state: TemporaryExploreState = cache_manager.explore_form_data_cache.get(
|
||||
key
|
||||
)
|
||||
|
|
@ -64,7 +65,7 @@ class UpdateFormDataCommand(BaseCommand, ABC):
|
|||
# Generate a new key if tab_id changes or equals 0
|
||||
tab_id = self._cmd_params.tab_id
|
||||
contextual_key = cache_key(
|
||||
session.get("_id"), tab_id, dataset_id, chart_id
|
||||
session.get("_id"), tab_id, datasource_id, chart_id, datasource_type
|
||||
)
|
||||
key = cache_manager.explore_form_data_cache.get(contextual_key)
|
||||
if not key or not tab_id:
|
||||
|
|
@ -73,7 +74,8 @@ class UpdateFormDataCommand(BaseCommand, ABC):
|
|||
|
||||
new_state: TemporaryExploreState = {
|
||||
"owner": owner,
|
||||
"dataset_id": dataset_id,
|
||||
"datasource_id": datasource_id,
|
||||
"datasource_type": datasource_type,
|
||||
"chart_id": chart_id,
|
||||
"form_data": form_data,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,11 +31,17 @@ from superset.temporary_cache.commands.exceptions import (
|
|||
TemporaryCacheAccessDeniedError,
|
||||
TemporaryCacheResourceNotFoundError,
|
||||
)
|
||||
from superset.utils.core import DatasourceType
|
||||
|
||||
|
||||
def check_access(dataset_id: int, chart_id: Optional[int], actor: User) -> None:
|
||||
def check_access(
|
||||
datasource_id: int,
|
||||
chart_id: Optional[int],
|
||||
actor: User,
|
||||
datasource_type: DatasourceType,
|
||||
) -> None:
|
||||
try:
|
||||
explore_check_access(dataset_id, chart_id, actor)
|
||||
explore_check_access(datasource_id, chart_id, actor, datasource_type)
|
||||
except (ChartNotFoundError, DatasetNotFoundError) as ex:
|
||||
raise TemporaryCacheResourceNotFoundError from ex
|
||||
except (ChartAccessDeniedError, DatasetAccessDeniedError) as ex:
|
||||
|
|
|
|||
|
|
@ -14,12 +14,20 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from marshmallow import fields, Schema
|
||||
from marshmallow import fields, Schema, validate
|
||||
|
||||
from superset.utils.core import DatasourceType
|
||||
|
||||
|
||||
class FormDataPostSchema(Schema):
|
||||
dataset_id = fields.Integer(
|
||||
required=True, allow_none=False, description="The dataset ID"
|
||||
datasource_id = fields.Integer(
|
||||
required=True, allow_none=False, description="The datasource ID"
|
||||
)
|
||||
datasource_type = fields.String(
|
||||
required=True,
|
||||
allow_none=False,
|
||||
description="The datasource type",
|
||||
validate=validate.OneOf(choices=[ds.value for ds in DatasourceType]),
|
||||
)
|
||||
chart_id = fields.Integer(required=False, description="The chart ID")
|
||||
form_data = fields.String(
|
||||
|
|
@ -28,8 +36,14 @@ class FormDataPostSchema(Schema):
|
|||
|
||||
|
||||
class FormDataPutSchema(Schema):
|
||||
dataset_id = fields.Integer(
|
||||
required=True, allow_none=False, description="The dataset ID"
|
||||
datasource_id = fields.Integer(
|
||||
required=True, allow_none=False, description="The datasource ID"
|
||||
)
|
||||
datasource_type = fields.String(
|
||||
required=True,
|
||||
allow_none=False,
|
||||
description="The datasource type",
|
||||
validate=validate.OneOf(choices=[ds.value for ds in DatasourceType]),
|
||||
)
|
||||
chart_id = fields.Integer(required=False, description="The chart ID")
|
||||
form_data = fields.String(
|
||||
|
|
|
|||
|
|
@ -22,9 +22,10 @@ from sqlalchemy.exc import SQLAlchemyError
|
|||
|
||||
from superset.explore.permalink.commands.base import BaseExplorePermalinkCommand
|
||||
from superset.explore.permalink.exceptions import ExplorePermalinkCreateFailedError
|
||||
from superset.explore.utils import check_access
|
||||
from superset.explore.utils import check_access as check_chart_access
|
||||
from superset.key_value.commands.create import CreateKeyValueCommand
|
||||
from superset.key_value.utils import encode_permalink_key
|
||||
from superset.utils.core import DatasourceType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -39,11 +40,16 @@ class CreateExplorePermalinkCommand(BaseExplorePermalinkCommand):
|
|||
def run(self) -> str:
|
||||
self.validate()
|
||||
try:
|
||||
dataset_id = int(self.datasource.split("__")[0])
|
||||
check_access(dataset_id, self.chart_id, self.actor)
|
||||
d_id, d_type = self.datasource.split("__")
|
||||
datasource_id = int(d_id)
|
||||
datasource_type = DatasourceType(d_type)
|
||||
check_chart_access(
|
||||
datasource_id, self.chart_id, self.actor, datasource_type
|
||||
)
|
||||
value = {
|
||||
"chartId": self.chart_id,
|
||||
"datasetId": dataset_id,
|
||||
"datasourceId": datasource_id,
|
||||
"datasourceType": datasource_type,
|
||||
"datasource": self.datasource,
|
||||
"state": self.state,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,10 +24,11 @@ from superset.datasets.commands.exceptions import DatasetNotFoundError
|
|||
from superset.explore.permalink.commands.base import BaseExplorePermalinkCommand
|
||||
from superset.explore.permalink.exceptions import ExplorePermalinkGetFailedError
|
||||
from superset.explore.permalink.types import ExplorePermalinkValue
|
||||
from superset.explore.utils import check_access
|
||||
from superset.explore.utils import check_access as check_chart_access
|
||||
from superset.key_value.commands.get import GetKeyValueCommand
|
||||
from superset.key_value.exceptions import KeyValueGetFailedError, KeyValueParseKeyError
|
||||
from superset.key_value.utils import decode_permalink_id
|
||||
from superset.utils.core import DatasourceType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -47,8 +48,9 @@ class GetExplorePermalinkCommand(BaseExplorePermalinkCommand):
|
|||
).run()
|
||||
if value:
|
||||
chart_id: Optional[int] = value.get("chartId")
|
||||
dataset_id = value["datasetId"]
|
||||
check_access(dataset_id, chart_id, self.actor)
|
||||
datasource_id: int = value["datasourceId"]
|
||||
datasource_type = DatasourceType(value["datasourceType"])
|
||||
check_chart_access(datasource_id, chart_id, self.actor, datasource_type)
|
||||
return value
|
||||
return None
|
||||
except (
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ class ExplorePermalinkState(TypedDict, total=False):
|
|||
|
||||
class ExplorePermalinkValue(TypedDict):
|
||||
chartId: Optional[int]
|
||||
datasetId: int
|
||||
datasourceId: int
|
||||
datasourceType: str
|
||||
datasource: str
|
||||
state: ExplorePermalinkState
|
||||
|
|
|
|||
|
|
@ -24,11 +24,18 @@ from superset.charts.commands.exceptions import (
|
|||
ChartNotFoundError,
|
||||
)
|
||||
from superset.charts.dao import ChartDAO
|
||||
from superset.commands.exceptions import (
|
||||
DatasourceNotFoundValidationError,
|
||||
DatasourceTypeInvalidError,
|
||||
QueryNotFoundValidationError,
|
||||
)
|
||||
from superset.datasets.commands.exceptions import (
|
||||
DatasetAccessDeniedError,
|
||||
DatasetNotFoundError,
|
||||
)
|
||||
from superset.datasets.dao import DatasetDAO
|
||||
from superset.queries.dao import QueryDAO
|
||||
from superset.utils.core import DatasourceType
|
||||
from superset.views.base import is_user_admin
|
||||
from superset.views.utils import is_owner
|
||||
|
||||
|
|
@ -44,10 +51,41 @@ def check_dataset_access(dataset_id: int) -> Optional[bool]:
|
|||
raise DatasetNotFoundError()
|
||||
|
||||
|
||||
def check_access(dataset_id: int, chart_id: Optional[int], actor: User) -> None:
|
||||
check_dataset_access(dataset_id)
|
||||
def check_query_access(query_id: int) -> Optional[bool]:
|
||||
if query_id:
|
||||
query = QueryDAO.find_by_id(query_id)
|
||||
if query:
|
||||
security_manager.raise_for_access(query=query)
|
||||
return True
|
||||
raise QueryNotFoundValidationError()
|
||||
|
||||
|
||||
ACCESS_FUNCTION_MAP = {
|
||||
DatasourceType.TABLE: check_dataset_access,
|
||||
DatasourceType.QUERY: check_query_access,
|
||||
}
|
||||
|
||||
|
||||
def check_datasource_access(
|
||||
datasource_id: int, datasource_type: DatasourceType
|
||||
) -> Optional[bool]:
|
||||
if datasource_id:
|
||||
try:
|
||||
return ACCESS_FUNCTION_MAP[datasource_type](datasource_id)
|
||||
except KeyError as ex:
|
||||
raise DatasourceTypeInvalidError() from ex
|
||||
raise DatasourceNotFoundValidationError()
|
||||
|
||||
|
||||
def check_access(
|
||||
datasource_id: int,
|
||||
chart_id: Optional[int],
|
||||
actor: User,
|
||||
datasource_type: DatasourceType,
|
||||
) -> Optional[bool]:
|
||||
check_datasource_access(datasource_id, datasource_type)
|
||||
if not chart_id:
|
||||
return
|
||||
return True
|
||||
chart = ChartDAO.find_by_id(chart_id)
|
||||
if chart:
|
||||
can_access_chart = (
|
||||
|
|
@ -56,6 +94,6 @@ def check_access(dataset_id: int, chart_id: Optional[int], actor: User) -> None:
|
|||
or security_manager.can_access("can_read", "Chart")
|
||||
)
|
||||
if can_access_chart:
|
||||
return
|
||||
return True
|
||||
raise ChartAccessDeniedError()
|
||||
raise ChartNotFoundError()
|
||||
|
|
|
|||
|
|
@ -15,15 +15,40 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from flask import Flask
|
||||
from flask_caching import Cache
|
||||
from markupsafe import Markup
|
||||
|
||||
from superset.utils.core import DatasourceType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CACHE_IMPORT_PATH = "superset.extensions.metastore_cache.SupersetMetastoreCache"
|
||||
|
||||
|
||||
class ExploreFormDataCache(Cache):
|
||||
def get(self, *args: Any, **kwargs: Any) -> Optional[Union[str, Markup]]:
|
||||
cache = self.cache.get(*args, **kwargs)
|
||||
|
||||
if not cache:
|
||||
return None
|
||||
|
||||
# rename data keys for existing cache based on new TemporaryExploreState model
|
||||
if isinstance(cache, dict):
|
||||
cache = {
|
||||
("datasource_id" if key == "dataset_id" else key): value
|
||||
for (key, value) in cache.items()
|
||||
}
|
||||
# add default datasource_type if it doesn't exist
|
||||
# temporarily defaulting to table until sqlatables are deprecated
|
||||
if "datasource_type" not in cache:
|
||||
cache["datasource_type"] = DatasourceType.TABLE
|
||||
|
||||
return cache
|
||||
|
||||
|
||||
class CacheManager:
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
|
@ -32,7 +57,7 @@ class CacheManager:
|
|||
self._data_cache = Cache()
|
||||
self._thumbnail_cache = Cache()
|
||||
self._filter_state_cache = Cache()
|
||||
self._explore_form_data_cache = Cache()
|
||||
self._explore_form_data_cache = ExploreFormDataCache()
|
||||
|
||||
@staticmethod
|
||||
def _init_cache(
|
||||
|
|
|
|||
|
|
@ -175,12 +175,13 @@ class GenericDataType(IntEnum):
|
|||
# ROW = 7
|
||||
|
||||
|
||||
class DatasourceType(Enum):
|
||||
SQLATABLE = "sqlatable"
|
||||
class DatasourceType(str, Enum):
|
||||
SLTABLE = "sl_table"
|
||||
TABLE = "table"
|
||||
DATASET = "dataset"
|
||||
QUERY = "query"
|
||||
SAVEDQUERY = "saved_query"
|
||||
VIEW = "view"
|
||||
|
||||
|
||||
class DatasourceDict(TypedDict):
|
||||
|
|
|
|||
|
|
@ -520,7 +520,13 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
|
|||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(
|
||||
response,
|
||||
{"message": {"datasource_type": ["Must be one of: druid, table, view."]}},
|
||||
{
|
||||
"message": {
|
||||
"datasource_type": [
|
||||
"Must be one of: sl_table, table, dataset, query, saved_query, view."
|
||||
]
|
||||
}
|
||||
},
|
||||
)
|
||||
chart_data = {
|
||||
"slice_name": "title1",
|
||||
|
|
@ -531,7 +537,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
|
|||
self.assertEqual(rv.status_code, 422)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(
|
||||
response, {"message": {"datasource_id": ["Dataset does not exist"]}}
|
||||
response, {"message": {"datasource_id": ["Datasource does not exist"]}}
|
||||
)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
|
|
@ -686,7 +692,13 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
|
|||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(
|
||||
response,
|
||||
{"message": {"datasource_type": ["Must be one of: druid, table, view."]}},
|
||||
{
|
||||
"message": {
|
||||
"datasource_type": [
|
||||
"Must be one of: sl_table, table, dataset, query, saved_query, view."
|
||||
]
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
chart_data = {"datasource_id": 0, "datasource_type": "table"}
|
||||
|
|
@ -694,7 +706,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
|
|||
self.assertEqual(rv.status_code, 422)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(
|
||||
response, {"message": {"datasource_id": ["Dataset does not exist"]}}
|
||||
response, {"message": {"datasource_id": ["Datasource does not exist"]}}
|
||||
)
|
||||
|
||||
db.session.delete(chart)
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ from superset.connectors.sqla.models import SqlaTable
|
|||
from superset.models.core import Database
|
||||
from superset.models.dashboard import Dashboard
|
||||
from superset.models.slice import Slice
|
||||
from superset.utils.core import get_example_default_schema
|
||||
from superset.utils.core import DatasourceType, get_example_default_schema
|
||||
|
||||
|
||||
def get_table(
|
||||
|
|
@ -72,7 +72,7 @@ def create_slice(
|
|||
return Slice(
|
||||
slice_name=title,
|
||||
viz_type=viz_type,
|
||||
datasource_type="table",
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
datasource_id=table.id,
|
||||
params=json.dumps(slices_dict, indent=4, sort_keys=True),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ def admin_id() -> int:
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_id() -> int:
|
||||
def datasource() -> int:
|
||||
with app.app_context() as ctx:
|
||||
session: Session = ctx.app.appbuilder.get_session
|
||||
dataset = (
|
||||
|
|
@ -64,24 +64,26 @@ def dataset_id() -> int:
|
|||
.filter_by(table_name="wb_health_population")
|
||||
.first()
|
||||
)
|
||||
return dataset.id
|
||||
return dataset
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cache(chart_id, admin_id, dataset_id):
|
||||
def cache(chart_id, admin_id, datasource):
|
||||
entry: TemporaryExploreState = {
|
||||
"owner": admin_id,
|
||||
"dataset_id": dataset_id,
|
||||
"datasource_id": datasource.id,
|
||||
"datasource_type": datasource.type,
|
||||
"chart_id": chart_id,
|
||||
"form_data": INITIAL_FORM_DATA,
|
||||
}
|
||||
cache_manager.explore_form_data_cache.set(KEY, entry)
|
||||
|
||||
|
||||
def test_post(client, chart_id: int, dataset_id: int):
|
||||
def test_post(client, chart_id: int, datasource: SqlaTable):
|
||||
login(client, "admin")
|
||||
payload = {
|
||||
"dataset_id": dataset_id,
|
||||
"datasource_id": datasource.id,
|
||||
"datasource_type": datasource.type,
|
||||
"chart_id": chart_id,
|
||||
"form_data": INITIAL_FORM_DATA,
|
||||
}
|
||||
|
|
@ -89,10 +91,11 @@ def test_post(client, chart_id: int, dataset_id: int):
|
|||
assert resp.status_code == 201
|
||||
|
||||
|
||||
def test_post_bad_request_non_string(client, chart_id: int, dataset_id: int):
|
||||
def test_post_bad_request_non_string(client, chart_id: int, datasource: SqlaTable):
|
||||
login(client, "admin")
|
||||
payload = {
|
||||
"dataset_id": dataset_id,
|
||||
"datasource_id": datasource.id,
|
||||
"datasource_type": datasource.type,
|
||||
"chart_id": chart_id,
|
||||
"form_data": 1234,
|
||||
}
|
||||
|
|
@ -100,10 +103,11 @@ def test_post_bad_request_non_string(client, chart_id: int, dataset_id: int):
|
|||
assert resp.status_code == 400
|
||||
|
||||
|
||||
def test_post_bad_request_non_json_string(client, chart_id: int, dataset_id: int):
|
||||
def test_post_bad_request_non_json_string(client, chart_id: int, datasource: SqlaTable):
|
||||
login(client, "admin")
|
||||
payload = {
|
||||
"dataset_id": dataset_id,
|
||||
"datasource_id": datasource.id,
|
||||
"datasource_type": datasource.type,
|
||||
"chart_id": chart_id,
|
||||
"form_data": "foo",
|
||||
}
|
||||
|
|
@ -111,10 +115,11 @@ def test_post_bad_request_non_json_string(client, chart_id: int, dataset_id: int
|
|||
assert resp.status_code == 400
|
||||
|
||||
|
||||
def test_post_access_denied(client, chart_id: int, dataset_id: int):
|
||||
def test_post_access_denied(client, chart_id: int, datasource: SqlaTable):
|
||||
login(client, "gamma")
|
||||
payload = {
|
||||
"dataset_id": dataset_id,
|
||||
"datasource_id": datasource.id,
|
||||
"datasource_type": datasource.type,
|
||||
"chart_id": chart_id,
|
||||
"form_data": INITIAL_FORM_DATA,
|
||||
}
|
||||
|
|
@ -122,10 +127,11 @@ def test_post_access_denied(client, chart_id: int, dataset_id: int):
|
|||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_post_same_key_for_same_context(client, chart_id: int, dataset_id: int):
|
||||
def test_post_same_key_for_same_context(client, chart_id: int, datasource: SqlaTable):
|
||||
login(client, "admin")
|
||||
payload = {
|
||||
"dataset_id": dataset_id,
|
||||
"datasource_id": datasource.id,
|
||||
"datasource_type": datasource.type,
|
||||
"chart_id": chart_id,
|
||||
"form_data": UPDATED_FORM_DATA,
|
||||
}
|
||||
|
|
@ -139,11 +145,12 @@ def test_post_same_key_for_same_context(client, chart_id: int, dataset_id: int):
|
|||
|
||||
|
||||
def test_post_different_key_for_different_context(
|
||||
client, chart_id: int, dataset_id: int
|
||||
client, chart_id: int, datasource: SqlaTable
|
||||
):
|
||||
login(client, "admin")
|
||||
payload = {
|
||||
"dataset_id": dataset_id,
|
||||
"datasource_id": datasource.id,
|
||||
"datasource_type": datasource.type,
|
||||
"chart_id": chart_id,
|
||||
"form_data": UPDATED_FORM_DATA,
|
||||
}
|
||||
|
|
@ -151,7 +158,8 @@ def test_post_different_key_for_different_context(
|
|||
data = json.loads(resp.data.decode("utf-8"))
|
||||
first_key = data.get("key")
|
||||
payload = {
|
||||
"dataset_id": dataset_id,
|
||||
"datasource_id": datasource.id,
|
||||
"datasource_type": datasource.type,
|
||||
"form_data": json.dumps({"test": "initial value"}),
|
||||
}
|
||||
resp = client.post("api/v1/explore/form_data?tab_id=1", json=payload)
|
||||
|
|
@ -160,10 +168,11 @@ def test_post_different_key_for_different_context(
|
|||
assert first_key != second_key
|
||||
|
||||
|
||||
def test_post_same_key_for_same_tab_id(client, chart_id: int, dataset_id: int):
|
||||
def test_post_same_key_for_same_tab_id(client, chart_id: int, datasource: SqlaTable):
|
||||
login(client, "admin")
|
||||
payload = {
|
||||
"dataset_id": dataset_id,
|
||||
"datasource_id": datasource.id,
|
||||
"datasource_type": datasource.type,
|
||||
"chart_id": chart_id,
|
||||
"form_data": json.dumps({"test": "initial value"}),
|
||||
}
|
||||
|
|
@ -177,11 +186,12 @@ def test_post_same_key_for_same_tab_id(client, chart_id: int, dataset_id: int):
|
|||
|
||||
|
||||
def test_post_different_key_for_different_tab_id(
|
||||
client, chart_id: int, dataset_id: int
|
||||
client, chart_id: int, datasource: SqlaTable
|
||||
):
|
||||
login(client, "admin")
|
||||
payload = {
|
||||
"dataset_id": dataset_id,
|
||||
"datasource_id": datasource.id,
|
||||
"datasource_type": datasource.type,
|
||||
"chart_id": chart_id,
|
||||
"form_data": json.dumps({"test": "initial value"}),
|
||||
}
|
||||
|
|
@ -194,10 +204,11 @@ def test_post_different_key_for_different_tab_id(
|
|||
assert first_key != second_key
|
||||
|
||||
|
||||
def test_post_different_key_for_no_tab_id(client, chart_id: int, dataset_id: int):
|
||||
def test_post_different_key_for_no_tab_id(client, chart_id: int, datasource: SqlaTable):
|
||||
login(client, "admin")
|
||||
payload = {
|
||||
"dataset_id": dataset_id,
|
||||
"datasource_id": datasource.id,
|
||||
"datasource_type": datasource.type,
|
||||
"chart_id": chart_id,
|
||||
"form_data": INITIAL_FORM_DATA,
|
||||
}
|
||||
|
|
@ -210,10 +221,11 @@ def test_post_different_key_for_no_tab_id(client, chart_id: int, dataset_id: int
|
|||
assert first_key != second_key
|
||||
|
||||
|
||||
def test_put(client, chart_id: int, dataset_id: int):
|
||||
def test_put(client, chart_id: int, datasource: SqlaTable):
|
||||
login(client, "admin")
|
||||
payload = {
|
||||
"dataset_id": dataset_id,
|
||||
"datasource_id": datasource.id,
|
||||
"datasource_type": datasource.type,
|
||||
"chart_id": chart_id,
|
||||
"form_data": UPDATED_FORM_DATA,
|
||||
}
|
||||
|
|
@ -221,10 +233,11 @@ def test_put(client, chart_id: int, dataset_id: int):
|
|||
assert resp.status_code == 200
|
||||
|
||||
|
||||
def test_put_same_key_for_same_tab_id(client, chart_id: int, dataset_id: int):
|
||||
def test_put_same_key_for_same_tab_id(client, chart_id: int, datasource: SqlaTable):
|
||||
login(client, "admin")
|
||||
payload = {
|
||||
"dataset_id": dataset_id,
|
||||
"datasource_id": datasource.id,
|
||||
"datasource_type": datasource.type,
|
||||
"chart_id": chart_id,
|
||||
"form_data": UPDATED_FORM_DATA,
|
||||
}
|
||||
|
|
@ -237,10 +250,13 @@ def test_put_same_key_for_same_tab_id(client, chart_id: int, dataset_id: int):
|
|||
assert first_key == second_key
|
||||
|
||||
|
||||
def test_put_different_key_for_different_tab_id(client, chart_id: int, dataset_id: int):
|
||||
def test_put_different_key_for_different_tab_id(
|
||||
client, chart_id: int, datasource: SqlaTable
|
||||
):
|
||||
login(client, "admin")
|
||||
payload = {
|
||||
"dataset_id": dataset_id,
|
||||
"datasource_id": datasource.id,
|
||||
"datasource_type": datasource.type,
|
||||
"chart_id": chart_id,
|
||||
"form_data": UPDATED_FORM_DATA,
|
||||
}
|
||||
|
|
@ -253,10 +269,11 @@ def test_put_different_key_for_different_tab_id(client, chart_id: int, dataset_i
|
|||
assert first_key != second_key
|
||||
|
||||
|
||||
def test_put_different_key_for_no_tab_id(client, chart_id: int, dataset_id: int):
|
||||
def test_put_different_key_for_no_tab_id(client, chart_id: int, datasource: SqlaTable):
|
||||
login(client, "admin")
|
||||
payload = {
|
||||
"dataset_id": dataset_id,
|
||||
"datasource_id": datasource.id,
|
||||
"datasource_type": datasource.type,
|
||||
"chart_id": chart_id,
|
||||
"form_data": UPDATED_FORM_DATA,
|
||||
}
|
||||
|
|
@ -269,10 +286,11 @@ def test_put_different_key_for_no_tab_id(client, chart_id: int, dataset_id: int)
|
|||
assert first_key != second_key
|
||||
|
||||
|
||||
def test_put_bad_request(client, chart_id: int, dataset_id: int):
|
||||
def test_put_bad_request(client, chart_id: int, datasource: SqlaTable):
|
||||
login(client, "admin")
|
||||
payload = {
|
||||
"dataset_id": dataset_id,
|
||||
"datasource_id": datasource.id,
|
||||
"datasource_type": datasource.type,
|
||||
"chart_id": chart_id,
|
||||
"form_data": 1234,
|
||||
}
|
||||
|
|
@ -280,10 +298,11 @@ def test_put_bad_request(client, chart_id: int, dataset_id: int):
|
|||
assert resp.status_code == 400
|
||||
|
||||
|
||||
def test_put_bad_request_non_string(client, chart_id: int, dataset_id: int):
|
||||
def test_put_bad_request_non_string(client, chart_id: int, datasource: SqlaTable):
|
||||
login(client, "admin")
|
||||
payload = {
|
||||
"dataset_id": dataset_id,
|
||||
"datasource_id": datasource.id,
|
||||
"datasource_type": datasource.type,
|
||||
"chart_id": chart_id,
|
||||
"form_data": 1234,
|
||||
}
|
||||
|
|
@ -291,10 +310,11 @@ def test_put_bad_request_non_string(client, chart_id: int, dataset_id: int):
|
|||
assert resp.status_code == 400
|
||||
|
||||
|
||||
def test_put_bad_request_non_json_string(client, chart_id: int, dataset_id: int):
|
||||
def test_put_bad_request_non_json_string(client, chart_id: int, datasource: SqlaTable):
|
||||
login(client, "admin")
|
||||
payload = {
|
||||
"dataset_id": dataset_id,
|
||||
"datasource_id": datasource.id,
|
||||
"datasource_type": datasource.type,
|
||||
"chart_id": chart_id,
|
||||
"form_data": "foo",
|
||||
}
|
||||
|
|
@ -302,10 +322,11 @@ def test_put_bad_request_non_json_string(client, chart_id: int, dataset_id: int)
|
|||
assert resp.status_code == 400
|
||||
|
||||
|
||||
def test_put_access_denied(client, chart_id: int, dataset_id: int):
|
||||
def test_put_access_denied(client, chart_id: int, datasource: SqlaTable):
|
||||
login(client, "gamma")
|
||||
payload = {
|
||||
"dataset_id": dataset_id,
|
||||
"datasource_id": datasource.id,
|
||||
"datasource_type": datasource.type,
|
||||
"chart_id": chart_id,
|
||||
"form_data": UPDATED_FORM_DATA,
|
||||
}
|
||||
|
|
@ -313,10 +334,11 @@ def test_put_access_denied(client, chart_id: int, dataset_id: int):
|
|||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_put_not_owner(client, chart_id: int, dataset_id: int):
|
||||
def test_put_not_owner(client, chart_id: int, datasource: SqlaTable):
|
||||
login(client, "gamma")
|
||||
payload = {
|
||||
"dataset_id": dataset_id,
|
||||
"datasource_id": datasource.id,
|
||||
"datasource_type": datasource.type,
|
||||
"chart_id": chart_id,
|
||||
"form_data": UPDATED_FORM_DATA,
|
||||
}
|
||||
|
|
@ -364,12 +386,13 @@ def test_delete_access_denied(client):
|
|||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_delete_not_owner(client, chart_id: int, dataset_id: int, admin_id: int):
|
||||
def test_delete_not_owner(client, chart_id: int, datasource: SqlaTable, admin_id: int):
|
||||
another_key = "another_key"
|
||||
another_owner = admin_id + 1
|
||||
entry: TemporaryExploreState = {
|
||||
"owner": another_owner,
|
||||
"dataset_id": dataset_id,
|
||||
"datasource_id": datasource.id,
|
||||
"datasource_type": datasource.type,
|
||||
"chart_id": chart_id,
|
||||
"form_data": INITIAL_FORM_DATA,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,359 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from superset import app, db, security, security_manager
|
||||
from superset.commands.exceptions import DatasourceTypeInvalidError
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.explore.form_data.commands.create import CreateFormDataCommand
|
||||
from superset.explore.form_data.commands.delete import DeleteFormDataCommand
|
||||
from superset.explore.form_data.commands.get import GetFormDataCommand
|
||||
from superset.explore.form_data.commands.parameters import CommandParameters
|
||||
from superset.explore.form_data.commands.update import UpdateFormDataCommand
|
||||
from superset.models.slice import Slice
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.utils.core import DatasourceType, get_example_default_schema
|
||||
from superset.utils.database import get_example_database
|
||||
from tests.integration_tests.base_tests import SupersetTestCase
|
||||
|
||||
|
||||
class TestCreateFormDataCommand(SupersetTestCase):
|
||||
@pytest.fixture()
|
||||
def create_dataset(self):
|
||||
with self.create_app().app_context():
|
||||
dataset = SqlaTable(
|
||||
table_name="dummy_sql_table",
|
||||
database=get_example_database(),
|
||||
schema=get_example_default_schema(),
|
||||
sql="select 123 as intcol, 'abc' as strcol",
|
||||
)
|
||||
session = db.session
|
||||
session.add(dataset)
|
||||
session.commit()
|
||||
|
||||
yield dataset
|
||||
|
||||
# rollback
|
||||
session.delete(dataset)
|
||||
session.commit()
|
||||
|
||||
@pytest.fixture()
|
||||
def create_slice(self):
|
||||
with self.create_app().app_context():
|
||||
session = db.session
|
||||
dataset = (
|
||||
session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first()
|
||||
)
|
||||
slice = Slice(
|
||||
datasource_id=dataset.id,
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
datasource_name="tmp_perm_table",
|
||||
slice_name="slice_name",
|
||||
)
|
||||
|
||||
session.add(slice)
|
||||
session.commit()
|
||||
|
||||
yield slice
|
||||
|
||||
# rollback
|
||||
session.delete(slice)
|
||||
session.commit()
|
||||
|
||||
@pytest.fixture()
|
||||
def create_query(self):
|
||||
with self.create_app().app_context():
|
||||
session = db.session
|
||||
|
||||
query = Query(
|
||||
sql="select 1 as foo;",
|
||||
client_id="sldkfjlk",
|
||||
database=get_example_database(),
|
||||
)
|
||||
|
||||
session.add(query)
|
||||
session.commit()
|
||||
|
||||
yield query
|
||||
|
||||
# rollback
|
||||
session.delete(query)
|
||||
session.commit()
|
||||
|
||||
@patch("superset.security.manager.g")
|
||||
@pytest.mark.usefixtures("create_dataset", "create_slice")
|
||||
def test_create_form_data_command(self, mock_g):
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
|
||||
dataset = (
|
||||
db.session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first()
|
||||
)
|
||||
slice = db.session.query(Slice).filter_by(slice_name="slice_name").first()
|
||||
|
||||
datasource = f"{dataset.id}__{DatasourceType.TABLE}"
|
||||
args = CommandParameters(
|
||||
actor=mock_g.user,
|
||||
datasource_id=dataset.id,
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
chart_id=slice.id,
|
||||
tab_id=1,
|
||||
form_data=json.dumps({"datasource": datasource}),
|
||||
)
|
||||
command = CreateFormDataCommand(args)
|
||||
|
||||
assert isinstance(command.run(), str)
|
||||
|
||||
@patch("superset.security.manager.g")
|
||||
@pytest.mark.usefixtures("create_dataset", "create_slice", "create_query")
|
||||
def test_create_form_data_command_invalid_type(self, mock_g):
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = {
|
||||
"REFRESH_TIMEOUT_ON_RETRIEVAL": True
|
||||
}
|
||||
|
||||
dataset = (
|
||||
db.session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first()
|
||||
)
|
||||
slice = db.session.query(Slice).filter_by(slice_name="slice_name").first()
|
||||
|
||||
datasource = f"{dataset.id}__{DatasourceType.TABLE}"
|
||||
create_args = CommandParameters(
|
||||
actor=mock_g.user,
|
||||
datasource_id=dataset.id,
|
||||
datasource_type="InvalidType",
|
||||
chart_id=slice.id,
|
||||
tab_id=1,
|
||||
form_data=json.dumps({"datasource": datasource}),
|
||||
)
|
||||
with pytest.raises(DatasourceTypeInvalidError) as exc:
|
||||
CreateFormDataCommand(create_args).run()
|
||||
|
||||
assert "Datasource type is invalid" in str(exc.value)
|
||||
|
||||
@patch("superset.security.manager.g")
|
||||
@pytest.mark.usefixtures("create_dataset", "create_slice", "create_query")
|
||||
def test_create_form_data_command_type_as_string(self, mock_g):
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = {
|
||||
"REFRESH_TIMEOUT_ON_RETRIEVAL": True
|
||||
}
|
||||
|
||||
dataset = (
|
||||
db.session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first()
|
||||
)
|
||||
slice = db.session.query(Slice).filter_by(slice_name="slice_name").first()
|
||||
|
||||
datasource = f"{dataset.id}__{DatasourceType.TABLE}"
|
||||
create_args = CommandParameters(
|
||||
actor=mock_g.user,
|
||||
datasource_id=dataset.id,
|
||||
datasource_type="table",
|
||||
chart_id=slice.id,
|
||||
tab_id=1,
|
||||
form_data=json.dumps({"datasource": datasource}),
|
||||
)
|
||||
command = CreateFormDataCommand(create_args)
|
||||
|
||||
assert isinstance(command.run(), str)
|
||||
|
||||
@patch("superset.security.manager.g")
|
||||
@pytest.mark.usefixtures("create_dataset", "create_slice")
|
||||
def test_get_form_data_command(self, mock_g):
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = {
|
||||
"REFRESH_TIMEOUT_ON_RETRIEVAL": True
|
||||
}
|
||||
|
||||
dataset = (
|
||||
db.session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first()
|
||||
)
|
||||
slice = db.session.query(Slice).filter_by(slice_name="slice_name").first()
|
||||
|
||||
datasource = f"{dataset.id}__{DatasourceType.TABLE}"
|
||||
create_args = CommandParameters(
|
||||
actor=mock_g.user,
|
||||
datasource_id=dataset.id,
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
chart_id=slice.id,
|
||||
tab_id=1,
|
||||
form_data=json.dumps({"datasource": datasource}),
|
||||
)
|
||||
key = CreateFormDataCommand(create_args).run()
|
||||
|
||||
key_args = CommandParameters(actor=mock_g.user, key=key)
|
||||
get_command = GetFormDataCommand(key_args)
|
||||
cache_data = json.loads(get_command.run())
|
||||
|
||||
assert cache_data.get("datasource") == datasource
|
||||
|
||||
@patch("superset.security.manager.g")
|
||||
@pytest.mark.usefixtures("create_dataset", "create_slice", "create_query")
|
||||
def test_update_form_data_command(self, mock_g):
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = {
|
||||
"REFRESH_TIMEOUT_ON_RETRIEVAL": True
|
||||
}
|
||||
|
||||
dataset = (
|
||||
db.session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first()
|
||||
)
|
||||
slice = db.session.query(Slice).filter_by(slice_name="slice_name").first()
|
||||
|
||||
query = db.session.query(Query).filter_by(sql="select 1 as foo;").first()
|
||||
|
||||
datasource = f"{dataset.id}__{DatasourceType.TABLE}"
|
||||
create_args = CommandParameters(
|
||||
actor=mock_g.user,
|
||||
datasource_id=dataset.id,
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
chart_id=slice.id,
|
||||
tab_id=1,
|
||||
form_data=json.dumps({"datasource": datasource}),
|
||||
)
|
||||
key = CreateFormDataCommand(create_args).run()
|
||||
|
||||
query_datasource = f"{dataset.id}__{DatasourceType.TABLE}"
|
||||
update_args = CommandParameters(
|
||||
actor=mock_g.user,
|
||||
datasource_id=query.id,
|
||||
datasource_type=DatasourceType.QUERY,
|
||||
chart_id=slice.id,
|
||||
tab_id=1,
|
||||
form_data=json.dumps({"datasource": query_datasource}),
|
||||
key=key,
|
||||
)
|
||||
|
||||
update_command = UpdateFormDataCommand(update_args)
|
||||
new_key = update_command.run()
|
||||
|
||||
# it should return a key
|
||||
assert isinstance(new_key, str)
|
||||
# the updated key returned should be different from the old one
|
||||
assert new_key != key
|
||||
|
||||
key_args = CommandParameters(actor=mock_g.user, key=key)
|
||||
get_command = GetFormDataCommand(key_args)
|
||||
|
||||
cache_data = json.loads(get_command.run())
|
||||
|
||||
assert cache_data.get("datasource") == query_datasource
|
||||
|
||||
@patch("superset.security.manager.g")
|
||||
@pytest.mark.usefixtures("create_dataset", "create_slice", "create_query")
|
||||
def test_update_form_data_command_same_form_data(self, mock_g):
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = {
|
||||
"REFRESH_TIMEOUT_ON_RETRIEVAL": True
|
||||
}
|
||||
|
||||
dataset = (
|
||||
db.session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first()
|
||||
)
|
||||
slice = db.session.query(Slice).filter_by(slice_name="slice_name").first()
|
||||
|
||||
datasource = f"{dataset.id}__{DatasourceType.TABLE}"
|
||||
create_args = CommandParameters(
|
||||
actor=mock_g.user,
|
||||
datasource_id=dataset.id,
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
chart_id=slice.id,
|
||||
tab_id=1,
|
||||
form_data=json.dumps({"datasource": datasource}),
|
||||
)
|
||||
key = CreateFormDataCommand(create_args).run()
|
||||
|
||||
update_args = CommandParameters(
|
||||
actor=mock_g.user,
|
||||
datasource_id=dataset.id,
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
chart_id=slice.id,
|
||||
tab_id=1,
|
||||
form_data=json.dumps({"datasource": datasource}),
|
||||
key=key,
|
||||
)
|
||||
|
||||
update_command = UpdateFormDataCommand(update_args)
|
||||
new_key = update_command.run()
|
||||
|
||||
# it should return a key
|
||||
assert isinstance(new_key, str)
|
||||
|
||||
# the updated key returned should be the same as the old one
|
||||
assert new_key == key
|
||||
|
||||
key_args = CommandParameters(actor=mock_g.user, key=key)
|
||||
get_command = GetFormDataCommand(key_args)
|
||||
|
||||
cache_data = json.loads(get_command.run())
|
||||
|
||||
assert cache_data.get("datasource") == datasource
|
||||
|
||||
@patch("superset.security.manager.g")
|
||||
@pytest.mark.usefixtures("create_dataset", "create_slice", "create_query")
|
||||
def test_delete_form_data_command(self, mock_g):
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = {
|
||||
"REFRESH_TIMEOUT_ON_RETRIEVAL": True
|
||||
}
|
||||
|
||||
dataset = (
|
||||
db.session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first()
|
||||
)
|
||||
slice = db.session.query(Slice).filter_by(slice_name="slice_name").first()
|
||||
|
||||
datasource = f"{dataset.id}__{DatasourceType.TABLE}"
|
||||
create_args = CommandParameters(
|
||||
actor=mock_g.user,
|
||||
datasource_id=dataset.id,
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
chart_id=slice.id,
|
||||
tab_id=1,
|
||||
form_data=json.dumps({"datasource": datasource}),
|
||||
)
|
||||
key = CreateFormDataCommand(create_args).run()
|
||||
|
||||
delete_args = CommandParameters(
|
||||
actor=mock_g.user,
|
||||
key=key,
|
||||
)
|
||||
|
||||
delete_command = DeleteFormDataCommand(delete_args)
|
||||
response = delete_command.run()
|
||||
|
||||
assert response == True
|
||||
|
||||
@patch("superset.security.manager.g")
|
||||
@pytest.mark.usefixtures("create_dataset", "create_slice", "create_query")
|
||||
def test_delete_form_data_command_key_expired(self, mock_g):
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = {
|
||||
"REFRESH_TIMEOUT_ON_RETRIEVAL": True
|
||||
}
|
||||
|
||||
delete_args = CommandParameters(
|
||||
actor=mock_g.user,
|
||||
key="some_expired_key",
|
||||
)
|
||||
|
||||
delete_command = DeleteFormDataCommand(delete_args)
|
||||
response = delete_command.run()
|
||||
|
||||
assert response == False
|
||||
|
|
@ -27,6 +27,7 @@ from superset.key_value.models import KeyValueEntry
|
|||
from superset.key_value.types import KeyValueResource
|
||||
from superset.key_value.utils import decode_permalink_id, encode_permalink_key
|
||||
from superset.models.slice import Slice
|
||||
from superset.utils.core import DatasourceType
|
||||
from tests.integration_tests.base_tests import login
|
||||
from tests.integration_tests.fixtures.client import client
|
||||
from tests.integration_tests.fixtures.world_bank_dashboard import (
|
||||
|
|
@ -97,7 +98,8 @@ def test_get_missing_chart(client, chart, permalink_salt: str) -> None:
|
|||
value=pickle.dumps(
|
||||
{
|
||||
"chartId": chart_id,
|
||||
"datasetId": chart.datasource.id,
|
||||
"datasourceId": chart.datasource.id,
|
||||
"datasourceType": DatasourceType.TABLE,
|
||||
"formData": {
|
||||
"slice_id": chart_id,
|
||||
"datasource": f"{chart.datasource.id}__{chart.datasource.type}",
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ from superset.dashboards.commands.importers.v0 import import_chart, import_dashb
|
|||
from superset.datasets.commands.importers.v0 import import_dataset
|
||||
from superset.models.dashboard import Dashboard
|
||||
from superset.models.slice import Slice
|
||||
from superset.utils.core import get_example_default_schema
|
||||
from superset.utils.core import DatasourceType, get_example_default_schema
|
||||
from superset.utils.database import get_example_database
|
||||
|
||||
from tests.integration_tests.fixtures.world_bank_dashboard import (
|
||||
|
|
@ -103,7 +103,7 @@ class TestImportExport(SupersetTestCase):
|
|||
|
||||
return Slice(
|
||||
slice_name=name,
|
||||
datasource_type="table",
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
viz_type="bubble",
|
||||
params=json.dumps(params),
|
||||
datasource_id=ds_id,
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@
|
|||
# under the License.
|
||||
# isort:skip_file
|
||||
import json
|
||||
from superset.utils.core import DatasourceType
|
||||
import textwrap
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
|
@ -604,7 +605,7 @@ class TestSqlaTableModel(SupersetTestCase):
|
|||
dashboard = self.get_dash_by_slug("births")
|
||||
slc = Slice(
|
||||
slice_name="slice with adhoc column",
|
||||
datasource_type="table",
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
viz_type="table",
|
||||
params=json.dumps(
|
||||
{
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ from superset.models.core import Database
|
|||
from superset.models.slice import Slice
|
||||
from superset.sql_parse import Table
|
||||
from superset.utils.core import (
|
||||
DatasourceType,
|
||||
backend,
|
||||
get_example_default_schema,
|
||||
)
|
||||
|
|
@ -120,7 +121,7 @@ class TestRolePermission(SupersetTestCase):
|
|||
|
||||
ds_slices = (
|
||||
session.query(Slice)
|
||||
.filter_by(datasource_type="table")
|
||||
.filter_by(datasource_type=DatasourceType.TABLE)
|
||||
.filter_by(datasource_id=ds.id)
|
||||
.all()
|
||||
)
|
||||
|
|
@ -143,7 +144,7 @@ class TestRolePermission(SupersetTestCase):
|
|||
ds.schema_perm = None
|
||||
ds_slices = (
|
||||
session.query(Slice)
|
||||
.filter_by(datasource_type="table")
|
||||
.filter_by(datasource_type=DatasourceType.TABLE)
|
||||
.filter_by(datasource_id=ds.id)
|
||||
.all()
|
||||
)
|
||||
|
|
@ -365,7 +366,7 @@ class TestRolePermission(SupersetTestCase):
|
|||
# no schema permission
|
||||
slice = Slice(
|
||||
datasource_id=table.id,
|
||||
datasource_type="table",
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
datasource_name="tmp_perm_table",
|
||||
slice_name="slice_name",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,49 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import pytest
|
||||
|
||||
from superset.extensions import cache_manager
|
||||
from superset.utils.core import backend, DatasourceType
|
||||
from tests.integration_tests.base_tests import SupersetTestCase
|
||||
|
||||
|
||||
class UtilsCacheManagerTests(SupersetTestCase):
|
||||
def test_get_set_explore_form_data_cache(self):
|
||||
key = "12345"
|
||||
data = {"foo": "bar", "datasource_type": "query"}
|
||||
cache_manager.explore_form_data_cache.set(key, data)
|
||||
assert cache_manager.explore_form_data_cache.get(key) == data
|
||||
|
||||
def test_get_same_context_twice(self):
|
||||
key = "12345"
|
||||
data = {"foo": "bar", "datasource_type": "query"}
|
||||
cache_manager.explore_form_data_cache.set(key, data)
|
||||
assert cache_manager.explore_form_data_cache.get(key) == data
|
||||
assert cache_manager.explore_form_data_cache.get(key) == data
|
||||
|
||||
def test_get_set_explore_form_data_cache_no_datasource_type(self):
|
||||
key = "12345"
|
||||
data = {"foo": "bar"}
|
||||
cache_manager.explore_form_data_cache.set(key, data)
|
||||
# datasource_type should be added because it is not present
|
||||
assert cache_manager.explore_form_data_cache.get(key) == {
|
||||
"datasource_type": DatasourceType.TABLE,
|
||||
**data,
|
||||
}
|
||||
|
||||
def test_get_explore_form_data_cache_invalid_key(self):
|
||||
assert cache_manager.explore_form_data_cache.get("foo") == None
|
||||
|
|
@ -106,7 +106,7 @@ def test_get_datasource_sqlatable(
|
|||
from superset.dao.datasource.dao import DatasourceDAO
|
||||
|
||||
result = DatasourceDAO.get_datasource(
|
||||
datasource_type=DatasourceType.SQLATABLE,
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
datasource_id=1,
|
||||
session=session_with_data,
|
||||
)
|
||||
|
|
@ -151,7 +151,9 @@ def test_get_datasource_sl_table(app_context: None, session_with_data: Session)
|
|||
# todo(hugh): This will break once we remove the dual write
|
||||
# update the datsource_id=1 and this will pass again
|
||||
result = DatasourceDAO.get_datasource(
|
||||
datasource_type=DatasourceType.TABLE, datasource_id=2, session=session_with_data
|
||||
datasource_type=DatasourceType.SLTABLE,
|
||||
datasource_id=2,
|
||||
session=session_with_data,
|
||||
)
|
||||
|
||||
assert result.id == 2
|
||||
|
|
|
|||
|
|
@ -23,12 +23,21 @@ from superset.charts.commands.exceptions import (
|
|||
ChartAccessDeniedError,
|
||||
ChartNotFoundError,
|
||||
)
|
||||
from superset.commands.exceptions import (
|
||||
DatasourceNotFoundValidationError,
|
||||
DatasourceTypeInvalidError,
|
||||
OwnersNotFoundValidationError,
|
||||
QueryNotFoundValidationError,
|
||||
)
|
||||
from superset.datasets.commands.exceptions import (
|
||||
DatasetAccessDeniedError,
|
||||
DatasetNotFoundError,
|
||||
)
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
from superset.utils.core import DatasourceType
|
||||
|
||||
dataset_find_by_id = "superset.datasets.dao.DatasetDAO.find_by_id"
|
||||
query_find_by_id = "superset.queries.dao.QueryDAO.find_by_id"
|
||||
chart_find_by_id = "superset.charts.dao.ChartDAO.find_by_id"
|
||||
is_user_admin = "superset.explore.utils.is_user_admin"
|
||||
is_owner = "superset.explore.utils.is_owner"
|
||||
|
|
@ -36,88 +45,142 @@ can_access_datasource = (
|
|||
"superset.security.SupersetSecurityManager.can_access_datasource"
|
||||
)
|
||||
can_access = "superset.security.SupersetSecurityManager.can_access"
|
||||
raise_for_access = "superset.security.SupersetSecurityManager.raise_for_access"
|
||||
query_datasources_by_name = (
|
||||
"superset.connectors.sqla.models.SqlaTable.query_datasources_by_name"
|
||||
)
|
||||
|
||||
|
||||
def test_unsaved_chart_no_dataset_id(app_context: AppContext) -> None:
|
||||
from superset.explore.utils import check_access
|
||||
from superset.explore.utils import check_access as check_chart_access
|
||||
|
||||
with raises(DatasetNotFoundError):
|
||||
check_access(dataset_id=0, chart_id=0, actor=User())
|
||||
with raises(DatasourceNotFoundValidationError):
|
||||
check_chart_access(
|
||||
datasource_id=0,
|
||||
chart_id=0,
|
||||
actor=User(),
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
)
|
||||
|
||||
|
||||
def test_unsaved_chart_unknown_dataset_id(
|
||||
mocker: MockFixture, app_context: AppContext
|
||||
) -> None:
|
||||
from superset.explore.utils import check_access
|
||||
from superset.explore.utils import check_access as check_chart_access
|
||||
|
||||
with raises(DatasetNotFoundError):
|
||||
mocker.patch(dataset_find_by_id, return_value=None)
|
||||
check_access(dataset_id=1, chart_id=0, actor=User())
|
||||
check_chart_access(
|
||||
datasource_id=1,
|
||||
chart_id=0,
|
||||
actor=User(),
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
)
|
||||
|
||||
|
||||
def test_unsaved_chart_unknown_query_id(
|
||||
mocker: MockFixture, app_context: AppContext
|
||||
) -> None:
|
||||
from superset.explore.utils import check_access as check_chart_access
|
||||
|
||||
with raises(QueryNotFoundValidationError):
|
||||
mocker.patch(query_find_by_id, return_value=None)
|
||||
check_chart_access(
|
||||
datasource_id=1,
|
||||
chart_id=0,
|
||||
actor=User(),
|
||||
datasource_type=DatasourceType.QUERY,
|
||||
)
|
||||
|
||||
|
||||
def test_unsaved_chart_unauthorized_dataset(
|
||||
mocker: MockFixture, app_context: AppContext
|
||||
) -> None:
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.explore import utils
|
||||
from superset.explore.utils import check_access as check_chart_access
|
||||
|
||||
with raises(DatasetAccessDeniedError):
|
||||
mocker.patch(dataset_find_by_id, return_value=SqlaTable())
|
||||
mocker.patch(can_access_datasource, return_value=False)
|
||||
utils.check_access(dataset_id=1, chart_id=0, actor=User())
|
||||
check_chart_access(
|
||||
datasource_id=1,
|
||||
chart_id=0,
|
||||
actor=User(),
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
)
|
||||
|
||||
|
||||
def test_unsaved_chart_authorized_dataset(
|
||||
mocker: MockFixture, app_context: AppContext
|
||||
) -> None:
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.explore.utils import check_access
|
||||
from superset.explore.utils import check_access as check_chart_access
|
||||
|
||||
mocker.patch(dataset_find_by_id, return_value=SqlaTable())
|
||||
mocker.patch(can_access_datasource, return_value=True)
|
||||
check_access(dataset_id=1, chart_id=0, actor=User())
|
||||
check_chart_access(
|
||||
datasource_id=1,
|
||||
chart_id=0,
|
||||
actor=User(),
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
)
|
||||
|
||||
|
||||
def test_saved_chart_unknown_chart_id(
|
||||
mocker: MockFixture, app_context: AppContext
|
||||
) -> None:
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.explore.utils import check_access
|
||||
from superset.explore.utils import check_access as check_chart_access
|
||||
|
||||
with raises(ChartNotFoundError):
|
||||
mocker.patch(dataset_find_by_id, return_value=SqlaTable())
|
||||
mocker.patch(can_access_datasource, return_value=True)
|
||||
mocker.patch(chart_find_by_id, return_value=None)
|
||||
check_access(dataset_id=1, chart_id=1, actor=User())
|
||||
check_chart_access(
|
||||
datasource_id=1,
|
||||
chart_id=1,
|
||||
actor=User(),
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
)
|
||||
|
||||
|
||||
def test_saved_chart_unauthorized_dataset(
|
||||
mocker: MockFixture, app_context: AppContext
|
||||
) -> None:
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.explore import utils
|
||||
from superset.explore.utils import check_access as check_chart_access
|
||||
|
||||
with raises(DatasetAccessDeniedError):
|
||||
mocker.patch(dataset_find_by_id, return_value=SqlaTable())
|
||||
mocker.patch(can_access_datasource, return_value=False)
|
||||
utils.check_access(dataset_id=1, chart_id=1, actor=User())
|
||||
check_chart_access(
|
||||
datasource_id=1,
|
||||
chart_id=1,
|
||||
actor=User(),
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
)
|
||||
|
||||
|
||||
def test_saved_chart_is_admin(mocker: MockFixture, app_context: AppContext) -> None:
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.explore.utils import check_access
|
||||
from superset.explore.utils import check_access as check_chart_access
|
||||
from superset.models.slice import Slice
|
||||
|
||||
mocker.patch(dataset_find_by_id, return_value=SqlaTable())
|
||||
mocker.patch(can_access_datasource, return_value=True)
|
||||
mocker.patch(is_user_admin, return_value=True)
|
||||
mocker.patch(chart_find_by_id, return_value=Slice())
|
||||
check_access(dataset_id=1, chart_id=1, actor=User())
|
||||
check_chart_access(
|
||||
datasource_id=1,
|
||||
chart_id=1,
|
||||
actor=User(),
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
)
|
||||
|
||||
|
||||
def test_saved_chart_is_owner(mocker: MockFixture, app_context: AppContext) -> None:
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.explore.utils import check_access
|
||||
from superset.explore.utils import check_access as check_chart_access
|
||||
from superset.models.slice import Slice
|
||||
|
||||
mocker.patch(dataset_find_by_id, return_value=SqlaTable())
|
||||
|
|
@ -125,12 +188,17 @@ def test_saved_chart_is_owner(mocker: MockFixture, app_context: AppContext) -> N
|
|||
mocker.patch(is_user_admin, return_value=False)
|
||||
mocker.patch(is_owner, return_value=True)
|
||||
mocker.patch(chart_find_by_id, return_value=Slice())
|
||||
check_access(dataset_id=1, chart_id=1, actor=User())
|
||||
check_chart_access(
|
||||
datasource_id=1,
|
||||
chart_id=1,
|
||||
actor=User(),
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
)
|
||||
|
||||
|
||||
def test_saved_chart_has_access(mocker: MockFixture, app_context: AppContext) -> None:
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.explore.utils import check_access
|
||||
from superset.explore.utils import check_access as check_chart_access
|
||||
from superset.models.slice import Slice
|
||||
|
||||
mocker.patch(dataset_find_by_id, return_value=SqlaTable())
|
||||
|
|
@ -139,12 +207,17 @@ def test_saved_chart_has_access(mocker: MockFixture, app_context: AppContext) ->
|
|||
mocker.patch(is_owner, return_value=False)
|
||||
mocker.patch(can_access, return_value=True)
|
||||
mocker.patch(chart_find_by_id, return_value=Slice())
|
||||
check_access(dataset_id=1, chart_id=1, actor=User())
|
||||
check_chart_access(
|
||||
datasource_id=1,
|
||||
chart_id=1,
|
||||
actor=User(),
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
)
|
||||
|
||||
|
||||
def test_saved_chart_no_access(mocker: MockFixture, app_context: AppContext) -> None:
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.explore.utils import check_access
|
||||
from superset.explore.utils import check_access as check_chart_access
|
||||
from superset.models.slice import Slice
|
||||
|
||||
with raises(ChartAccessDeniedError):
|
||||
|
|
@ -154,4 +227,66 @@ def test_saved_chart_no_access(mocker: MockFixture, app_context: AppContext) ->
|
|||
mocker.patch(is_owner, return_value=False)
|
||||
mocker.patch(can_access, return_value=False)
|
||||
mocker.patch(chart_find_by_id, return_value=Slice())
|
||||
check_access(dataset_id=1, chart_id=1, actor=User())
|
||||
check_chart_access(
|
||||
datasource_id=1,
|
||||
chart_id=1,
|
||||
actor=User(),
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
)
|
||||
|
||||
|
||||
def test_dataset_has_access(mocker: MockFixture, app_context: AppContext) -> None:
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.explore.utils import check_datasource_access
|
||||
|
||||
mocker.patch(dataset_find_by_id, return_value=SqlaTable())
|
||||
mocker.patch(can_access_datasource, return_value=True)
|
||||
mocker.patch(is_user_admin, return_value=False)
|
||||
mocker.patch(is_owner, return_value=False)
|
||||
mocker.patch(can_access, return_value=True)
|
||||
assert (
|
||||
check_datasource_access(
|
||||
datasource_id=1,
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
)
|
||||
== True
|
||||
)
|
||||
|
||||
|
||||
def test_query_has_access(mocker: MockFixture, app_context: AppContext) -> None:
|
||||
from superset.explore.utils import check_datasource_access
|
||||
from superset.models.sql_lab import Query
|
||||
|
||||
mocker.patch(query_find_by_id, return_value=Query())
|
||||
mocker.patch(raise_for_access, return_value=True)
|
||||
mocker.patch(is_user_admin, return_value=False)
|
||||
mocker.patch(is_owner, return_value=False)
|
||||
mocker.patch(can_access, return_value=True)
|
||||
assert (
|
||||
check_datasource_access(
|
||||
datasource_id=1,
|
||||
datasource_type=DatasourceType.QUERY,
|
||||
)
|
||||
== True
|
||||
)
|
||||
|
||||
|
||||
def test_query_no_access(mocker: MockFixture, app_context: AppContext) -> None:
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.explore.utils import check_datasource_access
|
||||
from superset.models.core import Database
|
||||
from superset.models.sql_lab import Query
|
||||
|
||||
with raises(SupersetSecurityException):
|
||||
mocker.patch(
|
||||
query_find_by_id,
|
||||
return_value=Query(database=Database(), sql="select * from foo"),
|
||||
)
|
||||
mocker.patch(query_datasources_by_name, return_value=[SqlaTable()])
|
||||
mocker.patch(is_user_admin, return_value=False)
|
||||
mocker.patch(is_owner, return_value=False)
|
||||
mocker.patch(can_access, return_value=False)
|
||||
check_datasource_access(
|
||||
datasource_id=1,
|
||||
datasource_type=DatasourceType.QUERY,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue