feat: add permalink to dashboard and explore (#19078)

* rename key_value to temporary_cache

* add migration

* create new key_value package

* add commands

* lots of new stuff

* fix schema reference

* remove redundant filter state from bootstrap data

* add missing license headers

* fix pylint

* fix dashboard permalink access

* use valid json mocks for filter state tests

* fix temporary cache tests

* add anchors to dashboard state

* lint

* fix util test

* fix url shortlink button tests

* remove legacy shortner

* remove unused imports

* fix js tests

* fix test

* add native filter state to anchor link

* add UPDATING.md section

* address comments

* address comments

* lint

* fix test

* add utils tests + other test stubs

* add key_value integration tests

* add filter box state to permalink state

* fully support persisting url parameters

* lint, add redirects and a few integration tests

* fix test + clean up trailing comma

* fix anchor bug

* change value to LargeBinary to support persisting binary values

* fix urlParams type and simplify urlencode

* lint

* add optional entry expiration

* fix incorrect chart id + add test
This commit is contained in:
Ville Brofeldt 2022-03-17 01:15:52 +02:00 committed by GitHub
parent d01fdad1d8
commit b7a0559aaf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
94 changed files with 2943 additions and 439 deletions

View File

@ -49,11 +49,12 @@ flag for the legacy datasource editor (DISABLE_LEGACY_DATASOURCE_EDITOR) in conf
### Deprecations ### Deprecations
- [19078](https://github.com/apache/superset/pull/19078): Creation of old shorturl links has been deprecated in favor of a new permalink feature that solves the long url problem (old shorturls will still work, though!). By default, new permalinks use UUID4 as the key. However, to use serial ids similar to the old shorturls, add the following to your `superset_config.py`: `PERMALINK_KEY_TYPE = "id"`.
- [18960](https://github.com/apache/superset/pull/18960): Persisting URL params in chart metadata is no longer supported. To set a default value for URL params in Jinja code, use the optional second argument: `url_param("my-param", "my-default-value")`. - [18960](https://github.com/apache/superset/pull/18960): Persisting URL params in chart metadata is no longer supported. To set a default value for URL params in Jinja code, use the optional second argument: `url_param("my-param", "my-default-value")`.
### Other ### Other
- [17589](https://github.com/apache/incubator-superset/pull/17589): It is now possible to limit access to users' recent activity data by setting the `ENABLE_BROAD_ACTIVITY_ACCESS` config flag to false, or customizing the `raise_for_user_activity_access` method in the security manager. - [17589](https://github.com/apache/superset/pull/17589): It is now possible to limit access to users' recent activity data by setting the `ENABLE_BROAD_ACTIVITY_ACCESS` config flag to false, or customizing the `raise_for_user_activity_access` method in the security manager.
- [17536](https://github.com/apache/superset/pull/17536): introduced a key-value endpoint to store dashboard filter state. This endpoint is backed by Flask-Caching and the default configuration assumes that the values will be stored in the file system. If you are already using another cache backend like Redis or Memchached, you'll probably want to change this setting in `superset_config.py`. The key is `FILTER_STATE_CACHE_CONFIG` and the available settings can be found in Flask-Caching [docs](https://flask-caching.readthedocs.io/en/latest/). - [17536](https://github.com/apache/superset/pull/17536): introduced a key-value endpoint to store dashboard filter state. This endpoint is backed by Flask-Caching and the default configuration assumes that the values will be stored in the file system. If you are already using another cache backend like Redis or Memchached, you'll probably want to change this setting in `superset_config.py`. The key is `FILTER_STATE_CACHE_CONFIG` and the available settings can be found in Flask-Caching [docs](https://flask-caching.readthedocs.io/en/latest/).
- [17882](https://github.com/apache/superset/pull/17882): introduced a key-value endpoint to store Explore form data. This endpoint is backed by Flask-Caching and the default configuration assumes that the values will be stored in the file system. If you are already using another cache backend like Redis or Memchached, you'll probably want to change this setting in `superset_config.py`. The key is `EXPLORE_FORM_DATA_CACHE_CONFIG` and the available settings can be found in Flask-Caching [docs](https://flask-caching.readthedocs.io/en/latest/). - [17882](https://github.com/apache/superset/pull/17882): introduced a key-value endpoint to store Explore form data. This endpoint is backed by Flask-Caching and the default configuration assumes that the values will be stored in the file system. If you are already using another cache backend like Redis or Memchached, you'll probably want to change this setting in `superset_config.py`. The key is `EXPLORE_FORM_DATA_CACHE_CONFIG` and the available settings can be found in Flask-Caching [docs](https://flask-caching.readthedocs.io/en/latest/).

View File

@ -25,6 +25,7 @@ import URLShortLinkButton from 'src/components/URLShortLinkButton';
describe('AnchorLink', () => { describe('AnchorLink', () => {
const props = { const props = {
anchorLinkId: 'CHART-123', anchorLinkId: 'CHART-123',
dashboardId: 10,
}; };
const globalLocation = window.location; const globalLocation = window.location;
@ -64,8 +65,9 @@ describe('AnchorLink', () => {
expect(wrapper.find(URLShortLinkButton)).toExist(); expect(wrapper.find(URLShortLinkButton)).toExist();
expect(wrapper.find(URLShortLinkButton)).toHaveProp({ placement: 'right' }); expect(wrapper.find(URLShortLinkButton)).toHaveProp({ placement: 'right' });
const targetUrl = wrapper.find(URLShortLinkButton).prop('url'); const anchorLinkId = wrapper.find(URLShortLinkButton).prop('anchorLinkId');
const hash = targetUrl.slice(targetUrl.indexOf('#') + 1); const dashboardId = wrapper.find(URLShortLinkButton).prop('dashboardId');
expect(hash).toBe(props.anchorLinkId); expect(anchorLinkId).toBe(props.anchorLinkId);
expect(dashboardId).toBe(props.dashboardId);
}); });
}); });

View File

@ -21,11 +21,11 @@ import PropTypes from 'prop-types';
import { t } from '@superset-ui/core'; import { t } from '@superset-ui/core';
import URLShortLinkButton from 'src/components/URLShortLinkButton'; import URLShortLinkButton from 'src/components/URLShortLinkButton';
import getDashboardUrl from 'src/dashboard/util/getDashboardUrl';
import getLocationHash from 'src/dashboard/util/getLocationHash'; import getLocationHash from 'src/dashboard/util/getLocationHash';
const propTypes = { const propTypes = {
anchorLinkId: PropTypes.string.isRequired, anchorLinkId: PropTypes.string.isRequired,
dashboardId: PropTypes.number,
filters: PropTypes.object, filters: PropTypes.object,
showShortLinkButton: PropTypes.bool, showShortLinkButton: PropTypes.bool,
inFocus: PropTypes.bool, inFocus: PropTypes.bool,
@ -70,17 +70,14 @@ class AnchorLink extends React.PureComponent {
} }
render() { render() {
const { anchorLinkId, filters, showShortLinkButton, placement } = const { anchorLinkId, dashboardId, showShortLinkButton, placement } =
this.props; this.props;
return ( return (
<span className="anchor-link-container" id={anchorLinkId}> <span className="anchor-link-container" id={anchorLinkId}>
{showShortLinkButton && ( {showShortLinkButton && (
<URLShortLinkButton <URLShortLinkButton
url={getDashboardUrl({ anchorLinkId={anchorLinkId}
pathname: window.location.pathname, dashboardId={dashboardId}
filters,
hash: anchorLinkId,
})}
emailSubject={t('Superset chart')} emailSubject={t('Superset chart')}
emailContent={t('Check out this chart in dashboard:')} emailContent={t('Check out this chart in dashboard:')}
placement={placement} placement={placement}

View File

@ -23,48 +23,76 @@ import fetchMock from 'fetch-mock';
import URLShortLinkButton from 'src/components/URLShortLinkButton'; import URLShortLinkButton from 'src/components/URLShortLinkButton';
import ToastContainer from 'src/components/MessageToasts/ToastContainer'; import ToastContainer from 'src/components/MessageToasts/ToastContainer';
const fakeUrl = 'http://fakeurl.com'; const DASHBOARD_ID = 10;
const PERMALINK_PAYLOAD = {
key: '123',
url: 'http://fakeurl.com/123',
};
const FILTER_STATE_PAYLOAD = {
value: '{}',
};
fetchMock.post('glob:*/r/shortner/', fakeUrl); const props = {
dashboardId: DASHBOARD_ID,
};
fetchMock.get(
`glob:*/api/v1/dashboard/${DASHBOARD_ID}/filter_state*`,
FILTER_STATE_PAYLOAD,
);
fetchMock.post(
`glob:*/api/v1/dashboard/${DASHBOARD_ID}/permalink`,
PERMALINK_PAYLOAD,
);
test('renders with default props', () => { test('renders with default props', () => {
render(<URLShortLinkButton />, { useRedux: true }); render(<URLShortLinkButton {...props} />, { useRedux: true });
expect(screen.getByRole('button')).toBeInTheDocument(); expect(screen.getByRole('button')).toBeInTheDocument();
}); });
test('renders overlay on click', async () => { test('renders overlay on click', async () => {
render(<URLShortLinkButton />, { useRedux: true }); render(<URLShortLinkButton {...props} />, { useRedux: true });
userEvent.click(screen.getByRole('button')); userEvent.click(screen.getByRole('button'));
expect(await screen.findByRole('tooltip')).toBeInTheDocument(); expect(await screen.findByRole('tooltip')).toBeInTheDocument();
}); });
test('obtains short url', async () => { test('obtains short url', async () => {
render(<URLShortLinkButton />, { useRedux: true }); render(<URLShortLinkButton {...props} />, { useRedux: true });
userEvent.click(screen.getByRole('button')); userEvent.click(screen.getByRole('button'));
expect(await screen.findByRole('tooltip')).toHaveTextContent(fakeUrl); expect(await screen.findByRole('tooltip')).toHaveTextContent(
PERMALINK_PAYLOAD.url,
);
}); });
test('creates email anchor', async () => { test('creates email anchor', async () => {
const subject = 'Subject'; const subject = 'Subject';
const content = 'Content'; const content = 'Content';
render(<URLShortLinkButton emailSubject={subject} emailContent={content} />, { render(
useRedux: true, <URLShortLinkButton
}); {...props}
emailSubject={subject}
emailContent={content}
/>,
{
useRedux: true,
},
);
const href = `mailto:?Subject=${subject}%20&Body=${content}${fakeUrl}`; const href = `mailto:?Subject=${subject}%20&Body=${content}${PERMALINK_PAYLOAD.url}`;
userEvent.click(screen.getByRole('button')); userEvent.click(screen.getByRole('button'));
expect(await screen.findByRole('link')).toHaveAttribute('href', href); expect(await screen.findByRole('link')).toHaveAttribute('href', href);
}); });
test('renders error message on short url error', async () => { test('renders error message on short url error', async () => {
fetchMock.mock('glob:*/r/shortner/', 500, { fetchMock.mock(`glob:*/api/v1/dashboard/${DASHBOARD_ID}/permalink`, 500, {
overwriteRoutes: true, overwriteRoutes: true,
}); });
render( render(
<> <>
<URLShortLinkButton /> <URLShortLinkButton {...props} />
<ToastContainer /> <ToastContainer />
</>, </>,
{ useRedux: true }, { useRedux: true },

View File

@ -21,14 +21,17 @@ import PropTypes from 'prop-types';
import { t } from '@superset-ui/core'; import { t } from '@superset-ui/core';
import Popover from 'src/components/Popover'; import Popover from 'src/components/Popover';
import CopyToClipboard from 'src/components/CopyToClipboard'; import CopyToClipboard from 'src/components/CopyToClipboard';
import { getShortUrl } from 'src/utils/urlUtils'; import { getDashboardPermalink, getUrlParam } from 'src/utils/urlUtils';
import withToasts from 'src/components/MessageToasts/withToasts'; import withToasts from 'src/components/MessageToasts/withToasts';
import { URL_PARAMS } from 'src/constants';
import { getFilterValue } from 'src/dashboard/components/nativeFilters/FilterBar/keyValue';
const propTypes = { const propTypes = {
url: PropTypes.string, addDangerToast: PropTypes.func.isRequired,
anchorLinkId: PropTypes.string,
dashboardId: PropTypes.number,
emailSubject: PropTypes.string, emailSubject: PropTypes.string,
emailContent: PropTypes.string, emailContent: PropTypes.string,
addDangerToast: PropTypes.func.isRequired,
placement: PropTypes.oneOf(['right', 'left', 'top', 'bottom']), placement: PropTypes.oneOf(['right', 'left', 'top', 'bottom']),
}; };
@ -50,9 +53,20 @@ class URLShortLinkButton extends React.Component {
getCopyUrl(e) { getCopyUrl(e) {
e.stopPropagation(); e.stopPropagation();
getShortUrl(this.props.url) const nativeFiltersKey = getUrlParam(URL_PARAMS.nativeFiltersKey);
.then(this.onShortUrlSuccess) if (this.props.dashboardId) {
.catch(this.props.addDangerToast); getFilterValue(this.props.dashboardId, nativeFiltersKey)
.then(filterState =>
getDashboardPermalink(
String(this.props.dashboardId),
filterState,
this.props.anchorLinkId,
)
.then(this.onShortUrlSuccess)
.catch(this.props.addDangerToast),
)
.catch(this.props.addDangerToast);
}
} }
renderPopover() { renderPopover() {
@ -96,7 +110,6 @@ class URLShortLinkButton extends React.Component {
} }
URLShortLinkButton.defaultProps = { URLShortLinkButton.defaultProps = {
url: window.location.href.substring(window.location.origin.length),
placement: 'left', placement: 'left',
emailSubject: '', emailSubject: '',
emailContent: '', emailContent: '',

View File

@ -71,8 +71,24 @@ export const URL_PARAMS = {
name: 'force', name: 'force',
type: 'boolean', type: 'boolean',
}, },
permalinkKey: {
name: 'permalink_key',
type: 'string',
},
} as const; } as const;
export const RESERVED_CHART_URL_PARAMS: string[] = [
URL_PARAMS.formDataKey.name,
URL_PARAMS.sliceId.name,
URL_PARAMS.datasetId.name,
];
export const RESERVED_DASHBOARD_URL_PARAMS: string[] = [
URL_PARAMS.nativeFilters.name,
URL_PARAMS.nativeFiltersKey.name,
URL_PARAMS.permalinkKey.name,
URL_PARAMS.preselectFilters.name,
];
/** /**
* Faster debounce delay for inputs without expensive operation. * Faster debounce delay for inputs without expensive operation.
*/ */

View File

@ -135,8 +135,8 @@ test('should show the share actions', async () => {
}; };
render(setup(canShareProps)); render(setup(canShareProps));
await openDropdown(); await openDropdown();
expect(screen.getByText('Copy dashboard URL')).toBeInTheDocument(); expect(screen.getByText('Copy permalink to clipboard')).toBeInTheDocument();
expect(screen.getByText('Share dashboard by email')).toBeInTheDocument(); expect(screen.getByText('Share permalink by email')).toBeInTheDocument();
}); });
test('should render the "Save Modal" when user can save', async () => { test('should render the "Save Modal" when user can save', async () => {

View File

@ -257,8 +257,8 @@ class HeaderActionsDropdown extends React.PureComponent {
{userCanShare && ( {userCanShare && (
<ShareMenuItems <ShareMenuItems
url={url} url={url}
copyMenuItemTitle={t('Copy dashboard URL')} copyMenuItemTitle={t('Copy permalink to clipboard')}
emailMenuItemTitle={t('Share dashboard by email')} emailMenuItemTitle={t('Share permalink by email')}
emailSubject={emailSubject} emailSubject={emailSubject}
emailBody={emailBody} emailBody={emailBody}
addSuccessToast={addSuccessToast} addSuccessToast={addSuccessToast}

View File

@ -21,6 +21,7 @@ import moment from 'moment';
import { import {
Behavior, Behavior,
getChartMetadataRegistry, getChartMetadataRegistry,
QueryFormData,
styled, styled,
t, t,
} from '@superset-ui/core'; } from '@superset-ui/core';
@ -98,7 +99,7 @@ export interface SliceHeaderControlsProps {
isExpanded?: boolean; isExpanded?: boolean;
updatedDttm: number | null; updatedDttm: number | null;
isFullSize?: boolean; isFullSize?: boolean;
formData: { slice_id: number; datasource: string }; formData: Pick<QueryFormData, 'slice_id' | 'datasource'>;
onExploreChart: () => void; onExploreChart: () => void;
forceRefresh: (sliceId: number, dashboardId: number) => void; forceRefresh: (sliceId: number, dashboardId: number) => void;
@ -309,8 +310,8 @@ class SliceHeaderControls extends React.PureComponent<
{supersetCanShare && ( {supersetCanShare && (
<ShareMenuItems <ShareMenuItems
copyMenuItemTitle={t('Copy chart URL')} copyMenuItemTitle={t('Copy permalink to clipboard')}
emailMenuItemTitle={t('Share chart by email')} emailMenuItemTitle={t('Share permalink by email')}
emailSubject={t('Superset chart')} emailSubject={t('Superset chart')}
emailBody={t('Check out this chart: ')} emailBody={t('Check out this chart: ')}
addSuccessToast={addSuccessToast} addSuccessToast={addSuccessToast}

View File

@ -30,6 +30,7 @@ export const RENDER_TAB = 'RENDER_TAB';
export const RENDER_TAB_CONTENT = 'RENDER_TAB_CONTENT'; export const RENDER_TAB_CONTENT = 'RENDER_TAB_CONTENT';
const propTypes = { const propTypes = {
dashboardId: PropTypes.number.isRequired,
id: PropTypes.string.isRequired, id: PropTypes.string.isRequired,
parentId: PropTypes.string.isRequired, parentId: PropTypes.string.isRequired,
component: componentShape.isRequired, component: componentShape.isRequired,
@ -237,6 +238,7 @@ export default class Tab extends React.PureComponent {
{!editMode && ( {!editMode && (
<AnchorLink <AnchorLink
anchorLinkId={component.id} anchorLinkId={component.id}
dashboardId={this.props.dashboardId}
filters={filters} filters={filters}
showShortLinkButton showShortLinkButton
placement={index >= 5 ? 'left' : 'right'} placement={index >= 5 ? 'left' : 'right'}

View File

@ -31,7 +31,7 @@ const DASHBOARD_ID = '26';
const createProps = () => ({ const createProps = () => ({
addDangerToast: jest.fn(), addDangerToast: jest.fn(),
addSuccessToast: jest.fn(), addSuccessToast: jest.fn(),
url: `/superset/dashboard/${DASHBOARD_ID}/?preselect_filters=%7B%7D`, url: `/superset/dashboard/${DASHBOARD_ID}`,
copyMenuItemTitle: 'Copy dashboard URL', copyMenuItemTitle: 'Copy dashboard URL',
emailMenuItemTitle: 'Share dashboard by email', emailMenuItemTitle: 'Share dashboard by email',
emailSubject: 'Superset dashboard COVID Vaccine Dashboard', emailSubject: 'Superset dashboard COVID Vaccine Dashboard',
@ -45,10 +45,10 @@ beforeAll((): void => {
// @ts-ignore // @ts-ignore
delete window.location; delete window.location;
fetchMock.post( fetchMock.post(
'http://localhost/r/shortner/', `http://localhost/api/v1/dashboard/${DASHBOARD_ID}/permalink`,
{ body: 'http://localhost:8088/r/3' }, { key: '123', url: 'http://localhost/superset/dashboard/p/123/' },
{ {
sendAsJson: false, sendAsJson: true,
}, },
); );
}); });
@ -104,7 +104,7 @@ test('Click on "Copy dashboard URL" and succeed', async () => {
await waitFor(() => { await waitFor(() => {
expect(spy).toBeCalledTimes(1); expect(spy).toBeCalledTimes(1);
expect(spy).toBeCalledWith('http://localhost:8088/r/3'); expect(spy).toBeCalledWith('http://localhost/superset/dashboard/p/123/');
expect(props.addSuccessToast).toBeCalledTimes(1); expect(props.addSuccessToast).toBeCalledTimes(1);
expect(props.addSuccessToast).toBeCalledWith('Copied to clipboard!'); expect(props.addSuccessToast).toBeCalledWith('Copied to clipboard!');
expect(props.addDangerToast).toBeCalledTimes(0); expect(props.addDangerToast).toBeCalledTimes(0);
@ -130,7 +130,7 @@ test('Click on "Copy dashboard URL" and fail', async () => {
await waitFor(() => { await waitFor(() => {
expect(spy).toBeCalledTimes(1); expect(spy).toBeCalledTimes(1);
expect(spy).toBeCalledWith('http://localhost:8088/r/3'); expect(spy).toBeCalledWith('http://localhost/superset/dashboard/p/123/');
expect(props.addSuccessToast).toBeCalledTimes(0); expect(props.addSuccessToast).toBeCalledTimes(0);
expect(props.addDangerToast).toBeCalledTimes(1); expect(props.addDangerToast).toBeCalledTimes(1);
expect(props.addDangerToast).toBeCalledWith( expect(props.addDangerToast).toBeCalledWith(
@ -159,14 +159,14 @@ test('Click on "Share dashboard by email" and succeed', async () => {
await waitFor(() => { await waitFor(() => {
expect(props.addDangerToast).toBeCalledTimes(0); expect(props.addDangerToast).toBeCalledTimes(0);
expect(window.location.href).toBe( expect(window.location.href).toBe(
'mailto:?Subject=Superset%20dashboard%20COVID%20Vaccine%20Dashboard%20&Body=Check%20out%20this%20dashboard%3A%20http%3A%2F%2Flocalhost%3A8088%2Fr%2F3', 'mailto:?Subject=Superset%20dashboard%20COVID%20Vaccine%20Dashboard%20&Body=Check%20out%20this%20dashboard%3A%20http%3A%2F%2Flocalhost%2Fsuperset%2Fdashboard%2Fp%2F123%2F',
); );
}); });
}); });
test('Click on "Share dashboard by email" and fail', async () => { test('Click on "Share dashboard by email" and fail', async () => {
fetchMock.post( fetchMock.post(
'http://localhost/r/shortner/', `http://localhost/api/v1/dashboard/${DASHBOARD_ID}/permalink`,
{ status: 404 }, { status: 404 },
{ overwriteRoutes: true }, { overwriteRoutes: true },
); );

View File

@ -17,19 +17,16 @@
* under the License. * under the License.
*/ */
import React from 'react'; import React from 'react';
import { useUrlShortener } from 'src/hooks/useUrlShortener';
import copyTextToClipboard from 'src/utils/copy'; import copyTextToClipboard from 'src/utils/copy';
import { t, logging } from '@superset-ui/core'; import { t, logging, QueryFormData } from '@superset-ui/core';
import { Menu } from 'src/components/Menu'; import { Menu } from 'src/components/Menu';
import { getUrlParam } from 'src/utils/urlUtils';
import { postFormData } from 'src/explore/exploreUtils/formData';
import { useTabId } from 'src/hooks/useTabId';
import { URL_PARAMS } from 'src/constants';
import { mountExploreUrl } from 'src/explore/exploreUtils';
import { import {
createFilterKey, getChartPermalink,
getFilterValue, getDashboardPermalink,
} from 'src/dashboard/components/nativeFilters/FilterBar/keyValue'; getUrlParam,
} from 'src/utils/urlUtils';
import { RESERVED_DASHBOARD_URL_PARAMS, URL_PARAMS } from 'src/constants';
import { getFilterValue } from 'src/dashboard/components/nativeFilters/FilterBar/keyValue';
interface ShareMenuItemProps { interface ShareMenuItemProps {
url?: string; url?: string;
@ -40,12 +37,11 @@ interface ShareMenuItemProps {
addDangerToast: Function; addDangerToast: Function;
addSuccessToast: Function; addSuccessToast: Function;
dashboardId?: string; dashboardId?: string;
formData?: { slice_id: number; datasource: string }; formData?: Pick<QueryFormData, 'slice_id' | 'datasource'>;
} }
const ShareMenuItems = (props: ShareMenuItemProps) => { const ShareMenuItems = (props: ShareMenuItemProps) => {
const { const {
url,
copyMenuItemTitle, copyMenuItemTitle,
emailMenuItemTitle, emailMenuItemTitle,
emailSubject, emailSubject,
@ -57,47 +53,25 @@ const ShareMenuItems = (props: ShareMenuItemProps) => {
...rest ...rest
} = props; } = props;
const tabId = useTabId();
const getShortUrl = useUrlShortener(url || '');
async function getCopyUrl() {
const risonObj = getUrlParam(URL_PARAMS.nativeFilters);
if (typeof risonObj === 'object' || !dashboardId) return null;
const prevData = await getFilterValue(
dashboardId,
getUrlParam(URL_PARAMS.nativeFiltersKey),
);
const newDataMaskKey = await createFilterKey(
dashboardId,
JSON.stringify(prevData),
tabId,
);
const newUrl = new URL(`${window.location.origin}${url}`);
newUrl.searchParams.set(URL_PARAMS.nativeFilters.name, newDataMaskKey);
return `${newUrl.pathname}${newUrl.search}`;
}
async function generateUrl() { async function generateUrl() {
// chart
if (formData) { if (formData) {
const key = await postFormData( // we need to remove reserved dashboard url params
parseInt(formData.datasource.split('_')[0], 10), return getChartPermalink(formData, RESERVED_DASHBOARD_URL_PARAMS);
formData,
formData.slice_id,
tabId,
);
return `${window.location.origin}${mountExploreUrl(null, {
[URL_PARAMS.formDataKey.name]: key,
[URL_PARAMS.sliceId.name]: formData.slice_id,
})}`;
} }
const copyUrl = await getCopyUrl(); // dashboard
return getShortUrl(copyUrl); const nativeFiltersKey = getUrlParam(URL_PARAMS.nativeFiltersKey);
let filterState = {};
if (nativeFiltersKey && dashboardId) {
filterState = await getFilterValue(dashboardId, nativeFiltersKey);
}
return getDashboardPermalink(String(dashboardId), filterState);
} }
async function onCopyLink() { async function onCopyLink() {
try { try {
await copyTextToClipboard(await generateUrl()); const url = await generateUrl();
await copyTextToClipboard(url);
addSuccessToast(t('Copied to clipboard!')); addSuccessToast(t('Copied to clipboard!'));
} catch (error) { } catch (error) {
logging.error(error); logging.error(error);

View File

@ -165,6 +165,11 @@ export interface FiltersBarProps {
offset: number; offset: number;
} }
const EXCLUDED_URL_PARAMS: string[] = [
URL_PARAMS.nativeFilters.name,
URL_PARAMS.permalinkKey.name,
];
const publishDataMask = debounce( const publishDataMask = debounce(
async ( async (
history, history,
@ -177,9 +182,9 @@ const publishDataMask = debounce(
const { search } = location; const { search } = location;
const previousParams = new URLSearchParams(search); const previousParams = new URLSearchParams(search);
const newParams = new URLSearchParams(); const newParams = new URLSearchParams();
let dataMaskKey: string; let dataMaskKey: string | null;
previousParams.forEach((value, key) => { previousParams.forEach((value, key) => {
if (key !== URL_PARAMS.nativeFilters.name) { if (!EXCLUDED_URL_PARAMS.includes(key)) {
newParams.append(key, value); newParams.append(key, value);
} }
}); });
@ -200,7 +205,9 @@ const publishDataMask = debounce(
} else { } else {
dataMaskKey = await createFilterKey(dashboardId, dataMask, tabId); dataMaskKey = await createFilterKey(dashboardId, dataMask, tabId);
} }
newParams.set(URL_PARAMS.nativeFiltersKey.name, dataMaskKey); if (dataMaskKey) {
newParams.set(URL_PARAMS.nativeFiltersKey.name, dataMaskKey);
}
// pathname could be updated somewhere else through window.history // pathname could be updated somewhere else through window.history
// keep react router history in sync with window history // keep react router history in sync with window history

View File

@ -17,6 +17,7 @@
* under the License. * under the License.
*/ */
import { SupersetClient, logging } from '@superset-ui/core'; import { SupersetClient, logging } from '@superset-ui/core';
import { DashboardPermalinkValue } from 'src/dashboard/types';
const assembleEndpoint = ( const assembleEndpoint = (
dashId: string | number, dashId: string | number,
@ -58,7 +59,7 @@ export const createFilterKey = (
endpoint: assembleEndpoint(dashId, undefined, tabId), endpoint: assembleEndpoint(dashId, undefined, tabId),
jsonPayload: { value }, jsonPayload: { value },
}) })
.then(r => r.json.key) .then(r => r.json.key as string)
.catch(err => { .catch(err => {
logging.error(err); logging.error(err);
return null; return null;
@ -73,3 +74,13 @@ export const getFilterValue = (dashId: string | number, key: string) =>
logging.error(err); logging.error(err);
return null; return null;
}); });
export const getPermalinkValue = (key: string) =>
SupersetClient.get({
endpoint: `/api/v1/dashboard/permalink/${key}`,
})
.then(({ json }) => json as DashboardPermalinkValue)
.catch(err => {
logging.error(err);
return null;
});

View File

@ -49,7 +49,10 @@ import { URL_PARAMS } from 'src/constants';
import { getUrlParam } from 'src/utils/urlUtils'; import { getUrlParam } from 'src/utils/urlUtils';
import { canUserEditDashboard } from 'src/dashboard/util/findPermission'; import { canUserEditDashboard } from 'src/dashboard/util/findPermission';
import { getFilterSets } from '../actions/nativeFilters'; import { getFilterSets } from '../actions/nativeFilters';
import { getFilterValue } from '../components/nativeFilters/FilterBar/keyValue'; import {
getFilterValue,
getPermalinkValue,
} from '../components/nativeFilters/FilterBar/keyValue';
import { filterCardPopoverStyle } from '../styles'; import { filterCardPopoverStyle } from '../styles';
export const MigrationContext = React.createContext( export const MigrationContext = React.createContext(
@ -161,12 +164,17 @@ const DashboardPage: FC = () => {
useEffect(() => { useEffect(() => {
// eslint-disable-next-line consistent-return // eslint-disable-next-line consistent-return
async function getDataMaskApplied() { async function getDataMaskApplied() {
const permalinkKey = getUrlParam(URL_PARAMS.permalinkKey);
const nativeFilterKeyValue = getUrlParam(URL_PARAMS.nativeFiltersKey); const nativeFilterKeyValue = getUrlParam(URL_PARAMS.nativeFiltersKey);
let dataMaskFromUrl = nativeFilterKeyValue || {}; let dataMaskFromUrl = nativeFilterKeyValue || {};
const isOldRison = getUrlParam(URL_PARAMS.nativeFilters); const isOldRison = getUrlParam(URL_PARAMS.nativeFilters);
// check if key from key_value api and get datamask if (permalinkKey) {
if (nativeFilterKeyValue) { const permalinkValue = await getPermalinkValue(permalinkKey);
if (permalinkValue) {
dataMaskFromUrl = permalinkValue.state.filterState;
}
} else if (nativeFilterKeyValue) {
dataMaskFromUrl = await getFilterValue(id, nativeFilterKeyValue); dataMaskFromUrl = await getFilterValue(id, nativeFilterKeyValue);
} }
if (isOldRison) { if (isOldRison) {

View File

@ -144,3 +144,11 @@ type ActiveFilter = {
export type ActiveFilters = { export type ActiveFilters = {
[key: string]: ActiveFilter; [key: string]: ActiveFilter;
}; };
export type DashboardPermalinkValue = {
dashboardId: string;
state: {
filterState: DataMaskStateWithId;
hash: string;
};
};

View File

@ -25,6 +25,7 @@ import Icons from 'src/components/Icons';
import { Tooltip } from 'src/components/Tooltip'; import { Tooltip } from 'src/components/Tooltip';
import CopyToClipboard from 'src/components/CopyToClipboard'; import CopyToClipboard from 'src/components/CopyToClipboard';
import { URL_PARAMS } from 'src/constants'; import { URL_PARAMS } from 'src/constants';
import { getChartPermalink } from 'src/utils/urlUtils';
export default class EmbedCodeButton extends React.Component { export default class EmbedCodeButton extends React.Component {
constructor(props) { constructor(props) {
@ -32,8 +33,11 @@ export default class EmbedCodeButton extends React.Component {
this.state = { this.state = {
height: '400', height: '400',
width: '600', width: '600',
url: '',
errorMessage: '',
}; };
this.handleInputChange = this.handleInputChange.bind(this); this.handleInputChange = this.handleInputChange.bind(this);
this.updateUrl = this.updateUrl.bind(this);
} }
handleInputChange(e) { handleInputChange(e) {
@ -43,8 +47,21 @@ export default class EmbedCodeButton extends React.Component {
this.setState(data); this.setState(data);
} }
updateUrl() {
this.setState({ url: '' });
getChartPermalink(this.props.formData)
.then(url => this.setState({ errorMessage: '', url }))
.catch(() => {
this.setState({ errorMessage: t('Error') });
this.props.addDangerToast(
t('Sorry, something went wrong. Try again later.'),
);
});
}
generateEmbedHTML() { generateEmbedHTML() {
const srcLink = `${window.location.href}&${URL_PARAMS.standalone.name}=1&height=${this.state.height}`; if (!this.state.url) return '';
const srcLink = `${this.state.url}?${URL_PARAMS.standalone.name}=1&height=${this.state.height}`;
return ( return (
'<iframe\n' + '<iframe\n' +
` width="${this.state.width}"\n` + ` width="${this.state.width}"\n` +
@ -60,6 +77,8 @@ export default class EmbedCodeButton extends React.Component {
renderPopoverContent() { renderPopoverContent() {
const html = this.generateEmbedHTML(); const html = this.generateEmbedHTML();
const text =
this.state.errorMessage || html || t('Generating link, please wait..');
return ( return (
<div id="embed-code-popover" data-test="embed-code-popover"> <div id="embed-code-popover" data-test="embed-code-popover">
<div className="row"> <div className="row">
@ -67,7 +86,8 @@ export default class EmbedCodeButton extends React.Component {
<textarea <textarea
data-test="embed-code-textarea" data-test="embed-code-textarea"
name="embedCode" name="embedCode"
value={html} disabled={!html}
value={text}
rows="4" rows="4"
readOnly readOnly
className="form-control input-sm" className="form-control input-sm"
@ -125,6 +145,7 @@ export default class EmbedCodeButton extends React.Component {
<Popover <Popover
trigger="click" trigger="click"
placement="left" placement="left"
onClick={this.updateUrl}
content={this.renderPopoverContent()} content={this.renderPopoverContent()}
> >
<Tooltip <Tooltip

View File

@ -34,12 +34,14 @@ describe('EmbedCodeButton', () => {
}); });
it('returns correct embed code', () => { it('returns correct embed code', () => {
const href = 'http://localhost/explore?form_data_key=xxxxxxxxx'; const wrapper = mount(
Object.defineProperty(window, 'location', { value: { href } }); <EmbedCodeButton formData={{}} addDangerToast={() => {}} />,
const wrapper = mount(<EmbedCodeButton />); );
const url = 'http://localhost/explore/p/100';
wrapper.find(EmbedCodeButton).setState({ wrapper.find(EmbedCodeButton).setState({
height: '1000', height: '1000',
width: '2000', width: '2000',
url,
}); });
const embedHTML = const embedHTML =
`${ `${
@ -49,7 +51,7 @@ describe('EmbedCodeButton', () => {
' seamless\n' + ' seamless\n' +
' frameBorder="0"\n' + ' frameBorder="0"\n' +
' scrolling="no"\n' + ' scrolling="no"\n' +
` src="${href}&standalone=` ` src="${url}?standalone=`
}${DashboardStandaloneMode.HIDE_NAV}&height=1000"\n` + }${DashboardStandaloneMode.HIDE_NAV}&height=1000"\n` +
`>\n` + `>\n` +
`</iframe>`; `</iframe>`;

View File

@ -21,7 +21,9 @@ import cx from 'classnames';
import { QueryFormData, t } from '@superset-ui/core'; import { QueryFormData, t } from '@superset-ui/core';
import Icons from 'src/components/Icons'; import Icons from 'src/components/Icons';
import { Tooltip } from 'src/components/Tooltip'; import { Tooltip } from 'src/components/Tooltip';
import { Slice } from 'src/types/Chart';
import copyTextToClipboard from 'src/utils/copy'; import copyTextToClipboard from 'src/utils/copy';
import { getChartPermalink } from 'src/utils/urlUtils';
import withToasts from 'src/components/MessageToasts/withToasts'; import withToasts from 'src/components/MessageToasts/withToasts';
import EmbedCodeButton from './EmbedCodeButton'; import EmbedCodeButton from './EmbedCodeButton';
import { exportChart } from '../exploreUtils'; import { exportChart } from '../exploreUtils';
@ -45,7 +47,7 @@ type ExploreActionButtonsProps = {
chartStatus: string; chartStatus: string;
latestQueryFormData: QueryFormData; latestQueryFormData: QueryFormData;
queriesResponse: {}; queriesResponse: {};
slice: { slice_name: string }; slice: Slice;
addDangerToast: Function; addDangerToast: Function;
addSuccessToast: Function; addSuccessToast: Function;
}; };
@ -101,26 +103,26 @@ const ExploreActionButtons = (props: ExploreActionButtonsProps) => {
addSuccessToast, addSuccessToast,
} = props; } = props;
const copyTooltipText = t('Copy chart URL to clipboard'); const copyTooltipText = t('Copy permalink to clipboard');
const [copyTooltip, setCopyTooltip] = useState(copyTooltipText); const [copyTooltip, setCopyTooltip] = useState(copyTooltipText);
const doCopyLink = async () => { const doCopyLink = async () => {
try { try {
setCopyTooltip(t('Loading...')); setCopyTooltip(t('Loading...'));
const url = window.location.href; const url = await getChartPermalink(latestQueryFormData);
await copyTextToClipboard(url); await copyTextToClipboard(url);
setCopyTooltip(t('Copied to clipboard!')); setCopyTooltip(t('Copied to clipboard!'));
addSuccessToast(t('Copied to clipboard!')); addSuccessToast(t('Copied to clipboard!'));
} catch (error) { } catch (error) {
setCopyTooltip(t('Sorry, your browser does not support copying.')); setCopyTooltip(t('Copying permalink failed.'));
addDangerToast(t('Sorry, your browser does not support copying.')); addDangerToast(t('Sorry, something went wrong. Try again later.'));
} }
}; };
const doShareEmail = async () => { const doShareEmail = async () => {
try { try {
const subject = t('Superset Chart'); const subject = t('Superset Chart');
const url = window.location.href; const url = await getChartPermalink(latestQueryFormData);
const body = encodeURIComponent(t('%s%s', 'Check out this chart: ', url)); const body = encodeURIComponent(t('%s%s', 'Check out this chart: ', url));
window.location.href = `mailto:?Subject=${subject}%20&Body=${body}`; window.location.href = `mailto:?Subject=${subject}%20&Body=${body}`;
} catch (error) { } catch (error) {
@ -173,10 +175,13 @@ const ExploreActionButtons = (props: ExploreActionButtonsProps) => {
/> />
<ActionButton <ActionButton
prefixIcon={<Icons.Email iconSize="l" />} prefixIcon={<Icons.Email iconSize="l" />}
tooltip={t('Share chart by email')} tooltip={t('Share permalink by email')}
onClick={doShareEmail} onClick={doShareEmail}
/> />
<EmbedCodeButton /> <EmbedCodeButton
formData={latestQueryFormData}
addDangerToast={addDangerToast}
/>
<ActionButton <ActionButton
prefixIcon={<Icons.FileTextOutlined iconSize="m" />} prefixIcon={<Icons.FileTextOutlined iconSize="m" />}
text=".JSON" text=".JSON"

View File

@ -36,7 +36,7 @@ import {
setItem, setItem,
LocalStorageKeys, LocalStorageKeys,
} from 'src/utils/localStorageHelpers'; } from 'src/utils/localStorageHelpers';
import { URL_PARAMS } from 'src/constants'; import { RESERVED_CHART_URL_PARAMS, URL_PARAMS } from 'src/constants';
import { getUrlParam } from 'src/utils/urlUtils'; import { getUrlParam } from 'src/utils/urlUtils';
import cx from 'classnames'; import cx from 'classnames';
import * as chartActions from 'src/components/Chart/chartAction'; import * as chartActions from 'src/components/Chart/chartAction';
@ -177,13 +177,7 @@ const updateHistory = debounce(
const urlParams = payload?.url_params || {}; const urlParams = payload?.url_params || {};
Object.entries(urlParams).forEach(([key, value]) => { Object.entries(urlParams).forEach(([key, value]) => {
if ( if (!RESERVED_CHART_URL_PARAMS.includes(key)) {
![
URL_PARAMS.sliceId.name,
URL_PARAMS.formDataKey.name,
URL_PARAMS.datasetId.name,
].includes(key)
) {
additionalParam[key] = value; additionalParam[key] = value;
} }
}); });

View File

@ -1,39 +0,0 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { useState, useEffect } from 'react';
import { getShortUrl as getShortUrlUtil } from 'src/utils/urlUtils';
export function useUrlShortener(url: string): Function {
const [update, setUpdate] = useState(false);
const [shortUrl, setShortUrl] = useState('');
async function getShortUrl(urlOverride?: string) {
if (update) {
const newShortUrl = await getShortUrlUtil(urlOverride || url);
setShortUrl(newShortUrl);
setUpdate(false);
return newShortUrl;
}
return shortUrl;
}
useEffect(() => setUpdate(true), [url]);
return getShortUrl;
}

View File

@ -16,10 +16,17 @@
* specific language governing permissions and limitations * specific language governing permissions and limitations
* under the License. * under the License.
*/ */
import { SupersetClient } from '@superset-ui/core'; import { JsonObject, QueryFormData, SupersetClient } from '@superset-ui/core';
import rison from 'rison'; import rison from 'rison';
import { isEmpty } from 'lodash';
import { getClientErrorObject } from './getClientErrorObject'; import { getClientErrorObject } from './getClientErrorObject';
import { URL_PARAMS } from '../constants'; import {
RESERVED_CHART_URL_PARAMS,
RESERVED_DASHBOARD_URL_PARAMS,
URL_PARAMS,
} from '../constants';
import { getActiveFilters } from '../dashboard/util/activeDashboardFilters';
import serializeActiveFilterValues from '../dashboard/util/serializeActiveFilterValues';
export type UrlParamType = 'string' | 'number' | 'boolean' | 'object' | 'rison'; export type UrlParamType = 'string' | 'number' | 'boolean' | 'object' | 'rison';
export type UrlParam = typeof URL_PARAMS[keyof typeof URL_PARAMS]; export type UrlParam = typeof URL_PARAMS[keyof typeof URL_PARAMS];
@ -72,14 +79,55 @@ export function getUrlParam({ name, type }: UrlParam): unknown {
} }
} }
export function getShortUrl(longUrl: string) { function getUrlParams(excludedParams: string[]): URLSearchParams {
const urlParams = new URLSearchParams();
const currentParams = new URLSearchParams(window.location.search);
currentParams.forEach((value, key) => {
if (!excludedParams.includes(key)) urlParams.append(key, value);
});
return urlParams;
}
type UrlParamEntries = [string, string][];
function getUrlParamEntries(urlParams: URLSearchParams): UrlParamEntries {
const urlEntries: [string, string][] = [];
urlParams.forEach((value, key) => urlEntries.push([key, value]));
return urlEntries;
}
function getChartUrlParams(excludedUrlParams?: string[]): UrlParamEntries {
const excludedParams = excludedUrlParams || RESERVED_CHART_URL_PARAMS;
const urlParams = getUrlParams(excludedParams);
const filterBoxFilters = getActiveFilters();
if (
!isEmpty(filterBoxFilters) &&
!excludedParams.includes(URL_PARAMS.preselectFilters.name)
)
urlParams.append(
URL_PARAMS.preselectFilters.name,
JSON.stringify(serializeActiveFilterValues(getActiveFilters())),
);
return getUrlParamEntries(urlParams);
}
function getDashboardUrlParams(): UrlParamEntries {
const urlParams = getUrlParams(RESERVED_DASHBOARD_URL_PARAMS);
const filterBoxFilters = getActiveFilters();
if (!isEmpty(filterBoxFilters))
urlParams.append(
URL_PARAMS.preselectFilters.name,
JSON.stringify(serializeActiveFilterValues(getActiveFilters())),
);
return getUrlParamEntries(urlParams);
}
function getPermalink(endpoint: string, jsonPayload: JsonObject) {
return SupersetClient.post({ return SupersetClient.post({
endpoint: '/r/shortner/', endpoint,
postPayload: { data: `/${longUrl}` }, // note: url should contain 2x '/' to redirect properly jsonPayload,
parseMethod: 'text',
stringify: false, // the url saves with an extra set of string quotes without this
}) })
.then(({ text }) => text) .then(result => result.json.url as string)
.catch(response => .catch(response =>
// @ts-ignore // @ts-ignore
getClientErrorObject(response).then(({ error, statusText }) => getClientErrorObject(response).then(({ error, statusText }) =>
@ -87,3 +135,26 @@ export function getShortUrl(longUrl: string) {
), ),
); );
} }
export function getChartPermalink(
formData: Pick<QueryFormData, 'datasource'>,
excludedUrlParams?: string[],
) {
return getPermalink('/api/v1/explore/permalink', {
formData,
urlParams: getChartUrlParams(excludedUrlParams),
});
}
export function getDashboardPermalink(
dashboardId: string,
filterState: JsonObject,
hash?: string,
) {
// only encode filter box state if non-empty
return getPermalink(`/api/v1/dashboard/${dashboardId}/permalink`, {
filterState,
urlParams: getDashboardUrlParams(),
hash,
});
}

View File

@ -44,6 +44,7 @@ from werkzeug.local import LocalProxy
from superset.constants import CHANGE_ME_SECRET_KEY from superset.constants import CHANGE_ME_SECRET_KEY
from superset.jinja_context import BaseTemplateProcessor from superset.jinja_context import BaseTemplateProcessor
from superset.key_value.types import KeyType
from superset.stats_logger import DummyStatsLogger from superset.stats_logger import DummyStatsLogger
from superset.typing import CacheConfig from superset.typing import CacheConfig
from superset.utils.core import is_test, parse_boolean_string from superset.utils.core import is_test, parse_boolean_string
@ -611,6 +612,8 @@ EXPLORE_FORM_DATA_CACHE_CONFIG: CacheConfig = {
# store cache keys by datasource UID (via CacheKey) for custom processing/invalidation # store cache keys by datasource UID (via CacheKey) for custom processing/invalidation
STORE_CACHE_KEYS_IN_METADATA_DB = False STORE_CACHE_KEYS_IN_METADATA_DB = False
PERMALINK_KEY_TYPE: KeyType = "uuid"
# CORS Options # CORS Options
ENABLE_CORS = False ENABLE_CORS = False
CORS_OPTIONS: Dict[Any, Any] = {} CORS_OPTIONS: Dict[Any, Any] = {}

View File

@ -25,12 +25,12 @@ from superset.dashboards.filter_state.commands.delete import DeleteFilterStateCo
from superset.dashboards.filter_state.commands.get import GetFilterStateCommand from superset.dashboards.filter_state.commands.get import GetFilterStateCommand
from superset.dashboards.filter_state.commands.update import UpdateFilterStateCommand from superset.dashboards.filter_state.commands.update import UpdateFilterStateCommand
from superset.extensions import event_logger from superset.extensions import event_logger
from superset.key_value.api import KeyValueRestApi from superset.temporary_cache.api import TemporaryCacheRestApi
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DashboardFilterStateRestApi(KeyValueRestApi): class DashboardFilterStateRestApi(TemporaryCacheRestApi):
class_permission_name = "DashboardFilterStateRestApi" class_permission_name = "DashboardFilterStateRestApi"
resource_name = "dashboard" resource_name = "dashboard"
openapi_spec_tag = "Dashboard Filter State" openapi_spec_tag = "Dashboard Filter State"
@ -74,7 +74,7 @@ class DashboardFilterStateRestApi(KeyValueRestApi):
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/KeyValuePostSchema' $ref: '#/components/schemas/TemporaryCachePostSchema'
responses: responses:
201: 201:
description: The value was stored successfully. description: The value was stored successfully.
@ -128,7 +128,7 @@ class DashboardFilterStateRestApi(KeyValueRestApi):
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/KeyValuePutSchema' $ref: '#/components/schemas/TemporaryCachePutSchema'
responses: responses:
200: 200:
description: The value was stored successfully. description: The value was stored successfully.

View File

@ -18,13 +18,13 @@ from flask import session
from superset.dashboards.dao import DashboardDAO from superset.dashboards.dao import DashboardDAO
from superset.extensions import cache_manager from superset.extensions import cache_manager
from superset.key_value.commands.create import CreateKeyValueCommand from superset.temporary_cache.commands.create import CreateTemporaryCacheCommand
from superset.key_value.commands.entry import Entry from superset.temporary_cache.commands.entry import Entry
from superset.key_value.commands.parameters import CommandParameters from superset.temporary_cache.commands.parameters import CommandParameters
from superset.key_value.utils import cache_key, random_key from superset.temporary_cache.utils import cache_key, random_key
class CreateFilterStateCommand(CreateKeyValueCommand): class CreateFilterStateCommand(CreateTemporaryCacheCommand):
def create(self, cmd_params: CommandParameters) -> str: def create(self, cmd_params: CommandParameters) -> str:
resource_id = cmd_params.resource_id resource_id = cmd_params.resource_id
actor = cmd_params.actor actor = cmd_params.actor

View File

@ -18,14 +18,14 @@ from flask import session
from superset.dashboards.dao import DashboardDAO from superset.dashboards.dao import DashboardDAO
from superset.extensions import cache_manager from superset.extensions import cache_manager
from superset.key_value.commands.delete import DeleteKeyValueCommand from superset.temporary_cache.commands.delete import DeleteTemporaryCacheCommand
from superset.key_value.commands.entry import Entry from superset.temporary_cache.commands.entry import Entry
from superset.key_value.commands.exceptions import KeyValueAccessDeniedError from superset.temporary_cache.commands.exceptions import TemporaryCacheAccessDeniedError
from superset.key_value.commands.parameters import CommandParameters from superset.temporary_cache.commands.parameters import CommandParameters
from superset.key_value.utils import cache_key from superset.temporary_cache.utils import cache_key
class DeleteFilterStateCommand(DeleteKeyValueCommand): class DeleteFilterStateCommand(DeleteTemporaryCacheCommand):
def delete(self, cmd_params: CommandParameters) -> bool: def delete(self, cmd_params: CommandParameters) -> bool:
resource_id = cmd_params.resource_id resource_id = cmd_params.resource_id
actor = cmd_params.actor actor = cmd_params.actor
@ -35,7 +35,7 @@ class DeleteFilterStateCommand(DeleteKeyValueCommand):
entry: Entry = cache_manager.filter_state_cache.get(key) entry: Entry = cache_manager.filter_state_cache.get(key)
if entry: if entry:
if entry["owner"] != actor.get_user_id(): if entry["owner"] != actor.get_user_id():
raise KeyValueAccessDeniedError() raise TemporaryCacheAccessDeniedError()
tab_id = cmd_params.tab_id tab_id = cmd_params.tab_id
contextual_key = cache_key(session.get("_id"), tab_id, resource_id) contextual_key = cache_key(session.get("_id"), tab_id, resource_id)
cache_manager.filter_state_cache.delete(contextual_key) cache_manager.filter_state_cache.delete(contextual_key)

View File

@ -20,12 +20,12 @@ from flask import current_app as app
from superset.dashboards.dao import DashboardDAO from superset.dashboards.dao import DashboardDAO
from superset.extensions import cache_manager from superset.extensions import cache_manager
from superset.key_value.commands.get import GetKeyValueCommand from superset.temporary_cache.commands.get import GetTemporaryCacheCommand
from superset.key_value.commands.parameters import CommandParameters from superset.temporary_cache.commands.parameters import CommandParameters
from superset.key_value.utils import cache_key from superset.temporary_cache.utils import cache_key
class GetFilterStateCommand(GetKeyValueCommand): class GetFilterStateCommand(GetTemporaryCacheCommand):
def __init__(self, cmd_params: CommandParameters) -> None: def __init__(self, cmd_params: CommandParameters) -> None:
super().__init__(cmd_params) super().__init__(cmd_params)
config = app.config["FILTER_STATE_CACHE_CONFIG"] config = app.config["FILTER_STATE_CACHE_CONFIG"]

View File

@ -20,14 +20,14 @@ from flask import session
from superset.dashboards.dao import DashboardDAO from superset.dashboards.dao import DashboardDAO
from superset.extensions import cache_manager from superset.extensions import cache_manager
from superset.key_value.commands.entry import Entry from superset.temporary_cache.commands.entry import Entry
from superset.key_value.commands.exceptions import KeyValueAccessDeniedError from superset.temporary_cache.commands.exceptions import TemporaryCacheAccessDeniedError
from superset.key_value.commands.parameters import CommandParameters from superset.temporary_cache.commands.parameters import CommandParameters
from superset.key_value.commands.update import UpdateKeyValueCommand from superset.temporary_cache.commands.update import UpdateTemporaryCacheCommand
from superset.key_value.utils import cache_key, random_key from superset.temporary_cache.utils import cache_key, random_key
class UpdateFilterStateCommand(UpdateKeyValueCommand): class UpdateFilterStateCommand(UpdateTemporaryCacheCommand):
def update(self, cmd_params: CommandParameters) -> Optional[str]: def update(self, cmd_params: CommandParameters) -> Optional[str]:
resource_id = cmd_params.resource_id resource_id = cmd_params.resource_id
actor = cmd_params.actor actor = cmd_params.actor
@ -41,7 +41,7 @@ class UpdateFilterStateCommand(UpdateKeyValueCommand):
if entry: if entry:
user_id = actor.get_user_id() user_id = actor.get_user_id()
if entry["owner"] != user_id: if entry["owner"] != user_id:
raise KeyValueAccessDeniedError() raise TemporaryCacheAccessDeniedError()
# Generate a new key if tab_id changes or equals 0 # Generate a new key if tab_id changes or equals 0
contextual_key = cache_key( contextual_key = cache_key(

View File

@ -0,0 +1,171 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from flask import current_app, g, request, Response
from flask_appbuilder.api import BaseApi, expose, protect, safe
from marshmallow import ValidationError
from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod
from superset.dashboards.commands.exceptions import (
DashboardAccessDeniedError,
DashboardNotFoundError,
)
from superset.dashboards.permalink.commands.create import (
CreateDashboardPermalinkCommand,
)
from superset.dashboards.permalink.commands.get import GetDashboardPermalinkCommand
from superset.dashboards.permalink.exceptions import DashboardPermalinkInvalidStateError
from superset.dashboards.permalink.schemas import DashboardPermalinkPostSchema
from superset.extensions import event_logger
from superset.key_value.exceptions import KeyValueAccessDeniedError
from superset.views.base_api import requires_json
logger = logging.getLogger(__name__)
class DashboardPermalinkRestApi(BaseApi):
add_model_schema = DashboardPermalinkPostSchema()
method_permission_name = MODEL_API_RW_METHOD_PERMISSION_MAP
include_route_methods = {
RouteMethod.POST,
RouteMethod.PUT,
RouteMethod.GET,
RouteMethod.DELETE,
}
allow_browser_login = True
class_permission_name = "DashboardPermalinkRestApi"
resource_name = "dashboard"
openapi_spec_tag = "Dashboard Permanent Link"
openapi_spec_component_schemas = (DashboardPermalinkPostSchema,)
@expose("/<pk>/permalink", methods=["POST"])
@protect()
@safe
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.post",
log_to_statsd=False,
)
@requires_json
def post(self, pk: str) -> Response:
"""Stores a new permanent link.
---
post:
description: >-
Stores a new permanent link.
parameters:
- in: path
schema:
type: string
name: pk
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/DashboardPermalinkPostSchema'
responses:
201:
description: The permanent link was stored successfully.
content:
application/json:
schema:
type: object
properties:
key:
type: string
description: The key to retrieve the permanent link data.
url:
type: string
description: permanent link.
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
key_type = current_app.config["PERMALINK_KEY_TYPE"]
try:
state = self.add_model_schema.load(request.json)
key = CreateDashboardPermalinkCommand(
actor=g.user, dashboard_id=pk, state=state, key_type=key_type,
).run()
http_origin = request.headers.environ.get("HTTP_ORIGIN")
url = f"{http_origin}/superset/dashboard/p/{key}/"
return self.response(201, key=key, url=url)
except (ValidationError, DashboardPermalinkInvalidStateError) as ex:
return self.response(400, message=str(ex))
except (DashboardAccessDeniedError, KeyValueAccessDeniedError,) as ex:
return self.response(403, message=str(ex))
except DashboardNotFoundError as ex:
return self.response(404, message=str(ex))
@expose("/permalink/<string:key>", methods=["GET"])
@protect()
@safe
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.get",
log_to_statsd=False,
)
def get(self, key: str) -> Response:
"""Retrives permanent link state for dashboard.
---
get:
description: >-
Retrives dashboard state associated with a permanent link.
parameters:
- in: path
schema:
type: string
name: key
responses:
200:
description: Returns the stored state.
content:
application/json:
schema:
type: object
properties:
state:
type: object
description: The stored state
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
404:
$ref: '#/components/responses/404'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
try:
key_type = current_app.config["PERMALINK_KEY_TYPE"]
value = GetDashboardPermalinkCommand(
actor=g.user, key=key, key_type=key_type
).run()
if not value:
return self.response_404()
return self.response(200, **value)
except DashboardAccessDeniedError as ex:
return self.response(403, message=str(ex))
except DashboardNotFoundError as ex:
return self.response(404, message=str(ex))

View File

@ -0,0 +1,23 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from abc import ABC
from superset.commands.base import BaseCommand
class BaseDashboardPermalinkCommand(BaseCommand, ABC):
resource = "dashboard_permalink"

View File

@ -0,0 +1,61 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from flask_appbuilder.security.sqla.models import User
from sqlalchemy.exc import SQLAlchemyError
from superset.dashboards.dao import DashboardDAO
from superset.dashboards.permalink.commands.base import BaseDashboardPermalinkCommand
from superset.dashboards.permalink.exceptions import DashboardPermalinkCreateFailedError
from superset.dashboards.permalink.types import DashboardPermalinkState
from superset.key_value.commands.create import CreateKeyValueCommand
from superset.key_value.types import KeyType
logger = logging.getLogger(__name__)
class CreateDashboardPermalinkCommand(BaseDashboardPermalinkCommand):
def __init__(
self,
actor: User,
dashboard_id: str,
state: DashboardPermalinkState,
key_type: KeyType,
):
self.actor = actor
self.dashboard_id = dashboard_id
self.state = state
self.key_type = key_type
def run(self) -> str:
self.validate()
try:
DashboardDAO.get_by_id_or_slug(self.dashboard_id)
value = {
"dashboardId": self.dashboard_id,
"state": self.state,
}
return CreateKeyValueCommand(
self.actor, self.resource, value, self.key_type
).run()
except SQLAlchemyError as ex:
logger.exception("Error running create command")
raise DashboardPermalinkCreateFailedError() from ex
def validate(self) -> None:
pass

View File

@ -0,0 +1,65 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Optional
from flask_appbuilder.security.sqla.models import User
from sqlalchemy.exc import SQLAlchemyError
from superset.dashboards.commands.exceptions import DashboardNotFoundError
from superset.dashboards.dao import DashboardDAO
from superset.dashboards.permalink.commands.base import BaseDashboardPermalinkCommand
from superset.dashboards.permalink.exceptions import DashboardPermalinkGetFailedError
from superset.dashboards.permalink.types import DashboardPermalinkValue
from superset.key_value.commands.get import GetKeyValueCommand
from superset.key_value.exceptions import KeyValueGetFailedError, KeyValueParseKeyError
from superset.key_value.types import KeyType
logger = logging.getLogger(__name__)
class GetDashboardPermalinkCommand(BaseDashboardPermalinkCommand):
def __init__(
self, actor: User, key: str, key_type: KeyType,
):
self.actor = actor
self.key = key
self.key_type = key_type
def run(self) -> Optional[DashboardPermalinkValue]:
self.validate()
try:
command = GetKeyValueCommand(
self.resource, self.key, key_type=self.key_type
)
value: Optional[DashboardPermalinkValue] = command.run()
if value:
DashboardDAO.get_by_id_or_slug(value["dashboardId"])
return value
return None
except (
DashboardNotFoundError,
KeyValueGetFailedError,
KeyValueParseKeyError,
) as ex:
raise DashboardPermalinkGetFailedError(message=ex.message) from ex
except SQLAlchemyError as ex:
logger.exception("Error running get command")
raise DashboardPermalinkGetFailedError() from ex
def validate(self) -> None:
pass

View File

@ -0,0 +1,31 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from flask_babel import lazy_gettext as _
from superset.commands.exceptions import CommandException, CreateFailedError
class DashboardPermalinkInvalidStateError(CommandException):
message = _("Invalid state.")
class DashboardPermalinkCreateFailedError(CreateFailedError):
message = _("An error occurred while creating the value.")
class DashboardPermalinkGetFailedError(CommandException):
message = _("An error occurred while accessing the value.")

View File

@ -0,0 +1,40 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from marshmallow import fields, Schema
class DashboardPermalinkPostSchema(Schema):
filterState = fields.Dict(
required=True, allow_none=False, description="Native filter state",
)
urlParams = fields.List(
fields.Tuple(
(
fields.String(required=True, allow_none=True, description="Key"),
fields.String(required=True, allow_none=True, description="Value"),
),
required=False,
allow_none=True,
description="URL Parameter key-value pair",
),
required=False,
allow_none=True,
description="URL Parameters",
)
hash = fields.String(
required=False, allow_none=True, description="Optional anchor link"
)

View File

@ -14,16 +14,15 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from marshmallow import fields, Schema from typing import Any, Dict, List, Optional, Tuple, TypedDict
class KeyValuePostSchema(Schema): class DashboardPermalinkState(TypedDict):
value = fields.String( filterState: Dict[str, Any]
required=True, allow_none=False, description="Any type of JSON supported text." hash: Optional[str]
) urlParams: Optional[List[Tuple[str, str]]]
class KeyValuePutSchema(Schema): class DashboardPermalinkValue(TypedDict):
value = fields.String( dashboardId: str
required=True, allow_none=False, description="Any type of JSON supported text." state: DashboardPermalinkState
)

View File

@ -37,7 +37,7 @@ from superset.explore.form_data.commands.parameters import CommandParameters
from superset.explore.form_data.commands.update import UpdateFormDataCommand from superset.explore.form_data.commands.update import UpdateFormDataCommand
from superset.explore.form_data.schemas import FormDataPostSchema, FormDataPutSchema from superset.explore.form_data.schemas import FormDataPostSchema, FormDataPutSchema
from superset.extensions import event_logger from superset.extensions import event_logger
from superset.key_value.commands.exceptions import KeyValueAccessDeniedError from superset.temporary_cache.commands.exceptions import TemporaryCacheAccessDeniedError
from superset.views.base_api import requires_json from superset.views.base_api import requires_json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -121,7 +121,7 @@ class ExploreFormDataRestApi(BaseApi, ABC):
except ( except (
ChartAccessDeniedError, ChartAccessDeniedError,
DatasetAccessDeniedError, DatasetAccessDeniedError,
KeyValueAccessDeniedError, TemporaryCacheAccessDeniedError,
) as ex: ) as ex:
return self.response(403, message=str(ex)) return self.response(403, message=str(ex))
except (ChartNotFoundError, DatasetNotFoundError) as ex: except (ChartNotFoundError, DatasetNotFoundError) as ex:
@ -198,7 +198,7 @@ class ExploreFormDataRestApi(BaseApi, ABC):
except ( except (
ChartAccessDeniedError, ChartAccessDeniedError,
DatasetAccessDeniedError, DatasetAccessDeniedError,
KeyValueAccessDeniedError, TemporaryCacheAccessDeniedError,
) as ex: ) as ex:
return self.response(403, message=str(ex)) return self.response(403, message=str(ex))
except (ChartNotFoundError, DatasetNotFoundError) as ex: except (ChartNotFoundError, DatasetNotFoundError) as ex:
@ -253,7 +253,7 @@ class ExploreFormDataRestApi(BaseApi, ABC):
except ( except (
ChartAccessDeniedError, ChartAccessDeniedError,
DatasetAccessDeniedError, DatasetAccessDeniedError,
KeyValueAccessDeniedError, TemporaryCacheAccessDeniedError,
) as ex: ) as ex:
return self.response(403, message=str(ex)) return self.response(403, message=str(ex))
except (ChartNotFoundError, DatasetNotFoundError) as ex: except (ChartNotFoundError, DatasetNotFoundError) as ex:
@ -309,7 +309,7 @@ class ExploreFormDataRestApi(BaseApi, ABC):
except ( except (
ChartAccessDeniedError, ChartAccessDeniedError,
DatasetAccessDeniedError, DatasetAccessDeniedError,
KeyValueAccessDeniedError, TemporaryCacheAccessDeniedError,
) as ex: ) as ex:
return self.response(403, message=str(ex)) return self.response(403, message=str(ex))
except (ChartNotFoundError, DatasetNotFoundError) as ex: except (ChartNotFoundError, DatasetNotFoundError) as ex:

View File

@ -22,10 +22,11 @@ from sqlalchemy.exc import SQLAlchemyError
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.explore.form_data.commands.parameters import CommandParameters from superset.explore.form_data.commands.parameters import CommandParameters
from superset.explore.form_data.commands.state import TemporaryExploreState from superset.explore.form_data.commands.state import TemporaryExploreState
from superset.explore.form_data.utils import check_access from superset.explore.utils import check_access
from superset.extensions import cache_manager from superset.extensions import cache_manager
from superset.key_value.commands.exceptions import KeyValueCreateFailedError from superset.temporary_cache.commands.exceptions import TemporaryCacheCreateFailedError
from superset.key_value.utils import cache_key, random_key from superset.temporary_cache.utils import cache_key, random_key
from superset.utils.schema import validate_json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -35,6 +36,7 @@ class CreateFormDataCommand(BaseCommand):
self._cmd_params = cmd_params self._cmd_params = cmd_params
def run(self) -> str: def run(self) -> str:
self.validate()
try: try:
dataset_id = self._cmd_params.dataset_id dataset_id = self._cmd_params.dataset_id
chart_id = self._cmd_params.chart_id chart_id = self._cmd_params.chart_id
@ -58,7 +60,8 @@ class CreateFormDataCommand(BaseCommand):
return key return key
except SQLAlchemyError as ex: except SQLAlchemyError as ex:
logger.exception("Error running create command") logger.exception("Error running create command")
raise KeyValueCreateFailedError() from ex raise TemporaryCacheCreateFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
pass if self._cmd_params.form_data:
validate_json(self._cmd_params.form_data)

View File

@ -23,13 +23,13 @@ from sqlalchemy.exc import SQLAlchemyError
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.explore.form_data.commands.parameters import CommandParameters from superset.explore.form_data.commands.parameters import CommandParameters
from superset.explore.form_data.commands.state import TemporaryExploreState from superset.explore.form_data.commands.state import TemporaryExploreState
from superset.explore.form_data.utils import check_access from superset.explore.utils import check_access
from superset.extensions import cache_manager from superset.extensions import cache_manager
from superset.key_value.commands.exceptions import ( from superset.temporary_cache.commands.exceptions import (
KeyValueAccessDeniedError, TemporaryCacheAccessDeniedError,
KeyValueDeleteFailedError, TemporaryCacheDeleteFailedError,
) )
from superset.key_value.utils import cache_key from superset.temporary_cache.utils import cache_key
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -50,7 +50,7 @@ class DeleteFormDataCommand(BaseCommand, ABC):
chart_id = state["chart_id"] chart_id = state["chart_id"]
check_access(dataset_id, chart_id, actor) check_access(dataset_id, chart_id, actor)
if state["owner"] != actor.get_user_id(): if state["owner"] != actor.get_user_id():
raise KeyValueAccessDeniedError() raise TemporaryCacheAccessDeniedError()
tab_id = self._cmd_params.tab_id tab_id = self._cmd_params.tab_id
contextual_key = cache_key( contextual_key = cache_key(
session.get("_id"), tab_id, dataset_id, chart_id session.get("_id"), tab_id, dataset_id, chart_id
@ -60,7 +60,7 @@ class DeleteFormDataCommand(BaseCommand, ABC):
return False return False
except SQLAlchemyError as ex: except SQLAlchemyError as ex:
logger.exception("Error running delete command") logger.exception("Error running delete command")
raise KeyValueDeleteFailedError() from ex raise TemporaryCacheDeleteFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
pass pass

View File

@ -24,9 +24,9 @@ from sqlalchemy.exc import SQLAlchemyError
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.explore.form_data.commands.parameters import CommandParameters from superset.explore.form_data.commands.parameters import CommandParameters
from superset.explore.form_data.commands.state import TemporaryExploreState from superset.explore.form_data.commands.state import TemporaryExploreState
from superset.explore.form_data.utils import check_access from superset.explore.utils import check_access
from superset.extensions import cache_manager from superset.extensions import cache_manager
from superset.key_value.commands.exceptions import KeyValueGetFailedError from superset.temporary_cache.commands.exceptions import TemporaryCacheGetFailedError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -52,7 +52,7 @@ class GetFormDataCommand(BaseCommand, ABC):
return None return None
except SQLAlchemyError as ex: except SQLAlchemyError as ex:
logger.exception("Error running get command") logger.exception("Error running get command")
raise KeyValueGetFailedError() from ex raise TemporaryCacheGetFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
pass pass

View File

@ -24,13 +24,14 @@ from sqlalchemy.exc import SQLAlchemyError
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.explore.form_data.commands.parameters import CommandParameters from superset.explore.form_data.commands.parameters import CommandParameters
from superset.explore.form_data.commands.state import TemporaryExploreState from superset.explore.form_data.commands.state import TemporaryExploreState
from superset.explore.form_data.utils import check_access from superset.explore.utils import check_access
from superset.extensions import cache_manager from superset.extensions import cache_manager
from superset.key_value.commands.exceptions import ( from superset.temporary_cache.commands.exceptions import (
KeyValueAccessDeniedError, TemporaryCacheAccessDeniedError,
KeyValueUpdateFailedError, TemporaryCacheUpdateFailedError,
) )
from superset.key_value.utils import cache_key, random_key from superset.temporary_cache.utils import cache_key, random_key
from superset.utils.schema import validate_json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -42,6 +43,7 @@ class UpdateFormDataCommand(BaseCommand, ABC):
self._cmd_params = cmd_params self._cmd_params = cmd_params
def run(self) -> Optional[str]: def run(self) -> Optional[str]:
self.validate()
try: try:
dataset_id = self._cmd_params.dataset_id dataset_id = self._cmd_params.dataset_id
chart_id = self._cmd_params.chart_id chart_id = self._cmd_params.chart_id
@ -55,7 +57,7 @@ class UpdateFormDataCommand(BaseCommand, ABC):
if state and form_data: if state and form_data:
user_id = actor.get_user_id() user_id = actor.get_user_id()
if state["owner"] != user_id: if state["owner"] != user_id:
raise KeyValueAccessDeniedError() raise TemporaryCacheAccessDeniedError()
# Generate a new key if tab_id changes or equals 0 # Generate a new key if tab_id changes or equals 0
tab_id = self._cmd_params.tab_id tab_id = self._cmd_params.tab_id
@ -77,7 +79,8 @@ class UpdateFormDataCommand(BaseCommand, ABC):
return key return key
except SQLAlchemyError as ex: except SQLAlchemyError as ex:
logger.exception("Error running update command") logger.exception("Error running update command")
raise KeyValueUpdateFailedError() from ex raise TemporaryCacheUpdateFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
pass if self._cmd_params.form_data:
validate_json(self._cmd_params.form_data)

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,174 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from flask import current_app, g, request, Response
from flask_appbuilder.api import BaseApi, expose, protect, safe
from marshmallow import ValidationError
from superset.charts.commands.exceptions import (
ChartAccessDeniedError,
ChartNotFoundError,
)
from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod
from superset.datasets.commands.exceptions import (
DatasetAccessDeniedError,
DatasetNotFoundError,
)
from superset.explore.permalink.commands.create import CreateExplorePermalinkCommand
from superset.explore.permalink.commands.get import GetExplorePermalinkCommand
from superset.explore.permalink.exceptions import ExplorePermalinkInvalidStateError
from superset.explore.permalink.schemas import ExplorePermalinkPostSchema
from superset.extensions import event_logger
from superset.key_value.exceptions import KeyValueAccessDeniedError
from superset.views.base_api import requires_json
logger = logging.getLogger(__name__)
class ExplorePermalinkRestApi(BaseApi):
add_model_schema = ExplorePermalinkPostSchema()
method_permission_name = MODEL_API_RW_METHOD_PERMISSION_MAP
include_route_methods = {
RouteMethod.POST,
RouteMethod.PUT,
RouteMethod.GET,
RouteMethod.DELETE,
}
allow_browser_login = True
class_permission_name = "ExplorePermalinkRestApi"
resource_name = "explore"
openapi_spec_tag = "Explore Permanent Link"
openapi_spec_component_schemas = (ExplorePermalinkPostSchema,)
@expose("/permalink", methods=["POST"])
@protect()
@safe
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.post",
log_to_statsd=False,
)
@requires_json
def post(self) -> Response:
"""Stores a new permanent link.
---
post:
description: >-
Stores a new permanent link.
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/ExplorePermalinkPostSchema'
responses:
201:
description: The permanent link was stored successfully.
content:
application/json:
schema:
type: object
properties:
key:
type: string
description: The key to retrieve the permanent link data.
url:
type: string
description: pemanent link.
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
key_type = current_app.config["PERMALINK_KEY_TYPE"]
try:
state = self.add_model_schema.load(request.json)
key = CreateExplorePermalinkCommand(
actor=g.user, state=state, key_type=key_type,
).run()
http_origin = request.headers.environ.get("HTTP_ORIGIN")
url = f"{http_origin}/superset/explore/p/{key}/"
return self.response(201, key=key, url=url)
except ValidationError as ex:
return self.response(400, message=ex.messages)
except (
ChartAccessDeniedError,
DatasetAccessDeniedError,
KeyValueAccessDeniedError,
) as ex:
return self.response(403, message=str(ex))
except (ChartNotFoundError, DatasetNotFoundError) as ex:
return self.response(404, message=str(ex))
@expose("/permalink/<string:key>", methods=["GET"])
@protect()
@safe
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.get",
log_to_statsd=False,
)
def get(self, key: str) -> Response:
"""Retrives permanent link state for chart.
---
get:
description: >-
Retrives chart state associated with a permanent link.
parameters:
- in: path
schema:
type: string
name: key
responses:
200:
description: Returns the stored form_data.
content:
application/json:
schema:
type: object
properties:
state:
type: object
description: The stored state
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
404:
$ref: '#/components/responses/404'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
try:
key_type = current_app.config["PERMALINK_KEY_TYPE"]
value = GetExplorePermalinkCommand(
actor=g.user, key=key, key_type=key_type
).run()
if not value:
return self.response_404()
return self.response(200, **value)
except ExplorePermalinkInvalidStateError as ex:
return self.response(400, message=str(ex))
except (ChartAccessDeniedError, DatasetAccessDeniedError,) as ex:
return self.response(403, message=str(ex))
except (ChartNotFoundError, DatasetNotFoundError) as ex:
return self.response(404, message=str(ex))

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,23 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from abc import ABC
from superset.commands.base import BaseCommand
class BaseExplorePermalinkCommand(BaseCommand, ABC):
resource = "explore_permalink"

View File

@ -0,0 +1,60 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict, Optional
from flask_appbuilder.security.sqla.models import User
from sqlalchemy.exc import SQLAlchemyError
from superset.explore.permalink.commands.base import BaseExplorePermalinkCommand
from superset.explore.permalink.exceptions import ExplorePermalinkCreateFailedError
from superset.explore.utils import check_access
from superset.key_value.commands.create import CreateKeyValueCommand
from superset.key_value.types import KeyType
logger = logging.getLogger(__name__)
class CreateExplorePermalinkCommand(BaseExplorePermalinkCommand):
def __init__(self, actor: User, state: Dict[str, Any], key_type: KeyType):
self.actor = actor
self.chart_id: Optional[int] = state["formData"].get("slice_id")
self.datasource: str = state["formData"]["datasource"]
self.state = state
self.key_type = key_type
def run(self) -> str:
self.validate()
try:
dataset_id = int(self.datasource.split("__")[0])
check_access(dataset_id, self.chart_id, self.actor)
value = {
"chartId": self.chart_id,
"datasetId": dataset_id,
"datasource": self.datasource,
"state": self.state,
}
command = CreateKeyValueCommand(
self.actor, self.resource, value, self.key_type
)
return command.run()
except SQLAlchemyError as ex:
logger.exception("Error running create command")
raise ExplorePermalinkCreateFailedError() from ex
def validate(self) -> None:
pass

View File

@ -0,0 +1,66 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Optional
from flask_appbuilder.security.sqla.models import User
from sqlalchemy.exc import SQLAlchemyError
from superset.datasets.commands.exceptions import DatasetNotFoundError
from superset.explore.permalink.commands.base import BaseExplorePermalinkCommand
from superset.explore.permalink.exceptions import ExplorePermalinkGetFailedError
from superset.explore.permalink.types import ExplorePermalinkValue
from superset.explore.utils import check_access
from superset.key_value.commands.get import GetKeyValueCommand
from superset.key_value.exceptions import KeyValueGetFailedError, KeyValueParseKeyError
from superset.key_value.types import KeyType
logger = logging.getLogger(__name__)
class GetExplorePermalinkCommand(BaseExplorePermalinkCommand):
def __init__(
self, actor: User, key: str, key_type: KeyType,
):
self.actor = actor
self.key = key
self.key_type = key_type
def run(self) -> Optional[ExplorePermalinkValue]:
self.validate()
try:
value: Optional[ExplorePermalinkValue] = GetKeyValueCommand(
self.resource, self.key, key_type=self.key_type
).run()
if value:
chart_id: Optional[int] = value.get("chartId")
dataset_id = value["datasetId"]
check_access(dataset_id, chart_id, self.actor)
return value
return None
except (
DatasetNotFoundError,
KeyValueGetFailedError,
KeyValueParseKeyError,
) as ex:
raise ExplorePermalinkGetFailedError(message=ex.message) from ex
except SQLAlchemyError as ex:
logger.exception("Error running get command")
raise ExplorePermalinkGetFailedError() from ex
def validate(self) -> None:
pass

View File

@ -0,0 +1,31 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from flask_babel import lazy_gettext as _
from superset.commands.exceptions import CommandException, CreateFailedError
class ExplorePermalinkInvalidStateError(CreateFailedError):
message = _("Invalid state.")
class ExplorePermalinkCreateFailedError(CreateFailedError):
message = _("An error occurred while creating the value.")
class ExplorePermalinkGetFailedError(CommandException):
message = _("An error occurred while accessing the value.")

View File

@ -0,0 +1,37 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from marshmallow import fields, Schema
class ExplorePermalinkPostSchema(Schema):
formData = fields.Dict(
required=True, allow_none=False, description="Chart form data",
)
urlParams = fields.List(
fields.Tuple(
(
fields.String(required=True, allow_none=True, description="Key"),
fields.String(required=True, allow_none=True, description="Value"),
),
required=False,
allow_none=True,
description="URL Parameter key-value pair",
),
required=False,
allow_none=True,
description="URL Parameters",
)

View File

@ -0,0 +1,29 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, List, Optional, Tuple, TypedDict
class ExplorePermalinkState(TypedDict, total=False):
formData: Dict[str, Any]
urlParams: Optional[List[Tuple[str, str]]]
class ExplorePermalinkValue(TypedDict):
chartId: Optional[int]
datasetId: int
datasource: str
state: ExplorePermalinkState

View File

@ -136,11 +136,13 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
from superset.dashboards.api import DashboardRestApi from superset.dashboards.api import DashboardRestApi
from superset.dashboards.filter_sets.api import FilterSetRestApi from superset.dashboards.filter_sets.api import FilterSetRestApi
from superset.dashboards.filter_state.api import DashboardFilterStateRestApi from superset.dashboards.filter_state.api import DashboardFilterStateRestApi
from superset.dashboards.permalink.api import DashboardPermalinkRestApi
from superset.databases.api import DatabaseRestApi from superset.databases.api import DatabaseRestApi
from superset.datasets.api import DatasetRestApi from superset.datasets.api import DatasetRestApi
from superset.datasets.columns.api import DatasetColumnsRestApi from superset.datasets.columns.api import DatasetColumnsRestApi
from superset.datasets.metrics.api import DatasetMetricRestApi from superset.datasets.metrics.api import DatasetMetricRestApi
from superset.explore.form_data.api import ExploreFormDataRestApi from superset.explore.form_data.api import ExploreFormDataRestApi
from superset.explore.permalink.api import ExplorePermalinkRestApi
from superset.queries.api import QueryRestApi from superset.queries.api import QueryRestApi
from superset.queries.saved_queries.api import SavedQueryRestApi from superset.queries.saved_queries.api import SavedQueryRestApi
from superset.reports.api import ReportScheduleRestApi from superset.reports.api import ReportScheduleRestApi
@ -208,12 +210,14 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
appbuilder.add_api(CssTemplateRestApi) appbuilder.add_api(CssTemplateRestApi)
appbuilder.add_api(CurrentUserRestApi) appbuilder.add_api(CurrentUserRestApi)
appbuilder.add_api(DashboardFilterStateRestApi) appbuilder.add_api(DashboardFilterStateRestApi)
appbuilder.add_api(DashboardPermalinkRestApi)
appbuilder.add_api(DashboardRestApi) appbuilder.add_api(DashboardRestApi)
appbuilder.add_api(DatabaseRestApi) appbuilder.add_api(DatabaseRestApi)
appbuilder.add_api(DatasetRestApi) appbuilder.add_api(DatasetRestApi)
appbuilder.add_api(DatasetColumnsRestApi) appbuilder.add_api(DatasetColumnsRestApi)
appbuilder.add_api(DatasetMetricRestApi) appbuilder.add_api(DatasetMetricRestApi)
appbuilder.add_api(ExploreFormDataRestApi) appbuilder.add_api(ExploreFormDataRestApi)
appbuilder.add_api(ExplorePermalinkRestApi)
appbuilder.add_api(FilterSetRestApi) appbuilder.add_api(FilterSetRestApi)
appbuilder.add_api(QueryRestApi) appbuilder.add_api(QueryRestApi)
appbuilder.add_api(ReportScheduleRestApi) appbuilder.add_api(ReportScheduleRestApi)

View File

@ -15,24 +15,56 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from abc import ABC, abstractmethod import pickle
from datetime import datetime
from typing import Any, Optional
from flask_appbuilder.security.sqla.models import User
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from superset import db
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.key_value.commands.exceptions import KeyValueCreateFailedError from superset.key_value.exceptions import KeyValueCreateFailedError
from superset.key_value.commands.parameters import CommandParameters from superset.key_value.models import KeyValueEntry
from superset.key_value.types import KeyType
from superset.key_value.utils import extract_key
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CreateKeyValueCommand(BaseCommand, ABC): class CreateKeyValueCommand(BaseCommand):
def __init__(self, cmd_params: CommandParameters): actor: User
self._cmd_params = cmd_params resource: str
value: Any
key_type: KeyType
expires_on: Optional[datetime]
def __init__(
self,
actor: User,
resource: str,
value: Any,
key_type: KeyType,
expires_on: Optional[datetime] = None,
):
"""
Create a new key-value pair
:param resource: the resource (dashboard, chart etc)
:param value: the value to persist in the key-value store
:param key_type: the type of the key to return
:param expires_on: entry expiration time
:return: the key associated with the persisted value
"""
self.resource = resource
self.actor = actor
self.value = value
self.key_type = key_type
self.expires_on = expires_on
def run(self) -> str: def run(self) -> str:
try: try:
return self.create(self._cmd_params) return self.create()
except SQLAlchemyError as ex: except SQLAlchemyError as ex:
logger.exception("Error running create command") logger.exception("Error running create command")
raise KeyValueCreateFailedError() from ex raise KeyValueCreateFailedError() from ex
@ -40,6 +72,14 @@ class CreateKeyValueCommand(BaseCommand, ABC):
def validate(self) -> None: def validate(self) -> None:
pass pass
@abstractmethod def create(self) -> str:
def create(self, cmd_params: CommandParameters) -> str: entry = KeyValueEntry(
... resource=self.resource,
value=pickle.dumps(self.value),
created_on=datetime.now(),
created_by_fk=None if self.actor.is_anonymous else self.actor.id,
expires_on=self.expires_on,
)
db.session.add(entry)
db.session.commit()
return extract_key(entry, self.key_type)

View File

@ -15,24 +15,45 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from abc import ABC, abstractmethod
from flask_appbuilder.security.sqla.models import User
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from superset import db
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.key_value.commands.exceptions import KeyValueDeleteFailedError from superset.key_value.exceptions import KeyValueDeleteFailedError
from superset.key_value.commands.parameters import CommandParameters from superset.key_value.models import KeyValueEntry
from superset.key_value.types import KeyType
from superset.key_value.utils import get_filter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DeleteKeyValueCommand(BaseCommand, ABC): class DeleteKeyValueCommand(BaseCommand):
def __init__(self, cmd_params: CommandParameters): actor: User
self._cmd_params = cmd_params key: str
key_type: KeyType
resource: str
def __init__(
self, actor: User, resource: str, key: str, key_type: KeyType = "uuid"
):
"""
Delete a key-value pair
:param resource: the resource (dashboard, chart etc)
:param key: the key to delete
:param key_type: the type of key
:return: was the entry deleted or not
"""
self.resource = resource
self.actor = actor
self.key = key
self.key_type = key_type
def run(self) -> bool: def run(self) -> bool:
try: try:
return self.delete(self._cmd_params) return self.delete()
except SQLAlchemyError as ex: except SQLAlchemyError as ex:
logger.exception("Error running delete command") logger.exception("Error running delete command")
raise KeyValueDeleteFailedError() from ex raise KeyValueDeleteFailedError() from ex
@ -40,6 +61,16 @@ class DeleteKeyValueCommand(BaseCommand, ABC):
def validate(self) -> None: def validate(self) -> None:
pass pass
@abstractmethod def delete(self) -> bool:
def delete(self, cmd_params: CommandParameters) -> bool: filter_ = get_filter(self.resource, self.key, self.key_type)
... entry = (
db.session.query(KeyValueEntry)
.filter_by(**filter_)
.autoflush(False)
.first()
)
if entry:
db.session.delete(entry)
db.session.commit()
return True
return False

View File

@ -14,26 +14,45 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from abc import ABC, abstractmethod import pickle
from typing import Optional from datetime import datetime
from typing import Any, Optional
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from superset import db
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.key_value.commands.exceptions import KeyValueGetFailedError from superset.key_value.exceptions import KeyValueGetFailedError
from superset.key_value.commands.parameters import CommandParameters from superset.key_value.models import KeyValueEntry
from superset.key_value.types import KeyType
from superset.key_value.utils import get_filter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class GetKeyValueCommand(BaseCommand, ABC): class GetKeyValueCommand(BaseCommand):
def __init__(self, cmd_params: CommandParameters): key: str
self._cmd_params = cmd_params key_type: KeyType
resource: str
def run(self) -> Optional[str]: def __init__(self, resource: str, key: str, key_type: KeyType = "uuid"):
"""
Retrieve a key value entry
:param resource: the resource (dashboard, chart etc)
:param key: the key to retrieve
:param key_type: the type of the key to retrieve
:return: the value associated with the key if present
"""
self.resource = resource
self.key = key
self.key_type = key_type
def run(self) -> Any:
try: try:
return self.get(self._cmd_params) return self.get()
except SQLAlchemyError as ex: except SQLAlchemyError as ex:
logger.exception("Error running get command") logger.exception("Error running get command")
raise KeyValueGetFailedError() from ex raise KeyValueGetFailedError() from ex
@ -41,6 +60,14 @@ class GetKeyValueCommand(BaseCommand, ABC):
def validate(self) -> None: def validate(self) -> None:
pass pass
@abstractmethod def get(self) -> Optional[Any]:
def get(self, cmd_params: CommandParameters) -> Optional[str]: filter_ = get_filter(self.resource, self.key, self.key_type)
... entry = (
db.session.query(KeyValueEntry)
.filter_by(**filter_)
.autoflush(False)
.first()
)
if entry and (entry.expires_on is None or entry.expires_on > datetime.now()):
return pickle.loads(entry.value)
return None

View File

@ -14,28 +14,62 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging
from abc import ABC, abstractmethod
from typing import Optional
import logging
import pickle
from datetime import datetime
from typing import Any, Optional
from flask_appbuilder.security.sqla.models import User
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from superset import db
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.key_value.commands.exceptions import KeyValueUpdateFailedError from superset.key_value.exceptions import KeyValueUpdateFailedError
from superset.key_value.commands.parameters import CommandParameters from superset.key_value.models import KeyValueEntry
from superset.key_value.types import KeyType
from superset.key_value.utils import extract_key, get_filter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UpdateKeyValueCommand(BaseCommand, ABC): class UpdateKeyValueCommand(BaseCommand):
actor: User
resource: str
value: Any
key: str
key_type: KeyType
expires_on: Optional[datetime]
def __init__( def __init__(
self, cmd_params: CommandParameters, self,
actor: User,
resource: str,
key: str,
value: Any,
key_type: KeyType = "uuid",
expires_on: Optional[datetime] = None,
): ):
self._parameters = cmd_params """
Update a key value entry
:param resource: the resource (dashboard, chart etc)
:param key: the key to update
:param value: the value to persist in the key-value store
:param key_type: the type of the key to update
:param expires_on: entry expiration time
:return: the key associated with the updated value
"""
self.actor = actor
self.resource = resource
self.key = key
self.value = value
self.key_type = key_type
self.expires_on = expires_on
def run(self) -> Optional[str]: def run(self) -> Optional[str]:
try: try:
return self.update(self._parameters) return self.update()
except SQLAlchemyError as ex: except SQLAlchemyError as ex:
logger.exception("Error running update command") logger.exception("Error running update command")
raise KeyValueUpdateFailedError() from ex raise KeyValueUpdateFailedError() from ex
@ -43,6 +77,20 @@ class UpdateKeyValueCommand(BaseCommand, ABC):
def validate(self) -> None: def validate(self) -> None:
pass pass
@abstractmethod def update(self) -> Optional[str]:
def update(self, cmd_params: CommandParameters) -> Optional[str]: filter_ = get_filter(self.resource, self.key, self.key_type)
... entry: KeyValueEntry = (
db.session.query(KeyValueEntry)
.filter_by(**filter_)
.autoflush(False)
.first()
)
if entry:
entry.value = pickle.dumps(self.value)
entry.expires_on = self.expires_on
entry.changed_on = datetime.now()
entry.changed_by_fk = None if self.actor.is_anonymous else self.actor.id
db.session.merge(entry)
db.session.commit()
return extract_key(entry, self.key_type)
return None

View File

@ -23,6 +23,11 @@ from superset.commands.exceptions import (
ForbiddenError, ForbiddenError,
UpdateFailedError, UpdateFailedError,
) )
from superset.exceptions import SupersetException
class KeyValueParseKeyError(SupersetException):
message = _("An error occurred while parsing the key.")
class KeyValueCreateFailedError(CreateFailedError): class KeyValueCreateFailedError(CreateFailedError):

View File

@ -0,0 +1,38 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from flask_appbuilder import Model
from sqlalchemy import Column, DateTime, ForeignKey, Integer, LargeBinary, String
from sqlalchemy.orm import relationship
from superset import security_manager
from superset.models.helpers import AuditMixinNullable, ImportExportMixin
class KeyValueEntry(Model, AuditMixinNullable, ImportExportMixin):
"""Key value store entity"""
__tablename__ = "key_value"
id = Column(Integer, primary_key=True)
resource = Column(String(32), nullable=False)
value = Column(LargeBinary(), nullable=False)
created_on = Column(DateTime, nullable=True)
created_by_fk = Column(Integer, ForeignKey("ab_user.id"), nullable=True)
changed_on = Column(DateTime, nullable=True)
expires_on = Column(DateTime, nullable=True)
changed_by_fk = Column(Integer, ForeignKey("ab_user.id"), nullable=True)
created_by = relationship(security_manager.user_model, foreign_keys=[created_by_fk])
changed_by = relationship(security_manager.user_model, foreign_keys=[changed_by_fk])

View File

@ -0,0 +1,34 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from dataclasses import dataclass
from typing import Literal, Optional, TypedDict
from uuid import UUID
@dataclass
class Key:
id: Optional[int]
uuid: Optional[UUID]
KeyType = Literal["id", "uuid"]
class KeyValueFilter(TypedDict, total=False):
resource: str
id: Optional[int]
uuid: Optional[UUID]

View File

@ -14,15 +14,44 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from secrets import token_urlsafe from typing import Literal
from typing import Any from uuid import UUID
SEPARATOR = ";" from flask import current_app
from superset.key_value.exceptions import KeyValueParseKeyError
from superset.key_value.models import KeyValueEntry
from superset.key_value.types import Key, KeyType, KeyValueFilter
def cache_key(*args: Any) -> str: def parse_permalink_key(key: str) -> Key:
return SEPARATOR.join(str(arg) for arg in args) key_type: Literal["id", "uuid"] = current_app.config["PERMALINK_KEY_TYPE"]
if key_type == "id":
return Key(id=int(key), uuid=None)
return Key(id=None, uuid=UUID(key))
def random_key() -> str: def format_permalink_key(key: Key) -> str:
return token_urlsafe(48) """
return the string representation of the key
:param key: a key object with either a numerical or uuid key
:return: a formatted string
"""
return str(key.id if key.id is not None else key.uuid)
def extract_key(entry: KeyValueEntry, key_type: KeyType) -> str:
return str(entry.id if key_type == "id" else entry.uuid)
def get_filter(resource: str, key: str, key_type: KeyType) -> KeyValueFilter:
try:
filter_: KeyValueFilter = {"resource": resource}
if key_type == "uuid":
filter_["uuid"] = UUID(key)
else:
filter_["id"] = int(key)
return filter_
except ValueError as ex:
raise KeyValueParseKeyError() from ex

View File

@ -0,0 +1,61 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""add key-value store
Revision ID: 6766938c6065
Revises: 7293b0ca7944
Create Date: 2022-03-04 09:59:26.922329
"""
# revision identifiers, used by Alembic.
revision = "6766938c6065"
down_revision = "7293b0ca7944"
from uuid import uuid4
import sqlalchemy as sa
from alembic import op
from sqlalchemy_utils import UUIDType
def upgrade():
op.create_table(
"key_value",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("resource", sa.String(32), nullable=False),
sa.Column("value", sa.LargeBinary(), nullable=False),
sa.Column("uuid", UUIDType(binary=True), default=uuid4),
sa.Column("created_on", sa.DateTime(), nullable=True),
sa.Column("created_by_fk", sa.Integer(), nullable=True),
sa.Column("changed_on", sa.DateTime(), nullable=True),
sa.Column("changed_by_fk", sa.Integer(), nullable=True),
sa.Column("expires_on", sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(["created_by_fk"], ["ab_user.id"]),
sa.ForeignKeyConstraint(["changed_by_fk"], ["ab_user.id"]),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_key_value_uuid"), "key_value", ["uuid"], unique=True)
op.create_index(
op.f("ix_key_value_expires_on"), "key_value", ["expires_on"], unique=False
)
def downgrade():
op.drop_index(op.f("ix_key_value_expires_on"), table_name="key_value")
op.drop_index(op.f("ix_key_value_uuid"), table_name="key_value")
op.drop_table("key_value")

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -37,17 +37,20 @@ from superset.datasets.commands.exceptions import (
DatasetAccessDeniedError, DatasetAccessDeniedError,
DatasetNotFoundError, DatasetNotFoundError,
) )
from superset.key_value.commands.exceptions import KeyValueAccessDeniedError from superset.temporary_cache.commands.exceptions import TemporaryCacheAccessDeniedError
from superset.key_value.commands.parameters import CommandParameters from superset.temporary_cache.commands.parameters import CommandParameters
from superset.key_value.schemas import KeyValuePostSchema, KeyValuePutSchema from superset.temporary_cache.schemas import (
TemporaryCachePostSchema,
TemporaryCachePutSchema,
)
from superset.views.base_api import requires_json from superset.views.base_api import requires_json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class KeyValueRestApi(BaseApi, ABC): class TemporaryCacheRestApi(BaseApi, ABC):
add_model_schema = KeyValuePostSchema() add_model_schema = TemporaryCachePostSchema()
edit_model_schema = KeyValuePutSchema() edit_model_schema = TemporaryCachePutSchema()
method_permission_name = MODEL_API_RW_METHOD_PERMISSION_MAP method_permission_name = MODEL_API_RW_METHOD_PERMISSION_MAP
include_route_methods = { include_route_methods = {
RouteMethod.POST, RouteMethod.POST,
@ -60,10 +63,10 @@ class KeyValueRestApi(BaseApi, ABC):
def add_apispec_components(self, api_spec: APISpec) -> None: def add_apispec_components(self, api_spec: APISpec) -> None:
try: try:
api_spec.components.schema( api_spec.components.schema(
KeyValuePostSchema.__name__, schema=KeyValuePostSchema, TemporaryCachePostSchema.__name__, schema=TemporaryCachePostSchema,
) )
api_spec.components.schema( api_spec.components.schema(
KeyValuePutSchema.__name__, schema=KeyValuePutSchema, TemporaryCachePutSchema.__name__, schema=TemporaryCachePutSchema,
) )
except DuplicateComponentNameError: except DuplicateComponentNameError:
pass pass
@ -85,7 +88,7 @@ class KeyValueRestApi(BaseApi, ABC):
ChartAccessDeniedError, ChartAccessDeniedError,
DashboardAccessDeniedError, DashboardAccessDeniedError,
DatasetAccessDeniedError, DatasetAccessDeniedError,
KeyValueAccessDeniedError, TemporaryCacheAccessDeniedError,
) as ex: ) as ex:
return self.response(403, message=str(ex)) return self.response(403, message=str(ex))
except (ChartNotFoundError, DashboardNotFoundError, DatasetNotFoundError) as ex: except (ChartNotFoundError, DashboardNotFoundError, DatasetNotFoundError) as ex:
@ -111,7 +114,7 @@ class KeyValueRestApi(BaseApi, ABC):
ChartAccessDeniedError, ChartAccessDeniedError,
DashboardAccessDeniedError, DashboardAccessDeniedError,
DatasetAccessDeniedError, DatasetAccessDeniedError,
KeyValueAccessDeniedError, TemporaryCacheAccessDeniedError,
) as ex: ) as ex:
return self.response(403, message=str(ex)) return self.response(403, message=str(ex))
except (ChartNotFoundError, DashboardNotFoundError, DatasetNotFoundError) as ex: except (ChartNotFoundError, DashboardNotFoundError, DatasetNotFoundError) as ex:
@ -128,7 +131,7 @@ class KeyValueRestApi(BaseApi, ABC):
ChartAccessDeniedError, ChartAccessDeniedError,
DashboardAccessDeniedError, DashboardAccessDeniedError,
DatasetAccessDeniedError, DatasetAccessDeniedError,
KeyValueAccessDeniedError, TemporaryCacheAccessDeniedError,
) as ex: ) as ex:
return self.response(403, message=str(ex)) return self.response(403, message=str(ex))
except (ChartNotFoundError, DashboardNotFoundError, DatasetNotFoundError) as ex: except (ChartNotFoundError, DashboardNotFoundError, DatasetNotFoundError) as ex:
@ -145,7 +148,7 @@ class KeyValueRestApi(BaseApi, ABC):
ChartAccessDeniedError, ChartAccessDeniedError,
DashboardAccessDeniedError, DashboardAccessDeniedError,
DatasetAccessDeniedError, DatasetAccessDeniedError,
KeyValueAccessDeniedError, TemporaryCacheAccessDeniedError,
) as ex: ) as ex:
return self.response(403, message=str(ex)) return self.response(403, message=str(ex))
except (ChartNotFoundError, DashboardNotFoundError, DatasetNotFoundError) as ex: except (ChartNotFoundError, DashboardNotFoundError, DatasetNotFoundError) as ex:

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,45 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from abc import ABC, abstractmethod
from sqlalchemy.exc import SQLAlchemyError
from superset.commands.base import BaseCommand
from superset.temporary_cache.commands.exceptions import TemporaryCacheCreateFailedError
from superset.temporary_cache.commands.parameters import CommandParameters
logger = logging.getLogger(__name__)
class CreateTemporaryCacheCommand(BaseCommand, ABC):
def __init__(self, cmd_params: CommandParameters):
self._cmd_params = cmd_params
def run(self) -> str:
try:
return self.create(self._cmd_params)
except SQLAlchemyError as ex:
logger.exception("Error running create command")
raise TemporaryCacheCreateFailedError() from ex
def validate(self) -> None:
pass
@abstractmethod
def create(self, cmd_params: CommandParameters) -> str:
...

View File

@ -0,0 +1,45 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from abc import ABC, abstractmethod
from sqlalchemy.exc import SQLAlchemyError
from superset.commands.base import BaseCommand
from superset.temporary_cache.commands.exceptions import TemporaryCacheDeleteFailedError
from superset.temporary_cache.commands.parameters import CommandParameters
logger = logging.getLogger(__name__)
class DeleteTemporaryCacheCommand(BaseCommand, ABC):
def __init__(self, cmd_params: CommandParameters):
self._cmd_params = cmd_params
def run(self) -> bool:
try:
return self.delete(self._cmd_params)
except SQLAlchemyError as ex:
logger.exception("Error running delete command")
raise TemporaryCacheDeleteFailedError() from ex
def validate(self) -> None:
pass
@abstractmethod
def delete(self, cmd_params: CommandParameters) -> bool:
...

View File

@ -0,0 +1,45 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from flask_babel import lazy_gettext as _
from superset.commands.exceptions import (
CommandException,
CreateFailedError,
DeleteFailedError,
ForbiddenError,
UpdateFailedError,
)
class TemporaryCacheCreateFailedError(CreateFailedError):
message = _("An error occurred while creating the value.")
class TemporaryCacheGetFailedError(CommandException):
message = _("An error occurred while accessing the value.")
class TemporaryCacheDeleteFailedError(DeleteFailedError):
message = _("An error occurred while deleting the value.")
class TemporaryCacheUpdateFailedError(UpdateFailedError):
message = _("An error occurred while updating the value.")
class TemporaryCacheAccessDeniedError(ForbiddenError):
message = _("You don't have permission to modify the value.")

View File

@ -0,0 +1,46 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from abc import ABC, abstractmethod
from typing import Optional
from sqlalchemy.exc import SQLAlchemyError
from superset.commands.base import BaseCommand
from superset.temporary_cache.commands.exceptions import TemporaryCacheGetFailedError
from superset.temporary_cache.commands.parameters import CommandParameters
logger = logging.getLogger(__name__)
class GetTemporaryCacheCommand(BaseCommand, ABC):
def __init__(self, cmd_params: CommandParameters):
self._cmd_params = cmd_params
def run(self) -> Optional[str]:
try:
return self.get(self._cmd_params)
except SQLAlchemyError as ex:
logger.exception("Error running get command")
raise TemporaryCacheGetFailedError() from ex
def validate(self) -> None:
pass
@abstractmethod
def get(self, cmd_params: CommandParameters) -> Optional[str]:
...

View File

@ -0,0 +1,48 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from abc import ABC, abstractmethod
from typing import Optional
from sqlalchemy.exc import SQLAlchemyError
from superset.commands.base import BaseCommand
from superset.temporary_cache.commands.exceptions import TemporaryCacheUpdateFailedError
from superset.temporary_cache.commands.parameters import CommandParameters
logger = logging.getLogger(__name__)
class UpdateTemporaryCacheCommand(BaseCommand, ABC):
def __init__(
self, cmd_params: CommandParameters,
):
self._parameters = cmd_params
def run(self) -> Optional[str]:
try:
return self.update(self._parameters)
except SQLAlchemyError as ex:
logger.exception("Error running update command")
raise TemporaryCacheUpdateFailedError() from ex
def validate(self) -> None:
pass
@abstractmethod
def update(self, cmd_params: CommandParameters) -> Optional[str]:
...

View File

@ -0,0 +1,37 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from marshmallow import fields, Schema
from superset.utils.schema import validate_json
class TemporaryCachePostSchema(Schema):
value = fields.String(
required=True,
allow_none=False,
description="Any type of JSON supported text.",
validate=validate_json,
)
class TemporaryCachePutSchema(Schema):
value = fields.String(
required=True,
allow_none=False,
description="Any type of JSON supported text.",
validate=validate_json,
)

View File

@ -0,0 +1,28 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from secrets import token_urlsafe
from typing import Any
SEPARATOR = ";"
def cache_key(*args: Any) -> str:
return SEPARATOR.join(str(arg) for arg in args)
def random_key() -> str:
return token_urlsafe(48)

View File

@ -57,6 +57,7 @@ from superset import (
sql_lab, sql_lab,
viz, viz,
) )
from superset.charts.commands.exceptions import ChartNotFoundError
from superset.charts.dao import ChartDAO from superset.charts.dao import ChartDAO
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.common.db_query_status import QueryStatus from superset.common.db_query_status import QueryStatus
@ -70,6 +71,8 @@ from superset.connectors.sqla.models import (
) )
from superset.dashboards.commands.importers.v0 import ImportDashboardsCommand from superset.dashboards.commands.importers.v0 import ImportDashboardsCommand
from superset.dashboards.dao import DashboardDAO from superset.dashboards.dao import DashboardDAO
from superset.dashboards.permalink.commands.get import GetDashboardPermalinkCommand
from superset.dashboards.permalink.exceptions import DashboardPermalinkGetFailedError
from superset.databases.dao import DatabaseDAO from superset.databases.dao import DatabaseDAO
from superset.databases.filters import DatabaseFilter from superset.databases.filters import DatabaseFilter
from superset.datasets.commands.exceptions import DatasetNotFoundError from superset.datasets.commands.exceptions import DatasetNotFoundError
@ -88,6 +91,8 @@ from superset.exceptions import (
) )
from superset.explore.form_data.commands.get import GetFormDataCommand from superset.explore.form_data.commands.get import GetFormDataCommand
from superset.explore.form_data.commands.parameters import CommandParameters from superset.explore.form_data.commands.parameters import CommandParameters
from superset.explore.permalink.commands.get import GetExplorePermalinkCommand
from superset.explore.permalink.exceptions import ExplorePermalinkGetFailedError
from superset.extensions import async_query_manager, cache_manager from superset.extensions import async_query_manager, cache_manager
from superset.jinja_context import get_template_processor from superset.jinja_context import get_template_processor
from superset.models.core import Database, FavStar, Log from superset.models.core import Database, FavStar, Log
@ -733,14 +738,36 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
@event_logger.log_this @event_logger.log_this
@expose("/explore/<datasource_type>/<int:datasource_id>/", methods=["GET", "POST"]) @expose("/explore/<datasource_type>/<int:datasource_id>/", methods=["GET", "POST"])
@expose("/explore/", methods=["GET", "POST"]) @expose("/explore/", methods=["GET", "POST"])
@expose("/explore/p/<key>/", methods=["GET"])
# pylint: disable=too-many-locals,too-many-branches,too-many-statements # pylint: disable=too-many-locals,too-many-branches,too-many-statements
def explore( def explore(
self, datasource_type: Optional[str] = None, datasource_id: Optional[int] = None self,
datasource_type: Optional[str] = None,
datasource_id: Optional[int] = None,
key: Optional[str] = None,
) -> FlaskResponse: ) -> FlaskResponse:
initial_form_data = {} initial_form_data = {}
form_data_key = request.args.get("form_data_key") form_data_key = request.args.get("form_data_key")
if form_data_key: if key is not None:
key_type = config["PERMALINK_KEY_TYPE"]
command = GetExplorePermalinkCommand(g.user, key, key_type)
try:
permalink_value = command.run()
if permalink_value:
state = permalink_value["state"]
initial_form_data = state["formData"]
url_params = state.get("urlParams")
if url_params:
initial_form_data["url_params"] = dict(url_params)
else:
return json_error_response(
_("Error: permalink state not found"), status=404
)
except (ChartNotFoundError, ExplorePermalinkGetFailedError) as ex:
flash(__("Error: %(msg)s", msg=ex.message), "danger")
return redirect("/chart/list/")
elif form_data_key:
parameters = CommandParameters(actor=g.user, key=form_data_key) parameters = CommandParameters(actor=g.user, key=form_data_key)
value = GetFormDataCommand(parameters).run() value = GetFormDataCommand(parameters).run()
initial_form_data = json.loads(value) if value else {} initial_form_data = json.loads(value) if value else {}
@ -1978,6 +2005,30 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
), ),
) )
@has_access
@expose("/dashboard/p/<key>/", methods=["GET"])
def dashboard_permalink( # pylint: disable=no-self-use
self, key: str,
) -> FlaskResponse:
key_type = config["PERMALINK_KEY_TYPE"]
try:
value = GetDashboardPermalinkCommand(g.user, key, key_type).run()
except DashboardPermalinkGetFailedError as ex:
flash(__("Error: %(msg)s", msg=ex.message), "danger")
return redirect("/dashboard/list/")
if not value:
return json_error_response(_("permalink state not found"), status=404)
dashboard_id = value["dashboardId"]
url = f"/superset/dashboard/{dashboard_id}?permalink_key={key}"
url_params = value["state"].get("urlParams")
if url_params:
params = parse.urlencode(url_params)
url = f"{url}&{params}"
hash_ = value["state"].get("hash")
if hash_:
url = f"{url}#{hash_}"
return redirect(url)
@api @api
@has_access @has_access
@event_logger.log_this @event_logger.log_this

View File

@ -17,9 +17,8 @@
import logging import logging
from typing import Optional from typing import Optional
from flask import flash, request, Response from flask import flash
from flask_appbuilder import expose from flask_appbuilder import expose
from flask_appbuilder.security.decorators import has_access_api
from werkzeug.utils import redirect from werkzeug.utils import redirect
from superset import db, event_logger from superset import db, event_logger
@ -58,21 +57,3 @@ class R(BaseSupersetView): # pylint: disable=invalid-name
flash("URL to nowhere...", "danger") flash("URL to nowhere...", "danger")
return redirect("/") return redirect("/")
@event_logger.log_this
@has_access_api
@expose("/shortner/", methods=["POST"])
def shortner(self) -> FlaskResponse:
url = request.form.get("data")
if not self._validate_url(url):
logger.warning("Invalid URL")
return Response("Invalid URL", 400)
obj = models.Url(url=url)
db.session.add(obj)
db.session.commit()
return Response(
"{scheme}://{request.headers[Host]}/r/{obj.id}".format(
scheme=request.scheme, request=request, obj=obj
),
mimetype="text/plain",
)

View File

@ -690,30 +690,6 @@ class TestCore(SupersetTestCase):
assert ck.datasource_uid == f"{girls_slice.table.id}__table" assert ck.datasource_uid == f"{girls_slice.table.id}__table"
app.config["STORE_CACHE_KEYS_IN_METADATA_DB"] = store_cache_keys app.config["STORE_CACHE_KEYS_IN_METADATA_DB"] = store_cache_keys
def test_shortner(self):
self.login(username="admin")
data = (
"//superset/explore/table/1/?viz_type=sankey&groupby=source&"
"groupby=target&metric=sum__value&row_limit=5000&where=&having=&"
"flt_col_0=source&flt_op_0=in&flt_eq_0=&slice_id=78&slice_name="
"Energy+Sankey&collapsed_fieldsets=&action=&datasource_name="
"energy_usage&datasource_id=1&datasource_type=table&"
"previous_viz_type=sankey"
)
resp = self.client.post("/r/shortner/", data=dict(data=data))
assert re.search(r"\/r\/[0-9]+", resp.data.decode("utf-8"))
def test_shortner_invalid(self):
self.login(username="admin")
invalid_urls = [
"hhttp://invalid.com",
"hhttps://invalid.com",
"www.invalid.com",
]
for invalid_url in invalid_urls:
resp = self.client.post("/r/shortner/", data=dict(data=invalid_url))
assert resp.status_code == 400
def test_redirect_invalid(self): def test_redirect_invalid(self):
model_url = models.Url(url="hhttp://invalid.com") model_url = models.Url(url="hhttp://invalid.com")
db.session.add(model_url) db.session.add(model_url)

View File

@ -23,25 +23,20 @@ from sqlalchemy.orm import Session
from superset.dashboards.commands.exceptions import DashboardAccessDeniedError from superset.dashboards.commands.exceptions import DashboardAccessDeniedError
from superset.extensions import cache_manager from superset.extensions import cache_manager
from superset.key_value.commands.entry import Entry
from superset.key_value.utils import cache_key
from superset.models.dashboard import Dashboard from superset.models.dashboard import Dashboard
from superset.temporary_cache.commands.entry import Entry
from superset.temporary_cache.utils import cache_key
from tests.integration_tests.base_tests import login from tests.integration_tests.base_tests import login
from tests.integration_tests.fixtures.client import client
from tests.integration_tests.fixtures.world_bank_dashboard import ( from tests.integration_tests.fixtures.world_bank_dashboard import (
load_world_bank_dashboard_with_slices, load_world_bank_dashboard_with_slices,
load_world_bank_data, load_world_bank_data,
) )
from tests.integration_tests.test_app import app from tests.integration_tests.test_app import app
key = "test-key" KEY = "test-key"
value = "test" INITIAL_VALUE = json.dumps({"test": "initial value"})
UPDATED_VALUE = json.dumps({"test": "updated value"})
@pytest.fixture
def client():
with app.test_client() as client:
with app.app_context():
yield client
@pytest.fixture @pytest.fixture
@ -62,20 +57,20 @@ def admin_id() -> int:
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def cache(dashboard_id, admin_id): def cache(dashboard_id, admin_id):
entry: Entry = {"owner": admin_id, "value": value} entry: Entry = {"owner": admin_id, "value": INITIAL_VALUE}
cache_manager.filter_state_cache.set(cache_key(dashboard_id, key), entry) cache_manager.filter_state_cache.set(cache_key(dashboard_id, KEY), entry)
def test_post(client, dashboard_id: int): def test_post(client, dashboard_id: int):
login(client, "admin") login(client, "admin")
payload = { payload = {
"value": value, "value": INITIAL_VALUE,
} }
resp = client.post(f"api/v1/dashboard/{dashboard_id}/filter_state", json=payload) resp = client.post(f"api/v1/dashboard/{dashboard_id}/filter_state", json=payload)
assert resp.status_code == 201 assert resp.status_code == 201
def test_post_bad_request(client, dashboard_id: int): def test_post_bad_request_non_string(client, dashboard_id: int):
login(client, "admin") login(client, "admin")
payload = { payload = {
"value": 1234, "value": 1234,
@ -84,12 +79,21 @@ def test_post_bad_request(client, dashboard_id: int):
assert resp.status_code == 400 assert resp.status_code == 400
def test_post_bad_request_non_json_string(client, dashboard_id: int):
login(client, "admin")
payload = {
"value": "foo",
}
resp = client.post(f"api/v1/dashboard/{dashboard_id}/filter_state", json=payload)
assert resp.status_code == 400
@patch("superset.security.SupersetSecurityManager.raise_for_dashboard_access") @patch("superset.security.SupersetSecurityManager.raise_for_dashboard_access")
def test_post_access_denied(mock_raise_for_dashboard_access, client, dashboard_id: int): def test_post_access_denied(mock_raise_for_dashboard_access, client, dashboard_id: int):
login(client, "admin") login(client, "admin")
mock_raise_for_dashboard_access.side_effect = DashboardAccessDeniedError() mock_raise_for_dashboard_access.side_effect = DashboardAccessDeniedError()
payload = { payload = {
"value": value, "value": INITIAL_VALUE,
} }
resp = client.post(f"api/v1/dashboard/{dashboard_id}/filter_state", json=payload) resp = client.post(f"api/v1/dashboard/{dashboard_id}/filter_state", json=payload)
assert resp.status_code == 403 assert resp.status_code == 403
@ -98,7 +102,7 @@ def test_post_access_denied(mock_raise_for_dashboard_access, client, dashboard_i
def test_post_same_key_for_same_tab_id(client, dashboard_id: int): def test_post_same_key_for_same_tab_id(client, dashboard_id: int):
login(client, "admin") login(client, "admin")
payload = { payload = {
"value": value, "value": INITIAL_VALUE,
} }
resp = client.post( resp = client.post(
f"api/v1/dashboard/{dashboard_id}/filter_state?tab_id=1", json=payload f"api/v1/dashboard/{dashboard_id}/filter_state?tab_id=1", json=payload
@ -116,7 +120,7 @@ def test_post_same_key_for_same_tab_id(client, dashboard_id: int):
def test_post_different_key_for_different_tab_id(client, dashboard_id: int): def test_post_different_key_for_different_tab_id(client, dashboard_id: int):
login(client, "admin") login(client, "admin")
payload = { payload = {
"value": value, "value": INITIAL_VALUE,
} }
resp = client.post( resp = client.post(
f"api/v1/dashboard/{dashboard_id}/filter_state?tab_id=1", json=payload f"api/v1/dashboard/{dashboard_id}/filter_state?tab_id=1", json=payload
@ -134,7 +138,7 @@ def test_post_different_key_for_different_tab_id(client, dashboard_id: int):
def test_post_different_key_for_no_tab_id(client, dashboard_id: int): def test_post_different_key_for_no_tab_id(client, dashboard_id: int):
login(client, "admin") login(client, "admin")
payload = { payload = {
"value": value, "value": INITIAL_VALUE,
} }
resp = client.post(f"api/v1/dashboard/{dashboard_id}/filter_state", json=payload) resp = client.post(f"api/v1/dashboard/{dashboard_id}/filter_state", json=payload)
data = json.loads(resp.data.decode("utf-8")) data = json.loads(resp.data.decode("utf-8"))
@ -148,10 +152,10 @@ def test_post_different_key_for_no_tab_id(client, dashboard_id: int):
def test_put(client, dashboard_id: int): def test_put(client, dashboard_id: int):
login(client, "admin") login(client, "admin")
payload = { payload = {
"value": "new value", "value": UPDATED_VALUE,
} }
resp = client.put( resp = client.put(
f"api/v1/dashboard/{dashboard_id}/filter_state/{key}", json=payload f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}", json=payload
) )
assert resp.status_code == 200 assert resp.status_code == 200
@ -159,15 +163,15 @@ def test_put(client, dashboard_id: int):
def test_put_same_key_for_same_tab_id(client, dashboard_id: int): def test_put_same_key_for_same_tab_id(client, dashboard_id: int):
login(client, "admin") login(client, "admin")
payload = { payload = {
"value": value, "value": INITIAL_VALUE,
} }
resp = client.put( resp = client.put(
f"api/v1/dashboard/{dashboard_id}/filter_state/{key}?tab_id=1", json=payload f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}?tab_id=1", json=payload
) )
data = json.loads(resp.data.decode("utf-8")) data = json.loads(resp.data.decode("utf-8"))
first_key = data.get("key") first_key = data.get("key")
resp = client.put( resp = client.put(
f"api/v1/dashboard/{dashboard_id}/filter_state/{key}?tab_id=1", json=payload f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}?tab_id=1", json=payload
) )
data = json.loads(resp.data.decode("utf-8")) data = json.loads(resp.data.decode("utf-8"))
second_key = data.get("key") second_key = data.get("key")
@ -177,15 +181,15 @@ def test_put_same_key_for_same_tab_id(client, dashboard_id: int):
def test_put_different_key_for_different_tab_id(client, dashboard_id: int): def test_put_different_key_for_different_tab_id(client, dashboard_id: int):
login(client, "admin") login(client, "admin")
payload = { payload = {
"value": value, "value": INITIAL_VALUE,
} }
resp = client.put( resp = client.put(
f"api/v1/dashboard/{dashboard_id}/filter_state/{key}?tab_id=1", json=payload f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}?tab_id=1", json=payload
) )
data = json.loads(resp.data.decode("utf-8")) data = json.loads(resp.data.decode("utf-8"))
first_key = data.get("key") first_key = data.get("key")
resp = client.put( resp = client.put(
f"api/v1/dashboard/{dashboard_id}/filter_state/{key}?tab_id=2", json=payload f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}?tab_id=2", json=payload
) )
data = json.loads(resp.data.decode("utf-8")) data = json.loads(resp.data.decode("utf-8"))
second_key = data.get("key") second_key = data.get("key")
@ -195,28 +199,39 @@ def test_put_different_key_for_different_tab_id(client, dashboard_id: int):
def test_put_different_key_for_no_tab_id(client, dashboard_id: int): def test_put_different_key_for_no_tab_id(client, dashboard_id: int):
login(client, "admin") login(client, "admin")
payload = { payload = {
"value": value, "value": INITIAL_VALUE,
} }
resp = client.put( resp = client.put(
f"api/v1/dashboard/{dashboard_id}/filter_state/{key}", json=payload f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}", json=payload
) )
data = json.loads(resp.data.decode("utf-8")) data = json.loads(resp.data.decode("utf-8"))
first_key = data.get("key") first_key = data.get("key")
resp = client.put( resp = client.put(
f"api/v1/dashboard/{dashboard_id}/filter_state/{key}", json=payload f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}", json=payload
) )
data = json.loads(resp.data.decode("utf-8")) data = json.loads(resp.data.decode("utf-8"))
second_key = data.get("key") second_key = data.get("key")
assert first_key != second_key assert first_key != second_key
def test_put_bad_request(client, dashboard_id: int): def test_put_bad_request_non_string(client, dashboard_id: int):
login(client, "admin") login(client, "admin")
payload = { payload = {
"value": 1234, "value": 1234,
} }
resp = client.put( resp = client.put(
f"api/v1/dashboard/{dashboard_id}/filter_state/{key}", json=payload f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}", json=payload
)
assert resp.status_code == 400
def test_put_bad_request_non_json_string(client, dashboard_id: int):
login(client, "admin")
payload = {
"value": "foo",
}
resp = client.put(
f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}", json=payload
) )
assert resp.status_code == 400 assert resp.status_code == 400
@ -226,10 +241,10 @@ def test_put_access_denied(mock_raise_for_dashboard_access, client, dashboard_id
login(client, "admin") login(client, "admin")
mock_raise_for_dashboard_access.side_effect = DashboardAccessDeniedError() mock_raise_for_dashboard_access.side_effect = DashboardAccessDeniedError()
payload = { payload = {
"value": "new value", "value": UPDATED_VALUE,
} }
resp = client.put( resp = client.put(
f"api/v1/dashboard/{dashboard_id}/filter_state/{key}", json=payload f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}", json=payload
) )
assert resp.status_code == 403 assert resp.status_code == 403
@ -237,10 +252,10 @@ def test_put_access_denied(mock_raise_for_dashboard_access, client, dashboard_id
def test_put_not_owner(client, dashboard_id: int): def test_put_not_owner(client, dashboard_id: int):
login(client, "gamma") login(client, "gamma")
payload = { payload = {
"value": "new value", "value": UPDATED_VALUE,
} }
resp = client.put( resp = client.put(
f"api/v1/dashboard/{dashboard_id}/filter_state/{key}", json=payload f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}", json=payload
) )
assert resp.status_code == 403 assert resp.status_code == 403
@ -253,29 +268,29 @@ def test_get_key_not_found(client, dashboard_id: int):
def test_get_dashboard_not_found(client): def test_get_dashboard_not_found(client):
login(client, "admin") login(client, "admin")
resp = client.get(f"api/v1/dashboard/{-1}/filter_state/{key}") resp = client.get(f"api/v1/dashboard/{-1}/filter_state/{KEY}")
assert resp.status_code == 404 assert resp.status_code == 404
def test_get(client, dashboard_id: int): def test_get(client, dashboard_id: int):
login(client, "admin") login(client, "admin")
resp = client.get(f"api/v1/dashboard/{dashboard_id}/filter_state/{key}") resp = client.get(f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}")
assert resp.status_code == 200 assert resp.status_code == 200
data = json.loads(resp.data.decode("utf-8")) data = json.loads(resp.data.decode("utf-8"))
assert value == data.get("value") assert INITIAL_VALUE == data.get("value")
@patch("superset.security.SupersetSecurityManager.raise_for_dashboard_access") @patch("superset.security.SupersetSecurityManager.raise_for_dashboard_access")
def test_get_access_denied(mock_raise_for_dashboard_access, client, dashboard_id): def test_get_access_denied(mock_raise_for_dashboard_access, client, dashboard_id):
login(client, "admin") login(client, "admin")
mock_raise_for_dashboard_access.side_effect = DashboardAccessDeniedError() mock_raise_for_dashboard_access.side_effect = DashboardAccessDeniedError()
resp = client.get(f"api/v1/dashboard/{dashboard_id}/filter_state/{key}") resp = client.get(f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}")
assert resp.status_code == 403 assert resp.status_code == 403
def test_delete(client, dashboard_id: int): def test_delete(client, dashboard_id: int):
login(client, "admin") login(client, "admin")
resp = client.delete(f"api/v1/dashboard/{dashboard_id}/filter_state/{key}") resp = client.delete(f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}")
assert resp.status_code == 200 assert resp.status_code == 200
@ -285,11 +300,11 @@ def test_delete_access_denied(
): ):
login(client, "admin") login(client, "admin")
mock_raise_for_dashboard_access.side_effect = DashboardAccessDeniedError() mock_raise_for_dashboard_access.side_effect = DashboardAccessDeniedError()
resp = client.delete(f"api/v1/dashboard/{dashboard_id}/filter_state/{key}") resp = client.delete(f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}")
assert resp.status_code == 403 assert resp.status_code == 403
def test_delete_not_owner(client, dashboard_id: int): def test_delete_not_owner(client, dashboard_id: int):
login(client, "gamma") login(client, "gamma")
resp = client.delete(f"api/v1/dashboard/{dashboard_id}/filter_state/{key}") resp = client.delete(f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}")
assert resp.status_code == 403 assert resp.status_code == 403

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,90 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
from unittest.mock import patch
import pytest
from flask_appbuilder.security.sqla.models import User
from sqlalchemy.orm import Session
from superset import db
from superset.dashboards.commands.exceptions import DashboardAccessDeniedError
from superset.key_value.models import KeyValueEntry
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from tests.integration_tests.base_tests import login
from tests.integration_tests.fixtures.client import client
from tests.integration_tests.fixtures.world_bank_dashboard import (
load_world_bank_dashboard_with_slices,
load_world_bank_data,
)
from tests.integration_tests.test_app import app
STATE = {
"filterState": {"FILTER_1": "foo",},
"hash": "my-anchor",
}
@pytest.fixture
def dashboard_id(load_world_bank_dashboard_with_slices) -> int:
with app.app_context() as ctx:
session: Session = ctx.app.appbuilder.get_session
dashboard = session.query(Dashboard).filter_by(slug="world_health").one()
return dashboard.id
def test_post(client, dashboard_id: int):
login(client, "admin")
resp = client.post(f"api/v1/dashboard/{dashboard_id}/permalink", json=STATE)
assert resp.status_code == 201
data = json.loads(resp.data.decode("utf-8"))
key = data["key"]
url = data["url"]
assert key in url
db.session.query(KeyValueEntry).filter_by(uuid=key).delete()
db.session.commit()
@patch("superset.security.SupersetSecurityManager.raise_for_dashboard_access")
def test_post_access_denied(mock_raise_for_dashboard_access, client, dashboard_id: int):
login(client, "admin")
mock_raise_for_dashboard_access.side_effect = DashboardAccessDeniedError()
resp = client.post(f"api/v1/dashboard/{dashboard_id}/permalink", json=STATE)
assert resp.status_code == 403
def test_post_invalid_schema(client, dashboard_id: int):
login(client, "admin")
resp = client.post(
f"api/v1/dashboard/{dashboard_id}/permalink", json={"foo": "bar"}
)
assert resp.status_code == 400
def test_get(client, dashboard_id: int):
login(client, "admin")
resp = client.post(f"api/v1/dashboard/{dashboard_id}/permalink", json=STATE)
data = json.loads(resp.data.decode("utf-8"))
key = data["key"]
resp = client.get(f"api/v1/dashboard/permalink/{key}")
assert resp.status_code == 200
result = json.loads(resp.data.decode("utf-8"))
assert result["dashboardId"] == str(dashboard_id)
assert result["state"] == STATE
db.session.query(KeyValueEntry).filter_by(uuid=key).delete()
db.session.commit()

View File

@ -27,21 +27,16 @@ from superset.explore.form_data.commands.state import TemporaryExploreState
from superset.extensions import cache_manager from superset.extensions import cache_manager
from superset.models.slice import Slice from superset.models.slice import Slice
from tests.integration_tests.base_tests import login from tests.integration_tests.base_tests import login
from tests.integration_tests.fixtures.client import client
from tests.integration_tests.fixtures.world_bank_dashboard import ( from tests.integration_tests.fixtures.world_bank_dashboard import (
load_world_bank_dashboard_with_slices, load_world_bank_dashboard_with_slices,
load_world_bank_data, load_world_bank_data,
) )
from tests.integration_tests.test_app import app from tests.integration_tests.test_app import app
key = "test-key" KEY = "test-key"
form_data = "test" INITIAL_FORM_DATA = json.dumps({"test": "initial value"})
UPDATED_FORM_DATA = json.dumps({"test": "updated value"})
@pytest.fixture
def client():
with app.test_client() as client:
with app.app_context():
yield client
@pytest.fixture @pytest.fixture
@ -78,9 +73,9 @@ def cache(chart_id, admin_id, dataset_id):
"owner": admin_id, "owner": admin_id,
"dataset_id": dataset_id, "dataset_id": dataset_id,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": form_data, "form_data": INITIAL_FORM_DATA,
} }
cache_manager.explore_form_data_cache.set(key, entry) cache_manager.explore_form_data_cache.set(KEY, entry)
def test_post(client, chart_id: int, dataset_id: int): def test_post(client, chart_id: int, dataset_id: int):
@ -88,13 +83,13 @@ def test_post(client, chart_id: int, dataset_id: int):
payload = { payload = {
"dataset_id": dataset_id, "dataset_id": dataset_id,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": form_data, "form_data": INITIAL_FORM_DATA,
} }
resp = client.post("api/v1/explore/form_data", json=payload) resp = client.post("api/v1/explore/form_data", json=payload)
assert resp.status_code == 201 assert resp.status_code == 201
def test_post_bad_request(client, chart_id: int, dataset_id: int): def test_post_bad_request_non_string(client, chart_id: int, dataset_id: int):
login(client, "admin") login(client, "admin")
payload = { payload = {
"dataset_id": dataset_id, "dataset_id": dataset_id,
@ -105,12 +100,23 @@ def test_post_bad_request(client, chart_id: int, dataset_id: int):
assert resp.status_code == 400 assert resp.status_code == 400
def test_post_bad_request_non_json_string(client, chart_id: int, dataset_id: int):
login(client, "admin")
payload = {
"dataset_id": dataset_id,
"chart_id": chart_id,
"form_data": "foo",
}
resp = client.post("api/v1/explore/form_data", json=payload)
assert resp.status_code == 400
def test_post_access_denied(client, chart_id: int, dataset_id: int): def test_post_access_denied(client, chart_id: int, dataset_id: int):
login(client, "gamma") login(client, "gamma")
payload = { payload = {
"dataset_id": dataset_id, "dataset_id": dataset_id,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": form_data, "form_data": INITIAL_FORM_DATA,
} }
resp = client.post("api/v1/explore/form_data", json=payload) resp = client.post("api/v1/explore/form_data", json=payload)
assert resp.status_code == 404 assert resp.status_code == 404
@ -121,7 +127,7 @@ def test_post_same_key_for_same_context(client, chart_id: int, dataset_id: int):
payload = { payload = {
"dataset_id": dataset_id, "dataset_id": dataset_id,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": "new form_data", "form_data": UPDATED_FORM_DATA,
} }
resp = client.post("api/v1/explore/form_data?tab_id=1", json=payload) resp = client.post("api/v1/explore/form_data?tab_id=1", json=payload)
data = json.loads(resp.data.decode("utf-8")) data = json.loads(resp.data.decode("utf-8"))
@ -139,14 +145,14 @@ def test_post_different_key_for_different_context(
payload = { payload = {
"dataset_id": dataset_id, "dataset_id": dataset_id,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": "new form_data", "form_data": UPDATED_FORM_DATA,
} }
resp = client.post("api/v1/explore/form_data?tab_id=1", json=payload) resp = client.post("api/v1/explore/form_data?tab_id=1", json=payload)
data = json.loads(resp.data.decode("utf-8")) data = json.loads(resp.data.decode("utf-8"))
first_key = data.get("key") first_key = data.get("key")
payload = { payload = {
"dataset_id": dataset_id, "dataset_id": dataset_id,
"form_data": "new form_data", "form_data": json.dumps({"test": "initial value"}),
} }
resp = client.post("api/v1/explore/form_data?tab_id=1", json=payload) resp = client.post("api/v1/explore/form_data?tab_id=1", json=payload)
data = json.loads(resp.data.decode("utf-8")) data = json.loads(resp.data.decode("utf-8"))
@ -159,7 +165,7 @@ def test_post_same_key_for_same_tab_id(client, chart_id: int, dataset_id: int):
payload = { payload = {
"dataset_id": dataset_id, "dataset_id": dataset_id,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": "new form_data", "form_data": json.dumps({"test": "initial value"}),
} }
resp = client.post("api/v1/explore/form_data?tab_id=1", json=payload) resp = client.post("api/v1/explore/form_data?tab_id=1", json=payload)
data = json.loads(resp.data.decode("utf-8")) data = json.loads(resp.data.decode("utf-8"))
@ -177,7 +183,7 @@ def test_post_different_key_for_different_tab_id(
payload = { payload = {
"dataset_id": dataset_id, "dataset_id": dataset_id,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": "new form_data", "form_data": json.dumps({"test": "initial value"}),
} }
resp = client.post("api/v1/explore/form_data?tab_id=1", json=payload) resp = client.post("api/v1/explore/form_data?tab_id=1", json=payload)
data = json.loads(resp.data.decode("utf-8")) data = json.loads(resp.data.decode("utf-8"))
@ -193,7 +199,7 @@ def test_post_different_key_for_no_tab_id(client, chart_id: int, dataset_id: int
payload = { payload = {
"dataset_id": dataset_id, "dataset_id": dataset_id,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": "new form_data", "form_data": INITIAL_FORM_DATA,
} }
resp = client.post("api/v1/explore/form_data", json=payload) resp = client.post("api/v1/explore/form_data", json=payload)
data = json.loads(resp.data.decode("utf-8")) data = json.loads(resp.data.decode("utf-8"))
@ -209,9 +215,9 @@ def test_put(client, chart_id: int, dataset_id: int):
payload = { payload = {
"dataset_id": dataset_id, "dataset_id": dataset_id,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": "new form_data", "form_data": UPDATED_FORM_DATA,
} }
resp = client.put(f"api/v1/explore/form_data/{key}", json=payload) resp = client.put(f"api/v1/explore/form_data/{KEY}", json=payload)
assert resp.status_code == 200 assert resp.status_code == 200
@ -220,12 +226,12 @@ def test_put_same_key_for_same_tab_id(client, chart_id: int, dataset_id: int):
payload = { payload = {
"dataset_id": dataset_id, "dataset_id": dataset_id,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": "new form_data", "form_data": UPDATED_FORM_DATA,
} }
resp = client.put(f"api/v1/explore/form_data/{key}?tab_id=1", json=payload) resp = client.put(f"api/v1/explore/form_data/{KEY}?tab_id=1", json=payload)
data = json.loads(resp.data.decode("utf-8")) data = json.loads(resp.data.decode("utf-8"))
first_key = data.get("key") first_key = data.get("key")
resp = client.put(f"api/v1/explore/form_data/{key}?tab_id=1", json=payload) resp = client.put(f"api/v1/explore/form_data/{KEY}?tab_id=1", json=payload)
data = json.loads(resp.data.decode("utf-8")) data = json.loads(resp.data.decode("utf-8"))
second_key = data.get("key") second_key = data.get("key")
assert first_key == second_key assert first_key == second_key
@ -236,12 +242,12 @@ def test_put_different_key_for_different_tab_id(client, chart_id: int, dataset_i
payload = { payload = {
"dataset_id": dataset_id, "dataset_id": dataset_id,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": "new form_data", "form_data": UPDATED_FORM_DATA,
} }
resp = client.put(f"api/v1/explore/form_data/{key}?tab_id=1", json=payload) resp = client.put(f"api/v1/explore/form_data/{KEY}?tab_id=1", json=payload)
data = json.loads(resp.data.decode("utf-8")) data = json.loads(resp.data.decode("utf-8"))
first_key = data.get("key") first_key = data.get("key")
resp = client.put(f"api/v1/explore/form_data/{key}?tab_id=2", json=payload) resp = client.put(f"api/v1/explore/form_data/{KEY}?tab_id=2", json=payload)
data = json.loads(resp.data.decode("utf-8")) data = json.loads(resp.data.decode("utf-8"))
second_key = data.get("key") second_key = data.get("key")
assert first_key != second_key assert first_key != second_key
@ -252,12 +258,12 @@ def test_put_different_key_for_no_tab_id(client, chart_id: int, dataset_id: int)
payload = { payload = {
"dataset_id": dataset_id, "dataset_id": dataset_id,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": "new form_data", "form_data": UPDATED_FORM_DATA,
} }
resp = client.put(f"api/v1/explore/form_data/{key}", json=payload) resp = client.put(f"api/v1/explore/form_data/{KEY}", json=payload)
data = json.loads(resp.data.decode("utf-8")) data = json.loads(resp.data.decode("utf-8"))
first_key = data.get("key") first_key = data.get("key")
resp = client.put(f"api/v1/explore/form_data/{key}", json=payload) resp = client.put(f"api/v1/explore/form_data/{KEY}", json=payload)
data = json.loads(resp.data.decode("utf-8")) data = json.loads(resp.data.decode("utf-8"))
second_key = data.get("key") second_key = data.get("key")
assert first_key != second_key assert first_key != second_key
@ -270,7 +276,29 @@ def test_put_bad_request(client, chart_id: int, dataset_id: int):
"chart_id": chart_id, "chart_id": chart_id,
"form_data": 1234, "form_data": 1234,
} }
resp = client.put(f"api/v1/explore/form_data/{key}", json=payload) resp = client.put(f"api/v1/explore/form_data/{KEY}", json=payload)
assert resp.status_code == 400
def test_put_bad_request_non_string(client, chart_id: int, dataset_id: int):
login(client, "admin")
payload = {
"dataset_id": dataset_id,
"chart_id": chart_id,
"form_data": 1234,
}
resp = client.put(f"api/v1/explore/form_data/{KEY}", json=payload)
assert resp.status_code == 400
def test_put_bad_request_non_json_string(client, chart_id: int, dataset_id: int):
login(client, "admin")
payload = {
"dataset_id": dataset_id,
"chart_id": chart_id,
"form_data": "foo",
}
resp = client.put(f"api/v1/explore/form_data/{KEY}", json=payload)
assert resp.status_code == 400 assert resp.status_code == 400
@ -279,9 +307,9 @@ def test_put_access_denied(client, chart_id: int, dataset_id: int):
payload = { payload = {
"dataset_id": dataset_id, "dataset_id": dataset_id,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": "new form_data", "form_data": UPDATED_FORM_DATA,
} }
resp = client.put(f"api/v1/explore/form_data/{key}", json=payload) resp = client.put(f"api/v1/explore/form_data/{KEY}", json=payload)
assert resp.status_code == 404 assert resp.status_code == 404
@ -290,9 +318,9 @@ def test_put_not_owner(client, chart_id: int, dataset_id: int):
payload = { payload = {
"dataset_id": dataset_id, "dataset_id": dataset_id,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": "new form_data", "form_data": UPDATED_FORM_DATA,
} }
resp = client.put(f"api/v1/explore/form_data/{key}", json=payload) resp = client.put(f"api/v1/explore/form_data/{KEY}", json=payload)
assert resp.status_code == 404 assert resp.status_code == 404
@ -304,15 +332,15 @@ def test_get_key_not_found(client):
def test_get(client): def test_get(client):
login(client, "admin") login(client, "admin")
resp = client.get(f"api/v1/explore/form_data/{key}") resp = client.get(f"api/v1/explore/form_data/{KEY}")
assert resp.status_code == 200 assert resp.status_code == 200
data = json.loads(resp.data.decode("utf-8")) data = json.loads(resp.data.decode("utf-8"))
assert form_data == data.get("form_data") assert INITIAL_FORM_DATA == data.get("form_data")
def test_get_access_denied(client): def test_get_access_denied(client):
login(client, "gamma") login(client, "gamma")
resp = client.get(f"api/v1/explore/form_data/{key}") resp = client.get(f"api/v1/explore/form_data/{KEY}")
assert resp.status_code == 404 assert resp.status_code == 404
@ -320,19 +348,19 @@ def test_get_access_denied(client):
def test_get_dataset_access_denied(mock_can_access_datasource, client): def test_get_dataset_access_denied(mock_can_access_datasource, client):
mock_can_access_datasource.side_effect = DatasetAccessDeniedError() mock_can_access_datasource.side_effect = DatasetAccessDeniedError()
login(client, "admin") login(client, "admin")
resp = client.get(f"api/v1/explore/form_data/{key}") resp = client.get(f"api/v1/explore/form_data/{KEY}")
assert resp.status_code == 403 assert resp.status_code == 403
def test_delete(client): def test_delete(client):
login(client, "admin") login(client, "admin")
resp = client.delete(f"api/v1/explore/form_data/{key}") resp = client.delete(f"api/v1/explore/form_data/{KEY}")
assert resp.status_code == 200 assert resp.status_code == 200
def test_delete_access_denied(client): def test_delete_access_denied(client):
login(client, "gamma") login(client, "gamma")
resp = client.delete(f"api/v1/explore/form_data/{key}") resp = client.delete(f"api/v1/explore/form_data/{KEY}")
assert resp.status_code == 404 assert resp.status_code == 404
@ -343,7 +371,7 @@ def test_delete_not_owner(client, chart_id: int, dataset_id: int, admin_id: int)
"owner": another_owner, "owner": another_owner,
"dataset_id": dataset_id, "dataset_id": dataset_id,
"chart_id": chart_id, "chart_id": chart_id,
"form_data": form_data, "form_data": INITIAL_FORM_DATA,
} }
cache_manager.explore_form_data_cache.set(another_key, entry) cache_manager.explore_form_data_cache.set(another_key, entry)
login(client, "admin") login(client, "admin")

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,117 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import pickle
from typing import Any, Dict
from uuid import UUID
import pytest
from sqlalchemy.orm import Session
from superset import db
from superset.key_value.models import KeyValueEntry
from superset.models.slice import Slice
from tests.integration_tests.base_tests import login
from tests.integration_tests.fixtures.client import client
from tests.integration_tests.fixtures.world_bank_dashboard import (
load_world_bank_dashboard_with_slices,
load_world_bank_data,
)
from tests.integration_tests.test_app import app
@pytest.fixture
def chart(load_world_bank_dashboard_with_slices) -> Slice:
with app.app_context() as ctx:
session: Session = ctx.app.appbuilder.get_session
chart = session.query(Slice).filter_by(slice_name="World's Population").one()
return chart
@pytest.fixture
def form_data(chart) -> Dict[str, Any]:
datasource = f"{chart.datasource.id}__{chart.datasource.type}"
return {
"chart_id": chart.id,
"datasource": datasource,
}
def test_post(client, form_data):
login(client, "admin")
resp = client.post(f"api/v1/explore/permalink", json={"formData": form_data})
assert resp.status_code == 201
data = json.loads(resp.data.decode("utf-8"))
key = data["key"]
url = data["url"]
assert key in url
db.session.query(KeyValueEntry).filter_by(uuid=key).delete()
db.session.commit()
def test_post_access_denied(client, form_data):
login(client, "gamma")
resp = client.post(f"api/v1/explore/permalink", json={"formData": form_data})
assert resp.status_code == 404
def test_get_missing_chart(client, chart):
from superset.key_value.models import KeyValueEntry
key = 1234
uuid_key = "e2ea9d19-7988-4862-aa69-c3a1a7628cb9"
entry = KeyValueEntry(
id=int(key),
uuid=UUID("e2ea9d19-7988-4862-aa69-c3a1a7628cb9"),
resource="explore_permalink",
value=pickle.dumps(
{
"chartId": key,
"datasetId": chart.datasource.id,
"formData": {
"slice_id": key,
"datasource": f"{chart.datasource.id}__{chart.datasource.type}",
},
}
),
)
db.session.add(entry)
db.session.commit()
login(client, "admin")
resp = client.get(f"api/v1/explore/permalink/{uuid_key}")
assert resp.status_code == 404
db.session.delete(entry)
db.session.commit()
def test_post_invalid_schema(client):
login(client, "admin")
resp = client.post(f"api/v1/explore/permalink", json={"abc": 123})
assert resp.status_code == 400
def test_get(client, form_data):
login(client, "admin")
resp = client.post(f"api/v1/explore/permalink", json={"formData": form_data})
data = json.loads(resp.data.decode("utf-8"))
key = data["key"]
resp = client.get(f"api/v1/explore/permalink/{key}")
assert resp.status_code == 200
result = json.loads(resp.data.decode("utf-8"))
assert result["state"]["formData"] == form_data
db.session.query(KeyValueEntry).filter_by(uuid=key).delete()
db.session.commit()

View File

@ -0,0 +1,26 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from tests.integration_tests.test_app import app
@pytest.fixture
def client():
with app.test_client() as client:
with app.app_context():
yield client

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,64 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import pickle
from uuid import UUID
from flask.ctx import AppContext
from flask_appbuilder.security.sqla.models import User
from superset.extensions import db
from tests.integration_tests.key_value.commands.fixtures import (
admin,
ID_KEY,
RESOURCE,
UUID_KEY,
VALUE,
)
def test_create_id_entry(app_context: AppContext, admin: User) -> None:
from superset.key_value.commands.create import CreateKeyValueCommand
from superset.key_value.models import KeyValueEntry
key = CreateKeyValueCommand(
actor=admin, resource=RESOURCE, value=VALUE, key_type="id",
).run()
entry = (
db.session.query(KeyValueEntry).filter_by(id=int(key)).autoflush(False).one()
)
assert pickle.loads(entry.value) == VALUE
assert entry.created_by_fk == admin.id
db.session.delete(entry)
db.session.commit()
def test_create_uuid_entry(app_context: AppContext, admin: User) -> None:
from superset.key_value.commands.create import CreateKeyValueCommand
from superset.key_value.models import KeyValueEntry
key = CreateKeyValueCommand(
actor=admin, resource=RESOURCE, value=VALUE, key_type="uuid",
).run()
entry = (
db.session.query(KeyValueEntry).filter_by(uuid=UUID(key)).autoflush(False).one()
)
assert pickle.loads(entry.value) == VALUE
assert entry.created_by_fk == admin.id
db.session.delete(entry)
db.session.commit()

View File

@ -0,0 +1,91 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import pickle
from typing import TYPE_CHECKING
from uuid import UUID
import pytest
from flask.ctx import AppContext
from flask_appbuilder.security.sqla.models import User
from superset.extensions import db
from tests.integration_tests.key_value.commands.fixtures import admin, RESOURCE, VALUE
if TYPE_CHECKING:
from superset.key_value.models import KeyValueEntry
ID_KEY = "234"
UUID_KEY = "5aae143c-44f1-478e-9153-ae6154df333a"
@pytest.fixture
def key_value_entry() -> KeyValueEntry:
from superset.key_value.models import KeyValueEntry
entry = KeyValueEntry(
id=int(ID_KEY),
uuid=UUID(UUID_KEY),
resource=RESOURCE,
value=pickle.dumps(VALUE),
)
db.session.add(entry)
db.session.commit()
return entry
def test_delete_id_entry(
app_context: AppContext, admin: User, key_value_entry: KeyValueEntry,
) -> None:
from superset.key_value.commands.delete import DeleteKeyValueCommand
from superset.key_value.models import KeyValueEntry
assert (
DeleteKeyValueCommand(
actor=admin, resource=RESOURCE, key=ID_KEY, key_type="id",
).run()
is True
)
def test_delete_uuid_entry(
app_context: AppContext, admin: User, key_value_entry: KeyValueEntry,
) -> None:
from superset.key_value.commands.delete import DeleteKeyValueCommand
from superset.key_value.models import KeyValueEntry
assert (
DeleteKeyValueCommand(
actor=admin, resource=RESOURCE, key=UUID_KEY, key_type="uuid",
).run()
is True
)
def test_delete_entry_missing(
app_context: AppContext, admin: User, key_value_entry: KeyValueEntry,
) -> None:
from superset.key_value.commands.delete import DeleteKeyValueCommand
from superset.key_value.models import KeyValueEntry
assert (
DeleteKeyValueCommand(
actor=admin, resource=RESOURCE, key="456", key_type="id",
).run()
is False
)

View File

@ -0,0 +1,62 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import pickle
from typing import Generator, TYPE_CHECKING
from uuid import UUID
import pytest
from flask_appbuilder.security.sqla.models import User
from sqlalchemy.orm import Session
from superset.extensions import db
from tests.integration_tests.test_app import app
if TYPE_CHECKING:
from superset.key_value.models import KeyValueEntry
ID_KEY = "123"
UUID_KEY = "3e7a2ab8-bcaf-49b0-a5df-dfb432f291cc"
RESOURCE = "my_resource"
VALUE = {"foo": "bar"}
@pytest.fixture
def key_value_entry() -> Generator[KeyValueEntry, None, None]:
from superset.key_value.models import KeyValueEntry
entry = KeyValueEntry(
id=int(ID_KEY),
uuid=UUID(UUID_KEY),
resource=RESOURCE,
value=pickle.dumps(VALUE),
)
db.session.add(entry)
db.session.commit()
yield entry
db.session.delete(entry)
db.session.commit()
@pytest.fixture
def admin() -> User:
with app.app_context() as ctx:
session: Session = ctx.app.appbuilder.get_session
admin = session.query(User).filter_by(username="admin").one()
return admin

View File

@ -0,0 +1,100 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import pickle
import uuid
from datetime import datetime, timedelta
from typing import TYPE_CHECKING
from flask.ctx import AppContext
from superset.extensions import db
from tests.integration_tests.key_value.commands.fixtures import (
ID_KEY,
key_value_entry,
RESOURCE,
UUID_KEY,
VALUE,
)
if TYPE_CHECKING:
from superset.key_value.models import KeyValueEntry
def test_get_id_entry(app_context: AppContext, key_value_entry: KeyValueEntry) -> None:
from superset.key_value.commands.get import GetKeyValueCommand
value = GetKeyValueCommand(resource=RESOURCE, key=ID_KEY, key_type="id").run()
assert value == VALUE
def test_get_uuid_entry(
app_context: AppContext, key_value_entry: KeyValueEntry
) -> None:
from superset.key_value.commands.get import GetKeyValueCommand
value = GetKeyValueCommand(resource=RESOURCE, key=UUID_KEY, key_type="uuid").run()
assert value == VALUE
def test_get_id_entry_missing(
app_context: AppContext, key_value_entry: KeyValueEntry,
) -> None:
from superset.key_value.commands.get import GetKeyValueCommand
value = GetKeyValueCommand(resource=RESOURCE, key="456", key_type="id").run()
assert value is None
def test_get_expired_entry(app_context: AppContext) -> None:
from superset.key_value.commands.get import GetKeyValueCommand
from superset.key_value.models import KeyValueEntry
entry = KeyValueEntry(
id=678,
uuid=uuid.uuid4(),
resource=RESOURCE,
value=pickle.dumps(VALUE),
expires_on=datetime.now() - timedelta(days=1),
)
db.session.add(entry)
db.session.commit()
value = GetKeyValueCommand(resource=RESOURCE, key=ID_KEY, key_type="id").run()
assert value is None
db.session.delete(entry)
db.session.commit()
def test_get_future_expiring_entry(app_context: AppContext) -> None:
from superset.key_value.commands.get import GetKeyValueCommand
from superset.key_value.models import KeyValueEntry
id_ = 789
entry = KeyValueEntry(
id=id_,
uuid=uuid.uuid4(),
resource=RESOURCE,
value=pickle.dumps(VALUE),
expires_on=datetime.now() + timedelta(days=1),
)
db.session.add(entry)
db.session.commit()
value = GetKeyValueCommand(resource=RESOURCE, key=str(id_), key_type="id").run()
assert value == VALUE
db.session.delete(entry)
db.session.commit()

View File

@ -0,0 +1,87 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import pickle
from typing import TYPE_CHECKING
from uuid import UUID
from flask.ctx import AppContext
from flask_appbuilder.security.sqla.models import User
from superset.extensions import db
from tests.integration_tests.key_value.commands.fixtures import (
admin,
ID_KEY,
key_value_entry,
RESOURCE,
UUID_KEY,
)
if TYPE_CHECKING:
from superset.key_value.models import KeyValueEntry
NEW_VALUE = "new value"
def test_update_id_entry(
app_context: AppContext, admin: User, key_value_entry: KeyValueEntry,
) -> None:
from superset.key_value.commands.update import UpdateKeyValueCommand
from superset.key_value.models import KeyValueEntry
key = UpdateKeyValueCommand(
actor=admin, resource=RESOURCE, key=ID_KEY, value=NEW_VALUE, key_type="id",
).run()
assert key == ID_KEY
entry = (
db.session.query(KeyValueEntry).filter_by(id=int(ID_KEY)).autoflush(False).one()
)
assert pickle.loads(entry.value) == NEW_VALUE
assert entry.changed_by_fk == admin.id
def test_update_uuid_entry(
app_context: AppContext, admin: User, key_value_entry: KeyValueEntry,
) -> None:
from superset.key_value.commands.update import UpdateKeyValueCommand
from superset.key_value.models import KeyValueEntry
key = UpdateKeyValueCommand(
actor=admin, resource=RESOURCE, key=UUID_KEY, value=NEW_VALUE, key_type="uuid",
).run()
assert key == UUID_KEY
entry = (
db.session.query(KeyValueEntry)
.filter_by(uuid=UUID(UUID_KEY))
.autoflush(False)
.one()
)
assert pickle.loads(entry.value) == NEW_VALUE
assert entry.changed_by_fk == admin.id
def test_update_missing_entry(
app_context: AppContext, admin: User, key_value_entry: KeyValueEntry,
) -> None:
from superset.key_value.commands.update import UpdateKeyValueCommand
key = UpdateKeyValueCommand(
actor=admin, resource=RESOURCE, key="456", value=NEW_VALUE, key_type="id",
).run()
assert key is None

View File

@ -30,8 +30,8 @@ from superset.datasets.commands.exceptions import (
dataset_find_by_id = "superset.datasets.dao.DatasetDAO.find_by_id" dataset_find_by_id = "superset.datasets.dao.DatasetDAO.find_by_id"
chart_find_by_id = "superset.charts.dao.ChartDAO.find_by_id" chart_find_by_id = "superset.charts.dao.ChartDAO.find_by_id"
is_user_admin = "superset.explore.form_data.utils.is_user_admin" is_user_admin = "superset.explore.utils.is_user_admin"
is_owner = "superset.explore.form_data.utils.is_owner" is_owner = "superset.explore.utils.is_owner"
can_access_datasource = ( can_access_datasource = (
"superset.security.SupersetSecurityManager.can_access_datasource" "superset.security.SupersetSecurityManager.can_access_datasource"
) )
@ -39,7 +39,7 @@ can_access = "superset.security.SupersetSecurityManager.can_access"
def test_unsaved_chart_no_dataset_id(app_context: AppContext) -> None: def test_unsaved_chart_no_dataset_id(app_context: AppContext) -> None:
from superset.explore.form_data.utils import check_access from superset.explore.utils import check_access
with raises(DatasetNotFoundError): with raises(DatasetNotFoundError):
check_access(dataset_id=0, chart_id=0, actor=User()) check_access(dataset_id=0, chart_id=0, actor=User())
@ -48,7 +48,7 @@ def test_unsaved_chart_no_dataset_id(app_context: AppContext) -> None:
def test_unsaved_chart_unknown_dataset_id( def test_unsaved_chart_unknown_dataset_id(
mocker: MockFixture, app_context: AppContext mocker: MockFixture, app_context: AppContext
) -> None: ) -> None:
from superset.explore.form_data.utils import check_access from superset.explore.utils import check_access
with raises(DatasetNotFoundError): with raises(DatasetNotFoundError):
mocker.patch(dataset_find_by_id, return_value=None) mocker.patch(dataset_find_by_id, return_value=None)
@ -59,7 +59,7 @@ def test_unsaved_chart_unauthorized_dataset(
mocker: MockFixture, app_context: AppContext mocker: MockFixture, app_context: AppContext
) -> None: ) -> None:
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.explore.form_data import utils from superset.explore import utils
with raises(DatasetAccessDeniedError): with raises(DatasetAccessDeniedError):
mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(dataset_find_by_id, return_value=SqlaTable())
@ -71,7 +71,7 @@ def test_unsaved_chart_authorized_dataset(
mocker: MockFixture, app_context: AppContext mocker: MockFixture, app_context: AppContext
) -> None: ) -> None:
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.explore.form_data.utils import check_access from superset.explore.utils import check_access
mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(dataset_find_by_id, return_value=SqlaTable())
mocker.patch(can_access_datasource, return_value=True) mocker.patch(can_access_datasource, return_value=True)
@ -82,7 +82,7 @@ def test_saved_chart_unknown_chart_id(
mocker: MockFixture, app_context: AppContext mocker: MockFixture, app_context: AppContext
) -> None: ) -> None:
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.explore.form_data.utils import check_access from superset.explore.utils import check_access
with raises(ChartNotFoundError): with raises(ChartNotFoundError):
mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(dataset_find_by_id, return_value=SqlaTable())
@ -95,7 +95,7 @@ def test_saved_chart_unauthorized_dataset(
mocker: MockFixture, app_context: AppContext mocker: MockFixture, app_context: AppContext
) -> None: ) -> None:
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.explore.form_data import utils from superset.explore import utils
with raises(DatasetAccessDeniedError): with raises(DatasetAccessDeniedError):
mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(dataset_find_by_id, return_value=SqlaTable())
@ -105,19 +105,19 @@ def test_saved_chart_unauthorized_dataset(
def test_saved_chart_is_admin(mocker: MockFixture, app_context: AppContext) -> None: def test_saved_chart_is_admin(mocker: MockFixture, app_context: AppContext) -> None:
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.explore.form_data.utils import check_access from superset.explore.utils import check_access
from superset.models.slice import Slice from superset.models.slice import Slice
mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(dataset_find_by_id, return_value=SqlaTable())
mocker.patch(can_access_datasource, return_value=True) mocker.patch(can_access_datasource, return_value=True)
mocker.patch(is_user_admin, return_value=True) mocker.patch(is_user_admin, return_value=True)
mocker.patch(chart_find_by_id, return_value=Slice()) mocker.patch(chart_find_by_id, return_value=Slice())
assert check_access(dataset_id=1, chart_id=1, actor=User()) == True assert check_access(dataset_id=1, chart_id=1, actor=User()) is True
def test_saved_chart_is_owner(mocker: MockFixture, app_context: AppContext) -> None: def test_saved_chart_is_owner(mocker: MockFixture, app_context: AppContext) -> None:
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.explore.form_data.utils import check_access from superset.explore.utils import check_access
from superset.models.slice import Slice from superset.models.slice import Slice
mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(dataset_find_by_id, return_value=SqlaTable())
@ -130,7 +130,7 @@ def test_saved_chart_is_owner(mocker: MockFixture, app_context: AppContext) -> N
def test_saved_chart_has_access(mocker: MockFixture, app_context: AppContext) -> None: def test_saved_chart_has_access(mocker: MockFixture, app_context: AppContext) -> None:
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.explore.form_data.utils import check_access from superset.explore.utils import check_access
from superset.models.slice import Slice from superset.models.slice import Slice
mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(dataset_find_by_id, return_value=SqlaTable())
@ -144,7 +144,7 @@ def test_saved_chart_has_access(mocker: MockFixture, app_context: AppContext) ->
def test_saved_chart_no_access(mocker: MockFixture, app_context: AppContext) -> None: def test_saved_chart_no_access(mocker: MockFixture, app_context: AppContext) -> None:
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.explore.form_data.utils import check_access from superset.explore.utils import check_access
from superset.models.slice import Slice from superset.models.slice import Slice
with raises(ChartAccessDeniedError): with raises(ChartAccessDeniedError):

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,117 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import json
from typing import TYPE_CHECKING
from unittest.mock import patch
from uuid import UUID
if TYPE_CHECKING:
from superset.key_value.models import KeyValueEntry
import pytest
from flask.ctx import AppContext
from superset.key_value.types import Key
RESOURCE = "my-resource"
UUID_KEY = "3e7a2ab8-bcaf-49b0-a5df-dfb432f291cc"
ID_KEY = "123"
@pytest.fixture
def key_value_entry(app_context: AppContext):
from superset.key_value.models import KeyValueEntry
return KeyValueEntry(
id=int(ID_KEY), uuid=UUID(UUID_KEY), value=json.dumps({"foo": "bar"}),
)
def test_parse_permalink_key_uuid_valid(app_context: AppContext) -> None:
from superset.key_value.utils import parse_permalink_key
assert parse_permalink_key(UUID_KEY) == Key(id=None, uuid=UUID(UUID_KEY))
def test_parse_permalink_key_id_invalid(app_context: AppContext) -> None:
from superset.key_value.utils import parse_permalink_key
with pytest.raises(ValueError):
parse_permalink_key(ID_KEY)
@patch("superset.key_value.utils.current_app.config", {"PERMALINK_KEY_TYPE": "id"})
def test_parse_permalink_key_id_valid(app_context: AppContext) -> None:
from superset.key_value.utils import parse_permalink_key
assert parse_permalink_key(ID_KEY) == Key(id=int(ID_KEY), uuid=None)
@patch("superset.key_value.utils.current_app.config", {"PERMALINK_KEY_TYPE": "id"})
def test_parse_permalink_key_uuid_invalid(app_context: AppContext) -> None:
from superset.key_value.utils import parse_permalink_key
with pytest.raises(ValueError):
parse_permalink_key(UUID_KEY)
def test_format_permalink_key_uuid(app_context: AppContext) -> None:
from superset.key_value.utils import format_permalink_key
assert format_permalink_key(Key(id=None, uuid=UUID(UUID_KEY))) == UUID_KEY
def test_format_permalink_key_id(app_context: AppContext) -> None:
from superset.key_value.utils import format_permalink_key
assert format_permalink_key(Key(id=int(ID_KEY), uuid=None)) == ID_KEY
def test_extract_key_uuid(
app_context: AppContext, key_value_entry: KeyValueEntry,
) -> None:
from superset.key_value.utils import extract_key
assert extract_key(key_value_entry, "id") == ID_KEY
def test_extract_key_id(
app_context: AppContext, key_value_entry: KeyValueEntry,
) -> None:
from superset.key_value.utils import extract_key
assert extract_key(key_value_entry, "uuid") == UUID_KEY
def test_get_filter_uuid(app_context: AppContext,) -> None:
from superset.key_value.utils import get_filter
assert get_filter(resource=RESOURCE, key=UUID_KEY, key_type="uuid",) == {
"resource": RESOURCE,
"uuid": UUID(UUID_KEY),
}
def test_get_filter_id(app_context: AppContext,) -> None:
from superset.key_value.utils import get_filter
assert get_filter(resource=RESOURCE, key=ID_KEY, key_type="id",) == {
"resource": RESOURCE,
"id": int(ID_KEY),
}