feat: OAuth2 database field (#30126)

This commit is contained in:
Beto Dealmeida 2024-09-03 20:57:55 -04:00 committed by GitHub
parent 6009023fad
commit ff449ad8ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 366 additions and 17 deletions

View File

@ -161,23 +161,20 @@ export const httpPathField = ({
getValidation,
validationErrors,
db,
}: FieldPropTypes) => {
console.error(db);
return (
<ValidatedInput
id="http_path_field"
name="http_path_field"
required={required}
value={db?.parameters?.http_path_field}
validationMethods={{ onBlur: getValidation }}
errorMessage={validationErrors?.http_path}
placeholder={t('e.g. sql/protocolv1/o/12345')}
label="HTTP Path"
onChange={changeMethods.onParametersChange}
helpText={t('Copy the name of the HTTP Path of your cluster.')}
/>
);
};
}: FieldPropTypes) => (
<ValidatedInput
id="http_path_field"
name="http_path_field"
required={required}
value={db?.parameters?.http_path_field}
validationMethods={{ onBlur: getValidation }}
errorMessage={validationErrors?.http_path}
placeholder={t('e.g. sql/protocolv1/o/12345')}
label="HTTP Path"
onChange={changeMethods.onParametersChange}
helpText={t('Copy the name of the HTTP Path of your cluster.')}
/>
);
export const usernameField = ({
required,
changeMethods,

View File

@ -0,0 +1,181 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { render, fireEvent } from '@testing-library/react';
import '@testing-library/jest-dom/extend-expect';
import { ThemeProvider, supersetTheme } from '@superset-ui/core';
import { DatabaseObject } from 'src/features/databases/types';
import { OAuth2ClientField } from './OAuth2ClientField';
const renderWithTheme = (component: JSX.Element) =>
render(<ThemeProvider theme={supersetTheme}>{component}</ThemeProvider>);
describe('OAuth2ClientField', () => {
const mockChangeMethods = {
onEncryptedExtraInputChange: jest.fn(),
onParametersChange: jest.fn(),
onChange: jest.fn(),
onQueryChange: jest.fn(),
onParametersUploadFileChange: jest.fn(),
onAddTableCatalog: jest.fn(),
onRemoveTableCatalog: jest.fn(),
onExtraInputChange: jest.fn(),
onSSHTunnelParametersChange: jest.fn(),
};
const defaultProps = {
required: false,
onParametersChange: jest.fn(),
onParametersUploadFileChange: jest.fn(),
changeMethods: mockChangeMethods,
validationErrors: null,
getValidation: jest.fn(),
clearValidationErrors: jest.fn(),
field: 'test',
db: {
configuration_method: 'dynamic_form',
database_name: 'test',
driver: 'test',
id: 1,
name: 'test',
is_managed_externally: false,
engine_information: {
supports_oauth2: true,
},
masked_encrypted_extra: JSON.stringify({
oauth2_client_info: {
id: 'test-id',
secret: 'test-secret',
authorization_request_uri: 'https://auth-uri',
token_request_uri: 'https://token-uri',
scope: 'test-scope',
},
}),
} as DatabaseObject,
};
afterEach(() => {
jest.clearAllMocks();
});
it('does not show input fields until the collapse trigger is clicked', () => {
const { getByText, getByTestId, queryByTestId } = renderWithTheme(
<OAuth2ClientField {...defaultProps} />,
);
expect(queryByTestId('client-id')).not.toBeInTheDocument();
expect(queryByTestId('client-secret')).not.toBeInTheDocument();
expect(
queryByTestId('client-authorization-request-uri'),
).not.toBeInTheDocument();
expect(queryByTestId('client-token-request-uri')).not.toBeInTheDocument();
expect(queryByTestId('client-scope')).not.toBeInTheDocument();
const collapseTrigger = getByText('OAuth2 client information');
fireEvent.click(collapseTrigger);
expect(getByTestId('client-id')).toBeInTheDocument();
expect(getByTestId('client-secret')).toBeInTheDocument();
expect(getByTestId('client-authorization-request-uri')).toBeInTheDocument();
expect(getByTestId('client-token-request-uri')).toBeInTheDocument();
expect(getByTestId('client-scope')).toBeInTheDocument();
});
it('renders the OAuth2ClientField component with initial values', () => {
const { getByTestId, getByText } = renderWithTheme(
<OAuth2ClientField {...defaultProps} />,
);
const collapseTrigger = getByText('OAuth2 client information');
fireEvent.click(collapseTrigger);
expect(getByTestId('client-id')).toHaveValue('test-id');
expect(getByTestId('client-secret')).toHaveValue('test-secret');
expect(getByTestId('client-authorization-request-uri')).toHaveValue(
'https://auth-uri',
);
expect(getByTestId('client-token-request-uri')).toHaveValue(
'https://token-uri',
);
expect(getByTestId('client-scope')).toHaveValue('test-scope');
});
it('handles input changes and triggers onEncryptedExtraInputChange', () => {
const { getByTestId, getByText } = renderWithTheme(
<OAuth2ClientField {...defaultProps} />,
);
const collapseTrigger = getByText('OAuth2 client information');
fireEvent.click(collapseTrigger);
const clientIdInput = getByTestId('client-id');
fireEvent.change(clientIdInput, { target: { value: 'new-id' } });
expect(mockChangeMethods.onEncryptedExtraInputChange).toHaveBeenCalledWith(
expect.objectContaining({
target: {
name: 'oauth2_client_info',
value: expect.objectContaining({ id: 'new-id' }),
},
}),
);
});
it('does not render when supports_oauth2 is false', () => {
const props = {
...defaultProps,
db: {
...defaultProps.db,
engine_information: {
supports_oauth2: false,
},
},
};
const { queryByTestId } = renderWithTheme(<OAuth2ClientField {...props} />);
expect(queryByTestId('client-id')).not.toBeInTheDocument();
});
it('renders empty fields when masked_encrypted_extra is empty', () => {
const props = {
...defaultProps,
db: {
...defaultProps.db,
engine_information: {
supports_oauth2: true,
},
masked_encrypted_extra: '{}',
},
};
const { getByTestId, getByText } = renderWithTheme(
<OAuth2ClientField {...props} />,
);
const collapseTrigger = getByText('OAuth2 client information');
fireEvent.click(collapseTrigger);
expect(getByTestId('client-id')).toHaveValue('');
expect(getByTestId('client-secret')).toHaveValue('');
expect(getByTestId('client-authorization-request-uri')).toHaveValue('');
expect(getByTestId('client-token-request-uri')).toHaveValue('');
expect(getByTestId('client-scope')).toHaveValue('');
});
});

View File

@ -0,0 +1,112 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { useState } from 'react';
import Collapse from 'src/components/Collapse';
import { Input } from 'src/components/Input';
import { FormItem } from 'src/components/Form';
import { FieldPropTypes } from '../../types';
interface OAuth2ClientInfo {
id: string;
secret: string;
authorization_request_uri: string;
token_request_uri: string;
scope: string;
}
export const OAuth2ClientField = ({ changeMethods, db }: FieldPropTypes) => {
const encryptedExtra = JSON.parse(db?.masked_encrypted_extra || '{}');
const [oauth2ClientInfo, setOauth2ClientInfo] = useState<OAuth2ClientInfo>({
id: encryptedExtra.oauth2_client_info?.id || '',
secret: encryptedExtra.oauth2_client_info?.secret || '',
authorization_request_uri:
encryptedExtra.oauth2_client_info?.authorization_request_uri || '',
token_request_uri:
encryptedExtra.oauth2_client_info?.token_request_uri || '',
scope: encryptedExtra.oauth2_client_info?.scope || '',
});
if (db?.engine_information?.supports_oauth2 !== true) {
return null;
}
const handleChange = (key: any) => (e: any) => {
const updatedInfo = {
...oauth2ClientInfo,
[key]: e.target.value,
};
setOauth2ClientInfo(updatedInfo);
const event = {
target: {
name: 'oauth2_client_info',
value: updatedInfo,
},
};
changeMethods.onEncryptedExtraInputChange(event);
};
return (
<Collapse>
<Collapse.Panel header="OAuth2 client information" key="1">
<FormItem label="Client ID">
<Input
data-test="client-id"
value={oauth2ClientInfo.id}
onChange={handleChange('id')}
/>
</FormItem>
<FormItem label="Client Secret">
<Input
data-test="client-secret"
type="password"
value={oauth2ClientInfo.secret}
onChange={handleChange('secret')}
/>
</FormItem>
<FormItem label="Authorization Request URI">
<Input
data-test="client-authorization-request-uri"
placeholder="https://"
value={oauth2ClientInfo.authorization_request_uri}
onChange={handleChange('authorization_request_uri')}
/>
</FormItem>
<FormItem label="Token Request URI">
<Input
data-test="client-token-request-uri"
placeholder="https://"
value={oauth2ClientInfo.token_request_uri}
onChange={handleChange('token_request_uri')}
/>
</FormItem>
<FormItem label="Scope">
<Input
data-test="client-scope"
value={oauth2ClientInfo.scope}
onChange={handleChange('scope')}
/>
</FormItem>
</Collapse.Panel>
</Collapse>
);
};

View File

@ -32,6 +32,7 @@ import {
queryField,
usernameField,
} from './CommonParameters';
import { OAuth2ClientField } from './OAuth2ClientField';
import { validatedInputField } from './ValidatedInputField';
import { EncryptedField } from './EncryptedField';
import { TableCatalog } from './TableCatalog';
@ -58,6 +59,7 @@ export const FormFieldOrder = [
'warehouse',
'role',
'ssh',
'oauth2_client',
];
const extensionsRegistry = getExtensionsRegistry();
@ -75,6 +77,7 @@ export const FORM_FIELD_MAP = {
default_schema: defaultSchemaField,
username: usernameField,
password: passwordField,
oauth2_client: OAuth2ClientField,
access_token: accessTokenField,
database_name: displayField,
query: queryField,

View File

@ -32,6 +32,7 @@ const DatabaseConnectionForm = ({
onAddTableCatalog,
onChange,
onExtraInputChange,
onEncryptedExtraInputChange,
onParametersChange,
onParametersUploadFileChange,
onQueryChange,
@ -75,6 +76,7 @@ const DatabaseConnectionForm = ({
onAddTableCatalog,
onRemoveTableCatalog,
onExtraInputChange,
onEncryptedExtraInputChange,
},
validationErrors,
getValidation,

View File

@ -1723,6 +1723,20 @@ describe('dbReducer', () => {
});
});
test('it will set state to payload from encrypted extra input change', () => {
const action: DBReducerActionType = {
type: ActionType.EncryptedExtraInputChange,
payload: { name: 'foo', value: 'bar' },
};
const currentState = dbReducer(databaseFixture, action);
// extra should be serialized
expect(currentState).toEqual({
...databaseFixture,
masked_encrypted_extra: '{"foo":"bar"}',
});
});
test('it will set state to payload from extra input change when checkbox', () => {
const action: DBReducerActionType = {
type: ActionType.ExtraInputChange,

View File

@ -154,6 +154,7 @@ export enum ActionType {
EditorChange,
ExtraEditorChange,
ExtraInputChange,
EncryptedExtraInputChange,
Fetched,
InputChange,
ParametersChange,
@ -185,6 +186,7 @@ export type DBReducerActionType =
type:
| ActionType.ExtraEditorChange
| ActionType.ExtraInputChange
| ActionType.EncryptedExtraInputChange
| ActionType.TextChange
| ActionType.QueryChange
| ActionType.InputChange
@ -269,6 +271,14 @@ export function dbReducer(
[action.payload.name]: actionPayloadJson,
}),
};
case ActionType.EncryptedExtraInputChange:
return {
...trimmedState,
masked_encrypted_extra: JSON.stringify({
...JSON.parse(trimmedState.masked_encrypted_extra || '{}'),
[action.payload.name]: action.payload.value,
}),
};
case ActionType.ExtraInputChange:
// "extra" payload in state is a string
if (
@ -1656,6 +1666,16 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
value: target.value,
})
}
onEncryptedExtraInputChange={({
target,
}: {
target: HTMLInputElement;
}) =>
onChange(ActionType.EncryptedExtraInputChange, {
name: target.name,
value: target.value,
})
}
onRemoveTableCatalog={(idx: number) => {
setDB({
type: ActionType.RemoveTableCatalogSheet,

View File

@ -113,6 +113,7 @@ export type DatabaseObject = {
supports_file_upload?: boolean;
disable_ssh_tunneling?: boolean;
supports_dynamic_catalog?: boolean;
supports_oauth2?: boolean;
};
// SSH Tunnel information
@ -301,6 +302,7 @@ export interface FieldPropTypes {
onRemoveTableCatalog: (idx: number) => void;
} & {
onExtraInputChange: (value: any) => void;
onEncryptedExtraInputChange: (value: any) => void;
onSSHTunnelParametersChange: CustomEventHandlerType;
};
validationErrors: JsonObject | null;
@ -352,6 +354,9 @@ export interface DatabaseConnectionFormProps {
onExtraInputChange: (
event: FormEvent<InputProps> | { target: HTMLInputElement },
) => void;
onEncryptedExtraInputChange: (
event: FormEvent<InputProps> | { target: HTMLInputElement },
) => void;
onAddTableCatalog: () => void;
onRemoveTableCatalog: (idx: number) => void;
validationErrors: JsonObject | null;

View File

@ -985,6 +985,9 @@ class EngineInformationSchema(Schema):
"description": "The database supports multiple catalogs in a single connection"
}
)
supports_oauth2 = fields.Boolean(
metadata={"description": "The database supports OAuth2"}
)
class DatabaseConnectionSchema(Schema):

View File

@ -2230,6 +2230,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
"supports_file_upload": cls.supports_file_upload,
"disable_ssh_tunneling": cls.disable_ssh_tunneling,
"supports_dynamic_catalog": cls.supports_dynamic_catalog,
"supports_oauth2": cls.supports_oauth2,
}
@classmethod
@ -2351,6 +2352,7 @@ class BasicParametersMixin:
parameters: BasicParametersType,
encrypted_extra: dict[str, str] | None = None,
) -> str:
# TODO (betodealmeida): this method should also build `connect_args`
# make a copy so that we don't update the original
query = parameters.get("query", {}).copy()
if parameters.get("encryption"):

View File

@ -3254,6 +3254,7 @@ class TestDatabaseApi(SupersetTestCase):
"supports_file_upload": True,
"supports_dynamic_catalog": True,
"disable_ssh_tunneling": False,
"supports_oauth2": False,
},
"supports_oauth2": False,
},
@ -3279,6 +3280,7 @@ class TestDatabaseApi(SupersetTestCase):
"supports_file_upload": True,
"supports_dynamic_catalog": True,
"disable_ssh_tunneling": True,
"supports_oauth2": False,
},
"supports_oauth2": False,
},
@ -3336,6 +3338,7 @@ class TestDatabaseApi(SupersetTestCase):
"supports_file_upload": True,
"supports_dynamic_catalog": False,
"disable_ssh_tunneling": False,
"supports_oauth2": False,
},
"supports_oauth2": False,
},
@ -3361,6 +3364,7 @@ class TestDatabaseApi(SupersetTestCase):
"supports_file_upload": True,
"supports_dynamic_catalog": False,
"disable_ssh_tunneling": True,
"supports_oauth2": True,
},
"supports_oauth2": True,
},
@ -3418,6 +3422,7 @@ class TestDatabaseApi(SupersetTestCase):
"supports_file_upload": True,
"supports_dynamic_catalog": False,
"disable_ssh_tunneling": False,
"supports_oauth2": False,
},
"supports_oauth2": False,
},
@ -3431,6 +3436,7 @@ class TestDatabaseApi(SupersetTestCase):
"supports_file_upload": True,
"supports_dynamic_catalog": False,
"disable_ssh_tunneling": False,
"supports_oauth2": False,
},
"supports_oauth2": False,
},
@ -3465,6 +3471,7 @@ class TestDatabaseApi(SupersetTestCase):
"supports_file_upload": True,
"supports_dynamic_catalog": False,
"disable_ssh_tunneling": False,
"supports_oauth2": False,
},
"supports_oauth2": False,
},
@ -3478,6 +3485,7 @@ class TestDatabaseApi(SupersetTestCase):
"supports_file_upload": True,
"supports_dynamic_catalog": False,
"disable_ssh_tunneling": False,
"supports_oauth2": False,
},
"supports_oauth2": False,
},

View File

@ -239,6 +239,7 @@ def test_database_connection(
"disable_ssh_tunneling": True,
"supports_dynamic_catalog": False,
"supports_file_upload": True,
"supports_oauth2": True,
},
"expose_in_sqllab": True,
"extra": '{\n "metadata_params": {},\n "engine_params": {},\n "metadata_cache_timeout": {},\n "schemas_allowed_for_file_upload": []\n}\n',
@ -311,6 +312,7 @@ def test_database_connection(
"disable_ssh_tunneling": True,
"supports_dynamic_catalog": False,
"supports_file_upload": True,
"supports_oauth2": True,
},
"expose_in_sqllab": True,
"force_ctas_schema": None,