feat(oauth): adding necessary changes to support bigquery oauth (#30674)
This commit is contained in:
parent
bc5da631c8
commit
849d426e06
|
|
@ -318,3 +318,25 @@ export const forceSSLField = ({
|
|||
/>
|
||||
</div>
|
||||
);
|
||||
|
||||
export const projectIdfield = ({
|
||||
changeMethods,
|
||||
getValidation,
|
||||
validationErrors,
|
||||
db,
|
||||
}: FieldPropTypes) => (
|
||||
<>
|
||||
<ValidatedInput
|
||||
id="project_id"
|
||||
name="project_id"
|
||||
required
|
||||
value={db?.parameters?.project_id}
|
||||
validationMethods={{ onBlur: getValidation }}
|
||||
errorMessage={validationErrors?.project_id}
|
||||
placeholder="your-project-1234-a1"
|
||||
label={t('Project Id')}
|
||||
onChange={changeMethods.onParametersChange}
|
||||
helpText={t('Enter the unique project id for your database.')}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
|
|
|
|||
|
|
@ -57,10 +57,11 @@ export const EncryptedField = ({
|
|||
db?.engine === 'gsheets' ? !isEditMode && !isPublic : !isEditMode;
|
||||
const isEncrypted = isEditMode && db?.masked_encrypted_extra !== '{}';
|
||||
const encryptedField = db?.engine && encryptedCredentialsMap[db.engine];
|
||||
const paramValue = db?.parameters?.[encryptedField];
|
||||
const encryptedValue =
|
||||
typeof db?.parameters?.[encryptedField] === 'object'
|
||||
? JSON.stringify(db?.parameters?.[encryptedField])
|
||||
: db?.parameters?.[encryptedField];
|
||||
paramValue && typeof paramValue === 'object'
|
||||
? JSON.stringify(paramValue)
|
||||
: paramValue;
|
||||
return (
|
||||
<CredentialInfoForm>
|
||||
{db?.engine === 'gsheets' && (
|
||||
|
|
|
|||
|
|
@ -24,6 +24,14 @@ import { Input } from 'src/components/Input';
|
|||
import { FormItem } from 'src/components/Form';
|
||||
import { FieldPropTypes } from '../../types';
|
||||
|
||||
const LABELS = {
|
||||
CLIENT_ID: 'Client ID',
|
||||
SECRET: 'Client Secret',
|
||||
AUTH_URI: 'Authorization Request URI',
|
||||
TOKEN_URI: 'Token Request URI',
|
||||
SCOPE: 'Scope',
|
||||
};
|
||||
|
||||
interface OAuth2ClientInfo {
|
||||
id: string;
|
||||
secret: string;
|
||||
|
|
@ -44,10 +52,6 @@ export const OAuth2ClientField = ({ changeMethods, db }: FieldPropTypes) => {
|
|||
scope: encryptedExtra.oauth2_client_info?.scope || '',
|
||||
});
|
||||
|
||||
if (db?.engine_information?.supports_oauth2 !== true) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const handleChange = (key: any) => (e: any) => {
|
||||
const updatedInfo = {
|
||||
...oauth2ClientInfo,
|
||||
|
|
@ -68,14 +72,14 @@ export const OAuth2ClientField = ({ changeMethods, db }: FieldPropTypes) => {
|
|||
return (
|
||||
<Collapse>
|
||||
<Collapse.Panel header="OAuth2 client information" key="1">
|
||||
<FormItem label="Client ID">
|
||||
<FormItem label={LABELS.CLIENT_ID}>
|
||||
<Input
|
||||
data-test="client-id"
|
||||
value={oauth2ClientInfo.id}
|
||||
onChange={handleChange('id')}
|
||||
/>
|
||||
</FormItem>
|
||||
<FormItem label="Client Secret">
|
||||
<FormItem label={LABELS.SECRET}>
|
||||
<Input
|
||||
data-test="client-secret"
|
||||
type="password"
|
||||
|
|
@ -83,7 +87,7 @@ export const OAuth2ClientField = ({ changeMethods, db }: FieldPropTypes) => {
|
|||
onChange={handleChange('secret')}
|
||||
/>
|
||||
</FormItem>
|
||||
<FormItem label="Authorization Request URI">
|
||||
<FormItem label={LABELS.AUTH_URI}>
|
||||
<Input
|
||||
data-test="client-authorization-request-uri"
|
||||
placeholder="https://"
|
||||
|
|
@ -91,7 +95,7 @@ export const OAuth2ClientField = ({ changeMethods, db }: FieldPropTypes) => {
|
|||
onChange={handleChange('authorization_request_uri')}
|
||||
/>
|
||||
</FormItem>
|
||||
<FormItem label="Token Request URI">
|
||||
<FormItem label={LABELS.TOKEN_URI}>
|
||||
<Input
|
||||
data-test="client-token-request-uri"
|
||||
placeholder="https://"
|
||||
|
|
@ -99,7 +103,7 @@ export const OAuth2ClientField = ({ changeMethods, db }: FieldPropTypes) => {
|
|||
onChange={handleChange('token_request_uri')}
|
||||
/>
|
||||
</FormItem>
|
||||
<FormItem label="Scope">
|
||||
<FormItem label={LABELS.SCOPE}>
|
||||
<Input
|
||||
data-test="client-scope"
|
||||
value={oauth2ClientInfo.scope}
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ import {
|
|||
passwordField,
|
||||
portField,
|
||||
queryField,
|
||||
projectIdfield,
|
||||
usernameField,
|
||||
} from './CommonParameters';
|
||||
import { OAuth2ClientField } from './OAuth2ClientField';
|
||||
|
|
@ -50,6 +51,7 @@ export const FormFieldOrder = [
|
|||
'http_path',
|
||||
'http_path_field',
|
||||
'database_name',
|
||||
'project_id',
|
||||
'credentials_info',
|
||||
'service_account_info',
|
||||
'catalog',
|
||||
|
|
@ -89,4 +91,5 @@ export const FORM_FIELD_MAP = {
|
|||
role: validatedInputField,
|
||||
account: validatedInputField,
|
||||
ssh: SSHTunnelSwitchComponent,
|
||||
project_id: projectIdfield,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -78,6 +78,7 @@ export type DatabaseObject = {
|
|||
role?: string;
|
||||
account?: string;
|
||||
ssh?: boolean;
|
||||
project_id?: string;
|
||||
};
|
||||
|
||||
// Performance
|
||||
|
|
@ -232,6 +233,7 @@ export enum ConfigurationMethod {
|
|||
|
||||
export enum Engines {
|
||||
GSheet = 'gsheets',
|
||||
BigQuery = 'bigquery',
|
||||
Snowflake = 'snowflake',
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ class TestConnectionDatabaseCommand(BaseCommand):
|
|||
self._context = context
|
||||
self._uri = uri
|
||||
|
||||
def run(self) -> None: # pylint: disable=too-many-statements
|
||||
def run(self) -> None: # pylint: disable=too-many-statements,too-many-branches
|
||||
self.validate()
|
||||
ex_str = ""
|
||||
ssh_tunnel = self._properties.get("ssh_tunnel")
|
||||
|
|
@ -225,6 +225,10 @@ class TestConnectionDatabaseCommand(BaseCommand):
|
|||
# bubble up the exception to return proper status code
|
||||
raise
|
||||
except Exception as ex:
|
||||
if database.is_oauth2_enabled() and database.db_engine_spec.needs_oauth2(
|
||||
ex
|
||||
):
|
||||
database.start_oauth2_dance()
|
||||
event_logger.log_with_context(
|
||||
action=get_log_connection_action(
|
||||
"test_connection_error", ssh_tunnel, ex
|
||||
|
|
|
|||
|
|
@ -121,6 +121,7 @@ from superset.sql_parse import Table
|
|||
from superset.superset_typing import FlaskResponse
|
||||
from superset.utils import json
|
||||
from superset.utils.core import error_msg_from_exception, parse_js_uri_path_item
|
||||
from superset.utils.decorators import transaction
|
||||
from superset.utils.oauth2 import decode_oauth2_state
|
||||
from superset.utils.ssh_tunnel import mask_password_info
|
||||
from superset.views.base_api import (
|
||||
|
|
@ -1341,6 +1342,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||
return self.response_404()
|
||||
|
||||
@expose("/oauth2/", methods=["GET"])
|
||||
@transaction()
|
||||
@event_logger.log_this_with_context(
|
||||
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.oauth2",
|
||||
log_to_statsd=True,
|
||||
|
|
@ -1428,7 +1430,6 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||
"refresh_token": token_response.get("refresh_token"),
|
||||
},
|
||||
)
|
||||
|
||||
# return blank page that closes itself
|
||||
return make_response(
|
||||
render_template("superset/oauth2.html", tab_id=state["tab_id"]),
|
||||
|
|
|
|||
|
|
@ -47,6 +47,11 @@ from superset.commands.database.ssh_tunnel.exceptions import (
|
|||
SSHTunnelMissingCredentials,
|
||||
)
|
||||
from superset.constants import PASSWORD_MASK
|
||||
from superset.databases.types import ( # pylint:disable=unused-import
|
||||
EncryptedDict, # noqa: F401
|
||||
EncryptedField,
|
||||
EncryptedString, # noqa: F401
|
||||
)
|
||||
from superset.databases.utils import make_url_safe
|
||||
from superset.db_engine_specs import get_engine_spec
|
||||
from superset.exceptions import CertificateException, SupersetSecurityException
|
||||
|
|
@ -941,20 +946,6 @@ class ImportV1DatabaseSchema(Schema):
|
|||
return
|
||||
|
||||
|
||||
class EncryptedField: # pylint: disable=too-few-public-methods
|
||||
"""
|
||||
A database field that should be stored in encrypted_extra.
|
||||
"""
|
||||
|
||||
|
||||
class EncryptedString(EncryptedField, fields.String):
|
||||
pass
|
||||
|
||||
|
||||
class EncryptedDict(EncryptedField, fields.Dict):
|
||||
pass
|
||||
|
||||
|
||||
def encrypted_field_properties(self, field: Any, **_) -> dict[str, Any]: # type: ignore
|
||||
ret = {}
|
||||
if isinstance(field, EncryptedField):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,34 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
# Field has been moved outside of the schemas.py file to
|
||||
# allow for it to be imported from outside of app_context
|
||||
from marshmallow import fields
|
||||
|
||||
|
||||
class EncryptedField: # pylint: disable=too-few-public-methods
|
||||
"""
|
||||
A database field that should be stored in encrypted_extra.
|
||||
"""
|
||||
|
||||
|
||||
class EncryptedString(EncryptedField, fields.String):
|
||||
pass
|
||||
|
||||
|
||||
class EncryptedDict(EncryptedField, fields.Dict):
|
||||
pass
|
||||
|
|
@ -1707,10 +1707,13 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
return sql
|
||||
|
||||
@classmethod
|
||||
def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]:
|
||||
def estimate_statement_cost(
|
||||
cls, database: Database, statement: str, cursor: Any
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate a SQL query that estimates the cost of a given statement.
|
||||
|
||||
:param database: A Database object
|
||||
:param statement: A single SQL statement
|
||||
:param cursor: Cursor instance
|
||||
:return: Dictionary with different costs
|
||||
|
|
@ -1781,6 +1784,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
cursor = conn.cursor()
|
||||
return [
|
||||
cls.estimate_statement_cost(
|
||||
database,
|
||||
cls.process_statement(statement, database),
|
||||
cursor,
|
||||
)
|
||||
|
|
@ -1809,8 +1813,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
return url
|
||||
|
||||
@classmethod
|
||||
def update_impersonation_config(
|
||||
def update_impersonation_config( # pylint: disable=too-many-arguments
|
||||
cls,
|
||||
database: Database,
|
||||
connect_args: dict[str, Any],
|
||||
uri: str,
|
||||
username: str | None,
|
||||
|
|
@ -1820,6 +1825,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
Update a configuration dictionary
|
||||
that can set the correct properties for impersonating users
|
||||
|
||||
:param connect_args: a Database object
|
||||
:param connect_args: config to be updated
|
||||
:param uri: URI
|
||||
:param username: Effective username
|
||||
|
|
|
|||
|
|
@ -409,7 +409,11 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
|
|||
pandas_gbq.to_gbq(df, **to_gbq_kwargs)
|
||||
|
||||
@classmethod
|
||||
def _get_client(cls, engine: Engine) -> bigquery.Client:
|
||||
def _get_client(
|
||||
cls,
|
||||
engine: Engine,
|
||||
database: Database, # pylint: disable=unused-argument
|
||||
) -> bigquery.Client:
|
||||
"""
|
||||
Return the BigQuery client associated with an engine.
|
||||
"""
|
||||
|
|
@ -453,7 +457,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
|
|||
catalog=catalog,
|
||||
schema=schema,
|
||||
) as engine:
|
||||
client = cls._get_client(engine)
|
||||
client = cls._get_client(engine, database)
|
||||
return [
|
||||
cls.custom_estimate_statement_cost(
|
||||
cls.process_statement(statement, database),
|
||||
|
|
@ -477,7 +481,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
|
|||
return project
|
||||
|
||||
with database.get_sqla_engine() as engine:
|
||||
client = cls._get_client(engine)
|
||||
client = cls._get_client(engine, database)
|
||||
return client.project
|
||||
|
||||
@classmethod
|
||||
|
|
@ -493,7 +497,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
|
|||
"""
|
||||
engine: Engine
|
||||
with database.get_sqla_engine() as engine:
|
||||
client = cls._get_client(engine)
|
||||
client = cls._get_client(engine, database)
|
||||
projects = client.list_projects()
|
||||
|
||||
return {project.project_id for project in projects}
|
||||
|
|
|
|||
|
|
@ -537,8 +537,9 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||
return url
|
||||
|
||||
@classmethod
|
||||
def update_impersonation_config(
|
||||
def update_impersonation_config( # pylint: disable=too-many-arguments
|
||||
cls,
|
||||
database: Database,
|
||||
connect_args: dict[str, Any],
|
||||
uri: str,
|
||||
username: str | None,
|
||||
|
|
@ -547,6 +548,7 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||
"""
|
||||
Update a configuration dictionary
|
||||
that can set the correct properties for impersonating users
|
||||
:param database: the Database Object
|
||||
:param connect_args:
|
||||
:param uri: URI string
|
||||
:param impersonate_user: Flag indicating if impersonation is enabled
|
||||
|
|
|
|||
|
|
@ -351,7 +351,16 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec):
|
|||
return True
|
||||
|
||||
@classmethod
|
||||
def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]:
|
||||
def estimate_statement_cost(
|
||||
cls, database: Database, statement: str, cursor: Any
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Run a SQL query that estimates the cost of a given statement.
|
||||
:param database: A Database object
|
||||
:param statement: A single SQL statement
|
||||
:param cursor: Cursor instance
|
||||
:return: JSON response from Trino
|
||||
"""
|
||||
sql = f"EXPLAIN {statement}"
|
||||
cursor.execute(sql)
|
||||
|
||||
|
|
|
|||
|
|
@ -365,9 +365,12 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
|
|||
return parse.unquote(database.split("/")[1])
|
||||
|
||||
@classmethod
|
||||
def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]:
|
||||
def estimate_statement_cost(
|
||||
cls, database: Database, statement: str, cursor: Any
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Run a SQL query that estimates the cost of a given statement.
|
||||
:param database: A Database object
|
||||
:param statement: A single SQL statement
|
||||
:param cursor: Cursor instance
|
||||
:return: JSON response from Trino
|
||||
|
|
@ -945,8 +948,9 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
|||
return version is not None and Version(version) >= Version("0.319")
|
||||
|
||||
@classmethod
|
||||
def update_impersonation_config(
|
||||
def update_impersonation_config( # pylint: disable=too-many-arguments
|
||||
cls,
|
||||
database: Database,
|
||||
connect_args: dict[str, Any],
|
||||
uri: str,
|
||||
username: str | None,
|
||||
|
|
@ -955,6 +959,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
|||
"""
|
||||
Update a configuration dictionary
|
||||
that can set the correct properties for impersonating users
|
||||
|
||||
:param connect_args: the Database object
|
||||
:param connect_args: config to be updated
|
||||
:param uri: URI string
|
||||
:param username: Effective username
|
||||
|
|
|
|||
|
|
@ -116,8 +116,9 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
|
|||
return metadata
|
||||
|
||||
@classmethod
|
||||
def update_impersonation_config(
|
||||
def update_impersonation_config( # pylint: disable=too-many-arguments
|
||||
cls,
|
||||
database: Database,
|
||||
connect_args: dict[str, Any],
|
||||
uri: str,
|
||||
username: str | None,
|
||||
|
|
@ -126,6 +127,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
|
|||
"""
|
||||
Update a configuration dictionary
|
||||
that can set the correct properties for impersonating users
|
||||
:param database: the Database object
|
||||
:param connect_args: config to be updated
|
||||
:param uri: URI string
|
||||
:param username: Effective username
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from contextlib import closing, contextmanager, nullcontext, suppress
|
|||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from inspect import signature
|
||||
from typing import Any, Callable, cast, TYPE_CHECKING
|
||||
|
||||
import numpy
|
||||
|
|
@ -510,12 +511,14 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
|||
logger.debug("Database._get_sqla_engine(). Masked URL: %s", str(masked_url))
|
||||
|
||||
if self.impersonate_user:
|
||||
self.db_engine_spec.update_impersonation_config(
|
||||
connect_args,
|
||||
str(sqlalchemy_url),
|
||||
effective_username,
|
||||
access_token,
|
||||
# PR #30674 changed the signature of the method to include database.
|
||||
# This ensures that the change is backwards compatible
|
||||
args = [connect_args, str(sqlalchemy_url), effective_username, access_token]
|
||||
args = self.add_database_to_signature(
|
||||
self.db_engine_spec.update_impersonation_config,
|
||||
args,
|
||||
)
|
||||
self.db_engine_spec.update_impersonation_config(*args)
|
||||
|
||||
if connect_args:
|
||||
params["connect_args"] = connect_args
|
||||
|
|
@ -543,6 +546,24 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
|||
except Exception as ex:
|
||||
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
|
||||
|
||||
def add_database_to_signature(
|
||||
self,
|
||||
func: Callable[..., None],
|
||||
args: list[Any],
|
||||
) -> list[Any]:
|
||||
"""
|
||||
Examines a function signature looking for a database param.
|
||||
If the signature requires a database, the function appends self in the
|
||||
proper position.
|
||||
"""
|
||||
|
||||
# PR #30674 changed the signature of the method to include database.
|
||||
# This ensures that the change is backwards compatible
|
||||
sig = signature(func)
|
||||
if "database" in (params := sig.parameters.keys()):
|
||||
args.insert(list(params).index("database"), self)
|
||||
return args
|
||||
|
||||
@contextmanager
|
||||
def get_raw_connection(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -151,12 +151,13 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
|
|||
DB Eng Specs (postgres): Test estimate_statement_cost select star
|
||||
"""
|
||||
|
||||
database = mock.Mock()
|
||||
cursor = mock.Mock()
|
||||
cursor.fetchone.return_value = (
|
||||
"Seq Scan on birth_names (cost=0.00..1537.91 rows=75691 width=46)",
|
||||
)
|
||||
sql = "SELECT * FROM birth_names"
|
||||
results = PostgresEngineSpec.estimate_statement_cost(sql, cursor)
|
||||
results = PostgresEngineSpec.estimate_statement_cost(database, sql, cursor)
|
||||
assert results == {"Start-up cost": 0.0, "Total cost": 1537.91}
|
||||
|
||||
def test_estimate_statement_invalid_syntax(self):
|
||||
|
|
@ -165,6 +166,7 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
|
|||
"""
|
||||
from psycopg2 import errors
|
||||
|
||||
database = mock.Mock()
|
||||
cursor = mock.Mock()
|
||||
cursor.execute.side_effect = errors.SyntaxError(
|
||||
"""
|
||||
|
|
@ -175,7 +177,7 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
|
|||
)
|
||||
sql = "DROP TABLE birth_names"
|
||||
with self.assertRaises(errors.SyntaxError):
|
||||
PostgresEngineSpec.estimate_statement_cost(sql, cursor)
|
||||
PostgresEngineSpec.estimate_statement_cost(database, sql, cursor)
|
||||
|
||||
def test_query_cost_formatter_example_costs(self):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -905,22 +905,26 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||
)
|
||||
|
||||
def test_estimate_statement_cost(self):
|
||||
mock_database = mock.MagicMock()
|
||||
mock_cursor = mock.MagicMock()
|
||||
estimate_json = {"a": "b"}
|
||||
mock_cursor.fetchone.return_value = [
|
||||
'{"a": "b"}',
|
||||
]
|
||||
result = PrestoEngineSpec.estimate_statement_cost(
|
||||
"SELECT * FROM brth_names", mock_cursor
|
||||
mock_database,
|
||||
"SELECT * FROM brth_names",
|
||||
mock_cursor,
|
||||
)
|
||||
assert result == estimate_json
|
||||
|
||||
def test_estimate_statement_cost_invalid_syntax(self):
|
||||
mock_database = mock.MagicMock()
|
||||
mock_cursor = mock.MagicMock()
|
||||
mock_cursor.execute.side_effect = Exception()
|
||||
with self.assertRaises(Exception):
|
||||
PrestoEngineSpec.estimate_statement_cost(
|
||||
"DROP TABLE brth_names", mock_cursor
|
||||
mock_database, "DROP TABLE brth_names", mock_cursor
|
||||
)
|
||||
|
||||
def test_get_create_view(self):
|
||||
|
|
|
|||
|
|
@ -432,6 +432,31 @@ def test_get_sqla_engine_user_impersonation(mocker: MockerFixture) -> None:
|
|||
)
|
||||
|
||||
|
||||
def test_add_database_to_signature():
|
||||
args = ["param1", "param2"]
|
||||
|
||||
def func_without_db(param1, param2):
|
||||
pass
|
||||
|
||||
def func_with_db_start(database, param1, param2):
|
||||
pass
|
||||
|
||||
def func_with_db_end(param1, param2, database):
|
||||
pass
|
||||
|
||||
database = Database(
|
||||
database_name="my_db",
|
||||
sqlalchemy_uri="trino://",
|
||||
impersonate_user=True,
|
||||
)
|
||||
args1 = database.add_database_to_signature(func_without_db, args.copy())
|
||||
assert args1 == ["param1", "param2"]
|
||||
args2 = database.add_database_to_signature(func_with_db_start, args.copy())
|
||||
assert args2 == [database, "param1", "param2"]
|
||||
args3 = database.add_database_to_signature(func_with_db_end, args.copy())
|
||||
assert args3 == ["param1", "param2", database]
|
||||
|
||||
|
||||
@with_feature_flags(IMPERSONATE_WITH_EMAIL_PREFIX=True)
|
||||
def test_get_sqla_engine_user_impersonation_email(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue