192 lines
6.8 KiB
Python
192 lines
6.8 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
import logging
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from flask import Flask
|
|
from flask_babel import lazy_gettext as _
|
|
from sqlalchemy import text, TypeDecorator
|
|
from sqlalchemy.engine import Connection, Dialect, Row
|
|
from sqlalchemy_utils import EncryptedType
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AbstractEncryptedFieldAdapter(ABC): # pylint: disable=too-few-public-methods
|
|
@abstractmethod
|
|
def create(
|
|
self,
|
|
app_config: Optional[Dict[str, Any]],
|
|
*args: List[Any],
|
|
**kwargs: Optional[Dict[str, Any]],
|
|
) -> TypeDecorator:
|
|
pass
|
|
|
|
|
|
class SQLAlchemyUtilsAdapter( # pylint: disable=too-few-public-methods
|
|
AbstractEncryptedFieldAdapter
|
|
):
|
|
def create(
|
|
self,
|
|
app_config: Optional[Dict[str, Any]],
|
|
*args: List[Any],
|
|
**kwargs: Optional[Dict[str, Any]],
|
|
) -> TypeDecorator:
|
|
if app_config:
|
|
return EncryptedType(*args, app_config["SECRET_KEY"], **kwargs)
|
|
|
|
raise Exception("Missing app_config kwarg")
|
|
|
|
|
|
class EncryptedFieldFactory:
|
|
def __init__(self) -> None:
|
|
self._concrete_type_adapter: Optional[AbstractEncryptedFieldAdapter] = None
|
|
self._config: Optional[Dict[str, Any]] = None
|
|
|
|
def init_app(self, app: Flask) -> None:
|
|
self._config = app.config
|
|
self._concrete_type_adapter = self._config[ # type: ignore
|
|
"SQLALCHEMY_ENCRYPTED_FIELD_TYPE_ADAPTER"
|
|
]()
|
|
|
|
def create(
|
|
self, *args: List[Any], **kwargs: Optional[Dict[str, Any]]
|
|
) -> TypeDecorator:
|
|
if self._concrete_type_adapter:
|
|
return self._concrete_type_adapter.create(self._config, *args, **kwargs)
|
|
|
|
raise Exception("App not initialized yet. Please call init_app first")
|
|
|
|
|
|
class SecretsMigrator:
|
|
def __init__(self, previous_secret_key: str) -> None:
|
|
from superset import db # pylint: disable=import-outside-toplevel
|
|
|
|
self._db = db
|
|
self._previous_secret_key = previous_secret_key
|
|
self._dialect: Dialect = db.engine.url.get_dialect()
|
|
|
|
def discover_encrypted_fields(self) -> Dict[str, Dict[str, EncryptedType]]:
|
|
"""
|
|
Iterates over SqlAlchemy's metadata, looking for EncryptedType
|
|
columns along the way. Builds up a dict of
|
|
table_name -> dict of col_name: enc type instance
|
|
:return:
|
|
"""
|
|
meta_info: Dict[str, Any] = {}
|
|
|
|
for table_name, table in self._db.metadata.tables.items():
|
|
for col_name, col in table.columns.items():
|
|
if isinstance(col.type, EncryptedType):
|
|
cols = meta_info.get(table_name, {})
|
|
cols[col_name] = col.type
|
|
meta_info[table_name] = cols
|
|
|
|
return meta_info
|
|
|
|
@staticmethod
|
|
def _read_bytes(col_name: str, value: Any) -> Optional[bytes]:
|
|
if value is None or isinstance(value, bytes):
|
|
return value
|
|
# Note that the Postgres Driver returns memoryview's for BLOB types
|
|
if isinstance(value, memoryview):
|
|
return value.tobytes()
|
|
if isinstance(value, str):
|
|
return bytes(value.encode("utf8"))
|
|
|
|
# Just bail if we haven't seen this type before...
|
|
raise ValueError(
|
|
_(
|
|
"DB column %(col_name)s has unknown type: %(value_type)s",
|
|
col_name=col_name,
|
|
value_type=type(value),
|
|
)
|
|
)
|
|
|
|
@staticmethod
|
|
def _select_columns_from_table(
|
|
conn: Connection, column_names: List[str], table_name: str
|
|
) -> Row:
|
|
return conn.execute(f"SELECT id, {','.join(column_names)} FROM {table_name}")
|
|
|
|
def _re_encrypt_row(
|
|
self,
|
|
conn: Connection,
|
|
row: Row,
|
|
table_name: str,
|
|
columns: Dict[str, EncryptedType],
|
|
) -> None:
|
|
"""
|
|
Re encrypts all columns in a Row
|
|
:param row: Current row to reencrypt
|
|
:param columns: Meta info from columns
|
|
"""
|
|
re_encrypted_columns = {}
|
|
|
|
for column_name, encrypted_type in columns.items():
|
|
previous_encrypted_type = EncryptedType(
|
|
type_in=encrypted_type.underlying_type, key=self._previous_secret_key
|
|
)
|
|
try:
|
|
unencrypted_value = previous_encrypted_type.process_result_value(
|
|
self._read_bytes(column_name, row[column_name]), self._dialect
|
|
)
|
|
except ValueError as exc:
|
|
# Failed to unencrypt
|
|
try:
|
|
encrypted_type.process_result_value(
|
|
self._read_bytes(column_name, row[column_name]), self._dialect
|
|
)
|
|
logger.info(
|
|
"Current secret is able to decrypt value on column [%s.%s],"
|
|
" nothing to do",
|
|
table_name,
|
|
column_name,
|
|
)
|
|
return
|
|
except Exception:
|
|
raise Exception from exc
|
|
|
|
re_encrypted_columns[column_name] = encrypted_type.process_bind_param(
|
|
unencrypted_value,
|
|
self._dialect,
|
|
)
|
|
|
|
set_cols = ",".join(
|
|
[f"{name} = :{name}" for name in list(re_encrypted_columns.keys())]
|
|
)
|
|
logger.info("Processing table: %s", table_name)
|
|
conn.execute(
|
|
text(f"UPDATE {table_name} SET {set_cols} WHERE id = :id"),
|
|
id=row["id"],
|
|
**re_encrypted_columns,
|
|
)
|
|
|
|
def run(self) -> None:
|
|
encrypted_meta_info = self.discover_encrypted_fields()
|
|
|
|
with self._db.engine.begin() as conn:
|
|
logger.info("Collecting info for re encryption")
|
|
for table_name, columns in encrypted_meta_info.items():
|
|
column_names = list(columns.keys())
|
|
rows = self._select_columns_from_table(conn, column_names, table_name)
|
|
|
|
for row in rows:
|
|
self._re_encrypt_row(conn, row, table_name, columns)
|
|
logger.info("All tables processed")
|