[SQL Lab] Async query results serialization with MessagePack and PyArrow (#8069)
* Add support for msgpack results_backend serialization * Serialize DataFrame with PyArrow rather than JSON * Adjust dependencies, de-lint * Add tests for (de)serialization methods * Add MessagePack config info to Installation docs * Enable msgpack/arrow serialization by default * [Fix] Prevent msgpack serialization on synchronous queries * Add type annotations
This commit is contained in:
parent
56566c2645
commit
7595d9e5fd
|
|
@ -23,6 +23,12 @@ assists people when migrating to a new version.
|
|||
|
||||
## Next Version
|
||||
|
||||
* [8069](https://github.com/apache/incubator-superset/pull/8069): introduces
|
||||
[MessagePack](https://github.com/msgpack/msgpack-python) and
|
||||
[PyArrow](https://arrow.apache.org/docs/python/) for async query results
|
||||
backend serialization. To disable set `RESULTS_BACKEND_USE_MSGPACK = False`
|
||||
in your configuration.
|
||||
|
||||
* [7848](https://github.com/apache/incubator-superset/pull/7848): If you are
|
||||
running redis with celery, celery bump to 4.3.0 requires redis-py upgrade to
|
||||
3.2.0 or later.
|
||||
|
|
|
|||
|
|
@ -846,6 +846,12 @@ look something like:
|
|||
RESULTS_BACKEND = RedisCache(
|
||||
host='localhost', port=6379, key_prefix='superset_results')
|
||||
|
||||
For performance gains, `MessagePack <https://github.com/msgpack/msgpack-python>`_
|
||||
and `PyArrow <https://arrow.apache.org/docs/python/>`_ are now used for results
|
||||
serialization. This can be disabled by setting ``RESULTS_BACKEND_USE_MSGPACK = False``
|
||||
in your configuration, should any issues arise. Please clear your existing results
|
||||
cache store when upgrading an existing environment.
|
||||
|
||||
**Important notes**
|
||||
|
||||
* It is important that all the worker nodes and web servers in
|
||||
|
|
|
|||
|
|
@ -49,13 +49,15 @@ markupsafe==1.1.1 # via jinja2, mako
|
|||
marshmallow-enum==1.4.1 # via flask-appbuilder
|
||||
marshmallow-sqlalchemy==0.17.0 # via flask-appbuilder
|
||||
marshmallow==2.19.5 # via flask-appbuilder, marshmallow-enum, marshmallow-sqlalchemy
|
||||
numpy==1.17.0 # via pandas
|
||||
msgpack==0.6.1
|
||||
numpy==1.17.0 # via pandas, pyarrow
|
||||
pandas==0.24.2
|
||||
parsedatetime==2.4
|
||||
pathlib2==2.3.4
|
||||
polyline==1.4.0
|
||||
prison==0.1.2 # via flask-appbuilder
|
||||
py==1.8.0 # via retry
|
||||
pyarrow==0.14.1
|
||||
pycparser==2.19 # via cffi
|
||||
pyjwt==1.7.1 # via flask-appbuilder, flask-jwt-extended
|
||||
pyrsistent==0.15.4 # via jsonschema
|
||||
|
|
@ -69,7 +71,7 @@ pyyaml==5.1.2
|
|||
retry==0.9.2
|
||||
selenium==3.141.0
|
||||
simplejson==3.16.0
|
||||
six==1.12.0 # via bleach, cryptography, flask-jwt-extended, flask-talisman, isodate, jsonschema, pathlib2, polyline, prison, pyrsistent, python-dateutil, sqlalchemy-utils, wtforms-json
|
||||
six==1.12.0 # via bleach, cryptography, flask-jwt-extended, flask-talisman, isodate, jsonschema, pathlib2, polyline, prison, pyarrow, pyrsistent, python-dateutil, sqlalchemy-utils, wtforms-json
|
||||
sqlalchemy-utils==0.34.1
|
||||
sqlalchemy==1.3.6
|
||||
sqlparse==0.3.0
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -84,6 +84,7 @@ setup(
|
|||
"humanize",
|
||||
"isodate",
|
||||
"markdown>=3.0",
|
||||
"msgpack>=0.6.1, <0.7.0",
|
||||
"pandas>=0.24.2, <0.25.0",
|
||||
"parsedatetime",
|
||||
"pathlib2",
|
||||
|
|
@ -91,6 +92,7 @@ setup(
|
|||
"python-dateutil",
|
||||
"python-dotenv",
|
||||
"python-geohash",
|
||||
"pyarrow>=0.14.1, <0.15.0",
|
||||
"pyyaml>=5.1",
|
||||
"retry>=0.9.2",
|
||||
"selenium>=3.141.0",
|
||||
|
|
|
|||
|
|
@ -193,6 +193,7 @@ with app.app_context():
|
|||
security_manager = appbuilder.sm
|
||||
|
||||
results_backend = app.config.get("RESULTS_BACKEND")
|
||||
results_backend_use_msgpack = app.config.get("RESULTS_BACKEND_USE_MSGPACK")
|
||||
|
||||
# Merge user defined feature flags with default feature flags
|
||||
_feature_flags = app.config.get("DEFAULT_FEATURE_FLAGS") or {}
|
||||
|
|
|
|||
|
|
@ -440,6 +440,12 @@ SQLLAB_ASYNC_TIME_LIMIT_SEC = 60 * 60 * 6
|
|||
# in SQL Lab by using the "Run Async" button/feature
|
||||
RESULTS_BACKEND = None
|
||||
|
||||
# Use PyArrow and MessagePack for async query results serialization,
|
||||
# rather than JSON. This feature requires additional testing from the
|
||||
# community before it is fully adopted, so this config option is provided
|
||||
# in order to disable should breaking issues be discovered.
|
||||
RESULTS_BACKEND_USE_MSGPACK = True
|
||||
|
||||
# The S3 bucket where you want to store your external hive tables created
|
||||
# from CSV files. For example, 'companyname-superset'
|
||||
CSV_TO_HIVE_UPLOAD_S3_BUCKET = None
|
||||
|
|
|
|||
|
|
@ -100,19 +100,27 @@ class SupersetDataFrame(object):
|
|||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
@property
|
||||
def raw_df(self):
|
||||
return self.df
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
return len(self.df.index)
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self.format_data(self.df)
|
||||
|
||||
@classmethod
|
||||
def format_data(cls, df):
|
||||
# work around for https://github.com/pandas-dev/pandas/issues/18372
|
||||
data = [
|
||||
dict(
|
||||
(k, maybe_box_datetimelike(v))
|
||||
for k, v in zip(self.df.columns, np.atleast_1d(row))
|
||||
for k, v in zip(df.columns, np.atleast_1d(row))
|
||||
)
|
||||
for row in self.df.values
|
||||
for row in df.values
|
||||
]
|
||||
for d in data:
|
||||
for k, v in list(d.items()):
|
||||
|
|
|
|||
|
|
@ -18,18 +18,30 @@
|
|||
from contextlib import closing
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from sys import getsizeof
|
||||
from time import sleep
|
||||
from typing import Optional, Tuple, Union
|
||||
import uuid
|
||||
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from contextlib2 import contextmanager
|
||||
from flask_babel import lazy_gettext as _
|
||||
import msgpack
|
||||
import pyarrow as pa
|
||||
import simplejson as json
|
||||
import sqlalchemy
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from superset import app, dataframe, db, results_backend, security_manager
|
||||
from superset import (
|
||||
app,
|
||||
db,
|
||||
results_backend,
|
||||
results_backend_use_msgpack,
|
||||
security_manager,
|
||||
)
|
||||
from superset.dataframe import SupersetDataFrame
|
||||
from superset.db_engine_specs import BaseEngineSpec
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.sql_parse import ParsedQuery
|
||||
from superset.tasks.celery_app import app as celery_app
|
||||
|
|
@ -226,7 +238,46 @@ def execute_sql_statement(sql_statement, query, user_name, session, cursor):
|
|||
|
||||
logging.debug(f"Query {query_id}: Fetching cursor description")
|
||||
cursor_description = cursor.description
|
||||
return dataframe.SupersetDataFrame(data, cursor_description, db_engine_spec)
|
||||
return SupersetDataFrame(data, cursor_description, db_engine_spec)
|
||||
|
||||
|
||||
def _serialize_payload(
|
||||
payload: dict, use_msgpack: Optional[bool] = False
|
||||
) -> Union[bytes, str]:
|
||||
logging.debug(f"Serializing to msgpack: {use_msgpack}")
|
||||
if use_msgpack:
|
||||
return msgpack.dumps(payload, default=json_iso_dttm_ser, use_bin_type=True)
|
||||
else:
|
||||
return json.dumps(payload, default=json_iso_dttm_ser, ignore_nan=True)
|
||||
|
||||
|
||||
def _serialize_and_expand_data(
|
||||
cdf: SupersetDataFrame,
|
||||
db_engine_spec: BaseEngineSpec,
|
||||
use_msgpack: Optional[bool] = False,
|
||||
) -> Tuple[Union[bytes, str], list, list, list]:
|
||||
selected_columns: list = cdf.columns or []
|
||||
expanded_columns: list
|
||||
|
||||
if use_msgpack:
|
||||
with stats_timing(
|
||||
"sqllab.query.results_backend_pa_serialization", stats_logger
|
||||
):
|
||||
data = (
|
||||
pa.default_serialization_context()
|
||||
.serialize(cdf.raw_df)
|
||||
.to_buffer()
|
||||
.to_pybytes()
|
||||
)
|
||||
# expand when loading data from results backend
|
||||
all_columns, expanded_columns = (selected_columns, [])
|
||||
else:
|
||||
data = cdf.data or []
|
||||
all_columns, data, expanded_columns = db_engine_spec.expand_data(
|
||||
selected_columns, data
|
||||
)
|
||||
|
||||
return (data, selected_columns, all_columns, expanded_columns)
|
||||
|
||||
|
||||
def execute_sql_statements(
|
||||
|
|
@ -310,10 +361,8 @@ def execute_sql_statements(
|
|||
)
|
||||
query.end_time = now_as_float()
|
||||
|
||||
selected_columns = cdf.columns or []
|
||||
data = cdf.data or []
|
||||
all_columns, data, expanded_columns = db_engine_spec.expand_data(
|
||||
selected_columns, data
|
||||
data, selected_columns, all_columns, expanded_columns = _serialize_and_expand_data(
|
||||
cdf, db_engine_spec, store_results and results_backend_use_msgpack
|
||||
)
|
||||
|
||||
payload.update(
|
||||
|
|
@ -334,13 +383,22 @@ def execute_sql_statements(
|
|||
f"Query {query_id}: Storing results in results backend, key: {key}"
|
||||
)
|
||||
with stats_timing("sqllab.query.results_backend_write", stats_logger):
|
||||
json_payload = json.dumps(
|
||||
payload, default=json_iso_dttm_ser, ignore_nan=True
|
||||
)
|
||||
with stats_timing(
|
||||
"sqllab.query.results_backend_write_serialization", stats_logger
|
||||
):
|
||||
serialized_payload = _serialize_payload(
|
||||
payload, results_backend_use_msgpack
|
||||
)
|
||||
cache_timeout = database.cache_timeout
|
||||
if cache_timeout is None:
|
||||
cache_timeout = config.get("CACHE_DEFAULT_TIMEOUT", 0)
|
||||
results_backend.set(key, zlib_compress(json_payload), cache_timeout)
|
||||
|
||||
compressed = zlib_compress(serialized_payload)
|
||||
logging.debug(
|
||||
f"*** serialized payload size: {getsizeof(serialized_payload)}"
|
||||
)
|
||||
logging.debug(f"*** compressed payload size: {getsizeof(compressed)}")
|
||||
results_backend.set(key, compressed, cache_timeout)
|
||||
query.results_key = key
|
||||
|
||||
query.status = QueryStatus.SUCCESS
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ import smtplib
|
|||
import sys
|
||||
from time import struct_time
|
||||
import traceback
|
||||
from typing import List, NamedTuple, Optional, Tuple
|
||||
from typing import List, NamedTuple, Optional, Tuple, Union
|
||||
from urllib.parse import unquote_plus
|
||||
import uuid
|
||||
import zlib
|
||||
|
|
@ -803,12 +803,12 @@ def zlib_compress(data):
|
|||
return zlib.compress(data)
|
||||
|
||||
|
||||
def zlib_decompress_to_string(blob):
|
||||
def zlib_decompress(blob: bytes, decode: Optional[bool] = True) -> Union[bytes, str]:
|
||||
"""
|
||||
Decompress things to a string in a py2/3 safe fashion
|
||||
>>> json_str = '{"test": 1}'
|
||||
>>> blob = zlib_compress(json_str)
|
||||
>>> got_str = zlib_decompress_to_string(blob)
|
||||
>>> got_str = zlib_decompress(blob)
|
||||
>>> got_str == json_str
|
||||
True
|
||||
"""
|
||||
|
|
@ -817,7 +817,7 @@ def zlib_decompress_to_string(blob):
|
|||
decompressed = zlib.decompress(blob)
|
||||
else:
|
||||
decompressed = zlib.decompress(bytes(blob, "utf-8"))
|
||||
return decompressed.decode("utf-8")
|
||||
return decompressed.decode("utf-8") if decode else decompressed
|
||||
return zlib.decompress(blob)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from contextlib import closing
|
|||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict, List # noqa: F401
|
||||
from typing import Dict, List, Optional, Union # noqa: F401
|
||||
from urllib import parse
|
||||
|
||||
from flask import (
|
||||
|
|
@ -40,7 +40,9 @@ from flask_appbuilder.security.decorators import has_access, has_access_api
|
|||
from flask_appbuilder.security.sqla import models as ab_models
|
||||
from flask_babel import gettext as __
|
||||
from flask_babel import lazy_gettext as _
|
||||
import msgpack
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import simplejson as json
|
||||
from sqlalchemy import and_, or_, select
|
||||
from werkzeug.routing import BaseConverter
|
||||
|
|
@ -50,11 +52,13 @@ from superset import (
|
|||
appbuilder,
|
||||
cache,
|
||||
conf,
|
||||
dataframe,
|
||||
db,
|
||||
event_logger,
|
||||
get_feature_flags,
|
||||
is_feature_enabled,
|
||||
results_backend,
|
||||
results_backend_use_msgpack,
|
||||
security_manager,
|
||||
sql_lab,
|
||||
viz,
|
||||
|
|
@ -76,7 +80,7 @@ from superset.sql_validators import get_validator_by_name
|
|||
from superset.utils import core as utils
|
||||
from superset.utils import dashboard_import_export
|
||||
from superset.utils.dates import now_as_float
|
||||
from superset.utils.decorators import etag_cache
|
||||
from superset.utils.decorators import etag_cache, stats_timing
|
||||
from .base import (
|
||||
api,
|
||||
BaseSupersetView,
|
||||
|
|
@ -186,6 +190,38 @@ def check_slice_perms(self, slice_id):
|
|||
security_manager.assert_datasource_permission(viz_obj.datasource)
|
||||
|
||||
|
||||
def _deserialize_results_payload(
|
||||
payload: Union[bytes, str], query, use_msgpack: Optional[bool] = False
|
||||
) -> dict:
|
||||
logging.debug(f"Deserializing from msgpack: {use_msgpack}")
|
||||
if use_msgpack:
|
||||
with stats_timing(
|
||||
"sqllab.query.results_backend_msgpack_deserialize", stats_logger
|
||||
):
|
||||
ds_payload = msgpack.loads(payload, raw=False)
|
||||
|
||||
with stats_timing("sqllab.query.results_backend_pa_deserialize", stats_logger):
|
||||
df = pa.deserialize(ds_payload["data"])
|
||||
|
||||
# TODO: optimize this, perhaps via df.to_dict, then traversing
|
||||
ds_payload["data"] = dataframe.SupersetDataFrame.format_data(df) or []
|
||||
|
||||
db_engine_spec = query.database.db_engine_spec
|
||||
all_columns, data, expanded_columns = db_engine_spec.expand_data(
|
||||
ds_payload["selected_columns"], ds_payload["data"]
|
||||
)
|
||||
ds_payload.update(
|
||||
{"data": data, "columns": all_columns, "expanded_columns": expanded_columns}
|
||||
)
|
||||
|
||||
return ds_payload
|
||||
else:
|
||||
with stats_timing(
|
||||
"sqllab.query.results_backend_json_deserialize", stats_logger
|
||||
):
|
||||
return json.loads(payload) # noqa
|
||||
|
||||
|
||||
class SliceFilter(SupersetFilter):
|
||||
def apply(self, query, func): # noqa
|
||||
if security_manager.all_datasource_access():
|
||||
|
|
@ -2416,12 +2452,12 @@ class Superset(BaseSupersetView):
|
|||
status=403,
|
||||
)
|
||||
|
||||
payload = utils.zlib_decompress_to_string(blob)
|
||||
payload_json = json.loads(payload)
|
||||
payload = utils.zlib_decompress(blob, decode=not results_backend_use_msgpack)
|
||||
obj = _deserialize_results_payload(payload, query, results_backend_use_msgpack)
|
||||
|
||||
return json_success(
|
||||
json.dumps(
|
||||
apply_display_max_row_limit(payload_json),
|
||||
apply_display_max_row_limit(obj),
|
||||
default=utils.json_iso_dttm_ser,
|
||||
ignore_nan=True,
|
||||
)
|
||||
|
|
@ -2669,8 +2705,12 @@ class Superset(BaseSupersetView):
|
|||
blob = results_backend.get(query.results_key)
|
||||
if blob:
|
||||
logging.info("Decompressing")
|
||||
json_payload = utils.zlib_decompress_to_string(blob)
|
||||
obj = json.loads(json_payload)
|
||||
payload = utils.zlib_decompress(
|
||||
blob, decode=not results_backend_use_msgpack
|
||||
)
|
||||
obj = _deserialize_results_payload(
|
||||
payload, query, results_backend_use_msgpack
|
||||
)
|
||||
columns = [c["name"] for c in obj["columns"]]
|
||||
df = pd.DataFrame.from_records(obj["data"], columns=columns)
|
||||
logging.info("Using pandas to convert to CSV")
|
||||
|
|
|
|||
|
|
@ -15,12 +15,16 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Unit tests for Superset Celery worker"""
|
||||
import datetime
|
||||
import json
|
||||
import subprocess
|
||||
import time
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
|
||||
from superset import app, db
|
||||
from superset import app, db, sql_lab
|
||||
from superset.dataframe import SupersetDataFrame
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.models.helpers import QueryStatus
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.sql_parse import ParsedQuery
|
||||
|
|
@ -242,6 +246,114 @@ class CeleryTestCase(SupersetTestCase):
|
|||
self.assertEqual(True, query.select_as_cta)
|
||||
self.assertEqual(True, query.select_as_cta_used)
|
||||
|
||||
def test_default_data_serialization(self):
|
||||
data = [("a", 4, 4.0, datetime.datetime(2019, 8, 18, 16, 39, 16, 660000))]
|
||||
cursor_descr = (
|
||||
("a", "string"),
|
||||
("b", "int"),
|
||||
("c", "float"),
|
||||
("d", "datetime"),
|
||||
)
|
||||
db_engine_spec = BaseEngineSpec()
|
||||
cdf = SupersetDataFrame(data, cursor_descr, db_engine_spec)
|
||||
|
||||
with mock.patch.object(
|
||||
db_engine_spec, "expand_data", wraps=db_engine_spec.expand_data
|
||||
) as expand_data:
|
||||
data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data(
|
||||
cdf, db_engine_spec, False
|
||||
)
|
||||
expand_data.assert_called_once()
|
||||
|
||||
self.assertIsInstance(data, list)
|
||||
|
||||
def test_new_data_serialization(self):
|
||||
data = [("a", 4, 4.0, datetime.datetime(2019, 8, 18, 16, 39, 16, 660000))]
|
||||
cursor_descr = (
|
||||
("a", "string"),
|
||||
("b", "int"),
|
||||
("c", "float"),
|
||||
("d", "datetime"),
|
||||
)
|
||||
db_engine_spec = BaseEngineSpec()
|
||||
cdf = SupersetDataFrame(data, cursor_descr, db_engine_spec)
|
||||
|
||||
with mock.patch.object(
|
||||
db_engine_spec, "expand_data", wraps=db_engine_spec.expand_data
|
||||
) as expand_data:
|
||||
data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data(
|
||||
cdf, db_engine_spec, True
|
||||
)
|
||||
expand_data.assert_not_called()
|
||||
|
||||
self.assertIsInstance(data, bytes)
|
||||
|
||||
def test_default_payload_serialization(self):
|
||||
use_new_deserialization = False
|
||||
data = [("a", 4, 4.0, datetime.datetime(2019, 8, 18, 16, 39, 16, 660000))]
|
||||
cursor_descr = (
|
||||
("a", "string"),
|
||||
("b", "int"),
|
||||
("c", "float"),
|
||||
("d", "datetime"),
|
||||
)
|
||||
db_engine_spec = BaseEngineSpec()
|
||||
cdf = SupersetDataFrame(data, cursor_descr, db_engine_spec)
|
||||
query = {
|
||||
"database_id": 1,
|
||||
"sql": "SELECT * FROM birth_names LIMIT 100",
|
||||
"status": QueryStatus.PENDING,
|
||||
}
|
||||
serialized_data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data(
|
||||
cdf, db_engine_spec, use_new_deserialization
|
||||
)
|
||||
payload = {
|
||||
"query_id": 1,
|
||||
"status": QueryStatus.SUCCESS,
|
||||
"state": QueryStatus.SUCCESS,
|
||||
"data": serialized_data,
|
||||
"columns": all_columns,
|
||||
"selected_columns": selected_columns,
|
||||
"expanded_columns": expanded_columns,
|
||||
"query": query,
|
||||
}
|
||||
|
||||
serialized = sql_lab._serialize_payload(payload, use_new_deserialization)
|
||||
self.assertIsInstance(serialized, str)
|
||||
|
||||
def test_msgpack_payload_serialization(self):
|
||||
use_new_deserialization = True
|
||||
data = [("a", 4, 4.0, datetime.datetime(2019, 8, 18, 16, 39, 16, 660000))]
|
||||
cursor_descr = (
|
||||
("a", "string"),
|
||||
("b", "int"),
|
||||
("c", "float"),
|
||||
("d", "datetime"),
|
||||
)
|
||||
db_engine_spec = BaseEngineSpec()
|
||||
cdf = SupersetDataFrame(data, cursor_descr, db_engine_spec)
|
||||
query = {
|
||||
"database_id": 1,
|
||||
"sql": "SELECT * FROM birth_names LIMIT 100",
|
||||
"status": QueryStatus.PENDING,
|
||||
}
|
||||
serialized_data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data(
|
||||
cdf, db_engine_spec, use_new_deserialization
|
||||
)
|
||||
payload = {
|
||||
"query_id": 1,
|
||||
"status": QueryStatus.SUCCESS,
|
||||
"state": QueryStatus.SUCCESS,
|
||||
"data": serialized_data,
|
||||
"columns": all_columns,
|
||||
"selected_columns": selected_columns,
|
||||
"expanded_columns": expanded_columns,
|
||||
"query": query,
|
||||
}
|
||||
|
||||
serialized = sql_lab._serialize_payload(payload, use_new_deserialization)
|
||||
self.assertIsInstance(serialized, bytes)
|
||||
|
||||
@staticmethod
|
||||
def de_unicode_dict(d):
|
||||
def str_if_basestring(o):
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ from superset.db_engine_specs.mssql import MssqlEngineSpec
|
|||
from superset.models import core as models
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.utils import core as utils
|
||||
from superset.views import core as views
|
||||
from superset.views.database.views import DatabaseView
|
||||
from .base_tests import SupersetTestCase
|
||||
from .fixtures.pyodbcRow import Row
|
||||
|
|
@ -776,6 +777,98 @@ class CoreTests(SupersetTestCase):
|
|||
resp = self.get_resp(f"/superset/select_star/{examples_db.id}/birth_names")
|
||||
self.assertIn("gender", resp)
|
||||
|
||||
def test_results_default_deserialization(self):
|
||||
use_new_deserialization = False
|
||||
data = [("a", 4, 4.0, "2019-08-18T16:39:16.660000")]
|
||||
cursor_descr = (
|
||||
("a", "string"),
|
||||
("b", "int"),
|
||||
("c", "float"),
|
||||
("d", "datetime"),
|
||||
)
|
||||
db_engine_spec = BaseEngineSpec()
|
||||
cdf = dataframe.SupersetDataFrame(data, cursor_descr, db_engine_spec)
|
||||
query = {
|
||||
"database_id": 1,
|
||||
"sql": "SELECT * FROM birth_names LIMIT 100",
|
||||
"status": utils.QueryStatus.PENDING,
|
||||
}
|
||||
serialized_data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data(
|
||||
cdf, db_engine_spec, use_new_deserialization
|
||||
)
|
||||
payload = {
|
||||
"query_id": 1,
|
||||
"status": utils.QueryStatus.SUCCESS,
|
||||
"state": utils.QueryStatus.SUCCESS,
|
||||
"data": serialized_data,
|
||||
"columns": all_columns,
|
||||
"selected_columns": selected_columns,
|
||||
"expanded_columns": expanded_columns,
|
||||
"query": query,
|
||||
}
|
||||
|
||||
serialized_payload = sql_lab._serialize_payload(
|
||||
payload, use_new_deserialization
|
||||
)
|
||||
self.assertIsInstance(serialized_payload, str)
|
||||
|
||||
query_mock = mock.Mock()
|
||||
deserialized_payload = views._deserialize_results_payload(
|
||||
serialized_payload, query_mock, use_new_deserialization
|
||||
)
|
||||
|
||||
self.assertDictEqual(deserialized_payload, payload)
|
||||
query_mock.assert_not_called()
|
||||
|
||||
def test_results_msgpack_deserialization(self):
|
||||
use_new_deserialization = True
|
||||
data = [("a", 4, 4.0, "2019-08-18T16:39:16.660000")]
|
||||
cursor_descr = (
|
||||
("a", "string"),
|
||||
("b", "int"),
|
||||
("c", "float"),
|
||||
("d", "datetime"),
|
||||
)
|
||||
db_engine_spec = BaseEngineSpec()
|
||||
cdf = dataframe.SupersetDataFrame(data, cursor_descr, db_engine_spec)
|
||||
query = {
|
||||
"database_id": 1,
|
||||
"sql": "SELECT * FROM birth_names LIMIT 100",
|
||||
"status": utils.QueryStatus.PENDING,
|
||||
}
|
||||
serialized_data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data(
|
||||
cdf, db_engine_spec, use_new_deserialization
|
||||
)
|
||||
payload = {
|
||||
"query_id": 1,
|
||||
"status": utils.QueryStatus.SUCCESS,
|
||||
"state": utils.QueryStatus.SUCCESS,
|
||||
"data": serialized_data,
|
||||
"columns": all_columns,
|
||||
"selected_columns": selected_columns,
|
||||
"expanded_columns": expanded_columns,
|
||||
"query": query,
|
||||
}
|
||||
|
||||
serialized_payload = sql_lab._serialize_payload(
|
||||
payload, use_new_deserialization
|
||||
)
|
||||
self.assertIsInstance(serialized_payload, bytes)
|
||||
|
||||
with mock.patch.object(
|
||||
db_engine_spec, "expand_data", wraps=db_engine_spec.expand_data
|
||||
) as expand_data:
|
||||
query_mock = mock.Mock()
|
||||
query_mock.database.db_engine_spec.expand_data = expand_data
|
||||
|
||||
deserialized_payload = views._deserialize_results_payload(
|
||||
serialized_payload, query_mock, use_new_deserialization
|
||||
)
|
||||
payload["data"] = dataframe.SupersetDataFrame.format_data(cdf.raw_df)
|
||||
|
||||
self.assertDictEqual(deserialized_payload, payload)
|
||||
expand_data.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ from superset.utils.core import (
|
|||
setup_cache,
|
||||
validate_json,
|
||||
zlib_compress,
|
||||
zlib_decompress_to_string,
|
||||
zlib_decompress,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -140,7 +140,7 @@ class UtilsTestCase(unittest.TestCase):
|
|||
def test_zlib_compression(self):
|
||||
json_str = '{"test": 1}'
|
||||
blob = zlib_compress(json_str)
|
||||
got_str = zlib_decompress_to_string(blob)
|
||||
got_str = zlib_decompress(blob)
|
||||
self.assertEquals(json_str, got_str)
|
||||
|
||||
@patch("superset.utils.core.to_adhoc", mock_to_adhoc)
|
||||
|
|
|
|||
Loading…
Reference in New Issue