From dd0bc472e3022ed1bdf2944ac61b6c40c5ae31af Mon Sep 17 00:00:00 2001 From: Ville Brofeldt <33317356+villebro@users.noreply.github.com> Date: Fri, 28 Aug 2020 21:12:03 +0300 Subject: [PATCH] refactor(database): use SupersetResultSet on SqlaTable.get_df() (#10707) * refactor(database): use SupersetResultSet on SqlaTable.get_df() * lint * change cypress test --- .../explore/visualizations/table.test.ts | 21 ++++++++++--------- superset/db_engine_specs/base.py | 6 ++++-- superset/db_engine_specs/bigquery.py | 4 +++- superset/db_engine_specs/exasol.py | 6 ++++-- superset/db_engine_specs/hive.py | 6 ++++-- superset/db_engine_specs/mssql.py | 4 +++- superset/db_engine_specs/postgres.py | 4 +++- superset/models/core.py | 14 ++++++------- superset/typing.py | 4 ++-- superset/viz.py | 1 - 10 files changed, 40 insertions(+), 30 deletions(-) diff --git a/superset-frontend/cypress-base/cypress/integration/explore/visualizations/table.test.ts b/superset-frontend/cypress-base/cypress/integration/explore/visualizations/table.test.ts index c7015d913..77f9c6f12 100644 --- a/superset-frontend/cypress-base/cypress/integration/explore/visualizations/table.test.ts +++ b/superset-frontend/cypress-base/cypress/integration/explore/visualizations/table.test.ts @@ -29,6 +29,16 @@ import readResponseBlob from '../../../utils/readResponseBlob'; describe('Visualization > Table', () => { const VIZ_DEFAULTS = { ...FORM_DATA_DEFAULTS, viz_type: 'table' }; + const PERCENT_METRIC = { + expressionType: 'SQL', + sqlExpression: 'CAST(SUM(sum_girls)+AS+FLOAT)/SUM(num)', + column: null, + aggregate: null, + hasCustomLabel: true, + label: 'Girls', + optionName: 'metric_6qwzgc8bh2v_zox7hil1mzs', + }; + beforeEach(() => { cy.login(); cy.server(); @@ -119,7 +129,7 @@ describe('Visualization > Table', () => { it('Test table with percent metrics and groupby', () => { const formData = { ...VIZ_DEFAULTS, - percent_metrics: NUM_METRIC, + percent_metrics: PERCENT_METRIC, metrics: [], groupby: ['name'], }; @@ -214,15 +224,6 @@ describe('Visualization > Table', () => { }); it('Tests table number formatting with % in metric name', () => { - const PERCENT_METRIC = { - expressionType: 'SQL', - sqlExpression: 'CAST(SUM(sum_girls)+AS+FLOAT)/SUM(num)', - column: null, - aggregate: null, - hasCustomLabel: true, - label: 'Girls', - optionName: 'metric_6qwzgc8bh2v_zox7hil1mzs', - }; const formData = { ...VIZ_DEFAULTS, percent_metrics: PERCENT_METRIC, diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index d3d9dab0c..cfb3671e3 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -305,7 +305,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return select_exprs @classmethod - def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]: + def fetch_data( + cls, cursor: Any, limit: Optional[int] = None + ) -> List[Tuple[Any, ...]]: """ :param cursor: Cursor instance @@ -314,7 +316,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods """ if cls.arraysize: cursor.arraysize = cls.arraysize - if cls.limit_method == LimitMethod.FETCH_MANY: + if cls.limit_method == LimitMethod.FETCH_MANY and limit: return cursor.fetchmany(limit) return cursor.fetchall() diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py index ea33531b4..71ae8280b 100644 --- a/superset/db_engine_specs/bigquery.py +++ b/superset/db_engine_specs/bigquery.py @@ -85,7 +85,9 @@ class BigQueryEngineSpec(BaseEngineSpec): return None @classmethod - def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]: + def fetch_data( + cls, cursor: Any, limit: Optional[int] = None + ) -> List[Tuple[Any, ...]]: data = super().fetch_data(cursor, limit) # Support type BigQuery Row, introduced here PR #4071 # google.cloud.bigquery.table.Row diff --git a/superset/db_engine_specs/exasol.py b/superset/db_engine_specs/exasol.py index a485be542..327cc3adb 100644 --- a/superset/db_engine_specs/exasol.py +++ b/superset/db_engine_specs/exasol.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, List, Tuple +from typing import Any, List, Optional, Tuple from superset.db_engine_specs.base import BaseEngineSpec @@ -40,7 +40,9 @@ class ExasolEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method } @classmethod - def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]: + def fetch_data( + cls, cursor: Any, limit: Optional[int] = None + ) -> List[Tuple[Any, ...]]: data = super().fetch_data(cursor, limit) # Lists of `pyodbc.Row` need to be unpacked further return cls.pyodbc_rows_to_tuples(data) diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index 82570533e..918128fa4 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -132,7 +132,9 @@ class HiveEngineSpec(PrestoEngineSpec): return BaseEngineSpec.get_all_datasource_names(database, datasource_type) @classmethod - def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]: + def fetch_data( + cls, cursor: Any, limit: Optional[int] = None + ) -> List[Tuple[Any, ...]]: import pyhive from TCLIService import ttypes @@ -140,7 +142,7 @@ class HiveEngineSpec(PrestoEngineSpec): if state.operationState == ttypes.TOperationState.ERROR_STATE: raise Exception("Query error", state.errorMessage) try: - return super(HiveEngineSpec, cls).fetch_data(cursor, limit) + return super().fetch_data(cursor, limit) except pyhive.exc.ProgrammingError: return [] diff --git a/superset/db_engine_specs/mssql.py b/superset/db_engine_specs/mssql.py index 70bd9b5e3..d1bb99c26 100644 --- a/superset/db_engine_specs/mssql.py +++ b/superset/db_engine_specs/mssql.py @@ -68,7 +68,9 @@ class MssqlEngineSpec(BaseEngineSpec): return None @classmethod - def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]: + def fetch_data( + cls, cursor: Any, limit: Optional[int] = None + ) -> List[Tuple[Any, ...]]: data = super().fetch_data(cursor, limit) # Lists of `pyodbc.Row` need to be unpacked further return cls.pyodbc_rows_to_tuples(data) diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index 0ccf51dba..1ec433fd5 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -53,7 +53,9 @@ class PostgresBaseEngineSpec(BaseEngineSpec): } @classmethod - def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]: + def fetch_data( + cls, cursor: Any, limit: Optional[int] = None + ) -> List[Tuple[Any, ...]]: cursor.tzinfo_factory = FixedOffsetTimezone if not cursor.description: return [] diff --git a/superset/models/core.py b/superset/models/core.py index 7660150e1..775a9f09d 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -57,6 +57,7 @@ from superset.db_engine_specs.base import TimeGrain from superset.models.dashboard import Dashboard from superset.models.helpers import AuditMixinNullable, ImportMixin from superset.models.tags import DashboardUpdater, FavStarUpdater +from superset.result_set import SupersetResultSet from superset.utils import cache as cache_util, core as utils config = app.config @@ -392,21 +393,18 @@ class Database( _log_query(sqls[-1]) self.db_engine_spec.execute(cursor, sqls[-1]) - if cursor.description is not None: - columns = [col_desc[0] for col_desc in cursor.description] - else: - columns = [] - - df = pd.DataFrame.from_records( - data=list(cursor.fetchall()), columns=columns, coerce_float=True + data = self.db_engine_spec.fetch_data(cursor) + result_set = SupersetResultSet( + data, cursor.description, self.db_engine_spec ) - + df = result_set.to_pandas_df() if mutator: mutator(df) for k, v in df.dtypes.items(): if v.type == numpy.object_ and needs_conversion(df[k]): df[k] = df[k].apply(utils.json_dumps_w_dates) + return df def compile_sqla_query(self, qry: Select, schema: Optional[str] = None) -> str: diff --git a/superset/typing.py b/superset/typing.py index e2380000e..6f1fa2ef3 100644 --- a/superset/typing.py +++ b/superset/typing.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union from flask import Flask from flask_caching import Cache @@ -25,7 +25,7 @@ DbapiDescriptionRow = Tuple[ str, str, Optional[str], Optional[str], Optional[int], Optional[int], bool ] DbapiDescription = Union[List[DbapiDescriptionRow], Tuple[DbapiDescriptionRow, ...]] -DbapiResult = List[Union[List[Any], Tuple[Any, ...]]] +DbapiResult = Sequence[Union[List[Any], Tuple[Any, ...]]] FilterValue = Union[float, int, str] FilterValues = Union[FilterValue, List[FilterValue], Tuple[FilterValue]] FormData = Dict[str, Any] diff --git a/superset/viz.py b/superset/viz.py index 17bebc757..d4487b427 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -26,7 +26,6 @@ import inspect import logging import math import re -import uuid from collections import defaultdict, OrderedDict from datetime import datetime, timedelta from itertools import product