From 5f4e3adfd242ca3e1d8a22511e2ec4b71006acd9 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Tue, 26 Nov 2019 11:49:19 -0800 Subject: [PATCH] Pass full response to `query_cost_formatter` (#8652) * Return full info when doing query cost estimation * Add unit test * Fix isort --- superset/db_engine_specs/presto.py | 13 +++--- tests/db_engine_specs/presto_tests.py | 63 +++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 6 deletions(-) diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 661fe3517..f3564cf98 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -445,7 +445,7 @@ class PrestoEngineSpec(BaseEngineSpec): @classmethod def estimate_statement_cost( # pylint: disable=too-many-locals cls, statement: str, database, cursor, user_name: str - ) -> Dict[str, float]: + ) -> Dict[str, Any]: """ Run a SQL query that estimates the cost of a given statement. @@ -453,7 +453,7 @@ class PrestoEngineSpec(BaseEngineSpec): :param database: Database instance :param cursor: Cursor instance :param username: Effective username - :return: JSON estimate from Presto + :return: JSON response from Presto """ parsed_query = ParsedQuery(statement) sql = parsed_query.stripped() @@ -479,11 +479,11 @@ class PrestoEngineSpec(BaseEngineSpec): # } # } result = json.loads(cursor.fetchone()[0]) - return result["estimate"] + return result @classmethod def query_cost_formatter( - cls, raw_cost: List[Dict[str, float]] + cls, raw_cost: List[Dict[str, Any]] ) -> List[Dict[str, str]]: """ Format cost estimate. @@ -516,10 +516,11 @@ class PrestoEngineSpec(BaseEngineSpec): ("networkCost", "Network cost", ""), ] for row in raw_cost: + estimate: Dict[str, float] = row.get("estimate", {}) statement_cost = {} for key, label, suffix in columns: - if key in row: - statement_cost[label] = humanize(row[key], suffix).strip() + if key in estimate: + statement_cost[label] = humanize(estimate[key], suffix).strip() cost.append(statement_cost) return cost diff --git a/tests/db_engine_specs/presto_tests.py b/tests/db_engine_specs/presto_tests.py index bfb032294..cf62b282d 100644 --- a/tests/db_engine_specs/presto_tests.py +++ b/tests/db_engine_specs/presto_tests.py @@ -372,3 +372,66 @@ class PrestoTests(DbEngineSpecTestCase): PrestoEngineSpec.convert_dttm("TIMESTAMP", dttm), "from_iso8601_timestamp('2019-01-02T03:04:05.678900')", ) + + def test_query_cost_formatter(self): + raw_cost = [ + { + "inputTableColumnInfos": [ + { + "table": { + "catalog": "hive", + "schemaTable": { + "schema": "default", + "table": "fact_passenger_state", + }, + }, + "columnConstraints": [ + { + "columnName": "ds", + "typeSignature": "varchar", + "domain": { + "nullsAllowed": False, + "ranges": [ + { + "low": { + "value": "2019-07-10", + "bound": "EXACTLY", + }, + "high": { + "value": "2019-07-10", + "bound": "EXACTLY", + }, + } + ], + }, + } + ], + "estimate": { + "outputRowCount": 9.04969899e8, + "outputSizeInBytes": 3.54143678301e11, + "cpuCost": 3.54143678301e11, + "maxMemory": 0.0, + "networkCost": 0.0, + }, + } + ], + "estimate": { + "outputRowCount": 9.04969899e8, + "outputSizeInBytes": 3.54143678301e11, + "cpuCost": 3.54143678301e11, + "maxMemory": 0.0, + "networkCost": 3.54143678301e11, + }, + } + ] + formatted_cost = PrestoEngineSpec.query_cost_formatter(raw_cost) + expected = [ + { + "Output count": "904 M rows", + "Output size": "354 GB", + "CPU cost": "354 G", + "Max memory": "0 B", + "Network cost": "354 G", + } + ] + self.assertEqual(formatted_cost, expected)