diff --git a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/CommonParameters.tsx b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/CommonParameters.tsx index 33a258d72..696f52baa 100644 --- a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/CommonParameters.tsx +++ b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/CommonParameters.tsx @@ -318,3 +318,25 @@ export const forceSSLField = ({ /> ); + +export const projectIdfield = ({ + changeMethods, + getValidation, + validationErrors, + db, +}: FieldPropTypes) => ( + <> + + +); diff --git a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/EncryptedField.tsx b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/EncryptedField.tsx index 2b9223141..fd11cd327 100644 --- a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/EncryptedField.tsx +++ b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/EncryptedField.tsx @@ -57,10 +57,11 @@ export const EncryptedField = ({ db?.engine === 'gsheets' ? !isEditMode && !isPublic : !isEditMode; const isEncrypted = isEditMode && db?.masked_encrypted_extra !== '{}'; const encryptedField = db?.engine && encryptedCredentialsMap[db.engine]; + const paramValue = db?.parameters?.[encryptedField]; const encryptedValue = - typeof db?.parameters?.[encryptedField] === 'object' - ? JSON.stringify(db?.parameters?.[encryptedField]) - : db?.parameters?.[encryptedField]; + paramValue && typeof paramValue === 'object' + ? JSON.stringify(paramValue) + : paramValue; return ( {db?.engine === 'gsheets' && ( diff --git a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/OAuth2ClientField.tsx b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/OAuth2ClientField.tsx index ee0ffdeb3..fac3c3331 100644 --- a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/OAuth2ClientField.tsx +++ b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/OAuth2ClientField.tsx @@ -24,6 +24,14 @@ import { Input } from 'src/components/Input'; import { FormItem } from 'src/components/Form'; import { FieldPropTypes } from '../../types'; +const LABELS = { + CLIENT_ID: 'Client ID', + SECRET: 'Client Secret', + AUTH_URI: 'Authorization Request URI', + TOKEN_URI: 'Token Request URI', + SCOPE: 'Scope', +}; + interface OAuth2ClientInfo { id: string; secret: string; @@ -44,10 +52,6 @@ export const OAuth2ClientField = ({ changeMethods, db }: FieldPropTypes) => { scope: encryptedExtra.oauth2_client_info?.scope || '', }); - if (db?.engine_information?.supports_oauth2 !== true) { - return null; - } - const handleChange = (key: any) => (e: any) => { const updatedInfo = { ...oauth2ClientInfo, @@ -68,14 +72,14 @@ export const OAuth2ClientField = ({ changeMethods, db }: FieldPropTypes) => { return ( - + - + { onChange={handleChange('secret')} /> - + { onChange={handleChange('authorization_request_uri')} /> - + { onChange={handleChange('token_request_uri')} /> - + None: # pylint: disable=too-many-statements + def run(self) -> None: # pylint: disable=too-many-statements,too-many-branches self.validate() ex_str = "" ssh_tunnel = self._properties.get("ssh_tunnel") @@ -225,6 +225,10 @@ class TestConnectionDatabaseCommand(BaseCommand): # bubble up the exception to return proper status code raise except Exception as ex: + if database.is_oauth2_enabled() and database.db_engine_spec.needs_oauth2( + ex + ): + database.start_oauth2_dance() event_logger.log_with_context( action=get_log_connection_action( "test_connection_error", ssh_tunnel, ex diff --git a/superset/databases/api.py b/superset/databases/api.py index 88188bed5..542daa93a 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -121,6 +121,7 @@ from superset.sql_parse import Table from superset.superset_typing import FlaskResponse from superset.utils import json from superset.utils.core import error_msg_from_exception, parse_js_uri_path_item +from superset.utils.decorators import transaction from superset.utils.oauth2 import decode_oauth2_state from superset.utils.ssh_tunnel import mask_password_info from superset.views.base_api import ( @@ -1341,6 +1342,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): return self.response_404() @expose("/oauth2/", methods=["GET"]) + @transaction() @event_logger.log_this_with_context( action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.oauth2", log_to_statsd=True, @@ -1428,7 +1430,6 @@ class DatabaseRestApi(BaseSupersetModelRestApi): "refresh_token": token_response.get("refresh_token"), }, ) - # return blank page that closes itself return make_response( render_template("superset/oauth2.html", tab_id=state["tab_id"]), diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index 27eb043eb..ed4e67d30 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -47,6 +47,11 @@ from superset.commands.database.ssh_tunnel.exceptions import ( SSHTunnelMissingCredentials, ) from superset.constants import PASSWORD_MASK +from superset.databases.types import ( # pylint:disable=unused-import + EncryptedDict, # noqa: F401 + EncryptedField, + EncryptedString, # noqa: F401 +) from superset.databases.utils import make_url_safe from superset.db_engine_specs import get_engine_spec from superset.exceptions import CertificateException, SupersetSecurityException @@ -941,20 +946,6 @@ class ImportV1DatabaseSchema(Schema): return -class EncryptedField: # pylint: disable=too-few-public-methods - """ - A database field that should be stored in encrypted_extra. - """ - - -class EncryptedString(EncryptedField, fields.String): - pass - - -class EncryptedDict(EncryptedField, fields.Dict): - pass - - def encrypted_field_properties(self, field: Any, **_) -> dict[str, Any]: # type: ignore ret = {} if isinstance(field, EncryptedField): diff --git a/superset/databases/types.py b/superset/databases/types.py new file mode 100644 index 000000000..4ab442860 --- /dev/null +++ b/superset/databases/types.py @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Field has been moved outside of the schemas.py file to +# allow for it to be imported from outside of app_context +from marshmallow import fields + + +class EncryptedField: # pylint: disable=too-few-public-methods + """ + A database field that should be stored in encrypted_extra. + """ + + +class EncryptedString(EncryptedField, fields.String): + pass + + +class EncryptedDict(EncryptedField, fields.Dict): + pass diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 2e555e32f..a086f6eff 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -1707,10 +1707,13 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return sql @classmethod - def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]: + def estimate_statement_cost( + cls, database: Database, statement: str, cursor: Any + ) -> dict[str, Any]: """ Generate a SQL query that estimates the cost of a given statement. + :param database: A Database object :param statement: A single SQL statement :param cursor: Cursor instance :return: Dictionary with different costs @@ -1781,6 +1784,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods cursor = conn.cursor() return [ cls.estimate_statement_cost( + database, cls.process_statement(statement, database), cursor, ) @@ -1809,8 +1813,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return url @classmethod - def update_impersonation_config( + def update_impersonation_config( # pylint: disable=too-many-arguments cls, + database: Database, connect_args: dict[str, Any], uri: str, username: str | None, @@ -1820,6 +1825,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods Update a configuration dictionary that can set the correct properties for impersonating users + :param connect_args: a Database object :param connect_args: config to be updated :param uri: URI :param username: Effective username diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py index 11175d795..70bc4bc84 100644 --- a/superset/db_engine_specs/bigquery.py +++ b/superset/db_engine_specs/bigquery.py @@ -409,7 +409,11 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met pandas_gbq.to_gbq(df, **to_gbq_kwargs) @classmethod - def _get_client(cls, engine: Engine) -> bigquery.Client: + def _get_client( + cls, + engine: Engine, + database: Database, # pylint: disable=unused-argument + ) -> bigquery.Client: """ Return the BigQuery client associated with an engine. """ @@ -453,7 +457,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met catalog=catalog, schema=schema, ) as engine: - client = cls._get_client(engine) + client = cls._get_client(engine, database) return [ cls.custom_estimate_statement_cost( cls.process_statement(statement, database), @@ -477,7 +481,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met return project with database.get_sqla_engine() as engine: - client = cls._get_client(engine) + client = cls._get_client(engine, database) return client.project @classmethod @@ -493,7 +497,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met """ engine: Engine with database.get_sqla_engine() as engine: - client = cls._get_client(engine) + client = cls._get_client(engine, database) projects = client.list_projects() return {project.project_id for project in projects} diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index e3cf128b7..6288866db 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -537,8 +537,9 @@ class HiveEngineSpec(PrestoEngineSpec): return url @classmethod - def update_impersonation_config( + def update_impersonation_config( # pylint: disable=too-many-arguments cls, + database: Database, connect_args: dict[str, Any], uri: str, username: str | None, @@ -547,6 +548,7 @@ class HiveEngineSpec(PrestoEngineSpec): """ Update a configuration dictionary that can set the correct properties for impersonating users + :param database: the Database Object :param connect_args: :param uri: URI string :param impersonate_user: Flag indicating if impersonation is enabled diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index 70373927d..6281c6b3b 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -351,7 +351,16 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec): return True @classmethod - def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]: + def estimate_statement_cost( + cls, database: Database, statement: str, cursor: Any + ) -> dict[str, Any]: + """ + Run a SQL query that estimates the cost of a given statement. + :param database: A Database object + :param statement: A single SQL statement + :param cursor: Cursor instance + :return: JSON response from Trino + """ sql = f"EXPLAIN {statement}" cursor.execute(sql) diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index f0664564f..df5e1c643 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -365,9 +365,12 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta): return parse.unquote(database.split("/")[1]) @classmethod - def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]: + def estimate_statement_cost( + cls, database: Database, statement: str, cursor: Any + ) -> dict[str, Any]: """ Run a SQL query that estimates the cost of a given statement. + :param database: A Database object :param statement: A single SQL statement :param cursor: Cursor instance :return: JSON response from Trino @@ -945,8 +948,9 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): return version is not None and Version(version) >= Version("0.319") @classmethod - def update_impersonation_config( + def update_impersonation_config( # pylint: disable=too-many-arguments cls, + database: Database, connect_args: dict[str, Any], uri: str, username: str | None, @@ -955,6 +959,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): """ Update a configuration dictionary that can set the correct properties for impersonating users + + :param connect_args: the Database object :param connect_args: config to be updated :param uri: URI string :param username: Effective username diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 49615c39c..c47352821 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -116,8 +116,9 @@ class TrinoEngineSpec(PrestoBaseEngineSpec): return metadata @classmethod - def update_impersonation_config( + def update_impersonation_config( # pylint: disable=too-many-arguments cls, + database: Database, connect_args: dict[str, Any], uri: str, username: str | None, @@ -126,6 +127,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec): """ Update a configuration dictionary that can set the correct properties for impersonating users + :param database: the Database object :param connect_args: config to be updated :param uri: URI string :param username: Effective username diff --git a/superset/models/core.py b/superset/models/core.py index 5d3a6ea74..418141272 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -29,6 +29,7 @@ from contextlib import closing, contextmanager, nullcontext, suppress from copy import deepcopy from datetime import datetime from functools import lru_cache +from inspect import signature from typing import Any, Callable, cast, TYPE_CHECKING import numpy @@ -510,12 +511,14 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable logger.debug("Database._get_sqla_engine(). Masked URL: %s", str(masked_url)) if self.impersonate_user: - self.db_engine_spec.update_impersonation_config( - connect_args, - str(sqlalchemy_url), - effective_username, - access_token, + # PR #30674 changed the signature of the method to include database. + # This ensures that the change is backwards compatible + args = [connect_args, str(sqlalchemy_url), effective_username, access_token] + args = self.add_database_to_signature( + self.db_engine_spec.update_impersonation_config, + args, ) + self.db_engine_spec.update_impersonation_config(*args) if connect_args: params["connect_args"] = connect_args @@ -543,6 +546,24 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable except Exception as ex: raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex + def add_database_to_signature( + self, + func: Callable[..., None], + args: list[Any], + ) -> list[Any]: + """ + Examines a function signature looking for a database param. + If the signature requires a database, the function appends self in the + proper position. + """ + + # PR #30674 changed the signature of the method to include database. + # This ensures that the change is backwards compatible + sig = signature(func) + if "database" in (params := sig.parameters.keys()): + args.insert(list(params).index("database"), self) + return args + @contextmanager def get_raw_connection( self, diff --git a/tests/integration_tests/db_engine_specs/postgres_tests.py b/tests/integration_tests/db_engine_specs/postgres_tests.py index e4f9462d6..a5ef1cdec 100644 --- a/tests/integration_tests/db_engine_specs/postgres_tests.py +++ b/tests/integration_tests/db_engine_specs/postgres_tests.py @@ -151,12 +151,13 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec): DB Eng Specs (postgres): Test estimate_statement_cost select star """ + database = mock.Mock() cursor = mock.Mock() cursor.fetchone.return_value = ( "Seq Scan on birth_names (cost=0.00..1537.91 rows=75691 width=46)", ) sql = "SELECT * FROM birth_names" - results = PostgresEngineSpec.estimate_statement_cost(sql, cursor) + results = PostgresEngineSpec.estimate_statement_cost(database, sql, cursor) assert results == {"Start-up cost": 0.0, "Total cost": 1537.91} def test_estimate_statement_invalid_syntax(self): @@ -165,6 +166,7 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec): """ from psycopg2 import errors + database = mock.Mock() cursor = mock.Mock() cursor.execute.side_effect = errors.SyntaxError( """ @@ -175,7 +177,7 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec): ) sql = "DROP TABLE birth_names" with self.assertRaises(errors.SyntaxError): - PostgresEngineSpec.estimate_statement_cost(sql, cursor) + PostgresEngineSpec.estimate_statement_cost(database, sql, cursor) def test_query_cost_formatter_example_costs(self): """ diff --git a/tests/integration_tests/db_engine_specs/presto_tests.py b/tests/integration_tests/db_engine_specs/presto_tests.py index 798e31ee4..94e3ea627 100644 --- a/tests/integration_tests/db_engine_specs/presto_tests.py +++ b/tests/integration_tests/db_engine_specs/presto_tests.py @@ -905,22 +905,26 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec): ) def test_estimate_statement_cost(self): + mock_database = mock.MagicMock() mock_cursor = mock.MagicMock() estimate_json = {"a": "b"} mock_cursor.fetchone.return_value = [ '{"a": "b"}', ] result = PrestoEngineSpec.estimate_statement_cost( - "SELECT * FROM brth_names", mock_cursor + mock_database, + "SELECT * FROM brth_names", + mock_cursor, ) assert result == estimate_json def test_estimate_statement_cost_invalid_syntax(self): + mock_database = mock.MagicMock() mock_cursor = mock.MagicMock() mock_cursor.execute.side_effect = Exception() with self.assertRaises(Exception): PrestoEngineSpec.estimate_statement_cost( - "DROP TABLE brth_names", mock_cursor + mock_database, "DROP TABLE brth_names", mock_cursor ) def test_get_create_view(self): diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index 3c591d446..452cbb6f5 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -432,6 +432,31 @@ def test_get_sqla_engine_user_impersonation(mocker: MockerFixture) -> None: ) +def test_add_database_to_signature(): + args = ["param1", "param2"] + + def func_without_db(param1, param2): + pass + + def func_with_db_start(database, param1, param2): + pass + + def func_with_db_end(param1, param2, database): + pass + + database = Database( + database_name="my_db", + sqlalchemy_uri="trino://", + impersonate_user=True, + ) + args1 = database.add_database_to_signature(func_without_db, args.copy()) + assert args1 == ["param1", "param2"] + args2 = database.add_database_to_signature(func_with_db_start, args.copy()) + assert args2 == [database, "param1", "param2"] + args3 = database.add_database_to_signature(func_with_db_end, args.copy()) + assert args3 == ["param1", "param2", database] + + @with_feature_flags(IMPERSONATE_WITH_EMAIL_PREFIX=True) def test_get_sqla_engine_user_impersonation_email(mocker: MockerFixture) -> None: """