feat!: pass datasource_type and datasource_id to form_data (#19981)

* pass datasource_type and datasource_id to form_data

* add datasource_type to delete command

* add datasource_type to delete command

* fix old keys implementation

* add more tests
This commit is contained in:
Elizabeth Thompson 2022-06-02 16:48:16 -07:00 committed by GitHub
parent a813528958
commit 32bb1ce3ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
47 changed files with 959 additions and 176 deletions

View File

@ -67,10 +67,18 @@ export const URL_PARAMS = {
name: 'slice_id', name: 'slice_id',
type: 'string', type: 'string',
}, },
datasourceId: {
name: 'datasource_id',
type: 'string',
},
datasetId: { datasetId: {
name: 'dataset_id', name: 'dataset_id',
type: 'string', type: 'string',
}, },
datasourceType: {
name: 'datasource_type',
type: 'string',
},
dashboardId: { dashboardId: {
name: 'dashboard_id', name: 'dashboard_id',
type: 'string', type: 'string',
@ -88,6 +96,8 @@ export const URL_PARAMS = {
export const RESERVED_CHART_URL_PARAMS: string[] = [ export const RESERVED_CHART_URL_PARAMS: string[] = [
URL_PARAMS.formDataKey.name, URL_PARAMS.formDataKey.name,
URL_PARAMS.sliceId.name, URL_PARAMS.sliceId.name,
URL_PARAMS.datasourceId.name,
URL_PARAMS.datasourceType.name,
URL_PARAMS.datasetId.name, URL_PARAMS.datasetId.name,
]; ];
export const RESERVED_DASHBOARD_URL_PARAMS: string[] = [ export const RESERVED_DASHBOARD_URL_PARAMS: string[] = [

View File

@ -272,6 +272,7 @@ export default class Chart extends React.Component {
: undefined; : undefined;
const key = await postFormData( const key = await postFormData(
this.props.datasource.id, this.props.datasource.id,
this.props.datasource.type,
this.props.formData, this.props.formData,
this.props.slice.slice_id, this.props.slice.slice_id,
nextTabId, nextTabId,

View File

@ -92,7 +92,7 @@ test('generates a new form_data param when none is available', async () => {
expect(replaceState).toHaveBeenCalledWith( expect(replaceState).toHaveBeenCalledWith(
expect.anything(), expect.anything(),
undefined, undefined,
expect.stringMatching('dataset_id'), expect.stringMatching('datasource_id'),
); );
replaceState.mockRestore(); replaceState.mockRestore();
}); });
@ -109,7 +109,7 @@ test('generates a different form_data param when one is provided and is mounting
expect(replaceState).toHaveBeenCalledWith( expect(replaceState).toHaveBeenCalledWith(
expect.anything(), expect.anything(),
undefined, undefined,
expect.stringMatching('dataset_id'), expect.stringMatching('datasource_id'),
); );
replaceState.mockRestore(); replaceState.mockRestore();
}); });

View File

@ -152,14 +152,24 @@ const ExplorePanelContainer = styled.div`
`; `;
const updateHistory = debounce( const updateHistory = debounce(
async (formData, datasetId, isReplace, standalone, force, title, tabId) => { async (
formData,
datasourceId,
datasourceType,
isReplace,
standalone,
force,
title,
tabId,
) => {
const payload = { ...formData }; const payload = { ...formData };
const chartId = formData.slice_id; const chartId = formData.slice_id;
const additionalParam = {}; const additionalParam = {};
if (chartId) { if (chartId) {
additionalParam[URL_PARAMS.sliceId.name] = chartId; additionalParam[URL_PARAMS.sliceId.name] = chartId;
} else { } else {
additionalParam[URL_PARAMS.datasetId.name] = datasetId; additionalParam[URL_PARAMS.datasourceId.name] = datasourceId;
additionalParam[URL_PARAMS.datasourceType.name] = datasourceType;
} }
const urlParams = payload?.url_params || {}; const urlParams = payload?.url_params || {};
@ -173,11 +183,24 @@ const updateHistory = debounce(
let key; let key;
let stateModifier; let stateModifier;
if (isReplace) { if (isReplace) {
key = await postFormData(datasetId, formData, chartId, tabId); key = await postFormData(
datasourceId,
datasourceType,
formData,
chartId,
tabId,
);
stateModifier = 'replaceState'; stateModifier = 'replaceState';
} else { } else {
key = getUrlParam(URL_PARAMS.formDataKey); key = getUrlParam(URL_PARAMS.formDataKey);
await putFormData(datasetId, key, formData, chartId, tabId); await putFormData(
datasourceId,
datasourceType,
key,
formData,
chartId,
tabId,
);
stateModifier = 'pushState'; stateModifier = 'pushState';
} }
const url = mountExploreUrl( const url = mountExploreUrl(
@ -229,11 +252,12 @@ function ExploreViewContainer(props) {
dashboardId: props.dashboardId, dashboardId: props.dashboardId,
} }
: props.form_data; : props.form_data;
const datasetId = props.datasource.id; const { id: datasourceId, type: datasourceType } = props.datasource;
updateHistory( updateHistory(
formData, formData,
datasetId, datasourceId,
datasourceType,
isReplace, isReplace,
props.standalone, props.standalone,
props.force, props.force,
@ -245,6 +269,7 @@ function ExploreViewContainer(props) {
props.dashboardId, props.dashboardId,
props.form_data, props.form_data,
props.datasource.id, props.datasource.id,
props.datasource.type,
props.standalone, props.standalone,
props.force, props.force,
tabId, tabId,

View File

@ -189,9 +189,9 @@ class DatasourceControl extends React.PureComponent {
const isMissingDatasource = datasource.id == null; const isMissingDatasource = datasource.id == null;
let isMissingParams = false; let isMissingParams = false;
if (isMissingDatasource) { if (isMissingDatasource) {
const datasetId = getUrlParam(URL_PARAMS.datasetId); const datasourceId = getUrlParam(URL_PARAMS.datasourceId);
const sliceId = getUrlParam(URL_PARAMS.sliceId); const sliceId = getUrlParam(URL_PARAMS.sliceId);
if (!datasetId && !sliceId) { if (!datasourceId && !sliceId) {
isMissingParams = true; isMissingParams = true;
} }
} }

View File

@ -20,7 +20,8 @@ import { omit } from 'lodash';
import { SupersetClient, JsonObject } from '@superset-ui/core'; import { SupersetClient, JsonObject } from '@superset-ui/core';
type Payload = { type Payload = {
dataset_id: number; datasource_id: number;
datasource_type: string;
form_data: string; form_data: string;
chart_id?: number; chart_id?: number;
}; };
@ -42,12 +43,14 @@ const assembleEndpoint = (key?: string, tabId?: string) => {
}; };
const assemblePayload = ( const assemblePayload = (
datasetId: number, datasourceId: number,
datasourceType: string,
formData: JsonObject, formData: JsonObject,
chartId?: number, chartId?: number,
) => { ) => {
const payload: Payload = { const payload: Payload = {
dataset_id: datasetId, datasource_id: datasourceId,
datasource_type: datasourceType,
form_data: JSON.stringify(sanitizeFormData(formData)), form_data: JSON.stringify(sanitizeFormData(formData)),
}; };
if (chartId) { if (chartId) {
@ -57,18 +60,25 @@ const assemblePayload = (
}; };
export const postFormData = ( export const postFormData = (
datasetId: number, datasourceId: number,
datasourceType: string,
formData: JsonObject, formData: JsonObject,
chartId?: number, chartId?: number,
tabId?: string, tabId?: string,
): Promise<string> => ): Promise<string> =>
SupersetClient.post({ SupersetClient.post({
endpoint: assembleEndpoint(undefined, tabId), endpoint: assembleEndpoint(undefined, tabId),
jsonPayload: assemblePayload(datasetId, formData, chartId), jsonPayload: assemblePayload(
datasourceId,
datasourceType,
formData,
chartId,
),
}).then(r => r.json.key); }).then(r => r.json.key);
export const putFormData = ( export const putFormData = (
datasetId: number, datasourceId: number,
datasourceType: string,
key: string, key: string,
formData: JsonObject, formData: JsonObject,
chartId?: number, chartId?: number,
@ -76,5 +86,10 @@ export const putFormData = (
): Promise<string> => ): Promise<string> =>
SupersetClient.put({ SupersetClient.put({
endpoint: assembleEndpoint(key, tabId), endpoint: assembleEndpoint(key, tabId),
jsonPayload: assemblePayload(datasetId, formData, chartId), jsonPayload: assemblePayload(
datasourceId,
datasourceType,
formData,
chartId,
),
}).then(r => r.json.message); }).then(r => r.json.message);

View File

@ -22,6 +22,7 @@ from superset.charts.schemas import (
datasource_type_description, datasource_type_description,
datasource_uid_description, datasource_uid_description,
) )
from superset.utils.core import DatasourceType
class Datasource(Schema): class Datasource(Schema):
@ -36,7 +37,7 @@ class Datasource(Schema):
) )
datasource_type = fields.String( datasource_type = fields.String(
description=datasource_type_description, description=datasource_type_description,
validate=validate.OneOf(choices=("druid", "table", "view")), validate=validate.OneOf(choices=[ds.value for ds in DatasourceType]),
required=True, required=True,
) )

View File

@ -31,6 +31,7 @@ from superset.db_engine_specs.base import builtin_time_grains
from superset.utils import pandas_postprocessing, schema as utils from superset.utils import pandas_postprocessing, schema as utils
from superset.utils.core import ( from superset.utils.core import (
AnnotationType, AnnotationType,
DatasourceType,
FilterOperator, FilterOperator,
PostProcessingBoxplotWhiskerType, PostProcessingBoxplotWhiskerType,
PostProcessingContributionOrientation, PostProcessingContributionOrientation,
@ -198,7 +199,7 @@ class ChartPostSchema(Schema):
datasource_id = fields.Integer(description=datasource_id_description, required=True) datasource_id = fields.Integer(description=datasource_id_description, required=True)
datasource_type = fields.String( datasource_type = fields.String(
description=datasource_type_description, description=datasource_type_description,
validate=validate.OneOf(choices=("druid", "table", "view")), validate=validate.OneOf(choices=[ds.value for ds in DatasourceType]),
required=True, required=True,
) )
datasource_name = fields.String( datasource_name = fields.String(
@ -244,7 +245,7 @@ class ChartPutSchema(Schema):
) )
datasource_type = fields.String( datasource_type = fields.String(
description=datasource_type_description, description=datasource_type_description,
validate=validate.OneOf(choices=("druid", "table", "view")), validate=validate.OneOf(choices=[ds.value for ds in DatasourceType]),
allow_none=True, allow_none=True,
) )
dashboards = fields.List(fields.Integer(description=dashboards_description)) dashboards = fields.List(fields.Integer(description=dashboards_description))
@ -983,7 +984,7 @@ class ChartDataDatasourceSchema(Schema):
) )
type = fields.String( type = fields.String(
description="Datasource type", description="Datasource type",
validate=validate.OneOf(choices=("druid", "table")), validate=validate.OneOf(choices=[ds.value for ds in DatasourceType]),
) )

View File

@ -115,8 +115,24 @@ class RolesNotFoundValidationError(ValidationError):
super().__init__([_("Some roles do not exist")], field_name="roles") super().__init__([_("Some roles do not exist")], field_name="roles")
class DatasourceTypeInvalidError(ValidationError):
status = 422
def __init__(self) -> None:
super().__init__(
[_("Datasource type is invalid")], field_name="datasource_type"
)
class DatasourceNotFoundValidationError(ValidationError): class DatasourceNotFoundValidationError(ValidationError):
status = 404 status = 404
def __init__(self) -> None: def __init__(self) -> None:
super().__init__([_("Dataset does not exist")], field_name="datasource_id") super().__init__([_("Datasource does not exist")], field_name="datasource_id")
class QueryNotFoundValidationError(ValidationError):
status = 404
def __init__(self) -> None:
super().__init__([_("Query does not exist")], field_name="datasource_id")

View File

@ -39,11 +39,11 @@ Datasource = Union[Dataset, SqlaTable, Table, Query, SavedQuery]
class DatasourceDAO(BaseDAO): class DatasourceDAO(BaseDAO):
sources: Dict[DatasourceType, Type[Datasource]] = { sources: Dict[DatasourceType, Type[Datasource]] = {
DatasourceType.SQLATABLE: SqlaTable, DatasourceType.TABLE: SqlaTable,
DatasourceType.QUERY: Query, DatasourceType.QUERY: Query,
DatasourceType.SAVEDQUERY: SavedQuery, DatasourceType.SAVEDQUERY: SavedQuery,
DatasourceType.DATASET: Dataset, DatasourceType.DATASET: Dataset,
DatasourceType.TABLE: Table, DatasourceType.SLTABLE: Table,
} }
@classmethod @classmethod
@ -66,7 +66,7 @@ class DatasourceDAO(BaseDAO):
@classmethod @classmethod
def get_all_sqlatables_datasources(cls, session: Session) -> List[Datasource]: def get_all_sqlatables_datasources(cls, session: Session) -> List[Datasource]:
source_class = DatasourceDAO.sources[DatasourceType.SQLATABLE] source_class = DatasourceDAO.sources[DatasourceType.TABLE]
qry = session.query(source_class) qry = session.query(source_class)
qry = source_class.default_query(qry) qry = source_class.default_query(qry)
return qry.all() return qry.all()

View File

@ -24,6 +24,7 @@ from superset.models.core import Database
from superset.models.dashboard import Dashboard from superset.models.dashboard import Dashboard
from superset.models.slice import Slice from superset.models.slice import Slice
from superset.models.sql_lab import TabState from superset.models.sql_lab import TabState
from superset.utils.core import DatasourceType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -75,7 +76,8 @@ class DatabaseDAO(BaseDAO):
charts = ( charts = (
db.session.query(Slice) db.session.query(Slice)
.filter( .filter(
Slice.datasource_id.in_(dataset_ids), Slice.datasource_type == "table" Slice.datasource_id.in_(dataset_ids),
Slice.datasource_type == DatasourceType.TABLE,
) )
.all() .all()
) )

View File

@ -26,6 +26,7 @@ from superset.extensions import db
from superset.models.core import Database from superset.models.core import Database
from superset.models.dashboard import Dashboard from superset.models.dashboard import Dashboard
from superset.models.slice import Slice from superset.models.slice import Slice
from superset.utils.core import DatasourceType
from superset.views.base import DatasourceFilter from superset.views.base import DatasourceFilter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -56,7 +57,8 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
charts = ( charts = (
db.session.query(Slice) db.session.query(Slice)
.filter( .filter(
Slice.datasource_id == database_id, Slice.datasource_type == "table" Slice.datasource_id == database_id,
Slice.datasource_type == DatasourceType.TABLE,
) )
.all() .all()
) )

View File

@ -29,6 +29,7 @@ from superset.exceptions import NoDataException
from superset.models.core import Database from superset.models.core import Database
from superset.models.dashboard import Dashboard from superset.models.dashboard import Dashboard
from superset.models.slice import Slice from superset.models.slice import Slice
from superset.utils.core import DatasourceType
from ..utils.database import get_example_database from ..utils.database import get_example_database
from .helpers import ( from .helpers import (
@ -205,13 +206,16 @@ def create_slices(tbl: SqlaTable, admin_owner: bool) -> Tuple[List[Slice], List[
if admin_owner: if admin_owner:
slice_props = dict( slice_props = dict(
datasource_id=tbl.id, datasource_id=tbl.id,
datasource_type="table", datasource_type=DatasourceType.TABLE,
owners=[admin], owners=[admin],
created_by=admin, created_by=admin,
) )
else: else:
slice_props = dict( slice_props = dict(
datasource_id=tbl.id, datasource_type="table", owners=[], created_by=admin datasource_id=tbl.id,
datasource_type=DatasourceType.TABLE,
owners=[],
created_by=admin,
) )
print("Creating some slices") print("Creating some slices")

View File

@ -24,6 +24,7 @@ import superset.utils.database as database_utils
from superset import db from superset import db
from superset.connectors.sqla.models import SqlMetric from superset.connectors.sqla.models import SqlMetric
from superset.models.slice import Slice from superset.models.slice import Slice
from superset.utils.core import DatasourceType
from .helpers import ( from .helpers import (
get_example_data, get_example_data,
@ -112,7 +113,7 @@ def load_country_map_data(only_metadata: bool = False, force: bool = False) -> N
slc = Slice( slc = Slice(
slice_name="Birth in France by department in 2016", slice_name="Birth in France by department in 2016",
viz_type="country_map", viz_type="country_map",
datasource_type="table", datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json(slice_data), params=get_slice_json(slice_data),
) )

View File

@ -19,6 +19,7 @@ import json
from superset import db from superset import db
from superset.models.dashboard import Dashboard from superset.models.dashboard import Dashboard
from superset.models.slice import Slice from superset.models.slice import Slice
from superset.utils.core import DatasourceType
from .helpers import ( from .helpers import (
get_slice_json, get_slice_json,
@ -213,7 +214,7 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements
slc = Slice( slc = Slice(
slice_name="Deck.gl Scatterplot", slice_name="Deck.gl Scatterplot",
viz_type="deck_scatter", viz_type="deck_scatter",
datasource_type="table", datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json(slice_data), params=get_slice_json(slice_data),
) )
@ -248,7 +249,7 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements
slc = Slice( slc = Slice(
slice_name="Deck.gl Screen grid", slice_name="Deck.gl Screen grid",
viz_type="deck_screengrid", viz_type="deck_screengrid",
datasource_type="table", datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json(slice_data), params=get_slice_json(slice_data),
) )
@ -284,7 +285,7 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements
slc = Slice( slc = Slice(
slice_name="Deck.gl Hexagons", slice_name="Deck.gl Hexagons",
viz_type="deck_hex", viz_type="deck_hex",
datasource_type="table", datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json(slice_data), params=get_slice_json(slice_data),
) )
@ -321,7 +322,7 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements
slc = Slice( slc = Slice(
slice_name="Deck.gl Grid", slice_name="Deck.gl Grid",
viz_type="deck_grid", viz_type="deck_grid",
datasource_type="table", datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json(slice_data), params=get_slice_json(slice_data),
) )
@ -410,7 +411,7 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements
slc = Slice( slc = Slice(
slice_name="Deck.gl Polygons", slice_name="Deck.gl Polygons",
viz_type="deck_polygon", viz_type="deck_polygon",
datasource_type="table", datasource_type=DatasourceType.TABLE,
datasource_id=polygon_tbl.id, datasource_id=polygon_tbl.id,
params=get_slice_json(slice_data), params=get_slice_json(slice_data),
) )
@ -460,7 +461,7 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements
slc = Slice( slc = Slice(
slice_name="Deck.gl Arcs", slice_name="Deck.gl Arcs",
viz_type="deck_arc", viz_type="deck_arc",
datasource_type="table", datasource_type=DatasourceType.TABLE,
datasource_id=db.session.query(table) datasource_id=db.session.query(table)
.filter_by(table_name="flights") .filter_by(table_name="flights")
.first() .first()
@ -512,7 +513,7 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements
slc = Slice( slc = Slice(
slice_name="Deck.gl Path", slice_name="Deck.gl Path",
viz_type="deck_path", viz_type="deck_path",
datasource_type="table", datasource_type=DatasourceType.TABLE,
datasource_id=db.session.query(table) datasource_id=db.session.query(table)
.filter_by(table_name="bart_lines") .filter_by(table_name="bart_lines")
.first() .first()

View File

@ -25,6 +25,7 @@ import superset.utils.database as database_utils
from superset import db from superset import db
from superset.connectors.sqla.models import SqlMetric from superset.connectors.sqla.models import SqlMetric
from superset.models.slice import Slice from superset.models.slice import Slice
from superset.utils.core import DatasourceType
from .helpers import ( from .helpers import (
get_example_data, get_example_data,
@ -81,7 +82,7 @@ def load_energy(
slc = Slice( slc = Slice(
slice_name="Energy Sankey", slice_name="Energy Sankey",
viz_type="sankey", viz_type="sankey",
datasource_type="table", datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id, datasource_id=tbl.id,
params=textwrap.dedent( params=textwrap.dedent(
"""\ """\
@ -105,7 +106,7 @@ def load_energy(
slc = Slice( slc = Slice(
slice_name="Energy Force Layout", slice_name="Energy Force Layout",
viz_type="graph_chart", viz_type="graph_chart",
datasource_type="table", datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id, datasource_id=tbl.id,
params=textwrap.dedent( params=textwrap.dedent(
"""\ """\
@ -129,7 +130,7 @@ def load_energy(
slc = Slice( slc = Slice(
slice_name="Heatmap", slice_name="Heatmap",
viz_type="heatmap", viz_type="heatmap",
datasource_type="table", datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id, datasource_id=tbl.id,
params=textwrap.dedent( params=textwrap.dedent(
"""\ """\

View File

@ -24,6 +24,7 @@ from sqlalchemy import DateTime, Float, inspect, String
import superset.utils.database as database_utils import superset.utils.database as database_utils
from superset import db from superset import db
from superset.models.slice import Slice from superset.models.slice import Slice
from superset.utils.core import DatasourceType
from .helpers import ( from .helpers import (
get_example_data, get_example_data,
@ -113,7 +114,7 @@ def load_long_lat_data(only_metadata: bool = False, force: bool = False) -> None
slc = Slice( slc = Slice(
slice_name="Mapbox Long/Lat", slice_name="Mapbox Long/Lat",
viz_type="mapbox", viz_type="mapbox",
datasource_type="table", datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json(slice_data), params=get_slice_json(slice_data),
) )

View File

@ -18,6 +18,7 @@ import json
from superset import db from superset import db
from superset.models.slice import Slice from superset.models.slice import Slice
from superset.utils.core import DatasourceType
from .birth_names import load_birth_names from .birth_names import load_birth_names
from .helpers import merge_slice, misc_dash_slices from .helpers import merge_slice, misc_dash_slices
@ -35,7 +36,7 @@ def load_multi_line(only_metadata: bool = False) -> None:
] ]
slc = Slice( slc = Slice(
datasource_type="table", # not true, but needed datasource_type=DatasourceType.TABLE, # not true, but needed
datasource_id=1, # cannot be empty datasource_id=1, # cannot be empty
slice_name="Multi Line", slice_name="Multi Line",
viz_type="line_multi", viz_type="line_multi",

View File

@ -21,6 +21,7 @@ from sqlalchemy import BigInteger, Date, DateTime, inspect, String
from superset import app, db from superset import app, db
from superset.models.slice import Slice from superset.models.slice import Slice
from superset.utils.core import DatasourceType
from ..utils.database import get_example_database from ..utils.database import get_example_database
from .helpers import ( from .helpers import (
@ -120,7 +121,7 @@ def load_multiformat_time_series( # pylint: disable=too-many-locals
slc = Slice( slc = Slice(
slice_name=f"Calendar Heatmap multiformat {i}", slice_name=f"Calendar Heatmap multiformat {i}",
viz_type="cal_heatmap", viz_type="cal_heatmap",
datasource_type="table", datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json(slice_data), params=get_slice_json(slice_data),
) )

View File

@ -21,6 +21,7 @@ from sqlalchemy import DateTime, inspect, String
import superset.utils.database as database_utils import superset.utils.database as database_utils
from superset import app, db from superset import app, db
from superset.models.slice import Slice from superset.models.slice import Slice
from superset.utils.core import DatasourceType
from .helpers import ( from .helpers import (
get_example_data, get_example_data,
@ -89,7 +90,7 @@ def load_random_time_series_data(
slc = Slice( slc = Slice(
slice_name="Calendar Heatmap", slice_name="Calendar Heatmap",
viz_type="cal_heatmap", viz_type="cal_heatmap",
datasource_type="table", datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json(slice_data), params=get_slice_json(slice_data),
) )

View File

@ -29,6 +29,7 @@ from superset.connectors.sqla.models import SqlMetric
from superset.models.dashboard import Dashboard from superset.models.dashboard import Dashboard
from superset.models.slice import Slice from superset.models.slice import Slice
from superset.utils import core as utils from superset.utils import core as utils
from superset.utils.core import DatasourceType
from ..connectors.base.models import BaseDatasource from ..connectors.base.models import BaseDatasource
from .helpers import ( from .helpers import (
@ -172,7 +173,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]:
Slice( Slice(
slice_name="Region Filter", slice_name="Region Filter",
viz_type="filter_box", viz_type="filter_box",
datasource_type="table", datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
@ -201,7 +202,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]:
Slice( Slice(
slice_name="World's Population", slice_name="World's Population",
viz_type="big_number", viz_type="big_number",
datasource_type="table", datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
@ -215,7 +216,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]:
Slice( Slice(
slice_name="Most Populated Countries", slice_name="Most Populated Countries",
viz_type="table", viz_type="table",
datasource_type="table", datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
@ -227,7 +228,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]:
Slice( Slice(
slice_name="Growth Rate", slice_name="Growth Rate",
viz_type="line", viz_type="line",
datasource_type="table", datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
@ -241,7 +242,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]:
Slice( Slice(
slice_name="% Rural", slice_name="% Rural",
viz_type="world_map", viz_type="world_map",
datasource_type="table", datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
@ -254,7 +255,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]:
Slice( Slice(
slice_name="Life Expectancy VS Rural %", slice_name="Life Expectancy VS Rural %",
viz_type="bubble", viz_type="bubble",
datasource_type="table", datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
@ -298,7 +299,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]:
Slice( Slice(
slice_name="Rural Breakdown", slice_name="Rural Breakdown",
viz_type="sunburst", viz_type="sunburst",
datasource_type="table", datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
@ -313,7 +314,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]:
Slice( Slice(
slice_name="World's Pop Growth", slice_name="World's Pop Growth",
viz_type="area", viz_type="area",
datasource_type="table", datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
@ -327,7 +328,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]:
Slice( Slice(
slice_name="Box plot", slice_name="Box plot",
viz_type="box_plot", viz_type="box_plot",
datasource_type="table", datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
@ -343,7 +344,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]:
Slice( Slice(
slice_name="Treemap", slice_name="Treemap",
viz_type="treemap", viz_type="treemap",
datasource_type="table", datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
@ -357,7 +358,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]:
Slice( Slice(
slice_name="Parallel Coordinates", slice_name="Parallel Coordinates",
viz_type="para", viz_type="para",
datasource_type="table", datasource_type=DatasourceType.TABLE,
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,

View File

@ -104,7 +104,8 @@ class ExploreFormDataRestApi(BaseApi, ABC):
tab_id = request.args.get("tab_id") tab_id = request.args.get("tab_id")
args = CommandParameters( args = CommandParameters(
actor=g.user, actor=g.user,
dataset_id=item["dataset_id"], datasource_id=item["datasource_id"],
datasource_type=item["datasource_type"],
chart_id=item.get("chart_id"), chart_id=item.get("chart_id"),
tab_id=tab_id, tab_id=tab_id,
form_data=item["form_data"], form_data=item["form_data"],
@ -123,7 +124,7 @@ class ExploreFormDataRestApi(BaseApi, ABC):
@safe @safe
@event_logger.log_this_with_context( @event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.put", action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.put",
log_to_statsd=False, log_to_statsd=True,
) )
@requires_json @requires_json
def put(self, key: str) -> Response: def put(self, key: str) -> Response:
@ -174,7 +175,8 @@ class ExploreFormDataRestApi(BaseApi, ABC):
tab_id = request.args.get("tab_id") tab_id = request.args.get("tab_id")
args = CommandParameters( args = CommandParameters(
actor=g.user, actor=g.user,
dataset_id=item["dataset_id"], datasource_id=item["datasource_id"],
datasource_type=item["datasource_type"],
chart_id=item.get("chart_id"), chart_id=item.get("chart_id"),
tab_id=tab_id, tab_id=tab_id,
key=key, key=key,
@ -196,7 +198,7 @@ class ExploreFormDataRestApi(BaseApi, ABC):
@safe @safe
@event_logger.log_this_with_context( @event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.get", action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.get",
log_to_statsd=False, log_to_statsd=True,
) )
def get(self, key: str) -> Response: def get(self, key: str) -> Response:
"""Retrives a form_data. """Retrives a form_data.
@ -247,7 +249,7 @@ class ExploreFormDataRestApi(BaseApi, ABC):
@safe @safe
@event_logger.log_this_with_context( @event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.delete", action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.delete",
log_to_statsd=False, log_to_statsd=True,
) )
def delete(self, key: str) -> Response: def delete(self, key: str) -> Response:
"""Deletes a form_data. """Deletes a form_data.

View File

@ -39,20 +39,24 @@ class CreateFormDataCommand(BaseCommand):
def run(self) -> str: def run(self) -> str:
self.validate() self.validate()
try: try:
dataset_id = self._cmd_params.dataset_id datasource_id = self._cmd_params.datasource_id
datasource_type = self._cmd_params.datasource_type
chart_id = self._cmd_params.chart_id chart_id = self._cmd_params.chart_id
tab_id = self._cmd_params.tab_id tab_id = self._cmd_params.tab_id
actor = self._cmd_params.actor actor = self._cmd_params.actor
form_data = self._cmd_params.form_data form_data = self._cmd_params.form_data
check_access(dataset_id, chart_id, actor) check_access(datasource_id, chart_id, actor, datasource_type)
contextual_key = cache_key(session.get("_id"), tab_id, dataset_id, chart_id) contextual_key = cache_key(
session.get("_id"), tab_id, datasource_id, chart_id, datasource_type
)
key = cache_manager.explore_form_data_cache.get(contextual_key) key = cache_manager.explore_form_data_cache.get(contextual_key)
if not key or not tab_id: if not key or not tab_id:
key = random_key() key = random_key()
if form_data: if form_data:
state: TemporaryExploreState = { state: TemporaryExploreState = {
"owner": get_owner(actor), "owner": get_owner(actor),
"dataset_id": dataset_id, "datasource_id": datasource_id,
"datasource_type": datasource_type,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": form_data, "form_data": form_data,
} }

View File

@ -16,6 +16,7 @@
# under the License. # under the License.
import logging import logging
from abc import ABC from abc import ABC
from typing import Optional
from flask import session from flask import session
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
@ -31,6 +32,7 @@ from superset.temporary_cache.commands.exceptions import (
TemporaryCacheDeleteFailedError, TemporaryCacheDeleteFailedError,
) )
from superset.temporary_cache.utils import cache_key from superset.temporary_cache.utils import cache_key
from superset.utils.core import DatasourceType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -47,14 +49,15 @@ class DeleteFormDataCommand(BaseCommand, ABC):
key key
) )
if state: if state:
dataset_id = state["dataset_id"] datasource_id: int = state["datasource_id"]
chart_id = state["chart_id"] chart_id: Optional[int] = state["chart_id"]
check_access(dataset_id, chart_id, actor) datasource_type = DatasourceType(state["datasource_type"])
check_access(datasource_id, chart_id, actor, datasource_type)
if state["owner"] != get_owner(actor): if state["owner"] != get_owner(actor):
raise TemporaryCacheAccessDeniedError() raise TemporaryCacheAccessDeniedError()
tab_id = self._cmd_params.tab_id tab_id = self._cmd_params.tab_id
contextual_key = cache_key( contextual_key = cache_key(
session.get("_id"), tab_id, dataset_id, chart_id session.get("_id"), tab_id, datasource_id, chart_id, datasource_type
) )
cache_manager.explore_form_data_cache.delete(contextual_key) cache_manager.explore_form_data_cache.delete(contextual_key)
return cache_manager.explore_form_data_cache.delete(key) return cache_manager.explore_form_data_cache.delete(key)

View File

@ -27,6 +27,7 @@ from superset.explore.form_data.commands.state import TemporaryExploreState
from superset.explore.form_data.commands.utils import check_access from superset.explore.form_data.commands.utils import check_access
from superset.extensions import cache_manager from superset.extensions import cache_manager
from superset.temporary_cache.commands.exceptions import TemporaryCacheGetFailedError from superset.temporary_cache.commands.exceptions import TemporaryCacheGetFailedError
from superset.utils.core import DatasourceType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -45,7 +46,12 @@ class GetFormDataCommand(BaseCommand, ABC):
key key
) )
if state: if state:
check_access(state["dataset_id"], state["chart_id"], actor) check_access(
state["datasource_id"],
state["chart_id"],
actor,
DatasourceType(state["datasource_type"]),
)
if self._refresh_timeout: if self._refresh_timeout:
cache_manager.explore_form_data_cache.set(key, state) cache_manager.explore_form_data_cache.set(key, state)
return state["form_data"] return state["form_data"]

View File

@ -19,11 +19,14 @@ from typing import Optional
from flask_appbuilder.security.sqla.models import User from flask_appbuilder.security.sqla.models import User
from superset.utils.core import DatasourceType
@dataclass @dataclass
class CommandParameters: class CommandParameters:
actor: User actor: User
dataset_id: int = 0 datasource_type: DatasourceType = DatasourceType.TABLE
datasource_id: int = 0
chart_id: int = 0 chart_id: int = 0
tab_id: Optional[int] = None tab_id: Optional[int] = None
key: Optional[str] = None key: Optional[str] = None

View File

@ -21,6 +21,7 @@ from typing_extensions import TypedDict
class TemporaryExploreState(TypedDict): class TemporaryExploreState(TypedDict):
owner: Optional[int] owner: Optional[int]
dataset_id: int datasource_id: int
datasource_type: str
chart_id: Optional[int] chart_id: Optional[int]
form_data: str form_data: str

View File

@ -47,12 +47,13 @@ class UpdateFormDataCommand(BaseCommand, ABC):
def run(self) -> Optional[str]: def run(self) -> Optional[str]:
self.validate() self.validate()
try: try:
dataset_id = self._cmd_params.dataset_id datasource_id = self._cmd_params.datasource_id
chart_id = self._cmd_params.chart_id chart_id = self._cmd_params.chart_id
datasource_type = self._cmd_params.datasource_type
actor = self._cmd_params.actor actor = self._cmd_params.actor
key = self._cmd_params.key key = self._cmd_params.key
form_data = self._cmd_params.form_data form_data = self._cmd_params.form_data
check_access(dataset_id, chart_id, actor) check_access(datasource_id, chart_id, actor, datasource_type)
state: TemporaryExploreState = cache_manager.explore_form_data_cache.get( state: TemporaryExploreState = cache_manager.explore_form_data_cache.get(
key key
) )
@ -64,7 +65,7 @@ class UpdateFormDataCommand(BaseCommand, ABC):
# Generate a new key if tab_id changes or equals 0 # Generate a new key if tab_id changes or equals 0
tab_id = self._cmd_params.tab_id tab_id = self._cmd_params.tab_id
contextual_key = cache_key( contextual_key = cache_key(
session.get("_id"), tab_id, dataset_id, chart_id session.get("_id"), tab_id, datasource_id, chart_id, datasource_type
) )
key = cache_manager.explore_form_data_cache.get(contextual_key) key = cache_manager.explore_form_data_cache.get(contextual_key)
if not key or not tab_id: if not key or not tab_id:
@ -73,7 +74,8 @@ class UpdateFormDataCommand(BaseCommand, ABC):
new_state: TemporaryExploreState = { new_state: TemporaryExploreState = {
"owner": owner, "owner": owner,
"dataset_id": dataset_id, "datasource_id": datasource_id,
"datasource_type": datasource_type,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": form_data, "form_data": form_data,
} }

View File

@ -31,11 +31,17 @@ from superset.temporary_cache.commands.exceptions import (
TemporaryCacheAccessDeniedError, TemporaryCacheAccessDeniedError,
TemporaryCacheResourceNotFoundError, TemporaryCacheResourceNotFoundError,
) )
from superset.utils.core import DatasourceType
def check_access(dataset_id: int, chart_id: Optional[int], actor: User) -> None: def check_access(
datasource_id: int,
chart_id: Optional[int],
actor: User,
datasource_type: DatasourceType,
) -> None:
try: try:
explore_check_access(dataset_id, chart_id, actor) explore_check_access(datasource_id, chart_id, actor, datasource_type)
except (ChartNotFoundError, DatasetNotFoundError) as ex: except (ChartNotFoundError, DatasetNotFoundError) as ex:
raise TemporaryCacheResourceNotFoundError from ex raise TemporaryCacheResourceNotFoundError from ex
except (ChartAccessDeniedError, DatasetAccessDeniedError) as ex: except (ChartAccessDeniedError, DatasetAccessDeniedError) as ex:

View File

@ -14,12 +14,20 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from marshmallow import fields, Schema from marshmallow import fields, Schema, validate
from superset.utils.core import DatasourceType
class FormDataPostSchema(Schema): class FormDataPostSchema(Schema):
dataset_id = fields.Integer( datasource_id = fields.Integer(
required=True, allow_none=False, description="The dataset ID" required=True, allow_none=False, description="The datasource ID"
)
datasource_type = fields.String(
required=True,
allow_none=False,
description="The datasource type",
validate=validate.OneOf(choices=[ds.value for ds in DatasourceType]),
) )
chart_id = fields.Integer(required=False, description="The chart ID") chart_id = fields.Integer(required=False, description="The chart ID")
form_data = fields.String( form_data = fields.String(
@ -28,8 +36,14 @@ class FormDataPostSchema(Schema):
class FormDataPutSchema(Schema): class FormDataPutSchema(Schema):
dataset_id = fields.Integer( datasource_id = fields.Integer(
required=True, allow_none=False, description="The dataset ID" required=True, allow_none=False, description="The datasource ID"
)
datasource_type = fields.String(
required=True,
allow_none=False,
description="The datasource type",
validate=validate.OneOf(choices=[ds.value for ds in DatasourceType]),
) )
chart_id = fields.Integer(required=False, description="The chart ID") chart_id = fields.Integer(required=False, description="The chart ID")
form_data = fields.String( form_data = fields.String(

View File

@ -22,9 +22,10 @@ from sqlalchemy.exc import SQLAlchemyError
from superset.explore.permalink.commands.base import BaseExplorePermalinkCommand from superset.explore.permalink.commands.base import BaseExplorePermalinkCommand
from superset.explore.permalink.exceptions import ExplorePermalinkCreateFailedError from superset.explore.permalink.exceptions import ExplorePermalinkCreateFailedError
from superset.explore.utils import check_access from superset.explore.utils import check_access as check_chart_access
from superset.key_value.commands.create import CreateKeyValueCommand from superset.key_value.commands.create import CreateKeyValueCommand
from superset.key_value.utils import encode_permalink_key from superset.key_value.utils import encode_permalink_key
from superset.utils.core import DatasourceType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -39,11 +40,16 @@ class CreateExplorePermalinkCommand(BaseExplorePermalinkCommand):
def run(self) -> str: def run(self) -> str:
self.validate() self.validate()
try: try:
dataset_id = int(self.datasource.split("__")[0]) d_id, d_type = self.datasource.split("__")
check_access(dataset_id, self.chart_id, self.actor) datasource_id = int(d_id)
datasource_type = DatasourceType(d_type)
check_chart_access(
datasource_id, self.chart_id, self.actor, datasource_type
)
value = { value = {
"chartId": self.chart_id, "chartId": self.chart_id,
"datasetId": dataset_id, "datasourceId": datasource_id,
"datasourceType": datasource_type,
"datasource": self.datasource, "datasource": self.datasource,
"state": self.state, "state": self.state,
} }

View File

@ -24,10 +24,11 @@ from superset.datasets.commands.exceptions import DatasetNotFoundError
from superset.explore.permalink.commands.base import BaseExplorePermalinkCommand from superset.explore.permalink.commands.base import BaseExplorePermalinkCommand
from superset.explore.permalink.exceptions import ExplorePermalinkGetFailedError from superset.explore.permalink.exceptions import ExplorePermalinkGetFailedError
from superset.explore.permalink.types import ExplorePermalinkValue from superset.explore.permalink.types import ExplorePermalinkValue
from superset.explore.utils import check_access from superset.explore.utils import check_access as check_chart_access
from superset.key_value.commands.get import GetKeyValueCommand from superset.key_value.commands.get import GetKeyValueCommand
from superset.key_value.exceptions import KeyValueGetFailedError, KeyValueParseKeyError from superset.key_value.exceptions import KeyValueGetFailedError, KeyValueParseKeyError
from superset.key_value.utils import decode_permalink_id from superset.key_value.utils import decode_permalink_id
from superset.utils.core import DatasourceType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -47,8 +48,9 @@ class GetExplorePermalinkCommand(BaseExplorePermalinkCommand):
).run() ).run()
if value: if value:
chart_id: Optional[int] = value.get("chartId") chart_id: Optional[int] = value.get("chartId")
dataset_id = value["datasetId"] datasource_id: int = value["datasourceId"]
check_access(dataset_id, chart_id, self.actor) datasource_type = DatasourceType(value["datasourceType"])
check_chart_access(datasource_id, chart_id, self.actor, datasource_type)
return value return value
return None return None
except ( except (

View File

@ -24,6 +24,7 @@ class ExplorePermalinkState(TypedDict, total=False):
class ExplorePermalinkValue(TypedDict): class ExplorePermalinkValue(TypedDict):
chartId: Optional[int] chartId: Optional[int]
datasetId: int datasourceId: int
datasourceType: str
datasource: str datasource: str
state: ExplorePermalinkState state: ExplorePermalinkState

View File

@ -24,11 +24,18 @@ from superset.charts.commands.exceptions import (
ChartNotFoundError, ChartNotFoundError,
) )
from superset.charts.dao import ChartDAO from superset.charts.dao import ChartDAO
from superset.commands.exceptions import (
DatasourceNotFoundValidationError,
DatasourceTypeInvalidError,
QueryNotFoundValidationError,
)
from superset.datasets.commands.exceptions import ( from superset.datasets.commands.exceptions import (
DatasetAccessDeniedError, DatasetAccessDeniedError,
DatasetNotFoundError, DatasetNotFoundError,
) )
from superset.datasets.dao import DatasetDAO from superset.datasets.dao import DatasetDAO
from superset.queries.dao import QueryDAO
from superset.utils.core import DatasourceType
from superset.views.base import is_user_admin from superset.views.base import is_user_admin
from superset.views.utils import is_owner from superset.views.utils import is_owner
@ -44,10 +51,41 @@ def check_dataset_access(dataset_id: int) -> Optional[bool]:
raise DatasetNotFoundError() raise DatasetNotFoundError()
def check_access(dataset_id: int, chart_id: Optional[int], actor: User) -> None: def check_query_access(query_id: int) -> Optional[bool]:
check_dataset_access(dataset_id) if query_id:
query = QueryDAO.find_by_id(query_id)
if query:
security_manager.raise_for_access(query=query)
return True
raise QueryNotFoundValidationError()
ACCESS_FUNCTION_MAP = {
DatasourceType.TABLE: check_dataset_access,
DatasourceType.QUERY: check_query_access,
}
def check_datasource_access(
datasource_id: int, datasource_type: DatasourceType
) -> Optional[bool]:
if datasource_id:
try:
return ACCESS_FUNCTION_MAP[datasource_type](datasource_id)
except KeyError as ex:
raise DatasourceTypeInvalidError() from ex
raise DatasourceNotFoundValidationError()
def check_access(
datasource_id: int,
chart_id: Optional[int],
actor: User,
datasource_type: DatasourceType,
) -> Optional[bool]:
check_datasource_access(datasource_id, datasource_type)
if not chart_id: if not chart_id:
return return True
chart = ChartDAO.find_by_id(chart_id) chart = ChartDAO.find_by_id(chart_id)
if chart: if chart:
can_access_chart = ( can_access_chart = (
@ -56,6 +94,6 @@ def check_access(dataset_id: int, chart_id: Optional[int], actor: User) -> None:
or security_manager.can_access("can_read", "Chart") or security_manager.can_access("can_read", "Chart")
) )
if can_access_chart: if can_access_chart:
return return True
raise ChartAccessDeniedError() raise ChartAccessDeniedError()
raise ChartNotFoundError() raise ChartNotFoundError()

View File

@ -15,15 +15,40 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from typing import Any, Optional, Union
from flask import Flask from flask import Flask
from flask_caching import Cache from flask_caching import Cache
from markupsafe import Markup
from superset.utils.core import DatasourceType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
CACHE_IMPORT_PATH = "superset.extensions.metastore_cache.SupersetMetastoreCache" CACHE_IMPORT_PATH = "superset.extensions.metastore_cache.SupersetMetastoreCache"
class ExploreFormDataCache(Cache):
def get(self, *args: Any, **kwargs: Any) -> Optional[Union[str, Markup]]:
cache = self.cache.get(*args, **kwargs)
if not cache:
return None
# rename data keys for existing cache based on new TemporaryExploreState model
if isinstance(cache, dict):
cache = {
("datasource_id" if key == "dataset_id" else key): value
for (key, value) in cache.items()
}
# add default datasource_type if it doesn't exist
# temporarily defaulting to table until sqlatables are deprecated
if "datasource_type" not in cache:
cache["datasource_type"] = DatasourceType.TABLE
return cache
class CacheManager: class CacheManager:
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -32,7 +57,7 @@ class CacheManager:
self._data_cache = Cache() self._data_cache = Cache()
self._thumbnail_cache = Cache() self._thumbnail_cache = Cache()
self._filter_state_cache = Cache() self._filter_state_cache = Cache()
self._explore_form_data_cache = Cache() self._explore_form_data_cache = ExploreFormDataCache()
@staticmethod @staticmethod
def _init_cache( def _init_cache(

View File

@ -175,12 +175,13 @@ class GenericDataType(IntEnum):
# ROW = 7 # ROW = 7
class DatasourceType(Enum): class DatasourceType(str, Enum):
SQLATABLE = "sqlatable" SLTABLE = "sl_table"
TABLE = "table" TABLE = "table"
DATASET = "dataset" DATASET = "dataset"
QUERY = "query" QUERY = "query"
SAVEDQUERY = "saved_query" SAVEDQUERY = "saved_query"
VIEW = "view"
class DatasourceDict(TypedDict): class DatasourceDict(TypedDict):

View File

@ -520,7 +520,13 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
response = json.loads(rv.data.decode("utf-8")) response = json.loads(rv.data.decode("utf-8"))
self.assertEqual( self.assertEqual(
response, response,
{"message": {"datasource_type": ["Must be one of: druid, table, view."]}}, {
"message": {
"datasource_type": [
"Must be one of: sl_table, table, dataset, query, saved_query, view."
]
}
},
) )
chart_data = { chart_data = {
"slice_name": "title1", "slice_name": "title1",
@ -531,7 +537,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
self.assertEqual(rv.status_code, 422) self.assertEqual(rv.status_code, 422)
response = json.loads(rv.data.decode("utf-8")) response = json.loads(rv.data.decode("utf-8"))
self.assertEqual( self.assertEqual(
response, {"message": {"datasource_id": ["Dataset does not exist"]}} response, {"message": {"datasource_id": ["Datasource does not exist"]}}
) )
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@ -686,7 +692,13 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
response = json.loads(rv.data.decode("utf-8")) response = json.loads(rv.data.decode("utf-8"))
self.assertEqual( self.assertEqual(
response, response,
{"message": {"datasource_type": ["Must be one of: druid, table, view."]}}, {
"message": {
"datasource_type": [
"Must be one of: sl_table, table, dataset, query, saved_query, view."
]
}
},
) )
chart_data = {"datasource_id": 0, "datasource_type": "table"} chart_data = {"datasource_id": 0, "datasource_type": "table"}
@ -694,7 +706,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
self.assertEqual(rv.status_code, 422) self.assertEqual(rv.status_code, 422)
response = json.loads(rv.data.decode("utf-8")) response = json.loads(rv.data.decode("utf-8"))
self.assertEqual( self.assertEqual(
response, {"message": {"datasource_id": ["Dataset does not exist"]}} response, {"message": {"datasource_id": ["Datasource does not exist"]}}
) )
db.session.delete(chart) db.session.delete(chart)

View File

@ -26,7 +26,7 @@ from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database from superset.models.core import Database
from superset.models.dashboard import Dashboard from superset.models.dashboard import Dashboard
from superset.models.slice import Slice from superset.models.slice import Slice
from superset.utils.core import get_example_default_schema from superset.utils.core import DatasourceType, get_example_default_schema
def get_table( def get_table(
@ -72,7 +72,7 @@ def create_slice(
return Slice( return Slice(
slice_name=title, slice_name=title,
viz_type=viz_type, viz_type=viz_type,
datasource_type="table", datasource_type=DatasourceType.TABLE,
datasource_id=table.id, datasource_id=table.id,
params=json.dumps(slices_dict, indent=4, sort_keys=True), params=json.dumps(slices_dict, indent=4, sort_keys=True),
) )

View File

@ -56,7 +56,7 @@ def admin_id() -> int:
@pytest.fixture @pytest.fixture
def dataset_id() -> int: def datasource() -> int:
with app.app_context() as ctx: with app.app_context() as ctx:
session: Session = ctx.app.appbuilder.get_session session: Session = ctx.app.appbuilder.get_session
dataset = ( dataset = (
@ -64,24 +64,26 @@ def dataset_id() -> int:
.filter_by(table_name="wb_health_population") .filter_by(table_name="wb_health_population")
.first() .first()
) )
return dataset.id return dataset
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def cache(chart_id, admin_id, dataset_id): def cache(chart_id, admin_id, datasource):
entry: TemporaryExploreState = { entry: TemporaryExploreState = {
"owner": admin_id, "owner": admin_id,
"dataset_id": dataset_id, "datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": INITIAL_FORM_DATA, "form_data": INITIAL_FORM_DATA,
} }
cache_manager.explore_form_data_cache.set(KEY, entry) cache_manager.explore_form_data_cache.set(KEY, entry)
def test_post(client, chart_id: int, dataset_id: int): def test_post(client, chart_id: int, datasource: SqlaTable):
login(client, "admin") login(client, "admin")
payload = { payload = {
"dataset_id": dataset_id, "datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": INITIAL_FORM_DATA, "form_data": INITIAL_FORM_DATA,
} }
@ -89,10 +91,11 @@ def test_post(client, chart_id: int, dataset_id: int):
assert resp.status_code == 201 assert resp.status_code == 201
def test_post_bad_request_non_string(client, chart_id: int, dataset_id: int): def test_post_bad_request_non_string(client, chart_id: int, datasource: SqlaTable):
login(client, "admin") login(client, "admin")
payload = { payload = {
"dataset_id": dataset_id, "datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": 1234, "form_data": 1234,
} }
@ -100,10 +103,11 @@ def test_post_bad_request_non_string(client, chart_id: int, dataset_id: int):
assert resp.status_code == 400 assert resp.status_code == 400
def test_post_bad_request_non_json_string(client, chart_id: int, dataset_id: int): def test_post_bad_request_non_json_string(client, chart_id: int, datasource: SqlaTable):
login(client, "admin") login(client, "admin")
payload = { payload = {
"dataset_id": dataset_id, "datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": "foo", "form_data": "foo",
} }
@ -111,10 +115,11 @@ def test_post_bad_request_non_json_string(client, chart_id: int, dataset_id: int
assert resp.status_code == 400 assert resp.status_code == 400
def test_post_access_denied(client, chart_id: int, dataset_id: int): def test_post_access_denied(client, chart_id: int, datasource: SqlaTable):
login(client, "gamma") login(client, "gamma")
payload = { payload = {
"dataset_id": dataset_id, "datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": INITIAL_FORM_DATA, "form_data": INITIAL_FORM_DATA,
} }
@ -122,10 +127,11 @@ def test_post_access_denied(client, chart_id: int, dataset_id: int):
assert resp.status_code == 404 assert resp.status_code == 404
def test_post_same_key_for_same_context(client, chart_id: int, dataset_id: int): def test_post_same_key_for_same_context(client, chart_id: int, datasource: SqlaTable):
login(client, "admin") login(client, "admin")
payload = { payload = {
"dataset_id": dataset_id, "datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": UPDATED_FORM_DATA, "form_data": UPDATED_FORM_DATA,
} }
@ -139,11 +145,12 @@ def test_post_same_key_for_same_context(client, chart_id: int, dataset_id: int):
def test_post_different_key_for_different_context( def test_post_different_key_for_different_context(
client, chart_id: int, dataset_id: int client, chart_id: int, datasource: SqlaTable
): ):
login(client, "admin") login(client, "admin")
payload = { payload = {
"dataset_id": dataset_id, "datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": UPDATED_FORM_DATA, "form_data": UPDATED_FORM_DATA,
} }
@ -151,7 +158,8 @@ def test_post_different_key_for_different_context(
data = json.loads(resp.data.decode("utf-8")) data = json.loads(resp.data.decode("utf-8"))
first_key = data.get("key") first_key = data.get("key")
payload = { payload = {
"dataset_id": dataset_id, "datasource_id": datasource.id,
"datasource_type": datasource.type,
"form_data": json.dumps({"test": "initial value"}), "form_data": json.dumps({"test": "initial value"}),
} }
resp = client.post("api/v1/explore/form_data?tab_id=1", json=payload) resp = client.post("api/v1/explore/form_data?tab_id=1", json=payload)
@ -160,10 +168,11 @@ def test_post_different_key_for_different_context(
assert first_key != second_key assert first_key != second_key
def test_post_same_key_for_same_tab_id(client, chart_id: int, dataset_id: int): def test_post_same_key_for_same_tab_id(client, chart_id: int, datasource: SqlaTable):
login(client, "admin") login(client, "admin")
payload = { payload = {
"dataset_id": dataset_id, "datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": json.dumps({"test": "initial value"}), "form_data": json.dumps({"test": "initial value"}),
} }
@ -177,11 +186,12 @@ def test_post_same_key_for_same_tab_id(client, chart_id: int, dataset_id: int):
def test_post_different_key_for_different_tab_id( def test_post_different_key_for_different_tab_id(
client, chart_id: int, dataset_id: int client, chart_id: int, datasource: SqlaTable
): ):
login(client, "admin") login(client, "admin")
payload = { payload = {
"dataset_id": dataset_id, "datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": json.dumps({"test": "initial value"}), "form_data": json.dumps({"test": "initial value"}),
} }
@ -194,10 +204,11 @@ def test_post_different_key_for_different_tab_id(
assert first_key != second_key assert first_key != second_key
def test_post_different_key_for_no_tab_id(client, chart_id: int, dataset_id: int): def test_post_different_key_for_no_tab_id(client, chart_id: int, datasource: SqlaTable):
login(client, "admin") login(client, "admin")
payload = { payload = {
"dataset_id": dataset_id, "datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": INITIAL_FORM_DATA, "form_data": INITIAL_FORM_DATA,
} }
@ -210,10 +221,11 @@ def test_post_different_key_for_no_tab_id(client, chart_id: int, dataset_id: int
assert first_key != second_key assert first_key != second_key
def test_put(client, chart_id: int, dataset_id: int): def test_put(client, chart_id: int, datasource: SqlaTable):
login(client, "admin") login(client, "admin")
payload = { payload = {
"dataset_id": dataset_id, "datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": UPDATED_FORM_DATA, "form_data": UPDATED_FORM_DATA,
} }
@ -221,10 +233,11 @@ def test_put(client, chart_id: int, dataset_id: int):
assert resp.status_code == 200 assert resp.status_code == 200
def test_put_same_key_for_same_tab_id(client, chart_id: int, dataset_id: int): def test_put_same_key_for_same_tab_id(client, chart_id: int, datasource: SqlaTable):
login(client, "admin") login(client, "admin")
payload = { payload = {
"dataset_id": dataset_id, "datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": UPDATED_FORM_DATA, "form_data": UPDATED_FORM_DATA,
} }
@ -237,10 +250,13 @@ def test_put_same_key_for_same_tab_id(client, chart_id: int, dataset_id: int):
assert first_key == second_key assert first_key == second_key
def test_put_different_key_for_different_tab_id(client, chart_id: int, dataset_id: int): def test_put_different_key_for_different_tab_id(
client, chart_id: int, datasource: SqlaTable
):
login(client, "admin") login(client, "admin")
payload = { payload = {
"dataset_id": dataset_id, "datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": UPDATED_FORM_DATA, "form_data": UPDATED_FORM_DATA,
} }
@ -253,10 +269,11 @@ def test_put_different_key_for_different_tab_id(client, chart_id: int, dataset_i
assert first_key != second_key assert first_key != second_key
def test_put_different_key_for_no_tab_id(client, chart_id: int, dataset_id: int): def test_put_different_key_for_no_tab_id(client, chart_id: int, datasource: SqlaTable):
login(client, "admin") login(client, "admin")
payload = { payload = {
"dataset_id": dataset_id, "datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": UPDATED_FORM_DATA, "form_data": UPDATED_FORM_DATA,
} }
@ -269,10 +286,11 @@ def test_put_different_key_for_no_tab_id(client, chart_id: int, dataset_id: int)
assert first_key != second_key assert first_key != second_key
def test_put_bad_request(client, chart_id: int, dataset_id: int): def test_put_bad_request(client, chart_id: int, datasource: SqlaTable):
login(client, "admin") login(client, "admin")
payload = { payload = {
"dataset_id": dataset_id, "datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": 1234, "form_data": 1234,
} }
@ -280,10 +298,11 @@ def test_put_bad_request(client, chart_id: int, dataset_id: int):
assert resp.status_code == 400 assert resp.status_code == 400
def test_put_bad_request_non_string(client, chart_id: int, dataset_id: int): def test_put_bad_request_non_string(client, chart_id: int, datasource: SqlaTable):
login(client, "admin") login(client, "admin")
payload = { payload = {
"dataset_id": dataset_id, "datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": 1234, "form_data": 1234,
} }
@ -291,10 +310,11 @@ def test_put_bad_request_non_string(client, chart_id: int, dataset_id: int):
assert resp.status_code == 400 assert resp.status_code == 400
def test_put_bad_request_non_json_string(client, chart_id: int, dataset_id: int): def test_put_bad_request_non_json_string(client, chart_id: int, datasource: SqlaTable):
login(client, "admin") login(client, "admin")
payload = { payload = {
"dataset_id": dataset_id, "datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": "foo", "form_data": "foo",
} }
@ -302,10 +322,11 @@ def test_put_bad_request_non_json_string(client, chart_id: int, dataset_id: int)
assert resp.status_code == 400 assert resp.status_code == 400
def test_put_access_denied(client, chart_id: int, dataset_id: int): def test_put_access_denied(client, chart_id: int, datasource: SqlaTable):
login(client, "gamma") login(client, "gamma")
payload = { payload = {
"dataset_id": dataset_id, "datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": UPDATED_FORM_DATA, "form_data": UPDATED_FORM_DATA,
} }
@ -313,10 +334,11 @@ def test_put_access_denied(client, chart_id: int, dataset_id: int):
assert resp.status_code == 404 assert resp.status_code == 404
def test_put_not_owner(client, chart_id: int, dataset_id: int): def test_put_not_owner(client, chart_id: int, datasource: SqlaTable):
login(client, "gamma") login(client, "gamma")
payload = { payload = {
"dataset_id": dataset_id, "datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": UPDATED_FORM_DATA, "form_data": UPDATED_FORM_DATA,
} }
@ -364,12 +386,13 @@ def test_delete_access_denied(client):
assert resp.status_code == 404 assert resp.status_code == 404
def test_delete_not_owner(client, chart_id: int, dataset_id: int, admin_id: int): def test_delete_not_owner(client, chart_id: int, datasource: SqlaTable, admin_id: int):
another_key = "another_key" another_key = "another_key"
another_owner = admin_id + 1 another_owner = admin_id + 1
entry: TemporaryExploreState = { entry: TemporaryExploreState = {
"owner": another_owner, "owner": another_owner,
"dataset_id": dataset_id, "datasource_id": datasource.id,
"datasource_type": datasource.type,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": INITIAL_FORM_DATA, "form_data": INITIAL_FORM_DATA,
} }

View File

@ -0,0 +1,359 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
from unittest.mock import patch
import pytest
from superset import app, db, security, security_manager
from superset.commands.exceptions import DatasourceTypeInvalidError
from superset.connectors.sqla.models import SqlaTable
from superset.explore.form_data.commands.create import CreateFormDataCommand
from superset.explore.form_data.commands.delete import DeleteFormDataCommand
from superset.explore.form_data.commands.get import GetFormDataCommand
from superset.explore.form_data.commands.parameters import CommandParameters
from superset.explore.form_data.commands.update import UpdateFormDataCommand
from superset.models.slice import Slice
from superset.models.sql_lab import Query
from superset.utils.core import DatasourceType, get_example_default_schema
from superset.utils.database import get_example_database
from tests.integration_tests.base_tests import SupersetTestCase
class TestCreateFormDataCommand(SupersetTestCase):
@pytest.fixture()
def create_dataset(self):
with self.create_app().app_context():
dataset = SqlaTable(
table_name="dummy_sql_table",
database=get_example_database(),
schema=get_example_default_schema(),
sql="select 123 as intcol, 'abc' as strcol",
)
session = db.session
session.add(dataset)
session.commit()
yield dataset
# rollback
session.delete(dataset)
session.commit()
@pytest.fixture()
def create_slice(self):
with self.create_app().app_context():
session = db.session
dataset = (
session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first()
)
slice = Slice(
datasource_id=dataset.id,
datasource_type=DatasourceType.TABLE,
datasource_name="tmp_perm_table",
slice_name="slice_name",
)
session.add(slice)
session.commit()
yield slice
# rollback
session.delete(slice)
session.commit()
@pytest.fixture()
def create_query(self):
with self.create_app().app_context():
session = db.session
query = Query(
sql="select 1 as foo;",
client_id="sldkfjlk",
database=get_example_database(),
)
session.add(query)
session.commit()
yield query
# rollback
session.delete(query)
session.commit()
@patch("superset.security.manager.g")
@pytest.mark.usefixtures("create_dataset", "create_slice")
def test_create_form_data_command(self, mock_g):
mock_g.user = security_manager.find_user("admin")
dataset = (
db.session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first()
)
slice = db.session.query(Slice).filter_by(slice_name="slice_name").first()
datasource = f"{dataset.id}__{DatasourceType.TABLE}"
args = CommandParameters(
actor=mock_g.user,
datasource_id=dataset.id,
datasource_type=DatasourceType.TABLE,
chart_id=slice.id,
tab_id=1,
form_data=json.dumps({"datasource": datasource}),
)
command = CreateFormDataCommand(args)
assert isinstance(command.run(), str)
@patch("superset.security.manager.g")
@pytest.mark.usefixtures("create_dataset", "create_slice", "create_query")
def test_create_form_data_command_invalid_type(self, mock_g):
mock_g.user = security_manager.find_user("admin")
app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = {
"REFRESH_TIMEOUT_ON_RETRIEVAL": True
}
dataset = (
db.session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first()
)
slice = db.session.query(Slice).filter_by(slice_name="slice_name").first()
datasource = f"{dataset.id}__{DatasourceType.TABLE}"
create_args = CommandParameters(
actor=mock_g.user,
datasource_id=dataset.id,
datasource_type="InvalidType",
chart_id=slice.id,
tab_id=1,
form_data=json.dumps({"datasource": datasource}),
)
with pytest.raises(DatasourceTypeInvalidError) as exc:
CreateFormDataCommand(create_args).run()
assert "Datasource type is invalid" in str(exc.value)
@patch("superset.security.manager.g")
@pytest.mark.usefixtures("create_dataset", "create_slice", "create_query")
def test_create_form_data_command_type_as_string(self, mock_g):
mock_g.user = security_manager.find_user("admin")
app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = {
"REFRESH_TIMEOUT_ON_RETRIEVAL": True
}
dataset = (
db.session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first()
)
slice = db.session.query(Slice).filter_by(slice_name="slice_name").first()
datasource = f"{dataset.id}__{DatasourceType.TABLE}"
create_args = CommandParameters(
actor=mock_g.user,
datasource_id=dataset.id,
datasource_type="table",
chart_id=slice.id,
tab_id=1,
form_data=json.dumps({"datasource": datasource}),
)
command = CreateFormDataCommand(create_args)
assert isinstance(command.run(), str)
@patch("superset.security.manager.g")
@pytest.mark.usefixtures("create_dataset", "create_slice")
def test_get_form_data_command(self, mock_g):
mock_g.user = security_manager.find_user("admin")
app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = {
"REFRESH_TIMEOUT_ON_RETRIEVAL": True
}
dataset = (
db.session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first()
)
slice = db.session.query(Slice).filter_by(slice_name="slice_name").first()
datasource = f"{dataset.id}__{DatasourceType.TABLE}"
create_args = CommandParameters(
actor=mock_g.user,
datasource_id=dataset.id,
datasource_type=DatasourceType.TABLE,
chart_id=slice.id,
tab_id=1,
form_data=json.dumps({"datasource": datasource}),
)
key = CreateFormDataCommand(create_args).run()
key_args = CommandParameters(actor=mock_g.user, key=key)
get_command = GetFormDataCommand(key_args)
cache_data = json.loads(get_command.run())
assert cache_data.get("datasource") == datasource
@patch("superset.security.manager.g")
@pytest.mark.usefixtures("create_dataset", "create_slice", "create_query")
def test_update_form_data_command(self, mock_g):
mock_g.user = security_manager.find_user("admin")
app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = {
"REFRESH_TIMEOUT_ON_RETRIEVAL": True
}
dataset = (
db.session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first()
)
slice = db.session.query(Slice).filter_by(slice_name="slice_name").first()
query = db.session.query(Query).filter_by(sql="select 1 as foo;").first()
datasource = f"{dataset.id}__{DatasourceType.TABLE}"
create_args = CommandParameters(
actor=mock_g.user,
datasource_id=dataset.id,
datasource_type=DatasourceType.TABLE,
chart_id=slice.id,
tab_id=1,
form_data=json.dumps({"datasource": datasource}),
)
key = CreateFormDataCommand(create_args).run()
query_datasource = f"{dataset.id}__{DatasourceType.TABLE}"
update_args = CommandParameters(
actor=mock_g.user,
datasource_id=query.id,
datasource_type=DatasourceType.QUERY,
chart_id=slice.id,
tab_id=1,
form_data=json.dumps({"datasource": query_datasource}),
key=key,
)
update_command = UpdateFormDataCommand(update_args)
new_key = update_command.run()
# it should return a key
assert isinstance(new_key, str)
# the updated key returned should be different from the old one
assert new_key != key
key_args = CommandParameters(actor=mock_g.user, key=key)
get_command = GetFormDataCommand(key_args)
cache_data = json.loads(get_command.run())
assert cache_data.get("datasource") == query_datasource
@patch("superset.security.manager.g")
@pytest.mark.usefixtures("create_dataset", "create_slice", "create_query")
def test_update_form_data_command_same_form_data(self, mock_g):
mock_g.user = security_manager.find_user("admin")
app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = {
"REFRESH_TIMEOUT_ON_RETRIEVAL": True
}
dataset = (
db.session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first()
)
slice = db.session.query(Slice).filter_by(slice_name="slice_name").first()
datasource = f"{dataset.id}__{DatasourceType.TABLE}"
create_args = CommandParameters(
actor=mock_g.user,
datasource_id=dataset.id,
datasource_type=DatasourceType.TABLE,
chart_id=slice.id,
tab_id=1,
form_data=json.dumps({"datasource": datasource}),
)
key = CreateFormDataCommand(create_args).run()
update_args = CommandParameters(
actor=mock_g.user,
datasource_id=dataset.id,
datasource_type=DatasourceType.TABLE,
chart_id=slice.id,
tab_id=1,
form_data=json.dumps({"datasource": datasource}),
key=key,
)
update_command = UpdateFormDataCommand(update_args)
new_key = update_command.run()
# it should return a key
assert isinstance(new_key, str)
# the updated key returned should be the same as the old one
assert new_key == key
key_args = CommandParameters(actor=mock_g.user, key=key)
get_command = GetFormDataCommand(key_args)
cache_data = json.loads(get_command.run())
assert cache_data.get("datasource") == datasource
@patch("superset.security.manager.g")
@pytest.mark.usefixtures("create_dataset", "create_slice", "create_query")
def test_delete_form_data_command(self, mock_g):
mock_g.user = security_manager.find_user("admin")
app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = {
"REFRESH_TIMEOUT_ON_RETRIEVAL": True
}
dataset = (
db.session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first()
)
slice = db.session.query(Slice).filter_by(slice_name="slice_name").first()
datasource = f"{dataset.id}__{DatasourceType.TABLE}"
create_args = CommandParameters(
actor=mock_g.user,
datasource_id=dataset.id,
datasource_type=DatasourceType.TABLE,
chart_id=slice.id,
tab_id=1,
form_data=json.dumps({"datasource": datasource}),
)
key = CreateFormDataCommand(create_args).run()
delete_args = CommandParameters(
actor=mock_g.user,
key=key,
)
delete_command = DeleteFormDataCommand(delete_args)
response = delete_command.run()
assert response == True
@patch("superset.security.manager.g")
@pytest.mark.usefixtures("create_dataset", "create_slice", "create_query")
def test_delete_form_data_command_key_expired(self, mock_g):
mock_g.user = security_manager.find_user("admin")
app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = {
"REFRESH_TIMEOUT_ON_RETRIEVAL": True
}
delete_args = CommandParameters(
actor=mock_g.user,
key="some_expired_key",
)
delete_command = DeleteFormDataCommand(delete_args)
response = delete_command.run()
assert response == False

View File

@ -27,6 +27,7 @@ from superset.key_value.models import KeyValueEntry
from superset.key_value.types import KeyValueResource from superset.key_value.types import KeyValueResource
from superset.key_value.utils import decode_permalink_id, encode_permalink_key from superset.key_value.utils import decode_permalink_id, encode_permalink_key
from superset.models.slice import Slice from superset.models.slice import Slice
from superset.utils.core import DatasourceType
from tests.integration_tests.base_tests import login from tests.integration_tests.base_tests import login
from tests.integration_tests.fixtures.client import client from tests.integration_tests.fixtures.client import client
from tests.integration_tests.fixtures.world_bank_dashboard import ( from tests.integration_tests.fixtures.world_bank_dashboard import (
@ -97,7 +98,8 @@ def test_get_missing_chart(client, chart, permalink_salt: str) -> None:
value=pickle.dumps( value=pickle.dumps(
{ {
"chartId": chart_id, "chartId": chart_id,
"datasetId": chart.datasource.id, "datasourceId": chart.datasource.id,
"datasourceType": DatasourceType.TABLE,
"formData": { "formData": {
"slice_id": chart_id, "slice_id": chart_id,
"datasource": f"{chart.datasource.id}__{chart.datasource.type}", "datasource": f"{chart.datasource.id}__{chart.datasource.type}",

View File

@ -40,7 +40,7 @@ from superset.dashboards.commands.importers.v0 import import_chart, import_dashb
from superset.datasets.commands.importers.v0 import import_dataset from superset.datasets.commands.importers.v0 import import_dataset
from superset.models.dashboard import Dashboard from superset.models.dashboard import Dashboard
from superset.models.slice import Slice from superset.models.slice import Slice
from superset.utils.core import get_example_default_schema from superset.utils.core import DatasourceType, get_example_default_schema
from superset.utils.database import get_example_database from superset.utils.database import get_example_database
from tests.integration_tests.fixtures.world_bank_dashboard import ( from tests.integration_tests.fixtures.world_bank_dashboard import (
@ -103,7 +103,7 @@ class TestImportExport(SupersetTestCase):
return Slice( return Slice(
slice_name=name, slice_name=name,
datasource_type="table", datasource_type=DatasourceType.TABLE,
viz_type="bubble", viz_type="bubble",
params=json.dumps(params), params=json.dumps(params),
datasource_id=ds_id, datasource_id=ds_id,

View File

@ -16,6 +16,7 @@
# under the License. # under the License.
# isort:skip_file # isort:skip_file
import json import json
from superset.utils.core import DatasourceType
import textwrap import textwrap
import unittest import unittest
from unittest import mock from unittest import mock
@ -604,7 +605,7 @@ class TestSqlaTableModel(SupersetTestCase):
dashboard = self.get_dash_by_slug("births") dashboard = self.get_dash_by_slug("births")
slc = Slice( slc = Slice(
slice_name="slice with adhoc column", slice_name="slice with adhoc column",
datasource_type="table", datasource_type=DatasourceType.TABLE,
viz_type="table", viz_type="table",
params=json.dumps( params=json.dumps(
{ {

View File

@ -39,6 +39,7 @@ from superset.models.core import Database
from superset.models.slice import Slice from superset.models.slice import Slice
from superset.sql_parse import Table from superset.sql_parse import Table
from superset.utils.core import ( from superset.utils.core import (
DatasourceType,
backend, backend,
get_example_default_schema, get_example_default_schema,
) )
@ -120,7 +121,7 @@ class TestRolePermission(SupersetTestCase):
ds_slices = ( ds_slices = (
session.query(Slice) session.query(Slice)
.filter_by(datasource_type="table") .filter_by(datasource_type=DatasourceType.TABLE)
.filter_by(datasource_id=ds.id) .filter_by(datasource_id=ds.id)
.all() .all()
) )
@ -143,7 +144,7 @@ class TestRolePermission(SupersetTestCase):
ds.schema_perm = None ds.schema_perm = None
ds_slices = ( ds_slices = (
session.query(Slice) session.query(Slice)
.filter_by(datasource_type="table") .filter_by(datasource_type=DatasourceType.TABLE)
.filter_by(datasource_id=ds.id) .filter_by(datasource_id=ds.id)
.all() .all()
) )
@ -365,7 +366,7 @@ class TestRolePermission(SupersetTestCase):
# no schema permission # no schema permission
slice = Slice( slice = Slice(
datasource_id=table.id, datasource_id=table.id,
datasource_type="table", datasource_type=DatasourceType.TABLE,
datasource_name="tmp_perm_table", datasource_name="tmp_perm_table",
slice_name="slice_name", slice_name="slice_name",
) )

View File

@ -0,0 +1,49 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from superset.extensions import cache_manager
from superset.utils.core import backend, DatasourceType
from tests.integration_tests.base_tests import SupersetTestCase
class UtilsCacheManagerTests(SupersetTestCase):
def test_get_set_explore_form_data_cache(self):
key = "12345"
data = {"foo": "bar", "datasource_type": "query"}
cache_manager.explore_form_data_cache.set(key, data)
assert cache_manager.explore_form_data_cache.get(key) == data
def test_get_same_context_twice(self):
key = "12345"
data = {"foo": "bar", "datasource_type": "query"}
cache_manager.explore_form_data_cache.set(key, data)
assert cache_manager.explore_form_data_cache.get(key) == data
assert cache_manager.explore_form_data_cache.get(key) == data
def test_get_set_explore_form_data_cache_no_datasource_type(self):
key = "12345"
data = {"foo": "bar"}
cache_manager.explore_form_data_cache.set(key, data)
# datasource_type should be added because it is not present
assert cache_manager.explore_form_data_cache.get(key) == {
"datasource_type": DatasourceType.TABLE,
**data,
}
def test_get_explore_form_data_cache_invalid_key(self):
assert cache_manager.explore_form_data_cache.get("foo") == None

View File

@ -106,7 +106,7 @@ def test_get_datasource_sqlatable(
from superset.dao.datasource.dao import DatasourceDAO from superset.dao.datasource.dao import DatasourceDAO
result = DatasourceDAO.get_datasource( result = DatasourceDAO.get_datasource(
datasource_type=DatasourceType.SQLATABLE, datasource_type=DatasourceType.TABLE,
datasource_id=1, datasource_id=1,
session=session_with_data, session=session_with_data,
) )
@ -151,7 +151,9 @@ def test_get_datasource_sl_table(app_context: None, session_with_data: Session)
# todo(hugh): This will break once we remove the dual write # todo(hugh): This will break once we remove the dual write
# update the datsource_id=1 and this will pass again # update the datsource_id=1 and this will pass again
result = DatasourceDAO.get_datasource( result = DatasourceDAO.get_datasource(
datasource_type=DatasourceType.TABLE, datasource_id=2, session=session_with_data datasource_type=DatasourceType.SLTABLE,
datasource_id=2,
session=session_with_data,
) )
assert result.id == 2 assert result.id == 2

View File

@ -23,12 +23,21 @@ from superset.charts.commands.exceptions import (
ChartAccessDeniedError, ChartAccessDeniedError,
ChartNotFoundError, ChartNotFoundError,
) )
from superset.commands.exceptions import (
DatasourceNotFoundValidationError,
DatasourceTypeInvalidError,
OwnersNotFoundValidationError,
QueryNotFoundValidationError,
)
from superset.datasets.commands.exceptions import ( from superset.datasets.commands.exceptions import (
DatasetAccessDeniedError, DatasetAccessDeniedError,
DatasetNotFoundError, DatasetNotFoundError,
) )
from superset.exceptions import SupersetSecurityException
from superset.utils.core import DatasourceType
dataset_find_by_id = "superset.datasets.dao.DatasetDAO.find_by_id" dataset_find_by_id = "superset.datasets.dao.DatasetDAO.find_by_id"
query_find_by_id = "superset.queries.dao.QueryDAO.find_by_id"
chart_find_by_id = "superset.charts.dao.ChartDAO.find_by_id" chart_find_by_id = "superset.charts.dao.ChartDAO.find_by_id"
is_user_admin = "superset.explore.utils.is_user_admin" is_user_admin = "superset.explore.utils.is_user_admin"
is_owner = "superset.explore.utils.is_owner" is_owner = "superset.explore.utils.is_owner"
@ -36,88 +45,142 @@ can_access_datasource = (
"superset.security.SupersetSecurityManager.can_access_datasource" "superset.security.SupersetSecurityManager.can_access_datasource"
) )
can_access = "superset.security.SupersetSecurityManager.can_access" can_access = "superset.security.SupersetSecurityManager.can_access"
raise_for_access = "superset.security.SupersetSecurityManager.raise_for_access"
query_datasources_by_name = (
"superset.connectors.sqla.models.SqlaTable.query_datasources_by_name"
)
def test_unsaved_chart_no_dataset_id(app_context: AppContext) -> None: def test_unsaved_chart_no_dataset_id(app_context: AppContext) -> None:
from superset.explore.utils import check_access from superset.explore.utils import check_access as check_chart_access
with raises(DatasetNotFoundError): with raises(DatasourceNotFoundValidationError):
check_access(dataset_id=0, chart_id=0, actor=User()) check_chart_access(
datasource_id=0,
chart_id=0,
actor=User(),
datasource_type=DatasourceType.TABLE,
)
def test_unsaved_chart_unknown_dataset_id( def test_unsaved_chart_unknown_dataset_id(
mocker: MockFixture, app_context: AppContext mocker: MockFixture, app_context: AppContext
) -> None: ) -> None:
from superset.explore.utils import check_access from superset.explore.utils import check_access as check_chart_access
with raises(DatasetNotFoundError): with raises(DatasetNotFoundError):
mocker.patch(dataset_find_by_id, return_value=None) mocker.patch(dataset_find_by_id, return_value=None)
check_access(dataset_id=1, chart_id=0, actor=User()) check_chart_access(
datasource_id=1,
chart_id=0,
actor=User(),
datasource_type=DatasourceType.TABLE,
)
def test_unsaved_chart_unknown_query_id(
mocker: MockFixture, app_context: AppContext
) -> None:
from superset.explore.utils import check_access as check_chart_access
with raises(QueryNotFoundValidationError):
mocker.patch(query_find_by_id, return_value=None)
check_chart_access(
datasource_id=1,
chart_id=0,
actor=User(),
datasource_type=DatasourceType.QUERY,
)
def test_unsaved_chart_unauthorized_dataset( def test_unsaved_chart_unauthorized_dataset(
mocker: MockFixture, app_context: AppContext mocker: MockFixture, app_context: AppContext
) -> None: ) -> None:
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.explore import utils from superset.explore.utils import check_access as check_chart_access
with raises(DatasetAccessDeniedError): with raises(DatasetAccessDeniedError):
mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(dataset_find_by_id, return_value=SqlaTable())
mocker.patch(can_access_datasource, return_value=False) mocker.patch(can_access_datasource, return_value=False)
utils.check_access(dataset_id=1, chart_id=0, actor=User()) check_chart_access(
datasource_id=1,
chart_id=0,
actor=User(),
datasource_type=DatasourceType.TABLE,
)
def test_unsaved_chart_authorized_dataset( def test_unsaved_chart_authorized_dataset(
mocker: MockFixture, app_context: AppContext mocker: MockFixture, app_context: AppContext
) -> None: ) -> None:
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.explore.utils import check_access from superset.explore.utils import check_access as check_chart_access
mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(dataset_find_by_id, return_value=SqlaTable())
mocker.patch(can_access_datasource, return_value=True) mocker.patch(can_access_datasource, return_value=True)
check_access(dataset_id=1, chart_id=0, actor=User()) check_chart_access(
datasource_id=1,
chart_id=0,
actor=User(),
datasource_type=DatasourceType.TABLE,
)
def test_saved_chart_unknown_chart_id( def test_saved_chart_unknown_chart_id(
mocker: MockFixture, app_context: AppContext mocker: MockFixture, app_context: AppContext
) -> None: ) -> None:
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.explore.utils import check_access from superset.explore.utils import check_access as check_chart_access
with raises(ChartNotFoundError): with raises(ChartNotFoundError):
mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(dataset_find_by_id, return_value=SqlaTable())
mocker.patch(can_access_datasource, return_value=True) mocker.patch(can_access_datasource, return_value=True)
mocker.patch(chart_find_by_id, return_value=None) mocker.patch(chart_find_by_id, return_value=None)
check_access(dataset_id=1, chart_id=1, actor=User()) check_chart_access(
datasource_id=1,
chart_id=1,
actor=User(),
datasource_type=DatasourceType.TABLE,
)
def test_saved_chart_unauthorized_dataset( def test_saved_chart_unauthorized_dataset(
mocker: MockFixture, app_context: AppContext mocker: MockFixture, app_context: AppContext
) -> None: ) -> None:
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.explore import utils from superset.explore.utils import check_access as check_chart_access
with raises(DatasetAccessDeniedError): with raises(DatasetAccessDeniedError):
mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(dataset_find_by_id, return_value=SqlaTable())
mocker.patch(can_access_datasource, return_value=False) mocker.patch(can_access_datasource, return_value=False)
utils.check_access(dataset_id=1, chart_id=1, actor=User()) check_chart_access(
datasource_id=1,
chart_id=1,
actor=User(),
datasource_type=DatasourceType.TABLE,
)
def test_saved_chart_is_admin(mocker: MockFixture, app_context: AppContext) -> None: def test_saved_chart_is_admin(mocker: MockFixture, app_context: AppContext) -> None:
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.explore.utils import check_access from superset.explore.utils import check_access as check_chart_access
from superset.models.slice import Slice from superset.models.slice import Slice
mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(dataset_find_by_id, return_value=SqlaTable())
mocker.patch(can_access_datasource, return_value=True) mocker.patch(can_access_datasource, return_value=True)
mocker.patch(is_user_admin, return_value=True) mocker.patch(is_user_admin, return_value=True)
mocker.patch(chart_find_by_id, return_value=Slice()) mocker.patch(chart_find_by_id, return_value=Slice())
check_access(dataset_id=1, chart_id=1, actor=User()) check_chart_access(
datasource_id=1,
chart_id=1,
actor=User(),
datasource_type=DatasourceType.TABLE,
)
def test_saved_chart_is_owner(mocker: MockFixture, app_context: AppContext) -> None: def test_saved_chart_is_owner(mocker: MockFixture, app_context: AppContext) -> None:
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.explore.utils import check_access from superset.explore.utils import check_access as check_chart_access
from superset.models.slice import Slice from superset.models.slice import Slice
mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(dataset_find_by_id, return_value=SqlaTable())
@ -125,12 +188,17 @@ def test_saved_chart_is_owner(mocker: MockFixture, app_context: AppContext) -> N
mocker.patch(is_user_admin, return_value=False) mocker.patch(is_user_admin, return_value=False)
mocker.patch(is_owner, return_value=True) mocker.patch(is_owner, return_value=True)
mocker.patch(chart_find_by_id, return_value=Slice()) mocker.patch(chart_find_by_id, return_value=Slice())
check_access(dataset_id=1, chart_id=1, actor=User()) check_chart_access(
datasource_id=1,
chart_id=1,
actor=User(),
datasource_type=DatasourceType.TABLE,
)
def test_saved_chart_has_access(mocker: MockFixture, app_context: AppContext) -> None: def test_saved_chart_has_access(mocker: MockFixture, app_context: AppContext) -> None:
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.explore.utils import check_access from superset.explore.utils import check_access as check_chart_access
from superset.models.slice import Slice from superset.models.slice import Slice
mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(dataset_find_by_id, return_value=SqlaTable())
@ -139,12 +207,17 @@ def test_saved_chart_has_access(mocker: MockFixture, app_context: AppContext) ->
mocker.patch(is_owner, return_value=False) mocker.patch(is_owner, return_value=False)
mocker.patch(can_access, return_value=True) mocker.patch(can_access, return_value=True)
mocker.patch(chart_find_by_id, return_value=Slice()) mocker.patch(chart_find_by_id, return_value=Slice())
check_access(dataset_id=1, chart_id=1, actor=User()) check_chart_access(
datasource_id=1,
chart_id=1,
actor=User(),
datasource_type=DatasourceType.TABLE,
)
def test_saved_chart_no_access(mocker: MockFixture, app_context: AppContext) -> None: def test_saved_chart_no_access(mocker: MockFixture, app_context: AppContext) -> None:
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.explore.utils import check_access from superset.explore.utils import check_access as check_chart_access
from superset.models.slice import Slice from superset.models.slice import Slice
with raises(ChartAccessDeniedError): with raises(ChartAccessDeniedError):
@ -154,4 +227,66 @@ def test_saved_chart_no_access(mocker: MockFixture, app_context: AppContext) ->
mocker.patch(is_owner, return_value=False) mocker.patch(is_owner, return_value=False)
mocker.patch(can_access, return_value=False) mocker.patch(can_access, return_value=False)
mocker.patch(chart_find_by_id, return_value=Slice()) mocker.patch(chart_find_by_id, return_value=Slice())
check_access(dataset_id=1, chart_id=1, actor=User()) check_chart_access(
datasource_id=1,
chart_id=1,
actor=User(),
datasource_type=DatasourceType.TABLE,
)
def test_dataset_has_access(mocker: MockFixture, app_context: AppContext) -> None:
from superset.connectors.sqla.models import SqlaTable
from superset.explore.utils import check_datasource_access
mocker.patch(dataset_find_by_id, return_value=SqlaTable())
mocker.patch(can_access_datasource, return_value=True)
mocker.patch(is_user_admin, return_value=False)
mocker.patch(is_owner, return_value=False)
mocker.patch(can_access, return_value=True)
assert (
check_datasource_access(
datasource_id=1,
datasource_type=DatasourceType.TABLE,
)
== True
)
def test_query_has_access(mocker: MockFixture, app_context: AppContext) -> None:
from superset.explore.utils import check_datasource_access
from superset.models.sql_lab import Query
mocker.patch(query_find_by_id, return_value=Query())
mocker.patch(raise_for_access, return_value=True)
mocker.patch(is_user_admin, return_value=False)
mocker.patch(is_owner, return_value=False)
mocker.patch(can_access, return_value=True)
assert (
check_datasource_access(
datasource_id=1,
datasource_type=DatasourceType.QUERY,
)
== True
)
def test_query_no_access(mocker: MockFixture, app_context: AppContext) -> None:
from superset.connectors.sqla.models import SqlaTable
from superset.explore.utils import check_datasource_access
from superset.models.core import Database
from superset.models.sql_lab import Query
with raises(SupersetSecurityException):
mocker.patch(
query_find_by_id,
return_value=Query(database=Database(), sql="select * from foo"),
)
mocker.patch(query_datasources_by_name, return_value=[SqlaTable()])
mocker.patch(is_user_admin, return_value=False)
mocker.patch(is_owner, return_value=False)
mocker.patch(can_access, return_value=False)
check_datasource_access(
datasource_id=1,
datasource_type=DatasourceType.QUERY,
)