From 93e1db4bd9d045b8a9b345733a60139cb213ab86 Mon Sep 17 00:00:00 2001 From: "Hugh A. Miles II" Date: Tue, 20 Jun 2023 13:54:19 -0400 Subject: [PATCH] fix: save columns reference from sqllab save datasets flow (#24248) --- .../src/constants.ts | 2 +- .../src/utils/columnChoices.ts | 13 +- .../test/utils/getTemporalColumns.test.ts | 2 +- .../superset-ui-core/src/query/types/Query.ts | 22 +- .../components/ResultSet/ResultSet.test.tsx | 4 +- .../src/SqlLab/components/ResultSet/index.tsx | 6 +- .../components/SaveDatasetModal/index.tsx | 8 +- superset-frontend/src/SqlLab/fixtures.ts | 46 ++--- .../Datasource/DatasourceEditor.jsx | 11 +- .../src/explore/components/SaveModal.tsx | 2 - superset/connectors/base/models.py | 9 +- superset/connectors/sqla/models.py | 26 ++- superset/connectors/sqla/utils.py | 4 +- superset/daos/query.py | 3 + superset/databases/utils.py | 4 +- superset/datasets/api.py | 2 +- superset/datasets/commands/create.py | 4 + superset/db_engine_specs/base.py | 25 ++- superset/db_engine_specs/bigquery.py | 18 +- superset/db_engine_specs/druid.py | 3 +- superset/db_engine_specs/hive.py | 11 +- superset/db_engine_specs/presto.py | 58 ++++-- superset/models/core.py | 5 +- superset/models/sql_lab.py | 2 +- superset/result_set.py | 2 +- superset/superset_typing.py | 21 +- superset/tables/models.py | 5 +- superset/views/utils.py | 4 + tests/integration_tests/datasets/api_tests.py | 2 +- tests/integration_tests/datasource_tests.py | 12 +- .../db_engine_specs/presto_tests.py | 192 ++++++++++++++---- tests/integration_tests/result_set_tests.py | 22 +- tests/unit_tests/config_test.py | 34 ++-- tests/unit_tests/db_engine_specs/test_base.py | 30 +++ .../db_engine_specs/test_bigquery.py | 22 +- .../unit_tests/db_engine_specs/test_presto.py | 10 +- tests/unit_tests/queries/dao_test.py | 45 ++++ 37 files changed, 489 insertions(+), 202 deletions(-) create mode 100644 tests/unit_tests/queries/dao_test.py diff --git a/superset-frontend/packages/superset-ui-chart-controls/src/constants.ts b/superset-frontend/packages/superset-ui-chart-controls/src/constants.ts index 8ae89efbf..cbde46b0e 100644 --- a/superset-frontend/packages/superset-ui-chart-controls/src/constants.ts +++ b/superset-frontend/packages/superset-ui-chart-controls/src/constants.ts @@ -48,7 +48,7 @@ export const DATASET_TIME_COLUMN_OPTION: ColumnMeta = { }; export const QUERY_TIME_COLUMN_OPTION: QueryColumn = { - name: DTTM_ALIAS, + column_name: DTTM_ALIAS, type: DatasourceType.Query, is_dttm: false, }; diff --git a/superset-frontend/packages/superset-ui-chart-controls/src/utils/columnChoices.ts b/superset-frontend/packages/superset-ui-chart-controls/src/utils/columnChoices.ts index fd4e1fb51..c76cd7903 100644 --- a/superset-frontend/packages/superset-ui-chart-controls/src/utils/columnChoices.ts +++ b/superset-frontend/packages/superset-ui-chart-controls/src/utils/columnChoices.ts @@ -16,8 +16,8 @@ * specific language governing permissions and limitations * under the License. */ -import { ensureIsArray, QueryResponse } from '@superset-ui/core'; -import { Dataset, isColumnMeta, isDataset, isQueryResponse } from '../types'; +import { QueryResponse } from '@superset-ui/core'; +import { Dataset, isColumnMeta, isDataset } from '../types'; /** * Convert Datasource columns to column choices @@ -35,14 +35,5 @@ export default function columnChoices( opt1[1].toLowerCase() > opt2[1].toLowerCase() ? 1 : -1, ); } - - if (isQueryResponse(datasource)) { - return ensureIsArray(datasource.columns) - .map((col): [string, string] => [col.name, col.name]) - .sort((opt1, opt2) => - opt1[1].toLowerCase() > opt2[1].toLowerCase() ? 1 : -1, - ); - } - return []; } diff --git a/superset-frontend/packages/superset-ui-chart-controls/test/utils/getTemporalColumns.test.ts b/superset-frontend/packages/superset-ui-chart-controls/test/utils/getTemporalColumns.test.ts index 1921540ea..722717304 100644 --- a/superset-frontend/packages/superset-ui-chart-controls/test/utils/getTemporalColumns.test.ts +++ b/superset-frontend/packages/superset-ui-chart-controls/test/utils/getTemporalColumns.test.ts @@ -54,7 +54,7 @@ test('get temporal columns from a QueryResponse', () => { expect(getTemporalColumns(testQueryResponse)).toEqual({ temporalColumns: [ { - name: 'Column 2', + column_name: 'Column 2', type: 'TIMESTAMP', is_dttm: true, }, diff --git a/superset-frontend/packages/superset-ui-core/src/query/types/Query.ts b/superset-frontend/packages/superset-ui-core/src/query/types/Query.ts index d71928420..f42b01abd 100644 --- a/superset-frontend/packages/superset-ui-core/src/query/types/Query.ts +++ b/superset-frontend/packages/superset-ui-core/src/query/types/Query.ts @@ -247,8 +247,8 @@ export const CtasEnum = { }; export type QueryColumn = { - name: string; - column_name?: string; + name?: string; + column_name: string; type: string | null; is_dttm: boolean; }; @@ -380,17 +380,17 @@ export const testQuery: Query = { type: DatasourceType.Query, columns: [ { - name: 'Column 1', + column_name: 'Column 1', type: 'STRING', is_dttm: false, }, { - name: 'Column 3', + column_name: 'Column 3', type: 'STRING', is_dttm: false, }, { - name: 'Column 2', + column_name: 'Column 2', type: 'TIMESTAMP', is_dttm: true, }, @@ -402,17 +402,17 @@ export const testQueryResults = { displayLimitReached: false, columns: [ { - name: 'Column 1', + column_name: 'Column 1', type: 'STRING', is_dttm: false, }, { - name: 'Column 3', + column_name: 'Column 3', type: 'STRING', is_dttm: false, }, { - name: 'Column 2', + column_name: 'Column 2', type: 'TIMESTAMP', is_dttm: true, }, @@ -423,17 +423,17 @@ export const testQueryResults = { expanded_columns: [], selected_columns: [ { - name: 'Column 1', + column_name: 'Column 1', type: 'STRING', is_dttm: false, }, { - name: 'Column 3', + column_name: 'Column 3', type: 'STRING', is_dttm: false, }, { - name: 'Column 2', + column_name: 'Column 2', type: 'TIMESTAMP', is_dttm: true, }, diff --git a/superset-frontend/src/SqlLab/components/ResultSet/ResultSet.test.tsx b/superset-frontend/src/SqlLab/components/ResultSet/ResultSet.test.tsx index 1d9fd58be..5e2a0455b 100644 --- a/superset-frontend/src/SqlLab/components/ResultSet/ResultSet.test.tsx +++ b/superset-frontend/src/SqlLab/components/ResultSet/ResultSet.test.tsx @@ -122,10 +122,10 @@ describe('ResultSet', () => { expect(table).toBeInTheDocument(); const firstColumn = queryAllByText( - mockedProps.query.results?.columns[0].name ?? '', + mockedProps.query.results?.columns[0].column_name ?? '', )[0]; const secondColumn = queryAllByText( - mockedProps.query.results?.columns[1].name ?? '', + mockedProps.query.results?.columns[1].column_name ?? '', )[0]; expect(firstColumn).toBeInTheDocument(); expect(secondColumn).toBeInTheDocument(); diff --git a/superset-frontend/src/SqlLab/components/ResultSet/index.tsx b/superset-frontend/src/SqlLab/components/ResultSet/index.tsx index b614a0efb..4c7569314 100644 --- a/superset-frontend/src/SqlLab/components/ResultSet/index.tsx +++ b/superset-frontend/src/SqlLab/components/ResultSet/index.tsx @@ -205,7 +205,7 @@ const ResultSet = ({ ...EXPLORE_CHART_DEFAULT, datasource: `${results.query_id}__query`, ...{ - all_columns: results.columns.map(column => column.name), + all_columns: results.columns.map(column => column.column_name), }, }); const url = mountExploreUrl(null, { @@ -491,7 +491,7 @@ const ResultSet = ({ } if (data && data.length > 0) { const expandedColumns = results.expanded_columns - ? results.expanded_columns.map(col => col.name) + ? results.expanded_columns.map(col => col.column_name) : []; return ( <> @@ -500,7 +500,7 @@ const ResultSet = ({ {sql} col.name)} + orderedColumnKeys={results.columns.map(col => col.column_name)} height={rowsHeight} filterText={searchText} expandedColumns={expandedColumns} diff --git a/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx b/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx index c51c6158d..a42928608 100644 --- a/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx +++ b/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx @@ -62,6 +62,7 @@ export type ExploreQuery = QueryResponse & { export interface ISimpleColumn { column_name?: string | null; + name?: string | null; type?: string | null; is_dttm?: boolean | null; } @@ -199,14 +200,15 @@ export const SaveDatasetModal = ({ return; } setLoading(true); + const [, key] = await Promise.all([ updateDataset( datasource?.dbId, datasetToOverwrite?.datasetid, datasource?.sql, datasource?.columns?.map( - (d: { name: string; type: string; is_dttm: boolean }) => ({ - column_name: d.name, + (d: { column_name: string; type: string; is_dttm: boolean }) => ({ + column_name: d.column_name, type: d.type, is_dttm: d.is_dttm, }), @@ -292,12 +294,10 @@ export const SaveDatasetModal = ({ dispatch( createDatasource({ - schema: datasource.schema, sql: datasource.sql, dbId: datasource.dbId || datasource?.database?.id, templateParams, datasourceName: datasetName, - columns: selectedColumns, }), ) .then((data: { id: number }) => diff --git a/superset-frontend/src/SqlLab/fixtures.ts b/superset-frontend/src/SqlLab/fixtures.ts index ebfd01888..6fd94e78f 100644 --- a/superset-frontend/src/SqlLab/fixtures.ts +++ b/superset-frontend/src/SqlLab/fixtures.ts @@ -236,24 +236,24 @@ export const queries = [ columns: [ { is_dttm: true, - name: 'ds', + column_name: 'ds', type: 'STRING', }, { is_dttm: false, - name: 'gender', + column_name: 'gender', type: 'STRING', }, ], selected_columns: [ { is_dttm: true, - name: 'ds', + column_name: 'ds', type: 'STRING', }, { is_dttm: false, - name: 'gender', + column_name: 'gender', type: 'STRING', }, ], @@ -326,7 +326,7 @@ export const queryWithNoQueryLimit = { columns: [ { is_dttm: true, - name: 'ds', + column_name: 'ds', type: 'STRING', }, { @@ -338,12 +338,12 @@ export const queryWithNoQueryLimit = { selected_columns: [ { is_dttm: true, - name: 'ds', + column_name: 'ds', type: 'STRING', }, { is_dttm: false, - name: 'gender', + column_name: 'gender', type: 'STRING', }, ], @@ -364,57 +364,57 @@ export const queryWithBadColumns = { selected_columns: [ { is_dttm: true, - name: 'COUNT(*)', + column_name: 'COUNT(*)', type: 'STRING', }, { is_dttm: false, - name: 'this_col_is_ok', + column_name: 'this_col_is_ok', type: 'STRING', }, { is_dttm: false, - name: 'a', + column_name: 'a', type: 'STRING', }, { is_dttm: false, - name: '1', + column_name: '1', type: 'STRING', }, { is_dttm: false, - name: '123', + column_name: '123', type: 'STRING', }, { is_dttm: false, - name: 'CASE WHEN 1=1 THEN 1 ELSE 0 END', + column_name: 'CASE WHEN 1=1 THEN 1 ELSE 0 END', type: 'STRING', }, { is_dttm: true, - name: '_TIMESTAMP', + column_name: '_TIMESTAMP', type: 'TIMESTAMP', }, { is_dttm: true, - name: '__TIME', + column_name: '__TIME', type: 'TIMESTAMP', }, { is_dttm: false, - name: 'my_dupe_col__2', + column_name: 'my_dupe_col__2', type: 'STRING', }, { is_dttm: true, - name: '__timestamp', + column_name: '__timestamp', type: 'TIMESTAMP', }, { is_dttm: true, - name: '__TIMESTAMP', + column_name: '__TIMESTAMP', type: 'TIMESTAMP', }, ], @@ -572,31 +572,31 @@ const baseQuery: QueryResponse = { columns: [ { is_dttm: true, - name: 'ds', + column_name: 'ds', type: 'STRING', }, { is_dttm: false, - name: 'gender', + column_name: 'gender', type: 'STRING', }, ], selected_columns: [ { is_dttm: true, - name: 'ds', + column_name: 'ds', type: 'STRING', }, { is_dttm: false, - name: 'gender', + column_name: 'gender', type: 'STRING', }, ], expanded_columns: [ { is_dttm: true, - name: 'ds', + column_name: 'ds', type: 'STRING', }, ], diff --git a/superset-frontend/src/components/Datasource/DatasourceEditor.jsx b/superset-frontend/src/components/Datasource/DatasourceEditor.jsx index 89ad14180..79b10b8fc 100644 --- a/superset-frontend/src/components/Datasource/DatasourceEditor.jsx +++ b/superset-frontend/src/components/Datasource/DatasourceEditor.jsx @@ -688,8 +688,9 @@ class DatasourceEditor extends React.PureComponent { } updateColumns(cols) { + // cols: Array<{column_name: string; is_dttm: boolean; type: string;}> const { databaseColumns } = this.state; - const databaseColumnNames = cols.map(col => col.name); + const databaseColumnNames = cols.map(col => col.column_name); const currentCols = databaseColumns.reduce( (agg, col) => ({ ...agg, @@ -706,18 +707,18 @@ class DatasourceEditor extends React.PureComponent { .filter(col => !databaseColumnNames.includes(col)), }; cols.forEach(col => { - const currentCol = currentCols[col.name]; + const currentCol = currentCols[col.column_name]; if (!currentCol) { // new column finalColumns.push({ id: shortid.generate(), - column_name: col.name, + column_name: col.column_name, type: col.type, groupby: true, filterable: true, is_dttm: col.is_dttm, }); - results.added.push(col.name); + results.added.push(col.column_name); } else if ( currentCol.type !== col.type || (!currentCol.is_dttm && col.is_dttm) @@ -728,7 +729,7 @@ class DatasourceEditor extends React.PureComponent { type: col.type, is_dttm: currentCol.is_dttm || col.is_dttm, }); - results.modified.push(col.name); + results.modified.push(col.column_name); } else { // unchanged finalColumns.push(currentCol); diff --git a/superset-frontend/src/explore/components/SaveModal.tsx b/superset-frontend/src/explore/components/SaveModal.tsx index 6b4061ed7..1de97b926 100644 --- a/superset-frontend/src/explore/components/SaveModal.tsx +++ b/superset-frontend/src/explore/components/SaveModal.tsx @@ -174,7 +174,6 @@ class SaveModal extends React.Component { if (this.props.datasource?.type === DatasourceType.Query) { const { schema, sql, database } = this.props.datasource; const { templateParams } = this.props.datasource; - const columns = this.props.datasource?.columns || []; await this.props.actions.saveDataset({ schema, @@ -182,7 +181,6 @@ class SaveModal extends React.Component { database, templateParams, datasourceName: this.state.datasetName, - columns, }); } diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py index f370a9f64..8396da911 100644 --- a/superset/connectors/base/models.py +++ b/superset/connectors/base/models.py @@ -35,7 +35,12 @@ from superset.constants import EMPTY_STRING, NULL_STRING from superset.datasets.commands.exceptions import DatasetNotFoundError from superset.models.helpers import AuditMixinNullable, ImportExportMixin, QueryResult from superset.models.slice import Slice -from superset.superset_typing import FilterValue, FilterValues, QueryObjectDict +from superset.superset_typing import ( + FilterValue, + FilterValues, + QueryObjectDict, + ResultSetColumnType, +) from superset.utils import core as utils from superset.utils.core import GenericDataType, MediumText @@ -456,7 +461,7 @@ class BaseDatasource( values = values[0] if values else None return values - def external_metadata(self) -> list[dict[str, str]]: + def external_metadata(self) -> list[ResultSetColumnType]: """Returns column information from the external system""" raise NotImplementedError() diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index c44b79062..4eebec6be 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -105,7 +105,13 @@ from superset.models.helpers import ( validate_adhoc_subquery, ) from superset.sql_parse import ParsedQuery, sanitize_clause -from superset.superset_typing import AdhocColumn, AdhocMetric, Metric, QueryObjectDict +from superset.superset_typing import ( + AdhocColumn, + AdhocMetric, + Metric, + QueryObjectDict, + ResultSetColumnType, +) from superset.utils import core as utils from superset.utils.core import GenericDataType, MediumText @@ -700,10 +706,10 @@ class SqlaTable( def sql_url(self) -> str: return self.database.sql_url + "?table_name=" + str(self.table_name) - def external_metadata(self) -> list[dict[str, str]]: + def external_metadata(self) -> list[ResultSetColumnType]: # todo(yongjie): create a physical table column type in a separate PR if self.sql: - return get_virtual_table_metadata(dataset=self) # type: ignore + return get_virtual_table_metadata(dataset=self) return get_physical_table_metadata( database=self.database, table_name=self.table_name, @@ -995,7 +1001,7 @@ class SqlaTable( qry = sa.select([sqla_column]).limit(1).select_from(tbl) sql = self.database.compile_sqla_query(qry) col_desc = get_columns_description(self.database, sql) - is_dttm = col_desc[0]["is_dttm"] + is_dttm = col_desc[0]["is_dttm"] # type: ignore except SupersetGenericDBErrorException as ex: raise ColumnNotFoundException(message=str(ex)) from ex @@ -1260,18 +1266,18 @@ class SqlaTable( removed=[ col for col in old_columns_by_name - if col not in {col["name"] for col in new_columns} + if col not in {col["column_name"] for col in new_columns} ] ) # clear old columns before adding modified columns back columns = [] for col in new_columns: - old_column = old_columns_by_name.pop(col["name"], None) + old_column = old_columns_by_name.pop(col["column_name"], None) if not old_column: - results.added.append(col["name"]) + results.added.append(col["column_name"]) new_column = TableColumn( - column_name=col["name"], + column_name=col["column_name"], type=col["type"], table=self, ) @@ -1280,14 +1286,14 @@ class SqlaTable( else: new_column = old_column if new_column.type != col["type"]: - results.modified.append(col["name"]) + results.modified.append(col["column_name"]) new_column.type = col["type"] new_column.expression = "" new_column.groupby = True new_column.filterable = True columns.append(new_column) if not any_date_col and new_column.is_temporal: - any_date_col = col["name"] + any_date_col = col["column_name"] # add back calculated (virtual) columns columns.extend([col for col in old_columns if col.expression]) diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py index d41c0555d..f761b2dca 100644 --- a/superset/connectors/sqla/utils.py +++ b/superset/connectors/sqla/utils.py @@ -19,7 +19,7 @@ from __future__ import annotations import logging from collections.abc import Iterable, Iterator from functools import lru_cache -from typing import Any, Callable, TYPE_CHECKING, TypeVar +from typing import Callable, TYPE_CHECKING, TypeVar from uuid import UUID from flask_babel import lazy_gettext as _ @@ -49,7 +49,7 @@ def get_physical_table_metadata( database: Database, table_name: str, schema_name: str | None = None, -) -> list[dict[str, Any]]: +) -> list[ResultSetColumnType]: """Use SQLAlchemy inspector to get table metadata""" db_engine_spec = database.db_engine_spec db_dialect = database.get_dialect() diff --git a/superset/daos/query.py b/superset/daos/query.py index 8aca1a4e2..8996e27a3 100644 --- a/superset/daos/query.py +++ b/superset/daos/query.py @@ -63,6 +63,9 @@ class QueryDAO(BaseDAO): def save_metadata(query: Query, payload: dict[str, Any]) -> None: # pull relevant data from payload and store in extra_json columns = payload.get("columns", {}) + for col in columns: + if "name" in col: + col["column_name"] = col.get("name") db.session.add(query) query.set_extra_json_key("columns", columns) diff --git a/superset/databases/utils.py b/superset/databases/utils.py index 74943f474..fa163e4d9 100644 --- a/superset/databases/utils.py +++ b/superset/databases/utils.py @@ -79,10 +79,10 @@ def get_table_metadata( dtype = get_col_type(col) payload_columns.append( { - "name": col["name"], + "name": col["column_name"], "type": dtype.split("(")[0] if "(" in dtype else dtype, "longType": dtype, - "keys": [k for k in keys if col["name"] in k["column_names"]], + "keys": [k for k in keys if col["column_name"] in k["column_names"]], "comment": col.get("comment"), } ) diff --git a/superset/datasets/api.py b/superset/datasets/api.py index 6e6cf38aa..2b6f417e3 100644 --- a/superset/datasets/api.py +++ b/superset/datasets/api.py @@ -311,7 +311,7 @@ class DatasetRestApi(BaseSupersetModelRestApi): try: new_model = CreateDatasetCommand(item).run() - return self.response(201, id=new_model.id, result=item) + return self.response(201, id=new_model.id, result=item, data=new_model.data) except DatasetInvalidError as ex: return self.response_422(message=ex.normalized_messages()) except DatasetCreateFailedError as ex: diff --git a/superset/datasets/commands/create.py b/superset/datasets/commands/create.py index 28b0250ab..38b8bf436 100644 --- a/superset/datasets/commands/create.py +++ b/superset/datasets/commands/create.py @@ -22,6 +22,7 @@ from marshmallow import ValidationError from sqlalchemy.exc import SQLAlchemyError from superset.commands.base import BaseCommand, CreateMixin +from superset.connectors.sqla.models import SqlMetric from superset.daos.dataset import DatasetDAO from superset.daos.exceptions import DAOCreateFailedError from superset.datasets.commands.exceptions import ( @@ -45,7 +46,10 @@ class CreateDatasetCommand(CreateMixin, BaseCommand): try: # Creates SqlaTable (Dataset) dataset = DatasetDAO.create(self._properties, commit=False) + # Updates columns and metrics from the dataset + dataset.metrics = [SqlMetric(metric_name="count", expression="COUNT(*)")] + dataset.fetch_metadata(commit=False) db.session.commit() except (SQLAlchemyError, DAOCreateFailedError) as ex: diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 01d878ce0..766f17e7e 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -23,7 +23,7 @@ import logging import re from datetime import datetime from re import Match, Pattern -from typing import Any, Callable, ContextManager, NamedTuple, TYPE_CHECKING, Union +from typing import Any, Callable, cast, ContextManager, NamedTuple, TYPE_CHECKING, Union import pandas as pd import sqlparse @@ -53,7 +53,7 @@ from superset.constants import TimeGrain as TimeGrainConstants from superset.databases.utils import make_url_safe from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.sql_parse import ParsedQuery, Table -from superset.superset_typing import ResultSetColumnType +from superset.superset_typing import ResultSetColumnType, SQLAColumnType from superset.utils import core as utils from superset.utils.core import ColumnSpec, GenericDataType from superset.utils.hashing import md5_sha_from_str @@ -73,6 +73,13 @@ ColumnTypeMapping = tuple[ logger = logging.getLogger() +def convert_inspector_columns(cols: list[SQLAColumnType]) -> list[ResultSetColumnType]: + result_set_columns: list[ResultSetColumnType] = [] + for col in cols: + result_set_columns.append({"column_name": col.get("name"), **col}) # type: ignore + return result_set_columns + + class TimeGrain(NamedTuple): name: str # TODO: redundant field, remove label: str @@ -1223,7 +1230,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods @classmethod def get_columns( cls, inspector: Inspector, table_name: str, schema: str | None - ) -> list[dict[str, Any]]: + ) -> list[ResultSetColumnType]: """ Get all columns from a given schema and table @@ -1232,7 +1239,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods :param schema: Schema name. If omitted, uses default schema for database :return: All columns in table """ - return inspector.get_columns(table_name, schema) + return convert_inspector_columns( + cast(list[SQLAColumnType], inspector.get_columns(table_name, schema)) + ) @classmethod def get_metrics( # pylint: disable=unused-argument @@ -1261,7 +1270,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods schema: str | None, database: Database, query: Select, - columns: list[dict[str, Any]] | None = None, + columns: list[ResultSetColumnType] | None = None, ) -> Select | None: """ Add a where clause to a query to reference only the most recent partition @@ -1278,8 +1287,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return None @classmethod - def _get_fields(cls, cols: list[dict[str, Any]]) -> list[Any]: - return [column(c["name"]) for c in cols] + def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[Any]: + return [column(c["column_name"]) for c in cols] @classmethod def select_star( # pylint: disable=too-many-arguments,too-many-locals @@ -1292,7 +1301,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods show_cols: bool = False, indent: bool = True, latest_partition: bool = True, - cols: list[dict[str, Any]] | None = None, + cols: list[ResultSetColumnType] | None = None, ) -> str: """ Generate a "SELECT * from [schema.]table_name" query with appropriate limit. diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py index e69194f50..a47a32841 100644 --- a/superset/db_engine_specs/bigquery.py +++ b/superset/db_engine_specs/bigquery.py @@ -43,6 +43,7 @@ from superset.db_engine_specs.exceptions import SupersetDBAPIConnectionError from superset.errors import SupersetError, SupersetErrorType from superset.exceptions import SupersetException from superset.sql_parse import Table +from superset.superset_typing import ResultSetColumnType from superset.utils import core as utils from superset.utils.hashing import md5_sha_from_str @@ -637,7 +638,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met show_cols: bool = False, indent: bool = True, latest_partition: bool = True, - cols: Optional[list[dict[str, Any]]] = None, + cols: Optional[list[ResultSetColumnType]] = None, ) -> str: """ Remove array structures from `SELECT *`. @@ -678,13 +679,15 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met # For arrays of structs, remove the child columns, otherwise the query # will fail. array_prefixes = { - col["name"] for col in cols if isinstance(col["type"], sqltypes.ARRAY) + col["column_name"] + for col in cols + if isinstance(col["type"], sqltypes.ARRAY) } cols = [ col for col in cols - if "." not in col["name"] - or col["name"].split(".")[0] not in array_prefixes + if "." not in col["column_name"] + or col["column_name"].split(".")[0] not in array_prefixes ] return super().select_star( @@ -700,7 +703,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met ) @classmethod - def _get_fields(cls, cols: list[dict[str, Any]]) -> list[Any]: + def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[Any]: """ Label columns using their fully qualified name. @@ -725,7 +728,10 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met the columns using their fully qualified name, so we end up with "author", "author__name" and "author__email", respectively. """ - return [column(c["name"]).label(c["name"].replace(".", "__")) for c in cols] + return [ + column(c["column_name"]).label(c["column_name"].replace(".", "__")) + for c in cols + ] @classmethod def parse_error_exception(cls, exception: Exception) -> Exception: diff --git a/superset/db_engine_specs/druid.py b/superset/db_engine_specs/druid.py index 478f3e949..9bba3a727 100644 --- a/superset/db_engine_specs/druid.py +++ b/superset/db_engine_specs/druid.py @@ -30,6 +30,7 @@ from superset.constants import TimeGrain from superset.db_engine_specs.base import BaseEngineSpec from superset.db_engine_specs.exceptions import SupersetDBAPIConnectionError from superset.exceptions import SupersetException +from superset.superset_typing import ResultSetColumnType from superset.utils import core as utils if TYPE_CHECKING: @@ -132,7 +133,7 @@ class DruidEngineSpec(BaseEngineSpec): @classmethod def get_columns( cls, inspector: Inspector, table_name: str, schema: str | None - ) -> list[dict[str, Any]]: + ) -> list[ResultSetColumnType]: """ Update the Druid type map. """ diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index d7c2465ba..e76d90293 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -46,6 +46,7 @@ from superset.exceptions import SupersetException from superset.extensions import cache_manager from superset.models.sql_lab import Query from superset.sql_parse import ParsedQuery, Table +from superset.superset_typing import ResultSetColumnType if TYPE_CHECKING: # prevent circular imports @@ -407,8 +408,8 @@ class HiveEngineSpec(PrestoEngineSpec): @classmethod def get_columns( cls, inspector: Inspector, table_name: str, schema: str | None - ) -> list[dict[str, Any]]: - return inspector.get_columns(table_name, schema) + ) -> list[ResultSetColumnType]: + return BaseEngineSpec.get_columns(inspector, table_name, schema) @classmethod def where_latest_partition( # pylint: disable=too-many-arguments @@ -417,7 +418,7 @@ class HiveEngineSpec(PrestoEngineSpec): schema: str | None, database: Database, query: Select, - columns: list[dict[str, Any]] | None = None, + columns: list[ResultSetColumnType] | None = None, ) -> Select | None: try: col_names, values = cls.latest_partition( @@ -436,7 +437,7 @@ class HiveEngineSpec(PrestoEngineSpec): return None @classmethod - def _get_fields(cls, cols: list[dict[str, Any]]) -> list[ColumnClause]: + def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[ColumnClause]: return BaseEngineSpec._get_fields(cols) # pylint: disable=protected-access @classmethod @@ -481,7 +482,7 @@ class HiveEngineSpec(PrestoEngineSpec): show_cols: bool = False, indent: bool = True, latest_partition: bool = True, - cols: list[dict[str, Any]] | None = None, + cols: list[ResultSetColumnType] | None = None, ) -> str: return super( # pylint: disable=bad-super-call PrestoEngineSpec, cls diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 861e82234..d24405d9c 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -126,7 +126,14 @@ def get_children(column: ResultSetColumnType) -> list[ResultSetColumnType]: type_ = group["type"].upper() children_type = group["children"] if type_ == "ARRAY": - return [{"name": column["name"], "type": children_type, "is_dttm": False}] + return [ + { + "column_name": column["column_name"], + "name": column["column_name"], + "type": children_type, + "is_dttm": False, + } + ] if type_ == "ROW": nameless_columns = 0 @@ -141,7 +148,8 @@ def get_children(column: ResultSetColumnType) -> list[ResultSetColumnType]: type_ = parts[0] nameless_columns += 1 _column: ResultSetColumnType = { - "name": f"{column['name']}.{name.lower()}", + "column_name": f"{column['column_name']}.{name.lower()}", + "name": f"{column['column_name']}.{name.lower()}", "type": type_, "is_dttm": False, } @@ -482,7 +490,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta): schema: str | None, database: Database, query: Select, - columns: list[dict[str, Any]] | None = None, + columns: list[ResultSetColumnType] | None = None, ) -> Select | None: try: col_names, values = cls.latest_partition( @@ -496,7 +504,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta): return None column_type_by_name = { - column.get("name"): column.get("type") for column in columns or [] + column.get("column_name"): column.get("type") for column in columns or [] } for col_name, value in zip(col_names, values): @@ -813,14 +821,20 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): @classmethod def _create_column_info( cls, name: str, data_type: types.TypeEngine - ) -> dict[str, Any]: + ) -> ResultSetColumnType: """ Create column info object :param name: column name :param data_type: column data type :return: column info object """ - return {"name": name, "type": f"{data_type}"} + return { + "column_name": name, + "name": name, + "type": f"{data_type}", + "is_dttm": None, + "type_generic": None, + } @classmethod def _get_full_name(cls, names: list[tuple[str, str]]) -> str: @@ -863,7 +877,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): cls, parent_column_name: str, parent_data_type: str, - result: list[dict[str, Any]], + result: list[ResultSetColumnType], ) -> None: """ Parse a row or array column @@ -941,7 +955,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): # Unquote the column name if necessary if formatted_parent_column_name != parent_column_name: for index in range(original_result_len, len(result)): - result[index]["name"] = result[index]["name"].replace( + result[index]["column_name"] = result[index]["column_name"].replace( formatted_parent_column_name, parent_column_name ) @@ -965,7 +979,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): @classmethod def get_columns( cls, inspector: Inspector, table_name: str, schema: str | None - ) -> list[dict[str, Any]]: + ) -> list[ResultSetColumnType]: """ Get columns from a Presto data source. This includes handling row and array data types @@ -976,7 +990,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): (i.e. column name and data type) """ columns = cls._show_columns(inspector, table_name, schema) - result: list[dict[str, Any]] = [] + result: list[ResultSetColumnType] = [] for column in columns: # parse column if it is a row or array if is_feature_enabled("PRESTO_EXPAND_DATA") and ( @@ -1003,6 +1017,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): column_info = cls._create_column_info(column.Column, column_type) column_info["nullable"] = getattr(column, "Null", True) column_info["default"] = None + column_info["column_name"] = column.Column result.append(column_info) return result @@ -1016,7 +1031,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): return column_name.startswith('"') and column_name.endswith('"') @classmethod - def _get_fields(cls, cols: list[dict[str, Any]]) -> list[ColumnClause]: + def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[ColumnClause]: """ Format column clauses where names are in quotes and labels are specified :param cols: columns @@ -1034,7 +1049,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): dot_regex = re.compile(dot_pattern, re.VERBOSE) for col in cols: # get individual column names - col_names = re.split(dot_regex, col["name"]) + col_names = re.split(dot_regex, col["column_name"]) # quote each column name if it is not already quoted for index, col_name in enumerate(col_names): if not cls._is_column_name_quoted(col_name): @@ -1044,7 +1059,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): for col_name in col_names ) # create column clause in the format "name"."name" AS "name.name" - column_clause = literal_column(quoted_col_name).label(col["name"]) + column_clause = literal_column(quoted_col_name).label(col["column_name"]) column_clauses.append(column_clause) return column_clauses @@ -1059,7 +1074,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): show_cols: bool = False, indent: bool = True, latest_partition: bool = True, - cols: list[dict[str, Any]] | None = None, + cols: list[ResultSetColumnType] | None = None, ) -> str: """ Include selecting properties of row objects. We cannot easily break arrays into @@ -1071,7 +1086,9 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): if is_feature_enabled("PRESTO_EXPAND_DATA") and show_cols: dot_regex = r"\.(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)" presto_cols = [ - col for col in presto_cols if not re.search(dot_regex, col["name"]) + col + for col in presto_cols + if not re.search(dot_regex, col["column_name"]) ] return super().select_star( database, @@ -1123,7 +1140,9 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): current_array_level = None while to_process: column, level = to_process.popleft() - if column["name"] not in [column["name"] for column in all_columns]: + if column["column_name"] not in [ + column["column_name"] for column in all_columns + ]: all_columns.append(column) # When unnesting arrays we need to keep track of how many extra rows @@ -1135,7 +1154,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): unnested_rows: dict[int, int] = defaultdict(int) current_array_level = level - name = column["name"] + name = column["column_name"] values: str | list[Any] | None if column["type"] and column["type"].startswith("ARRAY("): @@ -1186,10 +1205,11 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): values = cast(Optional[list[Any]], destringify(values)) row[name] = values for value, col in zip(values or [], expanded): - row[col["name"]] = value + row[col["column_name"]] = value data = [ - {k["name"]: row.get(k["name"], "") for k in all_columns} for row in data + {k["column_name"]: row.get(k["column_name"], "") for k in all_columns} + for row in data ] return all_columns, data, expanded_columns diff --git a/superset/models/core.py b/superset/models/core.py index 92e6f2dbb..4ff56145e 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -69,6 +69,7 @@ from superset.extensions import ( ) from superset.models.helpers import AuditMixinNullable, ImportExportMixin from superset.result_set import SupersetResultSet +from superset.superset_typing import ResultSetColumnType from superset.utils import cache as cache_util, core as utils from superset.utils.core import get_username @@ -632,7 +633,7 @@ class Database( show_cols: bool = False, indent: bool = True, latest_partition: bool = False, - cols: Optional[list[dict[str, Any]]] = None, + cols: Optional[list[ResultSetColumnType]] = None, ) -> str: """Generates a ``select *`` statement in the proper dialect""" eng = self._get_sqla_engine(schema=schema, source=utils.QuerySource.SQL_LAB) @@ -837,7 +838,7 @@ class Database( def get_columns( self, table_name: str, schema: Optional[str] = None - ) -> list[dict[str, Any]]: + ) -> list[ResultSetColumnType]: with self.get_inspector_with_context() as inspector: return self.db_engine_spec.get_columns(inspector, table_name, schema) diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index fbadaaa2f..96ec4010b 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -193,7 +193,7 @@ class Query( return [ TableColumn( - column_name=col["name"], + column_name=col["column_name"], database=self.database, is_dttm=col["is_dttm"], filterable=True, diff --git a/superset/result_set.py b/superset/result_set.py index f707b91dc..4ca39cba2 100644 --- a/superset/result_set.py +++ b/superset/result_set.py @@ -253,10 +253,10 @@ class SupersetResultSet: for col in self.table.schema: db_type_str = self.data_type(col.name, col.type) column: ResultSetColumnType = { + "column_name": col.name, "name": col.name, "type": db_type_str, "is_dttm": self.is_temporal(db_type_str), } columns.append(column) - return columns diff --git a/superset/superset_typing.py b/superset/superset_typing.py index 7c21df6a8..c58fa567a 100644 --- a/superset/superset_typing.py +++ b/superset/superset_typing.py @@ -18,7 +18,7 @@ from collections.abc import Sequence from datetime import datetime from typing import Any, Literal, Optional, TYPE_CHECKING, Union -from typing_extensions import TypedDict +from typing_extensions import NotRequired, TypedDict from werkzeug.wrappers import Response if TYPE_CHECKING: @@ -60,14 +60,29 @@ class AdhocColumn(TypedDict, total=False): timeGrain: Optional[str] +class SQLAColumnType(TypedDict): + name: str + type: Optional[str] + is_dttm: bool + + class ResultSetColumnType(TypedDict): """ Superset virtual dataset column interface """ - name: str + name: str # legacy naming convention keeping this for backwards compatibility + column_name: str type: Optional[str] - is_dttm: bool + is_dttm: Optional[bool] + type_generic: NotRequired[Optional["GenericDataType"]] + + nullable: NotRequired[Any] + default: NotRequired[Any] + comment: NotRequired[Any] + precision: NotRequired[Any] + scale: NotRequired[Any] + max_length: NotRequired[Any] CacheConfig = dict[str, Any] diff --git a/superset/tables/models.py b/superset/tables/models.py index a24035fb9..0b9741405 100644 --- a/superset/tables/models.py +++ b/superset/tables/models.py @@ -43,6 +43,7 @@ from superset.models.helpers import ( ImportExportMixin, ) from superset.sql_parse import Table as TableName +from superset.superset_typing import ResultSetColumnType if TYPE_CHECKING: from superset.datasets.models import Dataset @@ -131,8 +132,8 @@ class Table(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): existing_columns = {column.name: column for column in self.columns} quote_identifier = self.database.quote_identifier - def update_or_create_column(column_meta: dict[str, Any]) -> Column: - column_name: str = column_meta["name"] + def update_or_create_column(column_meta: ResultSetColumnType) -> Column: + column_name: str = column_meta["column_name"] if column_name in existing_columns: column = existing_columns[column_name] else: diff --git a/superset/views/utils.py b/superset/views/utils.py index 29f5a4c7e..cbe9a0ee5 100644 --- a/superset/views/utils.py +++ b/superset/views/utils.py @@ -536,6 +536,10 @@ def _deserialize_results_payload( df = result_set.SupersetResultSet.convert_table_to_df(pa_table) ds_payload["data"] = dataframe.df_to_records(df) or [] + for column in ds_payload["selected_columns"]: + if "name" in column: + column["column_name"] = column.get("name") + db_engine_spec = query.database.db_engine_spec all_columns, data, expanded_columns = db_engine_spec.expand_data( ds_payload["selected_columns"], ds_payload["data"] diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 2f55a1e97..c0d86a876 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -791,7 +791,7 @@ class TestDatasetApi(SupersetTestCase): mock_get_columns.return_value = [ { - "name": "col", + "column_name": "col", "type": "VARCHAR", "type_generic": None, "is_dttm": None, diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py index b73acc268..5de1cf6ef 100644 --- a/tests/integration_tests/datasource_tests.py +++ b/tests/integration_tests/datasource_tests.py @@ -71,7 +71,7 @@ class TestDatasource(SupersetTestCase): tbl = self.get_table(name="birth_names") url = f"/datasource/external_metadata/table/{tbl.id}/" resp = self.get_json_resp(url) - col_names = {o.get("name") for o in resp} + col_names = {o.get("column_name") for o in resp} self.assertEqual( col_names, {"num_boys", "num", "gender", "name", "ds", "state", "num_girls"} ) @@ -91,7 +91,7 @@ class TestDatasource(SupersetTestCase): table = self.get_table(name="dummy_sql_table") url = f"/datasource/external_metadata/table/{table.id}/" resp = self.get_json_resp(url) - assert {o.get("name") for o in resp} == {"intcol", "strcol"} + assert {o.get("column_name") for o in resp} == {"intcol", "strcol"} session.delete(table) session.commit() @@ -109,7 +109,7 @@ class TestDatasource(SupersetTestCase): ) url = f"/datasource/external_metadata_by_name/?q={params}" resp = self.get_json_resp(url) - col_names = {o.get("name") for o in resp} + col_names = {o.get("column_name") for o in resp} self.assertEqual( col_names, {"num_boys", "num", "gender", "name", "ds", "state", "num_girls"} ) @@ -137,7 +137,7 @@ class TestDatasource(SupersetTestCase): ) url = f"/datasource/external_metadata_by_name/?q={params}" resp = self.get_json_resp(url) - assert {o.get("name") for o in resp} == {"intcol", "strcol"} + assert {o.get("column_name") for o in resp} == {"intcol", "strcol"} session.delete(tbl) session.commit() @@ -155,7 +155,7 @@ class TestDatasource(SupersetTestCase): ) url = f"/datasource/external_metadata_by_name/?q={params}" resp = self.get_json_resp(url) - col_names = {o.get("name") for o in resp} + col_names = {o.get("column_name") for o in resp} self.assertEqual(col_names, {"first", "second"}) # No databases found @@ -216,7 +216,7 @@ class TestDatasource(SupersetTestCase): table = self.get_table(name="dummy_sql_table_with_template_params") url = f"/datasource/external_metadata/table/{table.id}/" resp = self.get_json_resp(url) - assert {o.get("name") for o in resp} == {"intcol"} + assert {o.get("column_name") for o in resp} == {"intcol"} session.delete(table) session.commit() diff --git a/tests/integration_tests/db_engine_specs/presto_tests.py b/tests/integration_tests/db_engine_specs/presto_tests.py index 4fe74c0ca..393f89621 100644 --- a/tests/integration_tests/db_engine_specs/presto_tests.py +++ b/tests/integration_tests/db_engine_specs/presto_tests.py @@ -85,7 +85,7 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec): results = PrestoEngineSpec.get_columns(inspector, "", "") self.assertEqual(len(expected_results), len(results)) for expected_result, result in zip(expected_results, results): - self.assertEqual(expected_result[0], result["name"]) + self.assertEqual(expected_result[0], result["column_name"]) self.assertEqual(expected_result[1], str(result["type"])) def test_presto_get_column(self): @@ -175,21 +175,21 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec): def test_presto_get_fields(self): cols = [ - {"name": "column"}, - {"name": "column.nested_obj"}, - {"name": 'column."quoted.nested obj"'}, + {"column_name": "column"}, + {"column_name": "column.nested_obj"}, + {"column_name": 'column."quoted.nested obj"'}, ] actual_results = PrestoEngineSpec._get_fields(cols) expected_results = [ - {"name": '"column"', "label": "column"}, - {"name": '"column"."nested_obj"', "label": "column.nested_obj"}, + {"column_name": '"column"', "label": "column"}, + {"column_name": '"column"."nested_obj"', "label": "column.nested_obj"}, { - "name": '"column"."quoted.nested obj"', + "column_name": '"column"."quoted.nested obj"', "label": 'column."quoted.nested obj"', }, ] for actual_result, expected_result in zip(actual_results, expected_results): - self.assertEqual(actual_result.element.name, expected_result["name"]) + self.assertEqual(actual_result.element.name, expected_result["column_name"]) self.assertEqual(actual_result.name, expected_result["label"]) @mock.patch.dict( @@ -199,8 +199,18 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec): ) def test_presto_expand_data_with_simple_structural_columns(self): cols = [ - {"name": "row_column", "type": "ROW(NESTED_OBJ VARCHAR)", "is_dttm": False}, - {"name": "array_column", "type": "ARRAY(BIGINT)", "is_dttm": False}, + { + "column_name": "row_column", + "name": "row_column", + "type": "ROW(NESTED_OBJ VARCHAR)", + "is_dttm": False, + }, + { + "column_name": "array_column", + "name": "array_column", + "type": "ARRAY(BIGINT)", + "is_dttm": False, + }, ] data = [ {"row_column": ["a"], "array_column": [1, 2, 3]}, @@ -210,9 +220,24 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec): cols, data ) expected_cols = [ - {"name": "row_column", "type": "ROW(NESTED_OBJ VARCHAR)", "is_dttm": False}, - {"name": "row_column.nested_obj", "type": "VARCHAR", "is_dttm": False}, - {"name": "array_column", "type": "ARRAY(BIGINT)", "is_dttm": False}, + { + "column_name": "row_column", + "name": "row_column", + "type": "ROW(NESTED_OBJ VARCHAR)", + "is_dttm": False, + }, + { + "column_name": "row_column.nested_obj", + "name": "row_column.nested_obj", + "type": "VARCHAR", + "is_dttm": False, + }, + { + "column_name": "array_column", + "name": "array_column", + "type": "ARRAY(BIGINT)", + "is_dttm": False, + }, ] expected_data = [ @@ -225,7 +250,12 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec): ] expected_expanded_cols = [ - {"name": "row_column.nested_obj", "type": "VARCHAR", "is_dttm": False} + { + "name": "row_column.nested_obj", + "column_name": "row_column.nested_obj", + "type": "VARCHAR", + "is_dttm": False, + } ] self.assertEqual(actual_cols, expected_cols) self.assertEqual(actual_data, expected_data) @@ -240,6 +270,7 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec): cols = [ { "name": "row_column", + "column_name": "row_column", "type": "ROW(NESTED_OBJ1 VARCHAR, NESTED_ROW ROW(NESTED_OBJ2 VARCHAR))", "is_dttm": False, } @@ -251,17 +282,25 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec): expected_cols = [ { "name": "row_column", + "column_name": "row_column", "type": "ROW(NESTED_OBJ1 VARCHAR, NESTED_ROW ROW(NESTED_OBJ2 VARCHAR))", "is_dttm": False, }, - {"name": "row_column.nested_obj1", "type": "VARCHAR", "is_dttm": False}, + { + "name": "row_column.nested_obj1", + "column_name": "row_column.nested_obj1", + "type": "VARCHAR", + "is_dttm": False, + }, { "name": "row_column.nested_row", + "column_name": "row_column.nested_row", "type": "ROW(NESTED_OBJ2 VARCHAR)", "is_dttm": False, }, { "name": "row_column.nested_row.nested_obj2", + "column_name": "row_column.nested_row.nested_obj2", "type": "VARCHAR", "is_dttm": False, }, @@ -282,14 +321,21 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec): ] expected_expanded_cols = [ - {"name": "row_column.nested_obj1", "type": "VARCHAR", "is_dttm": False}, + { + "name": "row_column.nested_obj1", + "column_name": "row_column.nested_obj1", + "type": "VARCHAR", + "is_dttm": False, + }, { "name": "row_column.nested_row", + "column_name": "row_column.nested_row", "type": "ROW(NESTED_OBJ2 VARCHAR)", "is_dttm": False, }, { "name": "row_column.nested_row.nested_obj2", + "column_name": "row_column.nested_row.nested_obj2", "type": "VARCHAR", "is_dttm": False, }, @@ -307,6 +353,7 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec): cols = [ { "name": "row_column", + "column_name": "row_column", "type": "ROW(NESTED_ROW ROW(NESTED_OBJ VARCHAR))", "is_dttm": False, } @@ -323,16 +370,19 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec): expected_cols = [ { "name": "row_column", + "column_name": "row_column", "type": "ROW(NESTED_ROW ROW(NESTED_OBJ VARCHAR))", "is_dttm": False, }, { "name": "row_column.nested_row", + "column_name": "row_column.nested_row", "type": "ROW(NESTED_OBJ VARCHAR)", "is_dttm": False, }, { "name": "row_column.nested_row.nested_obj", + "column_name": "row_column.nested_row.nested_obj", "type": "VARCHAR", "is_dttm": False, }, @@ -363,11 +413,13 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec): expected_expanded_cols = [ { "name": "row_column.nested_row", + "column_name": "row_column.nested_row", "type": "ROW(NESTED_OBJ VARCHAR)", "is_dttm": False, }, { "name": "row_column.nested_row.nested_obj", + "column_name": "row_column.nested_row.nested_obj", "type": "VARCHAR", "is_dttm": False, }, @@ -383,9 +435,15 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec): ) def test_presto_expand_data_with_complex_array_columns(self): cols = [ - {"name": "int_column", "type": "BIGINT", "is_dttm": False}, + { + "name": "int_column", + "column_name": "int_column", + "type": "BIGINT", + "is_dttm": False, + }, { "name": "array_column", + "column_name": "array_column", "type": "ARRAY(ROW(NESTED_ARRAY ARRAY(ROW(NESTED_OBJ VARCHAR))))", "is_dttm": False, }, @@ -398,19 +456,27 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec): cols, data ) expected_cols = [ - {"name": "int_column", "type": "BIGINT", "is_dttm": False}, + { + "name": "int_column", + "column_name": "int_column", + "type": "BIGINT", + "is_dttm": False, + }, { "name": "array_column", + "column_name": "array_column", "type": "ARRAY(ROW(NESTED_ARRAY ARRAY(ROW(NESTED_OBJ VARCHAR))))", "is_dttm": False, }, { "name": "array_column.nested_array", + "column_name": "array_column.nested_array", "type": "ARRAY(ROW(NESTED_OBJ VARCHAR))", "is_dttm": False, }, { "name": "array_column.nested_array.nested_obj", + "column_name": "array_column.nested_array.nested_obj", "type": "VARCHAR", "is_dttm": False, }, @@ -468,11 +534,13 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec): expected_expanded_cols = [ { "name": "array_column.nested_array", + "column_name": "array_column.nested_array", "type": "ARRAY(ROW(NESTED_OBJ VARCHAR))", "is_dttm": False, }, { "name": "array_column.nested_array.nested_obj", + "column_name": "array_column.nested_array.nested_obj", "type": "VARCHAR", "is_dttm": False, }, @@ -575,9 +643,20 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec): ) def test_presto_expand_data_array(self): cols = [ - {"name": "event_id", "type": "VARCHAR", "is_dttm": False}, - {"name": "timestamp", "type": "BIGINT", "is_dttm": False}, { + "column_name": "event_id", + "name": "event_id", + "type": "VARCHAR", + "is_dttm": False, + }, + { + "column_name": "timestamp", + "name": "timestamp", + "type": "BIGINT", + "is_dttm": False, + }, + { + "column_name": "user", "name": "user", "type": "ROW(ID BIGINT, FIRST_NAME VARCHAR, LAST_NAME VARCHAR)", "is_dttm": False, @@ -594,16 +673,42 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec): cols, data ) expected_cols = [ - {"name": "event_id", "type": "VARCHAR", "is_dttm": False}, - {"name": "timestamp", "type": "BIGINT", "is_dttm": False}, { + "column_name": "event_id", + "name": "event_id", + "type": "VARCHAR", + "is_dttm": False, + }, + { + "column_name": "timestamp", + "name": "timestamp", + "type": "BIGINT", + "is_dttm": False, + }, + { + "column_name": "user", "name": "user", "type": "ROW(ID BIGINT, FIRST_NAME VARCHAR, LAST_NAME VARCHAR)", "is_dttm": False, }, - {"name": "user.id", "type": "BIGINT", "is_dttm": False}, - {"name": "user.first_name", "type": "VARCHAR", "is_dttm": False}, - {"name": "user.last_name", "type": "VARCHAR", "is_dttm": False}, + { + "column_name": "user.id", + "name": "user.id", + "type": "BIGINT", + "is_dttm": False, + }, + { + "column_name": "user.first_name", + "name": "user.first_name", + "type": "VARCHAR", + "is_dttm": False, + }, + { + "column_name": "user.last_name", + "name": "user.last_name", + "type": "VARCHAR", + "is_dttm": False, + }, ] expected_data = [ { @@ -616,9 +721,24 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec): } ] expected_expanded_cols = [ - {"name": "user.id", "type": "BIGINT", "is_dttm": False}, - {"name": "user.first_name", "type": "VARCHAR", "is_dttm": False}, - {"name": "user.last_name", "type": "VARCHAR", "is_dttm": False}, + { + "column_name": "user.id", + "name": "user.id", + "type": "BIGINT", + "is_dttm": False, + }, + { + "column_name": "user.first_name", + "name": "user.first_name", + "type": "VARCHAR", + "is_dttm": False, + }, + { + "column_name": "user.last_name", + "name": "user.last_name", + "type": "VARCHAR", + "is_dttm": False, + }, ] self.assertEqual(actual_cols, expected_cols) @@ -736,12 +856,12 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec): table_name = "table_name" engine = mock.Mock() cols = [ - {"name": "val1"}, - {"name": "val2 None: mocker.patch( "superset.connectors.sqla.models.get_physical_table_metadata", return_value=[ - {"name": "ds", "type": "TIMESTAMP", "is_dttm": True}, - {"name": "event_time", "type": "TIMESTAMP", "is_dttm": True}, - {"name": "id", "type": "INTEGER", "is_dttm": False}, + {"column_name": "ds", "type": "TIMESTAMP", "is_dttm": True}, + {"column_name": "event_time", "type": "TIMESTAMP", "is_dttm": True}, + {"column_name": "id", "type": "INTEGER", "is_dttm": False}, ], ) @@ -154,9 +154,9 @@ def test_main_dttm_col_nonexistent( mocker.patch( "superset.connectors.sqla.models.get_physical_table_metadata", return_value=[ - {"name": "ds", "type": "TIMESTAMP", "is_dttm": True}, - {"name": "event_time", "type": "TIMESTAMP", "is_dttm": True}, - {"name": "id", "type": "INTEGER", "is_dttm": False}, + {"column_name": "ds", "type": "TIMESTAMP", "is_dttm": True}, + {"column_name": "event_time", "type": "TIMESTAMP", "is_dttm": True}, + {"column_name": "id", "type": "INTEGER", "is_dttm": False}, ], ) @@ -188,9 +188,9 @@ def test_main_dttm_col_nondttm( mocker.patch( "superset.connectors.sqla.models.get_physical_table_metadata", return_value=[ - {"name": "ds", "type": "TIMESTAMP", "is_dttm": True}, - {"name": "event_time", "type": "TIMESTAMP", "is_dttm": True}, - {"name": "id", "type": "INTEGER", "is_dttm": False}, + {"column_name": "ds", "type": "TIMESTAMP", "is_dttm": True}, + {"column_name": "event_time", "type": "TIMESTAMP", "is_dttm": True}, + {"column_name": "id", "type": "INTEGER", "is_dttm": False}, ], ) @@ -226,9 +226,9 @@ def test_python_date_format_by_column_name( mocker.patch( "superset.connectors.sqla.models.get_physical_table_metadata", return_value=[ - {"name": "id", "type": "INTEGER", "is_dttm": False}, - {"name": "dttm", "type": "INTEGER", "is_dttm": False}, - {"name": "duration_ms", "type": "INTEGER", "is_dttm": False}, + {"column_name": "id", "type": "INTEGER", "is_dttm": False}, + {"column_name": "dttm", "type": "INTEGER", "is_dttm": False}, + {"column_name": "duration_ms", "type": "INTEGER", "is_dttm": False}, ], ) @@ -274,8 +274,8 @@ def test_expression_by_column_name( mocker.patch( "superset.connectors.sqla.models.get_physical_table_metadata", return_value=[ - {"name": "dttm", "type": "INTEGER", "is_dttm": False}, - {"name": "duration_ms", "type": "INTEGER", "is_dttm": False}, + {"column_name": "dttm", "type": "INTEGER", "is_dttm": False}, + {"column_name": "duration_ms", "type": "INTEGER", "is_dttm": False}, ], ) @@ -311,9 +311,9 @@ def test_full_setting( mocker.patch( "superset.connectors.sqla.models.get_physical_table_metadata", return_value=[ - {"name": "id", "type": "INTEGER", "is_dttm": False}, - {"name": "dttm", "type": "INTEGER", "is_dttm": False}, - {"name": "duration_ms", "type": "INTEGER", "is_dttm": False}, + {"column_name": "id", "type": "INTEGER", "is_dttm": False}, + {"column_name": "dttm", "type": "INTEGER", "is_dttm": False}, + {"column_name": "duration_ms", "type": "INTEGER", "is_dttm": False}, ], ) diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py index 33083f039..7b1977ff1 100644 --- a/tests/unit_tests/db_engine_specs/test_base.py +++ b/tests/unit_tests/db_engine_specs/test_base.py @@ -22,6 +22,7 @@ from typing import Any, Optional import pytest from sqlalchemy import types +from superset.superset_typing import ResultSetColumnType, SQLAColumnType from superset.utils.core import GenericDataType from tests.unit_tests.db_engine_specs.utils import assert_column_spec @@ -138,3 +139,32 @@ def test_get_column_spec( from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec as spec assert_column_spec(spec, native_type, sqla_type, attrs, generic_type, is_dttm) + + +@pytest.mark.parametrize( + "cols, expected_result", + [ + ( + [SQLAColumnType(name="John", type="integer", is_dttm=False)], + [ + ResultSetColumnType( + column_name="John", name="John", type="integer", is_dttm=False + ) + ], + ), + ( + [SQLAColumnType(name="hugh", type="integer", is_dttm=False)], + [ + ResultSetColumnType( + column_name="hugh", name="hugh", type="integer", is_dttm=False + ) + ], + ), + ], +) +def test_convert_inspector_columns( + cols: list[SQLAColumnType], expected_result: list[ResultSetColumnType] +): + from superset.db_engine_specs.base import convert_inspector_columns + + assert convert_inspector_columns(cols) == expected_result diff --git a/tests/unit_tests/db_engine_specs/test_bigquery.py b/tests/unit_tests/db_engine_specs/test_bigquery.py index 5b9c6a956..37d04defc 100644 --- a/tests/unit_tests/db_engine_specs/test_bigquery.py +++ b/tests/unit_tests/db_engine_specs/test_bigquery.py @@ -27,6 +27,7 @@ from sqlalchemy import select from sqlalchemy.sql import sqltypes from sqlalchemy_bigquery import BigQueryDialect +from superset.superset_typing import ResultSetColumnType from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm from tests.unit_tests.fixtures.common import dttm @@ -64,7 +65,16 @@ def test_get_fields() -> None: """ from superset.db_engine_specs.bigquery import BigQueryEngineSpec - columns = [{"name": "limit"}, {"name": "name"}, {"name": "project.name"}] + columns: list[ResultSetColumnType] = [ + {"column_name": "limit", "name": "limit", "type": "STRING", "is_dttm": False}, + {"column_name": "name", "name": "name", "type": "STRING", "is_dttm": False}, + { + "column_name": "project.name", + "name": "project.name", + "type": "STRING", + "is_dttm": False, + }, + ] fields = BigQueryEngineSpec._get_fields(columns) query = select(fields) @@ -84,8 +94,9 @@ def test_select_star(mocker: MockFixture) -> None: """ from superset.db_engine_specs.bigquery import BigQueryEngineSpec - cols = [ + cols: list[ResultSetColumnType] = [ { + "column_name": "trailer", "name": "trailer", "type": sqltypes.ARRAY(sqltypes.JSON()), "nullable": True, @@ -94,8 +105,10 @@ def test_select_star(mocker: MockFixture) -> None: "precision": None, "scale": None, "max_length": None, + "is_dttm": False, }, { + "column_name": "trailer.key", "name": "trailer.key", "type": sqltypes.String(), "nullable": True, @@ -104,8 +117,10 @@ def test_select_star(mocker: MockFixture) -> None: "precision": None, "scale": None, "max_length": None, + "is_dttm": False, }, { + "column_name": "trailer.value", "name": "trailer.value", "type": sqltypes.String(), "nullable": True, @@ -114,8 +129,10 @@ def test_select_star(mocker: MockFixture) -> None: "precision": None, "scale": None, "max_length": None, + "is_dttm": False, }, { + "column_name": "trailer.email", "name": "trailer.email", "type": sqltypes.String(), "nullable": True, @@ -124,6 +141,7 @@ def test_select_star(mocker: MockFixture) -> None: "precision": None, "scale": None, "max_length": None, + "is_dttm": False, }, ] diff --git a/tests/unit_tests/db_engine_specs/test_presto.py b/tests/unit_tests/db_engine_specs/test_presto.py index 7739361cf..8d57d4ed1 100644 --- a/tests/unit_tests/db_engine_specs/test_presto.py +++ b/tests/unit_tests/db_engine_specs/test_presto.py @@ -24,6 +24,7 @@ from pyhive.sqlalchemy_presto import PrestoDialect from sqlalchemy import sql, text, types from sqlalchemy.engine.url import make_url +from superset.superset_typing import ResultSetColumnType from superset.utils.core import GenericDataType from tests.unit_tests.db_engine_specs.utils import ( assert_column_spec, @@ -131,7 +132,14 @@ def test_where_latest_partition( mock_latest_partition.return_value = (["partition_key"], [column_value]) query = sql.select(text("* FROM table")) - columns = [{"name": "partition_key", "type": column_type}] + columns: list[ResultSetColumnType] = [ + { + "column_name": "partition_key", + "name": "partition_key", + "type": column_type, + "is_dttm": False, + } + ] expected = f"""SELECT * FROM table \nWHERE "partition_key" = {expected_value}""" result = spec.where_latest_partition( diff --git a/tests/unit_tests/queries/dao_test.py b/tests/unit_tests/queries/dao_test.py new file mode 100644 index 000000000..a0221b801 --- /dev/null +++ b/tests/unit_tests/queries/dao_test.py @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import json + + +def test_column_attributes_on_query(): + from superset.daos.query import QueryDAO + from superset.models.core import Database + from superset.models.sql_lab import Query + + db = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + query_obj = Query( + client_id="foo", + database=db, + tab_name="test_tab", + sql_editor_id="test_editor_id", + sql="select * from bar", + select_sql="select * from bar", + executed_sql="select * from bar", + limit=100, + select_as_cta=False, + rows=100, + error_message="none", + results_key="abc", + ) + + columns = [{"name": "test", "is_dttm": False, "type": "INT"}] + payload = {"columns": columns} + + QueryDAO.save_metadata(query_obj, payload) + assert "column_name" in json.loads(query_obj.extra_json).get("columns")[0]