fix: save columns reference from sqllab save datasets flow (#24248)

This commit is contained in:
Hugh A. Miles II 2023-06-20 13:54:19 -04:00 committed by GitHub
parent fdef9cbc96
commit 93e1db4bd9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
37 changed files with 489 additions and 202 deletions

View File

@ -48,7 +48,7 @@ export const DATASET_TIME_COLUMN_OPTION: ColumnMeta = {
};
export const QUERY_TIME_COLUMN_OPTION: QueryColumn = {
name: DTTM_ALIAS,
column_name: DTTM_ALIAS,
type: DatasourceType.Query,
is_dttm: false,
};

View File

@ -16,8 +16,8 @@
* specific language governing permissions and limitations
* under the License.
*/
import { ensureIsArray, QueryResponse } from '@superset-ui/core';
import { Dataset, isColumnMeta, isDataset, isQueryResponse } from '../types';
import { QueryResponse } from '@superset-ui/core';
import { Dataset, isColumnMeta, isDataset } from '../types';
/**
* Convert Datasource columns to column choices
@ -35,14 +35,5 @@ export default function columnChoices(
opt1[1].toLowerCase() > opt2[1].toLowerCase() ? 1 : -1,
);
}
if (isQueryResponse(datasource)) {
return ensureIsArray(datasource.columns)
.map((col): [string, string] => [col.name, col.name])
.sort((opt1, opt2) =>
opt1[1].toLowerCase() > opt2[1].toLowerCase() ? 1 : -1,
);
}
return [];
}

View File

@ -54,7 +54,7 @@ test('get temporal columns from a QueryResponse', () => {
expect(getTemporalColumns(testQueryResponse)).toEqual({
temporalColumns: [
{
name: 'Column 2',
column_name: 'Column 2',
type: 'TIMESTAMP',
is_dttm: true,
},

View File

@ -247,8 +247,8 @@ export const CtasEnum = {
};
export type QueryColumn = {
name: string;
column_name?: string;
name?: string;
column_name: string;
type: string | null;
is_dttm: boolean;
};
@ -380,17 +380,17 @@ export const testQuery: Query = {
type: DatasourceType.Query,
columns: [
{
name: 'Column 1',
column_name: 'Column 1',
type: 'STRING',
is_dttm: false,
},
{
name: 'Column 3',
column_name: 'Column 3',
type: 'STRING',
is_dttm: false,
},
{
name: 'Column 2',
column_name: 'Column 2',
type: 'TIMESTAMP',
is_dttm: true,
},
@ -402,17 +402,17 @@ export const testQueryResults = {
displayLimitReached: false,
columns: [
{
name: 'Column 1',
column_name: 'Column 1',
type: 'STRING',
is_dttm: false,
},
{
name: 'Column 3',
column_name: 'Column 3',
type: 'STRING',
is_dttm: false,
},
{
name: 'Column 2',
column_name: 'Column 2',
type: 'TIMESTAMP',
is_dttm: true,
},
@ -423,17 +423,17 @@ export const testQueryResults = {
expanded_columns: [],
selected_columns: [
{
name: 'Column 1',
column_name: 'Column 1',
type: 'STRING',
is_dttm: false,
},
{
name: 'Column 3',
column_name: 'Column 3',
type: 'STRING',
is_dttm: false,
},
{
name: 'Column 2',
column_name: 'Column 2',
type: 'TIMESTAMP',
is_dttm: true,
},

View File

@ -122,10 +122,10 @@ describe('ResultSet', () => {
expect(table).toBeInTheDocument();
const firstColumn = queryAllByText(
mockedProps.query.results?.columns[0].name ?? '',
mockedProps.query.results?.columns[0].column_name ?? '',
)[0];
const secondColumn = queryAllByText(
mockedProps.query.results?.columns[1].name ?? '',
mockedProps.query.results?.columns[1].column_name ?? '',
)[0];
expect(firstColumn).toBeInTheDocument();
expect(secondColumn).toBeInTheDocument();

View File

@ -205,7 +205,7 @@ const ResultSet = ({
...EXPLORE_CHART_DEFAULT,
datasource: `${results.query_id}__query`,
...{
all_columns: results.columns.map(column => column.name),
all_columns: results.columns.map(column => column.column_name),
},
});
const url = mountExploreUrl(null, {
@ -491,7 +491,7 @@ const ResultSet = ({
}
if (data && data.length > 0) {
const expandedColumns = results.expanded_columns
? results.expanded_columns.map(col => col.name)
? results.expanded_columns.map(col => col.column_name)
: [];
return (
<>
@ -500,7 +500,7 @@ const ResultSet = ({
{sql}
<FilterableTable
data={data}
orderedColumnKeys={results.columns.map(col => col.name)}
orderedColumnKeys={results.columns.map(col => col.column_name)}
height={rowsHeight}
filterText={searchText}
expandedColumns={expandedColumns}

View File

@ -62,6 +62,7 @@ export type ExploreQuery = QueryResponse & {
export interface ISimpleColumn {
column_name?: string | null;
name?: string | null;
type?: string | null;
is_dttm?: boolean | null;
}
@ -199,14 +200,15 @@ export const SaveDatasetModal = ({
return;
}
setLoading(true);
const [, key] = await Promise.all([
updateDataset(
datasource?.dbId,
datasetToOverwrite?.datasetid,
datasource?.sql,
datasource?.columns?.map(
(d: { name: string; type: string; is_dttm: boolean }) => ({
column_name: d.name,
(d: { column_name: string; type: string; is_dttm: boolean }) => ({
column_name: d.column_name,
type: d.type,
is_dttm: d.is_dttm,
}),
@ -292,12 +294,10 @@ export const SaveDatasetModal = ({
dispatch(
createDatasource({
schema: datasource.schema,
sql: datasource.sql,
dbId: datasource.dbId || datasource?.database?.id,
templateParams,
datasourceName: datasetName,
columns: selectedColumns,
}),
)
.then((data: { id: number }) =>

View File

@ -236,24 +236,24 @@ export const queries = [
columns: [
{
is_dttm: true,
name: 'ds',
column_name: 'ds',
type: 'STRING',
},
{
is_dttm: false,
name: 'gender',
column_name: 'gender',
type: 'STRING',
},
],
selected_columns: [
{
is_dttm: true,
name: 'ds',
column_name: 'ds',
type: 'STRING',
},
{
is_dttm: false,
name: 'gender',
column_name: 'gender',
type: 'STRING',
},
],
@ -326,7 +326,7 @@ export const queryWithNoQueryLimit = {
columns: [
{
is_dttm: true,
name: 'ds',
column_name: 'ds',
type: 'STRING',
},
{
@ -338,12 +338,12 @@ export const queryWithNoQueryLimit = {
selected_columns: [
{
is_dttm: true,
name: 'ds',
column_name: 'ds',
type: 'STRING',
},
{
is_dttm: false,
name: 'gender',
column_name: 'gender',
type: 'STRING',
},
],
@ -364,57 +364,57 @@ export const queryWithBadColumns = {
selected_columns: [
{
is_dttm: true,
name: 'COUNT(*)',
column_name: 'COUNT(*)',
type: 'STRING',
},
{
is_dttm: false,
name: 'this_col_is_ok',
column_name: 'this_col_is_ok',
type: 'STRING',
},
{
is_dttm: false,
name: 'a',
column_name: 'a',
type: 'STRING',
},
{
is_dttm: false,
name: '1',
column_name: '1',
type: 'STRING',
},
{
is_dttm: false,
name: '123',
column_name: '123',
type: 'STRING',
},
{
is_dttm: false,
name: 'CASE WHEN 1=1 THEN 1 ELSE 0 END',
column_name: 'CASE WHEN 1=1 THEN 1 ELSE 0 END',
type: 'STRING',
},
{
is_dttm: true,
name: '_TIMESTAMP',
column_name: '_TIMESTAMP',
type: 'TIMESTAMP',
},
{
is_dttm: true,
name: '__TIME',
column_name: '__TIME',
type: 'TIMESTAMP',
},
{
is_dttm: false,
name: 'my_dupe_col__2',
column_name: 'my_dupe_col__2',
type: 'STRING',
},
{
is_dttm: true,
name: '__timestamp',
column_name: '__timestamp',
type: 'TIMESTAMP',
},
{
is_dttm: true,
name: '__TIMESTAMP',
column_name: '__TIMESTAMP',
type: 'TIMESTAMP',
},
],
@ -572,31 +572,31 @@ const baseQuery: QueryResponse = {
columns: [
{
is_dttm: true,
name: 'ds',
column_name: 'ds',
type: 'STRING',
},
{
is_dttm: false,
name: 'gender',
column_name: 'gender',
type: 'STRING',
},
],
selected_columns: [
{
is_dttm: true,
name: 'ds',
column_name: 'ds',
type: 'STRING',
},
{
is_dttm: false,
name: 'gender',
column_name: 'gender',
type: 'STRING',
},
],
expanded_columns: [
{
is_dttm: true,
name: 'ds',
column_name: 'ds',
type: 'STRING',
},
],

View File

@ -688,8 +688,9 @@ class DatasourceEditor extends React.PureComponent {
}
updateColumns(cols) {
// cols: Array<{column_name: string; is_dttm: boolean; type: string;}>
const { databaseColumns } = this.state;
const databaseColumnNames = cols.map(col => col.name);
const databaseColumnNames = cols.map(col => col.column_name);
const currentCols = databaseColumns.reduce(
(agg, col) => ({
...agg,
@ -706,18 +707,18 @@ class DatasourceEditor extends React.PureComponent {
.filter(col => !databaseColumnNames.includes(col)),
};
cols.forEach(col => {
const currentCol = currentCols[col.name];
const currentCol = currentCols[col.column_name];
if (!currentCol) {
// new column
finalColumns.push({
id: shortid.generate(),
column_name: col.name,
column_name: col.column_name,
type: col.type,
groupby: true,
filterable: true,
is_dttm: col.is_dttm,
});
results.added.push(col.name);
results.added.push(col.column_name);
} else if (
currentCol.type !== col.type ||
(!currentCol.is_dttm && col.is_dttm)
@ -728,7 +729,7 @@ class DatasourceEditor extends React.PureComponent {
type: col.type,
is_dttm: currentCol.is_dttm || col.is_dttm,
});
results.modified.push(col.name);
results.modified.push(col.column_name);
} else {
// unchanged
finalColumns.push(currentCol);

View File

@ -174,7 +174,6 @@ class SaveModal extends React.Component<SaveModalProps, SaveModalState> {
if (this.props.datasource?.type === DatasourceType.Query) {
const { schema, sql, database } = this.props.datasource;
const { templateParams } = this.props.datasource;
const columns = this.props.datasource?.columns || [];
await this.props.actions.saveDataset({
schema,
@ -182,7 +181,6 @@ class SaveModal extends React.Component<SaveModalProps, SaveModalState> {
database,
templateParams,
datasourceName: this.state.datasetName,
columns,
});
}

View File

@ -35,7 +35,12 @@ from superset.constants import EMPTY_STRING, NULL_STRING
from superset.datasets.commands.exceptions import DatasetNotFoundError
from superset.models.helpers import AuditMixinNullable, ImportExportMixin, QueryResult
from superset.models.slice import Slice
from superset.superset_typing import FilterValue, FilterValues, QueryObjectDict
from superset.superset_typing import (
FilterValue,
FilterValues,
QueryObjectDict,
ResultSetColumnType,
)
from superset.utils import core as utils
from superset.utils.core import GenericDataType, MediumText
@ -456,7 +461,7 @@ class BaseDatasource(
values = values[0] if values else None
return values
def external_metadata(self) -> list[dict[str, str]]:
def external_metadata(self) -> list[ResultSetColumnType]:
"""Returns column information from the external system"""
raise NotImplementedError()

View File

@ -105,7 +105,13 @@ from superset.models.helpers import (
validate_adhoc_subquery,
)
from superset.sql_parse import ParsedQuery, sanitize_clause
from superset.superset_typing import AdhocColumn, AdhocMetric, Metric, QueryObjectDict
from superset.superset_typing import (
AdhocColumn,
AdhocMetric,
Metric,
QueryObjectDict,
ResultSetColumnType,
)
from superset.utils import core as utils
from superset.utils.core import GenericDataType, MediumText
@ -700,10 +706,10 @@ class SqlaTable(
def sql_url(self) -> str:
return self.database.sql_url + "?table_name=" + str(self.table_name)
def external_metadata(self) -> list[dict[str, str]]:
def external_metadata(self) -> list[ResultSetColumnType]:
# todo(yongjie): create a physical table column type in a separate PR
if self.sql:
return get_virtual_table_metadata(dataset=self) # type: ignore
return get_virtual_table_metadata(dataset=self)
return get_physical_table_metadata(
database=self.database,
table_name=self.table_name,
@ -995,7 +1001,7 @@ class SqlaTable(
qry = sa.select([sqla_column]).limit(1).select_from(tbl)
sql = self.database.compile_sqla_query(qry)
col_desc = get_columns_description(self.database, sql)
is_dttm = col_desc[0]["is_dttm"]
is_dttm = col_desc[0]["is_dttm"] # type: ignore
except SupersetGenericDBErrorException as ex:
raise ColumnNotFoundException(message=str(ex)) from ex
@ -1260,18 +1266,18 @@ class SqlaTable(
removed=[
col
for col in old_columns_by_name
if col not in {col["name"] for col in new_columns}
if col not in {col["column_name"] for col in new_columns}
]
)
# clear old columns before adding modified columns back
columns = []
for col in new_columns:
old_column = old_columns_by_name.pop(col["name"], None)
old_column = old_columns_by_name.pop(col["column_name"], None)
if not old_column:
results.added.append(col["name"])
results.added.append(col["column_name"])
new_column = TableColumn(
column_name=col["name"],
column_name=col["column_name"],
type=col["type"],
table=self,
)
@ -1280,14 +1286,14 @@ class SqlaTable(
else:
new_column = old_column
if new_column.type != col["type"]:
results.modified.append(col["name"])
results.modified.append(col["column_name"])
new_column.type = col["type"]
new_column.expression = ""
new_column.groupby = True
new_column.filterable = True
columns.append(new_column)
if not any_date_col and new_column.is_temporal:
any_date_col = col["name"]
any_date_col = col["column_name"]
# add back calculated (virtual) columns
columns.extend([col for col in old_columns if col.expression])

View File

@ -19,7 +19,7 @@ from __future__ import annotations
import logging
from collections.abc import Iterable, Iterator
from functools import lru_cache
from typing import Any, Callable, TYPE_CHECKING, TypeVar
from typing import Callable, TYPE_CHECKING, TypeVar
from uuid import UUID
from flask_babel import lazy_gettext as _
@ -49,7 +49,7 @@ def get_physical_table_metadata(
database: Database,
table_name: str,
schema_name: str | None = None,
) -> list[dict[str, Any]]:
) -> list[ResultSetColumnType]:
"""Use SQLAlchemy inspector to get table metadata"""
db_engine_spec = database.db_engine_spec
db_dialect = database.get_dialect()

View File

@ -63,6 +63,9 @@ class QueryDAO(BaseDAO):
def save_metadata(query: Query, payload: dict[str, Any]) -> None:
# pull relevant data from payload and store in extra_json
columns = payload.get("columns", {})
for col in columns:
if "name" in col:
col["column_name"] = col.get("name")
db.session.add(query)
query.set_extra_json_key("columns", columns)

View File

@ -79,10 +79,10 @@ def get_table_metadata(
dtype = get_col_type(col)
payload_columns.append(
{
"name": col["name"],
"name": col["column_name"],
"type": dtype.split("(")[0] if "(" in dtype else dtype,
"longType": dtype,
"keys": [k for k in keys if col["name"] in k["column_names"]],
"keys": [k for k in keys if col["column_name"] in k["column_names"]],
"comment": col.get("comment"),
}
)

View File

@ -311,7 +311,7 @@ class DatasetRestApi(BaseSupersetModelRestApi):
try:
new_model = CreateDatasetCommand(item).run()
return self.response(201, id=new_model.id, result=item)
return self.response(201, id=new_model.id, result=item, data=new_model.data)
except DatasetInvalidError as ex:
return self.response_422(message=ex.normalized_messages())
except DatasetCreateFailedError as ex:

View File

@ -22,6 +22,7 @@ from marshmallow import ValidationError
from sqlalchemy.exc import SQLAlchemyError
from superset.commands.base import BaseCommand, CreateMixin
from superset.connectors.sqla.models import SqlMetric
from superset.daos.dataset import DatasetDAO
from superset.daos.exceptions import DAOCreateFailedError
from superset.datasets.commands.exceptions import (
@ -45,7 +46,10 @@ class CreateDatasetCommand(CreateMixin, BaseCommand):
try:
# Creates SqlaTable (Dataset)
dataset = DatasetDAO.create(self._properties, commit=False)
# Updates columns and metrics from the dataset
dataset.metrics = [SqlMetric(metric_name="count", expression="COUNT(*)")]
dataset.fetch_metadata(commit=False)
db.session.commit()
except (SQLAlchemyError, DAOCreateFailedError) as ex:

View File

@ -23,7 +23,7 @@ import logging
import re
from datetime import datetime
from re import Match, Pattern
from typing import Any, Callable, ContextManager, NamedTuple, TYPE_CHECKING, Union
from typing import Any, Callable, cast, ContextManager, NamedTuple, TYPE_CHECKING, Union
import pandas as pd
import sqlparse
@ -53,7 +53,7 @@ from superset.constants import TimeGrain as TimeGrainConstants
from superset.databases.utils import make_url_safe
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.sql_parse import ParsedQuery, Table
from superset.superset_typing import ResultSetColumnType
from superset.superset_typing import ResultSetColumnType, SQLAColumnType
from superset.utils import core as utils
from superset.utils.core import ColumnSpec, GenericDataType
from superset.utils.hashing import md5_sha_from_str
@ -73,6 +73,13 @@ ColumnTypeMapping = tuple[
logger = logging.getLogger()
def convert_inspector_columns(cols: list[SQLAColumnType]) -> list[ResultSetColumnType]:
result_set_columns: list[ResultSetColumnType] = []
for col in cols:
result_set_columns.append({"column_name": col.get("name"), **col}) # type: ignore
return result_set_columns
class TimeGrain(NamedTuple):
name: str # TODO: redundant field, remove
label: str
@ -1223,7 +1230,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def get_columns(
cls, inspector: Inspector, table_name: str, schema: str | None
) -> list[dict[str, Any]]:
) -> list[ResultSetColumnType]:
"""
Get all columns from a given schema and table
@ -1232,7 +1239,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:param schema: Schema name. If omitted, uses default schema for database
:return: All columns in table
"""
return inspector.get_columns(table_name, schema)
return convert_inspector_columns(
cast(list[SQLAColumnType], inspector.get_columns(table_name, schema))
)
@classmethod
def get_metrics( # pylint: disable=unused-argument
@ -1261,7 +1270,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
schema: str | None,
database: Database,
query: Select,
columns: list[dict[str, Any]] | None = None,
columns: list[ResultSetColumnType] | None = None,
) -> Select | None:
"""
Add a where clause to a query to reference only the most recent partition
@ -1278,8 +1287,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return None
@classmethod
def _get_fields(cls, cols: list[dict[str, Any]]) -> list[Any]:
return [column(c["name"]) for c in cols]
def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[Any]:
return [column(c["column_name"]) for c in cols]
@classmethod
def select_star( # pylint: disable=too-many-arguments,too-many-locals
@ -1292,7 +1301,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
show_cols: bool = False,
indent: bool = True,
latest_partition: bool = True,
cols: list[dict[str, Any]] | None = None,
cols: list[ResultSetColumnType] | None = None,
) -> str:
"""
Generate a "SELECT * from [schema.]table_name" query with appropriate limit.

View File

@ -43,6 +43,7 @@ from superset.db_engine_specs.exceptions import SupersetDBAPIConnectionError
from superset.errors import SupersetError, SupersetErrorType
from superset.exceptions import SupersetException
from superset.sql_parse import Table
from superset.superset_typing import ResultSetColumnType
from superset.utils import core as utils
from superset.utils.hashing import md5_sha_from_str
@ -637,7 +638,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
show_cols: bool = False,
indent: bool = True,
latest_partition: bool = True,
cols: Optional[list[dict[str, Any]]] = None,
cols: Optional[list[ResultSetColumnType]] = None,
) -> str:
"""
Remove array structures from `SELECT *`.
@ -678,13 +679,15 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
# For arrays of structs, remove the child columns, otherwise the query
# will fail.
array_prefixes = {
col["name"] for col in cols if isinstance(col["type"], sqltypes.ARRAY)
col["column_name"]
for col in cols
if isinstance(col["type"], sqltypes.ARRAY)
}
cols = [
col
for col in cols
if "." not in col["name"]
or col["name"].split(".")[0] not in array_prefixes
if "." not in col["column_name"]
or col["column_name"].split(".")[0] not in array_prefixes
]
return super().select_star(
@ -700,7 +703,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
)
@classmethod
def _get_fields(cls, cols: list[dict[str, Any]]) -> list[Any]:
def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[Any]:
"""
Label columns using their fully qualified name.
@ -725,7 +728,10 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
the columns using their fully qualified name, so we end up with "author",
"author__name" and "author__email", respectively.
"""
return [column(c["name"]).label(c["name"].replace(".", "__")) for c in cols]
return [
column(c["column_name"]).label(c["column_name"].replace(".", "__"))
for c in cols
]
@classmethod
def parse_error_exception(cls, exception: Exception) -> Exception:

View File

@ -30,6 +30,7 @@ from superset.constants import TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.exceptions import SupersetDBAPIConnectionError
from superset.exceptions import SupersetException
from superset.superset_typing import ResultSetColumnType
from superset.utils import core as utils
if TYPE_CHECKING:
@ -132,7 +133,7 @@ class DruidEngineSpec(BaseEngineSpec):
@classmethod
def get_columns(
cls, inspector: Inspector, table_name: str, schema: str | None
) -> list[dict[str, Any]]:
) -> list[ResultSetColumnType]:
"""
Update the Druid type map.
"""

View File

@ -46,6 +46,7 @@ from superset.exceptions import SupersetException
from superset.extensions import cache_manager
from superset.models.sql_lab import Query
from superset.sql_parse import ParsedQuery, Table
from superset.superset_typing import ResultSetColumnType
if TYPE_CHECKING:
# prevent circular imports
@ -407,8 +408,8 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
def get_columns(
cls, inspector: Inspector, table_name: str, schema: str | None
) -> list[dict[str, Any]]:
return inspector.get_columns(table_name, schema)
) -> list[ResultSetColumnType]:
return BaseEngineSpec.get_columns(inspector, table_name, schema)
@classmethod
def where_latest_partition( # pylint: disable=too-many-arguments
@ -417,7 +418,7 @@ class HiveEngineSpec(PrestoEngineSpec):
schema: str | None,
database: Database,
query: Select,
columns: list[dict[str, Any]] | None = None,
columns: list[ResultSetColumnType] | None = None,
) -> Select | None:
try:
col_names, values = cls.latest_partition(
@ -436,7 +437,7 @@ class HiveEngineSpec(PrestoEngineSpec):
return None
@classmethod
def _get_fields(cls, cols: list[dict[str, Any]]) -> list[ColumnClause]:
def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[ColumnClause]:
return BaseEngineSpec._get_fields(cols) # pylint: disable=protected-access
@classmethod
@ -481,7 +482,7 @@ class HiveEngineSpec(PrestoEngineSpec):
show_cols: bool = False,
indent: bool = True,
latest_partition: bool = True,
cols: list[dict[str, Any]] | None = None,
cols: list[ResultSetColumnType] | None = None,
) -> str:
return super( # pylint: disable=bad-super-call
PrestoEngineSpec, cls

View File

@ -126,7 +126,14 @@ def get_children(column: ResultSetColumnType) -> list[ResultSetColumnType]:
type_ = group["type"].upper()
children_type = group["children"]
if type_ == "ARRAY":
return [{"name": column["name"], "type": children_type, "is_dttm": False}]
return [
{
"column_name": column["column_name"],
"name": column["column_name"],
"type": children_type,
"is_dttm": False,
}
]
if type_ == "ROW":
nameless_columns = 0
@ -141,7 +148,8 @@ def get_children(column: ResultSetColumnType) -> list[ResultSetColumnType]:
type_ = parts[0]
nameless_columns += 1
_column: ResultSetColumnType = {
"name": f"{column['name']}.{name.lower()}",
"column_name": f"{column['column_name']}.{name.lower()}",
"name": f"{column['column_name']}.{name.lower()}",
"type": type_,
"is_dttm": False,
}
@ -482,7 +490,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
schema: str | None,
database: Database,
query: Select,
columns: list[dict[str, Any]] | None = None,
columns: list[ResultSetColumnType] | None = None,
) -> Select | None:
try:
col_names, values = cls.latest_partition(
@ -496,7 +504,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
return None
column_type_by_name = {
column.get("name"): column.get("type") for column in columns or []
column.get("column_name"): column.get("type") for column in columns or []
}
for col_name, value in zip(col_names, values):
@ -813,14 +821,20 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
@classmethod
def _create_column_info(
cls, name: str, data_type: types.TypeEngine
) -> dict[str, Any]:
) -> ResultSetColumnType:
"""
Create column info object
:param name: column name
:param data_type: column data type
:return: column info object
"""
return {"name": name, "type": f"{data_type}"}
return {
"column_name": name,
"name": name,
"type": f"{data_type}",
"is_dttm": None,
"type_generic": None,
}
@classmethod
def _get_full_name(cls, names: list[tuple[str, str]]) -> str:
@ -863,7 +877,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
cls,
parent_column_name: str,
parent_data_type: str,
result: list[dict[str, Any]],
result: list[ResultSetColumnType],
) -> None:
"""
Parse a row or array column
@ -941,7 +955,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
# Unquote the column name if necessary
if formatted_parent_column_name != parent_column_name:
for index in range(original_result_len, len(result)):
result[index]["name"] = result[index]["name"].replace(
result[index]["column_name"] = result[index]["column_name"].replace(
formatted_parent_column_name, parent_column_name
)
@ -965,7 +979,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
@classmethod
def get_columns(
cls, inspector: Inspector, table_name: str, schema: str | None
) -> list[dict[str, Any]]:
) -> list[ResultSetColumnType]:
"""
Get columns from a Presto data source. This includes handling row and
array data types
@ -976,7 +990,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
(i.e. column name and data type)
"""
columns = cls._show_columns(inspector, table_name, schema)
result: list[dict[str, Any]] = []
result: list[ResultSetColumnType] = []
for column in columns:
# parse column if it is a row or array
if is_feature_enabled("PRESTO_EXPAND_DATA") and (
@ -1003,6 +1017,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
column_info = cls._create_column_info(column.Column, column_type)
column_info["nullable"] = getattr(column, "Null", True)
column_info["default"] = None
column_info["column_name"] = column.Column
result.append(column_info)
return result
@ -1016,7 +1031,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
return column_name.startswith('"') and column_name.endswith('"')
@classmethod
def _get_fields(cls, cols: list[dict[str, Any]]) -> list[ColumnClause]:
def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[ColumnClause]:
"""
Format column clauses where names are in quotes and labels are specified
:param cols: columns
@ -1034,7 +1049,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
dot_regex = re.compile(dot_pattern, re.VERBOSE)
for col in cols:
# get individual column names
col_names = re.split(dot_regex, col["name"])
col_names = re.split(dot_regex, col["column_name"])
# quote each column name if it is not already quoted
for index, col_name in enumerate(col_names):
if not cls._is_column_name_quoted(col_name):
@ -1044,7 +1059,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
for col_name in col_names
)
# create column clause in the format "name"."name" AS "name.name"
column_clause = literal_column(quoted_col_name).label(col["name"])
column_clause = literal_column(quoted_col_name).label(col["column_name"])
column_clauses.append(column_clause)
return column_clauses
@ -1059,7 +1074,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
show_cols: bool = False,
indent: bool = True,
latest_partition: bool = True,
cols: list[dict[str, Any]] | None = None,
cols: list[ResultSetColumnType] | None = None,
) -> str:
"""
Include selecting properties of row objects. We cannot easily break arrays into
@ -1071,7 +1086,9 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
if is_feature_enabled("PRESTO_EXPAND_DATA") and show_cols:
dot_regex = r"\.(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)"
presto_cols = [
col for col in presto_cols if not re.search(dot_regex, col["name"])
col
for col in presto_cols
if not re.search(dot_regex, col["column_name"])
]
return super().select_star(
database,
@ -1123,7 +1140,9 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
current_array_level = None
while to_process:
column, level = to_process.popleft()
if column["name"] not in [column["name"] for column in all_columns]:
if column["column_name"] not in [
column["column_name"] for column in all_columns
]:
all_columns.append(column)
# When unnesting arrays we need to keep track of how many extra rows
@ -1135,7 +1154,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
unnested_rows: dict[int, int] = defaultdict(int)
current_array_level = level
name = column["name"]
name = column["column_name"]
values: str | list[Any] | None
if column["type"] and column["type"].startswith("ARRAY("):
@ -1186,10 +1205,11 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
values = cast(Optional[list[Any]], destringify(values))
row[name] = values
for value, col in zip(values or [], expanded):
row[col["name"]] = value
row[col["column_name"]] = value
data = [
{k["name"]: row.get(k["name"], "") for k in all_columns} for row in data
{k["column_name"]: row.get(k["column_name"], "") for k in all_columns}
for row in data
]
return all_columns, data, expanded_columns

View File

@ -69,6 +69,7 @@ from superset.extensions import (
)
from superset.models.helpers import AuditMixinNullable, ImportExportMixin
from superset.result_set import SupersetResultSet
from superset.superset_typing import ResultSetColumnType
from superset.utils import cache as cache_util, core as utils
from superset.utils.core import get_username
@ -632,7 +633,7 @@ class Database(
show_cols: bool = False,
indent: bool = True,
latest_partition: bool = False,
cols: Optional[list[dict[str, Any]]] = None,
cols: Optional[list[ResultSetColumnType]] = None,
) -> str:
"""Generates a ``select *`` statement in the proper dialect"""
eng = self._get_sqla_engine(schema=schema, source=utils.QuerySource.SQL_LAB)
@ -837,7 +838,7 @@ class Database(
def get_columns(
self, table_name: str, schema: Optional[str] = None
) -> list[dict[str, Any]]:
) -> list[ResultSetColumnType]:
with self.get_inspector_with_context() as inspector:
return self.db_engine_spec.get_columns(inspector, table_name, schema)

View File

@ -193,7 +193,7 @@ class Query(
return [
TableColumn(
column_name=col["name"],
column_name=col["column_name"],
database=self.database,
is_dttm=col["is_dttm"],
filterable=True,

View File

@ -253,10 +253,10 @@ class SupersetResultSet:
for col in self.table.schema:
db_type_str = self.data_type(col.name, col.type)
column: ResultSetColumnType = {
"column_name": col.name,
"name": col.name,
"type": db_type_str,
"is_dttm": self.is_temporal(db_type_str),
}
columns.append(column)
return columns

View File

@ -18,7 +18,7 @@ from collections.abc import Sequence
from datetime import datetime
from typing import Any, Literal, Optional, TYPE_CHECKING, Union
from typing_extensions import TypedDict
from typing_extensions import NotRequired, TypedDict
from werkzeug.wrappers import Response
if TYPE_CHECKING:
@ -60,14 +60,29 @@ class AdhocColumn(TypedDict, total=False):
timeGrain: Optional[str]
class SQLAColumnType(TypedDict):
name: str
type: Optional[str]
is_dttm: bool
class ResultSetColumnType(TypedDict):
"""
Superset virtual dataset column interface
"""
name: str
name: str # legacy naming convention keeping this for backwards compatibility
column_name: str
type: Optional[str]
is_dttm: bool
is_dttm: Optional[bool]
type_generic: NotRequired[Optional["GenericDataType"]]
nullable: NotRequired[Any]
default: NotRequired[Any]
comment: NotRequired[Any]
precision: NotRequired[Any]
scale: NotRequired[Any]
max_length: NotRequired[Any]
CacheConfig = dict[str, Any]

View File

@ -43,6 +43,7 @@ from superset.models.helpers import (
ImportExportMixin,
)
from superset.sql_parse import Table as TableName
from superset.superset_typing import ResultSetColumnType
if TYPE_CHECKING:
from superset.datasets.models import Dataset
@ -131,8 +132,8 @@ class Table(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
existing_columns = {column.name: column for column in self.columns}
quote_identifier = self.database.quote_identifier
def update_or_create_column(column_meta: dict[str, Any]) -> Column:
column_name: str = column_meta["name"]
def update_or_create_column(column_meta: ResultSetColumnType) -> Column:
column_name: str = column_meta["column_name"]
if column_name in existing_columns:
column = existing_columns[column_name]
else:

View File

@ -536,6 +536,10 @@ def _deserialize_results_payload(
df = result_set.SupersetResultSet.convert_table_to_df(pa_table)
ds_payload["data"] = dataframe.df_to_records(df) or []
for column in ds_payload["selected_columns"]:
if "name" in column:
column["column_name"] = column.get("name")
db_engine_spec = query.database.db_engine_spec
all_columns, data, expanded_columns = db_engine_spec.expand_data(
ds_payload["selected_columns"], ds_payload["data"]

View File

@ -791,7 +791,7 @@ class TestDatasetApi(SupersetTestCase):
mock_get_columns.return_value = [
{
"name": "col",
"column_name": "col",
"type": "VARCHAR",
"type_generic": None,
"is_dttm": None,

View File

@ -71,7 +71,7 @@ class TestDatasource(SupersetTestCase):
tbl = self.get_table(name="birth_names")
url = f"/datasource/external_metadata/table/{tbl.id}/"
resp = self.get_json_resp(url)
col_names = {o.get("name") for o in resp}
col_names = {o.get("column_name") for o in resp}
self.assertEqual(
col_names, {"num_boys", "num", "gender", "name", "ds", "state", "num_girls"}
)
@ -91,7 +91,7 @@ class TestDatasource(SupersetTestCase):
table = self.get_table(name="dummy_sql_table")
url = f"/datasource/external_metadata/table/{table.id}/"
resp = self.get_json_resp(url)
assert {o.get("name") for o in resp} == {"intcol", "strcol"}
assert {o.get("column_name") for o in resp} == {"intcol", "strcol"}
session.delete(table)
session.commit()
@ -109,7 +109,7 @@ class TestDatasource(SupersetTestCase):
)
url = f"/datasource/external_metadata_by_name/?q={params}"
resp = self.get_json_resp(url)
col_names = {o.get("name") for o in resp}
col_names = {o.get("column_name") for o in resp}
self.assertEqual(
col_names, {"num_boys", "num", "gender", "name", "ds", "state", "num_girls"}
)
@ -137,7 +137,7 @@ class TestDatasource(SupersetTestCase):
)
url = f"/datasource/external_metadata_by_name/?q={params}"
resp = self.get_json_resp(url)
assert {o.get("name") for o in resp} == {"intcol", "strcol"}
assert {o.get("column_name") for o in resp} == {"intcol", "strcol"}
session.delete(tbl)
session.commit()
@ -155,7 +155,7 @@ class TestDatasource(SupersetTestCase):
)
url = f"/datasource/external_metadata_by_name/?q={params}"
resp = self.get_json_resp(url)
col_names = {o.get("name") for o in resp}
col_names = {o.get("column_name") for o in resp}
self.assertEqual(col_names, {"first", "second"})
# No databases found
@ -216,7 +216,7 @@ class TestDatasource(SupersetTestCase):
table = self.get_table(name="dummy_sql_table_with_template_params")
url = f"/datasource/external_metadata/table/{table.id}/"
resp = self.get_json_resp(url)
assert {o.get("name") for o in resp} == {"intcol"}
assert {o.get("column_name") for o in resp} == {"intcol"}
session.delete(table)
session.commit()

View File

@ -85,7 +85,7 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
results = PrestoEngineSpec.get_columns(inspector, "", "")
self.assertEqual(len(expected_results), len(results))
for expected_result, result in zip(expected_results, results):
self.assertEqual(expected_result[0], result["name"])
self.assertEqual(expected_result[0], result["column_name"])
self.assertEqual(expected_result[1], str(result["type"]))
def test_presto_get_column(self):
@ -175,21 +175,21 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
def test_presto_get_fields(self):
cols = [
{"name": "column"},
{"name": "column.nested_obj"},
{"name": 'column."quoted.nested obj"'},
{"column_name": "column"},
{"column_name": "column.nested_obj"},
{"column_name": 'column."quoted.nested obj"'},
]
actual_results = PrestoEngineSpec._get_fields(cols)
expected_results = [
{"name": '"column"', "label": "column"},
{"name": '"column"."nested_obj"', "label": "column.nested_obj"},
{"column_name": '"column"', "label": "column"},
{"column_name": '"column"."nested_obj"', "label": "column.nested_obj"},
{
"name": '"column"."quoted.nested obj"',
"column_name": '"column"."quoted.nested obj"',
"label": 'column."quoted.nested obj"',
},
]
for actual_result, expected_result in zip(actual_results, expected_results):
self.assertEqual(actual_result.element.name, expected_result["name"])
self.assertEqual(actual_result.element.name, expected_result["column_name"])
self.assertEqual(actual_result.name, expected_result["label"])
@mock.patch.dict(
@ -199,8 +199,18 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
)
def test_presto_expand_data_with_simple_structural_columns(self):
cols = [
{"name": "row_column", "type": "ROW(NESTED_OBJ VARCHAR)", "is_dttm": False},
{"name": "array_column", "type": "ARRAY(BIGINT)", "is_dttm": False},
{
"column_name": "row_column",
"name": "row_column",
"type": "ROW(NESTED_OBJ VARCHAR)",
"is_dttm": False,
},
{
"column_name": "array_column",
"name": "array_column",
"type": "ARRAY(BIGINT)",
"is_dttm": False,
},
]
data = [
{"row_column": ["a"], "array_column": [1, 2, 3]},
@ -210,9 +220,24 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
cols, data
)
expected_cols = [
{"name": "row_column", "type": "ROW(NESTED_OBJ VARCHAR)", "is_dttm": False},
{"name": "row_column.nested_obj", "type": "VARCHAR", "is_dttm": False},
{"name": "array_column", "type": "ARRAY(BIGINT)", "is_dttm": False},
{
"column_name": "row_column",
"name": "row_column",
"type": "ROW(NESTED_OBJ VARCHAR)",
"is_dttm": False,
},
{
"column_name": "row_column.nested_obj",
"name": "row_column.nested_obj",
"type": "VARCHAR",
"is_dttm": False,
},
{
"column_name": "array_column",
"name": "array_column",
"type": "ARRAY(BIGINT)",
"is_dttm": False,
},
]
expected_data = [
@ -225,7 +250,12 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
]
expected_expanded_cols = [
{"name": "row_column.nested_obj", "type": "VARCHAR", "is_dttm": False}
{
"name": "row_column.nested_obj",
"column_name": "row_column.nested_obj",
"type": "VARCHAR",
"is_dttm": False,
}
]
self.assertEqual(actual_cols, expected_cols)
self.assertEqual(actual_data, expected_data)
@ -240,6 +270,7 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
cols = [
{
"name": "row_column",
"column_name": "row_column",
"type": "ROW(NESTED_OBJ1 VARCHAR, NESTED_ROW ROW(NESTED_OBJ2 VARCHAR))",
"is_dttm": False,
}
@ -251,17 +282,25 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
expected_cols = [
{
"name": "row_column",
"column_name": "row_column",
"type": "ROW(NESTED_OBJ1 VARCHAR, NESTED_ROW ROW(NESTED_OBJ2 VARCHAR))",
"is_dttm": False,
},
{"name": "row_column.nested_obj1", "type": "VARCHAR", "is_dttm": False},
{
"name": "row_column.nested_obj1",
"column_name": "row_column.nested_obj1",
"type": "VARCHAR",
"is_dttm": False,
},
{
"name": "row_column.nested_row",
"column_name": "row_column.nested_row",
"type": "ROW(NESTED_OBJ2 VARCHAR)",
"is_dttm": False,
},
{
"name": "row_column.nested_row.nested_obj2",
"column_name": "row_column.nested_row.nested_obj2",
"type": "VARCHAR",
"is_dttm": False,
},
@ -282,14 +321,21 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
]
expected_expanded_cols = [
{"name": "row_column.nested_obj1", "type": "VARCHAR", "is_dttm": False},
{
"name": "row_column.nested_obj1",
"column_name": "row_column.nested_obj1",
"type": "VARCHAR",
"is_dttm": False,
},
{
"name": "row_column.nested_row",
"column_name": "row_column.nested_row",
"type": "ROW(NESTED_OBJ2 VARCHAR)",
"is_dttm": False,
},
{
"name": "row_column.nested_row.nested_obj2",
"column_name": "row_column.nested_row.nested_obj2",
"type": "VARCHAR",
"is_dttm": False,
},
@ -307,6 +353,7 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
cols = [
{
"name": "row_column",
"column_name": "row_column",
"type": "ROW(NESTED_ROW ROW(NESTED_OBJ VARCHAR))",
"is_dttm": False,
}
@ -323,16 +370,19 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
expected_cols = [
{
"name": "row_column",
"column_name": "row_column",
"type": "ROW(NESTED_ROW ROW(NESTED_OBJ VARCHAR))",
"is_dttm": False,
},
{
"name": "row_column.nested_row",
"column_name": "row_column.nested_row",
"type": "ROW(NESTED_OBJ VARCHAR)",
"is_dttm": False,
},
{
"name": "row_column.nested_row.nested_obj",
"column_name": "row_column.nested_row.nested_obj",
"type": "VARCHAR",
"is_dttm": False,
},
@ -363,11 +413,13 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
expected_expanded_cols = [
{
"name": "row_column.nested_row",
"column_name": "row_column.nested_row",
"type": "ROW(NESTED_OBJ VARCHAR)",
"is_dttm": False,
},
{
"name": "row_column.nested_row.nested_obj",
"column_name": "row_column.nested_row.nested_obj",
"type": "VARCHAR",
"is_dttm": False,
},
@ -383,9 +435,15 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
)
def test_presto_expand_data_with_complex_array_columns(self):
cols = [
{"name": "int_column", "type": "BIGINT", "is_dttm": False},
{
"name": "int_column",
"column_name": "int_column",
"type": "BIGINT",
"is_dttm": False,
},
{
"name": "array_column",
"column_name": "array_column",
"type": "ARRAY(ROW(NESTED_ARRAY ARRAY(ROW(NESTED_OBJ VARCHAR))))",
"is_dttm": False,
},
@ -398,19 +456,27 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
cols, data
)
expected_cols = [
{"name": "int_column", "type": "BIGINT", "is_dttm": False},
{
"name": "int_column",
"column_name": "int_column",
"type": "BIGINT",
"is_dttm": False,
},
{
"name": "array_column",
"column_name": "array_column",
"type": "ARRAY(ROW(NESTED_ARRAY ARRAY(ROW(NESTED_OBJ VARCHAR))))",
"is_dttm": False,
},
{
"name": "array_column.nested_array",
"column_name": "array_column.nested_array",
"type": "ARRAY(ROW(NESTED_OBJ VARCHAR))",
"is_dttm": False,
},
{
"name": "array_column.nested_array.nested_obj",
"column_name": "array_column.nested_array.nested_obj",
"type": "VARCHAR",
"is_dttm": False,
},
@ -468,11 +534,13 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
expected_expanded_cols = [
{
"name": "array_column.nested_array",
"column_name": "array_column.nested_array",
"type": "ARRAY(ROW(NESTED_OBJ VARCHAR))",
"is_dttm": False,
},
{
"name": "array_column.nested_array.nested_obj",
"column_name": "array_column.nested_array.nested_obj",
"type": "VARCHAR",
"is_dttm": False,
},
@ -575,9 +643,20 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
)
def test_presto_expand_data_array(self):
cols = [
{"name": "event_id", "type": "VARCHAR", "is_dttm": False},
{"name": "timestamp", "type": "BIGINT", "is_dttm": False},
{
"column_name": "event_id",
"name": "event_id",
"type": "VARCHAR",
"is_dttm": False,
},
{
"column_name": "timestamp",
"name": "timestamp",
"type": "BIGINT",
"is_dttm": False,
},
{
"column_name": "user",
"name": "user",
"type": "ROW(ID BIGINT, FIRST_NAME VARCHAR, LAST_NAME VARCHAR)",
"is_dttm": False,
@ -594,16 +673,42 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
cols, data
)
expected_cols = [
{"name": "event_id", "type": "VARCHAR", "is_dttm": False},
{"name": "timestamp", "type": "BIGINT", "is_dttm": False},
{
"column_name": "event_id",
"name": "event_id",
"type": "VARCHAR",
"is_dttm": False,
},
{
"column_name": "timestamp",
"name": "timestamp",
"type": "BIGINT",
"is_dttm": False,
},
{
"column_name": "user",
"name": "user",
"type": "ROW(ID BIGINT, FIRST_NAME VARCHAR, LAST_NAME VARCHAR)",
"is_dttm": False,
},
{"name": "user.id", "type": "BIGINT", "is_dttm": False},
{"name": "user.first_name", "type": "VARCHAR", "is_dttm": False},
{"name": "user.last_name", "type": "VARCHAR", "is_dttm": False},
{
"column_name": "user.id",
"name": "user.id",
"type": "BIGINT",
"is_dttm": False,
},
{
"column_name": "user.first_name",
"name": "user.first_name",
"type": "VARCHAR",
"is_dttm": False,
},
{
"column_name": "user.last_name",
"name": "user.last_name",
"type": "VARCHAR",
"is_dttm": False,
},
]
expected_data = [
{
@ -616,9 +721,24 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
}
]
expected_expanded_cols = [
{"name": "user.id", "type": "BIGINT", "is_dttm": False},
{"name": "user.first_name", "type": "VARCHAR", "is_dttm": False},
{"name": "user.last_name", "type": "VARCHAR", "is_dttm": False},
{
"column_name": "user.id",
"name": "user.id",
"type": "BIGINT",
"is_dttm": False,
},
{
"column_name": "user.first_name",
"name": "user.first_name",
"type": "VARCHAR",
"is_dttm": False,
},
{
"column_name": "user.last_name",
"name": "user.last_name",
"type": "VARCHAR",
"is_dttm": False,
},
]
self.assertEqual(actual_cols, expected_cols)
@ -736,12 +856,12 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
table_name = "table_name"
engine = mock.Mock()
cols = [
{"name": "val1"},
{"name": "val2<?!@#$312,/'][p098"},
{"name": ".val2"},
{"name": "val2."},
{"name": "val.2"},
{"name": ".val2."},
{"column_name": "val1"},
{"column_name": "val2<?!@#$312,/'][p098"},
{"column_name": ".val2"},
{"column_name": "val2."},
{"column_name": "val.2"},
{"column_name": ".val2."},
]
PrestoEngineSpec.select_star(
database, table_name, engine, show_cols=True, cols=cols
@ -756,8 +876,8 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
True,
True,
[
{"name": "val1"},
{"name": "val2<?!@#$312,/'][p098"},
{"column_name": "val1"},
{"column_name": "val2<?!@#$312,/'][p098"},
],
)

View File

@ -48,9 +48,9 @@ class TestSupersetResultSet(SupersetTestCase):
self.assertEqual(
results.columns,
[
{"is_dttm": False, "type": "STRING", "name": "a"},
{"is_dttm": False, "type": "STRING", "name": "b"},
{"is_dttm": False, "type": "STRING", "name": "c"},
{"is_dttm": False, "type": "STRING", "column_name": "a", "name": "a"},
{"is_dttm": False, "type": "STRING", "column_name": "b", "name": "b"},
{"is_dttm": False, "type": "STRING", "column_name": "c", "name": "c"},
],
)
@ -61,8 +61,8 @@ class TestSupersetResultSet(SupersetTestCase):
self.assertEqual(
results.columns,
[
{"is_dttm": False, "type": "STRING", "name": "a"},
{"is_dttm": False, "type": "INT", "name": "b"},
{"is_dttm": False, "type": "STRING", "column_name": "a", "name": "a"},
{"is_dttm": False, "type": "INT", "column_name": "b", "name": "b"},
],
)
@ -76,11 +76,11 @@ class TestSupersetResultSet(SupersetTestCase):
self.assertEqual(
results.columns,
[
{"is_dttm": False, "type": "FLOAT", "name": "a"},
{"is_dttm": False, "type": "INT", "name": "b"},
{"is_dttm": False, "type": "STRING", "name": "c"},
{"is_dttm": True, "type": "DATETIME", "name": "d"},
{"is_dttm": False, "type": "BOOL", "name": "e"},
{"is_dttm": False, "type": "FLOAT", "column_name": "a", "name": "a"},
{"is_dttm": False, "type": "INT", "column_name": "b", "name": "b"},
{"is_dttm": False, "type": "STRING", "column_name": "c", "name": "c"},
{"is_dttm": True, "type": "DATETIME", "column_name": "d", "name": "d"},
{"is_dttm": False, "type": "BOOL", "column_name": "e", "name": "e"},
],
)
@ -100,7 +100,7 @@ class TestSupersetResultSet(SupersetTestCase):
data = [("a", 1), ("a", 2)]
cursor_descr = (("a", "string"), ("a", "string"))
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
column_names = [col["name"] for col in results.columns]
column_names = [col["column_name"] for col in results.columns]
self.assertListEqual(column_names, ["a", "a__1"])
def test_int64_with_missing_data(self):

View File

@ -121,9 +121,9 @@ def test_main_dttm_col(mocker: MockerFixture, test_table: "SqlaTable") -> None:
mocker.patch(
"superset.connectors.sqla.models.get_physical_table_metadata",
return_value=[
{"name": "ds", "type": "TIMESTAMP", "is_dttm": True},
{"name": "event_time", "type": "TIMESTAMP", "is_dttm": True},
{"name": "id", "type": "INTEGER", "is_dttm": False},
{"column_name": "ds", "type": "TIMESTAMP", "is_dttm": True},
{"column_name": "event_time", "type": "TIMESTAMP", "is_dttm": True},
{"column_name": "id", "type": "INTEGER", "is_dttm": False},
],
)
@ -154,9 +154,9 @@ def test_main_dttm_col_nonexistent(
mocker.patch(
"superset.connectors.sqla.models.get_physical_table_metadata",
return_value=[
{"name": "ds", "type": "TIMESTAMP", "is_dttm": True},
{"name": "event_time", "type": "TIMESTAMP", "is_dttm": True},
{"name": "id", "type": "INTEGER", "is_dttm": False},
{"column_name": "ds", "type": "TIMESTAMP", "is_dttm": True},
{"column_name": "event_time", "type": "TIMESTAMP", "is_dttm": True},
{"column_name": "id", "type": "INTEGER", "is_dttm": False},
],
)
@ -188,9 +188,9 @@ def test_main_dttm_col_nondttm(
mocker.patch(
"superset.connectors.sqla.models.get_physical_table_metadata",
return_value=[
{"name": "ds", "type": "TIMESTAMP", "is_dttm": True},
{"name": "event_time", "type": "TIMESTAMP", "is_dttm": True},
{"name": "id", "type": "INTEGER", "is_dttm": False},
{"column_name": "ds", "type": "TIMESTAMP", "is_dttm": True},
{"column_name": "event_time", "type": "TIMESTAMP", "is_dttm": True},
{"column_name": "id", "type": "INTEGER", "is_dttm": False},
],
)
@ -226,9 +226,9 @@ def test_python_date_format_by_column_name(
mocker.patch(
"superset.connectors.sqla.models.get_physical_table_metadata",
return_value=[
{"name": "id", "type": "INTEGER", "is_dttm": False},
{"name": "dttm", "type": "INTEGER", "is_dttm": False},
{"name": "duration_ms", "type": "INTEGER", "is_dttm": False},
{"column_name": "id", "type": "INTEGER", "is_dttm": False},
{"column_name": "dttm", "type": "INTEGER", "is_dttm": False},
{"column_name": "duration_ms", "type": "INTEGER", "is_dttm": False},
],
)
@ -274,8 +274,8 @@ def test_expression_by_column_name(
mocker.patch(
"superset.connectors.sqla.models.get_physical_table_metadata",
return_value=[
{"name": "dttm", "type": "INTEGER", "is_dttm": False},
{"name": "duration_ms", "type": "INTEGER", "is_dttm": False},
{"column_name": "dttm", "type": "INTEGER", "is_dttm": False},
{"column_name": "duration_ms", "type": "INTEGER", "is_dttm": False},
],
)
@ -311,9 +311,9 @@ def test_full_setting(
mocker.patch(
"superset.connectors.sqla.models.get_physical_table_metadata",
return_value=[
{"name": "id", "type": "INTEGER", "is_dttm": False},
{"name": "dttm", "type": "INTEGER", "is_dttm": False},
{"name": "duration_ms", "type": "INTEGER", "is_dttm": False},
{"column_name": "id", "type": "INTEGER", "is_dttm": False},
{"column_name": "dttm", "type": "INTEGER", "is_dttm": False},
{"column_name": "duration_ms", "type": "INTEGER", "is_dttm": False},
],
)

View File

@ -22,6 +22,7 @@ from typing import Any, Optional
import pytest
from sqlalchemy import types
from superset.superset_typing import ResultSetColumnType, SQLAColumnType
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import assert_column_spec
@ -138,3 +139,32 @@ def test_get_column_spec(
from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec as spec
assert_column_spec(spec, native_type, sqla_type, attrs, generic_type, is_dttm)
@pytest.mark.parametrize(
"cols, expected_result",
[
(
[SQLAColumnType(name="John", type="integer", is_dttm=False)],
[
ResultSetColumnType(
column_name="John", name="John", type="integer", is_dttm=False
)
],
),
(
[SQLAColumnType(name="hugh", type="integer", is_dttm=False)],
[
ResultSetColumnType(
column_name="hugh", name="hugh", type="integer", is_dttm=False
)
],
),
],
)
def test_convert_inspector_columns(
cols: list[SQLAColumnType], expected_result: list[ResultSetColumnType]
):
from superset.db_engine_specs.base import convert_inspector_columns
assert convert_inspector_columns(cols) == expected_result

View File

@ -27,6 +27,7 @@ from sqlalchemy import select
from sqlalchemy.sql import sqltypes
from sqlalchemy_bigquery import BigQueryDialect
from superset.superset_typing import ResultSetColumnType
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
@ -64,7 +65,16 @@ def test_get_fields() -> None:
"""
from superset.db_engine_specs.bigquery import BigQueryEngineSpec
columns = [{"name": "limit"}, {"name": "name"}, {"name": "project.name"}]
columns: list[ResultSetColumnType] = [
{"column_name": "limit", "name": "limit", "type": "STRING", "is_dttm": False},
{"column_name": "name", "name": "name", "type": "STRING", "is_dttm": False},
{
"column_name": "project.name",
"name": "project.name",
"type": "STRING",
"is_dttm": False,
},
]
fields = BigQueryEngineSpec._get_fields(columns)
query = select(fields)
@ -84,8 +94,9 @@ def test_select_star(mocker: MockFixture) -> None:
"""
from superset.db_engine_specs.bigquery import BigQueryEngineSpec
cols = [
cols: list[ResultSetColumnType] = [
{
"column_name": "trailer",
"name": "trailer",
"type": sqltypes.ARRAY(sqltypes.JSON()),
"nullable": True,
@ -94,8 +105,10 @@ def test_select_star(mocker: MockFixture) -> None:
"precision": None,
"scale": None,
"max_length": None,
"is_dttm": False,
},
{
"column_name": "trailer.key",
"name": "trailer.key",
"type": sqltypes.String(),
"nullable": True,
@ -104,8 +117,10 @@ def test_select_star(mocker: MockFixture) -> None:
"precision": None,
"scale": None,
"max_length": None,
"is_dttm": False,
},
{
"column_name": "trailer.value",
"name": "trailer.value",
"type": sqltypes.String(),
"nullable": True,
@ -114,8 +129,10 @@ def test_select_star(mocker: MockFixture) -> None:
"precision": None,
"scale": None,
"max_length": None,
"is_dttm": False,
},
{
"column_name": "trailer.email",
"name": "trailer.email",
"type": sqltypes.String(),
"nullable": True,
@ -124,6 +141,7 @@ def test_select_star(mocker: MockFixture) -> None:
"precision": None,
"scale": None,
"max_length": None,
"is_dttm": False,
},
]

View File

@ -24,6 +24,7 @@ from pyhive.sqlalchemy_presto import PrestoDialect
from sqlalchemy import sql, text, types
from sqlalchemy.engine.url import make_url
from superset.superset_typing import ResultSetColumnType
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
assert_column_spec,
@ -131,7 +132,14 @@ def test_where_latest_partition(
mock_latest_partition.return_value = (["partition_key"], [column_value])
query = sql.select(text("* FROM table"))
columns = [{"name": "partition_key", "type": column_type}]
columns: list[ResultSetColumnType] = [
{
"column_name": "partition_key",
"name": "partition_key",
"type": column_type,
"is_dttm": False,
}
]
expected = f"""SELECT * FROM table \nWHERE "partition_key" = {expected_value}"""
result = spec.where_latest_partition(

View File

@ -0,0 +1,45 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
def test_column_attributes_on_query():
from superset.daos.query import QueryDAO
from superset.models.core import Database
from superset.models.sql_lab import Query
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
query_obj = Query(
client_id="foo",
database=db,
tab_name="test_tab",
sql_editor_id="test_editor_id",
sql="select * from bar",
select_sql="select * from bar",
executed_sql="select * from bar",
limit=100,
select_as_cta=False,
rows=100,
error_message="none",
results_key="abc",
)
columns = [{"name": "test", "is_dttm": False, "type": "INT"}]
payload = {"columns": columns}
QueryDAO.save_metadata(query_obj, payload)
assert "column_name" in json.loads(query_obj.extra_json).get("columns")[0]