feat(SIP-85): OAuth2 for databases (#27631)
This commit is contained in:
parent
fdc2dbe7db
commit
9022f5c519
|
|
@ -79,7 +79,7 @@ dependencies = [
|
||||||
"PyJWT>=2.4.0, <3.0",
|
"PyJWT>=2.4.0, <3.0",
|
||||||
"redis>=4.6.0, <5.0",
|
"redis>=4.6.0, <5.0",
|
||||||
"selenium>=3.141.0, <4.10.0",
|
"selenium>=3.141.0, <4.10.0",
|
||||||
"shillelagh[gsheetsapi]>=1.2.10, <2.0",
|
"shillelagh[gsheetsapi]>=1.2.18, <2.0",
|
||||||
"shortid",
|
"shortid",
|
||||||
"sshtunnel>=0.4.0, <0.5",
|
"sshtunnel>=0.4.0, <0.5",
|
||||||
"simplejson>=3.15.0",
|
"simplejson>=3.15.0",
|
||||||
|
|
@ -127,13 +127,14 @@ excel = ["xlrd>=1.2.0, <1.3"]
|
||||||
firebird = ["sqlalchemy-firebird>=0.7.0, <0.8"]
|
firebird = ["sqlalchemy-firebird>=0.7.0, <0.8"]
|
||||||
firebolt = ["firebolt-sqlalchemy>=1.0.0, <2"]
|
firebolt = ["firebolt-sqlalchemy>=1.0.0, <2"]
|
||||||
gevent = ["gevent>=23.9.1"]
|
gevent = ["gevent>=23.9.1"]
|
||||||
gsheets = ["shillelagh[gsheetsapi]>=1.2.10, <2"]
|
gsheets = ["shillelagh[gsheetsapi]>=1.2.18, <2"]
|
||||||
hana = ["hdbcli==2.4.162", "sqlalchemy_hana==0.4.0"]
|
hana = ["hdbcli==2.4.162", "sqlalchemy_hana==0.4.0"]
|
||||||
hive = [
|
hive = [
|
||||||
"pyhive[hive]>=0.6.5;python_version<'3.11'",
|
"pyhive[hive]>=0.6.5;python_version<'3.11'",
|
||||||
"pyhive[hive_pure_sasl]>=0.7.0",
|
"pyhive[hive_pure_sasl]>=0.7.0",
|
||||||
"tableschema",
|
"tableschema",
|
||||||
"thrift>=0.14.1, <1.0.0",
|
"thrift>=0.14.1, <1.0.0",
|
||||||
|
"thrift_sasl>=0.4.3, < 1.0.0",
|
||||||
]
|
]
|
||||||
impala = ["impyla>0.16.2, <0.17"]
|
impala = ["impyla>0.16.2, <0.17"]
|
||||||
kusto = ["sqlalchemy-kusto>=2.0.0, <3"]
|
kusto = ["sqlalchemy-kusto>=2.0.0, <3"]
|
||||||
|
|
@ -155,7 +156,7 @@ trino = ["trino>=0.328.0"]
|
||||||
prophet = ["prophet>=1.1.5, <2"]
|
prophet = ["prophet>=1.1.5, <2"]
|
||||||
redshift = ["sqlalchemy-redshift>=0.8.1, <0.9"]
|
redshift = ["sqlalchemy-redshift>=0.8.1, <0.9"]
|
||||||
rockset = ["rockset-sqlalchemy>=0.0.1, <1"]
|
rockset = ["rockset-sqlalchemy>=0.0.1, <1"]
|
||||||
shillelagh = ["shillelagh[all]>=1.2.10, <2"]
|
shillelagh = ["shillelagh[all]>=1.2.18, <2"]
|
||||||
snowflake = ["snowflake-sqlalchemy>=1.2.4, <2"]
|
snowflake = ["snowflake-sqlalchemy>=1.2.4, <2"]
|
||||||
spark = [
|
spark = [
|
||||||
"pyhive[hive]>=0.6.5;python_version<'3.11'",
|
"pyhive[hive]>=0.6.5;python_version<'3.11'",
|
||||||
|
|
|
||||||
|
|
@ -84,6 +84,7 @@ cryptography==42.0.5
|
||||||
# via
|
# via
|
||||||
# apache-superset
|
# apache-superset
|
||||||
# paramiko
|
# paramiko
|
||||||
|
# pyopenssl
|
||||||
deprecated==1.2.13
|
deprecated==1.2.13
|
||||||
# via limits
|
# via limits
|
||||||
deprecation==2.1.0
|
deprecation==2.1.0
|
||||||
|
|
@ -147,7 +148,9 @@ geopy==2.4.1
|
||||||
google-auth==2.27.0
|
google-auth==2.27.0
|
||||||
# via shillelagh
|
# via shillelagh
|
||||||
greenlet==3.0.3
|
greenlet==3.0.3
|
||||||
# via shillelagh
|
# via
|
||||||
|
# shillelagh
|
||||||
|
# sqlalchemy
|
||||||
gunicorn==21.2.0
|
gunicorn==21.2.0
|
||||||
# via apache-superset
|
# via apache-superset
|
||||||
hashids==1.3.1
|
hashids==1.3.1
|
||||||
|
|
@ -278,6 +281,8 @@ pyjwt==2.8.0
|
||||||
# flask-jwt-extended
|
# flask-jwt-extended
|
||||||
pynacl==1.5.0
|
pynacl==1.5.0
|
||||||
# via paramiko
|
# via paramiko
|
||||||
|
pyopenssl==24.1.0
|
||||||
|
# via shillelagh
|
||||||
pyparsing==3.1.2
|
pyparsing==3.1.2
|
||||||
# via apache-superset
|
# via apache-superset
|
||||||
pyrsistent==0.20.0
|
pyrsistent==0.20.0
|
||||||
|
|
@ -319,7 +324,7 @@ rsa==4.9
|
||||||
# via google-auth
|
# via google-auth
|
||||||
selenium==3.141.0
|
selenium==3.141.0
|
||||||
# via apache-superset
|
# via apache-superset
|
||||||
shillelagh[gsheetsapi]==1.2.10
|
shillelagh[gsheetsapi]==1.2.18
|
||||||
# via apache-superset
|
# via apache-superset
|
||||||
shortid==0.1.2
|
shortid==0.1.2
|
||||||
# via apache-superset
|
# via apache-superset
|
||||||
|
|
|
||||||
|
|
@ -202,6 +202,8 @@ ptyprocess==0.7.0
|
||||||
# via pexpect
|
# via pexpect
|
||||||
pure-eval==0.2.2
|
pure-eval==0.2.2
|
||||||
# via stack-data
|
# via stack-data
|
||||||
|
pure-sasl==0.6.2
|
||||||
|
# via thrift-sasl
|
||||||
pydata-google-auth==1.7.0
|
pydata-google-auth==1.7.0
|
||||||
# via pandas-gbq
|
# via pandas-gbq
|
||||||
pydruid==0.6.6
|
pydruid==0.6.6
|
||||||
|
|
@ -252,6 +254,10 @@ statsd==4.0.1
|
||||||
tableschema==1.20.10
|
tableschema==1.20.10
|
||||||
# via apache-superset
|
# via apache-superset
|
||||||
thrift==0.16.0
|
thrift==0.16.0
|
||||||
|
# via
|
||||||
|
# apache-superset
|
||||||
|
# thrift-sasl
|
||||||
|
thrift-sasl==0.4.3
|
||||||
# via apache-superset
|
# via apache-superset
|
||||||
tomli==2.0.1
|
tomli==2.0.1
|
||||||
# via
|
# via
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,170 @@
|
||||||
|
/**
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing,
|
||||||
|
* software distributed under the License is distributed on an
|
||||||
|
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
* KIND, either express or implied. See the License for the
|
||||||
|
* specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import React from 'react';
|
||||||
|
import * as reduxHooks from 'react-redux';
|
||||||
|
import { Provider } from 'react-redux';
|
||||||
|
import { createStore } from 'redux';
|
||||||
|
import { render, fireEvent, waitFor } from '@testing-library/react';
|
||||||
|
import '@testing-library/jest-dom';
|
||||||
|
import { ThemeProvider, supersetTheme } from '@superset-ui/core';
|
||||||
|
import OAuth2RedirectMessage from 'src/components/ErrorMessage/OAuth2RedirectMessage';
|
||||||
|
import {
|
||||||
|
ErrorLevel,
|
||||||
|
ErrorSource,
|
||||||
|
ErrorTypeEnum,
|
||||||
|
} from 'src/components/ErrorMessage/types';
|
||||||
|
import { reRunQuery } from 'src/SqlLab/actions/sqlLab';
|
||||||
|
import { triggerQuery } from 'src/components/Chart/chartAction';
|
||||||
|
import { onRefresh } from 'src/dashboard/actions/dashboardState';
|
||||||
|
|
||||||
|
// Mock the Redux store
|
||||||
|
const mockStore = createStore(() => ({
|
||||||
|
sqlLab: {
|
||||||
|
queries: { 'query-id': { sql: 'SELECT * FROM table' } },
|
||||||
|
queryEditors: [{ id: 'editor-id', latestQueryId: 'query-id' }],
|
||||||
|
tabHistory: ['editor-id'],
|
||||||
|
},
|
||||||
|
explore: {
|
||||||
|
slice: { slice_id: 123 },
|
||||||
|
},
|
||||||
|
charts: { '1': {}, '2': {} },
|
||||||
|
dashboardInfo: { id: 'dashboard-id' },
|
||||||
|
}));
|
||||||
|
|
||||||
|
// Mock actions
|
||||||
|
jest.mock('src/SqlLab/actions/sqlLab', () => ({
|
||||||
|
reRunQuery: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('src/components/Chart/chartAction', () => ({
|
||||||
|
triggerQuery: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('src/dashboard/actions/dashboardState', () => ({
|
||||||
|
onRefresh: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
// Mock useDispatch
|
||||||
|
const mockDispatch = jest.fn();
|
||||||
|
jest.spyOn(reduxHooks, 'useDispatch').mockReturnValue(mockDispatch);
|
||||||
|
|
||||||
|
// Mock global window functions
|
||||||
|
const mockOpen = jest.spyOn(window, 'open').mockImplementation(() => null);
|
||||||
|
const mockAddEventListener = jest.spyOn(window, 'addEventListener');
|
||||||
|
const mockRemoveEventListener = jest.spyOn(window, 'removeEventListener');
|
||||||
|
|
||||||
|
// Mock window.postMessage
|
||||||
|
const originalPostMessage = window.postMessage;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
window.postMessage = jest.fn();
|
||||||
|
});
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
window.postMessage = originalPostMessage;
|
||||||
|
});
|
||||||
|
|
||||||
|
function simulateMessageEvent(data: any, origin: string) {
|
||||||
|
const messageEvent = new MessageEvent('message', { data, origin });
|
||||||
|
window.dispatchEvent(messageEvent);
|
||||||
|
}
|
||||||
|
|
||||||
|
const defaultProps = {
|
||||||
|
error: {
|
||||||
|
error_type: ErrorTypeEnum.OAUTH2_REDIRECT,
|
||||||
|
message: "You don't have permission to access the data.",
|
||||||
|
extra: {
|
||||||
|
url: 'https://example.com',
|
||||||
|
tab_id: 'tabId',
|
||||||
|
redirect_uri: 'https://redirect.example.com',
|
||||||
|
},
|
||||||
|
level: 'warning' as ErrorLevel,
|
||||||
|
},
|
||||||
|
source: 'sqllab' as ErrorSource,
|
||||||
|
};
|
||||||
|
|
||||||
|
const setup = (overrides = {}) => (
|
||||||
|
<ThemeProvider theme={supersetTheme}>
|
||||||
|
<Provider store={mockStore}>
|
||||||
|
<OAuth2RedirectMessage {...defaultProps} {...overrides} />;
|
||||||
|
</Provider>
|
||||||
|
</ThemeProvider>
|
||||||
|
);
|
||||||
|
|
||||||
|
describe('OAuth2RedirectMessage Component', () => {
|
||||||
|
it('renders without crashing and displays the correct initial UI elements', () => {
|
||||||
|
const { getByText } = render(setup());
|
||||||
|
|
||||||
|
expect(getByText(/Authorization needed/i)).toBeInTheDocument();
|
||||||
|
expect(getByText(/provide authorization/i)).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('opens a new window with the correct URL when the link is clicked', () => {
|
||||||
|
const { getByText } = render(setup());
|
||||||
|
|
||||||
|
const linkElement = getByText(/provide authorization/i);
|
||||||
|
fireEvent.click(linkElement);
|
||||||
|
|
||||||
|
expect(mockOpen).toHaveBeenCalledWith('https://example.com', '_blank');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('cleans up the message event listener on unmount', () => {
|
||||||
|
const { unmount } = render(setup());
|
||||||
|
|
||||||
|
expect(mockAddEventListener).toHaveBeenCalled();
|
||||||
|
unmount();
|
||||||
|
expect(mockRemoveEventListener).toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('dispatches reRunQuery action when a message with correct tab ID is received for SQL Lab', async () => {
|
||||||
|
render(setup());
|
||||||
|
|
||||||
|
simulateMessageEvent({ tabId: 'tabId' }, 'https://redirect.example.com');
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(reRunQuery).toHaveBeenCalledWith({ sql: 'SELECT * FROM table' });
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('dispatches triggerQuery action for explore source upon receiving a correct message', async () => {
|
||||||
|
render(setup({ source: 'explore' }));
|
||||||
|
|
||||||
|
simulateMessageEvent({ tabId: 'tabId' }, 'https://redirect.example.com');
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(triggerQuery).toHaveBeenCalledWith(true, 123);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('dispatches onRefresh action for dashboard source upon receiving a correct message', async () => {
|
||||||
|
render(setup({ source: 'dashboard' }));
|
||||||
|
|
||||||
|
simulateMessageEvent({ tabId: 'tabId' }, 'https://redirect.example.com');
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(onRefresh).toHaveBeenCalledWith(
|
||||||
|
['1', '2'],
|
||||||
|
true,
|
||||||
|
0,
|
||||||
|
'dashboard-id',
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
@ -0,0 +1,179 @@
|
||||||
|
/**
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing,
|
||||||
|
* software distributed under the License is distributed on an
|
||||||
|
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
* KIND, either express or implied. See the License for the
|
||||||
|
* specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*/
|
||||||
|
import React, { useEffect, useRef } from 'react';
|
||||||
|
import { useDispatch, useSelector } from 'react-redux';
|
||||||
|
import { QueryEditor, SqlLabRootState } from 'src/SqlLab/types';
|
||||||
|
import { ExplorePageState } from 'src/explore/types';
|
||||||
|
import { RootState } from 'src/dashboard/types';
|
||||||
|
import { reRunQuery } from 'src/SqlLab/actions/sqlLab';
|
||||||
|
import { triggerQuery } from 'src/components/Chart/chartAction';
|
||||||
|
import { onRefresh } from 'src/dashboard/actions/dashboardState';
|
||||||
|
import { QueryResponse, t } from '@superset-ui/core';
|
||||||
|
|
||||||
|
import { ErrorMessageComponentProps } from './types';
|
||||||
|
import ErrorAlert from './ErrorAlert';
|
||||||
|
|
||||||
|
interface OAuth2RedirectExtra {
|
||||||
|
url: string;
|
||||||
|
tab_id: string;
|
||||||
|
redirect_uri: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Component for starting OAuth2 dance.
|
||||||
|
*
|
||||||
|
* When a user without credentials tries to access a database that supports OAuth2, the
|
||||||
|
* backend will raise an exception with the custom error `OAUTH2_REDIRECT`. This will
|
||||||
|
* cause the frontend to display this component, which informs the user that they need
|
||||||
|
* to authenticate in order to access the data.
|
||||||
|
*
|
||||||
|
* The component has a URL that is used to start the OAuth2 dance for the given
|
||||||
|
* database. When the user clicks the link a new browser tab will open, where they can
|
||||||
|
* authorize Superset to access the data. Once authorization is successful the user will
|
||||||
|
* be redirected back to Superset, and their personal access token is stored, so it can
|
||||||
|
* be used in subsequent connections. If a refresh token is also present in the response,
|
||||||
|
* it will also be stored.
|
||||||
|
*
|
||||||
|
* After the token has been stored, the opened tab will send a message to the original
|
||||||
|
* tab and close itself. This component, running on the original tab, will listen for
|
||||||
|
* message events, and once it receives the success message from the opened tab it will
|
||||||
|
* re-run the query for the user, be it in SQL Lab, Explore, or a dashboard. In order to
|
||||||
|
* communicate securely, both tabs share a "tab ID", which is a UUID that is generated
|
||||||
|
* by the backend and sent from the opened tab to the original tab. For extra security,
|
||||||
|
* we also check that the source of the message is the opened tab via a ref.
|
||||||
|
*/
|
||||||
|
function OAuth2RedirectMessage({
|
||||||
|
error,
|
||||||
|
source,
|
||||||
|
}: ErrorMessageComponentProps<OAuth2RedirectExtra>) {
|
||||||
|
const oAuthTab = useRef<Window | null>(null);
|
||||||
|
const { extra, level } = error;
|
||||||
|
|
||||||
|
// store a reference to the OAuth2 browser tab, so we can check that the success
|
||||||
|
// message is coming from it
|
||||||
|
const handleOAuthClick = (event: React.MouseEvent<HTMLAnchorElement>) => {
|
||||||
|
event.preventDefault();
|
||||||
|
oAuthTab.current = window.open(extra.url, '_blank');
|
||||||
|
};
|
||||||
|
|
||||||
|
// state needed for re-running the SQL Lab query
|
||||||
|
const queries = useSelector<
|
||||||
|
SqlLabRootState,
|
||||||
|
Record<string, QueryResponse & { inLocalStorage?: boolean }>
|
||||||
|
>(state => state.sqlLab.queries);
|
||||||
|
const queryEditors = useSelector<SqlLabRootState, QueryEditor[]>(
|
||||||
|
state => state.sqlLab.queryEditors,
|
||||||
|
);
|
||||||
|
const tabHistory = useSelector<SqlLabRootState, string[]>(
|
||||||
|
state => state.sqlLab.tabHistory,
|
||||||
|
);
|
||||||
|
const qe = queryEditors.find(
|
||||||
|
qe => qe.id === tabHistory[tabHistory.length - 1],
|
||||||
|
);
|
||||||
|
const query = qe?.latestQueryId ? queries[qe.latestQueryId] : null;
|
||||||
|
|
||||||
|
// state needed for triggering the chart in Explore
|
||||||
|
const chartId = useSelector<ExplorePageState, number | undefined>(
|
||||||
|
state => state.explore?.slice?.slice_id,
|
||||||
|
);
|
||||||
|
|
||||||
|
// state needed for refreshing dashboard
|
||||||
|
const chartList = useSelector<RootState, string[]>(state =>
|
||||||
|
Object.keys(state.charts),
|
||||||
|
);
|
||||||
|
const dashboardId = useSelector<RootState, number | undefined>(
|
||||||
|
state => state.dashboardInfo?.id,
|
||||||
|
);
|
||||||
|
|
||||||
|
const dispatch = useDispatch();
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
/* Listen for messages from the OAuth2 tab.
|
||||||
|
*
|
||||||
|
* After OAuth2 is successful the opened tab will send a message before
|
||||||
|
* closing itself. Once we receive the message we can retrigger the
|
||||||
|
* original query in SQL Lab, explore, or in a dashboard.
|
||||||
|
*/
|
||||||
|
const redirectUrl = new URL(extra.redirect_uri);
|
||||||
|
const handleMessage = (event: MessageEvent) => {
|
||||||
|
if (
|
||||||
|
event.origin === redirectUrl.origin &&
|
||||||
|
event.data.tabId === extra.tab_id &&
|
||||||
|
event.source === oAuthTab.current
|
||||||
|
) {
|
||||||
|
if (source === 'sqllab' && query) {
|
||||||
|
dispatch(reRunQuery(query));
|
||||||
|
} else if (source === 'explore' && chartId) {
|
||||||
|
dispatch(triggerQuery(true, chartId));
|
||||||
|
} else if (source === 'dashboard') {
|
||||||
|
dispatch(onRefresh(chartList, true, 0, dashboardId));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
window.addEventListener('message', handleMessage);
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
window.removeEventListener('message', handleMessage);
|
||||||
|
};
|
||||||
|
}, [
|
||||||
|
source,
|
||||||
|
extra.redirect_uri,
|
||||||
|
extra.tab_id,
|
||||||
|
dispatch,
|
||||||
|
query,
|
||||||
|
chartId,
|
||||||
|
chartList,
|
||||||
|
dashboardId,
|
||||||
|
]);
|
||||||
|
|
||||||
|
const body = (
|
||||||
|
<p>
|
||||||
|
This database uses OAuth2 for authentication. Please click the link above
|
||||||
|
to grant Apache Superset permission to access the data. Your personal
|
||||||
|
access token will be stored encrypted and used only for queries run by
|
||||||
|
you.
|
||||||
|
</p>
|
||||||
|
);
|
||||||
|
const subtitle = (
|
||||||
|
<>
|
||||||
|
You need to{' '}
|
||||||
|
<a
|
||||||
|
href={extra.url}
|
||||||
|
onClick={handleOAuthClick}
|
||||||
|
target="_blank"
|
||||||
|
rel="noreferrer"
|
||||||
|
>
|
||||||
|
provide authorization
|
||||||
|
</a>{' '}
|
||||||
|
in order to run this query.
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<ErrorAlert
|
||||||
|
title={t('Authorization needed')}
|
||||||
|
subtitle={subtitle}
|
||||||
|
level={level}
|
||||||
|
source={source}
|
||||||
|
body={body}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export default OAuth2RedirectMessage;
|
||||||
|
|
@ -56,6 +56,8 @@ export const ErrorTypeEnum = {
|
||||||
QUERY_SECURITY_ACCESS_ERROR: 'QUERY_SECURITY_ACCESS_ERROR',
|
QUERY_SECURITY_ACCESS_ERROR: 'QUERY_SECURITY_ACCESS_ERROR',
|
||||||
MISSING_OWNERSHIP_ERROR: 'MISSING_OWNERSHIP_ERROR',
|
MISSING_OWNERSHIP_ERROR: 'MISSING_OWNERSHIP_ERROR',
|
||||||
DASHBOARD_SECURITY_ACCESS_ERROR: 'DASHBOARD_SECURITY_ACCESS_ERROR',
|
DASHBOARD_SECURITY_ACCESS_ERROR: 'DASHBOARD_SECURITY_ACCESS_ERROR',
|
||||||
|
OAUTH2_REDIRECT: 'OAUTH2_REDIRECT',
|
||||||
|
OAUTH2_REDIRECT_ERROR: 'OAUTH2_REDIRECT_ERROR',
|
||||||
|
|
||||||
// Other errors
|
// Other errors
|
||||||
BACKEND_TIMEOUT_ERROR: 'BACKEND_TIMEOUT_ERROR',
|
BACKEND_TIMEOUT_ERROR: 'BACKEND_TIMEOUT_ERROR',
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@ import DatabaseErrorMessage from 'src/components/ErrorMessage/DatabaseErrorMessa
|
||||||
import MarshmallowErrorMessage from 'src/components/ErrorMessage/MarshmallowErrorMessage';
|
import MarshmallowErrorMessage from 'src/components/ErrorMessage/MarshmallowErrorMessage';
|
||||||
import ParameterErrorMessage from 'src/components/ErrorMessage/ParameterErrorMessage';
|
import ParameterErrorMessage from 'src/components/ErrorMessage/ParameterErrorMessage';
|
||||||
import DatasetNotFoundErrorMessage from 'src/components/ErrorMessage/DatasetNotFoundErrorMessage';
|
import DatasetNotFoundErrorMessage from 'src/components/ErrorMessage/DatasetNotFoundErrorMessage';
|
||||||
|
import OAuth2RedirectMessage from 'src/components/ErrorMessage/OAuth2RedirectMessage';
|
||||||
|
|
||||||
import setupErrorMessagesExtra from './setupErrorMessagesExtra';
|
import setupErrorMessagesExtra from './setupErrorMessagesExtra';
|
||||||
|
|
||||||
|
|
@ -149,5 +150,9 @@ export default function setupErrorMessages() {
|
||||||
ErrorTypeEnum.MARSHMALLOW_ERROR,
|
ErrorTypeEnum.MARSHMALLOW_ERROR,
|
||||||
MarshmallowErrorMessage,
|
MarshmallowErrorMessage,
|
||||||
);
|
);
|
||||||
|
errorMessageComponentRegistry.registerValue(
|
||||||
|
ErrorTypeEnum.OAUTH2_REDIRECT,
|
||||||
|
OAuth2RedirectMessage,
|
||||||
|
);
|
||||||
setupErrorMessagesExtra();
|
setupErrorMessagesExtra();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,6 @@ class ChartDataCommand(BaseCommand):
|
||||||
except CacheLoadError as ex:
|
except CacheLoadError as ex:
|
||||||
raise ChartDataCacheLoadError(ex.message) from ex
|
raise ChartDataCacheLoadError(ex.message) from ex
|
||||||
|
|
||||||
# TODO: QueryContext should support SIP-40 style errors
|
|
||||||
for query in payload["queries"]:
|
for query in payload["queries"]:
|
||||||
if query.get("error"):
|
if query.get("error"):
|
||||||
raise ChartDataQueryFailedError(
|
raise ChartDataQueryFailedError(
|
||||||
|
|
|
||||||
|
|
@ -1407,6 +1407,25 @@ PREFERRED_DATABASES: list[str] = [
|
||||||
# one here.
|
# one here.
|
||||||
TEST_DATABASE_CONNECTION_TIMEOUT = timedelta(seconds=30)
|
TEST_DATABASE_CONNECTION_TIMEOUT = timedelta(seconds=30)
|
||||||
|
|
||||||
|
# Details needed for databases that allows user to authenticate using personal
|
||||||
|
# OAuth2 tokens. See https://github.com/apache/superset/issues/20300 for more
|
||||||
|
# information
|
||||||
|
DATABASE_OAUTH2_CREDENTIALS: dict[str, dict[str, Any]] = {
|
||||||
|
# "Google Sheets": {
|
||||||
|
# "CLIENT_ID": "XXX.apps.googleusercontent.com",
|
||||||
|
# "CLIENT_SECRET": "GOCSPX-YYY",
|
||||||
|
# "BASEURL": "https://accounts.google.com/o/oauth2/v2/auth",
|
||||||
|
# },
|
||||||
|
}
|
||||||
|
# OAuth2 state is encoded in a JWT using the alogorithm below.
|
||||||
|
DATABASE_OAUTH2_JWT_ALGORITHM = "HS256"
|
||||||
|
# By default the redirect URI points to /api/v1/database/oauth2/ and doesn't have to be
|
||||||
|
# specified. If you're running multiple Superset instances you might want to have a
|
||||||
|
# proxy handling the redirects, since redirect URIs need to be registered in the OAuth2
|
||||||
|
# applications. In that case, the proxy can forward the request to the correct instance
|
||||||
|
# by looking at the `default_redirect_uri` attribute in the OAuth2 state object.
|
||||||
|
# DATABASE_OAUTH2_REDIRECT_URI = "http://localhost:8088/api/v1/database/oauth2/"
|
||||||
|
|
||||||
# Enable/disable CSP warning
|
# Enable/disable CSP warning
|
||||||
CONTENT_SECURITY_POLICY_WARNING = True
|
CONTENT_SECURITY_POLICY_WARNING = True
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -87,6 +87,8 @@ from superset.exceptions import (
|
||||||
DatasetInvalidPermissionEvaluationException,
|
DatasetInvalidPermissionEvaluationException,
|
||||||
QueryClauseValidationException,
|
QueryClauseValidationException,
|
||||||
QueryObjectValidationError,
|
QueryObjectValidationError,
|
||||||
|
SupersetErrorException,
|
||||||
|
SupersetErrorsException,
|
||||||
SupersetGenericDBErrorException,
|
SupersetGenericDBErrorException,
|
||||||
SupersetSecurityException,
|
SupersetSecurityException,
|
||||||
)
|
)
|
||||||
|
|
@ -1744,7 +1746,15 @@ class SqlaTable(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
df = self.database.get_df(sql, self.schema, mutator=assign_column_label)
|
df = self.database.get_df(sql, self.schema, mutator=assign_column_label)
|
||||||
|
except (SupersetErrorException, SupersetErrorsException) as ex:
|
||||||
|
# SupersetError(s) exception should not be captured; instead, they should
|
||||||
|
# bubble up to the Flask error handler so they are returned as proper SIP-40
|
||||||
|
# errors. This is particularly important for database OAuth2, see SIP-85.
|
||||||
|
raise ex
|
||||||
except Exception as ex: # pylint: disable=broad-except
|
except Exception as ex: # pylint: disable=broad-except
|
||||||
|
# TODO (betodealmeida): review exception handling while querying the external
|
||||||
|
# database. Ideally we'd expect and handle external database error, but
|
||||||
|
# everything else / the default should be to let things bubble up.
|
||||||
df = pd.DataFrame()
|
df = pd.DataFrame()
|
||||||
status = QueryStatus.FAILED
|
status = QueryStatus.FAILED
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
|
||||||
|
|
@ -145,7 +145,7 @@ def get_columns_description(
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
query = database.apply_limit_to_sql(query, limit=1)
|
query = database.apply_limit_to_sql(query, limit=1)
|
||||||
cursor.execute(query)
|
cursor.execute(query)
|
||||||
db_engine_spec.execute(cursor, query)
|
db_engine_spec.execute(cursor, query, database.id)
|
||||||
result = db_engine_spec.fetch_data(cursor, limit=1)
|
result = db_engine_spec.fetch_data(cursor, limit=1)
|
||||||
result_set = SupersetResultSet(result, cursor.description, db_engine_spec)
|
result_set = SupersetResultSet(result, cursor.description, db_engine_spec)
|
||||||
return result_set.columns
|
return result_set.columns
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ from superset.daos.base import BaseDAO
|
||||||
from superset.databases.filters import DatabaseFilter
|
from superset.databases.filters import DatabaseFilter
|
||||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||||
from superset.extensions import db
|
from superset.extensions import db
|
||||||
from superset.models.core import Database
|
from superset.models.core import Database, DatabaseUserOAuth2Tokens
|
||||||
from superset.models.dashboard import Dashboard
|
from superset.models.dashboard import Dashboard
|
||||||
from superset.models.slice import Slice
|
from superset.models.slice import Slice
|
||||||
from superset.models.sql_lab import TabState
|
from superset.models.sql_lab import TabState
|
||||||
|
|
@ -165,3 +165,9 @@ class SSHTunnelDAO(BaseDAO[SSHTunnel]):
|
||||||
attributes = unmask_password_info(attributes, item)
|
attributes = unmask_password_info(attributes, item)
|
||||||
|
|
||||||
return super().update(item, attributes, commit)
|
return super().update(item, attributes, commit)
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseUserOAuth2TokensDAO(BaseDAO[DatabaseUserOAuth2Tokens]):
|
||||||
|
"""
|
||||||
|
DAO for OAuth2 tokens.
|
||||||
|
"""
|
||||||
|
|
|
||||||
|
|
@ -17,13 +17,13 @@
|
||||||
# pylint: disable=too-many-lines
|
# pylint: disable=too-many-lines
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime, timedelta
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Any, cast, Optional
|
from typing import Any, cast, Optional
|
||||||
from zipfile import is_zipfile, ZipFile
|
from zipfile import is_zipfile, ZipFile
|
||||||
|
|
||||||
from deprecation import deprecated
|
from deprecation import deprecated
|
||||||
from flask import request, Response, send_file
|
from flask import make_response, render_template, request, Response, send_file
|
||||||
from flask_appbuilder.api import expose, protect, rison, safe
|
from flask_appbuilder.api import expose, protect, rison, safe
|
||||||
from flask_appbuilder.models.sqla.interface import SQLAInterface
|
from flask_appbuilder.models.sqla.interface import SQLAInterface
|
||||||
from marshmallow import ValidationError
|
from marshmallow import ValidationError
|
||||||
|
|
@ -62,7 +62,7 @@ from superset.commands.importers.exceptions import (
|
||||||
)
|
)
|
||||||
from superset.commands.importers.v1.utils import get_contents_from_bundle
|
from superset.commands.importers.v1.utils import get_contents_from_bundle
|
||||||
from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod
|
from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod
|
||||||
from superset.daos.database import DatabaseDAO
|
from superset.daos.database import DatabaseDAO, DatabaseUserOAuth2TokensDAO
|
||||||
from superset.databases.decorators import check_datasource_access
|
from superset.databases.decorators import check_datasource_access
|
||||||
from superset.databases.filters import DatabaseFilter, DatabaseUploadEnabledFilter
|
from superset.databases.filters import DatabaseFilter, DatabaseUploadEnabledFilter
|
||||||
from superset.databases.schemas import (
|
from superset.databases.schemas import (
|
||||||
|
|
@ -78,6 +78,7 @@ from superset.databases.schemas import (
|
||||||
DatabaseTestConnectionSchema,
|
DatabaseTestConnectionSchema,
|
||||||
DatabaseValidateParametersSchema,
|
DatabaseValidateParametersSchema,
|
||||||
get_export_ids_schema,
|
get_export_ids_schema,
|
||||||
|
OAuth2ProviderResponseSchema,
|
||||||
openapi_spec_methods_override,
|
openapi_spec_methods_override,
|
||||||
SchemasResponseSchema,
|
SchemasResponseSchema,
|
||||||
SelectStarResponseSchema,
|
SelectStarResponseSchema,
|
||||||
|
|
@ -89,11 +90,12 @@ from superset.databases.schemas import (
|
||||||
from superset.databases.utils import get_table_metadata
|
from superset.databases.utils import get_table_metadata
|
||||||
from superset.db_engine_specs import get_available_engine_specs
|
from superset.db_engine_specs import get_available_engine_specs
|
||||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||||
from superset.exceptions import SupersetErrorsException, SupersetException
|
from superset.exceptions import OAuth2Error, SupersetErrorsException, SupersetException
|
||||||
from superset.extensions import security_manager
|
from superset.extensions import security_manager
|
||||||
from superset.models.core import Database
|
from superset.models.core import Database
|
||||||
from superset.superset_typing import FlaskResponse
|
from superset.superset_typing import FlaskResponse
|
||||||
from superset.utils.core import error_msg_from_exception, parse_js_uri_path_item
|
from superset.utils.core import error_msg_from_exception, parse_js_uri_path_item
|
||||||
|
from superset.utils.oauth2 import decode_oauth2_state
|
||||||
from superset.utils.ssh_tunnel import mask_password_info
|
from superset.utils.ssh_tunnel import mask_password_info
|
||||||
from superset.views.base import json_errors_response
|
from superset.views.base import json_errors_response
|
||||||
from superset.views.base_api import (
|
from superset.views.base_api import (
|
||||||
|
|
@ -106,6 +108,7 @@ from superset.views.base_api import (
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=too-many-public-methods
|
||||||
class DatabaseRestApi(BaseSupersetModelRestApi):
|
class DatabaseRestApi(BaseSupersetModelRestApi):
|
||||||
datamodel = SQLAInterface(Database)
|
datamodel = SQLAInterface(Database)
|
||||||
|
|
||||||
|
|
@ -127,7 +130,9 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
||||||
"delete_ssh_tunnel",
|
"delete_ssh_tunnel",
|
||||||
"schemas_access_for_file_upload",
|
"schemas_access_for_file_upload",
|
||||||
"get_connection",
|
"get_connection",
|
||||||
|
"oauth2",
|
||||||
}
|
}
|
||||||
|
|
||||||
resource_name = "database"
|
resource_name = "database"
|
||||||
class_permission_name = "Database"
|
class_permission_name = "Database"
|
||||||
method_permission_name = MODEL_API_RW_METHOD_PERMISSION_MAP
|
method_permission_name = MODEL_API_RW_METHOD_PERMISSION_MAP
|
||||||
|
|
@ -1050,6 +1055,98 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
||||||
except DatabaseNotFoundError:
|
except DatabaseNotFoundError:
|
||||||
return self.response_404()
|
return self.response_404()
|
||||||
|
|
||||||
|
@expose("/oauth2/", methods=["GET"])
|
||||||
|
@event_logger.log_this_with_context(
|
||||||
|
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.oauth2",
|
||||||
|
log_to_statsd=True,
|
||||||
|
)
|
||||||
|
def oauth2(self) -> FlaskResponse:
|
||||||
|
"""
|
||||||
|
---
|
||||||
|
get:
|
||||||
|
summary: >-
|
||||||
|
Receive personal access tokens from OAuth2
|
||||||
|
description: ->
|
||||||
|
Receive and store personal access tokens from OAuth for user-level
|
||||||
|
authorization
|
||||||
|
parameters:
|
||||||
|
- in: query
|
||||||
|
name: state
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
- in: query
|
||||||
|
name: code
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
- in: query
|
||||||
|
name: scope
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
- in: query
|
||||||
|
name: error
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
responses:
|
||||||
|
200:
|
||||||
|
description: A dummy self-closing HTML page
|
||||||
|
content:
|
||||||
|
text/html:
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
400:
|
||||||
|
$ref: '#/components/responses/400'
|
||||||
|
404:
|
||||||
|
$ref: '#/components/responses/404'
|
||||||
|
500:
|
||||||
|
$ref: '#/components/responses/500'
|
||||||
|
"""
|
||||||
|
parameters = OAuth2ProviderResponseSchema().load(request.args)
|
||||||
|
|
||||||
|
if "error" in parameters:
|
||||||
|
raise OAuth2Error(parameters["error"])
|
||||||
|
|
||||||
|
# note that when decoding the state we will perform JWT validation, preventing a
|
||||||
|
# malicious payload that would insert a bogus database token, or delete an
|
||||||
|
# existing one.
|
||||||
|
state = decode_oauth2_state(parameters["state"])
|
||||||
|
|
||||||
|
# exchange code for access/refresh tokens
|
||||||
|
database = DatabaseDAO.find_by_id(state["database_id"])
|
||||||
|
if database is None:
|
||||||
|
return self.response_404()
|
||||||
|
|
||||||
|
token_response = database.db_engine_spec.get_oauth2_token(
|
||||||
|
parameters["code"],
|
||||||
|
state,
|
||||||
|
)
|
||||||
|
|
||||||
|
# delete old tokens
|
||||||
|
existing = DatabaseUserOAuth2TokensDAO.find_one_or_none(
|
||||||
|
user_id=state["user_id"],
|
||||||
|
database_id=state["database_id"],
|
||||||
|
)
|
||||||
|
if existing:
|
||||||
|
DatabaseUserOAuth2TokensDAO.delete([existing], commit=True)
|
||||||
|
|
||||||
|
# store tokens
|
||||||
|
expiration = datetime.now() + timedelta(seconds=token_response["expires_in"])
|
||||||
|
DatabaseUserOAuth2TokensDAO.create(
|
||||||
|
attributes={
|
||||||
|
"user_id": state["user_id"],
|
||||||
|
"database_id": state["database_id"],
|
||||||
|
"access_token": token_response["access_token"],
|
||||||
|
"access_token_expiration": expiration,
|
||||||
|
"refresh_token": token_response.get("refresh_token"),
|
||||||
|
},
|
||||||
|
commit=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# return blank page that closes itself
|
||||||
|
return make_response(
|
||||||
|
render_template("superset/oauth2.html", tab_id=state["tab_id"]),
|
||||||
|
200,
|
||||||
|
)
|
||||||
|
|
||||||
@expose("/export/", methods=("GET",))
|
@expose("/export/", methods=("GET",))
|
||||||
@protect()
|
@protect()
|
||||||
@safe
|
@safe
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument, too-many-lines
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
|
|
@ -978,3 +978,38 @@ class DatabaseConnectionSchema(Schema):
|
||||||
metadata={"description": sqlalchemy_uri_description},
|
metadata={"description": sqlalchemy_uri_description},
|
||||||
validate=[Length(1, 1024), sqlalchemy_uri_validator],
|
validate=[Length(1, 1024), sqlalchemy_uri_validator],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuth2ProviderResponseSchema(Schema):
|
||||||
|
"""
|
||||||
|
Schema for the payload sent on OAuth2 redirect.
|
||||||
|
"""
|
||||||
|
|
||||||
|
code = fields.String(
|
||||||
|
required=False,
|
||||||
|
metadata={"description": "The authorization code returned by the provider"},
|
||||||
|
)
|
||||||
|
state = fields.String(
|
||||||
|
required=False,
|
||||||
|
metadata={"description": "The state parameter originally passed by the client"},
|
||||||
|
)
|
||||||
|
scope = fields.String(
|
||||||
|
required=False,
|
||||||
|
metadata={
|
||||||
|
"description": "A space-separated list of scopes granted by the user"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
error = fields.String(
|
||||||
|
required=False,
|
||||||
|
metadata={
|
||||||
|
"description": "In case of an error, this field contains the error code"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
error_description = fields.String(
|
||||||
|
required=False,
|
||||||
|
metadata={"description": "Additional description of the error"},
|
||||||
|
)
|
||||||
|
|
||||||
|
class Meta: # pylint: disable=too-few-public-methods
|
||||||
|
# Ignore unknown fields that might be sent by the OAuth2 provider
|
||||||
|
unknown = EXCLUDE
|
||||||
|
|
|
||||||
|
|
@ -529,6 +529,7 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
|
||||||
url: URL,
|
url: URL,
|
||||||
impersonate_user: bool,
|
impersonate_user: bool,
|
||||||
username: str | None,
|
username: str | None,
|
||||||
|
access_token: str | None,
|
||||||
) -> URL:
|
) -> URL:
|
||||||
if impersonate_user and username is not None:
|
if impersonate_user and username is not None:
|
||||||
user = security_manager.find_user(username=username)
|
user = security_manager.find_user(username=username)
|
||||||
|
|
@ -542,6 +543,70 @@ The method `get_url_for_impersonation` updates the SQLAlchemy URI before every q
|
||||||
|
|
||||||
Alternatively, it's also possible to impersonate users by implementing the `update_impersonation_config`. This is a class method which modifies `connect_args` in place. You can use either method, and ideally they [should be consolidated in a single one](https://github.com/apache/superset/issues/24910).
|
Alternatively, it's also possible to impersonate users by implementing the `update_impersonation_config`. This is a class method which modifies `connect_args` in place. You can use either method, and ideally they [should be consolidated in a single one](https://github.com/apache/superset/issues/24910).
|
||||||
|
|
||||||
|
### OAuth2
|
||||||
|
|
||||||
|
Support for authenticating to a database using personal OAuth2 access tokens was introduced in [SIP-85](https://github.com/apache/superset/issues/20300). The Google Sheets DB engine spec is the reference implementation.
|
||||||
|
|
||||||
|
To add support for OAuth2 to a DB engine spec, the following attribute and methods are needed:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class BaseEngineSpec:
|
||||||
|
|
||||||
|
oauth2_exception = OAuth2RedirectError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_oauth2_enabled() -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_oauth2_authorization_uri(state: OAuth2State) -> str:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_oauth2_token(code: str, state: OAuth2State) -> OAuth2TokenResponse:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_oauth2_fresh_token(refresh_token: str) -> OAuth2TokenResponse:
|
||||||
|
raise NotImplementedError()
|
||||||
|
```
|
||||||
|
|
||||||
|
The `oauth2_exception` is an exception that is raised by `cursor.execute` when OAuth2 is needed. This will start the OAuth2 dance when `BaseEngineSpec.execute` is called, by returning the custom error `OAUTH2_REDIRECT` to the frontend. If the database driver doesn't have a specific exception, it might be necessary to overload the `execute` method in the DB engine spec, so that the `BaseEngineSpec.start_oauth2_dance` method gets called whenever OAuth2 is needed.
|
||||||
|
|
||||||
|
The first method, `is_oauth2_enabled`, is used to inform if the database supports OAuth2. This can be dynamic; for example, the Google Sheets DB engine spec checks if the Superset configuration has the necessary section:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from flask import current_app
|
||||||
|
|
||||||
|
|
||||||
|
class GSheetsEngineSpec(ShillelaghEngineSpec):
|
||||||
|
@staticmethod
|
||||||
|
def is_oauth2_enabled() -> bool:
|
||||||
|
return "Google Sheets" in current_app.config["DATABASE_OAUTH2_CREDENTIALS"]
|
||||||
|
```
|
||||||
|
|
||||||
|
Where the configuration for OAuth2 would look like this:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# superset_config.py
|
||||||
|
DATABASE_OAUTH2_CREDENTIALS = {
|
||||||
|
"Google Sheets": {
|
||||||
|
"CLIENT_ID": "XXX.apps.googleusercontent.com",
|
||||||
|
"CLIENT_SECRET": "GOCSPX-YYY",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
DATABASE_OAUTH2_JWT_ALGORITHM = "HS256"
|
||||||
|
DATABASE_OAUTH2_REDIRECT_URI = "http://localhost:8088/api/v1/database/oauth2/"
|
||||||
|
```
|
||||||
|
|
||||||
|
The second method, `get_oauth2_authorization_uri`, is responsible for building the URL where the user is sent to initiate OAuth2. This method receives a `state`. The state is an encoded JWT that is passed to the OAuth2 provider, and is received unmodified when the user is redirected back to Superset. The default state contains the user ID and the database ID, so that Superset can know where to store the received OAuth2 tokens.
|
||||||
|
|
||||||
|
Additionally, the state also contains a `tab_id`, which is a random UUID4 used as a shared secret for communication between browser tabs. When OAuth2 starts, Superset will open a new browser tab, where the user will grant permissions to Superset. When authentication is complete and successful this opened tab will send a message to the original tab, so that the original query can be re-run. The `tab_id` is sent by the opened tab and verified by the original tab to prevent malicious messages from other sites. As an additional security measure the origin of the message should match the OAuth2 redirect URL.
|
||||||
|
|
||||||
|
State also contains a `defaul_redirect_uri`, which is the enpoint in Supeset that receives the tokens from the OAuth2 provider (`/api/v1/database/oauth2/`). The redirect URL can be overwritten in the config file via the `DATABASE_OAUTH2_REDIRECT_URI` parameter. This might be useful where you have multiple Superset instances. Since the OAuth2 provider requires the redirect URL to be registered a priori, it might be easier (or needed) to register a single URL for a proxy service; the proxy service can then inspect the JWT and redirect the request to `defaul_redirect_uri`.
|
||||||
|
|
||||||
|
Finally, `get_oauth2_token` and `get_oauth2_fresh_token` are used to actually retrieve a token and refresh an expired token, respectively.
|
||||||
|
|
||||||
### File upload
|
### File upload
|
||||||
|
|
||||||
When a DB engine spec supports file upload it declares so via the `supports_file_upload` class attribute. The base class implementation is very generic and should work for any database that has support for `CREATE TABLE`. It leverages Pandas and the [`df_to_sql`](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_sql.html) method.
|
When a DB engine spec supports file upload it declares so via the `supports_file_upload` class attribute. The base class implementation is very generic and should work for any database that has support for `CREATE TABLE`. It leverages Pandas and the [`df_to_sql`](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_sql.html) method.
|
||||||
|
|
@ -615,7 +680,7 @@ SELECT * FROM my_table
|
||||||
|
|
||||||
The table `my_table` should live in the `dev` schema. In order to do that, it's necessary to modify the SQLAlchemy URI before running the query. Since different databases have different ways of doing that, this functionality is implemented via the `adjust_engine_params` class method. The method receives the SQLAlchemy URI and `connect_args`, as well as the schema in which the query should run. It then returns a potentially modified URI and `connect_args` to ensure that the query runs in the specified schema.
|
The table `my_table` should live in the `dev` schema. In order to do that, it's necessary to modify the SQLAlchemy URI before running the query. Since different databases have different ways of doing that, this functionality is implemented via the `adjust_engine_params` class method. The method receives the SQLAlchemy URI and `connect_args`, as well as the schema in which the query should run. It then returns a potentially modified URI and `connect_args` to ensure that the query runs in the specified schema.
|
||||||
|
|
||||||
When a DB engine specs implements `adjust_engine_params` it should have the class attribute `supports_dynamic_schema` set to true. This is critical for security, since **it allows Superset to know to which schema any unqualified table names belong to**. For example, in the query above, if the database supports dynamic schema, Superset would check to see if the user running the query has access to `dev.my_table`. On the other hand, if the database doesn't support dynamic schema, Superset would sue the default database schema instead of `dev`.
|
When a DB engine specs implements `adjust_engine_params` it should have the class attribute `supports_dynamic_schema` set to true. This is critical for security, since **it allows Superset to know to which schema any unqualified table names belong to**. For example, in the query above, if the database supports dynamic schema, Superset would check to see if the user running the query has access to `dev.my_table`. On the other hand, if the database doesn't support dynamic schema, Superset would use the default database schema instead of `dev`.
|
||||||
|
|
||||||
Implementing this method is also important for usability. When the method is not implemented selecting the schema in SQL Lab has no effect on the schema in which the query runs, resulting in a confusing results when using unqualified table names.
|
Implementing this method is also important for usability. When the method is not implemented selecting the schema in SQL Lab has no effect on the schema in which the query runs, resulting in a confusing results when using unqualified table names.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -33,13 +33,14 @@ from typing import (
|
||||||
TypedDict,
|
TypedDict,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import sqlparse
|
import sqlparse
|
||||||
from apispec import APISpec
|
from apispec import APISpec
|
||||||
from apispec.ext.marshmallow import MarshmallowPlugin
|
from apispec.ext.marshmallow import MarshmallowPlugin
|
||||||
from deprecation import deprecated
|
from deprecation import deprecated
|
||||||
from flask import current_app
|
from flask import current_app, g, url_for
|
||||||
from flask_appbuilder.security.sqla.models import User
|
from flask_appbuilder.security.sqla.models import User
|
||||||
from flask_babel import gettext as __, lazy_gettext as _
|
from flask_babel import gettext as __, lazy_gettext as _
|
||||||
from marshmallow import fields, Schema
|
from marshmallow import fields, Schema
|
||||||
|
|
@ -59,6 +60,7 @@ from superset import security_manager, sql_parse
|
||||||
from superset.constants import TimeGrain as TimeGrainConstants
|
from superset.constants import TimeGrain as TimeGrainConstants
|
||||||
from superset.databases.utils import make_url_safe
|
from superset.databases.utils import make_url_safe
|
||||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||||
|
from superset.exceptions import OAuth2Error, OAuth2RedirectError
|
||||||
from superset.sql_parse import ParsedQuery, SQLScript, Table
|
from superset.sql_parse import ParsedQuery, SQLScript, Table
|
||||||
from superset.superset_typing import ResultSetColumnType, SQLAColumnType
|
from superset.superset_typing import ResultSetColumnType, SQLAColumnType
|
||||||
from superset.utils import core as utils
|
from superset.utils import core as utils
|
||||||
|
|
@ -71,6 +73,7 @@ if TYPE_CHECKING:
|
||||||
from superset.models.core import Database
|
from superset.models.core import Database
|
||||||
from superset.models.sql_lab import Query
|
from superset.models.sql_lab import Query
|
||||||
|
|
||||||
|
|
||||||
ColumnTypeMapping = tuple[
|
ColumnTypeMapping = tuple[
|
||||||
Pattern[str],
|
Pattern[str],
|
||||||
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
|
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
|
||||||
|
|
@ -170,6 +173,31 @@ class MetricType(TypedDict, total=False):
|
||||||
extra: str | None
|
extra: str | None
|
||||||
|
|
||||||
|
|
||||||
|
class OAuth2TokenResponse(TypedDict, total=False):
|
||||||
|
"""
|
||||||
|
Type for an OAuth2 response when exchanging or refreshing tokens.
|
||||||
|
"""
|
||||||
|
|
||||||
|
access_token: str
|
||||||
|
expires_in: int
|
||||||
|
scope: str
|
||||||
|
token_type: str
|
||||||
|
|
||||||
|
# only present when exchanging code for refresh/access tokens
|
||||||
|
refresh_token: str
|
||||||
|
|
||||||
|
|
||||||
|
class OAuth2State(TypedDict):
|
||||||
|
"""
|
||||||
|
Type for the state passed during OAuth2.
|
||||||
|
"""
|
||||||
|
|
||||||
|
database_id: int
|
||||||
|
user_id: int
|
||||||
|
default_redirect_uri: str
|
||||||
|
tab_id: str
|
||||||
|
|
||||||
|
|
||||||
class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||||
"""Abstract class for database engine specific configurations
|
"""Abstract class for database engine specific configurations
|
||||||
|
|
||||||
|
|
@ -397,6 +425,79 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||||
# Can the catalog be changed on a per-query basis?
|
# Can the catalog be changed on a per-query basis?
|
||||||
supports_dynamic_catalog = False
|
supports_dynamic_catalog = False
|
||||||
|
|
||||||
|
# Driver-specific exception that should be mapped to OAuth2RedirectError
|
||||||
|
oauth2_exception = OAuth2RedirectError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_oauth2_enabled() -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def start_oauth2_dance(cls, database_id: int) -> None:
|
||||||
|
"""
|
||||||
|
Start the OAuth2 dance.
|
||||||
|
|
||||||
|
This method will raise a custom exception that is captured by the frontend to
|
||||||
|
start the OAuth2 authentication. The frontend will open a new tab where the user
|
||||||
|
can authorize Superset to access the database. Once the user has authorized, the
|
||||||
|
tab sends a message to the original tab informing that authorization was
|
||||||
|
successful (or not), and then closes. The original tab will automatically
|
||||||
|
re-run the query after authorization.
|
||||||
|
"""
|
||||||
|
tab_id = str(uuid4())
|
||||||
|
default_redirect_uri = url_for("DatabaseRestApi.oauth2", _external=True)
|
||||||
|
redirect_uri = current_app.config.get(
|
||||||
|
"DATABASE_OAUTH2_REDIRECT_URI",
|
||||||
|
default_redirect_uri,
|
||||||
|
)
|
||||||
|
|
||||||
|
# The state is passed to the OAuth2 provider, and sent back to Superset after
|
||||||
|
# the user authorizes the access. The redirect endpoint in Superset can then
|
||||||
|
# inspect the state to figure out to which user/database the access token
|
||||||
|
# belongs to.
|
||||||
|
state: OAuth2State = {
|
||||||
|
# Database ID and user ID are the primary key associated with the token.
|
||||||
|
"database_id": database_id,
|
||||||
|
"user_id": g.user.id,
|
||||||
|
# In multi-instance deployments there might be a single proxy handling
|
||||||
|
# redirects, with a custom `DATABASE_OAUTH2_REDIRECT_URI`. Since the OAuth2
|
||||||
|
# application requires every redirect URL to be registered a priori, this
|
||||||
|
# allows OAuth2 to be used where new instances are being constantly
|
||||||
|
# deployed. The proxy can extract `default_redirect_uri` from the state and
|
||||||
|
# then forward the token to the instance that initiated the authentication.
|
||||||
|
"default_redirect_uri": default_redirect_uri,
|
||||||
|
# When OAuth2 is complete the browser tab where OAuth2 happened will send a
|
||||||
|
# message to the original browser tab informing that the process was
|
||||||
|
# successful. To allow cross-tab commmunication in a safe way we assign a
|
||||||
|
# UUID to the original tab, and the second tab will use it when sending the
|
||||||
|
# message.
|
||||||
|
"tab_id": tab_id,
|
||||||
|
}
|
||||||
|
oauth_url = cls.get_oauth2_authorization_uri(state)
|
||||||
|
|
||||||
|
raise OAuth2RedirectError(oauth_url, tab_id, redirect_uri)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_oauth2_authorization_uri(state: OAuth2State) -> str:
|
||||||
|
"""
|
||||||
|
Return URI for initial OAuth2 request.
|
||||||
|
"""
|
||||||
|
raise OAuth2Error("Subclasses must implement `get_oauth2_authorization_uri`")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_oauth2_token(code: str, state: OAuth2State) -> OAuth2TokenResponse:
|
||||||
|
"""
|
||||||
|
Exchange authorization code for refresh/access tokens.
|
||||||
|
"""
|
||||||
|
raise OAuth2Error("Subclasses must implement `get_oauth2_token`")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_oauth2_fresh_token(refresh_token: str) -> OAuth2TokenResponse:
|
||||||
|
"""
|
||||||
|
Refresh an access token that has expired.
|
||||||
|
"""
|
||||||
|
raise OAuth2Error("Subclasses must implement `get_oauth2_fresh_token`")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_allows_alias_in_select(
|
def get_allows_alias_in_select(
|
||||||
cls, database: Database # pylint: disable=unused-argument
|
cls, database: Database # pylint: disable=unused-argument
|
||||||
|
|
@ -1079,7 +1180,12 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||||
# TODO: Fix circular import error caused by importing sql_lab.Query
|
# TODO: Fix circular import error caused by importing sql_lab.Query
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute_with_cursor(cls, cursor: Any, sql: str, query: Query) -> None:
|
def execute_with_cursor(
|
||||||
|
cls,
|
||||||
|
cursor: Any,
|
||||||
|
sql: str,
|
||||||
|
query: Query,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Trigger execution of a query and handle the resulting cursor.
|
Trigger execution of a query and handle the resulting cursor.
|
||||||
|
|
||||||
|
|
@ -1090,7 +1196,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||||
in a timely manner and facilitate operations such as query stop
|
in a timely manner and facilitate operations such as query stop
|
||||||
"""
|
"""
|
||||||
logger.debug("Query %d: Running query: %s", query.id, sql)
|
logger.debug("Query %d: Running query: %s", query.id, sql)
|
||||||
cls.execute(cursor, sql, async_=True)
|
cls.execute(cursor, sql, query.database.id, async_=True)
|
||||||
logger.debug("Query %d: Handling cursor", query.id)
|
logger.debug("Query %d: Handling cursor", query.id)
|
||||||
cls.handle_cursor(cursor, query)
|
cls.handle_cursor(cursor, query)
|
||||||
|
|
||||||
|
|
@ -1536,7 +1642,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_url_for_impersonation(
|
def get_url_for_impersonation(
|
||||||
cls, url: URL, impersonate_user: bool, username: str | None
|
cls,
|
||||||
|
url: URL,
|
||||||
|
impersonate_user: bool,
|
||||||
|
username: str | None,
|
||||||
|
access_token: str | None, # pylint: disable=unused-argument
|
||||||
) -> URL:
|
) -> URL:
|
||||||
"""
|
"""
|
||||||
Return a modified URL with the username set.
|
Return a modified URL with the username set.
|
||||||
|
|
@ -1544,6 +1654,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||||
:param url: SQLAlchemy URL object
|
:param url: SQLAlchemy URL object
|
||||||
:param impersonate_user: Flag indicating if impersonation is enabled
|
:param impersonate_user: Flag indicating if impersonation is enabled
|
||||||
:param username: Effective username
|
:param username: Effective username
|
||||||
|
:param access_token: Personal access token
|
||||||
"""
|
"""
|
||||||
if impersonate_user and username is not None:
|
if impersonate_user and username is not None:
|
||||||
url = url.set(username=username)
|
url = url.set(username=username)
|
||||||
|
|
@ -1572,6 +1683,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||||
cls,
|
cls,
|
||||||
cursor: Any,
|
cursor: Any,
|
||||||
query: str,
|
query: str,
|
||||||
|
database_id: int,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
@ -1579,6 +1691,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||||
|
|
||||||
:param cursor: Cursor instance
|
:param cursor: Cursor instance
|
||||||
:param query: Query to execute
|
:param query: Query to execute
|
||||||
|
:param database_id: ID of the database where the query will run
|
||||||
:param kwargs: kwargs to be passed to cursor.execute()
|
:param kwargs: kwargs to be passed to cursor.execute()
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
|
@ -1589,6 +1702,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||||
cursor.arraysize = cls.arraysize
|
cursor.arraysize = cls.arraysize
|
||||||
try:
|
try:
|
||||||
cursor.execute(query)
|
cursor.execute(query)
|
||||||
|
except cls.oauth2_exception as ex:
|
||||||
|
if cls.is_oauth2_enabled() and g.user:
|
||||||
|
cls.start_oauth2_dance(database_id)
|
||||||
|
raise cls.get_dbapi_mapped_exception(ex) from ex
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise cls.get_dbapi_mapped_exception(ex) from ex
|
raise cls.get_dbapi_mapped_exception(ex) from ex
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -100,7 +100,11 @@ class DrillEngineSpec(BaseEngineSpec):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_url_for_impersonation(
|
def get_url_for_impersonation(
|
||||||
cls, url: URL, impersonate_user: bool, username: str | None
|
cls,
|
||||||
|
url: URL,
|
||||||
|
impersonate_user: bool,
|
||||||
|
username: str | None,
|
||||||
|
access_token: str | None,
|
||||||
) -> URL:
|
) -> URL:
|
||||||
"""
|
"""
|
||||||
Return a modified URL with the username set.
|
Return a modified URL with the username set.
|
||||||
|
|
|
||||||
|
|
@ -23,24 +23,30 @@ import logging
|
||||||
import re
|
import re
|
||||||
from re import Pattern
|
from re import Pattern
|
||||||
from typing import Any, TYPE_CHECKING, TypedDict
|
from typing import Any, TYPE_CHECKING, TypedDict
|
||||||
|
from urllib.parse import urlencode, urljoin
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import urllib3
|
||||||
from apispec import APISpec
|
from apispec import APISpec
|
||||||
from apispec.ext.marshmallow import MarshmallowPlugin
|
from apispec.ext.marshmallow import MarshmallowPlugin
|
||||||
from flask import g
|
from flask import current_app, g
|
||||||
from flask_babel import gettext as __
|
from flask_babel import gettext as __
|
||||||
from marshmallow import fields, Schema
|
from marshmallow import fields, Schema
|
||||||
from marshmallow.exceptions import ValidationError
|
from marshmallow.exceptions import ValidationError
|
||||||
from requests import Session
|
from requests import Session
|
||||||
|
from shillelagh.adapters.api.gsheets.lib import SCOPES
|
||||||
|
from shillelagh.exceptions import UnauthenticatedError
|
||||||
from sqlalchemy.engine import create_engine
|
from sqlalchemy.engine import create_engine
|
||||||
from sqlalchemy.engine.url import URL
|
from sqlalchemy.engine.url import URL
|
||||||
|
|
||||||
from superset import db, security_manager
|
from superset import db, security_manager
|
||||||
from superset.constants import PASSWORD_MASK
|
from superset.constants import PASSWORD_MASK
|
||||||
from superset.databases.schemas import encrypted_field_properties, EncryptedString
|
from superset.databases.schemas import encrypted_field_properties, EncryptedString
|
||||||
|
from superset.db_engine_specs.base import OAuth2State, OAuth2TokenResponse
|
||||||
from superset.db_engine_specs.shillelagh import ShillelaghEngineSpec
|
from superset.db_engine_specs.shillelagh import ShillelaghEngineSpec
|
||||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||||
from superset.exceptions import SupersetException
|
from superset.exceptions import SupersetException
|
||||||
|
from superset.utils.oauth2 import encode_oauth2_state
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from superset.models.core import Database
|
from superset.models.core import Database
|
||||||
|
|
@ -56,6 +62,7 @@ EXAMPLE_GSHEETS_URL = (
|
||||||
SYNTAX_ERROR_REGEX = re.compile('SQLError: near "(?P<server_error>.*?)": syntax error')
|
SYNTAX_ERROR_REGEX = re.compile('SQLError: near "(?P<server_error>.*?)": syntax error')
|
||||||
|
|
||||||
ma_plugin = MarshmallowPlugin()
|
ma_plugin = MarshmallowPlugin()
|
||||||
|
http = urllib3.PoolManager()
|
||||||
|
|
||||||
|
|
||||||
class GSheetsParametersSchema(Schema):
|
class GSheetsParametersSchema(Schema):
|
||||||
|
|
@ -104,18 +111,28 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
|
||||||
|
|
||||||
supports_file_upload = True
|
supports_file_upload = True
|
||||||
|
|
||||||
|
# exception raised by shillelagh that should trigger OAuth2
|
||||||
|
oauth2_exception = UnauthenticatedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_url_for_impersonation(
|
def get_url_for_impersonation(
|
||||||
cls,
|
cls,
|
||||||
url: URL,
|
url: URL,
|
||||||
impersonate_user: bool,
|
impersonate_user: bool,
|
||||||
username: str | None,
|
username: str | None,
|
||||||
|
access_token: str | None,
|
||||||
) -> URL:
|
) -> URL:
|
||||||
if impersonate_user and username is not None:
|
if not impersonate_user:
|
||||||
|
return url
|
||||||
|
|
||||||
|
if username is not None:
|
||||||
user = security_manager.find_user(username=username)
|
user = security_manager.find_user(username=username)
|
||||||
if user and user.email:
|
if user and user.email:
|
||||||
url = url.update_query_dict({"subject": user.email})
|
url = url.update_query_dict({"subject": user.email})
|
||||||
|
|
||||||
|
if access_token:
|
||||||
|
url = url.update_query_dict({"access_token": access_token})
|
||||||
|
|
||||||
return url
|
return url
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -136,6 +153,82 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
|
||||||
|
|
||||||
return {"metadata": metadata["extra"]}
|
return {"metadata": metadata["extra"]}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_oauth2_enabled() -> bool:
|
||||||
|
"""
|
||||||
|
Return if OAuth2 is enabled for GSheets.
|
||||||
|
"""
|
||||||
|
return "Google Sheets" in current_app.config["DATABASE_OAUTH2_CREDENTIALS"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_oauth2_authorization_uri(cls, state: OAuth2State) -> str:
|
||||||
|
"""
|
||||||
|
Return URI for initial OAuth2 request.
|
||||||
|
|
||||||
|
https://developers.google.com/identity/protocols/oauth2/web-server#creatingclient
|
||||||
|
"""
|
||||||
|
config = current_app.config["DATABASE_OAUTH2_CREDENTIALS"]["Google Sheets"]
|
||||||
|
baseurl = config.get("BASEURL", "https://accounts.google.com/o/oauth2/v2/auth")
|
||||||
|
redirect_uri = current_app.config.get(
|
||||||
|
"DATABASE_OAUTH2_REDIRECT_URI",
|
||||||
|
state["default_redirect_uri"],
|
||||||
|
)
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"scope": " ".join(SCOPES),
|
||||||
|
"access_type": "offline",
|
||||||
|
"include_granted_scopes": "false",
|
||||||
|
"response_type": "code",
|
||||||
|
"state": encode_oauth2_state(state),
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
"client_id": config["CLIENT_ID"],
|
||||||
|
"prompt": "consent",
|
||||||
|
}
|
||||||
|
return urljoin(baseurl, "?" + urlencode(params))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_oauth2_token(code: str, state: OAuth2State) -> OAuth2TokenResponse:
|
||||||
|
"""
|
||||||
|
Exchange authorization code for refresh/access tokens.
|
||||||
|
"""
|
||||||
|
config = current_app.config["DATABASE_OAUTH2_CREDENTIALS"]["Google Sheets"]
|
||||||
|
redirect_uri = current_app.config.get(
|
||||||
|
"DATABASE_OAUTH2_REDIRECT_URI",
|
||||||
|
state["default_redirect_uri"],
|
||||||
|
)
|
||||||
|
|
||||||
|
response = http.request(
|
||||||
|
"POST",
|
||||||
|
"https://oauth2.googleapis.com/token",
|
||||||
|
fields={
|
||||||
|
"code": code,
|
||||||
|
"client_id": config["CLIENT_ID"],
|
||||||
|
"client_secret": config["CLIENT_SECRET"],
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return json.loads(response.data.decode("utf-8"))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_oauth2_fresh_token(refresh_token: str) -> OAuth2TokenResponse:
|
||||||
|
"""
|
||||||
|
Refresh an access token that has expired.
|
||||||
|
"""
|
||||||
|
config = current_app.config["DATABASE_OAUTH2_CREDENTIALS"]["Google Sheets"]
|
||||||
|
|
||||||
|
response = http.request(
|
||||||
|
"POST",
|
||||||
|
"https://oauth2.googleapis.com/token",
|
||||||
|
fields={
|
||||||
|
"client_id": config["CLIENT_ID"],
|
||||||
|
"client_secret": config["CLIENT_SECRET"],
|
||||||
|
"refresh_token": refresh_token,
|
||||||
|
"grant_type": "refresh_token",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return json.loads(response.data.decode("utf-8"))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build_sqlalchemy_uri(
|
def build_sqlalchemy_uri(
|
||||||
cls,
|
cls,
|
||||||
|
|
|
||||||
|
|
@ -505,7 +505,11 @@ class HiveEngineSpec(PrestoEngineSpec):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_url_for_impersonation(
|
def get_url_for_impersonation(
|
||||||
cls, url: URL, impersonate_user: bool, username: str | None
|
cls,
|
||||||
|
url: URL,
|
||||||
|
impersonate_user: bool,
|
||||||
|
username: str | None,
|
||||||
|
access_token: str | None,
|
||||||
) -> URL:
|
) -> URL:
|
||||||
"""
|
"""
|
||||||
Return a modified URL with the username set.
|
Return a modified URL with the username set.
|
||||||
|
|
@ -547,7 +551,10 @@ class HiveEngineSpec(PrestoEngineSpec):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def execute( # type: ignore
|
def execute( # type: ignore
|
||||||
cursor, query: str, async_: bool = False
|
cursor,
|
||||||
|
query: str,
|
||||||
|
database_id: int,
|
||||||
|
async_: bool = False,
|
||||||
): # pylint: disable=arguments-differ
|
): # pylint: disable=arguments-differ
|
||||||
kwargs = {"async": async_}
|
kwargs = {"async": async_}
|
||||||
cursor.execute(query, **kwargs)
|
cursor.execute(query, **kwargs)
|
||||||
|
|
|
||||||
|
|
@ -93,6 +93,7 @@ class ImpalaEngineSpec(BaseEngineSpec):
|
||||||
cls,
|
cls,
|
||||||
cursor: Any,
|
cursor: Any,
|
||||||
query: str,
|
query: str,
|
||||||
|
database_id: int,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -1271,7 +1271,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
sql = f"SHOW CREATE VIEW {schema}.{table}"
|
sql = f"SHOW CREATE VIEW {schema}.{table}"
|
||||||
try:
|
try:
|
||||||
cls.execute(cursor, sql)
|
cls.execute(cursor, sql, database.id)
|
||||||
rows = cls.fetch_data(cursor, 1)
|
rows = cls.fetch_data(cursor, 1)
|
||||||
|
|
||||||
return rows[0][0]
|
return rows[0][0]
|
||||||
|
|
|
||||||
|
|
@ -132,7 +132,11 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_url_for_impersonation(
|
def get_url_for_impersonation(
|
||||||
cls, url: URL, impersonate_user: bool, username: str | None
|
cls,
|
||||||
|
url: URL,
|
||||||
|
impersonate_user: bool,
|
||||||
|
username: str | None,
|
||||||
|
access_token: str | None,
|
||||||
) -> URL:
|
) -> URL:
|
||||||
"""
|
"""
|
||||||
Return a modified URL with the username set.
|
Return a modified URL with the username set.
|
||||||
|
|
@ -191,7 +195,12 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
|
||||||
super().handle_cursor(cursor=cursor, query=query)
|
super().handle_cursor(cursor=cursor, query=query)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute_with_cursor(cls, cursor: Cursor, sql: str, query: Query) -> None:
|
def execute_with_cursor(
|
||||||
|
cls,
|
||||||
|
cursor: Cursor,
|
||||||
|
sql: str,
|
||||||
|
query: Query,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Trigger execution of a query and handle the resulting cursor.
|
Trigger execution of a query and handle the resulting cursor.
|
||||||
|
|
||||||
|
|
@ -210,7 +219,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
|
||||||
logger.debug("Query %d: Running query: %s", query_id, sql)
|
logger.debug("Query %d: Running query: %s", query_id, sql)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cls.execute(cursor, sql)
|
cls.execute(cursor, sql, query.database.id)
|
||||||
except Exception as ex: # pylint: disable=broad-except
|
except Exception as ex: # pylint: disable=broad-except
|
||||||
results["error"] = ex
|
results["error"] = ex
|
||||||
finally:
|
finally:
|
||||||
|
|
|
||||||
|
|
@ -67,6 +67,8 @@ class SupersetErrorType(StrEnum):
|
||||||
USER_ACTIVITY_SECURITY_ACCESS_ERROR = "USER_ACTIVITY_SECURITY_ACCESS_ERROR"
|
USER_ACTIVITY_SECURITY_ACCESS_ERROR = "USER_ACTIVITY_SECURITY_ACCESS_ERROR"
|
||||||
DASHBOARD_SECURITY_ACCESS_ERROR = "DASHBOARD_SECURITY_ACCESS_ERROR"
|
DASHBOARD_SECURITY_ACCESS_ERROR = "DASHBOARD_SECURITY_ACCESS_ERROR"
|
||||||
CHART_SECURITY_ACCESS_ERROR = "CHART_SECURITY_ACCESS_ERROR"
|
CHART_SECURITY_ACCESS_ERROR = "CHART_SECURITY_ACCESS_ERROR"
|
||||||
|
OAUTH2_REDIRECT = "OAUTH2_REDIRECT"
|
||||||
|
OAUTH2_REDIRECT_ERROR = "OAUTH2_REDIRECT_ERROR"
|
||||||
|
|
||||||
# Other errors
|
# Other errors
|
||||||
BACKEND_TIMEOUT_ERROR = "BACKEND_TIMEOUT_ERROR"
|
BACKEND_TIMEOUT_ERROR = "BACKEND_TIMEOUT_ERROR"
|
||||||
|
|
|
||||||
|
|
@ -312,3 +312,53 @@ class SupersetParseError(SupersetErrorException):
|
||||||
extra={"sql": sql, "engine": engine},
|
extra={"sql": sql, "engine": engine},
|
||||||
)
|
)
|
||||||
super().__init__(error)
|
super().__init__(error)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuth2RedirectError(SupersetErrorException):
|
||||||
|
"""
|
||||||
|
Exception used to start OAuth2 dance for personal tokens.
|
||||||
|
|
||||||
|
The exception requires 3 parameters:
|
||||||
|
|
||||||
|
- The URL that starts the OAuth2 dance.
|
||||||
|
- The UUID of the browser tab where OAuth2 started, so that the newly opened tab
|
||||||
|
where OAuth2 happens can communicate with the original tab to inform that OAuth2
|
||||||
|
was successful (or not).
|
||||||
|
- The redirect URL, so that the original tab can validate that the message from the
|
||||||
|
second tab is coming from a valid origin.
|
||||||
|
|
||||||
|
See the `OAuth2RedirectMessage.tsx` component for more details of how this
|
||||||
|
information is handled.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, url: str, tab_id: str, redirect_uri: str):
|
||||||
|
super().__init__(
|
||||||
|
SupersetError(
|
||||||
|
message="You don't have permission to access the data.",
|
||||||
|
error_type=SupersetErrorType.OAUTH2_REDIRECT,
|
||||||
|
level=ErrorLevel.WARNING,
|
||||||
|
extra={"url": url, "tab_id": tab_id, "redirect_uri": redirect_uri},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuth2Error(SupersetErrorException):
|
||||||
|
"""
|
||||||
|
Exception for when OAuth2 goes wrong.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, error: str):
|
||||||
|
super().__init__(
|
||||||
|
SupersetError(
|
||||||
|
message="Something went wrong while doing OAuth2",
|
||||||
|
error_type=SupersetErrorType.OAUTH2_REDIRECT_ERROR,
|
||||||
|
level=ErrorLevel.ERROR,
|
||||||
|
extra={"error": error},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CreateKeyValueDistributedLockFailedException(Exception):
|
||||||
|
"""
|
||||||
|
Exception to signalize failure to acquire lock.
|
||||||
|
"""
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,7 @@ class KeyValueResource(StrEnum):
|
||||||
DASHBOARD_PERMALINK = "dashboard_permalink"
|
DASHBOARD_PERMALINK = "dashboard_permalink"
|
||||||
EXPLORE_PERMALINK = "explore_permalink"
|
EXPLORE_PERMALINK = "explore_permalink"
|
||||||
METASTORE_CACHE = "superset_metastore_cache"
|
METASTORE_CACHE = "superset_metastore_cache"
|
||||||
|
LOCK = "lock"
|
||||||
|
|
||||||
|
|
||||||
class SharedKey(StrEnum):
|
class SharedKey(StrEnum):
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,84 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
"""Add access token table
|
||||||
|
|
||||||
|
Revision ID: 678eefb4ab44
|
||||||
|
Revises: be1b217cd8cd
|
||||||
|
Create Date: 2024-03-20 16:02:58.515915
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = "678eefb4ab44"
|
||||||
|
down_revision = "be1b217cd8cd"
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy_utils import EncryptedType
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
op.create_table(
|
||||||
|
"database_user_oauth2_tokens",
|
||||||
|
sa.Column("created_on", sa.DateTime(), nullable=True),
|
||||||
|
sa.Column("changed_on", sa.DateTime(), nullable=True),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("user_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("database_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"access_token",
|
||||||
|
EncryptedType(),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column("access_token_expiration", sa.DateTime(), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"refresh_token",
|
||||||
|
EncryptedType(),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column("created_by_fk", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("changed_by_fk", sa.Integer(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["changed_by_fk"],
|
||||||
|
["ab_user.id"],
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["created_by_fk"],
|
||||||
|
["ab_user.id"],
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["database_id"],
|
||||||
|
["dbs.id"],
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["user_id"],
|
||||||
|
["ab_user.id"],
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"idx_user_id_database_id",
|
||||||
|
"database_user_oauth2_tokens",
|
||||||
|
["user_id", "database_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
op.drop_index("idx_user_id_database_id", table_name="database_user_oauth2_tokens")
|
||||||
|
op.drop_table("database_user_oauth2_tokens")
|
||||||
|
|
@ -75,6 +75,7 @@ from superset.superset_typing import ResultSetColumnType
|
||||||
from superset.utils import cache as cache_util, core as utils
|
from superset.utils import cache as cache_util, core as utils
|
||||||
from superset.utils.backports import StrEnum
|
from superset.utils.backports import StrEnum
|
||||||
from superset.utils.core import get_username
|
from superset.utils.core import get_username
|
||||||
|
from superset.utils.oauth2 import get_oauth2_access_token
|
||||||
|
|
||||||
config = app.config
|
config = app.config
|
||||||
custom_password_store = config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"]
|
custom_password_store = config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"]
|
||||||
|
|
@ -465,6 +466,11 @@ class Database(
|
||||||
)
|
)
|
||||||
|
|
||||||
effective_username = self.get_effective_user(sqlalchemy_url)
|
effective_username = self.get_effective_user(sqlalchemy_url)
|
||||||
|
access_token = (
|
||||||
|
get_oauth2_access_token(self.id, g.user.id, self.db_engine_spec)
|
||||||
|
if hasattr(g, "user") and hasattr(g.user, "id")
|
||||||
|
else None
|
||||||
|
)
|
||||||
# If using MySQL or Presto for example, will set url.username
|
# If using MySQL or Presto for example, will set url.username
|
||||||
# If using Hive, will not do anything yet since that relies on a
|
# If using Hive, will not do anything yet since that relies on a
|
||||||
# configuration parameter instead.
|
# configuration parameter instead.
|
||||||
|
|
@ -472,6 +478,7 @@ class Database(
|
||||||
sqlalchemy_url,
|
sqlalchemy_url,
|
||||||
self.impersonate_user,
|
self.impersonate_user,
|
||||||
effective_username,
|
effective_username,
|
||||||
|
access_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
masked_url = self.get_password_masked_url(sqlalchemy_url)
|
masked_url = self.get_password_masked_url(sqlalchemy_url)
|
||||||
|
|
@ -592,7 +599,7 @@ class Database(
|
||||||
database=None,
|
database=None,
|
||||||
)
|
)
|
||||||
_log_query(sql_)
|
_log_query(sql_)
|
||||||
self.db_engine_spec.execute(cursor, sql_)
|
self.db_engine_spec.execute(cursor, sql_, self.id)
|
||||||
cursor.fetchall()
|
cursor.fetchall()
|
||||||
|
|
||||||
if mutate_after_split:
|
if mutate_after_split:
|
||||||
|
|
@ -602,10 +609,10 @@ class Database(
|
||||||
database=None,
|
database=None,
|
||||||
)
|
)
|
||||||
_log_query(last_sql)
|
_log_query(last_sql)
|
||||||
self.db_engine_spec.execute(cursor, last_sql)
|
self.db_engine_spec.execute(cursor, last_sql, self.id)
|
||||||
else:
|
else:
|
||||||
_log_query(sqls[-1])
|
_log_query(sqls[-1])
|
||||||
self.db_engine_spec.execute(cursor, sqls[-1])
|
self.db_engine_spec.execute(cursor, sqls[-1], self.id)
|
||||||
|
|
||||||
data = self.db_engine_spec.fetch_data(cursor)
|
data = self.db_engine_spec.fetch_data(cursor)
|
||||||
result_set = SupersetResultSet(
|
result_set = SupersetResultSet(
|
||||||
|
|
@ -982,6 +989,35 @@ sqla.event.listen(Database, "after_update", security_manager.database_after_upda
|
||||||
sqla.event.listen(Database, "after_delete", security_manager.database_after_delete)
|
sqla.event.listen(Database, "after_delete", security_manager.database_after_delete)
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseUserOAuth2Tokens(Model, AuditMixinNullable):
|
||||||
|
"""
|
||||||
|
Store OAuth2 tokens, for authenticating to DBs using user personal tokens.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "database_user_oauth2_tokens"
|
||||||
|
__table_args__ = (sqla.Index("idx_user_id_database_id", "user_id", "database_id"),)
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True)
|
||||||
|
|
||||||
|
user_id = Column(
|
||||||
|
Integer,
|
||||||
|
ForeignKey("ab_user.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
user = relationship(security_manager.user_model, foreign_keys=[user_id])
|
||||||
|
|
||||||
|
database_id = Column(
|
||||||
|
Integer,
|
||||||
|
ForeignKey("dbs.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
database = relationship("Database", foreign_keys=[database_id])
|
||||||
|
|
||||||
|
access_token = Column(encrypted_field_factory.create(Text), nullable=True)
|
||||||
|
access_token_expiration = Column(DateTime, nullable=True)
|
||||||
|
refresh_token = Column(encrypted_field_factory.create(Text), nullable=True)
|
||||||
|
|
||||||
|
|
||||||
class Log(Model): # pylint: disable=too-few-public-methods
|
class Log(Model): # pylint: disable=too-few-public-methods
|
||||||
"""ORM object used to log Superset actions to the database"""
|
"""ORM object used to log Superset actions to the database"""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,11 @@ from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY
|
||||||
from superset.dataframe import df_to_records
|
from superset.dataframe import df_to_records
|
||||||
from superset.db_engine_specs import BaseEngineSpec
|
from superset.db_engine_specs import BaseEngineSpec
|
||||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||||
from superset.exceptions import SupersetErrorException, SupersetErrorsException
|
from superset.exceptions import (
|
||||||
|
OAuth2RedirectError,
|
||||||
|
SupersetErrorException,
|
||||||
|
SupersetErrorsException,
|
||||||
|
)
|
||||||
from superset.extensions import celery_app
|
from superset.extensions import celery_app
|
||||||
from superset.models.core import Database
|
from superset.models.core import Database
|
||||||
from superset.models.sql_lab import Query
|
from superset.models.sql_lab import Query
|
||||||
|
|
@ -188,7 +192,7 @@ def get_sql_results( # pylint: disable=too-many-arguments
|
||||||
return handle_query_error(ex, query)
|
return handle_query_error(ex, query)
|
||||||
|
|
||||||
|
|
||||||
def execute_sql_statement(
|
def execute_sql_statement( # pylint: disable=too-many-statements
|
||||||
sql_statement: str,
|
sql_statement: str,
|
||||||
query: Query,
|
query: Query,
|
||||||
cursor: Any,
|
cursor: Any,
|
||||||
|
|
@ -308,6 +312,9 @@ def execute_sql_statement(
|
||||||
level=ErrorLevel.ERROR,
|
level=ErrorLevel.ERROR,
|
||||||
)
|
)
|
||||||
) from ex
|
) from ex
|
||||||
|
except OAuth2RedirectError as ex:
|
||||||
|
# user needs to authenticate with OAuth2 in order to run query
|
||||||
|
raise ex
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
# query is stopped in another thread/worker
|
# query is stopped in another thread/worker
|
||||||
# stopping raises expected exceptions which we should skip
|
# stopping raises expected exceptions which we should skip
|
||||||
|
|
|
||||||
|
|
@ -73,7 +73,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
|
||||||
from pyhive.exc import DatabaseError
|
from pyhive.exc import DatabaseError
|
||||||
|
|
||||||
try:
|
try:
|
||||||
db_engine_spec.execute(cursor, sql)
|
db_engine_spec.execute(cursor, sql, database.id)
|
||||||
polled = cursor.poll()
|
polled = cursor.poll()
|
||||||
while polled:
|
while polled:
|
||||||
logger.info("polling presto for validation progress")
|
logger.info("polling presto for validation progress")
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,31 @@
|
||||||
|
{#
|
||||||
|
Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
or more contributor license agreements. See the NOTICE file
|
||||||
|
distributed with this work for additional information
|
||||||
|
regarding copyright ownership. The ASF licenses this file
|
||||||
|
to you under the Apache License, Version 2.0 (the
|
||||||
|
"License"); you may not use this file except in compliance
|
||||||
|
with the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing,
|
||||||
|
software distributed under the License is distributed on an
|
||||||
|
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations
|
||||||
|
under the License.
|
||||||
|
-#}
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset="utf-8">
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<script>
|
||||||
|
window.opener.postMessage({ tabId: '{{ tab_id }}' });
|
||||||
|
window.close();
|
||||||
|
</script>
|
||||||
|
<p>You can close this window and re-run the query.</p>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
|
@ -0,0 +1,99 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from collections.abc import Iterator
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Any, cast, TypeVar, Union
|
||||||
|
|
||||||
|
from superset.exceptions import CreateKeyValueDistributedLockFailedException
|
||||||
|
from superset.key_value.exceptions import KeyValueCreateFailedError
|
||||||
|
from superset.key_value.types import KeyValueResource, PickleKeyValueCodec
|
||||||
|
|
||||||
|
LOCK_EXPIRATION = timedelta(seconds=30)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def serialize(params: dict[str, Any]) -> str:
|
||||||
|
"""
|
||||||
|
Serialize parameters into a string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
T = TypeVar(
|
||||||
|
"T",
|
||||||
|
bound=Union[dict[str, Any], list[Any], int, float, str, bool, None],
|
||||||
|
)
|
||||||
|
|
||||||
|
def sort(obj: T) -> T:
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return cast(T, {k: sort(v) for k, v in sorted(obj.items())})
|
||||||
|
if isinstance(obj, list):
|
||||||
|
return cast(T, [sort(x) for x in obj])
|
||||||
|
return obj
|
||||||
|
|
||||||
|
return json.dumps(params)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def KeyValueDistributedLock( # pylint: disable=invalid-name
|
||||||
|
namespace: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[uuid.UUID]:
|
||||||
|
"""
|
||||||
|
KV global lock for refreshing tokens.
|
||||||
|
|
||||||
|
This context manager acquires a distributed lock for a given namespace, with
|
||||||
|
optional parameters (eg, namespace="cache", user_id=1). It yields a UUID for the
|
||||||
|
lock that can be used within the context, and corresponds to the key in the KV
|
||||||
|
store.
|
||||||
|
|
||||||
|
:param namespace: The namespace for which the lock is to be acquired.
|
||||||
|
:type namespace: str
|
||||||
|
:param kwargs: Additional keyword arguments.
|
||||||
|
:yields: A unique identifier (UUID) for the acquired lock (the KV key).
|
||||||
|
:raises CreateKeyValueDistributedLockFailedException: If the lock is taken.
|
||||||
|
"""
|
||||||
|
# pylint: disable=import-outside-toplevel
|
||||||
|
from superset.commands.key_value.create import CreateKeyValueCommand
|
||||||
|
from superset.commands.key_value.delete import DeleteKeyValueCommand
|
||||||
|
from superset.commands.key_value.delete_expired import DeleteExpiredKeyValueCommand
|
||||||
|
|
||||||
|
key = uuid.uuid5(uuid.uuid5(uuid.NAMESPACE_DNS, namespace), serialize(kwargs))
|
||||||
|
logger.debug("Acquiring lock on namespace %s for key %s", namespace, key)
|
||||||
|
try:
|
||||||
|
DeleteExpiredKeyValueCommand(resource=KeyValueResource.LOCK).run()
|
||||||
|
CreateKeyValueCommand(
|
||||||
|
resource=KeyValueResource.LOCK,
|
||||||
|
codec=PickleKeyValueCodec(),
|
||||||
|
key=key,
|
||||||
|
value=True,
|
||||||
|
expires_on=datetime.now() + LOCK_EXPIRATION,
|
||||||
|
).run()
|
||||||
|
|
||||||
|
yield key
|
||||||
|
|
||||||
|
DeleteKeyValueCommand(resource=KeyValueResource.LOCK, key=key).run()
|
||||||
|
logger.debug("Removed lock on namespace %s for key %s", namespace, key)
|
||||||
|
except KeyValueCreateFailedError as ex:
|
||||||
|
raise CreateKeyValueDistributedLockFailedException(
|
||||||
|
"Error acquiring lock"
|
||||||
|
) from ex
|
||||||
|
|
@ -0,0 +1,176 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Any, TYPE_CHECKING
|
||||||
|
|
||||||
|
import backoff
|
||||||
|
import jwt
|
||||||
|
from flask import current_app
|
||||||
|
from marshmallow import EXCLUDE, fields, post_load, Schema
|
||||||
|
|
||||||
|
from superset import db
|
||||||
|
from superset.db_engine_specs.base import BaseEngineSpec, OAuth2State
|
||||||
|
from superset.exceptions import CreateKeyValueDistributedLockFailedException
|
||||||
|
from superset.utils.lock import KeyValueDistributedLock
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from superset.models.core import DatabaseUserOAuth2Tokens
|
||||||
|
|
||||||
|
JWT_EXPIRATION = timedelta(minutes=5)
|
||||||
|
|
||||||
|
|
||||||
|
@backoff.on_exception(
|
||||||
|
backoff.expo,
|
||||||
|
CreateKeyValueDistributedLockFailedException,
|
||||||
|
factor=10,
|
||||||
|
base=2,
|
||||||
|
max_tries=5,
|
||||||
|
)
|
||||||
|
def get_oauth2_access_token(
|
||||||
|
database_id: int,
|
||||||
|
user_id: int,
|
||||||
|
db_engine_spec: type[BaseEngineSpec],
|
||||||
|
) -> str | None:
|
||||||
|
"""
|
||||||
|
Return a valid OAuth2 access token.
|
||||||
|
|
||||||
|
If the token exists but is expired and a refresh token is available the function will
|
||||||
|
return a fresh token and store it in the database for further requests. The function
|
||||||
|
has a retry decorator, in case a dashboard with multiple charts triggers
|
||||||
|
simultaneous requests for refreshing a stale token; in that case only the first
|
||||||
|
process to acquire the lock will perform the refresh, and othe process should find a
|
||||||
|
a valid token when they retry.
|
||||||
|
"""
|
||||||
|
# pylint: disable=import-outside-toplevel
|
||||||
|
from superset.models.core import DatabaseUserOAuth2Tokens
|
||||||
|
|
||||||
|
token = (
|
||||||
|
db.session.query(DatabaseUserOAuth2Tokens)
|
||||||
|
.filter_by(user_id=user_id, database_id=database_id)
|
||||||
|
.one_or_none()
|
||||||
|
)
|
||||||
|
if token is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if token.access_token and datetime.now() < token.access_token_expiration:
|
||||||
|
return token.access_token
|
||||||
|
|
||||||
|
if token.refresh_token:
|
||||||
|
return refresh_oauth2_token(database_id, user_id, db_engine_spec, token)
|
||||||
|
|
||||||
|
# since the access token is expired and there's no refresh token, delete the entry
|
||||||
|
db.session.delete(token)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_oauth2_token(
|
||||||
|
database_id: int,
|
||||||
|
user_id: int,
|
||||||
|
db_engine_spec: type[BaseEngineSpec],
|
||||||
|
token: DatabaseUserOAuth2Tokens,
|
||||||
|
) -> str | None:
|
||||||
|
with KeyValueDistributedLock(
|
||||||
|
namespace="refresh_oauth2_token",
|
||||||
|
user_id=user_id,
|
||||||
|
database_id=database_id,
|
||||||
|
):
|
||||||
|
token_response = db_engine_spec.get_oauth2_fresh_token(token.refresh_token)
|
||||||
|
|
||||||
|
# store new access token; note that the refresh token might be revoked, in which
|
||||||
|
# case there would be no access token in the response
|
||||||
|
if "access_token" not in token_response:
|
||||||
|
return None
|
||||||
|
|
||||||
|
token.access_token = token_response["access_token"]
|
||||||
|
token.access_token_expiration = datetime.now() + timedelta(
|
||||||
|
seconds=token_response["expires_in"]
|
||||||
|
)
|
||||||
|
db.session.add(token)
|
||||||
|
|
||||||
|
return token.access_token
|
||||||
|
|
||||||
|
|
||||||
|
def encode_oauth2_state(state: OAuth2State) -> str:
|
||||||
|
"""
|
||||||
|
Encode the OAuth2 state.
|
||||||
|
"""
|
||||||
|
payload = {
|
||||||
|
"exp": datetime.now(tz=timezone.utc) + JWT_EXPIRATION,
|
||||||
|
"database_id": state["database_id"],
|
||||||
|
"user_id": state["user_id"],
|
||||||
|
"default_redirect_uri": state["default_redirect_uri"],
|
||||||
|
"tab_id": state["tab_id"],
|
||||||
|
}
|
||||||
|
encoded_state = jwt.encode(
|
||||||
|
payload=payload,
|
||||||
|
key=current_app.config["SECRET_KEY"],
|
||||||
|
algorithm=current_app.config["DATABASE_OAUTH2_JWT_ALGORITHM"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Google OAuth2 needs periods to be escaped.
|
||||||
|
encoded_state = encoded_state.replace(".", "%2E")
|
||||||
|
|
||||||
|
return encoded_state
|
||||||
|
|
||||||
|
|
||||||
|
class OAuth2StateSchema(Schema):
|
||||||
|
database_id = fields.Int(required=True)
|
||||||
|
user_id = fields.Int(required=True)
|
||||||
|
default_redirect_uri = fields.Str(required=True)
|
||||||
|
tab_id = fields.Str(required=True)
|
||||||
|
|
||||||
|
# pylint: disable=unused-argument
|
||||||
|
@post_load
|
||||||
|
def make_oauth2_state(
|
||||||
|
self,
|
||||||
|
data: dict[str, Any],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> OAuth2State:
|
||||||
|
return OAuth2State(
|
||||||
|
database_id=data["database_id"],
|
||||||
|
user_id=data["user_id"],
|
||||||
|
default_redirect_uri=data["default_redirect_uri"],
|
||||||
|
tab_id=data["tab_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
class Meta: # pylint: disable=too-few-public-methods
|
||||||
|
# ignore `exp`
|
||||||
|
unknown = EXCLUDE
|
||||||
|
|
||||||
|
|
||||||
|
oauth2_state_schema = OAuth2StateSchema()
|
||||||
|
|
||||||
|
|
||||||
|
def decode_oauth2_state(encoded_state: str) -> OAuth2State:
|
||||||
|
"""
|
||||||
|
Decode the OAuth2 state.
|
||||||
|
"""
|
||||||
|
# Google OAuth2 needs periods to be escaped.
|
||||||
|
encoded_state = encoded_state.replace("%2E", ".")
|
||||||
|
|
||||||
|
payload = jwt.decode(
|
||||||
|
jwt=encoded_state,
|
||||||
|
key=current_app.config["SECRET_KEY"],
|
||||||
|
algorithms=[current_app.config["DATABASE_OAUTH2_JWT_ALGORITHM"]],
|
||||||
|
)
|
||||||
|
state = oauth2_state_schema.load(payload)
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
@ -1551,6 +1551,7 @@ class TestRolePermission(SupersetTestCase):
|
||||||
["SecurityApi", "login"],
|
["SecurityApi", "login"],
|
||||||
["SecurityApi", "refresh"],
|
["SecurityApi", "refresh"],
|
||||||
["SupersetIndexView", "index"],
|
["SupersetIndexView", "index"],
|
||||||
|
["DatabaseRestApi", "oauth2"],
|
||||||
]
|
]
|
||||||
unsecured_views = []
|
unsecured_views = []
|
||||||
for view_class in appbuilder.baseviews:
|
for view_class in appbuilder.baseviews:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,16 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
@ -0,0 +1,16 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
@ -0,0 +1,66 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
|
from superset.connectors.sqla.models import SqlaTable
|
||||||
|
from superset.exceptions import OAuth2RedirectError
|
||||||
|
from superset.superset_typing import QueryObjectDict
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_bubbles_errors(mocker: MockerFixture) -> None:
|
||||||
|
"""
|
||||||
|
Test that the `query` method bubbles exceptions correctly.
|
||||||
|
|
||||||
|
When a user needs to authenticate via OAuth2 to access data, a custom exception is
|
||||||
|
raised. The exception needs to bubble up all the way to the frontend as a SIP-40
|
||||||
|
compliant payload with the error type `DATABASE_OAUTH2_REDIRECT_URI` so that the
|
||||||
|
frontend can initiate the OAuth2 authentication.
|
||||||
|
|
||||||
|
This tests verifies that the method does not capture these exceptions; otherwise the
|
||||||
|
user will be never be prompted to authenticate via OAuth2.
|
||||||
|
"""
|
||||||
|
database = mocker.MagicMock()
|
||||||
|
database.get_df.side_effect = OAuth2RedirectError(
|
||||||
|
url="http://example.com",
|
||||||
|
tab_id="1234",
|
||||||
|
redirect_uri="http://redirect.example.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
sqla_table = SqlaTable(
|
||||||
|
table_name="my_sqla_table",
|
||||||
|
columns=[],
|
||||||
|
metrics=[],
|
||||||
|
database=database,
|
||||||
|
)
|
||||||
|
mocker.patch.object(
|
||||||
|
sqla_table,
|
||||||
|
"get_query_str_extended",
|
||||||
|
return_value=mocker.MagicMock(sql="SELECT * FROM my_sqla_table"),
|
||||||
|
)
|
||||||
|
query_obj: QueryObjectDict = {
|
||||||
|
"granularity": None,
|
||||||
|
"from_dttm": None,
|
||||||
|
"to_dttm": None,
|
||||||
|
"groupby": ["id", "username", "email"],
|
||||||
|
"metrics": [],
|
||||||
|
"is_timeseries": False,
|
||||||
|
"filter": [],
|
||||||
|
}
|
||||||
|
with pytest.raises(OAuth2RedirectError):
|
||||||
|
sqla_table.query(query_obj)
|
||||||
|
|
@ -18,6 +18,7 @@
|
||||||
# pylint: disable=unused-argument, import-outside-toplevel, line-too-long
|
# pylint: disable=unused-argument, import-outside-toplevel, line-too-long
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from datetime import datetime
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
@ -25,10 +26,12 @@ from uuid import UUID
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
|
from freezegun import freeze_time
|
||||||
from pytest_mock import MockFixture
|
from pytest_mock import MockFixture
|
||||||
from sqlalchemy.orm.session import Session
|
from sqlalchemy.orm.session import Session
|
||||||
|
|
||||||
from superset import db
|
from superset import db
|
||||||
|
from superset.db_engine_specs.sqlite import SqliteEngineSpec
|
||||||
|
|
||||||
|
|
||||||
def test_filter_by_uuid(
|
def test_filter_by_uuid(
|
||||||
|
|
@ -638,3 +641,170 @@ def test_apply_dynamic_database_filter(
|
||||||
|
|
||||||
# Ensure that the filter has been called once
|
# Ensure that the filter has been called once
|
||||||
assert base_filter_mock.call_count == 1
|
assert base_filter_mock.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_oauth2_happy_path(
|
||||||
|
mocker: MockFixture,
|
||||||
|
session: Session,
|
||||||
|
client: Any,
|
||||||
|
full_api_access: None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Test the OAuth2 endpoint when everything goes well.
|
||||||
|
"""
|
||||||
|
from superset.databases.api import DatabaseRestApi
|
||||||
|
from superset.models.core import Database, DatabaseUserOAuth2Tokens
|
||||||
|
|
||||||
|
DatabaseRestApi.datamodel.session = session
|
||||||
|
|
||||||
|
# create table for databases
|
||||||
|
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
|
||||||
|
db.session.add(
|
||||||
|
Database(
|
||||||
|
database_name="my_db",
|
||||||
|
sqlalchemy_uri="sqlite://",
|
||||||
|
uuid=UUID("7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
get_oauth2_token = mocker.patch.object(SqliteEngineSpec, "get_oauth2_token")
|
||||||
|
get_oauth2_token.return_value = {
|
||||||
|
"access_token": "YYY",
|
||||||
|
"expires_in": 3600,
|
||||||
|
"refresh_token": "ZZZ",
|
||||||
|
}
|
||||||
|
|
||||||
|
state = {
|
||||||
|
"user_id": 1,
|
||||||
|
"database_id": 1,
|
||||||
|
"tab_id": 42,
|
||||||
|
}
|
||||||
|
decode_oauth2_state = mocker.patch("superset.databases.api.decode_oauth2_state")
|
||||||
|
decode_oauth2_state.return_value = state
|
||||||
|
|
||||||
|
mocker.patch("superset.databases.api.render_template", return_value="OK")
|
||||||
|
|
||||||
|
with freeze_time("2024-01-01T00:00:00Z"):
|
||||||
|
response = client.get(
|
||||||
|
"/api/v1/database/oauth2/",
|
||||||
|
query_string={
|
||||||
|
"state": "some%2Estate",
|
||||||
|
"code": "XXX",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
decode_oauth2_state.assert_called_with("some%2Estate")
|
||||||
|
get_oauth2_token.assert_called_with("XXX", state)
|
||||||
|
|
||||||
|
token = db.session.query(DatabaseUserOAuth2Tokens).one()
|
||||||
|
assert token.user_id == 1
|
||||||
|
assert token.database_id == 1
|
||||||
|
assert token.access_token == "YYY"
|
||||||
|
assert token.access_token_expiration == datetime(2024, 1, 1, 1, 0)
|
||||||
|
assert token.refresh_token == "ZZZ"
|
||||||
|
|
||||||
|
|
||||||
|
def test_oauth2_multiple_tokens(
|
||||||
|
mocker: MockFixture,
|
||||||
|
session: Session,
|
||||||
|
client: Any,
|
||||||
|
full_api_access: None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Test the OAuth2 endpoint when a second token is added.
|
||||||
|
"""
|
||||||
|
from superset.databases.api import DatabaseRestApi
|
||||||
|
from superset.models.core import Database, DatabaseUserOAuth2Tokens
|
||||||
|
|
||||||
|
DatabaseRestApi.datamodel.session = session
|
||||||
|
|
||||||
|
# create table for databases
|
||||||
|
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
|
||||||
|
db.session.add(
|
||||||
|
Database(
|
||||||
|
database_name="my_db",
|
||||||
|
sqlalchemy_uri="sqlite://",
|
||||||
|
uuid=UUID("7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
get_oauth2_token = mocker.patch.object(SqliteEngineSpec, "get_oauth2_token")
|
||||||
|
get_oauth2_token.side_effect = [
|
||||||
|
{
|
||||||
|
"access_token": "YYY",
|
||||||
|
"expires_in": 3600,
|
||||||
|
"refresh_token": "ZZZ",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"access_token": "YYY2",
|
||||||
|
"expires_in": 3600,
|
||||||
|
"refresh_token": "ZZZ2",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
state = {
|
||||||
|
"user_id": 1,
|
||||||
|
"database_id": 1,
|
||||||
|
"tab_id": 42,
|
||||||
|
}
|
||||||
|
decode_oauth2_state = mocker.patch("superset.databases.api.decode_oauth2_state")
|
||||||
|
decode_oauth2_state.return_value = state
|
||||||
|
|
||||||
|
mocker.patch("superset.databases.api.render_template", return_value="OK")
|
||||||
|
|
||||||
|
with freeze_time("2024-01-01T00:00:00Z"):
|
||||||
|
response = client.get(
|
||||||
|
"/api/v1/database/oauth2/",
|
||||||
|
query_string={
|
||||||
|
"state": "some%2Estate",
|
||||||
|
"code": "XXX",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# second request should delete token from the first request
|
||||||
|
response = client.get(
|
||||||
|
"/api/v1/database/oauth2/",
|
||||||
|
query_string={
|
||||||
|
"state": "some%2Estate",
|
||||||
|
"code": "XXX",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
tokens = db.session.query(DatabaseUserOAuth2Tokens).all()
|
||||||
|
assert len(tokens) == 1
|
||||||
|
token = tokens[0]
|
||||||
|
assert token.access_token == "YYY2"
|
||||||
|
assert token.refresh_token == "ZZZ2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_oauth2_error(
|
||||||
|
mocker: MockFixture,
|
||||||
|
session: Session,
|
||||||
|
client: Any,
|
||||||
|
full_api_access: None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Test the OAuth2 endpoint when OAuth2 errors.
|
||||||
|
"""
|
||||||
|
response = client.get(
|
||||||
|
"/api/v1/database/oauth2/",
|
||||||
|
query_string={
|
||||||
|
"error": "Something bad hapened",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 500
|
||||||
|
assert response.json == {
|
||||||
|
"errors": [
|
||||||
|
{
|
||||||
|
"message": "Something went wrong while doing OAuth2",
|
||||||
|
"error_type": "OAUTH2_REDIRECT_ERROR",
|
||||||
|
"level": "error",
|
||||||
|
"extra": {"error": "Something bad hapened"},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -225,3 +225,45 @@ def test_rename_encrypted_extra() -> None:
|
||||||
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
|
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
|
||||||
"masked_encrypted_extra": "{}",
|
"masked_encrypted_extra": "{}",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_oauth2_schema_success() -> None:
|
||||||
|
"""
|
||||||
|
Test a successful redirect.
|
||||||
|
"""
|
||||||
|
from superset.databases.schemas import OAuth2ProviderResponseSchema
|
||||||
|
|
||||||
|
schema = OAuth2ProviderResponseSchema()
|
||||||
|
|
||||||
|
payload = schema.load({"code": "SECRET", "state": "12345"})
|
||||||
|
assert payload == {"code": "SECRET", "state": "12345"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_oauth2_schema_error() -> None:
|
||||||
|
"""
|
||||||
|
Test a redirect with an error.
|
||||||
|
"""
|
||||||
|
from superset.databases.schemas import OAuth2ProviderResponseSchema
|
||||||
|
|
||||||
|
schema = OAuth2ProviderResponseSchema()
|
||||||
|
|
||||||
|
payload = schema.load({"error": "access_denied"})
|
||||||
|
assert payload == {"error": "access_denied"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_oauth2_schema_extra() -> None:
|
||||||
|
"""
|
||||||
|
Test a redirect with extra keys.
|
||||||
|
"""
|
||||||
|
from superset.databases.schemas import OAuth2ProviderResponseSchema
|
||||||
|
|
||||||
|
schema = OAuth2ProviderResponseSchema()
|
||||||
|
|
||||||
|
payload = schema.load(
|
||||||
|
{
|
||||||
|
"code": "SECRET",
|
||||||
|
"state": "12345",
|
||||||
|
"optional": "NEW THING",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert payload == {"code": "SECRET", "state": "12345"}
|
||||||
|
|
|
||||||
|
|
@ -65,8 +65,9 @@ def test_execute_connection_error() -> None:
|
||||||
cursor.execute.side_effect = NewConnectionError(
|
cursor.execute.side_effect = NewConnectionError(
|
||||||
HTTPConnection("localhost"), "Exception with sensitive data"
|
HTTPConnection("localhost"), "Exception with sensitive data"
|
||||||
)
|
)
|
||||||
with pytest.raises(SupersetDBAPIDatabaseError) as ex:
|
with pytest.raises(SupersetDBAPIDatabaseError) as excinfo:
|
||||||
ClickHouseEngineSpec.execute(cursor, "SELECT col1 from table1")
|
ClickHouseEngineSpec.execute(cursor, "SELECT col1 from table1", 1)
|
||||||
|
assert str(excinfo.value) == "Connection failed"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
|
||||||
|
|
@ -66,8 +66,9 @@ def test_execute_connection_error() -> None:
|
||||||
cursor.execute.side_effect = NewConnectionError(
|
cursor.execute.side_effect = NewConnectionError(
|
||||||
HTTPConnection("Dummypool"), "Exception with sensitive data"
|
HTTPConnection("Dummypool"), "Exception with sensitive data"
|
||||||
)
|
)
|
||||||
with pytest.raises(SupersetDBAPIDatabaseError) as ex:
|
with pytest.raises(SupersetDBAPIDatabaseError) as excinfo:
|
||||||
DatabendEngineSpec.execute(cursor, "SELECT col1 from table1")
|
DatabendEngineSpec.execute(cursor, "SELECT col1 from table1", 1)
|
||||||
|
assert str(excinfo.value) == "Connection failed"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ def test_odbc_impersonation() -> None:
|
||||||
|
|
||||||
url = URL.create("drill+odbc")
|
url = URL.create("drill+odbc")
|
||||||
username = "DoAsUser"
|
username = "DoAsUser"
|
||||||
url = DrillEngineSpec.get_url_for_impersonation(url, True, username)
|
url = DrillEngineSpec.get_url_for_impersonation(url, True, username, None)
|
||||||
assert url.query["DelegationUID"] == username
|
assert url.query["DelegationUID"] == username
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -54,7 +54,7 @@ def test_jdbc_impersonation() -> None:
|
||||||
|
|
||||||
url = URL.create("drill+jdbc")
|
url = URL.create("drill+jdbc")
|
||||||
username = "DoAsUser"
|
username = "DoAsUser"
|
||||||
url = DrillEngineSpec.get_url_for_impersonation(url, True, username)
|
url = DrillEngineSpec.get_url_for_impersonation(url, True, username, None)
|
||||||
assert url.query["impersonation_target"] == username
|
assert url.query["impersonation_target"] == username
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -70,7 +70,7 @@ def test_sadrill_impersonation() -> None:
|
||||||
|
|
||||||
url = URL.create("drill+sadrill")
|
url = URL.create("drill+sadrill")
|
||||||
username = "DoAsUser"
|
username = "DoAsUser"
|
||||||
url = DrillEngineSpec.get_url_for_impersonation(url, True, username)
|
url = DrillEngineSpec.get_url_for_impersonation(url, True, username, None)
|
||||||
assert url.query["impersonation_target"] == username
|
assert url.query["impersonation_target"] == username
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -90,7 +90,7 @@ def test_invalid_impersonation() -> None:
|
||||||
username = "DoAsUser"
|
username = "DoAsUser"
|
||||||
|
|
||||||
with pytest.raises(SupersetDBAPIProgrammingError):
|
with pytest.raises(SupersetDBAPIProgrammingError):
|
||||||
DrillEngineSpec.get_url_for_impersonation(url, True, username)
|
DrillEngineSpec.get_url_for_impersonation(url, True, username, None)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
|
||||||
|
|
@ -101,6 +101,8 @@ def test_opendistro_strip_comments() -> None:
|
||||||
mock_cursor.execute.return_value = []
|
mock_cursor.execute.return_value = []
|
||||||
|
|
||||||
OpenDistroEngineSpec.execute(
|
OpenDistroEngineSpec.execute(
|
||||||
mock_cursor, "-- some comment \nSELECT 1\n --other comment"
|
mock_cursor,
|
||||||
|
"-- some comment \nSELECT 1\n --other comment",
|
||||||
|
1,
|
||||||
)
|
)
|
||||||
mock_cursor.execute.assert_called_once_with("SELECT 1\n")
|
mock_cursor.execute.assert_called_once_with("SELECT 1\n")
|
||||||
|
|
|
||||||
|
|
@ -18,14 +18,21 @@
|
||||||
# pylint: disable=import-outside-toplevel, invalid-name, line-too-long
|
# pylint: disable=import-outside-toplevel, invalid-name, line-too-long
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import pytest
|
import pytest
|
||||||
from pytest_mock import MockFixture
|
from pytest_mock import MockFixture
|
||||||
|
from sqlalchemy.engine.url import make_url
|
||||||
|
|
||||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||||
from superset.exceptions import SupersetException
|
from superset.exceptions import SupersetException
|
||||||
from superset.sql_parse import Table
|
from superset.sql_parse import Table
|
||||||
|
from superset.utils.oauth2 import decode_oauth2_state
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from superset.db_engine_specs.base import OAuth2State
|
||||||
|
|
||||||
|
|
||||||
class ProgrammingError(Exception):
|
class ProgrammingError(Exception):
|
||||||
|
|
@ -399,3 +406,223 @@ def test_upload_existing(mocker: MockFixture) -> None:
|
||||||
mocker.call().json(),
|
mocker.call().json(),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_url_for_impersonation_username(mocker: MockFixture) -> None:
|
||||||
|
"""
|
||||||
|
Test passing a username to `get_url_for_impersonation`.
|
||||||
|
"""
|
||||||
|
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
|
||||||
|
|
||||||
|
user = mocker.MagicMock()
|
||||||
|
user.email = "alice@example.org"
|
||||||
|
mocker.patch(
|
||||||
|
"superset.db_engine_specs.gsheets.security_manager.find_user",
|
||||||
|
return_value=user,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert GSheetsEngineSpec.get_url_for_impersonation(
|
||||||
|
url=make_url("gsheets://"),
|
||||||
|
impersonate_user=True,
|
||||||
|
username="alice",
|
||||||
|
access_token=None,
|
||||||
|
) == make_url("gsheets://?subject=alice%40example.org")
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_url_for_impersonation_access_token() -> None:
|
||||||
|
"""
|
||||||
|
Test passing an access token to `get_url_for_impersonation`.
|
||||||
|
"""
|
||||||
|
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
|
||||||
|
|
||||||
|
assert GSheetsEngineSpec.get_url_for_impersonation(
|
||||||
|
url=make_url("gsheets://"),
|
||||||
|
impersonate_user=True,
|
||||||
|
username=None,
|
||||||
|
access_token="access-token",
|
||||||
|
) == make_url("gsheets://?access_token=access-token")
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_oauth2_enabled_no_config(mocker: MockFixture) -> None:
|
||||||
|
"""
|
||||||
|
Test `is_oauth2_enabled` when OAuth2 is not configured.
|
||||||
|
"""
|
||||||
|
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
|
||||||
|
|
||||||
|
mocker.patch(
|
||||||
|
"superset.db_engine_specs.gsheets.current_app.config",
|
||||||
|
new={"DATABASE_OAUTH2_CREDENTIALS": {}},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert GSheetsEngineSpec.is_oauth2_enabled() is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_oauth2_enabled_config(mocker: MockFixture) -> None:
|
||||||
|
"""
|
||||||
|
Test `is_oauth2_enabled` when OAuth2 is configured.
|
||||||
|
"""
|
||||||
|
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
|
||||||
|
|
||||||
|
mocker.patch(
|
||||||
|
"superset.db_engine_specs.gsheets.current_app.config",
|
||||||
|
new={
|
||||||
|
"DATABASE_OAUTH2_CREDENTIALS": {
|
||||||
|
"Google Sheets": {
|
||||||
|
"CLIENT_ID": "XXX.apps.googleusercontent.com",
|
||||||
|
"CLIENT_SECRET": "GOCSPX-YYY",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert GSheetsEngineSpec.is_oauth2_enabled() is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_oauth2_authorization_uri(mocker: MockFixture) -> None:
|
||||||
|
"""
|
||||||
|
Test `get_oauth2_authorization_uri`.
|
||||||
|
"""
|
||||||
|
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
|
||||||
|
|
||||||
|
mocker.patch(
|
||||||
|
"superset.db_engine_specs.gsheets.current_app.config",
|
||||||
|
new={
|
||||||
|
"DATABASE_OAUTH2_CREDENTIALS": {
|
||||||
|
"Google Sheets": {
|
||||||
|
"CLIENT_ID": "XXX.apps.googleusercontent.com",
|
||||||
|
"CLIENT_SECRET": "GOCSPX-YYY",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"SECRET_KEY": "not-a-secret",
|
||||||
|
"DATABASE_OAUTH2_JWT_ALGORITHM": "HS256",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
state: OAuth2State = {
|
||||||
|
"database_id": 1,
|
||||||
|
"user_id": 1,
|
||||||
|
"default_redirect_uri": "http://localhost:8088/api/v1/oauth2/",
|
||||||
|
"tab_id": "1234",
|
||||||
|
}
|
||||||
|
|
||||||
|
url = GSheetsEngineSpec.get_oauth2_authorization_uri(state)
|
||||||
|
parsed = urlparse(url)
|
||||||
|
assert parsed.netloc == "accounts.google.com"
|
||||||
|
assert parsed.path == "/o/oauth2/v2/auth"
|
||||||
|
|
||||||
|
query = parse_qs(parsed.query)
|
||||||
|
assert query["scope"][0] == (
|
||||||
|
"https://www.googleapis.com/auth/drive.readonly "
|
||||||
|
"https://www.googleapis.com/auth/spreadsheets "
|
||||||
|
"https://spreadsheets.google.com/feeds"
|
||||||
|
)
|
||||||
|
encoded_state = query["state"][0].replace("%2E", ".")
|
||||||
|
assert decode_oauth2_state(encoded_state) == state
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_oauth2_token(mocker: MockFixture) -> None:
|
||||||
|
"""
|
||||||
|
Test `get_oauth2_token`.
|
||||||
|
"""
|
||||||
|
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
|
||||||
|
|
||||||
|
http = mocker.patch("superset.db_engine_specs.gsheets.http")
|
||||||
|
http.request().data.decode.return_value = json.dumps(
|
||||||
|
{
|
||||||
|
"access_token": "access-token",
|
||||||
|
"expires_in": 3600,
|
||||||
|
"scope": "scope",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"refresh_token": "refresh-token",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
mocker.patch(
|
||||||
|
"superset.db_engine_specs.gsheets.current_app.config",
|
||||||
|
new={
|
||||||
|
"DATABASE_OAUTH2_CREDENTIALS": {
|
||||||
|
"Google Sheets": {
|
||||||
|
"CLIENT_ID": "XXX.apps.googleusercontent.com",
|
||||||
|
"CLIENT_SECRET": "GOCSPX-YYY",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"SECRET_KEY": "not-a-secret",
|
||||||
|
"DATABASE_OAUTH2_JWT_ALGORITHM": "HS256",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
state: OAuth2State = {
|
||||||
|
"database_id": 1,
|
||||||
|
"user_id": 1,
|
||||||
|
"default_redirect_uri": "http://localhost:8088/api/v1/oauth2/",
|
||||||
|
"tab_id": "1234",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert GSheetsEngineSpec.get_oauth2_token("code", state) == {
|
||||||
|
"access_token": "access-token",
|
||||||
|
"expires_in": 3600,
|
||||||
|
"scope": "scope",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"refresh_token": "refresh-token",
|
||||||
|
}
|
||||||
|
http.request.assert_called_with(
|
||||||
|
"POST",
|
||||||
|
"https://oauth2.googleapis.com/token",
|
||||||
|
fields={
|
||||||
|
"code": "code",
|
||||||
|
"client_id": "XXX.apps.googleusercontent.com",
|
||||||
|
"client_secret": "GOCSPX-YYY",
|
||||||
|
"redirect_uri": "http://localhost:8088/api/v1/oauth2/",
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_oauth2_fresh_token(mocker: MockFixture) -> None:
|
||||||
|
"""
|
||||||
|
Test `get_oauth2_token`.
|
||||||
|
"""
|
||||||
|
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
|
||||||
|
|
||||||
|
http = mocker.patch("superset.db_engine_specs.gsheets.http")
|
||||||
|
http.request().data.decode.return_value = json.dumps(
|
||||||
|
{
|
||||||
|
"access_token": "access-token",
|
||||||
|
"expires_in": 3600,
|
||||||
|
"scope": "scope",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"refresh_token": "refresh-token",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
mocker.patch(
|
||||||
|
"superset.db_engine_specs.gsheets.current_app.config",
|
||||||
|
new={
|
||||||
|
"DATABASE_OAUTH2_CREDENTIALS": {
|
||||||
|
"Google Sheets": {
|
||||||
|
"CLIENT_ID": "XXX.apps.googleusercontent.com",
|
||||||
|
"CLIENT_SECRET": "GOCSPX-YYY",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"SECRET_KEY": "not-a-secret",
|
||||||
|
"DATABASE_OAUTH2_JWT_ALGORITHM": "HS256",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert GSheetsEngineSpec.get_oauth2_fresh_token("refresh-token") == {
|
||||||
|
"access_token": "access-token",
|
||||||
|
"expires_in": 3600,
|
||||||
|
"scope": "scope",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"refresh_token": "refresh-token",
|
||||||
|
}
|
||||||
|
http.request.assert_called_with(
|
||||||
|
"POST",
|
||||||
|
"https://oauth2.googleapis.com/token",
|
||||||
|
fields={
|
||||||
|
"client_id": "XXX.apps.googleusercontent.com",
|
||||||
|
"client_secret": "GOCSPX-YYY",
|
||||||
|
"refresh_token": "refresh-token",
|
||||||
|
"grant_type": "refresh_token",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,9 @@ def test_execute_sql_statement(mocker: MockerFixture, app: None) -> None:
|
||||||
|
|
||||||
database.apply_limit_to_sql.assert_called_with("SELECT 42 AS answer", 2, force=True)
|
database.apply_limit_to_sql.assert_called_with("SELECT 42 AS answer", 2, force=True)
|
||||||
db_engine_spec.execute_with_cursor.assert_called_with(
|
db_engine_spec.execute_with_cursor.assert_called_with(
|
||||||
cursor, "SELECT 42 AS answer LIMIT 2", query
|
cursor,
|
||||||
|
"SELECT 42 AS answer LIMIT 2",
|
||||||
|
query,
|
||||||
)
|
)
|
||||||
SupersetResultSet.assert_called_with([(42,)], cursor.description, db_engine_spec)
|
SupersetResultSet.assert_called_with([(42,)], cursor.description, db_engine_spec)
|
||||||
|
|
||||||
|
|
@ -104,7 +106,9 @@ def test_execute_sql_statement_with_rls(
|
||||||
force=True,
|
force=True,
|
||||||
)
|
)
|
||||||
db_engine_spec.execute_with_cursor.assert_called_with(
|
db_engine_spec.execute_with_cursor.assert_called_with(
|
||||||
cursor, "SELECT * FROM sales WHERE organization_id=42 LIMIT 101", query
|
cursor,
|
||||||
|
"SELECT * FROM sales WHERE organization_id=42 LIMIT 101",
|
||||||
|
query,
|
||||||
)
|
)
|
||||||
SupersetResultSet.assert_called_with([(42,)], cursor.description, db_engine_spec)
|
SupersetResultSet.assert_called_with([(42,)], cursor.description, db_engine_spec)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,79 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from freezegun import freeze_time
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
|
from superset.exceptions import CreateKeyValueDistributedLockFailedException
|
||||||
|
from superset.key_value.exceptions import KeyValueCreateFailedError
|
||||||
|
from superset.key_value.types import KeyValueResource
|
||||||
|
from superset.utils.lock import KeyValueDistributedLock, serialize
|
||||||
|
|
||||||
|
|
||||||
|
def test_KeyValueDistributedLock_happy_path(mocker: MockerFixture) -> None:
|
||||||
|
"""
|
||||||
|
Test successfully acquiring the global auth lock.
|
||||||
|
"""
|
||||||
|
CreateKeyValueCommand = mocker.patch(
|
||||||
|
"superset.commands.key_value.create.CreateKeyValueCommand"
|
||||||
|
)
|
||||||
|
DeleteKeyValueCommand = mocker.patch(
|
||||||
|
"superset.commands.key_value.delete.DeleteKeyValueCommand"
|
||||||
|
)
|
||||||
|
DeleteExpiredKeyValueCommand = mocker.patch(
|
||||||
|
"superset.commands.key_value.delete_expired.DeleteExpiredKeyValueCommand"
|
||||||
|
)
|
||||||
|
PickleKeyValueCodec = mocker.patch("superset.utils.lock.PickleKeyValueCodec")
|
||||||
|
|
||||||
|
with freeze_time("2024-01-01"):
|
||||||
|
with KeyValueDistributedLock("ns", a=1, b=2) as key:
|
||||||
|
DeleteExpiredKeyValueCommand.assert_called_with(
|
||||||
|
resource=KeyValueResource.LOCK,
|
||||||
|
)
|
||||||
|
CreateKeyValueCommand.assert_called_with(
|
||||||
|
resource=KeyValueResource.LOCK,
|
||||||
|
codec=PickleKeyValueCodec(),
|
||||||
|
key=key,
|
||||||
|
value=True,
|
||||||
|
expires_on=datetime(2024, 1, 1, 0, 0, 30),
|
||||||
|
)
|
||||||
|
DeleteKeyValueCommand.assert_not_called()
|
||||||
|
|
||||||
|
DeleteKeyValueCommand.assert_called_with(
|
||||||
|
resource=KeyValueResource.LOCK,
|
||||||
|
key=key,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_KeyValueDistributedLock_no_lock(mocker: MockerFixture) -> None:
|
||||||
|
"""
|
||||||
|
Test unsuccessfully acquiring the global auth lock.
|
||||||
|
"""
|
||||||
|
mocker.patch(
|
||||||
|
"superset.commands.key_value.create.CreateKeyValueCommand",
|
||||||
|
side_effect=KeyValueCreateFailedError(),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(CreateKeyValueDistributedLockFailedException) as excinfo:
|
||||||
|
with KeyValueDistributedLock("ns", a=1, b=2):
|
||||||
|
pass
|
||||||
|
assert str(excinfo.value) == "Error acquiring lock"
|
||||||
|
|
@ -0,0 +1,95 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
# pylint: disable=invalid-name, disallowed-name
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from freezegun import freeze_time
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
|
from superset.utils.oauth2 import get_oauth2_access_token
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_oauth2_access_token_base_no_token(mocker: MockerFixture) -> None:
|
||||||
|
"""
|
||||||
|
Test `get_oauth2_access_token` when there's no token.
|
||||||
|
"""
|
||||||
|
db = mocker.patch("superset.utils.oauth2.db")
|
||||||
|
db_engine_spec = mocker.MagicMock()
|
||||||
|
db.session.query().filter_by().one_or_none.return_value = None
|
||||||
|
|
||||||
|
assert get_oauth2_access_token(1, 1, db_engine_spec) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_oauth2_access_token_base_token_valid(mocker: MockerFixture) -> None:
|
||||||
|
"""
|
||||||
|
Test `get_oauth2_access_token` when the token is valid.
|
||||||
|
"""
|
||||||
|
db = mocker.patch("superset.utils.oauth2.db")
|
||||||
|
db_engine_spec = mocker.MagicMock()
|
||||||
|
token = mocker.MagicMock()
|
||||||
|
token.access_token = "access-token"
|
||||||
|
token.access_token_expiration = datetime(2024, 1, 2)
|
||||||
|
db.session.query().filter_by().one_or_none.return_value = token
|
||||||
|
|
||||||
|
with freeze_time("2024-01-01"):
|
||||||
|
assert get_oauth2_access_token(1, 1, db_engine_spec) == "access-token"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_oauth2_access_token_base_refresh(mocker: MockerFixture) -> None:
|
||||||
|
"""
|
||||||
|
Test `get_oauth2_access_token` when the token needs to be refreshed.
|
||||||
|
"""
|
||||||
|
db = mocker.patch("superset.utils.oauth2.db")
|
||||||
|
db_engine_spec = mocker.MagicMock()
|
||||||
|
db_engine_spec.get_oauth2_fresh_token.return_value = {
|
||||||
|
"access_token": "new-token",
|
||||||
|
"expires_in": 3600,
|
||||||
|
}
|
||||||
|
token = mocker.MagicMock()
|
||||||
|
token.access_token = "access-token"
|
||||||
|
token.access_token_expiration = datetime(2024, 1, 1)
|
||||||
|
token.refresh_token = "refresh-token"
|
||||||
|
db.session.query().filter_by().one_or_none.return_value = token
|
||||||
|
|
||||||
|
with freeze_time("2024-01-02"):
|
||||||
|
assert get_oauth2_access_token(1, 1, db_engine_spec) == "new-token"
|
||||||
|
|
||||||
|
# check that token was updated
|
||||||
|
assert token.access_token == "new-token"
|
||||||
|
assert token.access_token_expiration == datetime(2024, 1, 2, 1)
|
||||||
|
db.session.add.assert_called_with(token)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_oauth2_access_token_base_no_refresh(mocker: MockerFixture) -> None:
|
||||||
|
"""
|
||||||
|
Test `get_oauth2_access_token` when token is expired and there's no refresh.
|
||||||
|
"""
|
||||||
|
db = mocker.patch("superset.utils.oauth2.db")
|
||||||
|
db_engine_spec = mocker.MagicMock()
|
||||||
|
token = mocker.MagicMock()
|
||||||
|
token.access_token = "access-token"
|
||||||
|
token.access_token_expiration = datetime(2024, 1, 1)
|
||||||
|
token.refresh_token = None
|
||||||
|
db.session.query().filter_by().one_or_none.return_value = token
|
||||||
|
|
||||||
|
with freeze_time("2024-01-02"):
|
||||||
|
assert get_oauth2_access_token(1, 1, db_engine_spec) is None
|
||||||
|
|
||||||
|
# check that token was deleted
|
||||||
|
db.session.delete.assert_called_with(token)
|
||||||
Loading…
Reference in New Issue