fix(clickhouse): add clickhouse connect driver (#23185)

This commit is contained in:
Ville Brofeldt 2023-02-24 14:04:12 +02:00 committed by GitHub
parent f0f27a486d
commit d0c54cddb0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 456 additions and 29 deletions

View File

@ -72,7 +72,6 @@ from superset.utils.hashing import md5_sha_from_str
from superset.utils.network import is_hostname_valid, is_port_open
if TYPE_CHECKING:
# prevent circular imports
from superset.connectors.sqla.models import TableColumn
from superset.models.core import Database
from superset.models.sql_lab import Query

View File

@ -14,29 +14,43 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING
from __future__ import annotations
import logging
import re
from datetime import datetime
from typing import Any, cast, Dict, List, Optional, Type, TYPE_CHECKING
from flask import current_app
from flask_babel import gettext as __
from marshmallow import fields, Schema
from marshmallow.validate import Range
from sqlalchemy import types
from sqlalchemy.engine.url import URL
from urllib3.exceptions import NewConnectionError
from superset.db_engine_specs.base import BaseEngineSpec
from superset.databases.utils import make_url_safe
from superset.db_engine_specs.base import (
BaseEngineSpec,
BasicParametersMixin,
BasicParametersType,
BasicPropertiesType,
)
from superset.db_engine_specs.exceptions import SupersetDBAPIDatabaseError
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.extensions import cache_manager
from superset.utils.core import GenericDataType
from superset.utils.hashing import md5_sha_from_str
from superset.utils.network import is_hostname_valid, is_port_open
if TYPE_CHECKING:
# prevent circular imports
from superset.models.core import Database
logger = logging.getLogger(__name__)
class ClickHouseEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
"""Dialect for ClickHouse analytical DB."""
engine = "clickhouse"
engine_name = "ClickHouse"
class ClickHouseBaseEngineSpec(BaseEngineSpec):
"""Shared engine spec for ClickHouse."""
time_secondary_columns = True
time_groupby_inline = True
@ -56,8 +70,78 @@ class ClickHouseEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
"P1Y": "toStartOfYear(toDateTime({col}))",
}
_show_functions_column = "name"
column_type_mappings = (
(
re.compile(r".*Enum.*", re.IGNORECASE),
types.String(),
GenericDataType.STRING,
),
(
re.compile(r".*Array.*", re.IGNORECASE),
types.String(),
GenericDataType.STRING,
),
(
re.compile(r".*UUID.*", re.IGNORECASE),
types.String(),
GenericDataType.STRING,
),
(
re.compile(r".*Bool.*", re.IGNORECASE),
types.Boolean(),
GenericDataType.BOOLEAN,
),
(
re.compile(r".*String.*", re.IGNORECASE),
types.String(),
GenericDataType.STRING,
),
(
re.compile(r".*Int\d+.*", re.IGNORECASE),
types.INTEGER(),
GenericDataType.NUMERIC,
),
(
re.compile(r".*Decimal.*", re.IGNORECASE),
types.DECIMAL(),
GenericDataType.NUMERIC,
),
(
re.compile(r".*DateTime.*", re.IGNORECASE),
types.DateTime(),
GenericDataType.TEMPORAL,
),
(
re.compile(r".*Date.*", re.IGNORECASE),
types.Date(),
GenericDataType.TEMPORAL,
),
)
@classmethod
def epoch_to_dttm(cls) -> str:
return "{col}"
@classmethod
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date):
return f"toDate('{dttm.date().isoformat()}')"
if isinstance(sqla_type, types.DateTime):
return f"""toDateTime('{dttm.isoformat(sep=" ", timespec="seconds")}')"""
return None
class ClickHouseEngineSpec(ClickHouseBaseEngineSpec):
"""Engine spec for clickhouse_sqlalchemy connector"""
engine = "clickhouse"
engine_name = "ClickHouse"
_show_functions_column = "name"
supports_file_upload = False
@classmethod
@ -73,21 +157,9 @@ class ClickHouseEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
return exception
return new_exception(str(exception))
@classmethod
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date):
return f"toDate('{dttm.date().isoformat()}')"
if isinstance(sqla_type, types.DateTime):
return f"""toDateTime('{dttm.isoformat(sep=" ", timespec="seconds")}')"""
return None
@classmethod
@cache_manager.cache.memoize()
def get_function_names(cls, database: "Database") -> List[str]:
def get_function_names(cls, database: Database) -> List[str]:
"""
Get a list of function names that are able to be called on the database.
Used for SQL Lab autocomplete.
@ -123,3 +195,201 @@ class ClickHouseEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
# otherwise, return no function names to prevent errors
return []
class ClickHouseParametersSchema(Schema):
username = fields.String(allow_none=True, description=__("Username"))
password = fields.String(allow_none=True, description=__("Password"))
host = fields.String(required=True, description=__("Hostname or IP address"))
port = fields.Integer(
allow_none=True,
description=__("Database port"),
validate=Range(min=0, max=65535),
)
database = fields.String(allow_none=True, description=__("Database name"))
encryption = fields.Boolean(
default=True, description=__("Use an encrypted connection to the database")
)
query = fields.Dict(
keys=fields.Str(), values=fields.Raw(), description=__("Additional parameters")
)
try:
from clickhouse_connect.common import set_setting
from clickhouse_connect.datatypes.format import set_default_formats
# override default formats for compatibility
set_default_formats(
"FixedString",
"string",
"IPv*",
"string",
"signed",
"UUID",
"string",
"*Int256",
"string",
"*Int128",
"string",
)
set_setting(
"product_name",
f"superset/{current_app.config.get('VERSION_STRING', 'dev')}",
)
except ImportError: # ClickHouse Connect not installed, do nothing
pass
class ClickHouseConnectEngineSpec(ClickHouseEngineSpec, BasicParametersMixin):
"""Engine spec for clickhouse-connect connector"""
engine = "clickhousedb"
engine_name = "ClickHouse Connect"
default_driver = "connect"
_function_names: List[str] = []
sqlalchemy_uri_placeholder = (
"clickhousedb://user:password@host[:port][/dbname][?secure=value&=value...]"
)
parameters_schema = ClickHouseParametersSchema()
encryption_parameters = {"secure": "true"}
@classmethod
def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
return {}
@classmethod
def get_dbapi_mapped_exception(cls, exception: Exception) -> Exception:
new_exception = cls.get_dbapi_exception_mapping().get(type(exception))
if new_exception == SupersetDBAPIDatabaseError:
return SupersetDBAPIDatabaseError("Connection failed")
if not new_exception:
return exception
return new_exception(str(exception))
@classmethod
def get_function_names(cls, database: Database) -> List[str]:
# pylint: disable=import-outside-toplevel,import-error
from clickhouse_connect.driver.exceptions import ClickHouseError
if cls._function_names:
return cls._function_names
try:
names = database.get_df(
"SELECT name FROM system.functions UNION ALL "
+ "SELECT name FROM system.table_functions LIMIT 10000"
)["name"].tolist()
cls._function_names = names
return names
except ClickHouseError:
logger.exception("Error retrieving system.functions")
return []
@classmethod
def get_datatype(cls, type_code: str) -> str:
# keep it lowercase, as ClickHouse types aren't typical SHOUTCASE ANSI SQL
return type_code
@classmethod
def build_sqlalchemy_uri(
cls,
parameters: BasicParametersType,
encrypted_extra: Optional[Dict[str, str]] = None,
) -> str:
url_params = parameters.copy()
if url_params.get("encryption"):
query = parameters.get("query", {}).copy()
query.update(cls.encryption_parameters)
url_params["query"] = query
if not url_params.get("database"):
url_params["database"] = "__default__"
url_params.pop("encryption", None)
return str(URL(f"{cls.engine}+{cls.default_driver}", **url_params))
@classmethod
def get_parameters_from_uri(
cls, uri: str, encrypted_extra: Optional[Dict[str, Any]] = None
) -> BasicParametersType:
url = make_url_safe(uri)
query = url.query
if "secure" in query:
encryption = url.query.get("secure") == "true"
query.pop("secure")
else:
encryption = False
return BasicParametersType(
username=url.username,
password=url.password,
host=url.host,
port=url.port,
database="" if url.database == "__default__" else cast(str, url.database),
query=dict(query),
encryption=encryption,
)
@classmethod
def validate_parameters(
cls, properties: BasicPropertiesType
) -> List[SupersetError]:
# pylint: disable=import-outside-toplevel,import-error
from clickhouse_connect.driver import default_port
parameters = properties.get("parameters", {})
host = parameters.get("host", None)
if not host:
return [
SupersetError(
"Hostname is required",
SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR,
ErrorLevel.WARNING,
{"missing": ["host"]},
)
]
if not is_hostname_valid(host):
return [
SupersetError(
"The hostname provided can't be resolved.",
SupersetErrorType.CONNECTION_INVALID_HOSTNAME_ERROR,
ErrorLevel.ERROR,
{"invalid": ["host"]},
)
]
port = parameters.get("port")
if port is None:
port = default_port("http", parameters.get("encryption", False))
try:
port = int(port)
except (ValueError, TypeError):
port = -1
if port <= 0 or port >= 65535:
return [
SupersetError(
"Port must be a valid integer between 0 and 65535 (inclusive).",
SupersetErrorType.CONNECTION_INVALID_PORT_ERROR,
ErrorLevel.ERROR,
{"invalid": ["port"]},
)
]
if not is_port_open(host, port):
return [
SupersetError(
"The port is closed.",
SupersetErrorType.CONNECTION_PORT_CLOSED_ERROR,
ErrorLevel.ERROR,
{"invalid": ["port"]},
)
]
return []
@staticmethod
def _mutate_label(label: str) -> str:
"""
Suffix with the first six characters from the md5 of the label to avoid
collisions with original column names
:param label: Expected expression label
:return: Conditionally mutated label
"""
return f"{label}_{md5_sha_from_str(label)[:6]}"

View File

@ -16,12 +16,26 @@
# under the License.
from datetime import datetime
from typing import Optional
from typing import Any, Dict, Optional, Type
from unittest.mock import Mock
import pytest
from sqlalchemy.types import (
Boolean,
Date,
DateTime,
DECIMAL,
Float,
Integer,
String,
TypeEngine,
)
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
assert_column_spec,
assert_convert_dttm,
)
from tests.unit_tests.fixtures.common import dttm
@ -53,3 +67,147 @@ def test_execute_connection_error() -> None:
)
with pytest.raises(SupersetDBAPIDatabaseError) as ex:
ClickHouseEngineSpec.execute(cursor, "SELECT col1 from table1")
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Date", "toDate('2019-01-02')"),
("DateTime", "toDateTime('2019-01-02 03:04:05')"),
("UnknownType", None),
],
)
def test_connect_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.clickhouse import ClickHouseEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)
@pytest.mark.parametrize(
"native_type,sqla_type,attrs,generic_type,is_dttm",
[
("String", String, None, GenericDataType.STRING, False),
("LowCardinality(String)", String, None, GenericDataType.STRING, False),
("Nullable(String)", String, None, GenericDataType.STRING, False),
(
"LowCardinality(Nullable(String))",
String,
None,
GenericDataType.STRING,
False,
),
("Array(UInt8)", String, None, GenericDataType.STRING, False),
("Enum('hello', 'world')", String, None, GenericDataType.STRING, False),
("Enum('UInt32', 'Bool')", String, None, GenericDataType.STRING, False),
(
"LowCardinality(Enum('hello', 'world'))",
String,
None,
GenericDataType.STRING,
False,
),
(
"Nullable(Enum('hello', 'world'))",
String,
None,
GenericDataType.STRING,
False,
),
(
"LowCardinality(Nullable(Enum('hello', 'world')))",
String,
None,
GenericDataType.STRING,
False,
),
("FixedString(16)", String, None, GenericDataType.STRING, False),
("Nullable(FixedString(16))", String, None, GenericDataType.STRING, False),
(
"LowCardinality(Nullable(FixedString(16)))",
String,
None,
GenericDataType.STRING,
False,
),
("UUID", String, None, GenericDataType.STRING, False),
("Int8", Integer, None, GenericDataType.NUMERIC, False),
("Int16", Integer, None, GenericDataType.NUMERIC, False),
("Int32", Integer, None, GenericDataType.NUMERIC, False),
("Int64", Integer, None, GenericDataType.NUMERIC, False),
("Int128", Integer, None, GenericDataType.NUMERIC, False),
("Int256", Integer, None, GenericDataType.NUMERIC, False),
("Nullable(Int256)", Integer, None, GenericDataType.NUMERIC, False),
(
"LowCardinality(Nullable(Int256))",
Integer,
None,
GenericDataType.NUMERIC,
False,
),
("UInt8", Integer, None, GenericDataType.NUMERIC, False),
("UInt16", Integer, None, GenericDataType.NUMERIC, False),
("UInt32", Integer, None, GenericDataType.NUMERIC, False),
("UInt64", Integer, None, GenericDataType.NUMERIC, False),
("UInt128", Integer, None, GenericDataType.NUMERIC, False),
("UInt256", Integer, None, GenericDataType.NUMERIC, False),
("Nullable(UInt256)", Integer, None, GenericDataType.NUMERIC, False),
(
"LowCardinality(Nullable(UInt256))",
Integer,
None,
GenericDataType.NUMERIC,
False,
),
("Float32", Float, None, GenericDataType.NUMERIC, False),
("Float64", Float, None, GenericDataType.NUMERIC, False),
("Decimal(1, 2)", DECIMAL, None, GenericDataType.NUMERIC, False),
("Decimal32(2)", DECIMAL, None, GenericDataType.NUMERIC, False),
("Decimal64(2)", DECIMAL, None, GenericDataType.NUMERIC, False),
("Decimal128(2)", DECIMAL, None, GenericDataType.NUMERIC, False),
("Decimal256(2)", DECIMAL, None, GenericDataType.NUMERIC, False),
("Bool", Boolean, None, GenericDataType.BOOLEAN, False),
("Nullable(Bool)", Boolean, None, GenericDataType.BOOLEAN, False),
("Date", Date, None, GenericDataType.TEMPORAL, True),
("Nullable(Date)", Date, None, GenericDataType.TEMPORAL, True),
("LowCardinality(Nullable(Date))", Date, None, GenericDataType.TEMPORAL, True),
("Date32", Date, None, GenericDataType.TEMPORAL, True),
("Datetime", DateTime, None, GenericDataType.TEMPORAL, True),
("Nullable(Datetime)", DateTime, None, GenericDataType.TEMPORAL, True),
(
"LowCardinality(Nullable(Datetime))",
DateTime,
None,
GenericDataType.TEMPORAL,
True,
),
("Datetime('UTC')", DateTime, None, GenericDataType.TEMPORAL, True),
("Datetime64(3)", DateTime, None, GenericDataType.TEMPORAL, True),
("Datetime64(3, 'UTC')", DateTime, None, GenericDataType.TEMPORAL, True),
],
)
def test_connect_get_column_spec(
native_type: str,
sqla_type: Type[TypeEngine],
attrs: Optional[Dict[str, Any]],
generic_type: GenericDataType,
is_dttm: bool,
) -> None:
from superset.db_engine_specs.clickhouse import ClickHouseConnectEngineSpec as spec
assert_column_spec(spec, native_type, sqla_type, attrs, generic_type, is_dttm)
@pytest.mark.parametrize(
"column_name,expected_result",
[
("time", "time_07cc69"),
("count", "count_e2942a"),
],
)
def test_connect_make_label_compatible(column_name: str, expected_result: str) -> None:
from superset.db_engine_specs.clickhouse import ClickHouseConnectEngineSpec as spec
label = spec.make_label_compatible(column_name)
assert label == expected_result

View File

@ -20,7 +20,7 @@ from textwrap import dedent
from typing import Any, Dict, Optional, Type
import pytest
from sqlalchemy import column, table, types
from sqlalchemy import column, table
from sqlalchemy.dialects import mssql
from sqlalchemy.dialects.mssql import DATE, NTEXT, NVARCHAR, TEXT, VARCHAR
from sqlalchemy.sql import select
@ -50,7 +50,7 @@ from tests.unit_tests.fixtures.common import dttm
)
def test_get_column_spec(
native_type: str,
sqla_type: Type[types.TypeEngine],
sqla_type: Type[TypeEngine],
attrs: Optional[Dict[str, Any]],
generic_type: GenericDataType,
is_dttm: bool,