fix(clickhouse): add clickhouse connect driver (#23185)
This commit is contained in:
parent
f0f27a486d
commit
d0c54cddb0
|
|
@ -72,7 +72,6 @@ from superset.utils.hashing import md5_sha_from_str
|
|||
from superset.utils.network import is_hostname_valid, is_port_open
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# prevent circular imports
|
||||
from superset.connectors.sqla.models import TableColumn
|
||||
from superset.models.core import Database
|
||||
from superset.models.sql_lab import Query
|
||||
|
|
|
|||
|
|
@ -14,29 +14,43 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, cast, Dict, List, Optional, Type, TYPE_CHECKING
|
||||
|
||||
from flask import current_app
|
||||
from flask_babel import gettext as __
|
||||
from marshmallow import fields, Schema
|
||||
from marshmallow.validate import Range
|
||||
from sqlalchemy import types
|
||||
from sqlalchemy.engine.url import URL
|
||||
from urllib3.exceptions import NewConnectionError
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.databases.utils import make_url_safe
|
||||
from superset.db_engine_specs.base import (
|
||||
BaseEngineSpec,
|
||||
BasicParametersMixin,
|
||||
BasicParametersType,
|
||||
BasicPropertiesType,
|
||||
)
|
||||
from superset.db_engine_specs.exceptions import SupersetDBAPIDatabaseError
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.extensions import cache_manager
|
||||
from superset.utils.core import GenericDataType
|
||||
from superset.utils.hashing import md5_sha_from_str
|
||||
from superset.utils.network import is_hostname_valid, is_port_open
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# prevent circular imports
|
||||
from superset.models.core import Database
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClickHouseEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
|
||||
"""Dialect for ClickHouse analytical DB."""
|
||||
|
||||
engine = "clickhouse"
|
||||
engine_name = "ClickHouse"
|
||||
class ClickHouseBaseEngineSpec(BaseEngineSpec):
|
||||
"""Shared engine spec for ClickHouse."""
|
||||
|
||||
time_secondary_columns = True
|
||||
time_groupby_inline = True
|
||||
|
|
@ -56,8 +70,78 @@ class ClickHouseEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
|
|||
"P1Y": "toStartOfYear(toDateTime({col}))",
|
||||
}
|
||||
|
||||
_show_functions_column = "name"
|
||||
column_type_mappings = (
|
||||
(
|
||||
re.compile(r".*Enum.*", re.IGNORECASE),
|
||||
types.String(),
|
||||
GenericDataType.STRING,
|
||||
),
|
||||
(
|
||||
re.compile(r".*Array.*", re.IGNORECASE),
|
||||
types.String(),
|
||||
GenericDataType.STRING,
|
||||
),
|
||||
(
|
||||
re.compile(r".*UUID.*", re.IGNORECASE),
|
||||
types.String(),
|
||||
GenericDataType.STRING,
|
||||
),
|
||||
(
|
||||
re.compile(r".*Bool.*", re.IGNORECASE),
|
||||
types.Boolean(),
|
||||
GenericDataType.BOOLEAN,
|
||||
),
|
||||
(
|
||||
re.compile(r".*String.*", re.IGNORECASE),
|
||||
types.String(),
|
||||
GenericDataType.STRING,
|
||||
),
|
||||
(
|
||||
re.compile(r".*Int\d+.*", re.IGNORECASE),
|
||||
types.INTEGER(),
|
||||
GenericDataType.NUMERIC,
|
||||
),
|
||||
(
|
||||
re.compile(r".*Decimal.*", re.IGNORECASE),
|
||||
types.DECIMAL(),
|
||||
GenericDataType.NUMERIC,
|
||||
),
|
||||
(
|
||||
re.compile(r".*DateTime.*", re.IGNORECASE),
|
||||
types.DateTime(),
|
||||
GenericDataType.TEMPORAL,
|
||||
),
|
||||
(
|
||||
re.compile(r".*Date.*", re.IGNORECASE),
|
||||
types.Date(),
|
||||
GenericDataType.TEMPORAL,
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def epoch_to_dttm(cls) -> str:
|
||||
return "{col}"
|
||||
|
||||
@classmethod
|
||||
def convert_dttm(
|
||||
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[str]:
|
||||
sqla_type = cls.get_sqla_column_type(target_type)
|
||||
|
||||
if isinstance(sqla_type, types.Date):
|
||||
return f"toDate('{dttm.date().isoformat()}')"
|
||||
if isinstance(sqla_type, types.DateTime):
|
||||
return f"""toDateTime('{dttm.isoformat(sep=" ", timespec="seconds")}')"""
|
||||
return None
|
||||
|
||||
|
||||
class ClickHouseEngineSpec(ClickHouseBaseEngineSpec):
|
||||
"""Engine spec for clickhouse_sqlalchemy connector"""
|
||||
|
||||
engine = "clickhouse"
|
||||
engine_name = "ClickHouse"
|
||||
|
||||
_show_functions_column = "name"
|
||||
supports_file_upload = False
|
||||
|
||||
@classmethod
|
||||
|
|
@ -73,21 +157,9 @@ class ClickHouseEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
|
|||
return exception
|
||||
return new_exception(str(exception))
|
||||
|
||||
@classmethod
|
||||
def convert_dttm(
|
||||
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[str]:
|
||||
sqla_type = cls.get_sqla_column_type(target_type)
|
||||
|
||||
if isinstance(sqla_type, types.Date):
|
||||
return f"toDate('{dttm.date().isoformat()}')"
|
||||
if isinstance(sqla_type, types.DateTime):
|
||||
return f"""toDateTime('{dttm.isoformat(sep=" ", timespec="seconds")}')"""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@cache_manager.cache.memoize()
|
||||
def get_function_names(cls, database: "Database") -> List[str]:
|
||||
def get_function_names(cls, database: Database) -> List[str]:
|
||||
"""
|
||||
Get a list of function names that are able to be called on the database.
|
||||
Used for SQL Lab autocomplete.
|
||||
|
|
@ -123,3 +195,201 @@ class ClickHouseEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
|
|||
|
||||
# otherwise, return no function names to prevent errors
|
||||
return []
|
||||
|
||||
|
||||
class ClickHouseParametersSchema(Schema):
|
||||
username = fields.String(allow_none=True, description=__("Username"))
|
||||
password = fields.String(allow_none=True, description=__("Password"))
|
||||
host = fields.String(required=True, description=__("Hostname or IP address"))
|
||||
port = fields.Integer(
|
||||
allow_none=True,
|
||||
description=__("Database port"),
|
||||
validate=Range(min=0, max=65535),
|
||||
)
|
||||
database = fields.String(allow_none=True, description=__("Database name"))
|
||||
encryption = fields.Boolean(
|
||||
default=True, description=__("Use an encrypted connection to the database")
|
||||
)
|
||||
query = fields.Dict(
|
||||
keys=fields.Str(), values=fields.Raw(), description=__("Additional parameters")
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
from clickhouse_connect.common import set_setting
|
||||
from clickhouse_connect.datatypes.format import set_default_formats
|
||||
|
||||
# override default formats for compatibility
|
||||
set_default_formats(
|
||||
"FixedString",
|
||||
"string",
|
||||
"IPv*",
|
||||
"string",
|
||||
"signed",
|
||||
"UUID",
|
||||
"string",
|
||||
"*Int256",
|
||||
"string",
|
||||
"*Int128",
|
||||
"string",
|
||||
)
|
||||
set_setting(
|
||||
"product_name",
|
||||
f"superset/{current_app.config.get('VERSION_STRING', 'dev')}",
|
||||
)
|
||||
except ImportError: # ClickHouse Connect not installed, do nothing
|
||||
pass
|
||||
|
||||
|
||||
class ClickHouseConnectEngineSpec(ClickHouseEngineSpec, BasicParametersMixin):
|
||||
"""Engine spec for clickhouse-connect connector"""
|
||||
|
||||
engine = "clickhousedb"
|
||||
engine_name = "ClickHouse Connect"
|
||||
|
||||
default_driver = "connect"
|
||||
_function_names: List[str] = []
|
||||
|
||||
sqlalchemy_uri_placeholder = (
|
||||
"clickhousedb://user:password@host[:port][/dbname][?secure=value&=value...]"
|
||||
)
|
||||
parameters_schema = ClickHouseParametersSchema()
|
||||
encryption_parameters = {"secure": "true"}
|
||||
|
||||
@classmethod
|
||||
def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def get_dbapi_mapped_exception(cls, exception: Exception) -> Exception:
|
||||
new_exception = cls.get_dbapi_exception_mapping().get(type(exception))
|
||||
if new_exception == SupersetDBAPIDatabaseError:
|
||||
return SupersetDBAPIDatabaseError("Connection failed")
|
||||
if not new_exception:
|
||||
return exception
|
||||
return new_exception(str(exception))
|
||||
|
||||
@classmethod
|
||||
def get_function_names(cls, database: Database) -> List[str]:
|
||||
# pylint: disable=import-outside-toplevel,import-error
|
||||
from clickhouse_connect.driver.exceptions import ClickHouseError
|
||||
|
||||
if cls._function_names:
|
||||
return cls._function_names
|
||||
try:
|
||||
names = database.get_df(
|
||||
"SELECT name FROM system.functions UNION ALL "
|
||||
+ "SELECT name FROM system.table_functions LIMIT 10000"
|
||||
)["name"].tolist()
|
||||
cls._function_names = names
|
||||
return names
|
||||
except ClickHouseError:
|
||||
logger.exception("Error retrieving system.functions")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def get_datatype(cls, type_code: str) -> str:
|
||||
# keep it lowercase, as ClickHouse types aren't typical SHOUTCASE ANSI SQL
|
||||
return type_code
|
||||
|
||||
@classmethod
|
||||
def build_sqlalchemy_uri(
|
||||
cls,
|
||||
parameters: BasicParametersType,
|
||||
encrypted_extra: Optional[Dict[str, str]] = None,
|
||||
) -> str:
|
||||
url_params = parameters.copy()
|
||||
if url_params.get("encryption"):
|
||||
query = parameters.get("query", {}).copy()
|
||||
query.update(cls.encryption_parameters)
|
||||
url_params["query"] = query
|
||||
if not url_params.get("database"):
|
||||
url_params["database"] = "__default__"
|
||||
url_params.pop("encryption", None)
|
||||
return str(URL(f"{cls.engine}+{cls.default_driver}", **url_params))
|
||||
|
||||
@classmethod
|
||||
def get_parameters_from_uri(
|
||||
cls, uri: str, encrypted_extra: Optional[Dict[str, Any]] = None
|
||||
) -> BasicParametersType:
|
||||
url = make_url_safe(uri)
|
||||
query = url.query
|
||||
if "secure" in query:
|
||||
encryption = url.query.get("secure") == "true"
|
||||
query.pop("secure")
|
||||
else:
|
||||
encryption = False
|
||||
return BasicParametersType(
|
||||
username=url.username,
|
||||
password=url.password,
|
||||
host=url.host,
|
||||
port=url.port,
|
||||
database="" if url.database == "__default__" else cast(str, url.database),
|
||||
query=dict(query),
|
||||
encryption=encryption,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_parameters(
|
||||
cls, properties: BasicPropertiesType
|
||||
) -> List[SupersetError]:
|
||||
# pylint: disable=import-outside-toplevel,import-error
|
||||
from clickhouse_connect.driver import default_port
|
||||
|
||||
parameters = properties.get("parameters", {})
|
||||
host = parameters.get("host", None)
|
||||
if not host:
|
||||
return [
|
||||
SupersetError(
|
||||
"Hostname is required",
|
||||
SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR,
|
||||
ErrorLevel.WARNING,
|
||||
{"missing": ["host"]},
|
||||
)
|
||||
]
|
||||
if not is_hostname_valid(host):
|
||||
return [
|
||||
SupersetError(
|
||||
"The hostname provided can't be resolved.",
|
||||
SupersetErrorType.CONNECTION_INVALID_HOSTNAME_ERROR,
|
||||
ErrorLevel.ERROR,
|
||||
{"invalid": ["host"]},
|
||||
)
|
||||
]
|
||||
port = parameters.get("port")
|
||||
if port is None:
|
||||
port = default_port("http", parameters.get("encryption", False))
|
||||
try:
|
||||
port = int(port)
|
||||
except (ValueError, TypeError):
|
||||
port = -1
|
||||
if port <= 0 or port >= 65535:
|
||||
return [
|
||||
SupersetError(
|
||||
"Port must be a valid integer between 0 and 65535 (inclusive).",
|
||||
SupersetErrorType.CONNECTION_INVALID_PORT_ERROR,
|
||||
ErrorLevel.ERROR,
|
||||
{"invalid": ["port"]},
|
||||
)
|
||||
]
|
||||
if not is_port_open(host, port):
|
||||
return [
|
||||
SupersetError(
|
||||
"The port is closed.",
|
||||
SupersetErrorType.CONNECTION_PORT_CLOSED_ERROR,
|
||||
ErrorLevel.ERROR,
|
||||
{"invalid": ["port"]},
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _mutate_label(label: str) -> str:
|
||||
"""
|
||||
Suffix with the first six characters from the md5 of the label to avoid
|
||||
collisions with original column names
|
||||
|
||||
:param label: Expected expression label
|
||||
:return: Conditionally mutated label
|
||||
"""
|
||||
return f"{label}_{md5_sha_from_str(label)[:6]}"
|
||||
|
|
|
|||
|
|
@ -16,12 +16,26 @@
|
|||
# under the License.
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional, Type
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.types import (
|
||||
Boolean,
|
||||
Date,
|
||||
DateTime,
|
||||
DECIMAL,
|
||||
Float,
|
||||
Integer,
|
||||
String,
|
||||
TypeEngine,
|
||||
)
|
||||
|
||||
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
|
||||
from superset.utils.core import GenericDataType
|
||||
from tests.unit_tests.db_engine_specs.utils import (
|
||||
assert_column_spec,
|
||||
assert_convert_dttm,
|
||||
)
|
||||
from tests.unit_tests.fixtures.common import dttm
|
||||
|
||||
|
||||
|
|
@ -53,3 +67,147 @@ def test_execute_connection_error() -> None:
|
|||
)
|
||||
with pytest.raises(SupersetDBAPIDatabaseError) as ex:
|
||||
ClickHouseEngineSpec.execute(cursor, "SELECT col1 from table1")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"target_type,expected_result",
|
||||
[
|
||||
("Date", "toDate('2019-01-02')"),
|
||||
("DateTime", "toDateTime('2019-01-02 03:04:05')"),
|
||||
("UnknownType", None),
|
||||
],
|
||||
)
|
||||
def test_connect_convert_dttm(
|
||||
target_type: str, expected_result: Optional[str], dttm: datetime
|
||||
) -> None:
|
||||
from superset.db_engine_specs.clickhouse import ClickHouseEngineSpec as spec
|
||||
|
||||
assert_convert_dttm(spec, target_type, expected_result, dttm)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"native_type,sqla_type,attrs,generic_type,is_dttm",
|
||||
[
|
||||
("String", String, None, GenericDataType.STRING, False),
|
||||
("LowCardinality(String)", String, None, GenericDataType.STRING, False),
|
||||
("Nullable(String)", String, None, GenericDataType.STRING, False),
|
||||
(
|
||||
"LowCardinality(Nullable(String))",
|
||||
String,
|
||||
None,
|
||||
GenericDataType.STRING,
|
||||
False,
|
||||
),
|
||||
("Array(UInt8)", String, None, GenericDataType.STRING, False),
|
||||
("Enum('hello', 'world')", String, None, GenericDataType.STRING, False),
|
||||
("Enum('UInt32', 'Bool')", String, None, GenericDataType.STRING, False),
|
||||
(
|
||||
"LowCardinality(Enum('hello', 'world'))",
|
||||
String,
|
||||
None,
|
||||
GenericDataType.STRING,
|
||||
False,
|
||||
),
|
||||
(
|
||||
"Nullable(Enum('hello', 'world'))",
|
||||
String,
|
||||
None,
|
||||
GenericDataType.STRING,
|
||||
False,
|
||||
),
|
||||
(
|
||||
"LowCardinality(Nullable(Enum('hello', 'world')))",
|
||||
String,
|
||||
None,
|
||||
GenericDataType.STRING,
|
||||
False,
|
||||
),
|
||||
("FixedString(16)", String, None, GenericDataType.STRING, False),
|
||||
("Nullable(FixedString(16))", String, None, GenericDataType.STRING, False),
|
||||
(
|
||||
"LowCardinality(Nullable(FixedString(16)))",
|
||||
String,
|
||||
None,
|
||||
GenericDataType.STRING,
|
||||
False,
|
||||
),
|
||||
("UUID", String, None, GenericDataType.STRING, False),
|
||||
("Int8", Integer, None, GenericDataType.NUMERIC, False),
|
||||
("Int16", Integer, None, GenericDataType.NUMERIC, False),
|
||||
("Int32", Integer, None, GenericDataType.NUMERIC, False),
|
||||
("Int64", Integer, None, GenericDataType.NUMERIC, False),
|
||||
("Int128", Integer, None, GenericDataType.NUMERIC, False),
|
||||
("Int256", Integer, None, GenericDataType.NUMERIC, False),
|
||||
("Nullable(Int256)", Integer, None, GenericDataType.NUMERIC, False),
|
||||
(
|
||||
"LowCardinality(Nullable(Int256))",
|
||||
Integer,
|
||||
None,
|
||||
GenericDataType.NUMERIC,
|
||||
False,
|
||||
),
|
||||
("UInt8", Integer, None, GenericDataType.NUMERIC, False),
|
||||
("UInt16", Integer, None, GenericDataType.NUMERIC, False),
|
||||
("UInt32", Integer, None, GenericDataType.NUMERIC, False),
|
||||
("UInt64", Integer, None, GenericDataType.NUMERIC, False),
|
||||
("UInt128", Integer, None, GenericDataType.NUMERIC, False),
|
||||
("UInt256", Integer, None, GenericDataType.NUMERIC, False),
|
||||
("Nullable(UInt256)", Integer, None, GenericDataType.NUMERIC, False),
|
||||
(
|
||||
"LowCardinality(Nullable(UInt256))",
|
||||
Integer,
|
||||
None,
|
||||
GenericDataType.NUMERIC,
|
||||
False,
|
||||
),
|
||||
("Float32", Float, None, GenericDataType.NUMERIC, False),
|
||||
("Float64", Float, None, GenericDataType.NUMERIC, False),
|
||||
("Decimal(1, 2)", DECIMAL, None, GenericDataType.NUMERIC, False),
|
||||
("Decimal32(2)", DECIMAL, None, GenericDataType.NUMERIC, False),
|
||||
("Decimal64(2)", DECIMAL, None, GenericDataType.NUMERIC, False),
|
||||
("Decimal128(2)", DECIMAL, None, GenericDataType.NUMERIC, False),
|
||||
("Decimal256(2)", DECIMAL, None, GenericDataType.NUMERIC, False),
|
||||
("Bool", Boolean, None, GenericDataType.BOOLEAN, False),
|
||||
("Nullable(Bool)", Boolean, None, GenericDataType.BOOLEAN, False),
|
||||
("Date", Date, None, GenericDataType.TEMPORAL, True),
|
||||
("Nullable(Date)", Date, None, GenericDataType.TEMPORAL, True),
|
||||
("LowCardinality(Nullable(Date))", Date, None, GenericDataType.TEMPORAL, True),
|
||||
("Date32", Date, None, GenericDataType.TEMPORAL, True),
|
||||
("Datetime", DateTime, None, GenericDataType.TEMPORAL, True),
|
||||
("Nullable(Datetime)", DateTime, None, GenericDataType.TEMPORAL, True),
|
||||
(
|
||||
"LowCardinality(Nullable(Datetime))",
|
||||
DateTime,
|
||||
None,
|
||||
GenericDataType.TEMPORAL,
|
||||
True,
|
||||
),
|
||||
("Datetime('UTC')", DateTime, None, GenericDataType.TEMPORAL, True),
|
||||
("Datetime64(3)", DateTime, None, GenericDataType.TEMPORAL, True),
|
||||
("Datetime64(3, 'UTC')", DateTime, None, GenericDataType.TEMPORAL, True),
|
||||
],
|
||||
)
|
||||
def test_connect_get_column_spec(
|
||||
native_type: str,
|
||||
sqla_type: Type[TypeEngine],
|
||||
attrs: Optional[Dict[str, Any]],
|
||||
generic_type: GenericDataType,
|
||||
is_dttm: bool,
|
||||
) -> None:
|
||||
from superset.db_engine_specs.clickhouse import ClickHouseConnectEngineSpec as spec
|
||||
|
||||
assert_column_spec(spec, native_type, sqla_type, attrs, generic_type, is_dttm)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"column_name,expected_result",
|
||||
[
|
||||
("time", "time_07cc69"),
|
||||
("count", "count_e2942a"),
|
||||
],
|
||||
)
|
||||
def test_connect_make_label_compatible(column_name: str, expected_result: str) -> None:
|
||||
from superset.db_engine_specs.clickhouse import ClickHouseConnectEngineSpec as spec
|
||||
|
||||
label = spec.make_label_compatible(column_name)
|
||||
assert label == expected_result
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from textwrap import dedent
|
|||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import column, table, types
|
||||
from sqlalchemy import column, table
|
||||
from sqlalchemy.dialects import mssql
|
||||
from sqlalchemy.dialects.mssql import DATE, NTEXT, NVARCHAR, TEXT, VARCHAR
|
||||
from sqlalchemy.sql import select
|
||||
|
|
@ -50,7 +50,7 @@ from tests.unit_tests.fixtures.common import dttm
|
|||
)
|
||||
def test_get_column_spec(
|
||||
native_type: str,
|
||||
sqla_type: Type[types.TypeEngine],
|
||||
sqla_type: Type[TypeEngine],
|
||||
attrs: Optional[Dict[str, Any]],
|
||||
generic_type: GenericDataType,
|
||||
is_dttm: bool,
|
||||
|
|
|
|||
Loading…
Reference in New Issue