feat: Axis sort in the Bar Chart V2 (#21993)

This commit is contained in:
Yongjie Zhao 2022-11-26 22:06:26 +08:00 committed by GitHub
parent cc2334e58c
commit 22fab5e58c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 349 additions and 167 deletions

View File

@ -1,4 +1,3 @@
/* eslint-disable camelcase */
/** /**
* Licensed to the Apache Software Foundation (ASF) under one * Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file * or more contributor license agreements. See the NOTICE file
@ -17,25 +16,50 @@
* specific language governing permissions and limitationsxw * specific language governing permissions and limitationsxw
* under the License. * under the License.
*/ */
import { DTTM_ALIAS, PostProcessingSort, RollingType } from '@superset-ui/core'; import { isEmpty } from 'lodash';
import {
ensureIsArray,
getMetricLabel,
getXAxisLabel,
hasGenericChartAxes,
isDefined,
PostProcessingSort,
} from '@superset-ui/core';
import { PostProcessingFactory } from './types'; import { PostProcessingFactory } from './types';
export const sortOperator: PostProcessingFactory<PostProcessingSort> = ( export const sortOperator: PostProcessingFactory<PostProcessingSort> = (
formData, formData,
queryObject, queryObject,
) => { ) => {
const { x_axis: xAxis } = formData; // the sortOperator only used in the barchart v2
const sortableLabels = [
getXAxisLabel(formData),
...ensureIsArray(formData.metrics).map(metric => getMetricLabel(metric)),
].filter(Boolean);
if ( if (
(xAxis || queryObject.is_timeseries) && hasGenericChartAxes &&
Object.values(RollingType).includes(formData.rolling_type) isDefined(formData?.x_axis_sort) &&
isDefined(formData?.x_axis_sort_asc) &&
sortableLabels.includes(formData.x_axis_sort) &&
// the sort operator doesn't support sort-by multiple series.
isEmpty(formData.groupby)
) { ) {
const index = xAxis || DTTM_ALIAS; if (formData.x_axis_sort === getXAxisLabel(formData)) {
return {
operation: 'sort',
options: {
is_sort_index: true,
ascending: formData.x_axis_sort_asc,
},
};
}
return { return {
operation: 'sort', operation: 'sort',
options: { options: {
columns: { by: formData.x_axis_sort,
[index]: true, ascending: formData.x_axis_sort_asc,
},
}, },
}; };
} }

View File

@ -16,9 +16,28 @@
* specific language governing permissions and limitations * specific language governing permissions and limitations
* under the License. * under the License.
*/ */
import { ContributionType, hasGenericChartAxes, t } from '@superset-ui/core'; import { hasGenericChartAxes, t } from '@superset-ui/core';
import { ControlPanelSectionConfig } from '../types'; import { ControlPanelSectionConfig, ControlSetRow } from '../types';
import { emitFilterControl } from '../shared-controls'; import {
contributionModeControl,
emitFilterControl,
xAxisSortControl,
xAxisSortAscControl,
} from '../shared-controls';
const controlsWithoutXAxis: ControlSetRow[] = [
['metrics'],
['groupby'],
[contributionModeControl],
['adhoc_filters'],
emitFilterControl,
['limit'],
['timeseries_limit_metric'],
['order_desc'],
['row_limit'],
['truncate_metric'],
['show_empty_columns'],
];
export const echartsTimeSeriesQuery: ControlPanelSectionConfig = { export const echartsTimeSeriesQuery: ControlPanelSectionConfig = {
label: t('Query'), label: t('Query'),
@ -26,31 +45,18 @@ export const echartsTimeSeriesQuery: ControlPanelSectionConfig = {
controlSetRows: [ controlSetRows: [
[hasGenericChartAxes ? 'x_axis' : null], [hasGenericChartAxes ? 'x_axis' : null],
[hasGenericChartAxes ? 'time_grain_sqla' : null], [hasGenericChartAxes ? 'time_grain_sqla' : null],
['metrics'], ...controlsWithoutXAxis,
['groupby'], ],
[ };
{
name: 'contributionMode', export const echartsTimeSeriesQueryWithXAxisSort: ControlPanelSectionConfig = {
config: { label: t('Query'),
type: 'SelectControl', expanded: true,
label: t('Contribution Mode'), controlSetRows: [
default: null, [hasGenericChartAxes ? 'x_axis' : null],
choices: [ [hasGenericChartAxes ? 'time_grain_sqla' : null],
[null, t('None')], [hasGenericChartAxes ? xAxisSortControl : null],
[ContributionType.Row, t('Row')], [hasGenericChartAxes ? xAxisSortAscControl : null],
[ContributionType.Column, t('Series')], ...controlsWithoutXAxis,
],
description: t('Calculate contribution per series or row'),
},
},
],
['adhoc_filters'],
emitFilterControl,
['limit'],
['timeseries_limit_metric'],
['order_desc'],
['row_limit'],
['truncate_metric'],
['show_empty_columns'],
], ],
}; };

View File

@ -17,7 +17,21 @@
* under the License. * under the License.
*/ */
import { FeatureFlag, isFeatureEnabled, t } from '@superset-ui/core'; import {
ContributionType,
ensureIsArray,
FeatureFlag,
getColumnLabel,
getMetricLabel,
isDefined,
isEqualArray,
isFeatureEnabled,
QueryFormColumn,
QueryFormMetric,
t,
} from '@superset-ui/core';
import { ControlPanelState, ControlState, ControlStateMapping } from '../types';
import { isTemporalColumn } from '../utils';
export const emitFilterControl = isFeatureEnabled( export const emitFilterControl = isFeatureEnabled(
FeatureFlag.DASHBOARD_CROSS_FILTERS, FeatureFlag.DASHBOARD_CROSS_FILTERS,
@ -35,3 +49,93 @@ export const emitFilterControl = isFeatureEnabled(
}, },
] ]
: []; : [];
export const contributionModeControl = {
name: 'contributionMode',
config: {
type: 'SelectControl',
label: t('Contribution Mode'),
default: null,
choices: [
[null, t('None')],
[ContributionType.Row, t('Row')],
[ContributionType.Column, t('Series')],
],
description: t('Calculate contribution per series or row'),
},
};
const xAxisSortVisibility = ({ controls }: { controls: ControlStateMapping }) =>
isDefined(controls?.x_axis?.value) &&
!isTemporalColumn(
getColumnLabel(controls?.x_axis?.value as QueryFormColumn),
controls?.datasource?.datasource,
) &&
Array.isArray(controls?.groupby?.value) &&
controls.groupby.value.length === 0;
export const xAxisSortControl = {
name: 'x_axis_sort',
config: {
type: 'XAxisSortControl',
label: t('X-Axis Sort By'),
description: t('Whether to sort descending or ascending on the X-Axis.'),
shouldMapStateToProps: (
prevState: ControlPanelState,
state: ControlPanelState,
) => {
const prevOptions = [
getColumnLabel(prevState?.controls?.x_axis?.value as QueryFormColumn),
...ensureIsArray(prevState?.controls?.metrics?.value).map(metric =>
getMetricLabel(metric as QueryFormMetric),
),
];
const currOptions = [
getColumnLabel(state?.controls?.x_axis?.value as QueryFormColumn),
...ensureIsArray(state?.controls?.metrics?.value).map(metric =>
getMetricLabel(metric as QueryFormMetric),
),
];
return !isEqualArray(prevOptions, currOptions);
},
mapStateToProps: (
{ controls }: { controls: ControlStateMapping },
controlState: ControlState,
) => {
const choices = [
getColumnLabel(controls?.x_axis?.value as QueryFormColumn),
...ensureIsArray(controls?.metrics?.value).map(metric =>
getMetricLabel(metric as QueryFormMetric),
),
].filter(Boolean);
const shouldReset = !(
typeof controlState.value === 'string' &&
choices.includes(controlState.value) &&
!isTemporalColumn(
getColumnLabel(controls?.x_axis?.value as QueryFormColumn),
controls?.datasource?.datasource,
)
);
return {
shouldReset,
options: choices.map(entry => ({
value: entry,
label: entry,
})),
};
},
visibility: xAxisSortVisibility,
},
};
export const xAxisSortAscControl = {
name: 'x_axis_sort_asc',
config: {
type: 'CheckboxControl',
label: t('X-Axis Sort Ascending'),
default: true,
description: t('Whether to sort descending or ascending on the X-Axis.'),
visibility: xAxisSortVisibility,
},
};

View File

@ -354,7 +354,7 @@ const show_empty_columns: SharedControlConfig<'CheckboxControl'> = {
description: t('Show empty columns'), description: t('Show empty columns'),
}; };
const datetime_columns_lookup: SharedControlConfig<'HiddenControl'> = { const temporal_columns_lookup: SharedControlConfig<'HiddenControl'> = {
type: 'HiddenControl', type: 'HiddenControl',
initialValue: (control: ControlState, state: ControlPanelState | null) => initialValue: (control: ControlState, state: ControlPanelState | null) =>
Object.fromEntries( Object.fromEntries(
@ -400,5 +400,5 @@ export default {
truncate_metric, truncate_metric,
x_axis: dndXAxisControl, x_axis: dndXAxisControl,
show_empty_columns, show_empty_columns,
datetime_columns_lookup, temporal_columns_lookup,
}; };

View File

@ -18,6 +18,7 @@
*/ */
import { QueryObject, SqlaFormData } from '@superset-ui/core'; import { QueryObject, SqlaFormData } from '@superset-ui/core';
import { sortOperator } from '@superset-ui/chart-controls'; import { sortOperator } from '@superset-ui/chart-controls';
import * as supersetCoreModule from '@superset-ui/core';
const formData: SqlaFormData = { const formData: SqlaFormData = {
metrics: [ metrics: [
@ -52,92 +53,96 @@ const queryObject: QueryObject = {
], ],
}; };
test('skip sort', () => { test('should ignore the sortOperator', () => {
// FF is disabled
Object.defineProperty(supersetCoreModule, 'hasGenericChartAxes', {
value: false,
});
expect(sortOperator(formData, queryObject)).toEqual(undefined); expect(sortOperator(formData, queryObject)).toEqual(undefined);
expect(
sortOperator(formData, { ...queryObject, is_timeseries: false }), // FF is enabled
).toEqual(undefined); Object.defineProperty(supersetCoreModule, 'hasGenericChartAxes', {
value: true,
});
expect( expect(
sortOperator( sortOperator(
{ ...formData, rolling_type: 'xxxx' }, {
{ ...queryObject, is_timeseries: true }, ...formData,
...{
x_axis_sort: undefined,
x_axis_sort_asc: true,
},
},
queryObject,
), ),
).toEqual(undefined); ).toEqual(undefined);
// sortOperator doesn't support multiple series
Object.defineProperty(supersetCoreModule, 'hasGenericChartAxes', {
value: true,
});
expect( expect(
sortOperator(formData, { ...queryObject, is_timeseries: true }), sortOperator(
{
...formData,
...{
x_axis_sort: 'metric label',
x_axis_sort_asc: true,
groupby: ['col1'],
x_axis: 'axis column',
},
},
queryObject,
),
).toEqual(undefined); ).toEqual(undefined);
}); });
test('sort by __timestamp', () => { test('should sort by metric', () => {
expect( Object.defineProperty(supersetCoreModule, 'hasGenericChartAxes', {
sortOperator( value: true,
{ ...formData, rolling_type: 'cumsum' },
{ ...queryObject, is_timeseries: true },
),
).toEqual({
operation: 'sort',
options: {
columns: {
__timestamp: true,
},
},
}); });
expect( expect(
sortOperator( sortOperator(
{ ...formData, rolling_type: 'sum' }, {
{ ...queryObject, is_timeseries: true }, ...formData,
...{
metrics: ['a metric label'],
x_axis_sort: 'a metric label',
x_axis_sort_asc: true,
},
},
queryObject,
), ),
).toEqual({ ).toEqual({
operation: 'sort', operation: 'sort',
options: { options: {
columns: { by: 'a metric label',
__timestamp: true, ascending: true,
},
},
});
expect(
sortOperator(
{ ...formData, rolling_type: 'mean' },
{ ...queryObject, is_timeseries: true },
),
).toEqual({
operation: 'sort',
options: {
columns: {
__timestamp: true,
},
},
});
expect(
sortOperator(
{ ...formData, rolling_type: 'std' },
{ ...queryObject, is_timeseries: true },
),
).toEqual({
operation: 'sort',
options: {
columns: {
__timestamp: true,
},
}, },
}); });
}); });
test('sort by named x-axis', () => { test('should sort by axis', () => {
Object.defineProperty(supersetCoreModule, 'hasGenericChartAxes', {
value: true,
});
expect( expect(
sortOperator( sortOperator(
{ ...formData, x_axis: 'ds', rolling_type: 'cumsum' }, {
{ ...queryObject }, ...formData,
...{
x_axis_sort: 'Categorical Column',
x_axis_sort_asc: true,
x_axis: 'Categorical Column',
},
},
queryObject,
), ),
).toEqual({ ).toEqual({
operation: 'sort', operation: 'sort',
options: { options: {
columns: { is_sort_index: true,
ds: true, ascending: true,
},
}, },
}); });
}); });

View File

@ -23,8 +23,8 @@ export default function getColumnLabel(column: QueryFormColumn): string {
if (isPhysicalColumn(column)) { if (isPhysicalColumn(column)) {
return column; return column;
} }
if (column.label) { if (column?.label) {
return column.label; return column.label;
} }
return column.sqlExpression; return column?.sqlExpression;
} }

View File

@ -29,13 +29,8 @@ export { default as getMetricLabel } from './getMetricLabel';
export { default as DatasourceKey } from './DatasourceKey'; export { default as DatasourceKey } from './DatasourceKey';
export { default as normalizeOrderBy } from './normalizeOrderBy'; export { default as normalizeOrderBy } from './normalizeOrderBy';
export { normalizeTimeColumn } from './normalizeTimeColumn'; export { normalizeTimeColumn } from './normalizeTimeColumn';
export {
getXAxisLabel,
getXAxisColumn,
isXAxisSet,
hasGenericChartAxes,
} from './getXAxis';
export { default as extractQueryFields } from './extractQueryFields'; export { default as extractQueryFields } from './extractQueryFields';
export * from './getXAxis';
export * from './types/AnnotationLayer'; export * from './types/AnnotationLayer';
export * from './types/QueryFormData'; export * from './types/QueryFormData';

View File

@ -182,7 +182,9 @@ export type PostProcessingCompare =
interface _PostProcessingSort { interface _PostProcessingSort {
operation: 'sort'; operation: 'sort';
options: { options: {
columns: Record<string, boolean>; is_sort_index?: boolean;
by?: string[] | string;
ascending?: boolean[] | boolean;
}; };
} }
export type PostProcessingSort = _PostProcessingSort | DefaultPostProcessing; export type PostProcessingSort = _PostProcessingSort | DefaultPostProcessing;

View File

@ -147,7 +147,7 @@ const ROLLING_RULE: PostProcessingRolling = {
const SORT_RULE: PostProcessingSort = { const SORT_RULE: PostProcessingSort = {
operation: 'sort', operation: 'sort',
options: { options: {
columns: { foo: true }, by: 'foo',
}, },
}; };

View File

@ -38,7 +38,7 @@ export default function buildQuery(formData: BoxPlotQueryFormData) {
if ( if (
isPhysicalColumn(col) && isPhysicalColumn(col) &&
formData.time_grain_sqla && formData.time_grain_sqla &&
formData?.datetime_columns_lookup?.[col] formData?.temporal_columns_lookup?.[col]
) { ) {
return { return {
timeGrain: formData.time_grain_sqla, timeGrain: formData.time_grain_sqla,

View File

@ -73,7 +73,7 @@ const config: ControlPanelConfig = {
}, },
}, },
}, },
'datetime_columns_lookup', 'temporal_columns_lookup',
], ],
['groupby'], ['groupby'],
['metrics'], ['metrics'],

View File

@ -41,7 +41,6 @@ import {
const { const {
logAxis, logAxis,
minorSplitLine, minorSplitLine,
rowLimit,
truncateYAxis, truncateYAxis,
yAxisBounds, yAxisBounds,
zoomable, zoomable,
@ -260,7 +259,7 @@ function createAxisControl(axis: 'x' | 'y'): ControlSetRow[] {
const config: ControlPanelConfig = { const config: ControlPanelConfig = {
controlPanelSections: [ controlPanelSections: [
sections.genericTime, sections.genericTime,
sections.echartsTimeSeriesQuery, sections.echartsTimeSeriesQueryWithXAxisSort,
sections.advancedAnalyticsControls, sections.advancedAnalyticsControls,
sections.annotationsAndLayersControls, sections.annotationsAndLayersControls,
sections.forecastIntervalControls, sections.forecastIntervalControls,
@ -324,40 +323,6 @@ const config: ControlPanelConfig = {
], ],
}, },
], ],
controlOverrides: {
row_limit: {
default: rowLimit,
},
limit: {
rerender: ['timeseries_limit_metric', 'order_desc'],
},
timeseries_limit_metric: {
label: t('Series Limit Sort By'),
description: t(
'Metric used to order the limit if a series limit is present. ' +
'If undefined reverts to the first metric (where appropriate).',
),
visibility: ({ controls }) => Boolean(controls?.limit.value),
mapStateToProps: (state, controlState) => {
const timeserieslimitProps =
sharedControls.timeseries_limit_metric.mapStateToProps?.(
state,
controlState,
) || {};
timeserieslimitProps.value = state.controls?.limit?.value
? controlState?.value
: [];
return timeserieslimitProps;
},
},
order_desc: {
label: t('Series Limit Sort Descending'),
default: false,
description: t(
'Whether to sort descending or ascending if a series limit is present',
),
},
},
formDataOverrides: formData => ({ formDataOverrides: formData => ({
...formData, ...formData,
metrics: getStandardizedControls().popAllMetrics(), metrics: getStandardizedControls().popAllMetrics(),

View File

@ -36,6 +36,7 @@ import {
prophetOperator, prophetOperator,
timeComparePivotOperator, timeComparePivotOperator,
flattenOperator, flattenOperator,
sortOperator,
} from '@superset-ui/chart-controls'; } from '@superset-ui/chart-controls';
export default function buildQuery(formData: QueryFormData) { export default function buildQuery(formData: QueryFormData) {
@ -95,6 +96,7 @@ export default function buildQuery(formData: QueryFormData) {
resampleOperator(formData, baseQueryObject), resampleOperator(formData, baseQueryObject),
renameOperator(formData, baseQueryObject), renameOperator(formData, baseQueryObject),
contributionOperator(formData, baseQueryObject), contributionOperator(formData, baseQueryObject),
sortOperator(formData, baseQueryObject),
flattenOperator(formData, baseQueryObject), flattenOperator(formData, baseQueryObject),
// todo: move prophet before flatten // todo: move prophet before flatten
prophetOperator(formData, baseQueryObject), prophetOperator(formData, baseQueryObject),

View File

@ -42,7 +42,7 @@ export default function buildQuery(formData: PivotTableQueryFormData) {
isPhysicalColumn(col) && isPhysicalColumn(col) &&
formData.time_grain_sqla && formData.time_grain_sqla &&
hasGenericChartAxes && hasGenericChartAxes &&
formData?.datetime_columns_lookup?.[col] formData?.temporal_columns_lookup?.[col]
) { ) {
return { return {
timeGrain: formData.time_grain_sqla, timeGrain: formData.time_grain_sqla,

View File

@ -97,7 +97,7 @@ const config: ControlPanelConfig = {
}, },
} }
: null, : null,
hasGenericChartAxes ? 'datetime_columns_lookup' : null, hasGenericChartAxes ? 'temporal_columns_lookup' : null,
], ],
[ [
{ {

View File

@ -104,7 +104,7 @@ const buildQuery: BuildQuery<TableChartFormData> = (
isPhysicalColumn(col) && isPhysicalColumn(col) &&
formData.time_grain_sqla && formData.time_grain_sqla &&
hasGenericChartAxes && hasGenericChartAxes &&
formData?.datetime_columns_lookup?.[col] formData?.temporal_columns_lookup?.[col]
) { ) {
return { return {
timeGrain: formData.time_grain_sqla, timeGrain: formData.time_grain_sqla,

View File

@ -218,7 +218,7 @@ const config: ControlPanelConfig = {
}, },
} }
: null, : null,
hasGenericChartAxes && isAggMode ? 'datetime_columns_lookup' : null, hasGenericChartAxes && isAggMode ? 'temporal_columns_lookup' : null,
], ],
[ [
{ {

View File

@ -0,0 +1,36 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import React, { useEffect, useState } from 'react';
import SelectControl from './SelectControl';
export default function XAxisSortControl(props: {
onChange: (val: string | undefined) => void;
value: string | null;
shouldReset: boolean;
}) {
const [value, setValue] = useState(props.value);
useEffect(() => {
if (props.shouldReset) {
props.onChange(undefined);
setValue(null);
}
}, [props.shouldReset, props.value]);
return <SelectControl {...props} value={value} />;
}

View File

@ -45,6 +45,7 @@ import DndColumnSelectControl, {
DndFilterSelect, DndFilterSelect,
DndMetricSelect, DndMetricSelect,
} from './DndColumnSelectControl'; } from './DndColumnSelectControl';
import XAxisSortControl from './XAxisSortControl';
const controlMap = { const controlMap = {
AnnotationLayerControl, AnnotationLayerControl,
@ -74,6 +75,7 @@ const controlMap = {
AdhocFilterControl, AdhocFilterControl,
FilterBoxItemControl, FilterBoxItemControl,
ConditionalFormattingControl, ConditionalFormattingControl,
XAxisSortControl,
...sharedControlComponents, ...sharedControlComponents,
}; };
export default controlMap; export default controlMap;

View File

@ -14,22 +14,34 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from typing import Dict from typing import List, Optional, Union
from pandas import DataFrame from pandas import DataFrame
from superset.utils.pandas_postprocessing.utils import validate_column_args from superset.utils.pandas_postprocessing.utils import validate_column_args
@validate_column_args("columns") # pylint: disable=invalid-name
def sort(df: DataFrame, columns: Dict[str, bool]) -> DataFrame: @validate_column_args("by")
def sort(
df: DataFrame,
is_sort_index: bool = False,
by: Optional[Union[List[str], str]] = None,
ascending: Union[List[bool], bool] = True,
) -> DataFrame:
""" """
Sort a DataFrame. Sort a DataFrame.
:param df: DataFrame to sort. :param df: DataFrame to sort.
:param columns: columns by by which to sort. The key specifies the column name, :param is_sort_index: Whether by index or value to sort
value specifies if sorting in ascending order. :param by: Name or list of names to sort by.
:param ascending: Sort ascending or descending.
:return: Sorted DataFrame :return: Sorted DataFrame
:raises InvalidPostProcessingError: If the request in incorrect :raises InvalidPostProcessingError: If the request in incorrect
""" """
return df.sort_values(by=list(columns.keys()), ascending=list(columns.values())) if not is_sort_index and not by:
return df
if is_sort_index:
return df.sort_index(ascending=ascending)
return df.sort_values(by=by, ascending=ascending)

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from functools import partial from functools import partial
from typing import Any, Callable, Dict from typing import Any, Callable, Dict, Sequence
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@ -101,6 +101,14 @@ def _is_multi_index_on_columns(df: DataFrame) -> bool:
return isinstance(df.columns, pd.MultiIndex) return isinstance(df.columns, pd.MultiIndex)
def scalar_to_sequence(val: Any) -> Sequence[str]:
if val is None:
return []
if isinstance(val, str):
return [val]
return val
def validate_column_args(*argnames: str) -> Callable[..., Any]: def validate_column_args(*argnames: str) -> Callable[..., Any]:
def wrapper(func: Callable[..., Any]) -> Callable[..., Any]: def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
def wrapped(df: DataFrame, **options: Any) -> Any: def wrapped(df: DataFrame, **options: Any) -> Any:
@ -111,7 +119,7 @@ def validate_column_args(*argnames: str) -> Callable[..., Any]:
columns = df.columns.tolist() columns = df.columns.tolist()
for name in argnames: for name in argnames:
if name in options and not all( if name in options and not all(
elem in columns for elem in options.get(name) or [] elem in columns for elem in scalar_to_sequence(options.get(name))
): ):
raise InvalidPostProcessingError( raise InvalidPostProcessingError(
_("Referenced columns not available in DataFrame.") _("Referenced columns not available in DataFrame.")

View File

@ -195,9 +195,7 @@ POSTPROCESSING_OPERATIONS = {
}, },
{ {
"operation": "sort", "operation": "sort",
"options": { "options": {"by": ["q1", "name"], "ascending": [False, True]},
"columns": {"q1": False, "name": True},
},
}, },
] ]
} }

View File

@ -15,16 +15,39 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import pytest import pytest
from dateutil.parser import parse
from superset.exceptions import InvalidPostProcessingError from superset.exceptions import InvalidPostProcessingError
from superset.utils.pandas_postprocessing import sort from superset.utils.pandas_postprocessing import sort
from tests.unit_tests.fixtures.dataframes import categories_df from tests.unit_tests.fixtures.dataframes import categories_df, timeseries_df
from tests.unit_tests.pandas_postprocessing.utils import series_to_list from tests.unit_tests.pandas_postprocessing.utils import series_to_list
def test_sort(): def test_sort():
df = sort(df=categories_df, columns={"category": True, "asc_idx": False}) df = sort(df=categories_df, by=["category", "asc_idx"], ascending=[True, False])
assert series_to_list(df["asc_idx"])[1] == 96 assert series_to_list(df["asc_idx"])[1] == 96
df = sort(df=categories_df.set_index("name"), is_sort_index=True)
assert df.index[0] == "person0"
df = sort(df=categories_df.set_index("name"), is_sort_index=True, ascending=False)
assert df.index[0] == "person99"
df = sort(df=categories_df.set_index("name"), by="asc_idx")
assert df["asc_idx"][0] == 0
df = sort(df=categories_df.set_index("name"), by="asc_idx", ascending=False)
assert df["asc_idx"][0] == 100
df = sort(df=timeseries_df, is_sort_index=True)
assert df.index[0] == parse("2019-01-01")
df = sort(df=timeseries_df, is_sort_index=True, ascending=False)
assert df.index[0] == parse("2019-01-07")
df = sort(df=timeseries_df)
assert df.equals(timeseries_df)
with pytest.raises(InvalidPostProcessingError): with pytest.raises(InvalidPostProcessingError):
sort(df=df, columns={"abc": True}) sort(df=df, by="abc", ascending=False)
sort(df=df, by=["abc", "def"])