diff --git a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/CommonParameters.tsx b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/CommonParameters.tsx
index 529fc1841..4d864ba11 100644
--- a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/CommonParameters.tsx
+++ b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/CommonParameters.tsx
@@ -116,6 +116,69 @@ export const databaseField = ({
helpText={t('Copy the name of the database you are trying to connect to.')}
/>
);
+export const defaultCatalogField = ({
+ required,
+ changeMethods,
+ getValidation,
+ validationErrors,
+ db,
+}: FieldPropTypes) => (
+
+);
+export const defaultSchemaField = ({
+ required,
+ changeMethods,
+ getValidation,
+ validationErrors,
+ db,
+}: FieldPropTypes) => (
+
+);
+export const httpPathField = ({
+ required,
+ changeMethods,
+ getValidation,
+ validationErrors,
+ db,
+}: FieldPropTypes) => {
+ console.error(db);
+ return (
+
+ );
+};
export const usernameField = ({
required,
changeMethods,
diff --git a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/index.tsx b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/index.tsx
index 509103ea2..aff755b95 100644
--- a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/index.tsx
+++ b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/index.tsx
@@ -27,10 +27,13 @@ import { Form } from 'src/components/Form';
import {
accessTokenField,
databaseField,
+ defaultCatalogField,
+ defaultSchemaField,
displayField,
forceSSLField,
hostField,
httpPath,
+ httpPathField,
passwordField,
portField,
queryField,
@@ -47,10 +50,13 @@ export const FormFieldOrder = [
'host',
'port',
'database',
+ 'default_catalog',
+ 'default_schema',
'username',
'password',
'access_token',
'http_path',
+ 'http_path_field',
'database_name',
'credentials_info',
'service_account_info',
@@ -71,8 +77,11 @@ const SSHTunnelSwitchComponent =
const FORM_FIELD_MAP = {
host: hostField,
http_path: httpPath,
+ http_path_field: httpPathField,
port: portField,
database: databaseField,
+ default_catalog: defaultCatalogField,
+ default_schema: defaultSchemaField,
username: usernameField,
password: passwordField,
access_token: accessTokenField,
diff --git a/superset-frontend/src/features/databases/DatabaseModal/index.tsx b/superset-frontend/src/features/databases/DatabaseModal/index.tsx
index 47c9a8b65..4e1e58ebb 100644
--- a/superset-frontend/src/features/databases/DatabaseModal/index.tsx
+++ b/superset-frontend/src/features/databases/DatabaseModal/index.tsx
@@ -633,11 +633,23 @@ const DatabaseModal: FunctionComponent = ({
const history = useHistory();
const dbModel: DatabaseForm =
+ // TODO: we need a centralized engine in one place
+
+ // first try to match both engine and driver
+ availableDbs?.databases?.find(
+ (available: {
+ engine: string | undefined;
+ default_driver: string | undefined;
+ }) =>
+ available.engine === (isEditMode ? db?.backend : db?.engine) &&
+ available.default_driver === db?.driver,
+ ) ||
+ // alternatively try to match only engine
availableDbs?.databases?.find(
(available: { engine: string | undefined }) =>
- // TODO: we need a centralized engine in one place
available.engine === (isEditMode ? db?.backend : db?.engine),
- ) || {};
+ ) ||
+ {};
// Test Connection logic
const testConnection = () => {
diff --git a/superset-frontend/src/features/databases/types.ts b/superset-frontend/src/features/databases/types.ts
index a09ad174a..c46296a2a 100644
--- a/superset-frontend/src/features/databases/types.ts
+++ b/superset-frontend/src/features/databases/types.ts
@@ -63,6 +63,9 @@ export type DatabaseObject = {
host?: string;
port?: number;
database?: string;
+ default_catalog?: string;
+ default_schema?: string;
+ http_path_field?: string;
username?: string;
password?: string;
encryption?: boolean;
@@ -126,6 +129,18 @@ export type DatabaseForm = {
description: string;
type: string;
};
+ default_catalog: {
+ description: string;
+ type: string;
+ };
+ default_schema: {
+ description: string;
+ type: string;
+ };
+ http_path_field: {
+ description: string;
+ type: string;
+ };
host: {
description: string;
type: string;
diff --git a/superset/db_engine_specs/databricks.py b/superset/db_engine_specs/databricks.py
index 4b2f93ca5..6fc753c00 100644
--- a/superset/db_engine_specs/databricks.py
+++ b/superset/db_engine_specs/databricks.py
@@ -18,7 +18,7 @@ from __future__ import annotations
import json
from datetime import datetime
-from typing import Any, TYPE_CHECKING, TypedDict
+from typing import Any, TYPE_CHECKING, TypedDict, Union
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
@@ -40,10 +40,10 @@ if TYPE_CHECKING:
#
-class DatabricksParametersSchema(Schema):
+class DatabricksBaseSchema(Schema):
"""
- This is the list of fields that are expected
- from the client in order to build the sqlalchemy string
+ Fields that are required for both Databricks drivers that uses a
+ dynamic form.
"""
access_token = fields.Str(required=True)
@@ -53,44 +53,85 @@ class DatabricksParametersSchema(Schema):
metadata={"description": __("Database port")},
validate=Range(min=0, max=2**16, max_inclusive=False),
)
- database = fields.Str(required=True)
encryption = fields.Boolean(
required=False,
metadata={"description": __("Use an encrypted connection to the database")},
)
-class DatabricksPropertiesSchema(DatabricksParametersSchema):
+class DatabricksBaseParametersType(TypedDict):
"""
- This is the list of fields expected
- for successful database creation execution
- """
-
- http_path = fields.Str(required=True)
-
-
-class DatabricksParametersType(TypedDict):
- """
- The parameters are all the keys that do
- not exist on the Database model.
- These are used to build the sqlalchemy uri
+ The parameters are all the keys that do not exist on the Database model.
+ These are used to build the sqlalchemy uri.
"""
access_token: str
host: str
port: int
- database: str
encryption: bool
-class DatabricksPropertiesType(TypedDict):
+class DatabricksNativeSchema(DatabricksBaseSchema):
"""
- All properties that need to be available to
- this engine in order to create a connection
- if the dynamic form is used
+ Additional fields required only for the DatabricksNativeEngineSpec.
"""
- parameters: DatabricksParametersType
+ database = fields.Str(required=True)
+
+
+class DatabricksNativePropertiesSchema(DatabricksNativeSchema):
+ """
+ Properties required only for the DatabricksNativeEngineSpec.
+ """
+
+ http_path = fields.Str(required=True)
+
+
+class DatabricksNativeParametersType(DatabricksBaseParametersType):
+ """
+ Additional parameters required only for the DatabricksNativeEngineSpec.
+ """
+
+ database: str
+
+
+class DatabricksNativePropertiesType(TypedDict):
+ """
+ All properties that need to be available to the DatabricksNativeEngineSpec
+ in order tocreate a connection if the dynamic form is used.
+ """
+
+ parameters: DatabricksNativeParametersType
+ extra: str
+
+
+class DatabricksPythonConnectorSchema(DatabricksBaseSchema):
+ """
+ Additional fields required only for the DatabricksPythonConnectorEngineSpec.
+ """
+
+ http_path_field = fields.Str(required=True)
+ default_catalog = fields.Str(required=True)
+ default_schema = fields.Str(required=True)
+
+
+class DatabricksPythonConnectorParametersType(DatabricksBaseParametersType):
+ """
+ Additional parameters required only for the DatabricksPythonConnectorEngineSpec.
+ """
+
+ http_path_field: str
+ default_catalog: str
+ default_schema: str
+
+
+class DatabricksPythonConnectorPropertiesType(TypedDict):
+ """
+ All properties that need to be available to the DatabricksPythonConnectorEngineSpec
+ in order to create a connection if the dynamic form is used.
+ """
+
+ parameters: DatabricksPythonConnectorParametersType
extra: str
@@ -125,13 +166,7 @@ class DatabricksHiveEngineSpec(HiveEngineSpec):
_time_grain_expressions = time_grain_expressions
-class DatabricksODBCEngineSpec(BaseEngineSpec):
- engine_name = "Databricks SQL Endpoint"
-
- engine = "databricks"
- drivers = {"pyodbc": "ODBC driver for SQL endpoint"}
- default_driver = "pyodbc"
-
+class DatabricksBaseEngineSpec(BaseEngineSpec):
_time_grain_expressions = time_grain_expressions
@classmethod
@@ -145,20 +180,23 @@ class DatabricksODBCEngineSpec(BaseEngineSpec):
return HiveEngineSpec.epoch_to_dttm()
-class DatabricksNativeEngineSpec(BasicParametersMixin, DatabricksODBCEngineSpec):
- engine_name = "Databricks"
+class DatabricksODBCEngineSpec(DatabricksBaseEngineSpec):
+ engine_name = "Databricks SQL Endpoint"
engine = "databricks"
- drivers = {"connector": "Native all-purpose driver"}
- default_driver = "connector"
+ drivers = {"pyodbc": "ODBC driver for SQL endpoint"}
+ default_driver = "pyodbc"
- parameters_schema = DatabricksParametersSchema()
- properties_schema = DatabricksPropertiesSchema()
- sqlalchemy_uri_placeholder = (
- "databricks+connector://token:{access_token}@{host}:{port}/{database_name}"
- )
+class DatabricksDynamicBaseEngineSpec(BasicParametersMixin, DatabricksBaseEngineSpec):
+ default_driver = ""
encryption_parameters = {"ssl": "1"}
+ required_parameters = {"access_token", "host", "port"}
+ context_key_mapping = {
+ "access_token": "password",
+ "host": "hostname",
+ "port": "port",
+ }
@staticmethod
def get_extra_params(database: Database) -> dict[str, Any]:
@@ -190,30 +228,6 @@ class DatabricksNativeEngineSpec(BasicParametersMixin, DatabricksODBCEngineSpec)
database, inspector, schema
) - cls.get_view_names(database, inspector, schema)
- @classmethod
- def build_sqlalchemy_uri( # type: ignore
- cls, parameters: DatabricksParametersType, *_
- ) -> str:
- query = {}
- if parameters.get("encryption"):
- if not cls.encryption_parameters:
- raise Exception( # pylint: disable=broad-exception-raised
- "Unable to build a URL with encryption enabled"
- )
- query.update(cls.encryption_parameters)
-
- return str(
- URL.create(
- f"{cls.engine}+{cls.default_driver}".rstrip("+"),
- username="token",
- password=parameters.get("access_token"),
- host=parameters["host"],
- port=parameters["port"],
- database=parameters["database"],
- query=query,
- )
- )
-
@classmethod
def extract_errors(
cls, ex: Exception, context: dict[str, Any] | None = None
@@ -224,13 +238,10 @@ class DatabricksNativeEngineSpec(BasicParametersMixin, DatabricksODBCEngineSpec)
# access_token isn't currently parseable from the
# databricks error response, but adding it in here
# for reference if their error message changes
- context = {
- "host": context.get("hostname"),
- "access_token": context.get("password"),
- "port": context.get("port"),
- "username": context.get("username"),
- "database": context.get("database"),
- }
+
+ for key, value in cls.context_key_mapping.items():
+ context[key] = context.get(value)
+
for regex, (message, error_type, extra) in cls.custom_errors.items():
match = regex.search(raw_message)
if match:
@@ -254,32 +265,18 @@ class DatabricksNativeEngineSpec(BasicParametersMixin, DatabricksODBCEngineSpec)
)
]
- @classmethod
- def get_parameters_from_uri( # type: ignore
- cls, uri: str, *_, **__
- ) -> DatabricksParametersType:
- url = make_url_safe(uri)
- encryption = all(
- item in url.query.items() for item in cls.encryption_parameters.items()
- )
- return {
- "access_token": url.password,
- "host": url.host,
- "port": url.port,
- "database": url.database,
- "encryption": encryption,
- }
-
@classmethod
def validate_parameters( # type: ignore
cls,
- properties: DatabricksPropertiesType,
+ properties: Union[
+ DatabricksNativePropertiesType,
+ DatabricksPythonConnectorPropertiesType,
+ ],
) -> list[SupersetError]:
errors: list[SupersetError] = []
- required = {"access_token", "host", "port", "database", "extra"}
- extra = json.loads(properties.get("extra", "{}"))
- engine_params = extra.get("engine_params", {})
- connect_args = engine_params.get("connect_args", {})
+ if extra := json.loads(properties.get("extra")): # type: ignore
+ engine_params = extra.get("engine_params", {})
+ connect_args = engine_params.get("connect_args", {})
parameters = {
**properties,
**properties.get("parameters", {}),
@@ -289,7 +286,7 @@ class DatabricksNativeEngineSpec(BasicParametersMixin, DatabricksODBCEngineSpec)
present = {key for key in parameters if parameters.get(key, ())}
- if missing := sorted(required - present):
+ if missing := sorted(cls.required_parameters - present):
errors.append(
SupersetError(
message=f'One or more parameters are missing: {", ".join(missing)}',
@@ -351,6 +348,69 @@ class DatabricksNativeEngineSpec(BasicParametersMixin, DatabricksODBCEngineSpec)
)
return errors
+
+class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec):
+ engine = "databricks"
+ engine_name = "Databricks"
+ drivers = {"connector": "Native all-purpose driver"}
+ default_driver = "connector"
+
+ parameters_schema = DatabricksNativeSchema()
+ properties_schema = DatabricksNativePropertiesSchema()
+
+ sqlalchemy_uri_placeholder = (
+ "databricks+connector://token:{access_token}@{host}:{port}/{database_name}"
+ )
+ context_key_mapping = {
+ **DatabricksDynamicBaseEngineSpec.context_key_mapping,
+ "database": "database",
+ "username": "username",
+ }
+ required_parameters = DatabricksDynamicBaseEngineSpec.required_parameters | {
+ "database",
+ "extra",
+ }
+
+ @classmethod
+ def build_sqlalchemy_uri( # type: ignore
+ cls, parameters: DatabricksNativeParametersType, *_
+ ) -> str:
+ query = {}
+ if parameters.get("encryption"):
+ if not cls.encryption_parameters:
+ raise Exception( # pylint: disable=broad-exception-raised
+ "Unable to build a URL with encryption enabled"
+ )
+ query.update(cls.encryption_parameters)
+
+ return str(
+ URL.create(
+ f"{cls.engine}+{cls.default_driver}".rstrip("+"),
+ username="token",
+ password=parameters.get("access_token"),
+ host=parameters["host"],
+ port=parameters["port"],
+ database=parameters["database"],
+ query=query,
+ )
+ )
+
+ @classmethod
+ def get_parameters_from_uri( # type: ignore
+ cls, uri: str, *_, **__
+ ) -> DatabricksNativeParametersType:
+ url = make_url_safe(uri)
+ encryption = all(
+ item in url.query.items() for item in cls.encryption_parameters.items()
+ )
+ return {
+ "access_token": url.password,
+ "host": url.host,
+ "port": url.port,
+ "database": url.database,
+ "encryption": encryption,
+ }
+
@classmethod
def parameters_json_schema(cls) -> Any:
"""
@@ -367,3 +427,78 @@ class DatabricksNativeEngineSpec(BasicParametersMixin, DatabricksODBCEngineSpec)
)
spec.components.schema(cls.__name__, schema=cls.properties_schema)
return spec.to_dict()["components"]["schemas"][cls.__name__]
+
+
+class DatabricksPythonConnectorEngineSpec(DatabricksDynamicBaseEngineSpec):
+ engine = "databricks"
+ engine_name = "Databricks Python Connector"
+ default_driver = "databricks-sql-python"
+ drivers = {"databricks-sql-python": "Databricks SQL Python"}
+
+ parameters_schema = DatabricksPythonConnectorSchema()
+
+ sqlalchemy_uri_placeholder = (
+ "databricks://token:{access_token}@{host}:{port}?http_path={http_path}"
+ "&catalog={default_catalog}&schema={default_schema}"
+ )
+
+ context_key_mapping = {
+ **DatabricksDynamicBaseEngineSpec.context_key_mapping,
+ "default_catalog": "catalog",
+ "default_schema": "schema",
+ "http_path_field": "http_path",
+ }
+
+ required_parameters = DatabricksDynamicBaseEngineSpec.required_parameters | {
+ "default_catalog",
+ "default_schema",
+ "http_path_field",
+ }
+
+ @classmethod
+ def build_sqlalchemy_uri( # type: ignore
+ cls, parameters: DatabricksPythonConnectorParametersType, *_
+ ) -> str:
+ query = {}
+ if http_path := parameters.get("http_path_field"):
+ query["http_path"] = http_path
+ if catalog := parameters.get("default_catalog"):
+ query["catalog"] = catalog
+ if schema := parameters.get("default_schema"):
+ query["schema"] = schema
+ if parameters.get("encryption"):
+ query.update(cls.encryption_parameters)
+
+ return str(
+ URL.create(
+ cls.engine,
+ username="token",
+ password=parameters.get("access_token"),
+ host=parameters["host"],
+ port=parameters["port"],
+ query=query,
+ )
+ )
+
+ @classmethod
+ def get_parameters_from_uri( # type: ignore
+ cls, uri: str, *_: Any, **__: Any
+ ) -> DatabricksPythonConnectorParametersType:
+ url = make_url_safe(uri)
+ query = {
+ key: value
+ for (key, value) in url.query.items()
+ if (key, value) not in cls.encryption_parameters.items()
+ }
+ encryption = all(
+ item in url.query.items() for item in cls.encryption_parameters.items()
+ )
+ return {
+ "access_token": url.password,
+ "host": url.host,
+ "port": url.port,
+ "http_path_field": query["http_path"],
+ "default_catalog": query["catalog"],
+ "default_schema": query["schema"],
+ "encryption": encryption,
+ }
diff --git a/tests/unit_tests/db_engine_specs/test_databricks.py b/tests/unit_tests/db_engine_specs/test_databricks.py
index de06f919b..8709833d3 100644
--- a/tests/unit_tests/db_engine_specs/test_databricks.py
+++ b/tests/unit_tests/db_engine_specs/test_databricks.py
@@ -35,13 +35,13 @@ def test_get_parameters_from_uri() -> None:
"""
from superset.db_engine_specs.databricks import (
DatabricksNativeEngineSpec,
- DatabricksParametersType,
+ DatabricksNativeParametersType,
)
parameters = DatabricksNativeEngineSpec.get_parameters_from_uri(
"databricks+connector://token:abc12345@my_hostname:1234/test"
)
- assert parameters == DatabricksParametersType(
+ assert parameters == DatabricksNativeParametersType(
{
"access_token": "abc12345",
"host": "my_hostname",
@@ -60,10 +60,10 @@ def test_build_sqlalchemy_uri() -> None:
"""
from superset.db_engine_specs.databricks import (
DatabricksNativeEngineSpec,
- DatabricksParametersType,
+ DatabricksNativeParametersType,
)
- parameters = DatabricksParametersType(
+ parameters = DatabricksNativeParametersType(
{
"access_token": "abc12345",
"host": "my_hostname",