chore(Databricks): New Databricks driver (#28393)

This commit is contained in:
Vitor Avila 2024-05-09 15:58:03 -03:00 committed by GitHub
parent e6a85c5901
commit 307ebeaa19
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 333 additions and 99 deletions

View File

@ -116,6 +116,69 @@ export const databaseField = ({
helpText={t('Copy the name of the database you are trying to connect to.')}
/>
);
export const defaultCatalogField = ({
required,
changeMethods,
getValidation,
validationErrors,
db,
}: FieldPropTypes) => (
<ValidatedInput
id="default_catalog"
name="default_catalog"
required={required}
value={db?.parameters?.default_catalog}
validationMethods={{ onBlur: getValidation }}
errorMessage={validationErrors?.default_catalog}
placeholder={t('e.g. hive_metastore')}
label={t('Default Catalog')}
onChange={changeMethods.onParametersChange}
helpText={t('The default catalog that should be used for the connection.')}
/>
);
export const defaultSchemaField = ({
required,
changeMethods,
getValidation,
validationErrors,
db,
}: FieldPropTypes) => (
<ValidatedInput
id="default_schema"
name="default_schema"
required={required}
value={db?.parameters?.default_schema}
validationMethods={{ onBlur: getValidation }}
errorMessage={validationErrors?.default_schema}
placeholder={t('e.g. default')}
label={t('Default Schema')}
onChange={changeMethods.onParametersChange}
helpText={t('The default schema that should be used for the connection.')}
/>
);
export const httpPathField = ({
required,
changeMethods,
getValidation,
validationErrors,
db,
}: FieldPropTypes) => {
console.error(db);
return (
<ValidatedInput
id="http_path_field"
name="http_path_field"
required={required}
value={db?.parameters?.http_path_field}
validationMethods={{ onBlur: getValidation }}
errorMessage={validationErrors?.http_path}
placeholder={t('e.g. sql/protocolv1/o/12345')}
label="HTTP Path"
onChange={changeMethods.onParametersChange}
helpText={t('Copy the name of the HTTP Path of your cluster.')}
/>
);
};
export const usernameField = ({
required,
changeMethods,

View File

@ -27,10 +27,13 @@ import { Form } from 'src/components/Form';
import {
accessTokenField,
databaseField,
defaultCatalogField,
defaultSchemaField,
displayField,
forceSSLField,
hostField,
httpPath,
httpPathField,
passwordField,
portField,
queryField,
@ -47,10 +50,13 @@ export const FormFieldOrder = [
'host',
'port',
'database',
'default_catalog',
'default_schema',
'username',
'password',
'access_token',
'http_path',
'http_path_field',
'database_name',
'credentials_info',
'service_account_info',
@ -71,8 +77,11 @@ const SSHTunnelSwitchComponent =
const FORM_FIELD_MAP = {
host: hostField,
http_path: httpPath,
http_path_field: httpPathField,
port: portField,
database: databaseField,
default_catalog: defaultCatalogField,
default_schema: defaultSchemaField,
username: usernameField,
password: passwordField,
access_token: accessTokenField,

View File

@ -633,11 +633,23 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
const history = useHistory();
const dbModel: DatabaseForm =
// TODO: we need a centralized engine in one place
// first try to match both engine and driver
availableDbs?.databases?.find(
(available: {
engine: string | undefined;
default_driver: string | undefined;
}) =>
available.engine === (isEditMode ? db?.backend : db?.engine) &&
available.default_driver === db?.driver,
) ||
// alternatively try to match only engine
availableDbs?.databases?.find(
(available: { engine: string | undefined }) =>
// TODO: we need a centralized engine in one place
available.engine === (isEditMode ? db?.backend : db?.engine),
) || {};
) ||
{};
// Test Connection logic
const testConnection = () => {

View File

@ -63,6 +63,9 @@ export type DatabaseObject = {
host?: string;
port?: number;
database?: string;
default_catalog?: string;
default_schema?: string;
http_path_field?: string;
username?: string;
password?: string;
encryption?: boolean;
@ -126,6 +129,18 @@ export type DatabaseForm = {
description: string;
type: string;
};
default_catalog: {
description: string;
type: string;
};
default_schema: {
description: string;
type: string;
};
http_path_field: {
description: string;
type: string;
};
host: {
description: string;
type: string;

View File

@ -18,7 +18,7 @@ from __future__ import annotations
import json
from datetime import datetime
from typing import Any, TYPE_CHECKING, TypedDict
from typing import Any, TYPE_CHECKING, TypedDict, Union
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
@ -40,10 +40,10 @@ if TYPE_CHECKING:
#
class DatabricksParametersSchema(Schema):
class DatabricksBaseSchema(Schema):
"""
This is the list of fields that are expected
from the client in order to build the sqlalchemy string
Fields that are required for both Databricks drivers that uses a
dynamic form.
"""
access_token = fields.Str(required=True)
@ -53,44 +53,85 @@ class DatabricksParametersSchema(Schema):
metadata={"description": __("Database port")},
validate=Range(min=0, max=2**16, max_inclusive=False),
)
database = fields.Str(required=True)
encryption = fields.Boolean(
required=False,
metadata={"description": __("Use an encrypted connection to the database")},
)
class DatabricksPropertiesSchema(DatabricksParametersSchema):
class DatabricksBaseParametersType(TypedDict):
"""
This is the list of fields expected
for successful database creation execution
"""
http_path = fields.Str(required=True)
class DatabricksParametersType(TypedDict):
"""
The parameters are all the keys that do
not exist on the Database model.
These are used to build the sqlalchemy uri
The parameters are all the keys that do not exist on the Database model.
These are used to build the sqlalchemy uri.
"""
access_token: str
host: str
port: int
database: str
encryption: bool
class DatabricksPropertiesType(TypedDict):
class DatabricksNativeSchema(DatabricksBaseSchema):
"""
All properties that need to be available to
this engine in order to create a connection
if the dynamic form is used
Additional fields required only for the DatabricksNativeEngineSpec.
"""
parameters: DatabricksParametersType
database = fields.Str(required=True)
class DatabricksNativePropertiesSchema(DatabricksNativeSchema):
"""
Properties required only for the DatabricksNativeEngineSpec.
"""
http_path = fields.Str(required=True)
class DatabricksNativeParametersType(DatabricksBaseParametersType):
"""
Additional parameters required only for the DatabricksNativeEngineSpec.
"""
database: str
class DatabricksNativePropertiesType(TypedDict):
"""
All properties that need to be available to the DatabricksNativeEngineSpec
in order tocreate a connection if the dynamic form is used.
"""
parameters: DatabricksNativeParametersType
extra: str
class DatabricksPythonConnectorSchema(DatabricksBaseSchema):
"""
Additional fields required only for the DatabricksPythonConnectorEngineSpec.
"""
http_path_field = fields.Str(required=True)
default_catalog = fields.Str(required=True)
default_schema = fields.Str(required=True)
class DatabricksPythonConnectorParametersType(DatabricksBaseParametersType):
"""
Additional parameters required only for the DatabricksPythonConnectorEngineSpec.
"""
http_path_field: str
default_catalog: str
default_schema: str
class DatabricksPythonConnectorPropertiesType(TypedDict):
"""
All properties that need to be available to the DatabricksPythonConnectorEngineSpec
in order to create a connection if the dynamic form is used.
"""
parameters: DatabricksPythonConnectorParametersType
extra: str
@ -125,13 +166,7 @@ class DatabricksHiveEngineSpec(HiveEngineSpec):
_time_grain_expressions = time_grain_expressions
class DatabricksODBCEngineSpec(BaseEngineSpec):
engine_name = "Databricks SQL Endpoint"
engine = "databricks"
drivers = {"pyodbc": "ODBC driver for SQL endpoint"}
default_driver = "pyodbc"
class DatabricksBaseEngineSpec(BaseEngineSpec):
_time_grain_expressions = time_grain_expressions
@classmethod
@ -145,20 +180,23 @@ class DatabricksODBCEngineSpec(BaseEngineSpec):
return HiveEngineSpec.epoch_to_dttm()
class DatabricksNativeEngineSpec(BasicParametersMixin, DatabricksODBCEngineSpec):
engine_name = "Databricks"
class DatabricksODBCEngineSpec(DatabricksBaseEngineSpec):
engine_name = "Databricks SQL Endpoint"
engine = "databricks"
drivers = {"connector": "Native all-purpose driver"}
default_driver = "connector"
drivers = {"pyodbc": "ODBC driver for SQL endpoint"}
default_driver = "pyodbc"
parameters_schema = DatabricksParametersSchema()
properties_schema = DatabricksPropertiesSchema()
sqlalchemy_uri_placeholder = (
"databricks+connector://token:{access_token}@{host}:{port}/{database_name}"
)
class DatabricksDynamicBaseEngineSpec(BasicParametersMixin, DatabricksBaseEngineSpec):
default_driver = ""
encryption_parameters = {"ssl": "1"}
required_parameters = {"access_token", "host", "port"}
context_key_mapping = {
"access_token": "password",
"host": "hostname",
"port": "port",
}
@staticmethod
def get_extra_params(database: Database) -> dict[str, Any]:
@ -190,30 +228,6 @@ class DatabricksNativeEngineSpec(BasicParametersMixin, DatabricksODBCEngineSpec)
database, inspector, schema
) - cls.get_view_names(database, inspector, schema)
@classmethod
def build_sqlalchemy_uri( # type: ignore
cls, parameters: DatabricksParametersType, *_
) -> str:
query = {}
if parameters.get("encryption"):
if not cls.encryption_parameters:
raise Exception( # pylint: disable=broad-exception-raised
"Unable to build a URL with encryption enabled"
)
query.update(cls.encryption_parameters)
return str(
URL.create(
f"{cls.engine}+{cls.default_driver}".rstrip("+"),
username="token",
password=parameters.get("access_token"),
host=parameters["host"],
port=parameters["port"],
database=parameters["database"],
query=query,
)
)
@classmethod
def extract_errors(
cls, ex: Exception, context: dict[str, Any] | None = None
@ -224,13 +238,10 @@ class DatabricksNativeEngineSpec(BasicParametersMixin, DatabricksODBCEngineSpec)
# access_token isn't currently parseable from the
# databricks error response, but adding it in here
# for reference if their error message changes
context = {
"host": context.get("hostname"),
"access_token": context.get("password"),
"port": context.get("port"),
"username": context.get("username"),
"database": context.get("database"),
}
for key, value in cls.context_key_mapping.items():
context[key] = context.get(value)
for regex, (message, error_type, extra) in cls.custom_errors.items():
match = regex.search(raw_message)
if match:
@ -254,32 +265,18 @@ class DatabricksNativeEngineSpec(BasicParametersMixin, DatabricksODBCEngineSpec)
)
]
@classmethod
def get_parameters_from_uri( # type: ignore
cls, uri: str, *_, **__
) -> DatabricksParametersType:
url = make_url_safe(uri)
encryption = all(
item in url.query.items() for item in cls.encryption_parameters.items()
)
return {
"access_token": url.password,
"host": url.host,
"port": url.port,
"database": url.database,
"encryption": encryption,
}
@classmethod
def validate_parameters( # type: ignore
cls,
properties: DatabricksPropertiesType,
properties: Union[
DatabricksNativePropertiesType,
DatabricksPythonConnectorPropertiesType,
],
) -> list[SupersetError]:
errors: list[SupersetError] = []
required = {"access_token", "host", "port", "database", "extra"}
extra = json.loads(properties.get("extra", "{}"))
engine_params = extra.get("engine_params", {})
connect_args = engine_params.get("connect_args", {})
if extra := json.loads(properties.get("extra")): # type: ignore
engine_params = extra.get("engine_params", {})
connect_args = engine_params.get("connect_args", {})
parameters = {
**properties,
**properties.get("parameters", {}),
@ -289,7 +286,7 @@ class DatabricksNativeEngineSpec(BasicParametersMixin, DatabricksODBCEngineSpec)
present = {key for key in parameters if parameters.get(key, ())}
if missing := sorted(required - present):
if missing := sorted(cls.required_parameters - present):
errors.append(
SupersetError(
message=f'One or more parameters are missing: {", ".join(missing)}',
@ -351,6 +348,69 @@ class DatabricksNativeEngineSpec(BasicParametersMixin, DatabricksODBCEngineSpec)
)
return errors
class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec):
engine = "databricks"
engine_name = "Databricks"
drivers = {"connector": "Native all-purpose driver"}
default_driver = "connector"
parameters_schema = DatabricksNativeSchema()
properties_schema = DatabricksNativePropertiesSchema()
sqlalchemy_uri_placeholder = (
"databricks+connector://token:{access_token}@{host}:{port}/{database_name}"
)
context_key_mapping = {
**DatabricksDynamicBaseEngineSpec.context_key_mapping,
"database": "database",
"username": "username",
}
required_parameters = DatabricksDynamicBaseEngineSpec.required_parameters | {
"database",
"extra",
}
@classmethod
def build_sqlalchemy_uri( # type: ignore
cls, parameters: DatabricksNativeParametersType, *_
) -> str:
query = {}
if parameters.get("encryption"):
if not cls.encryption_parameters:
raise Exception( # pylint: disable=broad-exception-raised
"Unable to build a URL with encryption enabled"
)
query.update(cls.encryption_parameters)
return str(
URL.create(
f"{cls.engine}+{cls.default_driver}".rstrip("+"),
username="token",
password=parameters.get("access_token"),
host=parameters["host"],
port=parameters["port"],
database=parameters["database"],
query=query,
)
)
@classmethod
def get_parameters_from_uri( # type: ignore
cls, uri: str, *_, **__
) -> DatabricksNativeParametersType:
url = make_url_safe(uri)
encryption = all(
item in url.query.items() for item in cls.encryption_parameters.items()
)
return {
"access_token": url.password,
"host": url.host,
"port": url.port,
"database": url.database,
"encryption": encryption,
}
@classmethod
def parameters_json_schema(cls) -> Any:
"""
@ -367,3 +427,78 @@ class DatabricksNativeEngineSpec(BasicParametersMixin, DatabricksODBCEngineSpec)
)
spec.components.schema(cls.__name__, schema=cls.properties_schema)
return spec.to_dict()["components"]["schemas"][cls.__name__]
class DatabricksPythonConnectorEngineSpec(DatabricksDynamicBaseEngineSpec):
engine = "databricks"
engine_name = "Databricks Python Connector"
default_driver = "databricks-sql-python"
drivers = {"databricks-sql-python": "Databricks SQL Python"}
parameters_schema = DatabricksPythonConnectorSchema()
sqlalchemy_uri_placeholder = (
"databricks://token:{access_token}@{host}:{port}?http_path={http_path}"
"&catalog={default_catalog}&schema={default_schema}"
)
context_key_mapping = {
**DatabricksDynamicBaseEngineSpec.context_key_mapping,
"default_catalog": "catalog",
"default_schema": "schema",
"http_path_field": "http_path",
}
required_parameters = DatabricksDynamicBaseEngineSpec.required_parameters | {
"default_catalog",
"default_schema",
"http_path_field",
}
@classmethod
def build_sqlalchemy_uri( # type: ignore
cls, parameters: DatabricksPythonConnectorParametersType, *_
) -> str:
query = {}
if http_path := parameters.get("http_path_field"):
query["http_path"] = http_path
if catalog := parameters.get("default_catalog"):
query["catalog"] = catalog
if schema := parameters.get("default_schema"):
query["schema"] = schema
if parameters.get("encryption"):
query.update(cls.encryption_parameters)
return str(
URL.create(
cls.engine,
username="token",
password=parameters.get("access_token"),
host=parameters["host"],
port=parameters["port"],
query=query,
)
)
@classmethod
def get_parameters_from_uri( # type: ignore
cls, uri: str, *_: Any, **__: Any
) -> DatabricksPythonConnectorParametersType:
url = make_url_safe(uri)
query = {
key: value
for (key, value) in url.query.items()
if (key, value) not in cls.encryption_parameters.items()
}
encryption = all(
item in url.query.items() for item in cls.encryption_parameters.items()
)
return {
"access_token": url.password,
"host": url.host,
"port": url.port,
"http_path_field": query["http_path"],
"default_catalog": query["catalog"],
"default_schema": query["schema"],
"encryption": encryption,
}

View File

@ -35,13 +35,13 @@ def test_get_parameters_from_uri() -> None:
"""
from superset.db_engine_specs.databricks import (
DatabricksNativeEngineSpec,
DatabricksParametersType,
DatabricksNativeParametersType,
)
parameters = DatabricksNativeEngineSpec.get_parameters_from_uri(
"databricks+connector://token:abc12345@my_hostname:1234/test"
)
assert parameters == DatabricksParametersType(
assert parameters == DatabricksNativeParametersType(
{
"access_token": "abc12345",
"host": "my_hostname",
@ -60,10 +60,10 @@ def test_build_sqlalchemy_uri() -> None:
"""
from superset.db_engine_specs.databricks import (
DatabricksNativeEngineSpec,
DatabricksParametersType,
DatabricksNativeParametersType,
)
parameters = DatabricksParametersType(
parameters = DatabricksNativeParametersType(
{
"access_token": "abc12345",
"host": "my_hostname",