chore(ssh): Allow users to set TUNNEL_TIMEOUT from config (#24202)
This commit is contained in:
parent
c54eedfdc0
commit
8b0c68c0d2
|
|
@ -515,6 +515,7 @@ DEFAULT_FEATURE_FLAGS: Dict[str, bool] = {
|
|||
# ----------------------------------------------------------------------
|
||||
SSH_TUNNEL_MANAGER_CLASS = "superset.extensions.ssh.SSHManager"
|
||||
SSH_TUNNEL_LOCAL_BIND_ADDRESS = "127.0.0.1"
|
||||
SSH_TUNNEL_TIMEOUT_SEC = 10.0
|
||||
|
||||
# Feature flags may also be set via 'SUPERSET_FEATURE_' prefixed environment vars.
|
||||
DEFAULT_FEATURE_FLAGS.update(
|
||||
|
|
|
|||
|
|
@ -20,9 +20,9 @@ import logging
|
|||
from io import StringIO
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import sshtunnel
|
||||
from flask import Flask
|
||||
from paramiko import RSAKey
|
||||
from sshtunnel import open_tunnel, SSHTunnelForwarder
|
||||
|
||||
from superset.databases.utils import make_url_safe
|
||||
|
||||
|
|
@ -34,9 +34,10 @@ class SSHManager:
|
|||
def __init__(self, app: Flask) -> None:
|
||||
super().__init__()
|
||||
self.local_bind_address = app.config["SSH_TUNNEL_LOCAL_BIND_ADDRESS"]
|
||||
sshtunnel.TUNNEL_TIMEOUT = app.config["SSH_TUNNEL_TIMEOUT_SEC"]
|
||||
|
||||
def build_sqla_url( # pylint: disable=no-self-use
|
||||
self, sqlalchemy_url: str, server: SSHTunnelForwarder
|
||||
self, sqlalchemy_url: str, server: sshtunnel.SSHTunnelForwarder
|
||||
) -> str:
|
||||
# override any ssh tunnel configuration object
|
||||
url = make_url_safe(sqlalchemy_url)
|
||||
|
|
@ -49,7 +50,7 @@ class SSHManager:
|
|||
self,
|
||||
ssh_tunnel: "SSHTunnel",
|
||||
sqlalchemy_database_uri: str,
|
||||
) -> SSHTunnelForwarder:
|
||||
) -> sshtunnel.SSHTunnelForwarder:
|
||||
url = make_url_safe(sqlalchemy_database_uri)
|
||||
params = {
|
||||
"ssh_address_or_host": (ssh_tunnel.server_address, ssh_tunnel.server_port),
|
||||
|
|
@ -68,7 +69,7 @@ class SSHManager:
|
|||
)
|
||||
params["ssh_pkey"] = private_key
|
||||
|
||||
return open_tunnel(**params)
|
||||
return sshtunnel.open_tunnel(**params)
|
||||
|
||||
|
||||
class SSHManagerFactory:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,36 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Any
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
import sshtunnel
|
||||
|
||||
from superset.extensions.ssh import SSHManagerFactory
|
||||
|
||||
|
||||
def test_ssh_tunnel_timeout_setting() -> None:
|
||||
app = Mock()
|
||||
app.config = {
|
||||
"SSH_TUNNEL_MAX_RETRIES": 2,
|
||||
"SSH_TUNNEL_LOCAL_BIND_ADDRESS": "test",
|
||||
"SSH_TUNNEL_TIMEOUT_SEC": 123.0,
|
||||
"SSH_TUNNEL_MANAGER_CLASS": "superset.extensions.ssh.SSHManager",
|
||||
}
|
||||
factory = SSHManagerFactory()
|
||||
factory.init_app(app)
|
||||
assert sshtunnel.TUNNEL_TIMEOUT == 123.0
|
||||
Loading…
Reference in New Issue