From 2a1235c0c28a62a444574b9c65e92d8f1d30df01 Mon Sep 17 00:00:00 2001 From: Ben Reinhart Date: Mon, 26 Apr 2021 14:04:40 -0700 Subject: [PATCH] fix: Cleanup serialization and hashing code (#14317) --- superset/common/query_object.py | 12 +--- superset/db_engine_specs/base.py | 4 +- superset/db_engine_specs/bigquery.py | 6 +- superset/models/dashboard.py | 3 +- superset/models/slice.py | 3 +- superset/utils/cache.py | 11 +--- superset/utils/core.py | 8 +-- superset/utils/hashing.py | 14 ++-- tests/utils/hashing_tests.py | 97 ++++++++++++++++++++++++++++ tests/utils_tests.py | 4 +- 10 files changed, 124 insertions(+), 38 deletions(-) create mode 100644 tests/utils/hashing_tests.py diff --git a/superset/common/query_object.py b/superset/common/query_object.py index ad28b2eed..8fd281cf3 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -15,12 +15,10 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=R -import hashlib import logging from datetime import datetime, timedelta from typing import Any, Dict, List, NamedTuple, Optional, Union -import simplejson as json from flask_babel import gettext as _ from pandas import DataFrame @@ -40,6 +38,7 @@ from superset.utils.core import ( json_int_dttm_ser, ) from superset.utils.date_parser import get_since_until, parse_human_timedelta +from superset.utils.hashing import md5_sha_from_dict from superset.views.utils import get_time_range_endpoints config = app.config @@ -333,14 +332,7 @@ class QueryObject: if annotation_layers: cache_dict["annotation_layers"] = annotation_layers - json_data = self.json_dumps(cache_dict, sort_keys=True) - return hashlib.md5(json_data.encode("utf-8")).hexdigest() - - @staticmethod - def json_dumps(obj: Any, sort_keys: bool = False) -> str: - return json.dumps( - obj, default=json_int_dttm_ser, ignore_nan=True, sort_keys=sort_keys - ) + return md5_sha_from_dict(cache_dict, default=json_int_dttm_ser, ignore_nan=True) def exec_post_processing(self, df: DataFrame) -> DataFrame: """ diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 7bbc4bd9b..7ea73e30f 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=unused-argument -import hashlib import json import logging import re @@ -63,6 +62,7 @@ from superset.models.sql_types.base import literal_dttm_type_factory from superset.sql_parse import ParsedQuery, Table from superset.utils import core as utils from superset.utils.core import ColumnSpec, GenericDataType +from superset.utils.hashing import md5_sha_from_str if TYPE_CHECKING: # prevent circular imports @@ -1145,7 +1145,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods :param label: Expected expression label :return: Truncated label """ - label = hashlib.md5(label.encode("utf-8")).hexdigest() + label = md5_sha_from_str(label) # truncate hash if it exceeds max length if cls.max_column_name_length and len(label) > cls.max_column_name_length: label = label[: cls.max_column_name_length] diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py index fd34bc988..7b8fa84e6 100644 --- a/superset/db_engine_specs/bigquery.py +++ b/superset/db_engine_specs/bigquery.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import hashlib import re from datetime import datetime from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING @@ -28,6 +27,7 @@ from superset.db_engine_specs.base import BaseEngineSpec from superset.errors import SupersetErrorType from superset.sql_parse import Table from superset.utils import core as utils +from superset.utils.hashing import md5_sha_from_str if TYPE_CHECKING: from superset.models.core import Database # pragma: no cover @@ -141,7 +141,7 @@ class BigQueryEngineSpec(BaseEngineSpec): :param label: Expected expression label :return: Conditionally mutated label """ - label_hashed = "_" + hashlib.md5(label.encode("utf-8")).hexdigest() + label_hashed = "_" + md5_sha_from_str(label) # if label starts with number, add underscore as first character label_mutated = "_" + label if re.match(r"^\d", label) else label @@ -163,7 +163,7 @@ class BigQueryEngineSpec(BaseEngineSpec): :param label: expected expression label :return: truncated label """ - return "_" + hashlib.md5(label.encode("utf-8")).hexdigest() + return "_" + md5_sha_from_str(label) @classmethod def normalize_indexes(cls, indexes: List[Dict[str, Any]]) -> List[Dict[str, Any]]: diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py index 61607d106..6f1e6bf16 100644 --- a/superset/models/dashboard.py +++ b/superset/models/dashboard.py @@ -56,6 +56,7 @@ from superset.models.user_attributes import UserAttribute from superset.tasks.thumbnails import cache_dashboard_thumbnail from superset.utils import core as utils from superset.utils.decorators import debounce +from superset.utils.hashing import md5_sha_from_str from superset.utils.urls import get_url_path # pylint: disable=too-many-public-methods @@ -199,7 +200,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes Returns a MD5 HEX digest that makes this dashboard unique """ unique_string = f"{self.position_json}.{self.css}.{self.json_metadata}" - return utils.md5_hex(unique_string) + return md5_sha_from_str(unique_string) @property def thumbnail_url(self) -> str: diff --git a/superset/models/slice.py b/superset/models/slice.py index fe41699fb..08421ca54 100644 --- a/superset/models/slice.py +++ b/superset/models/slice.py @@ -34,6 +34,7 @@ from superset.models.helpers import AuditMixinNullable, ImportExportMixin from superset.models.tags import ChartUpdater from superset.tasks.thumbnails import cache_chart_thumbnail from superset.utils import core as utils +from superset.utils.hashing import md5_sha_from_str from superset.utils.urls import get_url_path from superset.viz import BaseViz, viz_types # type: ignore @@ -202,7 +203,7 @@ class Slice( """ Returns a MD5 HEX digest that makes this dashboard unique """ - return utils.md5_hex(self.params or "") + return md5_sha_from_str(self.params or "") @property def thumbnail_url(self) -> str: diff --git a/superset/utils/cache.py b/superset/utils/cache.py index 0abd76a79..782a497d0 100644 --- a/superset/utils/cache.py +++ b/superset/utils/cache.py @@ -14,8 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import hashlib -import json import logging from datetime import datetime, timedelta from functools import wraps @@ -30,20 +28,15 @@ from superset.extensions import cache_manager from superset.models.cache import CacheKey from superset.stats_logger import BaseStatsLogger from superset.utils.core import json_int_dttm_ser +from superset.utils.hashing import md5_sha_from_dict config = app.config # type: ignore stats_logger: BaseStatsLogger = config["STATS_LOGGER"] logger = logging.getLogger(__name__) -# TODO: DRY up cache key code -def json_dumps(obj: Any, sort_keys: bool = False) -> str: - return json.dumps(obj, default=json_int_dttm_ser, sort_keys=sort_keys) - - def generate_cache_key(values_dict: Dict[str, Any], key_prefix: str = "") -> str: - json_data = json_dumps(values_dict, sort_keys=True) - hash_str = hashlib.md5(json_data.encode("utf-8")).hexdigest() + hash_str = md5_sha_from_dict(values_dict, default=json_int_dttm_ser) return f"{key_prefix}{hash_str}" diff --git a/superset/utils/core.py b/superset/utils/core.py index 4e9e47073..abda7b7ba 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -19,7 +19,6 @@ import collections import decimal import errno import functools -import hashlib import json import logging import os @@ -99,6 +98,7 @@ from superset.exceptions import ( ) from superset.typing import FlaskResponse, FormData, Metric from superset.utils.dates import datetime_to_epoch, EPOCH +from superset.utils.hashing import md5_sha_from_str try: from pydruid.utils.having import Having @@ -484,10 +484,6 @@ def list_minus(l: List[Any], minus: List[Any]) -> List[Any]: return [o for o in l if o not in minus] -def md5_hex(data: str) -> str: - return hashlib.md5(data.encode()).hexdigest() - - class DashboardEncoder(json.JSONEncoder): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -1381,7 +1377,7 @@ def create_ssl_cert_file(certificate: str) -> str: :return: The path to the certificate file :raises CertificateException: If certificate is not valid/unparseable """ - filename = f"{hashlib.md5(certificate.encode('utf-8')).hexdigest()}.crt" + filename = f"{md5_sha_from_str(certificate)}.crt" cert_dir = current_app.config["SSL_CERT_PATH"] path = cert_dir if cert_dir else tempfile.gettempdir() path = os.path.join(path, filename) diff --git a/superset/utils/hashing.py b/superset/utils/hashing.py index 72856d3d1..66983582c 100644 --- a/superset/utils/hashing.py +++ b/superset/utils/hashing.py @@ -15,14 +15,20 @@ # specific language governing permissions and limitations # under the License. import hashlib -import json -from typing import Any, Dict +from typing import Any, Callable, Dict, Optional + +import simplejson as json def md5_sha_from_str(val: str) -> str: return hashlib.md5(val.encode("utf-8")).hexdigest() -def md5_sha_from_dict(opts: Dict[Any, Any]) -> str: - json_data = json.dumps(opts, sort_keys=True) +def md5_sha_from_dict( + obj: Dict[Any, Any], + ignore_nan: bool = False, + default: Optional[Callable[[Any], Any]] = None, +) -> str: + json_data = json.dumps(obj, sort_keys=True, ignore_nan=ignore_nan, default=default) + return md5_sha_from_str(json_data) diff --git a/tests/utils/hashing_tests.py b/tests/utils/hashing_tests.py new file mode 100644 index 000000000..8931ff15c --- /dev/null +++ b/tests/utils/hashing_tests.py @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-self-use +import datetime +import math +from typing import Any + +import pytest + +from superset.utils.hashing import md5_sha_from_dict, md5_sha_from_str + + +def test_basic_md5_sha(): + obj = { + "product": "Coffee", + "company": "Gobias Industries", + "price_in_cents": 4000, + } + + serialized_obj = ( + '{"company": "Gobias Industries", "price_in_cents": 4000, "product": "Coffee"}' + ) + + assert md5_sha_from_str(serialized_obj) == md5_sha_from_dict(obj) + assert md5_sha_from_str(serialized_obj) == "35f22273cd6a6798b04f8ddef51135e3" + + +def test_sort_order_md5_sha(): + obj_1 = { + "product": "Coffee", + "price_in_cents": 4000, + "company": "Gobias Industries", + } + + obj_2 = { + "product": "Coffee", + "company": "Gobias Industries", + "price_in_cents": 4000, + } + + assert md5_sha_from_dict(obj_1) == md5_sha_from_dict(obj_2) + assert md5_sha_from_dict(obj_1) == "35f22273cd6a6798b04f8ddef51135e3" + + +def test_custom_default_md5_sha(): + def custom_datetime_serializer(obj: Any): + if isinstance(obj, datetime.datetime): + return "" + + obj = { + "product": "Coffee", + "company": "Gobias Industries", + "datetime": datetime.datetime.now(), + } + + serialized_obj = '{"company": "Gobias Industries", "datetime": "", "product": "Coffee"}' + + assert md5_sha_from_str(serialized_obj) == md5_sha_from_dict( + obj, default=custom_datetime_serializer + ) + assert md5_sha_from_str(serialized_obj) == "dc280121213aabcaeb8087aef268fd0d" + + +def test_ignore_nan_md5_sha(): + obj = { + "product": "Coffee", + "company": "Gobias Industries", + "price": math.nan, + } + + serialized_obj = ( + '{"company": "Gobias Industries", "price": NaN, "product": "Coffee"}' + ) + + assert md5_sha_from_str(serialized_obj) == md5_sha_from_dict(obj) + assert md5_sha_from_str(serialized_obj) == "5d129d1dffebc0bacc734366476d586d" + + serialized_obj = ( + '{"company": "Gobias Industries", "price": null, "product": "Coffee"}' + ) + + assert md5_sha_from_str(serialized_obj) == md5_sha_from_dict(obj, ignore_nan=True) + assert md5_sha_from_str(serialized_obj) == "40e87d61f6add03816bccdeac5713b9f" diff --git a/tests/utils_tests.py b/tests/utils_tests.py index cf79f16aa..c0546245a 100644 --- a/tests/utils_tests.py +++ b/tests/utils_tests.py @@ -19,7 +19,6 @@ import unittest import uuid from datetime import date, datetime, time, timedelta from decimal import Decimal -import hashlib import json import os import re @@ -71,6 +70,7 @@ from superset.utils.core import ( zlib_decompress, ) from superset.utils import schema +from superset.utils.hashing import md5_sha_from_str from superset.views.utils import ( build_extra_filters, get_form_data, @@ -960,7 +960,7 @@ class TestUtils(SupersetTestCase): def test_ssl_certificate_file_creation(self): path = create_ssl_cert_file(ssl_certificate) - expected_filename = hashlib.md5(ssl_certificate.encode("utf-8")).hexdigest() + expected_filename = md5_sha_from_str(ssl_certificate) self.assertIn(expected_filename, path) self.assertTrue(os.path.exists(path))