diff --git a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/CommonParameters.tsx b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/CommonParameters.tsx
index 33a258d72..696f52baa 100644
--- a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/CommonParameters.tsx
+++ b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/CommonParameters.tsx
@@ -318,3 +318,25 @@ export const forceSSLField = ({
/>
);
+
+export const projectIdfield = ({
+ changeMethods,
+ getValidation,
+ validationErrors,
+ db,
+}: FieldPropTypes) => (
+ <>
+
+ >
+);
diff --git a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/EncryptedField.tsx b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/EncryptedField.tsx
index 2b9223141..fd11cd327 100644
--- a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/EncryptedField.tsx
+++ b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/EncryptedField.tsx
@@ -57,10 +57,11 @@ export const EncryptedField = ({
db?.engine === 'gsheets' ? !isEditMode && !isPublic : !isEditMode;
const isEncrypted = isEditMode && db?.masked_encrypted_extra !== '{}';
const encryptedField = db?.engine && encryptedCredentialsMap[db.engine];
+ const paramValue = db?.parameters?.[encryptedField];
const encryptedValue =
- typeof db?.parameters?.[encryptedField] === 'object'
- ? JSON.stringify(db?.parameters?.[encryptedField])
- : db?.parameters?.[encryptedField];
+ paramValue && typeof paramValue === 'object'
+ ? JSON.stringify(paramValue)
+ : paramValue;
return (
{db?.engine === 'gsheets' && (
diff --git a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/OAuth2ClientField.tsx b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/OAuth2ClientField.tsx
index ee0ffdeb3..fac3c3331 100644
--- a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/OAuth2ClientField.tsx
+++ b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/OAuth2ClientField.tsx
@@ -24,6 +24,14 @@ import { Input } from 'src/components/Input';
import { FormItem } from 'src/components/Form';
import { FieldPropTypes } from '../../types';
+const LABELS = {
+ CLIENT_ID: 'Client ID',
+ SECRET: 'Client Secret',
+ AUTH_URI: 'Authorization Request URI',
+ TOKEN_URI: 'Token Request URI',
+ SCOPE: 'Scope',
+};
+
interface OAuth2ClientInfo {
id: string;
secret: string;
@@ -44,10 +52,6 @@ export const OAuth2ClientField = ({ changeMethods, db }: FieldPropTypes) => {
scope: encryptedExtra.oauth2_client_info?.scope || '',
});
- if (db?.engine_information?.supports_oauth2 !== true) {
- return null;
- }
-
const handleChange = (key: any) => (e: any) => {
const updatedInfo = {
...oauth2ClientInfo,
@@ -68,14 +72,14 @@ export const OAuth2ClientField = ({ changeMethods, db }: FieldPropTypes) => {
return (
-
+
-
+
{
onChange={handleChange('secret')}
/>
-
+
{
onChange={handleChange('authorization_request_uri')}
/>
-
+
{
onChange={handleChange('token_request_uri')}
/>
-
+
None: # pylint: disable=too-many-statements
+ def run(self) -> None: # pylint: disable=too-many-statements,too-many-branches
self.validate()
ex_str = ""
ssh_tunnel = self._properties.get("ssh_tunnel")
@@ -225,6 +225,10 @@ class TestConnectionDatabaseCommand(BaseCommand):
# bubble up the exception to return proper status code
raise
except Exception as ex:
+ if database.is_oauth2_enabled() and database.db_engine_spec.needs_oauth2(
+ ex
+ ):
+ database.start_oauth2_dance()
event_logger.log_with_context(
action=get_log_connection_action(
"test_connection_error", ssh_tunnel, ex
diff --git a/superset/databases/api.py b/superset/databases/api.py
index 88188bed5..542daa93a 100644
--- a/superset/databases/api.py
+++ b/superset/databases/api.py
@@ -121,6 +121,7 @@ from superset.sql_parse import Table
from superset.superset_typing import FlaskResponse
from superset.utils import json
from superset.utils.core import error_msg_from_exception, parse_js_uri_path_item
+from superset.utils.decorators import transaction
from superset.utils.oauth2 import decode_oauth2_state
from superset.utils.ssh_tunnel import mask_password_info
from superset.views.base_api import (
@@ -1341,6 +1342,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
return self.response_404()
@expose("/oauth2/", methods=["GET"])
+ @transaction()
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.oauth2",
log_to_statsd=True,
@@ -1428,7 +1430,6 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
"refresh_token": token_response.get("refresh_token"),
},
)
-
# return blank page that closes itself
return make_response(
render_template("superset/oauth2.html", tab_id=state["tab_id"]),
diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py
index 27eb043eb..ed4e67d30 100644
--- a/superset/databases/schemas.py
+++ b/superset/databases/schemas.py
@@ -47,6 +47,11 @@ from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelMissingCredentials,
)
from superset.constants import PASSWORD_MASK
+from superset.databases.types import ( # pylint:disable=unused-import
+ EncryptedDict, # noqa: F401
+ EncryptedField,
+ EncryptedString, # noqa: F401
+)
from superset.databases.utils import make_url_safe
from superset.db_engine_specs import get_engine_spec
from superset.exceptions import CertificateException, SupersetSecurityException
@@ -941,20 +946,6 @@ class ImportV1DatabaseSchema(Schema):
return
-class EncryptedField: # pylint: disable=too-few-public-methods
- """
- A database field that should be stored in encrypted_extra.
- """
-
-
-class EncryptedString(EncryptedField, fields.String):
- pass
-
-
-class EncryptedDict(EncryptedField, fields.Dict):
- pass
-
-
def encrypted_field_properties(self, field: Any, **_) -> dict[str, Any]: # type: ignore
ret = {}
if isinstance(field, EncryptedField):
diff --git a/superset/databases/types.py b/superset/databases/types.py
new file mode 100644
index 000000000..4ab442860
--- /dev/null
+++ b/superset/databases/types.py
@@ -0,0 +1,34 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Field has been moved outside of the schemas.py file to
+# allow for it to be imported from outside of app_context
+from marshmallow import fields
+
+
+class EncryptedField: # pylint: disable=too-few-public-methods
+ """
+ A database field that should be stored in encrypted_extra.
+ """
+
+
+class EncryptedString(EncryptedField, fields.String):
+ pass
+
+
+class EncryptedDict(EncryptedField, fields.Dict):
+ pass
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 2e555e32f..a086f6eff 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -1707,10 +1707,13 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return sql
@classmethod
- def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]:
+ def estimate_statement_cost(
+ cls, database: Database, statement: str, cursor: Any
+ ) -> dict[str, Any]:
"""
Generate a SQL query that estimates the cost of a given statement.
+ :param database: A Database object
:param statement: A single SQL statement
:param cursor: Cursor instance
:return: Dictionary with different costs
@@ -1781,6 +1784,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
cursor = conn.cursor()
return [
cls.estimate_statement_cost(
+ database,
cls.process_statement(statement, database),
cursor,
)
@@ -1809,8 +1813,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return url
@classmethod
- def update_impersonation_config(
+ def update_impersonation_config( # pylint: disable=too-many-arguments
cls,
+ database: Database,
connect_args: dict[str, Any],
uri: str,
username: str | None,
@@ -1820,6 +1825,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
Update a configuration dictionary
that can set the correct properties for impersonating users
+ :param connect_args: a Database object
:param connect_args: config to be updated
:param uri: URI
:param username: Effective username
diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py
index 11175d795..70bc4bc84 100644
--- a/superset/db_engine_specs/bigquery.py
+++ b/superset/db_engine_specs/bigquery.py
@@ -409,7 +409,11 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
pandas_gbq.to_gbq(df, **to_gbq_kwargs)
@classmethod
- def _get_client(cls, engine: Engine) -> bigquery.Client:
+ def _get_client(
+ cls,
+ engine: Engine,
+ database: Database, # pylint: disable=unused-argument
+ ) -> bigquery.Client:
"""
Return the BigQuery client associated with an engine.
"""
@@ -453,7 +457,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
catalog=catalog,
schema=schema,
) as engine:
- client = cls._get_client(engine)
+ client = cls._get_client(engine, database)
return [
cls.custom_estimate_statement_cost(
cls.process_statement(statement, database),
@@ -477,7 +481,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
return project
with database.get_sqla_engine() as engine:
- client = cls._get_client(engine)
+ client = cls._get_client(engine, database)
return client.project
@classmethod
@@ -493,7 +497,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
"""
engine: Engine
with database.get_sqla_engine() as engine:
- client = cls._get_client(engine)
+ client = cls._get_client(engine, database)
projects = client.list_projects()
return {project.project_id for project in projects}
diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py
index e3cf128b7..6288866db 100644
--- a/superset/db_engine_specs/hive.py
+++ b/superset/db_engine_specs/hive.py
@@ -537,8 +537,9 @@ class HiveEngineSpec(PrestoEngineSpec):
return url
@classmethod
- def update_impersonation_config(
+ def update_impersonation_config( # pylint: disable=too-many-arguments
cls,
+ database: Database,
connect_args: dict[str, Any],
uri: str,
username: str | None,
@@ -547,6 +548,7 @@ class HiveEngineSpec(PrestoEngineSpec):
"""
Update a configuration dictionary
that can set the correct properties for impersonating users
+ :param database: the Database Object
:param connect_args:
:param uri: URI string
:param impersonate_user: Flag indicating if impersonation is enabled
diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py
index 70373927d..6281c6b3b 100644
--- a/superset/db_engine_specs/postgres.py
+++ b/superset/db_engine_specs/postgres.py
@@ -351,7 +351,16 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec):
return True
@classmethod
- def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]:
+ def estimate_statement_cost(
+ cls, database: Database, statement: str, cursor: Any
+ ) -> dict[str, Any]:
+ """
+ Run a SQL query that estimates the cost of a given statement.
+ :param database: A Database object
+ :param statement: A single SQL statement
+ :param cursor: Cursor instance
+ :return: JSON response from Trino
+ """
sql = f"EXPLAIN {statement}"
cursor.execute(sql)
diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py
index f0664564f..df5e1c643 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -365,9 +365,12 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
return parse.unquote(database.split("/")[1])
@classmethod
- def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]:
+ def estimate_statement_cost(
+ cls, database: Database, statement: str, cursor: Any
+ ) -> dict[str, Any]:
"""
Run a SQL query that estimates the cost of a given statement.
+ :param database: A Database object
:param statement: A single SQL statement
:param cursor: Cursor instance
:return: JSON response from Trino
@@ -945,8 +948,9 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
return version is not None and Version(version) >= Version("0.319")
@classmethod
- def update_impersonation_config(
+ def update_impersonation_config( # pylint: disable=too-many-arguments
cls,
+ database: Database,
connect_args: dict[str, Any],
uri: str,
username: str | None,
@@ -955,6 +959,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
"""
Update a configuration dictionary
that can set the correct properties for impersonating users
+
+ :param connect_args: the Database object
:param connect_args: config to be updated
:param uri: URI string
:param username: Effective username
diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py
index 49615c39c..c47352821 100644
--- a/superset/db_engine_specs/trino.py
+++ b/superset/db_engine_specs/trino.py
@@ -116,8 +116,9 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
return metadata
@classmethod
- def update_impersonation_config(
+ def update_impersonation_config( # pylint: disable=too-many-arguments
cls,
+ database: Database,
connect_args: dict[str, Any],
uri: str,
username: str | None,
@@ -126,6 +127,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
"""
Update a configuration dictionary
that can set the correct properties for impersonating users
+ :param database: the Database object
:param connect_args: config to be updated
:param uri: URI string
:param username: Effective username
diff --git a/superset/models/core.py b/superset/models/core.py
index 5d3a6ea74..418141272 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -29,6 +29,7 @@ from contextlib import closing, contextmanager, nullcontext, suppress
from copy import deepcopy
from datetime import datetime
from functools import lru_cache
+from inspect import signature
from typing import Any, Callable, cast, TYPE_CHECKING
import numpy
@@ -510,12 +511,14 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
logger.debug("Database._get_sqla_engine(). Masked URL: %s", str(masked_url))
if self.impersonate_user:
- self.db_engine_spec.update_impersonation_config(
- connect_args,
- str(sqlalchemy_url),
- effective_username,
- access_token,
+ # PR #30674 changed the signature of the method to include database.
+ # This ensures that the change is backwards compatible
+ args = [connect_args, str(sqlalchemy_url), effective_username, access_token]
+ args = self.add_database_to_signature(
+ self.db_engine_spec.update_impersonation_config,
+ args,
)
+ self.db_engine_spec.update_impersonation_config(*args)
if connect_args:
params["connect_args"] = connect_args
@@ -543,6 +546,24 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
except Exception as ex:
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
+ def add_database_to_signature(
+ self,
+ func: Callable[..., None],
+ args: list[Any],
+ ) -> list[Any]:
+ """
+ Examines a function signature looking for a database param.
+ If the signature requires a database, the function appends self in the
+ proper position.
+ """
+
+ # PR #30674 changed the signature of the method to include database.
+ # This ensures that the change is backwards compatible
+ sig = signature(func)
+ if "database" in (params := sig.parameters.keys()):
+ args.insert(list(params).index("database"), self)
+ return args
+
@contextmanager
def get_raw_connection(
self,
diff --git a/tests/integration_tests/db_engine_specs/postgres_tests.py b/tests/integration_tests/db_engine_specs/postgres_tests.py
index e4f9462d6..a5ef1cdec 100644
--- a/tests/integration_tests/db_engine_specs/postgres_tests.py
+++ b/tests/integration_tests/db_engine_specs/postgres_tests.py
@@ -151,12 +151,13 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
DB Eng Specs (postgres): Test estimate_statement_cost select star
"""
+ database = mock.Mock()
cursor = mock.Mock()
cursor.fetchone.return_value = (
"Seq Scan on birth_names (cost=0.00..1537.91 rows=75691 width=46)",
)
sql = "SELECT * FROM birth_names"
- results = PostgresEngineSpec.estimate_statement_cost(sql, cursor)
+ results = PostgresEngineSpec.estimate_statement_cost(database, sql, cursor)
assert results == {"Start-up cost": 0.0, "Total cost": 1537.91}
def test_estimate_statement_invalid_syntax(self):
@@ -165,6 +166,7 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
"""
from psycopg2 import errors
+ database = mock.Mock()
cursor = mock.Mock()
cursor.execute.side_effect = errors.SyntaxError(
"""
@@ -175,7 +177,7 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
)
sql = "DROP TABLE birth_names"
with self.assertRaises(errors.SyntaxError):
- PostgresEngineSpec.estimate_statement_cost(sql, cursor)
+ PostgresEngineSpec.estimate_statement_cost(database, sql, cursor)
def test_query_cost_formatter_example_costs(self):
"""
diff --git a/tests/integration_tests/db_engine_specs/presto_tests.py b/tests/integration_tests/db_engine_specs/presto_tests.py
index 798e31ee4..94e3ea627 100644
--- a/tests/integration_tests/db_engine_specs/presto_tests.py
+++ b/tests/integration_tests/db_engine_specs/presto_tests.py
@@ -905,22 +905,26 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
)
def test_estimate_statement_cost(self):
+ mock_database = mock.MagicMock()
mock_cursor = mock.MagicMock()
estimate_json = {"a": "b"}
mock_cursor.fetchone.return_value = [
'{"a": "b"}',
]
result = PrestoEngineSpec.estimate_statement_cost(
- "SELECT * FROM brth_names", mock_cursor
+ mock_database,
+ "SELECT * FROM brth_names",
+ mock_cursor,
)
assert result == estimate_json
def test_estimate_statement_cost_invalid_syntax(self):
+ mock_database = mock.MagicMock()
mock_cursor = mock.MagicMock()
mock_cursor.execute.side_effect = Exception()
with self.assertRaises(Exception):
PrestoEngineSpec.estimate_statement_cost(
- "DROP TABLE brth_names", mock_cursor
+ mock_database, "DROP TABLE brth_names", mock_cursor
)
def test_get_create_view(self):
diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py
index 3c591d446..452cbb6f5 100644
--- a/tests/unit_tests/models/core_test.py
+++ b/tests/unit_tests/models/core_test.py
@@ -432,6 +432,31 @@ def test_get_sqla_engine_user_impersonation(mocker: MockerFixture) -> None:
)
+def test_add_database_to_signature():
+ args = ["param1", "param2"]
+
+ def func_without_db(param1, param2):
+ pass
+
+ def func_with_db_start(database, param1, param2):
+ pass
+
+ def func_with_db_end(param1, param2, database):
+ pass
+
+ database = Database(
+ database_name="my_db",
+ sqlalchemy_uri="trino://",
+ impersonate_user=True,
+ )
+ args1 = database.add_database_to_signature(func_without_db, args.copy())
+ assert args1 == ["param1", "param2"]
+ args2 = database.add_database_to_signature(func_with_db_start, args.copy())
+ assert args2 == [database, "param1", "param2"]
+ args3 = database.add_database_to_signature(func_with_db_end, args.copy())
+ assert args3 == ["param1", "param2", database]
+
+
@with_feature_flags(IMPERSONATE_WITH_EMAIL_PREFIX=True)
def test_get_sqla_engine_user_impersonation_email(mocker: MockerFixture) -> None:
"""