fix: save columns reference from sqllab save datasets flow (#24248)
This commit is contained in:
parent
fdef9cbc96
commit
93e1db4bd9
|
|
@ -48,7 +48,7 @@ export const DATASET_TIME_COLUMN_OPTION: ColumnMeta = {
|
|||
};
|
||||
|
||||
export const QUERY_TIME_COLUMN_OPTION: QueryColumn = {
|
||||
name: DTTM_ALIAS,
|
||||
column_name: DTTM_ALIAS,
|
||||
type: DatasourceType.Query,
|
||||
is_dttm: false,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -16,8 +16,8 @@
|
|||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
import { ensureIsArray, QueryResponse } from '@superset-ui/core';
|
||||
import { Dataset, isColumnMeta, isDataset, isQueryResponse } from '../types';
|
||||
import { QueryResponse } from '@superset-ui/core';
|
||||
import { Dataset, isColumnMeta, isDataset } from '../types';
|
||||
|
||||
/**
|
||||
* Convert Datasource columns to column choices
|
||||
|
|
@ -35,14 +35,5 @@ export default function columnChoices(
|
|||
opt1[1].toLowerCase() > opt2[1].toLowerCase() ? 1 : -1,
|
||||
);
|
||||
}
|
||||
|
||||
if (isQueryResponse(datasource)) {
|
||||
return ensureIsArray(datasource.columns)
|
||||
.map((col): [string, string] => [col.name, col.name])
|
||||
.sort((opt1, opt2) =>
|
||||
opt1[1].toLowerCase() > opt2[1].toLowerCase() ? 1 : -1,
|
||||
);
|
||||
}
|
||||
|
||||
return [];
|
||||
}
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ test('get temporal columns from a QueryResponse', () => {
|
|||
expect(getTemporalColumns(testQueryResponse)).toEqual({
|
||||
temporalColumns: [
|
||||
{
|
||||
name: 'Column 2',
|
||||
column_name: 'Column 2',
|
||||
type: 'TIMESTAMP',
|
||||
is_dttm: true,
|
||||
},
|
||||
|
|
|
|||
|
|
@ -247,8 +247,8 @@ export const CtasEnum = {
|
|||
};
|
||||
|
||||
export type QueryColumn = {
|
||||
name: string;
|
||||
column_name?: string;
|
||||
name?: string;
|
||||
column_name: string;
|
||||
type: string | null;
|
||||
is_dttm: boolean;
|
||||
};
|
||||
|
|
@ -380,17 +380,17 @@ export const testQuery: Query = {
|
|||
type: DatasourceType.Query,
|
||||
columns: [
|
||||
{
|
||||
name: 'Column 1',
|
||||
column_name: 'Column 1',
|
||||
type: 'STRING',
|
||||
is_dttm: false,
|
||||
},
|
||||
{
|
||||
name: 'Column 3',
|
||||
column_name: 'Column 3',
|
||||
type: 'STRING',
|
||||
is_dttm: false,
|
||||
},
|
||||
{
|
||||
name: 'Column 2',
|
||||
column_name: 'Column 2',
|
||||
type: 'TIMESTAMP',
|
||||
is_dttm: true,
|
||||
},
|
||||
|
|
@ -402,17 +402,17 @@ export const testQueryResults = {
|
|||
displayLimitReached: false,
|
||||
columns: [
|
||||
{
|
||||
name: 'Column 1',
|
||||
column_name: 'Column 1',
|
||||
type: 'STRING',
|
||||
is_dttm: false,
|
||||
},
|
||||
{
|
||||
name: 'Column 3',
|
||||
column_name: 'Column 3',
|
||||
type: 'STRING',
|
||||
is_dttm: false,
|
||||
},
|
||||
{
|
||||
name: 'Column 2',
|
||||
column_name: 'Column 2',
|
||||
type: 'TIMESTAMP',
|
||||
is_dttm: true,
|
||||
},
|
||||
|
|
@ -423,17 +423,17 @@ export const testQueryResults = {
|
|||
expanded_columns: [],
|
||||
selected_columns: [
|
||||
{
|
||||
name: 'Column 1',
|
||||
column_name: 'Column 1',
|
||||
type: 'STRING',
|
||||
is_dttm: false,
|
||||
},
|
||||
{
|
||||
name: 'Column 3',
|
||||
column_name: 'Column 3',
|
||||
type: 'STRING',
|
||||
is_dttm: false,
|
||||
},
|
||||
{
|
||||
name: 'Column 2',
|
||||
column_name: 'Column 2',
|
||||
type: 'TIMESTAMP',
|
||||
is_dttm: true,
|
||||
},
|
||||
|
|
|
|||
|
|
@ -122,10 +122,10 @@ describe('ResultSet', () => {
|
|||
expect(table).toBeInTheDocument();
|
||||
|
||||
const firstColumn = queryAllByText(
|
||||
mockedProps.query.results?.columns[0].name ?? '',
|
||||
mockedProps.query.results?.columns[0].column_name ?? '',
|
||||
)[0];
|
||||
const secondColumn = queryAllByText(
|
||||
mockedProps.query.results?.columns[1].name ?? '',
|
||||
mockedProps.query.results?.columns[1].column_name ?? '',
|
||||
)[0];
|
||||
expect(firstColumn).toBeInTheDocument();
|
||||
expect(secondColumn).toBeInTheDocument();
|
||||
|
|
|
|||
|
|
@ -205,7 +205,7 @@ const ResultSet = ({
|
|||
...EXPLORE_CHART_DEFAULT,
|
||||
datasource: `${results.query_id}__query`,
|
||||
...{
|
||||
all_columns: results.columns.map(column => column.name),
|
||||
all_columns: results.columns.map(column => column.column_name),
|
||||
},
|
||||
});
|
||||
const url = mountExploreUrl(null, {
|
||||
|
|
@ -491,7 +491,7 @@ const ResultSet = ({
|
|||
}
|
||||
if (data && data.length > 0) {
|
||||
const expandedColumns = results.expanded_columns
|
||||
? results.expanded_columns.map(col => col.name)
|
||||
? results.expanded_columns.map(col => col.column_name)
|
||||
: [];
|
||||
return (
|
||||
<>
|
||||
|
|
@ -500,7 +500,7 @@ const ResultSet = ({
|
|||
{sql}
|
||||
<FilterableTable
|
||||
data={data}
|
||||
orderedColumnKeys={results.columns.map(col => col.name)}
|
||||
orderedColumnKeys={results.columns.map(col => col.column_name)}
|
||||
height={rowsHeight}
|
||||
filterText={searchText}
|
||||
expandedColumns={expandedColumns}
|
||||
|
|
|
|||
|
|
@ -62,6 +62,7 @@ export type ExploreQuery = QueryResponse & {
|
|||
|
||||
export interface ISimpleColumn {
|
||||
column_name?: string | null;
|
||||
name?: string | null;
|
||||
type?: string | null;
|
||||
is_dttm?: boolean | null;
|
||||
}
|
||||
|
|
@ -199,14 +200,15 @@ export const SaveDatasetModal = ({
|
|||
return;
|
||||
}
|
||||
setLoading(true);
|
||||
|
||||
const [, key] = await Promise.all([
|
||||
updateDataset(
|
||||
datasource?.dbId,
|
||||
datasetToOverwrite?.datasetid,
|
||||
datasource?.sql,
|
||||
datasource?.columns?.map(
|
||||
(d: { name: string; type: string; is_dttm: boolean }) => ({
|
||||
column_name: d.name,
|
||||
(d: { column_name: string; type: string; is_dttm: boolean }) => ({
|
||||
column_name: d.column_name,
|
||||
type: d.type,
|
||||
is_dttm: d.is_dttm,
|
||||
}),
|
||||
|
|
@ -292,12 +294,10 @@ export const SaveDatasetModal = ({
|
|||
|
||||
dispatch(
|
||||
createDatasource({
|
||||
schema: datasource.schema,
|
||||
sql: datasource.sql,
|
||||
dbId: datasource.dbId || datasource?.database?.id,
|
||||
templateParams,
|
||||
datasourceName: datasetName,
|
||||
columns: selectedColumns,
|
||||
}),
|
||||
)
|
||||
.then((data: { id: number }) =>
|
||||
|
|
|
|||
|
|
@ -236,24 +236,24 @@ export const queries = [
|
|||
columns: [
|
||||
{
|
||||
is_dttm: true,
|
||||
name: 'ds',
|
||||
column_name: 'ds',
|
||||
type: 'STRING',
|
||||
},
|
||||
{
|
||||
is_dttm: false,
|
||||
name: 'gender',
|
||||
column_name: 'gender',
|
||||
type: 'STRING',
|
||||
},
|
||||
],
|
||||
selected_columns: [
|
||||
{
|
||||
is_dttm: true,
|
||||
name: 'ds',
|
||||
column_name: 'ds',
|
||||
type: 'STRING',
|
||||
},
|
||||
{
|
||||
is_dttm: false,
|
||||
name: 'gender',
|
||||
column_name: 'gender',
|
||||
type: 'STRING',
|
||||
},
|
||||
],
|
||||
|
|
@ -326,7 +326,7 @@ export const queryWithNoQueryLimit = {
|
|||
columns: [
|
||||
{
|
||||
is_dttm: true,
|
||||
name: 'ds',
|
||||
column_name: 'ds',
|
||||
type: 'STRING',
|
||||
},
|
||||
{
|
||||
|
|
@ -338,12 +338,12 @@ export const queryWithNoQueryLimit = {
|
|||
selected_columns: [
|
||||
{
|
||||
is_dttm: true,
|
||||
name: 'ds',
|
||||
column_name: 'ds',
|
||||
type: 'STRING',
|
||||
},
|
||||
{
|
||||
is_dttm: false,
|
||||
name: 'gender',
|
||||
column_name: 'gender',
|
||||
type: 'STRING',
|
||||
},
|
||||
],
|
||||
|
|
@ -364,57 +364,57 @@ export const queryWithBadColumns = {
|
|||
selected_columns: [
|
||||
{
|
||||
is_dttm: true,
|
||||
name: 'COUNT(*)',
|
||||
column_name: 'COUNT(*)',
|
||||
type: 'STRING',
|
||||
},
|
||||
{
|
||||
is_dttm: false,
|
||||
name: 'this_col_is_ok',
|
||||
column_name: 'this_col_is_ok',
|
||||
type: 'STRING',
|
||||
},
|
||||
{
|
||||
is_dttm: false,
|
||||
name: 'a',
|
||||
column_name: 'a',
|
||||
type: 'STRING',
|
||||
},
|
||||
{
|
||||
is_dttm: false,
|
||||
name: '1',
|
||||
column_name: '1',
|
||||
type: 'STRING',
|
||||
},
|
||||
{
|
||||
is_dttm: false,
|
||||
name: '123',
|
||||
column_name: '123',
|
||||
type: 'STRING',
|
||||
},
|
||||
{
|
||||
is_dttm: false,
|
||||
name: 'CASE WHEN 1=1 THEN 1 ELSE 0 END',
|
||||
column_name: 'CASE WHEN 1=1 THEN 1 ELSE 0 END',
|
||||
type: 'STRING',
|
||||
},
|
||||
{
|
||||
is_dttm: true,
|
||||
name: '_TIMESTAMP',
|
||||
column_name: '_TIMESTAMP',
|
||||
type: 'TIMESTAMP',
|
||||
},
|
||||
{
|
||||
is_dttm: true,
|
||||
name: '__TIME',
|
||||
column_name: '__TIME',
|
||||
type: 'TIMESTAMP',
|
||||
},
|
||||
{
|
||||
is_dttm: false,
|
||||
name: 'my_dupe_col__2',
|
||||
column_name: 'my_dupe_col__2',
|
||||
type: 'STRING',
|
||||
},
|
||||
{
|
||||
is_dttm: true,
|
||||
name: '__timestamp',
|
||||
column_name: '__timestamp',
|
||||
type: 'TIMESTAMP',
|
||||
},
|
||||
{
|
||||
is_dttm: true,
|
||||
name: '__TIMESTAMP',
|
||||
column_name: '__TIMESTAMP',
|
||||
type: 'TIMESTAMP',
|
||||
},
|
||||
],
|
||||
|
|
@ -572,31 +572,31 @@ const baseQuery: QueryResponse = {
|
|||
columns: [
|
||||
{
|
||||
is_dttm: true,
|
||||
name: 'ds',
|
||||
column_name: 'ds',
|
||||
type: 'STRING',
|
||||
},
|
||||
{
|
||||
is_dttm: false,
|
||||
name: 'gender',
|
||||
column_name: 'gender',
|
||||
type: 'STRING',
|
||||
},
|
||||
],
|
||||
selected_columns: [
|
||||
{
|
||||
is_dttm: true,
|
||||
name: 'ds',
|
||||
column_name: 'ds',
|
||||
type: 'STRING',
|
||||
},
|
||||
{
|
||||
is_dttm: false,
|
||||
name: 'gender',
|
||||
column_name: 'gender',
|
||||
type: 'STRING',
|
||||
},
|
||||
],
|
||||
expanded_columns: [
|
||||
{
|
||||
is_dttm: true,
|
||||
name: 'ds',
|
||||
column_name: 'ds',
|
||||
type: 'STRING',
|
||||
},
|
||||
],
|
||||
|
|
|
|||
|
|
@ -688,8 +688,9 @@ class DatasourceEditor extends React.PureComponent {
|
|||
}
|
||||
|
||||
updateColumns(cols) {
|
||||
// cols: Array<{column_name: string; is_dttm: boolean; type: string;}>
|
||||
const { databaseColumns } = this.state;
|
||||
const databaseColumnNames = cols.map(col => col.name);
|
||||
const databaseColumnNames = cols.map(col => col.column_name);
|
||||
const currentCols = databaseColumns.reduce(
|
||||
(agg, col) => ({
|
||||
...agg,
|
||||
|
|
@ -706,18 +707,18 @@ class DatasourceEditor extends React.PureComponent {
|
|||
.filter(col => !databaseColumnNames.includes(col)),
|
||||
};
|
||||
cols.forEach(col => {
|
||||
const currentCol = currentCols[col.name];
|
||||
const currentCol = currentCols[col.column_name];
|
||||
if (!currentCol) {
|
||||
// new column
|
||||
finalColumns.push({
|
||||
id: shortid.generate(),
|
||||
column_name: col.name,
|
||||
column_name: col.column_name,
|
||||
type: col.type,
|
||||
groupby: true,
|
||||
filterable: true,
|
||||
is_dttm: col.is_dttm,
|
||||
});
|
||||
results.added.push(col.name);
|
||||
results.added.push(col.column_name);
|
||||
} else if (
|
||||
currentCol.type !== col.type ||
|
||||
(!currentCol.is_dttm && col.is_dttm)
|
||||
|
|
@ -728,7 +729,7 @@ class DatasourceEditor extends React.PureComponent {
|
|||
type: col.type,
|
||||
is_dttm: currentCol.is_dttm || col.is_dttm,
|
||||
});
|
||||
results.modified.push(col.name);
|
||||
results.modified.push(col.column_name);
|
||||
} else {
|
||||
// unchanged
|
||||
finalColumns.push(currentCol);
|
||||
|
|
|
|||
|
|
@ -174,7 +174,6 @@ class SaveModal extends React.Component<SaveModalProps, SaveModalState> {
|
|||
if (this.props.datasource?.type === DatasourceType.Query) {
|
||||
const { schema, sql, database } = this.props.datasource;
|
||||
const { templateParams } = this.props.datasource;
|
||||
const columns = this.props.datasource?.columns || [];
|
||||
|
||||
await this.props.actions.saveDataset({
|
||||
schema,
|
||||
|
|
@ -182,7 +181,6 @@ class SaveModal extends React.Component<SaveModalProps, SaveModalState> {
|
|||
database,
|
||||
templateParams,
|
||||
datasourceName: this.state.datasetName,
|
||||
columns,
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -35,7 +35,12 @@ from superset.constants import EMPTY_STRING, NULL_STRING
|
|||
from superset.datasets.commands.exceptions import DatasetNotFoundError
|
||||
from superset.models.helpers import AuditMixinNullable, ImportExportMixin, QueryResult
|
||||
from superset.models.slice import Slice
|
||||
from superset.superset_typing import FilterValue, FilterValues, QueryObjectDict
|
||||
from superset.superset_typing import (
|
||||
FilterValue,
|
||||
FilterValues,
|
||||
QueryObjectDict,
|
||||
ResultSetColumnType,
|
||||
)
|
||||
from superset.utils import core as utils
|
||||
from superset.utils.core import GenericDataType, MediumText
|
||||
|
||||
|
|
@ -456,7 +461,7 @@ class BaseDatasource(
|
|||
values = values[0] if values else None
|
||||
return values
|
||||
|
||||
def external_metadata(self) -> list[dict[str, str]]:
|
||||
def external_metadata(self) -> list[ResultSetColumnType]:
|
||||
"""Returns column information from the external system"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
|
|
|||
|
|
@ -105,7 +105,13 @@ from superset.models.helpers import (
|
|||
validate_adhoc_subquery,
|
||||
)
|
||||
from superset.sql_parse import ParsedQuery, sanitize_clause
|
||||
from superset.superset_typing import AdhocColumn, AdhocMetric, Metric, QueryObjectDict
|
||||
from superset.superset_typing import (
|
||||
AdhocColumn,
|
||||
AdhocMetric,
|
||||
Metric,
|
||||
QueryObjectDict,
|
||||
ResultSetColumnType,
|
||||
)
|
||||
from superset.utils import core as utils
|
||||
from superset.utils.core import GenericDataType, MediumText
|
||||
|
||||
|
|
@ -700,10 +706,10 @@ class SqlaTable(
|
|||
def sql_url(self) -> str:
|
||||
return self.database.sql_url + "?table_name=" + str(self.table_name)
|
||||
|
||||
def external_metadata(self) -> list[dict[str, str]]:
|
||||
def external_metadata(self) -> list[ResultSetColumnType]:
|
||||
# todo(yongjie): create a physical table column type in a separate PR
|
||||
if self.sql:
|
||||
return get_virtual_table_metadata(dataset=self) # type: ignore
|
||||
return get_virtual_table_metadata(dataset=self)
|
||||
return get_physical_table_metadata(
|
||||
database=self.database,
|
||||
table_name=self.table_name,
|
||||
|
|
@ -995,7 +1001,7 @@ class SqlaTable(
|
|||
qry = sa.select([sqla_column]).limit(1).select_from(tbl)
|
||||
sql = self.database.compile_sqla_query(qry)
|
||||
col_desc = get_columns_description(self.database, sql)
|
||||
is_dttm = col_desc[0]["is_dttm"]
|
||||
is_dttm = col_desc[0]["is_dttm"] # type: ignore
|
||||
except SupersetGenericDBErrorException as ex:
|
||||
raise ColumnNotFoundException(message=str(ex)) from ex
|
||||
|
||||
|
|
@ -1260,18 +1266,18 @@ class SqlaTable(
|
|||
removed=[
|
||||
col
|
||||
for col in old_columns_by_name
|
||||
if col not in {col["name"] for col in new_columns}
|
||||
if col not in {col["column_name"] for col in new_columns}
|
||||
]
|
||||
)
|
||||
|
||||
# clear old columns before adding modified columns back
|
||||
columns = []
|
||||
for col in new_columns:
|
||||
old_column = old_columns_by_name.pop(col["name"], None)
|
||||
old_column = old_columns_by_name.pop(col["column_name"], None)
|
||||
if not old_column:
|
||||
results.added.append(col["name"])
|
||||
results.added.append(col["column_name"])
|
||||
new_column = TableColumn(
|
||||
column_name=col["name"],
|
||||
column_name=col["column_name"],
|
||||
type=col["type"],
|
||||
table=self,
|
||||
)
|
||||
|
|
@ -1280,14 +1286,14 @@ class SqlaTable(
|
|||
else:
|
||||
new_column = old_column
|
||||
if new_column.type != col["type"]:
|
||||
results.modified.append(col["name"])
|
||||
results.modified.append(col["column_name"])
|
||||
new_column.type = col["type"]
|
||||
new_column.expression = ""
|
||||
new_column.groupby = True
|
||||
new_column.filterable = True
|
||||
columns.append(new_column)
|
||||
if not any_date_col and new_column.is_temporal:
|
||||
any_date_col = col["name"]
|
||||
any_date_col = col["column_name"]
|
||||
|
||||
# add back calculated (virtual) columns
|
||||
columns.extend([col for col in old_columns if col.expression])
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from __future__ import annotations
|
|||
import logging
|
||||
from collections.abc import Iterable, Iterator
|
||||
from functools import lru_cache
|
||||
from typing import Any, Callable, TYPE_CHECKING, TypeVar
|
||||
from typing import Callable, TYPE_CHECKING, TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
from flask_babel import lazy_gettext as _
|
||||
|
|
@ -49,7 +49,7 @@ def get_physical_table_metadata(
|
|||
database: Database,
|
||||
table_name: str,
|
||||
schema_name: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
) -> list[ResultSetColumnType]:
|
||||
"""Use SQLAlchemy inspector to get table metadata"""
|
||||
db_engine_spec = database.db_engine_spec
|
||||
db_dialect = database.get_dialect()
|
||||
|
|
|
|||
|
|
@ -63,6 +63,9 @@ class QueryDAO(BaseDAO):
|
|||
def save_metadata(query: Query, payload: dict[str, Any]) -> None:
|
||||
# pull relevant data from payload and store in extra_json
|
||||
columns = payload.get("columns", {})
|
||||
for col in columns:
|
||||
if "name" in col:
|
||||
col["column_name"] = col.get("name")
|
||||
db.session.add(query)
|
||||
query.set_extra_json_key("columns", columns)
|
||||
|
||||
|
|
|
|||
|
|
@ -79,10 +79,10 @@ def get_table_metadata(
|
|||
dtype = get_col_type(col)
|
||||
payload_columns.append(
|
||||
{
|
||||
"name": col["name"],
|
||||
"name": col["column_name"],
|
||||
"type": dtype.split("(")[0] if "(" in dtype else dtype,
|
||||
"longType": dtype,
|
||||
"keys": [k for k in keys if col["name"] in k["column_names"]],
|
||||
"keys": [k for k in keys if col["column_name"] in k["column_names"]],
|
||||
"comment": col.get("comment"),
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -311,7 +311,7 @@ class DatasetRestApi(BaseSupersetModelRestApi):
|
|||
|
||||
try:
|
||||
new_model = CreateDatasetCommand(item).run()
|
||||
return self.response(201, id=new_model.id, result=item)
|
||||
return self.response(201, id=new_model.id, result=item, data=new_model.data)
|
||||
except DatasetInvalidError as ex:
|
||||
return self.response_422(message=ex.normalized_messages())
|
||||
except DatasetCreateFailedError as ex:
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from marshmallow import ValidationError
|
|||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from superset.commands.base import BaseCommand, CreateMixin
|
||||
from superset.connectors.sqla.models import SqlMetric
|
||||
from superset.daos.dataset import DatasetDAO
|
||||
from superset.daos.exceptions import DAOCreateFailedError
|
||||
from superset.datasets.commands.exceptions import (
|
||||
|
|
@ -45,7 +46,10 @@ class CreateDatasetCommand(CreateMixin, BaseCommand):
|
|||
try:
|
||||
# Creates SqlaTable (Dataset)
|
||||
dataset = DatasetDAO.create(self._properties, commit=False)
|
||||
|
||||
# Updates columns and metrics from the dataset
|
||||
dataset.metrics = [SqlMetric(metric_name="count", expression="COUNT(*)")]
|
||||
|
||||
dataset.fetch_metadata(commit=False)
|
||||
db.session.commit()
|
||||
except (SQLAlchemyError, DAOCreateFailedError) as ex:
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ import logging
|
|||
import re
|
||||
from datetime import datetime
|
||||
from re import Match, Pattern
|
||||
from typing import Any, Callable, ContextManager, NamedTuple, TYPE_CHECKING, Union
|
||||
from typing import Any, Callable, cast, ContextManager, NamedTuple, TYPE_CHECKING, Union
|
||||
|
||||
import pandas as pd
|
||||
import sqlparse
|
||||
|
|
@ -53,7 +53,7 @@ from superset.constants import TimeGrain as TimeGrainConstants
|
|||
from superset.databases.utils import make_url_safe
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.sql_parse import ParsedQuery, Table
|
||||
from superset.superset_typing import ResultSetColumnType
|
||||
from superset.superset_typing import ResultSetColumnType, SQLAColumnType
|
||||
from superset.utils import core as utils
|
||||
from superset.utils.core import ColumnSpec, GenericDataType
|
||||
from superset.utils.hashing import md5_sha_from_str
|
||||
|
|
@ -73,6 +73,13 @@ ColumnTypeMapping = tuple[
|
|||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def convert_inspector_columns(cols: list[SQLAColumnType]) -> list[ResultSetColumnType]:
|
||||
result_set_columns: list[ResultSetColumnType] = []
|
||||
for col in cols:
|
||||
result_set_columns.append({"column_name": col.get("name"), **col}) # type: ignore
|
||||
return result_set_columns
|
||||
|
||||
|
||||
class TimeGrain(NamedTuple):
|
||||
name: str # TODO: redundant field, remove
|
||||
label: str
|
||||
|
|
@ -1223,7 +1230,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
@classmethod
|
||||
def get_columns(
|
||||
cls, inspector: Inspector, table_name: str, schema: str | None
|
||||
) -> list[dict[str, Any]]:
|
||||
) -> list[ResultSetColumnType]:
|
||||
"""
|
||||
Get all columns from a given schema and table
|
||||
|
||||
|
|
@ -1232,7 +1239,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
:param schema: Schema name. If omitted, uses default schema for database
|
||||
:return: All columns in table
|
||||
"""
|
||||
return inspector.get_columns(table_name, schema)
|
||||
return convert_inspector_columns(
|
||||
cast(list[SQLAColumnType], inspector.get_columns(table_name, schema))
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_metrics( # pylint: disable=unused-argument
|
||||
|
|
@ -1261,7 +1270,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
schema: str | None,
|
||||
database: Database,
|
||||
query: Select,
|
||||
columns: list[dict[str, Any]] | None = None,
|
||||
columns: list[ResultSetColumnType] | None = None,
|
||||
) -> Select | None:
|
||||
"""
|
||||
Add a where clause to a query to reference only the most recent partition
|
||||
|
|
@ -1278,8 +1287,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
return None
|
||||
|
||||
@classmethod
|
||||
def _get_fields(cls, cols: list[dict[str, Any]]) -> list[Any]:
|
||||
return [column(c["name"]) for c in cols]
|
||||
def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[Any]:
|
||||
return [column(c["column_name"]) for c in cols]
|
||||
|
||||
@classmethod
|
||||
def select_star( # pylint: disable=too-many-arguments,too-many-locals
|
||||
|
|
@ -1292,7 +1301,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
show_cols: bool = False,
|
||||
indent: bool = True,
|
||||
latest_partition: bool = True,
|
||||
cols: list[dict[str, Any]] | None = None,
|
||||
cols: list[ResultSetColumnType] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a "SELECT * from [schema.]table_name" query with appropriate limit.
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ from superset.db_engine_specs.exceptions import SupersetDBAPIConnectionError
|
|||
from superset.errors import SupersetError, SupersetErrorType
|
||||
from superset.exceptions import SupersetException
|
||||
from superset.sql_parse import Table
|
||||
from superset.superset_typing import ResultSetColumnType
|
||||
from superset.utils import core as utils
|
||||
from superset.utils.hashing import md5_sha_from_str
|
||||
|
||||
|
|
@ -637,7 +638,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
|
|||
show_cols: bool = False,
|
||||
indent: bool = True,
|
||||
latest_partition: bool = True,
|
||||
cols: Optional[list[dict[str, Any]]] = None,
|
||||
cols: Optional[list[ResultSetColumnType]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Remove array structures from `SELECT *`.
|
||||
|
|
@ -678,13 +679,15 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
|
|||
# For arrays of structs, remove the child columns, otherwise the query
|
||||
# will fail.
|
||||
array_prefixes = {
|
||||
col["name"] for col in cols if isinstance(col["type"], sqltypes.ARRAY)
|
||||
col["column_name"]
|
||||
for col in cols
|
||||
if isinstance(col["type"], sqltypes.ARRAY)
|
||||
}
|
||||
cols = [
|
||||
col
|
||||
for col in cols
|
||||
if "." not in col["name"]
|
||||
or col["name"].split(".")[0] not in array_prefixes
|
||||
if "." not in col["column_name"]
|
||||
or col["column_name"].split(".")[0] not in array_prefixes
|
||||
]
|
||||
|
||||
return super().select_star(
|
||||
|
|
@ -700,7 +703,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def _get_fields(cls, cols: list[dict[str, Any]]) -> list[Any]:
|
||||
def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[Any]:
|
||||
"""
|
||||
Label columns using their fully qualified name.
|
||||
|
||||
|
|
@ -725,7 +728,10 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
|
|||
the columns using their fully qualified name, so we end up with "author",
|
||||
"author__name" and "author__email", respectively.
|
||||
"""
|
||||
return [column(c["name"]).label(c["name"].replace(".", "__")) for c in cols]
|
||||
return [
|
||||
column(c["column_name"]).label(c["column_name"].replace(".", "__"))
|
||||
for c in cols
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def parse_error_exception(cls, exception: Exception) -> Exception:
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from superset.constants import TimeGrain
|
|||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.db_engine_specs.exceptions import SupersetDBAPIConnectionError
|
||||
from superset.exceptions import SupersetException
|
||||
from superset.superset_typing import ResultSetColumnType
|
||||
from superset.utils import core as utils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -132,7 +133,7 @@ class DruidEngineSpec(BaseEngineSpec):
|
|||
@classmethod
|
||||
def get_columns(
|
||||
cls, inspector: Inspector, table_name: str, schema: str | None
|
||||
) -> list[dict[str, Any]]:
|
||||
) -> list[ResultSetColumnType]:
|
||||
"""
|
||||
Update the Druid type map.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ from superset.exceptions import SupersetException
|
|||
from superset.extensions import cache_manager
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.sql_parse import ParsedQuery, Table
|
||||
from superset.superset_typing import ResultSetColumnType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# prevent circular imports
|
||||
|
|
@ -407,8 +408,8 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||
@classmethod
|
||||
def get_columns(
|
||||
cls, inspector: Inspector, table_name: str, schema: str | None
|
||||
) -> list[dict[str, Any]]:
|
||||
return inspector.get_columns(table_name, schema)
|
||||
) -> list[ResultSetColumnType]:
|
||||
return BaseEngineSpec.get_columns(inspector, table_name, schema)
|
||||
|
||||
@classmethod
|
||||
def where_latest_partition( # pylint: disable=too-many-arguments
|
||||
|
|
@ -417,7 +418,7 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||
schema: str | None,
|
||||
database: Database,
|
||||
query: Select,
|
||||
columns: list[dict[str, Any]] | None = None,
|
||||
columns: list[ResultSetColumnType] | None = None,
|
||||
) -> Select | None:
|
||||
try:
|
||||
col_names, values = cls.latest_partition(
|
||||
|
|
@ -436,7 +437,7 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||
return None
|
||||
|
||||
@classmethod
|
||||
def _get_fields(cls, cols: list[dict[str, Any]]) -> list[ColumnClause]:
|
||||
def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[ColumnClause]:
|
||||
return BaseEngineSpec._get_fields(cols) # pylint: disable=protected-access
|
||||
|
||||
@classmethod
|
||||
|
|
@ -481,7 +482,7 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||
show_cols: bool = False,
|
||||
indent: bool = True,
|
||||
latest_partition: bool = True,
|
||||
cols: list[dict[str, Any]] | None = None,
|
||||
cols: list[ResultSetColumnType] | None = None,
|
||||
) -> str:
|
||||
return super( # pylint: disable=bad-super-call
|
||||
PrestoEngineSpec, cls
|
||||
|
|
|
|||
|
|
@ -126,7 +126,14 @@ def get_children(column: ResultSetColumnType) -> list[ResultSetColumnType]:
|
|||
type_ = group["type"].upper()
|
||||
children_type = group["children"]
|
||||
if type_ == "ARRAY":
|
||||
return [{"name": column["name"], "type": children_type, "is_dttm": False}]
|
||||
return [
|
||||
{
|
||||
"column_name": column["column_name"],
|
||||
"name": column["column_name"],
|
||||
"type": children_type,
|
||||
"is_dttm": False,
|
||||
}
|
||||
]
|
||||
|
||||
if type_ == "ROW":
|
||||
nameless_columns = 0
|
||||
|
|
@ -141,7 +148,8 @@ def get_children(column: ResultSetColumnType) -> list[ResultSetColumnType]:
|
|||
type_ = parts[0]
|
||||
nameless_columns += 1
|
||||
_column: ResultSetColumnType = {
|
||||
"name": f"{column['name']}.{name.lower()}",
|
||||
"column_name": f"{column['column_name']}.{name.lower()}",
|
||||
"name": f"{column['column_name']}.{name.lower()}",
|
||||
"type": type_,
|
||||
"is_dttm": False,
|
||||
}
|
||||
|
|
@ -482,7 +490,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
|
|||
schema: str | None,
|
||||
database: Database,
|
||||
query: Select,
|
||||
columns: list[dict[str, Any]] | None = None,
|
||||
columns: list[ResultSetColumnType] | None = None,
|
||||
) -> Select | None:
|
||||
try:
|
||||
col_names, values = cls.latest_partition(
|
||||
|
|
@ -496,7 +504,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
|
|||
return None
|
||||
|
||||
column_type_by_name = {
|
||||
column.get("name"): column.get("type") for column in columns or []
|
||||
column.get("column_name"): column.get("type") for column in columns or []
|
||||
}
|
||||
|
||||
for col_name, value in zip(col_names, values):
|
||||
|
|
@ -813,14 +821,20 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
|||
@classmethod
|
||||
def _create_column_info(
|
||||
cls, name: str, data_type: types.TypeEngine
|
||||
) -> dict[str, Any]:
|
||||
) -> ResultSetColumnType:
|
||||
"""
|
||||
Create column info object
|
||||
:param name: column name
|
||||
:param data_type: column data type
|
||||
:return: column info object
|
||||
"""
|
||||
return {"name": name, "type": f"{data_type}"}
|
||||
return {
|
||||
"column_name": name,
|
||||
"name": name,
|
||||
"type": f"{data_type}",
|
||||
"is_dttm": None,
|
||||
"type_generic": None,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _get_full_name(cls, names: list[tuple[str, str]]) -> str:
|
||||
|
|
@ -863,7 +877,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
|||
cls,
|
||||
parent_column_name: str,
|
||||
parent_data_type: str,
|
||||
result: list[dict[str, Any]],
|
||||
result: list[ResultSetColumnType],
|
||||
) -> None:
|
||||
"""
|
||||
Parse a row or array column
|
||||
|
|
@ -941,7 +955,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
|||
# Unquote the column name if necessary
|
||||
if formatted_parent_column_name != parent_column_name:
|
||||
for index in range(original_result_len, len(result)):
|
||||
result[index]["name"] = result[index]["name"].replace(
|
||||
result[index]["column_name"] = result[index]["column_name"].replace(
|
||||
formatted_parent_column_name, parent_column_name
|
||||
)
|
||||
|
||||
|
|
@ -965,7 +979,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
|||
@classmethod
|
||||
def get_columns(
|
||||
cls, inspector: Inspector, table_name: str, schema: str | None
|
||||
) -> list[dict[str, Any]]:
|
||||
) -> list[ResultSetColumnType]:
|
||||
"""
|
||||
Get columns from a Presto data source. This includes handling row and
|
||||
array data types
|
||||
|
|
@ -976,7 +990,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
|||
(i.e. column name and data type)
|
||||
"""
|
||||
columns = cls._show_columns(inspector, table_name, schema)
|
||||
result: list[dict[str, Any]] = []
|
||||
result: list[ResultSetColumnType] = []
|
||||
for column in columns:
|
||||
# parse column if it is a row or array
|
||||
if is_feature_enabled("PRESTO_EXPAND_DATA") and (
|
||||
|
|
@ -1003,6 +1017,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
|||
column_info = cls._create_column_info(column.Column, column_type)
|
||||
column_info["nullable"] = getattr(column, "Null", True)
|
||||
column_info["default"] = None
|
||||
column_info["column_name"] = column.Column
|
||||
result.append(column_info)
|
||||
return result
|
||||
|
||||
|
|
@ -1016,7 +1031,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
|||
return column_name.startswith('"') and column_name.endswith('"')
|
||||
|
||||
@classmethod
|
||||
def _get_fields(cls, cols: list[dict[str, Any]]) -> list[ColumnClause]:
|
||||
def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[ColumnClause]:
|
||||
"""
|
||||
Format column clauses where names are in quotes and labels are specified
|
||||
:param cols: columns
|
||||
|
|
@ -1034,7 +1049,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
|||
dot_regex = re.compile(dot_pattern, re.VERBOSE)
|
||||
for col in cols:
|
||||
# get individual column names
|
||||
col_names = re.split(dot_regex, col["name"])
|
||||
col_names = re.split(dot_regex, col["column_name"])
|
||||
# quote each column name if it is not already quoted
|
||||
for index, col_name in enumerate(col_names):
|
||||
if not cls._is_column_name_quoted(col_name):
|
||||
|
|
@ -1044,7 +1059,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
|||
for col_name in col_names
|
||||
)
|
||||
# create column clause in the format "name"."name" AS "name.name"
|
||||
column_clause = literal_column(quoted_col_name).label(col["name"])
|
||||
column_clause = literal_column(quoted_col_name).label(col["column_name"])
|
||||
column_clauses.append(column_clause)
|
||||
return column_clauses
|
||||
|
||||
|
|
@ -1059,7 +1074,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
|||
show_cols: bool = False,
|
||||
indent: bool = True,
|
||||
latest_partition: bool = True,
|
||||
cols: list[dict[str, Any]] | None = None,
|
||||
cols: list[ResultSetColumnType] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Include selecting properties of row objects. We cannot easily break arrays into
|
||||
|
|
@ -1071,7 +1086,9 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
|||
if is_feature_enabled("PRESTO_EXPAND_DATA") and show_cols:
|
||||
dot_regex = r"\.(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)"
|
||||
presto_cols = [
|
||||
col for col in presto_cols if not re.search(dot_regex, col["name"])
|
||||
col
|
||||
for col in presto_cols
|
||||
if not re.search(dot_regex, col["column_name"])
|
||||
]
|
||||
return super().select_star(
|
||||
database,
|
||||
|
|
@ -1123,7 +1140,9 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
|||
current_array_level = None
|
||||
while to_process:
|
||||
column, level = to_process.popleft()
|
||||
if column["name"] not in [column["name"] for column in all_columns]:
|
||||
if column["column_name"] not in [
|
||||
column["column_name"] for column in all_columns
|
||||
]:
|
||||
all_columns.append(column)
|
||||
|
||||
# When unnesting arrays we need to keep track of how many extra rows
|
||||
|
|
@ -1135,7 +1154,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
|||
unnested_rows: dict[int, int] = defaultdict(int)
|
||||
current_array_level = level
|
||||
|
||||
name = column["name"]
|
||||
name = column["column_name"]
|
||||
values: str | list[Any] | None
|
||||
|
||||
if column["type"] and column["type"].startswith("ARRAY("):
|
||||
|
|
@ -1186,10 +1205,11 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
|||
values = cast(Optional[list[Any]], destringify(values))
|
||||
row[name] = values
|
||||
for value, col in zip(values or [], expanded):
|
||||
row[col["name"]] = value
|
||||
row[col["column_name"]] = value
|
||||
|
||||
data = [
|
||||
{k["name"]: row.get(k["name"], "") for k in all_columns} for row in data
|
||||
{k["column_name"]: row.get(k["column_name"], "") for k in all_columns}
|
||||
for row in data
|
||||
]
|
||||
|
||||
return all_columns, data, expanded_columns
|
||||
|
|
|
|||
|
|
@ -69,6 +69,7 @@ from superset.extensions import (
|
|||
)
|
||||
from superset.models.helpers import AuditMixinNullable, ImportExportMixin
|
||||
from superset.result_set import SupersetResultSet
|
||||
from superset.superset_typing import ResultSetColumnType
|
||||
from superset.utils import cache as cache_util, core as utils
|
||||
from superset.utils.core import get_username
|
||||
|
||||
|
|
@ -632,7 +633,7 @@ class Database(
|
|||
show_cols: bool = False,
|
||||
indent: bool = True,
|
||||
latest_partition: bool = False,
|
||||
cols: Optional[list[dict[str, Any]]] = None,
|
||||
cols: Optional[list[ResultSetColumnType]] = None,
|
||||
) -> str:
|
||||
"""Generates a ``select *`` statement in the proper dialect"""
|
||||
eng = self._get_sqla_engine(schema=schema, source=utils.QuerySource.SQL_LAB)
|
||||
|
|
@ -837,7 +838,7 @@ class Database(
|
|||
|
||||
def get_columns(
|
||||
self, table_name: str, schema: Optional[str] = None
|
||||
) -> list[dict[str, Any]]:
|
||||
) -> list[ResultSetColumnType]:
|
||||
with self.get_inspector_with_context() as inspector:
|
||||
return self.db_engine_spec.get_columns(inspector, table_name, schema)
|
||||
|
||||
|
|
|
|||
|
|
@ -193,7 +193,7 @@ class Query(
|
|||
|
||||
return [
|
||||
TableColumn(
|
||||
column_name=col["name"],
|
||||
column_name=col["column_name"],
|
||||
database=self.database,
|
||||
is_dttm=col["is_dttm"],
|
||||
filterable=True,
|
||||
|
|
|
|||
|
|
@ -253,10 +253,10 @@ class SupersetResultSet:
|
|||
for col in self.table.schema:
|
||||
db_type_str = self.data_type(col.name, col.type)
|
||||
column: ResultSetColumnType = {
|
||||
"column_name": col.name,
|
||||
"name": col.name,
|
||||
"type": db_type_str,
|
||||
"is_dttm": self.is_temporal(db_type_str),
|
||||
}
|
||||
columns.append(column)
|
||||
|
||||
return columns
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from collections.abc import Sequence
|
|||
from datetime import datetime
|
||||
from typing import Any, Literal, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
from werkzeug.wrappers import Response
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -60,14 +60,29 @@ class AdhocColumn(TypedDict, total=False):
|
|||
timeGrain: Optional[str]
|
||||
|
||||
|
||||
class SQLAColumnType(TypedDict):
|
||||
name: str
|
||||
type: Optional[str]
|
||||
is_dttm: bool
|
||||
|
||||
|
||||
class ResultSetColumnType(TypedDict):
|
||||
"""
|
||||
Superset virtual dataset column interface
|
||||
"""
|
||||
|
||||
name: str
|
||||
name: str # legacy naming convention keeping this for backwards compatibility
|
||||
column_name: str
|
||||
type: Optional[str]
|
||||
is_dttm: bool
|
||||
is_dttm: Optional[bool]
|
||||
type_generic: NotRequired[Optional["GenericDataType"]]
|
||||
|
||||
nullable: NotRequired[Any]
|
||||
default: NotRequired[Any]
|
||||
comment: NotRequired[Any]
|
||||
precision: NotRequired[Any]
|
||||
scale: NotRequired[Any]
|
||||
max_length: NotRequired[Any]
|
||||
|
||||
|
||||
CacheConfig = dict[str, Any]
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ from superset.models.helpers import (
|
|||
ImportExportMixin,
|
||||
)
|
||||
from superset.sql_parse import Table as TableName
|
||||
from superset.superset_typing import ResultSetColumnType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.datasets.models import Dataset
|
||||
|
|
@ -131,8 +132,8 @@ class Table(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
|
|||
existing_columns = {column.name: column for column in self.columns}
|
||||
quote_identifier = self.database.quote_identifier
|
||||
|
||||
def update_or_create_column(column_meta: dict[str, Any]) -> Column:
|
||||
column_name: str = column_meta["name"]
|
||||
def update_or_create_column(column_meta: ResultSetColumnType) -> Column:
|
||||
column_name: str = column_meta["column_name"]
|
||||
if column_name in existing_columns:
|
||||
column = existing_columns[column_name]
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -536,6 +536,10 @@ def _deserialize_results_payload(
|
|||
df = result_set.SupersetResultSet.convert_table_to_df(pa_table)
|
||||
ds_payload["data"] = dataframe.df_to_records(df) or []
|
||||
|
||||
for column in ds_payload["selected_columns"]:
|
||||
if "name" in column:
|
||||
column["column_name"] = column.get("name")
|
||||
|
||||
db_engine_spec = query.database.db_engine_spec
|
||||
all_columns, data, expanded_columns = db_engine_spec.expand_data(
|
||||
ds_payload["selected_columns"], ds_payload["data"]
|
||||
|
|
|
|||
|
|
@ -791,7 +791,7 @@ class TestDatasetApi(SupersetTestCase):
|
|||
|
||||
mock_get_columns.return_value = [
|
||||
{
|
||||
"name": "col",
|
||||
"column_name": "col",
|
||||
"type": "VARCHAR",
|
||||
"type_generic": None,
|
||||
"is_dttm": None,
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ class TestDatasource(SupersetTestCase):
|
|||
tbl = self.get_table(name="birth_names")
|
||||
url = f"/datasource/external_metadata/table/{tbl.id}/"
|
||||
resp = self.get_json_resp(url)
|
||||
col_names = {o.get("name") for o in resp}
|
||||
col_names = {o.get("column_name") for o in resp}
|
||||
self.assertEqual(
|
||||
col_names, {"num_boys", "num", "gender", "name", "ds", "state", "num_girls"}
|
||||
)
|
||||
|
|
@ -91,7 +91,7 @@ class TestDatasource(SupersetTestCase):
|
|||
table = self.get_table(name="dummy_sql_table")
|
||||
url = f"/datasource/external_metadata/table/{table.id}/"
|
||||
resp = self.get_json_resp(url)
|
||||
assert {o.get("name") for o in resp} == {"intcol", "strcol"}
|
||||
assert {o.get("column_name") for o in resp} == {"intcol", "strcol"}
|
||||
session.delete(table)
|
||||
session.commit()
|
||||
|
||||
|
|
@ -109,7 +109,7 @@ class TestDatasource(SupersetTestCase):
|
|||
)
|
||||
url = f"/datasource/external_metadata_by_name/?q={params}"
|
||||
resp = self.get_json_resp(url)
|
||||
col_names = {o.get("name") for o in resp}
|
||||
col_names = {o.get("column_name") for o in resp}
|
||||
self.assertEqual(
|
||||
col_names, {"num_boys", "num", "gender", "name", "ds", "state", "num_girls"}
|
||||
)
|
||||
|
|
@ -137,7 +137,7 @@ class TestDatasource(SupersetTestCase):
|
|||
)
|
||||
url = f"/datasource/external_metadata_by_name/?q={params}"
|
||||
resp = self.get_json_resp(url)
|
||||
assert {o.get("name") for o in resp} == {"intcol", "strcol"}
|
||||
assert {o.get("column_name") for o in resp} == {"intcol", "strcol"}
|
||||
session.delete(tbl)
|
||||
session.commit()
|
||||
|
||||
|
|
@ -155,7 +155,7 @@ class TestDatasource(SupersetTestCase):
|
|||
)
|
||||
url = f"/datasource/external_metadata_by_name/?q={params}"
|
||||
resp = self.get_json_resp(url)
|
||||
col_names = {o.get("name") for o in resp}
|
||||
col_names = {o.get("column_name") for o in resp}
|
||||
self.assertEqual(col_names, {"first", "second"})
|
||||
|
||||
# No databases found
|
||||
|
|
@ -216,7 +216,7 @@ class TestDatasource(SupersetTestCase):
|
|||
table = self.get_table(name="dummy_sql_table_with_template_params")
|
||||
url = f"/datasource/external_metadata/table/{table.id}/"
|
||||
resp = self.get_json_resp(url)
|
||||
assert {o.get("name") for o in resp} == {"intcol"}
|
||||
assert {o.get("column_name") for o in resp} == {"intcol"}
|
||||
session.delete(table)
|
||||
session.commit()
|
||||
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||
results = PrestoEngineSpec.get_columns(inspector, "", "")
|
||||
self.assertEqual(len(expected_results), len(results))
|
||||
for expected_result, result in zip(expected_results, results):
|
||||
self.assertEqual(expected_result[0], result["name"])
|
||||
self.assertEqual(expected_result[0], result["column_name"])
|
||||
self.assertEqual(expected_result[1], str(result["type"]))
|
||||
|
||||
def test_presto_get_column(self):
|
||||
|
|
@ -175,21 +175,21 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||
|
||||
def test_presto_get_fields(self):
|
||||
cols = [
|
||||
{"name": "column"},
|
||||
{"name": "column.nested_obj"},
|
||||
{"name": 'column."quoted.nested obj"'},
|
||||
{"column_name": "column"},
|
||||
{"column_name": "column.nested_obj"},
|
||||
{"column_name": 'column."quoted.nested obj"'},
|
||||
]
|
||||
actual_results = PrestoEngineSpec._get_fields(cols)
|
||||
expected_results = [
|
||||
{"name": '"column"', "label": "column"},
|
||||
{"name": '"column"."nested_obj"', "label": "column.nested_obj"},
|
||||
{"column_name": '"column"', "label": "column"},
|
||||
{"column_name": '"column"."nested_obj"', "label": "column.nested_obj"},
|
||||
{
|
||||
"name": '"column"."quoted.nested obj"',
|
||||
"column_name": '"column"."quoted.nested obj"',
|
||||
"label": 'column."quoted.nested obj"',
|
||||
},
|
||||
]
|
||||
for actual_result, expected_result in zip(actual_results, expected_results):
|
||||
self.assertEqual(actual_result.element.name, expected_result["name"])
|
||||
self.assertEqual(actual_result.element.name, expected_result["column_name"])
|
||||
self.assertEqual(actual_result.name, expected_result["label"])
|
||||
|
||||
@mock.patch.dict(
|
||||
|
|
@ -199,8 +199,18 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||
)
|
||||
def test_presto_expand_data_with_simple_structural_columns(self):
|
||||
cols = [
|
||||
{"name": "row_column", "type": "ROW(NESTED_OBJ VARCHAR)", "is_dttm": False},
|
||||
{"name": "array_column", "type": "ARRAY(BIGINT)", "is_dttm": False},
|
||||
{
|
||||
"column_name": "row_column",
|
||||
"name": "row_column",
|
||||
"type": "ROW(NESTED_OBJ VARCHAR)",
|
||||
"is_dttm": False,
|
||||
},
|
||||
{
|
||||
"column_name": "array_column",
|
||||
"name": "array_column",
|
||||
"type": "ARRAY(BIGINT)",
|
||||
"is_dttm": False,
|
||||
},
|
||||
]
|
||||
data = [
|
||||
{"row_column": ["a"], "array_column": [1, 2, 3]},
|
||||
|
|
@ -210,9 +220,24 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||
cols, data
|
||||
)
|
||||
expected_cols = [
|
||||
{"name": "row_column", "type": "ROW(NESTED_OBJ VARCHAR)", "is_dttm": False},
|
||||
{"name": "row_column.nested_obj", "type": "VARCHAR", "is_dttm": False},
|
||||
{"name": "array_column", "type": "ARRAY(BIGINT)", "is_dttm": False},
|
||||
{
|
||||
"column_name": "row_column",
|
||||
"name": "row_column",
|
||||
"type": "ROW(NESTED_OBJ VARCHAR)",
|
||||
"is_dttm": False,
|
||||
},
|
||||
{
|
||||
"column_name": "row_column.nested_obj",
|
||||
"name": "row_column.nested_obj",
|
||||
"type": "VARCHAR",
|
||||
"is_dttm": False,
|
||||
},
|
||||
{
|
||||
"column_name": "array_column",
|
||||
"name": "array_column",
|
||||
"type": "ARRAY(BIGINT)",
|
||||
"is_dttm": False,
|
||||
},
|
||||
]
|
||||
|
||||
expected_data = [
|
||||
|
|
@ -225,7 +250,12 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||
]
|
||||
|
||||
expected_expanded_cols = [
|
||||
{"name": "row_column.nested_obj", "type": "VARCHAR", "is_dttm": False}
|
||||
{
|
||||
"name": "row_column.nested_obj",
|
||||
"column_name": "row_column.nested_obj",
|
||||
"type": "VARCHAR",
|
||||
"is_dttm": False,
|
||||
}
|
||||
]
|
||||
self.assertEqual(actual_cols, expected_cols)
|
||||
self.assertEqual(actual_data, expected_data)
|
||||
|
|
@ -240,6 +270,7 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||
cols = [
|
||||
{
|
||||
"name": "row_column",
|
||||
"column_name": "row_column",
|
||||
"type": "ROW(NESTED_OBJ1 VARCHAR, NESTED_ROW ROW(NESTED_OBJ2 VARCHAR))",
|
||||
"is_dttm": False,
|
||||
}
|
||||
|
|
@ -251,17 +282,25 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||
expected_cols = [
|
||||
{
|
||||
"name": "row_column",
|
||||
"column_name": "row_column",
|
||||
"type": "ROW(NESTED_OBJ1 VARCHAR, NESTED_ROW ROW(NESTED_OBJ2 VARCHAR))",
|
||||
"is_dttm": False,
|
||||
},
|
||||
{"name": "row_column.nested_obj1", "type": "VARCHAR", "is_dttm": False},
|
||||
{
|
||||
"name": "row_column.nested_obj1",
|
||||
"column_name": "row_column.nested_obj1",
|
||||
"type": "VARCHAR",
|
||||
"is_dttm": False,
|
||||
},
|
||||
{
|
||||
"name": "row_column.nested_row",
|
||||
"column_name": "row_column.nested_row",
|
||||
"type": "ROW(NESTED_OBJ2 VARCHAR)",
|
||||
"is_dttm": False,
|
||||
},
|
||||
{
|
||||
"name": "row_column.nested_row.nested_obj2",
|
||||
"column_name": "row_column.nested_row.nested_obj2",
|
||||
"type": "VARCHAR",
|
||||
"is_dttm": False,
|
||||
},
|
||||
|
|
@ -282,14 +321,21 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||
]
|
||||
|
||||
expected_expanded_cols = [
|
||||
{"name": "row_column.nested_obj1", "type": "VARCHAR", "is_dttm": False},
|
||||
{
|
||||
"name": "row_column.nested_obj1",
|
||||
"column_name": "row_column.nested_obj1",
|
||||
"type": "VARCHAR",
|
||||
"is_dttm": False,
|
||||
},
|
||||
{
|
||||
"name": "row_column.nested_row",
|
||||
"column_name": "row_column.nested_row",
|
||||
"type": "ROW(NESTED_OBJ2 VARCHAR)",
|
||||
"is_dttm": False,
|
||||
},
|
||||
{
|
||||
"name": "row_column.nested_row.nested_obj2",
|
||||
"column_name": "row_column.nested_row.nested_obj2",
|
||||
"type": "VARCHAR",
|
||||
"is_dttm": False,
|
||||
},
|
||||
|
|
@ -307,6 +353,7 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||
cols = [
|
||||
{
|
||||
"name": "row_column",
|
||||
"column_name": "row_column",
|
||||
"type": "ROW(NESTED_ROW ROW(NESTED_OBJ VARCHAR))",
|
||||
"is_dttm": False,
|
||||
}
|
||||
|
|
@ -323,16 +370,19 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||
expected_cols = [
|
||||
{
|
||||
"name": "row_column",
|
||||
"column_name": "row_column",
|
||||
"type": "ROW(NESTED_ROW ROW(NESTED_OBJ VARCHAR))",
|
||||
"is_dttm": False,
|
||||
},
|
||||
{
|
||||
"name": "row_column.nested_row",
|
||||
"column_name": "row_column.nested_row",
|
||||
"type": "ROW(NESTED_OBJ VARCHAR)",
|
||||
"is_dttm": False,
|
||||
},
|
||||
{
|
||||
"name": "row_column.nested_row.nested_obj",
|
||||
"column_name": "row_column.nested_row.nested_obj",
|
||||
"type": "VARCHAR",
|
||||
"is_dttm": False,
|
||||
},
|
||||
|
|
@ -363,11 +413,13 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||
expected_expanded_cols = [
|
||||
{
|
||||
"name": "row_column.nested_row",
|
||||
"column_name": "row_column.nested_row",
|
||||
"type": "ROW(NESTED_OBJ VARCHAR)",
|
||||
"is_dttm": False,
|
||||
},
|
||||
{
|
||||
"name": "row_column.nested_row.nested_obj",
|
||||
"column_name": "row_column.nested_row.nested_obj",
|
||||
"type": "VARCHAR",
|
||||
"is_dttm": False,
|
||||
},
|
||||
|
|
@ -383,9 +435,15 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||
)
|
||||
def test_presto_expand_data_with_complex_array_columns(self):
|
||||
cols = [
|
||||
{"name": "int_column", "type": "BIGINT", "is_dttm": False},
|
||||
{
|
||||
"name": "int_column",
|
||||
"column_name": "int_column",
|
||||
"type": "BIGINT",
|
||||
"is_dttm": False,
|
||||
},
|
||||
{
|
||||
"name": "array_column",
|
||||
"column_name": "array_column",
|
||||
"type": "ARRAY(ROW(NESTED_ARRAY ARRAY(ROW(NESTED_OBJ VARCHAR))))",
|
||||
"is_dttm": False,
|
||||
},
|
||||
|
|
@ -398,19 +456,27 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||
cols, data
|
||||
)
|
||||
expected_cols = [
|
||||
{"name": "int_column", "type": "BIGINT", "is_dttm": False},
|
||||
{
|
||||
"name": "int_column",
|
||||
"column_name": "int_column",
|
||||
"type": "BIGINT",
|
||||
"is_dttm": False,
|
||||
},
|
||||
{
|
||||
"name": "array_column",
|
||||
"column_name": "array_column",
|
||||
"type": "ARRAY(ROW(NESTED_ARRAY ARRAY(ROW(NESTED_OBJ VARCHAR))))",
|
||||
"is_dttm": False,
|
||||
},
|
||||
{
|
||||
"name": "array_column.nested_array",
|
||||
"column_name": "array_column.nested_array",
|
||||
"type": "ARRAY(ROW(NESTED_OBJ VARCHAR))",
|
||||
"is_dttm": False,
|
||||
},
|
||||
{
|
||||
"name": "array_column.nested_array.nested_obj",
|
||||
"column_name": "array_column.nested_array.nested_obj",
|
||||
"type": "VARCHAR",
|
||||
"is_dttm": False,
|
||||
},
|
||||
|
|
@ -468,11 +534,13 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||
expected_expanded_cols = [
|
||||
{
|
||||
"name": "array_column.nested_array",
|
||||
"column_name": "array_column.nested_array",
|
||||
"type": "ARRAY(ROW(NESTED_OBJ VARCHAR))",
|
||||
"is_dttm": False,
|
||||
},
|
||||
{
|
||||
"name": "array_column.nested_array.nested_obj",
|
||||
"column_name": "array_column.nested_array.nested_obj",
|
||||
"type": "VARCHAR",
|
||||
"is_dttm": False,
|
||||
},
|
||||
|
|
@ -575,9 +643,20 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||
)
|
||||
def test_presto_expand_data_array(self):
|
||||
cols = [
|
||||
{"name": "event_id", "type": "VARCHAR", "is_dttm": False},
|
||||
{"name": "timestamp", "type": "BIGINT", "is_dttm": False},
|
||||
{
|
||||
"column_name": "event_id",
|
||||
"name": "event_id",
|
||||
"type": "VARCHAR",
|
||||
"is_dttm": False,
|
||||
},
|
||||
{
|
||||
"column_name": "timestamp",
|
||||
"name": "timestamp",
|
||||
"type": "BIGINT",
|
||||
"is_dttm": False,
|
||||
},
|
||||
{
|
||||
"column_name": "user",
|
||||
"name": "user",
|
||||
"type": "ROW(ID BIGINT, FIRST_NAME VARCHAR, LAST_NAME VARCHAR)",
|
||||
"is_dttm": False,
|
||||
|
|
@ -594,16 +673,42 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||
cols, data
|
||||
)
|
||||
expected_cols = [
|
||||
{"name": "event_id", "type": "VARCHAR", "is_dttm": False},
|
||||
{"name": "timestamp", "type": "BIGINT", "is_dttm": False},
|
||||
{
|
||||
"column_name": "event_id",
|
||||
"name": "event_id",
|
||||
"type": "VARCHAR",
|
||||
"is_dttm": False,
|
||||
},
|
||||
{
|
||||
"column_name": "timestamp",
|
||||
"name": "timestamp",
|
||||
"type": "BIGINT",
|
||||
"is_dttm": False,
|
||||
},
|
||||
{
|
||||
"column_name": "user",
|
||||
"name": "user",
|
||||
"type": "ROW(ID BIGINT, FIRST_NAME VARCHAR, LAST_NAME VARCHAR)",
|
||||
"is_dttm": False,
|
||||
},
|
||||
{"name": "user.id", "type": "BIGINT", "is_dttm": False},
|
||||
{"name": "user.first_name", "type": "VARCHAR", "is_dttm": False},
|
||||
{"name": "user.last_name", "type": "VARCHAR", "is_dttm": False},
|
||||
{
|
||||
"column_name": "user.id",
|
||||
"name": "user.id",
|
||||
"type": "BIGINT",
|
||||
"is_dttm": False,
|
||||
},
|
||||
{
|
||||
"column_name": "user.first_name",
|
||||
"name": "user.first_name",
|
||||
"type": "VARCHAR",
|
||||
"is_dttm": False,
|
||||
},
|
||||
{
|
||||
"column_name": "user.last_name",
|
||||
"name": "user.last_name",
|
||||
"type": "VARCHAR",
|
||||
"is_dttm": False,
|
||||
},
|
||||
]
|
||||
expected_data = [
|
||||
{
|
||||
|
|
@ -616,9 +721,24 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||
}
|
||||
]
|
||||
expected_expanded_cols = [
|
||||
{"name": "user.id", "type": "BIGINT", "is_dttm": False},
|
||||
{"name": "user.first_name", "type": "VARCHAR", "is_dttm": False},
|
||||
{"name": "user.last_name", "type": "VARCHAR", "is_dttm": False},
|
||||
{
|
||||
"column_name": "user.id",
|
||||
"name": "user.id",
|
||||
"type": "BIGINT",
|
||||
"is_dttm": False,
|
||||
},
|
||||
{
|
||||
"column_name": "user.first_name",
|
||||
"name": "user.first_name",
|
||||
"type": "VARCHAR",
|
||||
"is_dttm": False,
|
||||
},
|
||||
{
|
||||
"column_name": "user.last_name",
|
||||
"name": "user.last_name",
|
||||
"type": "VARCHAR",
|
||||
"is_dttm": False,
|
||||
},
|
||||
]
|
||||
|
||||
self.assertEqual(actual_cols, expected_cols)
|
||||
|
|
@ -736,12 +856,12 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||
table_name = "table_name"
|
||||
engine = mock.Mock()
|
||||
cols = [
|
||||
{"name": "val1"},
|
||||
{"name": "val2<?!@#$312,/'][p098"},
|
||||
{"name": ".val2"},
|
||||
{"name": "val2."},
|
||||
{"name": "val.2"},
|
||||
{"name": ".val2."},
|
||||
{"column_name": "val1"},
|
||||
{"column_name": "val2<?!@#$312,/'][p098"},
|
||||
{"column_name": ".val2"},
|
||||
{"column_name": "val2."},
|
||||
{"column_name": "val.2"},
|
||||
{"column_name": ".val2."},
|
||||
]
|
||||
PrestoEngineSpec.select_star(
|
||||
database, table_name, engine, show_cols=True, cols=cols
|
||||
|
|
@ -756,8 +876,8 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||
True,
|
||||
True,
|
||||
[
|
||||
{"name": "val1"},
|
||||
{"name": "val2<?!@#$312,/'][p098"},
|
||||
{"column_name": "val1"},
|
||||
{"column_name": "val2<?!@#$312,/'][p098"},
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -48,9 +48,9 @@ class TestSupersetResultSet(SupersetTestCase):
|
|||
self.assertEqual(
|
||||
results.columns,
|
||||
[
|
||||
{"is_dttm": False, "type": "STRING", "name": "a"},
|
||||
{"is_dttm": False, "type": "STRING", "name": "b"},
|
||||
{"is_dttm": False, "type": "STRING", "name": "c"},
|
||||
{"is_dttm": False, "type": "STRING", "column_name": "a", "name": "a"},
|
||||
{"is_dttm": False, "type": "STRING", "column_name": "b", "name": "b"},
|
||||
{"is_dttm": False, "type": "STRING", "column_name": "c", "name": "c"},
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -61,8 +61,8 @@ class TestSupersetResultSet(SupersetTestCase):
|
|||
self.assertEqual(
|
||||
results.columns,
|
||||
[
|
||||
{"is_dttm": False, "type": "STRING", "name": "a"},
|
||||
{"is_dttm": False, "type": "INT", "name": "b"},
|
||||
{"is_dttm": False, "type": "STRING", "column_name": "a", "name": "a"},
|
||||
{"is_dttm": False, "type": "INT", "column_name": "b", "name": "b"},
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -76,11 +76,11 @@ class TestSupersetResultSet(SupersetTestCase):
|
|||
self.assertEqual(
|
||||
results.columns,
|
||||
[
|
||||
{"is_dttm": False, "type": "FLOAT", "name": "a"},
|
||||
{"is_dttm": False, "type": "INT", "name": "b"},
|
||||
{"is_dttm": False, "type": "STRING", "name": "c"},
|
||||
{"is_dttm": True, "type": "DATETIME", "name": "d"},
|
||||
{"is_dttm": False, "type": "BOOL", "name": "e"},
|
||||
{"is_dttm": False, "type": "FLOAT", "column_name": "a", "name": "a"},
|
||||
{"is_dttm": False, "type": "INT", "column_name": "b", "name": "b"},
|
||||
{"is_dttm": False, "type": "STRING", "column_name": "c", "name": "c"},
|
||||
{"is_dttm": True, "type": "DATETIME", "column_name": "d", "name": "d"},
|
||||
{"is_dttm": False, "type": "BOOL", "column_name": "e", "name": "e"},
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -100,7 +100,7 @@ class TestSupersetResultSet(SupersetTestCase):
|
|||
data = [("a", 1), ("a", 2)]
|
||||
cursor_descr = (("a", "string"), ("a", "string"))
|
||||
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
|
||||
column_names = [col["name"] for col in results.columns]
|
||||
column_names = [col["column_name"] for col in results.columns]
|
||||
self.assertListEqual(column_names, ["a", "a__1"])
|
||||
|
||||
def test_int64_with_missing_data(self):
|
||||
|
|
|
|||
|
|
@ -121,9 +121,9 @@ def test_main_dttm_col(mocker: MockerFixture, test_table: "SqlaTable") -> None:
|
|||
mocker.patch(
|
||||
"superset.connectors.sqla.models.get_physical_table_metadata",
|
||||
return_value=[
|
||||
{"name": "ds", "type": "TIMESTAMP", "is_dttm": True},
|
||||
{"name": "event_time", "type": "TIMESTAMP", "is_dttm": True},
|
||||
{"name": "id", "type": "INTEGER", "is_dttm": False},
|
||||
{"column_name": "ds", "type": "TIMESTAMP", "is_dttm": True},
|
||||
{"column_name": "event_time", "type": "TIMESTAMP", "is_dttm": True},
|
||||
{"column_name": "id", "type": "INTEGER", "is_dttm": False},
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -154,9 +154,9 @@ def test_main_dttm_col_nonexistent(
|
|||
mocker.patch(
|
||||
"superset.connectors.sqla.models.get_physical_table_metadata",
|
||||
return_value=[
|
||||
{"name": "ds", "type": "TIMESTAMP", "is_dttm": True},
|
||||
{"name": "event_time", "type": "TIMESTAMP", "is_dttm": True},
|
||||
{"name": "id", "type": "INTEGER", "is_dttm": False},
|
||||
{"column_name": "ds", "type": "TIMESTAMP", "is_dttm": True},
|
||||
{"column_name": "event_time", "type": "TIMESTAMP", "is_dttm": True},
|
||||
{"column_name": "id", "type": "INTEGER", "is_dttm": False},
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -188,9 +188,9 @@ def test_main_dttm_col_nondttm(
|
|||
mocker.patch(
|
||||
"superset.connectors.sqla.models.get_physical_table_metadata",
|
||||
return_value=[
|
||||
{"name": "ds", "type": "TIMESTAMP", "is_dttm": True},
|
||||
{"name": "event_time", "type": "TIMESTAMP", "is_dttm": True},
|
||||
{"name": "id", "type": "INTEGER", "is_dttm": False},
|
||||
{"column_name": "ds", "type": "TIMESTAMP", "is_dttm": True},
|
||||
{"column_name": "event_time", "type": "TIMESTAMP", "is_dttm": True},
|
||||
{"column_name": "id", "type": "INTEGER", "is_dttm": False},
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -226,9 +226,9 @@ def test_python_date_format_by_column_name(
|
|||
mocker.patch(
|
||||
"superset.connectors.sqla.models.get_physical_table_metadata",
|
||||
return_value=[
|
||||
{"name": "id", "type": "INTEGER", "is_dttm": False},
|
||||
{"name": "dttm", "type": "INTEGER", "is_dttm": False},
|
||||
{"name": "duration_ms", "type": "INTEGER", "is_dttm": False},
|
||||
{"column_name": "id", "type": "INTEGER", "is_dttm": False},
|
||||
{"column_name": "dttm", "type": "INTEGER", "is_dttm": False},
|
||||
{"column_name": "duration_ms", "type": "INTEGER", "is_dttm": False},
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -274,8 +274,8 @@ def test_expression_by_column_name(
|
|||
mocker.patch(
|
||||
"superset.connectors.sqla.models.get_physical_table_metadata",
|
||||
return_value=[
|
||||
{"name": "dttm", "type": "INTEGER", "is_dttm": False},
|
||||
{"name": "duration_ms", "type": "INTEGER", "is_dttm": False},
|
||||
{"column_name": "dttm", "type": "INTEGER", "is_dttm": False},
|
||||
{"column_name": "duration_ms", "type": "INTEGER", "is_dttm": False},
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -311,9 +311,9 @@ def test_full_setting(
|
|||
mocker.patch(
|
||||
"superset.connectors.sqla.models.get_physical_table_metadata",
|
||||
return_value=[
|
||||
{"name": "id", "type": "INTEGER", "is_dttm": False},
|
||||
{"name": "dttm", "type": "INTEGER", "is_dttm": False},
|
||||
{"name": "duration_ms", "type": "INTEGER", "is_dttm": False},
|
||||
{"column_name": "id", "type": "INTEGER", "is_dttm": False},
|
||||
{"column_name": "dttm", "type": "INTEGER", "is_dttm": False},
|
||||
{"column_name": "duration_ms", "type": "INTEGER", "is_dttm": False},
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from typing import Any, Optional
|
|||
import pytest
|
||||
from sqlalchemy import types
|
||||
|
||||
from superset.superset_typing import ResultSetColumnType, SQLAColumnType
|
||||
from superset.utils.core import GenericDataType
|
||||
from tests.unit_tests.db_engine_specs.utils import assert_column_spec
|
||||
|
||||
|
|
@ -138,3 +139,32 @@ def test_get_column_spec(
|
|||
from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec as spec
|
||||
|
||||
assert_column_spec(spec, native_type, sqla_type, attrs, generic_type, is_dttm)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cols, expected_result",
|
||||
[
|
||||
(
|
||||
[SQLAColumnType(name="John", type="integer", is_dttm=False)],
|
||||
[
|
||||
ResultSetColumnType(
|
||||
column_name="John", name="John", type="integer", is_dttm=False
|
||||
)
|
||||
],
|
||||
),
|
||||
(
|
||||
[SQLAColumnType(name="hugh", type="integer", is_dttm=False)],
|
||||
[
|
||||
ResultSetColumnType(
|
||||
column_name="hugh", name="hugh", type="integer", is_dttm=False
|
||||
)
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_convert_inspector_columns(
|
||||
cols: list[SQLAColumnType], expected_result: list[ResultSetColumnType]
|
||||
):
|
||||
from superset.db_engine_specs.base import convert_inspector_columns
|
||||
|
||||
assert convert_inspector_columns(cols) == expected_result
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from sqlalchemy import select
|
|||
from sqlalchemy.sql import sqltypes
|
||||
from sqlalchemy_bigquery import BigQueryDialect
|
||||
|
||||
from superset.superset_typing import ResultSetColumnType
|
||||
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
|
||||
from tests.unit_tests.fixtures.common import dttm
|
||||
|
||||
|
|
@ -64,7 +65,16 @@ def test_get_fields() -> None:
|
|||
"""
|
||||
from superset.db_engine_specs.bigquery import BigQueryEngineSpec
|
||||
|
||||
columns = [{"name": "limit"}, {"name": "name"}, {"name": "project.name"}]
|
||||
columns: list[ResultSetColumnType] = [
|
||||
{"column_name": "limit", "name": "limit", "type": "STRING", "is_dttm": False},
|
||||
{"column_name": "name", "name": "name", "type": "STRING", "is_dttm": False},
|
||||
{
|
||||
"column_name": "project.name",
|
||||
"name": "project.name",
|
||||
"type": "STRING",
|
||||
"is_dttm": False,
|
||||
},
|
||||
]
|
||||
fields = BigQueryEngineSpec._get_fields(columns)
|
||||
|
||||
query = select(fields)
|
||||
|
|
@ -84,8 +94,9 @@ def test_select_star(mocker: MockFixture) -> None:
|
|||
"""
|
||||
from superset.db_engine_specs.bigquery import BigQueryEngineSpec
|
||||
|
||||
cols = [
|
||||
cols: list[ResultSetColumnType] = [
|
||||
{
|
||||
"column_name": "trailer",
|
||||
"name": "trailer",
|
||||
"type": sqltypes.ARRAY(sqltypes.JSON()),
|
||||
"nullable": True,
|
||||
|
|
@ -94,8 +105,10 @@ def test_select_star(mocker: MockFixture) -> None:
|
|||
"precision": None,
|
||||
"scale": None,
|
||||
"max_length": None,
|
||||
"is_dttm": False,
|
||||
},
|
||||
{
|
||||
"column_name": "trailer.key",
|
||||
"name": "trailer.key",
|
||||
"type": sqltypes.String(),
|
||||
"nullable": True,
|
||||
|
|
@ -104,8 +117,10 @@ def test_select_star(mocker: MockFixture) -> None:
|
|||
"precision": None,
|
||||
"scale": None,
|
||||
"max_length": None,
|
||||
"is_dttm": False,
|
||||
},
|
||||
{
|
||||
"column_name": "trailer.value",
|
||||
"name": "trailer.value",
|
||||
"type": sqltypes.String(),
|
||||
"nullable": True,
|
||||
|
|
@ -114,8 +129,10 @@ def test_select_star(mocker: MockFixture) -> None:
|
|||
"precision": None,
|
||||
"scale": None,
|
||||
"max_length": None,
|
||||
"is_dttm": False,
|
||||
},
|
||||
{
|
||||
"column_name": "trailer.email",
|
||||
"name": "trailer.email",
|
||||
"type": sqltypes.String(),
|
||||
"nullable": True,
|
||||
|
|
@ -124,6 +141,7 @@ def test_select_star(mocker: MockFixture) -> None:
|
|||
"precision": None,
|
||||
"scale": None,
|
||||
"max_length": None,
|
||||
"is_dttm": False,
|
||||
},
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from pyhive.sqlalchemy_presto import PrestoDialect
|
|||
from sqlalchemy import sql, text, types
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
from superset.superset_typing import ResultSetColumnType
|
||||
from superset.utils.core import GenericDataType
|
||||
from tests.unit_tests.db_engine_specs.utils import (
|
||||
assert_column_spec,
|
||||
|
|
@ -131,7 +132,14 @@ def test_where_latest_partition(
|
|||
mock_latest_partition.return_value = (["partition_key"], [column_value])
|
||||
|
||||
query = sql.select(text("* FROM table"))
|
||||
columns = [{"name": "partition_key", "type": column_type}]
|
||||
columns: list[ResultSetColumnType] = [
|
||||
{
|
||||
"column_name": "partition_key",
|
||||
"name": "partition_key",
|
||||
"type": column_type,
|
||||
"is_dttm": False,
|
||||
}
|
||||
]
|
||||
|
||||
expected = f"""SELECT * FROM table \nWHERE "partition_key" = {expected_value}"""
|
||||
result = spec.where_latest_partition(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,45 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import json
|
||||
|
||||
|
||||
def test_column_attributes_on_query():
|
||||
from superset.daos.query import QueryDAO
|
||||
from superset.models.core import Database
|
||||
from superset.models.sql_lab import Query
|
||||
|
||||
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
query_obj = Query(
|
||||
client_id="foo",
|
||||
database=db,
|
||||
tab_name="test_tab",
|
||||
sql_editor_id="test_editor_id",
|
||||
sql="select * from bar",
|
||||
select_sql="select * from bar",
|
||||
executed_sql="select * from bar",
|
||||
limit=100,
|
||||
select_as_cta=False,
|
||||
rows=100,
|
||||
error_message="none",
|
||||
results_key="abc",
|
||||
)
|
||||
|
||||
columns = [{"name": "test", "is_dttm": False, "type": "INT"}]
|
||||
payload = {"columns": columns}
|
||||
|
||||
QueryDAO.save_metadata(query_obj, payload)
|
||||
assert "column_name" in json.loads(query_obj.extra_json).get("columns")[0]
|
||||
Loading…
Reference in New Issue