feat(oauth): adding necessary changes to support bigquery oauth (#30674)

This commit is contained in:
Jack 2024-10-30 14:56:22 -05:00 committed by GitHub
parent bc5da631c8
commit 849d426e06
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 191 additions and 48 deletions

View File

@ -318,3 +318,25 @@ export const forceSSLField = ({
/>
</div>
);
export const projectIdfield = ({
changeMethods,
getValidation,
validationErrors,
db,
}: FieldPropTypes) => (
<>
<ValidatedInput
id="project_id"
name="project_id"
required
value={db?.parameters?.project_id}
validationMethods={{ onBlur: getValidation }}
errorMessage={validationErrors?.project_id}
placeholder="your-project-1234-a1"
label={t('Project Id')}
onChange={changeMethods.onParametersChange}
helpText={t('Enter the unique project id for your database.')}
/>
</>
);

View File

@ -57,10 +57,11 @@ export const EncryptedField = ({
db?.engine === 'gsheets' ? !isEditMode && !isPublic : !isEditMode;
const isEncrypted = isEditMode && db?.masked_encrypted_extra !== '{}';
const encryptedField = db?.engine && encryptedCredentialsMap[db.engine];
const paramValue = db?.parameters?.[encryptedField];
const encryptedValue =
typeof db?.parameters?.[encryptedField] === 'object'
? JSON.stringify(db?.parameters?.[encryptedField])
: db?.parameters?.[encryptedField];
paramValue && typeof paramValue === 'object'
? JSON.stringify(paramValue)
: paramValue;
return (
<CredentialInfoForm>
{db?.engine === 'gsheets' && (

View File

@ -24,6 +24,14 @@ import { Input } from 'src/components/Input';
import { FormItem } from 'src/components/Form';
import { FieldPropTypes } from '../../types';
const LABELS = {
CLIENT_ID: 'Client ID',
SECRET: 'Client Secret',
AUTH_URI: 'Authorization Request URI',
TOKEN_URI: 'Token Request URI',
SCOPE: 'Scope',
};
interface OAuth2ClientInfo {
id: string;
secret: string;
@ -44,10 +52,6 @@ export const OAuth2ClientField = ({ changeMethods, db }: FieldPropTypes) => {
scope: encryptedExtra.oauth2_client_info?.scope || '',
});
if (db?.engine_information?.supports_oauth2 !== true) {
return null;
}
const handleChange = (key: any) => (e: any) => {
const updatedInfo = {
...oauth2ClientInfo,
@ -68,14 +72,14 @@ export const OAuth2ClientField = ({ changeMethods, db }: FieldPropTypes) => {
return (
<Collapse>
<Collapse.Panel header="OAuth2 client information" key="1">
<FormItem label="Client ID">
<FormItem label={LABELS.CLIENT_ID}>
<Input
data-test="client-id"
value={oauth2ClientInfo.id}
onChange={handleChange('id')}
/>
</FormItem>
<FormItem label="Client Secret">
<FormItem label={LABELS.SECRET}>
<Input
data-test="client-secret"
type="password"
@ -83,7 +87,7 @@ export const OAuth2ClientField = ({ changeMethods, db }: FieldPropTypes) => {
onChange={handleChange('secret')}
/>
</FormItem>
<FormItem label="Authorization Request URI">
<FormItem label={LABELS.AUTH_URI}>
<Input
data-test="client-authorization-request-uri"
placeholder="https://"
@ -91,7 +95,7 @@ export const OAuth2ClientField = ({ changeMethods, db }: FieldPropTypes) => {
onChange={handleChange('authorization_request_uri')}
/>
</FormItem>
<FormItem label="Token Request URI">
<FormItem label={LABELS.TOKEN_URI}>
<Input
data-test="client-token-request-uri"
placeholder="https://"
@ -99,7 +103,7 @@ export const OAuth2ClientField = ({ changeMethods, db }: FieldPropTypes) => {
onChange={handleChange('token_request_uri')}
/>
</FormItem>
<FormItem label="Scope">
<FormItem label={LABELS.SCOPE}>
<Input
data-test="client-scope"
value={oauth2ClientInfo.scope}

View File

@ -30,6 +30,7 @@ import {
passwordField,
portField,
queryField,
projectIdfield,
usernameField,
} from './CommonParameters';
import { OAuth2ClientField } from './OAuth2ClientField';
@ -50,6 +51,7 @@ export const FormFieldOrder = [
'http_path',
'http_path_field',
'database_name',
'project_id',
'credentials_info',
'service_account_info',
'catalog',
@ -89,4 +91,5 @@ export const FORM_FIELD_MAP = {
role: validatedInputField,
account: validatedInputField,
ssh: SSHTunnelSwitchComponent,
project_id: projectIdfield,
};

View File

@ -78,6 +78,7 @@ export type DatabaseObject = {
role?: string;
account?: string;
ssh?: boolean;
project_id?: string;
};
// Performance
@ -232,6 +233,7 @@ export enum ConfigurationMethod {
export enum Engines {
GSheet = 'gsheets',
BigQuery = 'bigquery',
Snowflake = 'snowflake',
}

View File

@ -93,7 +93,7 @@ class TestConnectionDatabaseCommand(BaseCommand):
self._context = context
self._uri = uri
def run(self) -> None: # pylint: disable=too-many-statements
def run(self) -> None: # pylint: disable=too-many-statements,too-many-branches
self.validate()
ex_str = ""
ssh_tunnel = self._properties.get("ssh_tunnel")
@ -225,6 +225,10 @@ class TestConnectionDatabaseCommand(BaseCommand):
# bubble up the exception to return proper status code
raise
except Exception as ex:
if database.is_oauth2_enabled() and database.db_engine_spec.needs_oauth2(
ex
):
database.start_oauth2_dance()
event_logger.log_with_context(
action=get_log_connection_action(
"test_connection_error", ssh_tunnel, ex

View File

@ -121,6 +121,7 @@ from superset.sql_parse import Table
from superset.superset_typing import FlaskResponse
from superset.utils import json
from superset.utils.core import error_msg_from_exception, parse_js_uri_path_item
from superset.utils.decorators import transaction
from superset.utils.oauth2 import decode_oauth2_state
from superset.utils.ssh_tunnel import mask_password_info
from superset.views.base_api import (
@ -1341,6 +1342,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
return self.response_404()
@expose("/oauth2/", methods=["GET"])
@transaction()
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.oauth2",
log_to_statsd=True,
@ -1428,7 +1430,6 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
"refresh_token": token_response.get("refresh_token"),
},
)
# return blank page that closes itself
return make_response(
render_template("superset/oauth2.html", tab_id=state["tab_id"]),

View File

@ -47,6 +47,11 @@ from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelMissingCredentials,
)
from superset.constants import PASSWORD_MASK
from superset.databases.types import ( # pylint:disable=unused-import
EncryptedDict, # noqa: F401
EncryptedField,
EncryptedString, # noqa: F401
)
from superset.databases.utils import make_url_safe
from superset.db_engine_specs import get_engine_spec
from superset.exceptions import CertificateException, SupersetSecurityException
@ -941,20 +946,6 @@ class ImportV1DatabaseSchema(Schema):
return
class EncryptedField: # pylint: disable=too-few-public-methods
"""
A database field that should be stored in encrypted_extra.
"""
class EncryptedString(EncryptedField, fields.String):
pass
class EncryptedDict(EncryptedField, fields.Dict):
pass
def encrypted_field_properties(self, field: Any, **_) -> dict[str, Any]: # type: ignore
ret = {}
if isinstance(field, EncryptedField):

View File

@ -0,0 +1,34 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# Field has been moved outside of the schemas.py file to
# allow for it to be imported from outside of app_context
from marshmallow import fields
class EncryptedField: # pylint: disable=too-few-public-methods
"""
A database field that should be stored in encrypted_extra.
"""
class EncryptedString(EncryptedField, fields.String):
pass
class EncryptedDict(EncryptedField, fields.Dict):
pass

View File

@ -1707,10 +1707,13 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return sql
@classmethod
def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]:
def estimate_statement_cost(
cls, database: Database, statement: str, cursor: Any
) -> dict[str, Any]:
"""
Generate a SQL query that estimates the cost of a given statement.
:param database: A Database object
:param statement: A single SQL statement
:param cursor: Cursor instance
:return: Dictionary with different costs
@ -1781,6 +1784,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
cursor = conn.cursor()
return [
cls.estimate_statement_cost(
database,
cls.process_statement(statement, database),
cursor,
)
@ -1809,8 +1813,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return url
@classmethod
def update_impersonation_config(
def update_impersonation_config( # pylint: disable=too-many-arguments
cls,
database: Database,
connect_args: dict[str, Any],
uri: str,
username: str | None,
@ -1820,6 +1825,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
Update a configuration dictionary
that can set the correct properties for impersonating users
:param connect_args: a Database object
:param connect_args: config to be updated
:param uri: URI
:param username: Effective username

View File

@ -409,7 +409,11 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
pandas_gbq.to_gbq(df, **to_gbq_kwargs)
@classmethod
def _get_client(cls, engine: Engine) -> bigquery.Client:
def _get_client(
cls,
engine: Engine,
database: Database, # pylint: disable=unused-argument
) -> bigquery.Client:
"""
Return the BigQuery client associated with an engine.
"""
@ -453,7 +457,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
catalog=catalog,
schema=schema,
) as engine:
client = cls._get_client(engine)
client = cls._get_client(engine, database)
return [
cls.custom_estimate_statement_cost(
cls.process_statement(statement, database),
@ -477,7 +481,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
return project
with database.get_sqla_engine() as engine:
client = cls._get_client(engine)
client = cls._get_client(engine, database)
return client.project
@classmethod
@ -493,7 +497,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
"""
engine: Engine
with database.get_sqla_engine() as engine:
client = cls._get_client(engine)
client = cls._get_client(engine, database)
projects = client.list_projects()
return {project.project_id for project in projects}

View File

@ -537,8 +537,9 @@ class HiveEngineSpec(PrestoEngineSpec):
return url
@classmethod
def update_impersonation_config(
def update_impersonation_config( # pylint: disable=too-many-arguments
cls,
database: Database,
connect_args: dict[str, Any],
uri: str,
username: str | None,
@ -547,6 +548,7 @@ class HiveEngineSpec(PrestoEngineSpec):
"""
Update a configuration dictionary
that can set the correct properties for impersonating users
:param database: the Database Object
:param connect_args:
:param uri: URI string
:param impersonate_user: Flag indicating if impersonation is enabled

View File

@ -351,7 +351,16 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec):
return True
@classmethod
def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]:
def estimate_statement_cost(
cls, database: Database, statement: str, cursor: Any
) -> dict[str, Any]:
"""
Run a SQL query that estimates the cost of a given statement.
:param database: A Database object
:param statement: A single SQL statement
:param cursor: Cursor instance
:return: JSON response from Trino
"""
sql = f"EXPLAIN {statement}"
cursor.execute(sql)

View File

@ -365,9 +365,12 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
return parse.unquote(database.split("/")[1])
@classmethod
def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]:
def estimate_statement_cost(
cls, database: Database, statement: str, cursor: Any
) -> dict[str, Any]:
"""
Run a SQL query that estimates the cost of a given statement.
:param database: A Database object
:param statement: A single SQL statement
:param cursor: Cursor instance
:return: JSON response from Trino
@ -945,8 +948,9 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
return version is not None and Version(version) >= Version("0.319")
@classmethod
def update_impersonation_config(
def update_impersonation_config( # pylint: disable=too-many-arguments
cls,
database: Database,
connect_args: dict[str, Any],
uri: str,
username: str | None,
@ -955,6 +959,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
"""
Update a configuration dictionary
that can set the correct properties for impersonating users
:param connect_args: the Database object
:param connect_args: config to be updated
:param uri: URI string
:param username: Effective username

View File

@ -116,8 +116,9 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
return metadata
@classmethod
def update_impersonation_config(
def update_impersonation_config( # pylint: disable=too-many-arguments
cls,
database: Database,
connect_args: dict[str, Any],
uri: str,
username: str | None,
@ -126,6 +127,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
"""
Update a configuration dictionary
that can set the correct properties for impersonating users
:param database: the Database object
:param connect_args: config to be updated
:param uri: URI string
:param username: Effective username

View File

@ -29,6 +29,7 @@ from contextlib import closing, contextmanager, nullcontext, suppress
from copy import deepcopy
from datetime import datetime
from functools import lru_cache
from inspect import signature
from typing import Any, Callable, cast, TYPE_CHECKING
import numpy
@ -510,12 +511,14 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
logger.debug("Database._get_sqla_engine(). Masked URL: %s", str(masked_url))
if self.impersonate_user:
self.db_engine_spec.update_impersonation_config(
connect_args,
str(sqlalchemy_url),
effective_username,
access_token,
# PR #30674 changed the signature of the method to include database.
# This ensures that the change is backwards compatible
args = [connect_args, str(sqlalchemy_url), effective_username, access_token]
args = self.add_database_to_signature(
self.db_engine_spec.update_impersonation_config,
args,
)
self.db_engine_spec.update_impersonation_config(*args)
if connect_args:
params["connect_args"] = connect_args
@ -543,6 +546,24 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
except Exception as ex:
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
def add_database_to_signature(
self,
func: Callable[..., None],
args: list[Any],
) -> list[Any]:
"""
Examines a function signature looking for a database param.
If the signature requires a database, the function appends self in the
proper position.
"""
# PR #30674 changed the signature of the method to include database.
# This ensures that the change is backwards compatible
sig = signature(func)
if "database" in (params := sig.parameters.keys()):
args.insert(list(params).index("database"), self)
return args
@contextmanager
def get_raw_connection(
self,

View File

@ -151,12 +151,13 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
DB Eng Specs (postgres): Test estimate_statement_cost select star
"""
database = mock.Mock()
cursor = mock.Mock()
cursor.fetchone.return_value = (
"Seq Scan on birth_names (cost=0.00..1537.91 rows=75691 width=46)",
)
sql = "SELECT * FROM birth_names"
results = PostgresEngineSpec.estimate_statement_cost(sql, cursor)
results = PostgresEngineSpec.estimate_statement_cost(database, sql, cursor)
assert results == {"Start-up cost": 0.0, "Total cost": 1537.91}
def test_estimate_statement_invalid_syntax(self):
@ -165,6 +166,7 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
"""
from psycopg2 import errors
database = mock.Mock()
cursor = mock.Mock()
cursor.execute.side_effect = errors.SyntaxError(
"""
@ -175,7 +177,7 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
)
sql = "DROP TABLE birth_names"
with self.assertRaises(errors.SyntaxError):
PostgresEngineSpec.estimate_statement_cost(sql, cursor)
PostgresEngineSpec.estimate_statement_cost(database, sql, cursor)
def test_query_cost_formatter_example_costs(self):
"""

View File

@ -905,22 +905,26 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
)
def test_estimate_statement_cost(self):
mock_database = mock.MagicMock()
mock_cursor = mock.MagicMock()
estimate_json = {"a": "b"}
mock_cursor.fetchone.return_value = [
'{"a": "b"}',
]
result = PrestoEngineSpec.estimate_statement_cost(
"SELECT * FROM brth_names", mock_cursor
mock_database,
"SELECT * FROM brth_names",
mock_cursor,
)
assert result == estimate_json
def test_estimate_statement_cost_invalid_syntax(self):
mock_database = mock.MagicMock()
mock_cursor = mock.MagicMock()
mock_cursor.execute.side_effect = Exception()
with self.assertRaises(Exception):
PrestoEngineSpec.estimate_statement_cost(
"DROP TABLE brth_names", mock_cursor
mock_database, "DROP TABLE brth_names", mock_cursor
)
def test_get_create_view(self):

View File

@ -432,6 +432,31 @@ def test_get_sqla_engine_user_impersonation(mocker: MockerFixture) -> None:
)
def test_add_database_to_signature():
args = ["param1", "param2"]
def func_without_db(param1, param2):
pass
def func_with_db_start(database, param1, param2):
pass
def func_with_db_end(param1, param2, database):
pass
database = Database(
database_name="my_db",
sqlalchemy_uri="trino://",
impersonate_user=True,
)
args1 = database.add_database_to_signature(func_without_db, args.copy())
assert args1 == ["param1", "param2"]
args2 = database.add_database_to_signature(func_with_db_start, args.copy())
assert args2 == [database, "param1", "param2"]
args3 = database.add_database_to_signature(func_with_db_end, args.copy())
assert args3 == ["param1", "param2", database]
@with_feature_flags(IMPERSONATE_WITH_EMAIL_PREFIX=True)
def test_get_sqla_engine_user_impersonation_email(mocker: MockerFixture) -> None:
"""