feat(advanced analysis): support MultiIndex column in post processing stage (#19116)

This commit is contained in:
Yongjie Zhao 2022-03-23 13:46:28 +08:00 committed by GitHub
parent 6083545e86
commit 375c03e084
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
55 changed files with 1267 additions and 772 deletions

View File

@ -21,16 +21,16 @@ import {
getColumnLabel,
getMetricLabel,
PostProcessingBoxplot,
BoxPlotQueryObjectWhiskerType,
} from '@superset-ui/core';
import { PostProcessingFactory } from './types';
type BoxPlotQueryObjectWhiskerType =
PostProcessingBoxplot['options']['whisker_type'];
const PERCENTILE_REGEX = /(\d+)\/(\d+) percentiles/;
export const boxplotOperator: PostProcessingFactory<
PostProcessingBoxplot | undefined
> = (formData, queryObject) => {
export const boxplotOperator: PostProcessingFactory<PostProcessingBoxplot> = (
formData,
queryObject,
) => {
const { groupby, whiskerOptions } = formData;
if (whiskerOptions) {

View File

@ -19,16 +19,15 @@
import { PostProcessingContribution } from '@superset-ui/core';
import { PostProcessingFactory } from './types';
export const contributionOperator: PostProcessingFactory<
PostProcessingContribution | undefined
> = (formData, queryObject) => {
if (formData.contributionMode) {
return {
operation: 'contribution',
options: {
orientation: formData.contributionMode,
},
};
}
return undefined;
};
export const contributionOperator: PostProcessingFactory<PostProcessingContribution> =
(formData, queryObject) => {
if (formData.contributionMode) {
return {
operation: 'contribution',
options: {
orientation: formData.contributionMode,
},
};
}
return undefined;
};

View File

@ -0,0 +1,26 @@
/* eslint-disable camelcase */
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitationsxw
* under the License.
*/
import { PostProcessingFlatten } from '@superset-ui/core';
import { PostProcessingFactory } from './types';
export const flattenOperator: PostProcessingFactory<PostProcessingFlatten> = (
formData,
queryObject,
) => ({ operation: 'flatten' });

View File

@ -26,4 +26,5 @@ export { resampleOperator } from './resampleOperator';
export { contributionOperator } from './contributionOperator';
export { prophetOperator } from './prophetOperator';
export { boxplotOperator } from './boxplotOperator';
export { flattenOperator } from './flattenOperator';
export * from './utils';

View File

@ -24,19 +24,14 @@ import {
PostProcessingPivot,
} from '@superset-ui/core';
import { PostProcessingFactory } from './types';
import { isValidTimeCompare } from './utils';
import { timeComparePivotOperator } from './timeComparePivotOperator';
export const pivotOperator: PostProcessingFactory<
PostProcessingPivot | undefined
> = (formData, queryObject) => {
export const pivotOperator: PostProcessingFactory<PostProcessingPivot> = (
formData,
queryObject,
) => {
const metricLabels = ensureIsArray(queryObject.metrics).map(getMetricLabel);
const { x_axis: xAxis } = formData;
if ((xAxis || queryObject.is_timeseries) && metricLabels.length) {
if (isValidTimeCompare(formData, queryObject)) {
return timeComparePivotOperator(formData, queryObject);
}
return {
operation: 'pivot',
options: {
@ -48,6 +43,8 @@ export const pivotOperator: PostProcessingFactory<
metricLabels.map(metric => [metric, { operator: 'mean' }]),
),
drop_missing_columns: false,
flatten_columns: false,
reset_index: false,
},
};
}

View File

@ -19,9 +19,10 @@
import { DTTM_ALIAS, PostProcessingProphet } from '@superset-ui/core';
import { PostProcessingFactory } from './types';
export const prophetOperator: PostProcessingFactory<
PostProcessingProphet | undefined
> = (formData, queryObject) => {
export const prophetOperator: PostProcessingFactory<PostProcessingProphet> = (
formData,
queryObject,
) => {
if (formData.forecastEnabled) {
return {
operation: 'prophet',

View File

@ -17,36 +17,23 @@
* specific language governing permissions and limitationsxw
* under the License.
*/
import {
DTTM_ALIAS,
ensureIsArray,
isPhysicalColumn,
PostProcessingResample,
} from '@superset-ui/core';
import { PostProcessingResample } from '@superset-ui/core';
import { PostProcessingFactory } from './types';
export const resampleOperator: PostProcessingFactory<
PostProcessingResample | undefined
> = (formData, queryObject) => {
export const resampleOperator: PostProcessingFactory<PostProcessingResample> = (
formData,
queryObject,
) => {
const resampleZeroFill = formData.resample_method === 'zerofill';
const resampleMethod = resampleZeroFill ? 'asfreq' : formData.resample_method;
const resampleRule = formData.resample_rule;
if (resampleMethod && resampleRule) {
const groupby_columns = ensureIsArray(queryObject.columns).map(column => {
if (isPhysicalColumn(column)) {
return column;
}
return column.label;
});
return {
operation: 'resample',
options: {
method: resampleMethod,
rule: resampleRule,
fill_value: resampleZeroFill ? 0 : null,
time_column: formData.x_axis || DTTM_ALIAS,
groupby_columns,
},
};
}

View File

@ -18,39 +18,25 @@
* under the License.
*/
import {
ComparisionType,
ensureIsArray,
ensureIsInt,
PostProcessingCum,
PostProcessingRolling,
RollingType,
} from '@superset-ui/core';
import {
getMetricOffsetsMap,
isValidTimeCompare,
TIME_COMPARISON_SEPARATOR,
} from './utils';
import { getMetricOffsetsMap, isValidTimeCompare } from './utils';
import { PostProcessingFactory } from './types';
export const rollingWindowOperator: PostProcessingFactory<
PostProcessingRolling | PostProcessingCum | undefined
PostProcessingRolling | PostProcessingCum
> = (formData, queryObject) => {
let columns: (string | undefined)[];
if (isValidTimeCompare(formData, queryObject)) {
const metricsMap = getMetricOffsetsMap(formData, queryObject);
const comparisonType = formData.comparison_type;
if (comparisonType === ComparisionType.Values) {
// time compare type: actual values
columns = [
...Array.from(metricsMap.values()),
...Array.from(metricsMap.keys()),
];
} else {
// time compare type: difference / percentage / ratio
columns = Array.from(metricsMap.entries()).map(([offset, metric]) =>
[comparisonType, metric, offset].join(TIME_COMPARISON_SEPARATOR),
);
}
columns = [
...Array.from(metricsMap.values()),
...Array.from(metricsMap.keys()),
];
} else {
columns = ensureIsArray(queryObject.metrics).map(metric => {
if (typeof metric === 'string') {
@ -67,7 +53,6 @@ export const rollingWindowOperator: PostProcessingFactory<
options: {
operator: 'sum',
columns: columnsMap,
is_pivot_df: true,
},
};
}
@ -84,7 +69,6 @@ export const rollingWindowOperator: PostProcessingFactory<
window: ensureIsInt(formData.rolling_periods, 1),
min_periods: ensureIsInt(formData.min_periods, 0),
columns: columnsMap,
is_pivot_df: true,
},
};
}

View File

@ -20,9 +20,10 @@
import { DTTM_ALIAS, PostProcessingSort, RollingType } from '@superset-ui/core';
import { PostProcessingFactory } from './types';
export const sortOperator: PostProcessingFactory<
PostProcessingSort | undefined
> = (formData, queryObject) => {
export const sortOperator: PostProcessingFactory<PostProcessingSort> = (
formData,
queryObject,
) => {
const { x_axis: xAxis } = formData;
if (
(xAxis || queryObject.is_timeseries) &&

View File

@ -21,26 +21,25 @@ import { ComparisionType, PostProcessingCompare } from '@superset-ui/core';
import { getMetricOffsetsMap, isValidTimeCompare } from './utils';
import { PostProcessingFactory } from './types';
export const timeCompareOperator: PostProcessingFactory<
PostProcessingCompare | undefined
> = (formData, queryObject) => {
const comparisonType = formData.comparison_type;
const metricOffsetMap = getMetricOffsetsMap(formData, queryObject);
export const timeCompareOperator: PostProcessingFactory<PostProcessingCompare> =
(formData, queryObject) => {
const comparisonType = formData.comparison_type;
const metricOffsetMap = getMetricOffsetsMap(formData, queryObject);
if (
isValidTimeCompare(formData, queryObject) &&
comparisonType !== ComparisionType.Values
) {
return {
operation: 'compare',
options: {
source_columns: Array.from(metricOffsetMap.values()),
compare_columns: Array.from(metricOffsetMap.keys()),
compare_type: comparisonType,
drop_original_columns: true,
},
};
}
if (
isValidTimeCompare(formData, queryObject) &&
comparisonType !== ComparisionType.Values
) {
return {
operation: 'compare',
options: {
source_columns: Array.from(metricOffsetMap.values()),
compare_columns: Array.from(metricOffsetMap.keys()),
compare_type: comparisonType,
drop_original_columns: true,
},
};
}
return undefined;
};
return undefined;
};

View File

@ -18,54 +18,40 @@
* under the License.
*/
import {
ComparisionType,
DTTM_ALIAS,
ensureIsArray,
getColumnLabel,
NumpyFunction,
PostProcessingPivot,
} from '@superset-ui/core';
import {
getMetricOffsetsMap,
isValidTimeCompare,
TIME_COMPARISON_SEPARATOR,
} from './utils';
import { getMetricOffsetsMap, isValidTimeCompare } from './utils';
import { PostProcessingFactory } from './types';
export const timeComparePivotOperator: PostProcessingFactory<
PostProcessingPivot | undefined
> = (formData, queryObject) => {
const comparisonType = formData.comparison_type;
const metricOffsetMap = getMetricOffsetsMap(formData, queryObject);
export const timeComparePivotOperator: PostProcessingFactory<PostProcessingPivot> =
(formData, queryObject) => {
const metricOffsetMap = getMetricOffsetsMap(formData, queryObject);
if (isValidTimeCompare(formData, queryObject)) {
const valuesAgg = Object.fromEntries(
[...metricOffsetMap.values(), ...metricOffsetMap.keys()].map(metric => [
metric,
// use the 'mean' aggregates to avoid drop NaN
{ operator: 'mean' as NumpyFunction },
]),
);
const changeAgg = Object.fromEntries(
[...metricOffsetMap.entries()]
.map(([offset, metric]) =>
[comparisonType, metric, offset].join(TIME_COMPARISON_SEPARATOR),
)
// use the 'mean' aggregates to avoid drop NaN
.map(metric => [metric, { operator: 'mean' as NumpyFunction }]),
);
if (isValidTimeCompare(formData, queryObject)) {
const aggregates = Object.fromEntries(
[...metricOffsetMap.values(), ...metricOffsetMap.keys()].map(metric => [
metric,
// use the 'mean' aggregates to avoid drop NaN
{ operator: 'mean' as NumpyFunction },
]),
);
return {
operation: 'pivot',
options: {
index: [formData.x_axis || DTTM_ALIAS],
columns: ensureIsArray(queryObject.columns).map(getColumnLabel),
aggregates:
comparisonType === ComparisionType.Values ? valuesAgg : changeAgg,
drop_missing_columns: false,
},
};
}
return {
operation: 'pivot',
options: {
index: [formData.x_axis || DTTM_ALIAS],
columns: ensureIsArray(queryObject.columns).map(getColumnLabel),
drop_missing_columns: false,
flatten_columns: false,
reset_index: false,
aggregates,
},
};
}
return undefined;
};
return undefined;
};

View File

@ -0,0 +1,59 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { QueryObject, SqlaFormData } from '@superset-ui/core';
import { flattenOperator } from '@superset-ui/chart-controls';
const formData: SqlaFormData = {
metrics: [
'count(*)',
{ label: 'sum(val)', expressionType: 'SQL', sqlExpression: 'sum(val)' },
],
time_range: '2015 : 2016',
granularity: 'month',
datasource: 'foo',
viz_type: 'table',
};
const queryObject: QueryObject = {
metrics: [
'count(*)',
{ label: 'sum(val)', expressionType: 'SQL', sqlExpression: 'sum(val)' },
],
time_range: '2015 : 2016',
granularity: 'month',
post_processing: [
{
operation: 'pivot',
options: {
index: ['__timestamp'],
columns: ['nation'],
aggregates: {
'count(*)': {
operator: 'sum',
},
},
},
},
],
};
test('should do flattenOperator', () => {
expect(flattenOperator(formData, queryObject)).toEqual({
operation: 'flatten',
});
});

View File

@ -80,6 +80,8 @@ test('pivot by __timestamp without groupby', () => {
'sum(val)': { operator: 'mean' },
},
drop_missing_columns: false,
flatten_columns: false,
reset_index: false,
},
});
});
@ -101,6 +103,8 @@ test('pivot by __timestamp with groupby', () => {
'sum(val)': { operator: 'mean' },
},
drop_missing_columns: false,
flatten_columns: false,
reset_index: false,
},
});
});
@ -127,44 +131,8 @@ test('pivot by x_axis with groupby', () => {
'sum(val)': { operator: 'mean' },
},
drop_missing_columns: false,
},
});
});
test('timecompare in formdata', () => {
expect(
pivotOperator(
{
...formData,
comparison_type: 'values',
time_compare: ['1 year ago', '1 year later'],
},
{
...queryObject,
columns: ['foo', 'bar'],
is_timeseries: true,
},
),
).toEqual({
operation: 'pivot',
options: {
aggregates: {
'count(*)': { operator: 'mean' },
'count(*)__1 year ago': { operator: 'mean' },
'count(*)__1 year later': { operator: 'mean' },
'sum(val)': {
operator: 'mean',
},
'sum(val)__1 year ago': {
operator: 'mean',
},
'sum(val)__1 year later': {
operator: 'mean',
},
},
drop_missing_columns: false,
columns: ['foo', 'bar'],
index: ['__timestamp'],
flatten_columns: false,
reset_index: false,
},
});
});

View File

@ -16,7 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
import { AdhocColumn, QueryObject, SqlaFormData } from '@superset-ui/core';
import { QueryObject, SqlaFormData } from '@superset-ui/core';
import { resampleOperator } from '@superset-ui/chart-controls';
const formData: SqlaFormData = {
@ -74,8 +74,6 @@ test('should do resample on implicit time column', () => {
method: 'ffill',
rule: '1D',
fill_value: null,
time_column: '__timestamp',
groupby_columns: [],
},
});
});
@ -95,10 +93,8 @@ test('should do resample on x-axis', () => {
operation: 'resample',
options: {
fill_value: null,
groupby_columns: [],
method: 'ffill',
rule: '1D',
time_column: 'ds',
},
});
});
@ -115,81 +111,6 @@ test('should do zerofill resample', () => {
method: 'asfreq',
rule: '1D',
fill_value: 0,
time_column: '__timestamp',
groupby_columns: [],
},
});
});
test('should append physical column to resample', () => {
expect(
resampleOperator(
{ ...formData, resample_method: 'zerofill', resample_rule: '1D' },
{ ...queryObject, columns: ['column1', 'column2'] },
),
).toEqual({
operation: 'resample',
options: {
method: 'asfreq',
rule: '1D',
fill_value: 0,
time_column: '__timestamp',
groupby_columns: ['column1', 'column2'],
},
});
});
test('should append label of adhoc column and physical column to resample', () => {
expect(
resampleOperator(
{ ...formData, resample_method: 'zerofill', resample_rule: '1D' },
{
...queryObject,
columns: [
{
hasCustomLabel: true,
label: 'concat_a_b',
expressionType: 'SQL',
sqlExpression: "'a' + 'b'",
} as AdhocColumn,
'column2',
],
},
),
).toEqual({
operation: 'resample',
options: {
method: 'asfreq',
rule: '1D',
fill_value: 0,
time_column: '__timestamp',
groupby_columns: ['concat_a_b', 'column2'],
},
});
});
test('should append `undefined` if adhoc non-existing label', () => {
expect(
resampleOperator(
{ ...formData, resample_method: 'zerofill', resample_rule: '1D' },
{
...queryObject,
columns: [
{
sqlExpression: "'a' + 'b'",
} as AdhocColumn,
'column2',
],
},
),
).toEqual({
operation: 'resample',
options: {
method: 'asfreq',
rule: '1D',
fill_value: 0,
time_column: '__timestamp',
groupby_columns: [undefined, 'column2'],
},
});
});

View File

@ -79,7 +79,6 @@ test('rolling_type: cumsum', () => {
'count(*)': 'count(*)',
'sum(val)': 'sum(val)',
},
is_pivot_df: true,
},
});
});
@ -102,42 +101,13 @@ test('rolling_type: sum/mean/std', () => {
'count(*)': 'count(*)',
'sum(val)': 'sum(val)',
},
is_pivot_df: true,
},
});
});
});
test('rolling window and "actual values" in the time compare', () => {
expect(
rollingWindowOperator(
{
...formData,
rolling_type: 'cumsum',
comparison_type: 'values',
time_compare: ['1 year ago', '1 year later'],
},
queryObject,
),
).toEqual({
operation: 'cum',
options: {
operator: 'sum',
columns: {
'count(*)': 'count(*)',
'count(*)__1 year ago': 'count(*)__1 year ago',
'count(*)__1 year later': 'count(*)__1 year later',
'sum(val)': 'sum(val)',
'sum(val)__1 year ago': 'sum(val)__1 year ago',
'sum(val)__1 year later': 'sum(val)__1 year later',
},
is_pivot_df: true,
},
});
});
test('rolling window and "difference / percentage / ratio" in the time compare', () => {
const comparisionTypes = ['difference', 'percentage', 'ratio'];
test('should append compared metrics when sets time compare type', () => {
const comparisionTypes = ['values', 'difference', 'percentage', 'ratio'];
comparisionTypes.forEach(cType => {
expect(
rollingWindowOperator(
@ -154,12 +124,13 @@ test('rolling window and "difference / percentage / ratio" in the time compare',
options: {
operator: 'sum',
columns: {
[`${cType}__count(*)__count(*)__1 year ago`]: `${cType}__count(*)__count(*)__1 year ago`,
[`${cType}__count(*)__count(*)__1 year later`]: `${cType}__count(*)__count(*)__1 year later`,
[`${cType}__sum(val)__sum(val)__1 year ago`]: `${cType}__sum(val)__sum(val)__1 year ago`,
[`${cType}__sum(val)__sum(val)__1 year later`]: `${cType}__sum(val)__sum(val)__1 year later`,
'count(*)': 'count(*)',
'count(*)__1 year ago': 'count(*)__1 year ago',
'count(*)__1 year later': 'count(*)__1 year later',
'sum(val)': 'sum(val)',
'sum(val)__1 year ago': 'sum(val)__1 year ago',
'sum(val)__1 year later': 'sum(val)__1 year later',
},
is_pivot_df: true,
},
});
});

View File

@ -17,17 +17,23 @@
* under the License.
*/
import { QueryObject, SqlaFormData } from '@superset-ui/core';
import { timeCompareOperator, timeComparePivotOperator } from '../../../src';
import { timeCompareOperator } from '../../../src';
const formData: SqlaFormData = {
metrics: ['count(*)'],
metrics: [
'count(*)',
{ label: 'sum(val)', expressionType: 'SQL', sqlExpression: 'sum(val)' },
],
time_range: '2015 : 2016',
granularity: 'month',
datasource: 'foo',
viz_type: 'table',
};
const queryObject: QueryObject = {
metrics: ['count(*)'],
metrics: [
'count(*)',
{ label: 'sum(val)', expressionType: 'SQL', sqlExpression: 'sum(val)' },
],
time_range: '2015 : 2016',
granularity: 'month',
post_processing: [
@ -40,21 +46,26 @@ const queryObject: QueryObject = {
'count(*)': {
operator: 'mean',
},
'sum(val)': {
operator: 'mean',
},
},
drop_missing_columns: false,
flatten_columns: false,
reset_index: false,
},
},
{
operation: 'aggregation',
options: {
groupby: ['col1'],
aggregates: 'count',
aggregates: {},
},
},
],
};
test('time compare: skip transformation', () => {
test('should skip CompareOperator', () => {
expect(timeCompareOperator(formData, queryObject)).toEqual(undefined);
expect(
timeCompareOperator({ ...formData, time_compare: [] }, queryObject),
@ -80,7 +91,7 @@ test('time compare: skip transformation', () => {
).toEqual(undefined);
});
test('time compare: difference/percentage/ratio', () => {
test('should generate difference/percentage/ratio CompareOperator', () => {
const comparisionTypes = ['difference', 'percentage', 'ratio'];
comparisionTypes.forEach(cType => {
expect(
@ -95,108 +106,16 @@ test('time compare: difference/percentage/ratio', () => {
).toEqual({
operation: 'compare',
options: {
source_columns: ['count(*)', 'count(*)'],
compare_columns: ['count(*)__1 year ago', 'count(*)__1 year later'],
source_columns: ['count(*)', 'count(*)', 'sum(val)', 'sum(val)'],
compare_columns: [
'count(*)__1 year ago',
'count(*)__1 year later',
'sum(val)__1 year ago',
'sum(val)__1 year later',
],
compare_type: cType,
drop_original_columns: true,
},
});
});
});
test('time compare pivot: skip transformation', () => {
expect(timeComparePivotOperator(formData, queryObject)).toEqual(undefined);
expect(
timeComparePivotOperator({ ...formData, time_compare: [] }, queryObject),
).toEqual(undefined);
expect(
timeComparePivotOperator(
{ ...formData, comparison_type: null },
queryObject,
),
).toEqual(undefined);
expect(
timeCompareOperator(
{ ...formData, comparison_type: 'foobar' },
queryObject,
),
).toEqual(undefined);
});
test('time compare pivot: values', () => {
expect(
timeComparePivotOperator(
{
...formData,
comparison_type: 'values',
time_compare: ['1 year ago', '1 year later'],
},
queryObject,
),
).toEqual({
operation: 'pivot',
options: {
aggregates: {
'count(*)': { operator: 'mean' },
'count(*)__1 year ago': { operator: 'mean' },
'count(*)__1 year later': { operator: 'mean' },
},
drop_missing_columns: false,
columns: [],
index: ['__timestamp'],
},
});
});
test('time compare pivot: difference/percentage/ratio', () => {
const comparisionTypes = ['difference', 'percentage', 'ratio'];
comparisionTypes.forEach(cType => {
expect(
timeComparePivotOperator(
{
...formData,
comparison_type: cType,
time_compare: ['1 year ago', '1 year later'],
},
queryObject,
),
).toEqual({
operation: 'pivot',
options: {
aggregates: {
[`${cType}__count(*)__count(*)__1 year ago`]: { operator: 'mean' },
[`${cType}__count(*)__count(*)__1 year later`]: { operator: 'mean' },
},
drop_missing_columns: false,
columns: [],
index: ['__timestamp'],
},
});
});
});
test('time compare pivot on x-axis', () => {
expect(
timeComparePivotOperator(
{
...formData,
comparison_type: 'values',
time_compare: ['1 year ago', '1 year later'],
x_axis: 'ds',
},
queryObject,
),
).toEqual({
operation: 'pivot',
options: {
aggregates: {
'count(*)': { operator: 'mean' },
'count(*)__1 year ago': { operator: 'mean' },
'count(*)__1 year later': { operator: 'mean' },
},
drop_missing_columns: false,
columns: [],
index: ['ds'],
},
});
});

View File

@ -0,0 +1,137 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { QueryObject, SqlaFormData } from '@superset-ui/core';
import { timeCompareOperator, timeComparePivotOperator } from '../../../src';
const formData: SqlaFormData = {
metrics: [
'count(*)',
{ label: 'sum(val)', expressionType: 'SQL', sqlExpression: 'sum(val)' },
],
time_range: '2015 : 2016',
granularity: 'month',
datasource: 'foo',
viz_type: 'table',
};
const queryObject: QueryObject = {
metrics: [
'count(*)',
{ label: 'sum(val)', expressionType: 'SQL', sqlExpression: 'sum(val)' },
],
columns: ['foo', 'bar'],
time_range: '2015 : 2016',
granularity: 'month',
post_processing: [],
};
test('should skip pivot', () => {
expect(timeComparePivotOperator(formData, queryObject)).toEqual(undefined);
expect(
timeComparePivotOperator({ ...formData, time_compare: [] }, queryObject),
).toEqual(undefined);
expect(
timeComparePivotOperator(
{ ...formData, comparison_type: null },
queryObject,
),
).toEqual(undefined);
expect(
timeCompareOperator(
{ ...formData, comparison_type: 'foobar' },
queryObject,
),
).toEqual(undefined);
});
test('should pivot on any type of timeCompare', () => {
const anyTimeCompareTypes = ['values', 'difference', 'percentage', 'ratio'];
anyTimeCompareTypes.forEach(cType => {
expect(
timeComparePivotOperator(
{
...formData,
comparison_type: cType,
time_compare: ['1 year ago', '1 year later'],
},
{
...queryObject,
is_timeseries: true,
},
),
).toEqual({
operation: 'pivot',
options: {
aggregates: {
'count(*)': { operator: 'mean' },
'count(*)__1 year ago': { operator: 'mean' },
'count(*)__1 year later': { operator: 'mean' },
'sum(val)': { operator: 'mean' },
'sum(val)__1 year ago': {
operator: 'mean',
},
'sum(val)__1 year later': {
operator: 'mean',
},
},
drop_missing_columns: false,
flatten_columns: false,
reset_index: false,
columns: ['foo', 'bar'],
index: ['__timestamp'],
},
});
});
});
test('should pivot on x-axis', () => {
expect(
timeComparePivotOperator(
{
...formData,
comparison_type: 'values',
time_compare: ['1 year ago', '1 year later'],
x_axis: 'ds',
},
queryObject,
),
).toEqual({
operation: 'pivot',
options: {
aggregates: {
'count(*)': { operator: 'mean' },
'count(*)__1 year ago': { operator: 'mean' },
'count(*)__1 year later': { operator: 'mean' },
'sum(val)': {
operator: 'mean',
},
'sum(val)__1 year ago': {
operator: 'mean',
},
'sum(val)__1 year later': {
operator: 'mean',
},
},
drop_missing_columns: false,
columns: ['foo', 'bar'],
index: ['ds'],
flatten_columns: false,
reset_index: false,
},
});
});

View File

@ -64,25 +64,34 @@ export interface Aggregates {
};
}
export interface PostProcessingAggregation {
export type DefaultPostProcessing = undefined;
interface _PostProcessingAggregation {
operation: 'aggregation';
options: {
groupby: string[];
aggregates: Aggregates;
};
}
export type PostProcessingAggregation =
| _PostProcessingAggregation
| DefaultPostProcessing;
export interface PostProcessingBoxplot {
export type BoxPlotQueryObjectWhiskerType = 'tukey' | 'min/max' | 'percentile';
interface _PostProcessingBoxplot {
operation: 'boxplot';
options: {
groupby: string[];
metrics: string[];
whisker_type: 'tukey' | 'min/max' | 'percentile';
whisker_type: BoxPlotQueryObjectWhiskerType;
percentiles?: [number, number];
};
}
export type PostProcessingBoxplot =
| _PostProcessingBoxplot
| DefaultPostProcessing;
export interface PostProcessingContribution {
interface _PostProcessingContribution {
operation: 'contribution';
options?: {
orientation?: 'row' | 'column';
@ -90,8 +99,11 @@ export interface PostProcessingContribution {
rename_columns?: string[];
};
}
export type PostProcessingContribution =
| _PostProcessingContribution
| DefaultPostProcessing;
export interface PostProcessingPivot {
interface _PostProcessingPivot {
operation: 'pivot';
options: {
aggregates: Aggregates;
@ -107,8 +119,9 @@ export interface PostProcessingPivot {
reset_index?: boolean;
};
}
export type PostProcessingPivot = _PostProcessingPivot | DefaultPostProcessing;
export interface PostProcessingProphet {
interface _PostProcessingProphet {
operation: 'prophet';
options: {
time_grain: TimeGranularity;
@ -119,8 +132,11 @@ export interface PostProcessingProphet {
daily_seasonality?: boolean | number;
};
}
export type PostProcessingProphet =
| _PostProcessingProphet
| DefaultPostProcessing;
export interface PostProcessingDiff {
interface _PostProcessingDiff {
operation: 'diff';
options: {
columns: string[];
@ -128,28 +144,31 @@ export interface PostProcessingDiff {
axis: PandasAxis;
};
}
export type PostProcessingDiff = _PostProcessingDiff | DefaultPostProcessing;
export interface PostProcessingRolling {
interface _PostProcessingRolling {
operation: 'rolling';
options: {
rolling_type: RollingType;
window: number;
min_periods: number;
columns: string[];
is_pivot_df?: boolean;
};
}
export type PostProcessingRolling =
| _PostProcessingRolling
| DefaultPostProcessing;
export interface PostProcessingCum {
interface _PostProcessingCum {
operation: 'cum';
options: {
columns: string[];
operator: NumpyFunction;
is_pivot_df?: boolean;
};
}
export type PostProcessingCum = _PostProcessingCum | DefaultPostProcessing;
export interface PostProcessingCompare {
export interface _PostProcessingCompare {
operation: 'compare';
options: {
source_columns: string[];
@ -158,26 +177,39 @@ export interface PostProcessingCompare {
drop_original_columns: boolean;
};
}
export type PostProcessingCompare =
| _PostProcessingCompare
| DefaultPostProcessing;
export interface PostProcessingSort {
interface _PostProcessingSort {
operation: 'sort';
options: {
columns: Record<string, boolean>;
};
}
export type PostProcessingSort = _PostProcessingSort | DefaultPostProcessing;
export interface PostProcessingResample {
interface _PostProcessingResample {
operation: 'resample';
options: {
method: string;
rule: string;
fill_value?: number | null;
time_column: string;
// If AdhocColumn doesn't have a label, it will be undefined.
// todo: we have to give an explicit label for AdhocColumn.
groupby_columns?: Array<string | undefined>;
};
}
export type PostProcessingResample =
| _PostProcessingResample
| DefaultPostProcessing;
interface _PostProcessingFlatten {
operation: 'flatten';
options?: {
reset_index?: boolean;
};
}
export type PostProcessingFlatten =
| _PostProcessingFlatten
| DefaultPostProcessing;
/**
* Parameters for chart data postprocessing.
@ -194,7 +226,8 @@ export type PostProcessingRule =
| PostProcessingCum
| PostProcessingCompare
| PostProcessingSort
| PostProcessingResample;
| PostProcessingResample
| PostProcessingFlatten;
export function isPostProcessingAggregation(
rule?: PostProcessingRule,

View File

@ -18,11 +18,14 @@
*/
import {
buildQueryContext,
DTTM_ALIAS,
PostProcessingResample,
QueryFormData,
} from '@superset-ui/core';
import { rollingWindowOperator } from '@superset-ui/chart-controls';
import {
flattenOperator,
rollingWindowOperator,
sortOperator,
} from '@superset-ui/chart-controls';
const TIME_GRAIN_MAP: Record<string, string> = {
PT1S: 'S',
@ -47,12 +50,10 @@ const TIME_GRAIN_MAP: Record<string, string> = {
export default function buildQuery(formData: QueryFormData) {
return buildQueryContext(formData, baseQueryObject => {
// todo: move into full advanced analysis section here
const rollingProc = rollingWindowOperator(formData, baseQueryObject);
if (rollingProc) {
rollingProc.options = { ...rollingProc.options, is_pivot_df: false };
}
const { time_grain_sqla } = formData;
let resampleProc: PostProcessingResample | undefined;
let resampleProc: PostProcessingResample;
if (rollingProc && time_grain_sqla) {
const rule = TIME_GRAIN_MAP[time_grain_sqla];
if (rule) {
@ -62,7 +63,6 @@ export default function buildQuery(formData: QueryFormData) {
method: 'asfreq',
rule,
fill_value: null,
time_column: DTTM_ALIAS,
},
};
}
@ -72,16 +72,10 @@ export default function buildQuery(formData: QueryFormData) {
...baseQueryObject,
is_timeseries: true,
post_processing: [
{
operation: 'sort',
options: {
columns: {
[DTTM_ALIAS]: true,
},
},
},
sortOperator(formData, baseQueryObject),
resampleProc,
rollingProc,
flattenOperator(formData, baseQueryObject),
],
},
];

View File

@ -22,7 +22,7 @@ import {
QueryObject,
normalizeOrderBy,
} from '@superset-ui/core';
import { pivotOperator } from '@superset-ui/chart-controls';
import { flattenOperator, pivotOperator } from '@superset-ui/chart-controls';
export default function buildQuery(formData: QueryFormData) {
const {
@ -66,6 +66,7 @@ export default function buildQuery(formData: QueryFormData) {
is_timeseries: true,
post_processing: [
pivotOperator(formData1, { ...baseQueryObject, is_timeseries: true }),
flattenOperator(formData1, { ...baseQueryObject, is_timeseries: true }),
],
} as QueryObject;
return [normalizeOrderBy(queryObjectA)];
@ -77,6 +78,7 @@ export default function buildQuery(formData: QueryFormData) {
is_timeseries: true,
post_processing: [
pivotOperator(formData2, { ...baseQueryObject, is_timeseries: true }),
flattenOperator(formData2, { ...baseQueryObject, is_timeseries: true }),
],
} as QueryObject;
return [normalizeOrderBy(queryObjectB)];

View File

@ -22,42 +22,54 @@ import {
ensureIsArray,
QueryFormData,
normalizeOrderBy,
RollingType,
PostProcessingPivot,
} from '@superset-ui/core';
import {
rollingWindowOperator,
timeCompareOperator,
isValidTimeCompare,
sortOperator,
pivotOperator,
resampleOperator,
contributionOperator,
prophetOperator,
timeComparePivotOperator,
flattenOperator,
} from '@superset-ui/chart-controls';
export default function buildQuery(formData: QueryFormData) {
const { x_axis, groupby } = formData;
const is_timeseries = x_axis === DTTM_ALIAS || !x_axis;
return buildQueryContext(formData, baseQueryObject => {
const pivotOperatorInRuntime: PostProcessingPivot | undefined =
pivotOperator(formData, {
...baseQueryObject,
index: x_axis,
is_timeseries,
});
if (
pivotOperatorInRuntime &&
Object.values(RollingType).includes(formData.rolling_type)
) {
pivotOperatorInRuntime.options = {
...pivotOperatorInRuntime.options,
...{
flatten_columns: false,
reset_index: false,
},
};
}
/* the `pivotOperatorInRuntime` determines how to pivot the dataframe returned from the raw query.
1. If it's a time compared query, there will return a pivoted dataframe that append time compared metrics. for instance:
MAX(value) MAX(value)__1 year ago MIN(value) MIN(value)__1 year ago
city LA LA LA LA
__timestamp
2015-01-01 568.0 671.0 5.0 6.0
2015-02-01 407.0 649.0 4.0 3.0
2015-03-01 318.0 465.0 0.0 3.0
2. If it's a normal query, there will return a pivoted dataframe.
MAX(value) MIN(value)
city LA LA
__timestamp
2015-01-01 568.0 5.0
2015-02-01 407.0 4.0
2015-03-01 318.0 0.0
*/
const pivotOperatorInRuntime: PostProcessingPivot = isValidTimeCompare(
formData,
baseQueryObject,
)
? timeComparePivotOperator(formData, baseQueryObject)
: pivotOperator(formData, {
...baseQueryObject,
index: x_axis,
is_timeseries,
});
return [
{
@ -70,13 +82,16 @@ export default function buildQuery(formData: QueryFormData) {
time_offsets: isValidTimeCompare(formData, baseQueryObject)
? formData.time_compare
: [],
/* Note that:
1. The resample, rolling, cum, timeCompare operators should be after pivot.
2. the flatOperator makes multiIndex Dataframe into flat Dataframe
*/
post_processing: [
resampleOperator(formData, baseQueryObject),
timeCompareOperator(formData, baseQueryObject),
sortOperator(formData, { ...baseQueryObject, is_timeseries: true }),
// in order to be able to rolling in multiple series, must do pivot before rollingOperator
pivotOperatorInRuntime,
rollingWindowOperator(formData, baseQueryObject),
timeCompareOperator(formData, baseQueryObject),
resampleOperator(formData, baseQueryObject),
flattenOperator(formData, baseQueryObject),
contributionOperator(formData, baseQueryObject),
prophetOperator(formData, baseQueryObject),
],

View File

@ -767,6 +767,7 @@ class ChartDataPostProcessingOperationSchema(Schema):
"diff",
"compare",
"resample",
"flatten",
)
),
example="aggregate",

View File

@ -36,7 +36,11 @@ from superset.common.utils import dataframe_utils as df_utils
from superset.common.utils.query_cache_manager import QueryCacheManager
from superset.connectors.base.models import BaseDatasource
from superset.constants import CacheRegion
from superset.exceptions import QueryObjectValidationError, SupersetException
from superset.exceptions import (
InvalidPostProcessingError,
QueryObjectValidationError,
SupersetException,
)
from superset.extensions import cache_manager, security_manager
from superset.models.helpers import QueryResult
from superset.utils import csv
@ -196,7 +200,11 @@ class QueryContextProcessor:
query += ";\n\n".join(queries)
query += ";\n\n"
df = query_object.exec_post_processing(df)
# Re-raising QueryObjectValidationError
try:
df = query_object.exec_post_processing(df)
except InvalidPostProcessingError as ex:
raise QueryObjectValidationError from ex
result.df = df
result.query = query

View File

@ -17,6 +17,7 @@
# pylint: disable=invalid-name
from __future__ import annotations
import json
import logging
from datetime import datetime, timedelta
from pprint import pformat
@ -27,6 +28,7 @@ from pandas import DataFrame
from superset.common.chart_data import ChartDataResultType
from superset.exceptions import (
InvalidPostProcessingError,
QueryClauseValidationException,
QueryObjectValidationError,
)
@ -337,6 +339,10 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
}
return query_object_dict
def __repr__(self) -> str:
# we use `print` or `logging` output QueryObject
return json.dumps(self.to_dict(), sort_keys=True, default=str,)
def cache_key(self, **extra: Any) -> str:
"""
The cache key is made out of the key/values from to_dict(), plus any
@ -398,15 +404,15 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
:raises QueryObjectValidationError: If the post processing operation
is incorrect
"""
logger.debug("post_processing: %s", pformat(self.post_processing))
logger.debug("post_processing: \n %s", pformat(self.post_processing))
for post_process in self.post_processing:
operation = post_process.get("operation")
if not operation:
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_("`operation` property of post processing object undefined")
)
if not hasattr(pandas_postprocessing, operation):
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_(
"Unsupported post processing operation: %(operation)s",
type=operation,

View File

@ -190,6 +190,10 @@ class QueryObjectValidationError(SupersetException):
status = 400
class InvalidPostProcessingError(SupersetException):
status = 400
class CacheLoadError(SupersetException):
status = 404

View File

@ -20,6 +20,7 @@ from superset.utils.pandas_postprocessing.compare import compare
from superset.utils.pandas_postprocessing.contribution import contribution
from superset.utils.pandas_postprocessing.cum import cum
from superset.utils.pandas_postprocessing.diff import diff
from superset.utils.pandas_postprocessing.flatten import flatten
from superset.utils.pandas_postprocessing.geography import (
geodetic_parse,
geohash_decode,
@ -49,5 +50,6 @@ __all__ = [
"rolling",
"select",
"sort",
"flatten",
"_flatten_column_after_pivot",
]

View File

@ -35,7 +35,7 @@ def aggregate(
:param groupby: columns to aggregate
:param aggregates: A mapping from metric column to the function used to
aggregate values.
:raises QueryObjectValidationError: If the request in incorrect
:raises InvalidPostProcessingError: If the request in incorrect
"""
aggregates = aggregates or {}
aggregate_funcs = _get_aggregate_funcs(df, aggregates)

View File

@ -20,7 +20,7 @@ import numpy as np
from flask_babel import gettext as _
from pandas import DataFrame, Series
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import InvalidPostProcessingError
from superset.utils.core import PostProcessingBoxplotWhiskerType
from superset.utils.pandas_postprocessing.aggregate import aggregate
@ -84,7 +84,7 @@ def boxplot(
or not isinstance(percentiles[1], (int, float))
or percentiles[0] >= percentiles[1]
):
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_(
"percentiles must be a list or tuple with two numeric values, "
"of which the first is lower than the second value"

View File

@ -21,7 +21,7 @@ from flask_babel import gettext as _
from pandas import DataFrame
from superset.constants import PandasPostprocessingCompare
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import InvalidPostProcessingError
from superset.utils.core import TIME_COMPARISION
from superset.utils.pandas_postprocessing.utils import validate_column_args
@ -31,7 +31,7 @@ def compare( # pylint: disable=too-many-arguments
df: DataFrame,
source_columns: List[str],
compare_columns: List[str],
compare_type: Optional[PandasPostprocessingCompare],
compare_type: PandasPostprocessingCompare,
drop_original_columns: Optional[bool] = False,
precision: Optional[int] = 4,
) -> DataFrame:
@ -46,31 +46,38 @@ def compare( # pylint: disable=too-many-arguments
compare columns.
:param precision: Round a change rate to a variable number of decimal places.
:return: DataFrame with compared columns.
:raises QueryObjectValidationError: If the request in incorrect.
:raises InvalidPostProcessingError: If the request in incorrect.
"""
if len(source_columns) != len(compare_columns):
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_("`compare_columns` must have the same length as `source_columns`.")
)
if compare_type not in tuple(PandasPostprocessingCompare):
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_("`compare_type` must be `difference`, `percentage` or `ratio`")
)
if len(source_columns) == 0:
return df
for s_col, c_col in zip(source_columns, compare_columns):
s_df = df.loc[:, [s_col]]
s_df.rename(columns={s_col: "__intermediate"}, inplace=True)
c_df = df.loc[:, [c_col]]
c_df.rename(columns={c_col: "__intermediate"}, inplace=True)
if compare_type == PandasPostprocessingCompare.DIFF:
diff_series = df[s_col] - df[c_col]
diff_df = c_df - s_df
elif compare_type == PandasPostprocessingCompare.PCT:
diff_series = (
((df[s_col] - df[c_col]) / df[c_col]).astype(float).round(precision)
)
# https://en.wikipedia.org/wiki/Relative_change_and_difference#Percentage_change
diff_df = ((c_df - s_df) / s_df).astype(float).round(precision)
else:
# compare_type == "ratio"
diff_series = (df[s_col] / df[c_col]).astype(float).round(precision)
diff_df = diff_series.to_frame(
name=TIME_COMPARISION.join([compare_type, s_col, c_col])
diff_df = (c_df / s_df).astype(float).round(precision)
diff_df.rename(
columns={
"__intermediate": TIME_COMPARISION.join([compare_type, s_col, c_col])
},
inplace=True,
)
df = pd.concat([df, diff_df], axis=1)

View File

@ -20,7 +20,7 @@ from typing import List, Optional
from flask_babel import gettext as _
from pandas import DataFrame
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import InvalidPostProcessingError
from superset.utils.core import PostProcessingContributionOrientation
from superset.utils.pandas_postprocessing.utils import validate_column_args
@ -55,7 +55,7 @@ def contribution(
numeric_columns = numeric_df.columns.tolist()
for col in columns:
if col not in numeric_columns:
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_(
'Column "%(column)s" is not numeric or does not '
"exists in the query results.",
@ -65,7 +65,7 @@ def contribution(
columns = columns or numeric_df.columns
rename_columns = rename_columns or columns
if len(rename_columns) != len(columns):
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_("`rename_columns` must have the same length as `columns`.")
)
# limit to selected columns

View File

@ -14,27 +14,21 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, Optional
from typing import Dict
from flask_babel import gettext as _
from pandas import DataFrame
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import InvalidPostProcessingError
from superset.utils.pandas_postprocessing.utils import (
_append_columns,
_flatten_column_after_pivot,
ALLOWLIST_CUMULATIVE_FUNCTIONS,
validate_column_args,
)
@validate_column_args("columns")
def cum(
df: DataFrame,
operator: str,
columns: Optional[Dict[str, str]] = None,
is_pivot_df: bool = False,
) -> DataFrame:
def cum(df: DataFrame, operator: str, columns: Dict[str, str],) -> DataFrame:
"""
Calculate cumulative sum/product/min/max for select columns.
@ -45,29 +39,16 @@ def cum(
`y2` based on cumulative values calculated from `y`, leaving the original
column `y` unchanged.
:param operator: cumulative operator, e.g. `sum`, `prod`, `min`, `max`
:param is_pivot_df: Dataframe is pivoted or not
:return: DataFrame with cumulated columns
"""
columns = columns or {}
if is_pivot_df:
df_cum = df
else:
df_cum = df[columns.keys()]
df_cum = df.loc[:, columns.keys()]
operation = "cum" + operator
if operation not in ALLOWLIST_CUMULATIVE_FUNCTIONS or not hasattr(
df_cum, operation
):
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_("Invalid cumulative operator: %(operator)s", operator=operator)
)
if is_pivot_df:
df_cum = getattr(df_cum, operation)()
agg_in_pivot_df = df.columns.get_level_values(0).drop_duplicates().to_list()
agg: Dict[str, Dict[str, Any]] = {col: {} for col in agg_in_pivot_df}
df_cum.columns = [
_flatten_column_after_pivot(col, agg) for col in df_cum.columns
]
df_cum.reset_index(level=0, inplace=True)
else:
df_cum = _append_columns(df, getattr(df_cum, operation)(), columns)
df_cum = _append_columns(df, getattr(df_cum, operation)(), columns)
return df_cum

View File

@ -44,7 +44,7 @@ def diff(
:param periods: periods to shift for calculating difference.
:param axis: 0 for row, 1 for column. default 0.
:return: DataFrame with diffed columns
:raises QueryObjectValidationError: If the request in incorrect
:raises InvalidPostProcessingError: If the request in incorrect
"""
df_diff = df[columns.keys()]
df_diff = df_diff.diff(periods=periods, axis=axis)

View File

@ -0,0 +1,81 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pandas as pd
from superset.utils.pandas_postprocessing.utils import (
_is_multi_index_on_columns,
FLAT_COLUMN_SEPARATOR,
)
def flatten(df: pd.DataFrame, reset_index: bool = True,) -> pd.DataFrame:
"""
Convert N-dimensional DataFrame to a flat DataFrame
:param df: N-dimensional DataFrame.
:param reset_index: Convert index to column when df.index isn't RangeIndex
:return: a flat DataFrame
Examples
-----------
Convert DatetimeIndex into columns.
>>> index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03",])
>>> index.name = "__timestamp"
>>> df = pd.DataFrame(index=index, data={"metric": [1, 2, 3]})
>>> df
metric
__timestamp
2021-01-01 1
2021-01-02 2
2021-01-03 3
>>> df = flatten(df)
>>> df
__timestamp metric
0 2021-01-01 1
1 2021-01-02 2
2 2021-01-03 3
Convert DatetimeIndex and MultipleIndex into columns
>>> iterables = [["foo", "bar"], ["one", "two"]]
>>> columns = pd.MultiIndex.from_product(iterables, names=["level1", "level2"])
>>> df = pd.DataFrame(index=index, columns=columns, data=1)
>>> df
level1 foo bar
level2 one two one two
__timestamp
2021-01-01 1 1 1 1
2021-01-02 1 1 1 1
2021-01-03 1 1 1 1
>>> flatten(df)
__timestamp foo, one foo, two bar, one bar, two
0 2021-01-01 1 1 1 1
1 2021-01-02 1 1 1 1
2 2021-01-03 1 1 1 1
"""
if _is_multi_index_on_columns(df):
# every cell should be converted to string
df.columns = [
FLAT_COLUMN_SEPARATOR.join([str(cell) for cell in series])
for series in df.columns.to_flat_index()
]
if reset_index and not isinstance(df.index, pd.RangeIndex):
df = df.reset_index(level=0)
return df

View File

@ -21,7 +21,7 @@ from flask_babel import gettext as _
from geopy.point import Point
from pandas import DataFrame
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import InvalidPostProcessingError
from superset.utils.pandas_postprocessing.utils import _append_columns
@ -46,7 +46,7 @@ def geohash_decode(
df, lonlat_df, {"latitude": latitude, "longitude": longitude}
)
except ValueError as ex:
raise QueryObjectValidationError(_("Invalid geohash string")) from ex
raise InvalidPostProcessingError(_("Invalid geohash string")) from ex
def geohash_encode(
@ -69,7 +69,7 @@ def geohash_encode(
)
return _append_columns(df, encode_df, {"geohash": geohash})
except ValueError as ex:
raise QueryObjectValidationError(_("Invalid longitude/latitude")) from ex
raise InvalidPostProcessingError(_("Invalid longitude/latitude")) from ex
def geodetic_parse(
@ -111,4 +111,4 @@ def geodetic_parse(
columns["altitude"] = altitude
return _append_columns(df, geodetic_df, columns)
except ValueError as ex:
raise QueryObjectValidationError(_("Invalid geodetic string")) from ex
raise InvalidPostProcessingError(_("Invalid geodetic string")) from ex

View File

@ -20,7 +20,7 @@ from flask_babel import gettext as _
from pandas import DataFrame
from superset.constants import NULL_STRING, PandasAxis
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import InvalidPostProcessingError
from superset.utils.pandas_postprocessing.utils import (
_flatten_column_after_pivot,
_get_aggregate_funcs,
@ -64,14 +64,14 @@ def pivot( # pylint: disable=too-many-arguments,too-many-locals
:param flatten_columns: Convert column names to strings
:param reset_index: Convert index to column
:return: A pivot table
:raises QueryObjectValidationError: If the request in incorrect
:raises InvalidPostProcessingError: If the request in incorrect
"""
if not index:
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_("Pivot operation requires at least one index")
)
if not aggregates:
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_("Pivot operation must include at least one aggregate")
)

View File

@ -20,7 +20,7 @@ from typing import Optional, Union
from flask_babel import gettext as _
from pandas import DataFrame
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import InvalidPostProcessingError
from superset.utils.core import DTTM_ALIAS
from superset.utils.pandas_postprocessing.utils import PROPHET_TIME_GRAIN_MAP
@ -58,7 +58,7 @@ def _prophet_fit_and_predict( # pylint: disable=too-many-arguments
prophet_logger.setLevel(logging.CRITICAL)
prophet_logger.setLevel(logging.NOTSET)
except ModuleNotFoundError as ex:
raise QueryObjectValidationError(_("`prophet` package not installed")) from ex
raise InvalidPostProcessingError(_("`prophet` package not installed")) from ex
model = Prophet(
interval_width=confidence_interval,
yearly_seasonality=yearly_seasonality,
@ -111,24 +111,24 @@ def prophet( # pylint: disable=too-many-arguments
index = index or DTTM_ALIAS
# validate inputs
if not time_grain:
raise QueryObjectValidationError(_("Time grain missing"))
raise InvalidPostProcessingError(_("Time grain missing"))
if time_grain not in PROPHET_TIME_GRAIN_MAP:
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_("Unsupported time grain: %(time_grain)s", time_grain=time_grain,)
)
freq = PROPHET_TIME_GRAIN_MAP[time_grain]
# check type at runtime due to marhsmallow schema not being able to handle
# union types
if not isinstance(periods, int) or periods < 0:
raise QueryObjectValidationError(_("Periods must be a whole number"))
raise InvalidPostProcessingError(_("Periods must be a whole number"))
if not confidence_interval or confidence_interval <= 0 or confidence_interval >= 1:
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_("Confidence interval must be between 0 and 1 (exclusive)")
)
if index not in df.columns:
raise QueryObjectValidationError(_("DataFrame must include temporal column"))
raise InvalidPostProcessingError(_("DataFrame must include temporal column"))
if len(df.columns) < 2:
raise QueryObjectValidationError(_("DataFrame include at least one series"))
raise InvalidPostProcessingError(_("DataFrame include at least one series"))
target_df = DataFrame()
for column in [column for column in df.columns if column != index]:

View File

@ -14,48 +14,35 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Optional, Tuple, Union
from typing import Optional, Union
from pandas import DataFrame
import pandas as pd
from flask_babel import gettext as _
from superset.utils.pandas_postprocessing.utils import validate_column_args
from superset.exceptions import InvalidPostProcessingError
@validate_column_args("groupby_columns")
def resample( # pylint: disable=too-many-arguments
df: DataFrame,
def resample(
df: pd.DataFrame,
rule: str,
method: str,
time_column: str,
groupby_columns: Optional[Tuple[Optional[str], ...]] = None,
fill_value: Optional[Union[float, int]] = None,
) -> DataFrame:
) -> pd.DataFrame:
"""
support upsampling in resample
:param df: DataFrame to resample.
:param rule: The offset string representing target conversion.
:param method: How to fill the NaN value after resample.
:param time_column: existing columns in DataFrame.
:param groupby_columns: columns except time_column in dataframe
:param fill_value: What values do fill missing.
:return: DataFrame after resample
:raises QueryObjectValidationError: If the request in incorrect
:raises InvalidPostProcessingError: If the request in incorrect
"""
if not isinstance(df.index, pd.DatetimeIndex):
raise InvalidPostProcessingError(_("Resample operation requires DatetimeIndex"))
def _upsampling(_df: DataFrame) -> DataFrame:
_df = _df.set_index(time_column)
if method == "asfreq" and fill_value is not None:
return _df.resample(rule).asfreq(fill_value=fill_value)
return getattr(_df.resample(rule), method)()
if groupby_columns:
df = (
df.set_index(keys=list(groupby_columns))
.groupby(by=list(groupby_columns))
.apply(_upsampling)
)
df = df.reset_index().set_index(time_column).sort_index()
if method == "asfreq" and fill_value is not None:
_df = df.resample(rule).asfreq(fill_value=fill_value)
else:
df = _upsampling(df)
return df.reset_index()
_df = getattr(df.resample(rule), method)()
return _df

View File

@ -19,10 +19,9 @@ from typing import Any, Dict, Optional, Union
from flask_babel import gettext as _
from pandas import DataFrame
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import InvalidPostProcessingError
from superset.utils.pandas_postprocessing.utils import (
_append_columns,
_flatten_column_after_pivot,
DENYLIST_ROLLING_FUNCTIONS,
validate_column_args,
)
@ -32,13 +31,12 @@ from superset.utils.pandas_postprocessing.utils import (
def rolling( # pylint: disable=too-many-arguments
df: DataFrame,
rolling_type: str,
columns: Optional[Dict[str, str]] = None,
columns: Dict[str, str],
window: Optional[int] = None,
rolling_type_options: Optional[Dict[str, Any]] = None,
center: bool = False,
win_type: Optional[str] = None,
min_periods: Optional[int] = None,
is_pivot_df: bool = False,
) -> DataFrame:
"""
Apply a rolling window on the dataset. See the Pandas docs for further details:
@ -58,21 +56,17 @@ def rolling( # pylint: disable=too-many-arguments
:param win_type: Type of window function.
:param min_periods: The minimum amount of periods required for a row to be included
in the result set.
:param is_pivot_df: Dataframe is pivoted or not
:return: DataFrame with the rolling columns
:raises QueryObjectValidationError: If the request in incorrect
:raises InvalidPostProcessingError: If the request in incorrect
"""
rolling_type_options = rolling_type_options or {}
columns = columns or {}
if is_pivot_df:
df_rolling = df
else:
df_rolling = df[columns.keys()]
df_rolling = df.loc[:, columns.keys()]
kwargs: Dict[str, Union[str, int]] = {}
if window is None:
raise QueryObjectValidationError(_("Undefined window for rolling operation"))
raise InvalidPostProcessingError(_("Undefined window for rolling operation"))
if window == 0:
raise QueryObjectValidationError(_("Window must be > 0"))
raise InvalidPostProcessingError(_("Window must be > 0"))
kwargs["window"] = window
if min_periods is not None:
@ -86,13 +80,13 @@ def rolling( # pylint: disable=too-many-arguments
if rolling_type not in DENYLIST_ROLLING_FUNCTIONS or not hasattr(
df_rolling, rolling_type
):
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_("Invalid rolling_type: %(type)s", type=rolling_type)
)
try:
df_rolling = getattr(df_rolling, rolling_type)(**rolling_type_options)
except TypeError as ex:
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_(
"Invalid options for %(rolling_type)s: %(options)s",
rolling_type=rolling_type,
@ -100,15 +94,7 @@ def rolling( # pylint: disable=too-many-arguments
)
) from ex
if is_pivot_df:
agg_in_pivot_df = df.columns.get_level_values(0).drop_duplicates().to_list()
agg: Dict[str, Dict[str, Any]] = {col: {} for col in agg_in_pivot_df}
df_rolling.columns = [
_flatten_column_after_pivot(col, agg) for col in df_rolling.columns
]
df_rolling.reset_index(level=0, inplace=True)
else:
df_rolling = _append_columns(df, df_rolling, columns)
df_rolling = _append_columns(df, df_rolling, columns)
if min_periods:
df_rolling = df_rolling[min_periods:]

View File

@ -42,7 +42,7 @@ def select(
For instance, `{'y': 'y2'}` will rename the column `y` to
`y2`.
:return: Subset of columns in original DataFrame
:raises QueryObjectValidationError: If the request in incorrect
:raises InvalidPostProcessingError: If the request in incorrect
"""
df_select = df.copy(deep=False)
if columns:

View File

@ -30,6 +30,6 @@ def sort(df: DataFrame, columns: Dict[str, bool]) -> DataFrame:
:param columns: columns by by which to sort. The key specifies the column name,
value specifies if sorting in ascending order.
:return: Sorted DataFrame
:raises QueryObjectValidationError: If the request in incorrect
:raises InvalidPostProcessingError: If the request in incorrect
"""
return df.sort_values(by=list(columns.keys()), ascending=list(columns.values()))

View File

@ -18,10 +18,11 @@ from functools import partial
from typing import Any, Callable, Dict, Tuple, Union
import numpy as np
import pandas as pd
from flask_babel import gettext as _
from pandas import DataFrame, NamedAgg, Timestamp
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import InvalidPostProcessingError
NUMPY_FUNCTIONS = {
"average": np.average,
@ -91,6 +92,8 @@ PROPHET_TIME_GRAIN_MAP = {
"P1W/1970-01-04T00:00:00Z": "W",
}
FLAT_COLUMN_SEPARATOR = ", "
def _flatten_column_after_pivot(
column: Union[float, Timestamp, str, Tuple[str, ...]],
@ -113,21 +116,26 @@ def _flatten_column_after_pivot(
# drop aggregate for single aggregate pivots with multiple groupings
# from column name (aggregates always come first in column name)
column = column[1:]
return ", ".join([str(col) for col in column])
return FLAT_COLUMN_SEPARATOR.join([str(col) for col in column])
def _is_multi_index_on_columns(df: DataFrame) -> bool:
return isinstance(df.columns, pd.MultiIndex)
def validate_column_args(*argnames: str) -> Callable[..., Any]:
def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
def wrapped(df: DataFrame, **options: Any) -> Any:
if options.get("is_pivot_df"):
# skip validation when pivot Dataframe
return func(df, **options)
columns = df.columns.tolist()
if _is_multi_index_on_columns(df):
# MultiIndex column validate first level
columns = df.columns.get_level_values(0)
else:
columns = df.columns.tolist()
for name in argnames:
if name in options and not all(
elem in columns for elem in options.get(name) or []
):
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_("Referenced columns not available in DataFrame.")
)
return func(df, **options)
@ -152,14 +160,14 @@ def _get_aggregate_funcs(
for name, agg_obj in aggregates.items():
column = agg_obj.get("column", name)
if column not in df:
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_(
"Column referenced by aggregate is undefined: %(column)s",
column=column,
)
)
if "operator" not in agg_obj:
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_("Operator undefined for aggregator: %(name)s", name=name,)
)
operator = agg_obj["operator"]
@ -168,7 +176,7 @@ def _get_aggregate_funcs(
else:
func = NUMPY_FUNCTIONS.get(operator)
if not func:
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_("Invalid numpy function: %(operator)s", operator=operator,)
)
options = agg_obj.get("options", {})
@ -186,6 +194,8 @@ def _append_columns(
assign method, which overwrites the original column in `base_df` if the column
already exists, and appends the column if the name is not defined.
Note that! this is a memory-intensive operation.
:param base_df: DataFrame which to use as the base
:param append_df: DataFrame from which to select data.
:param columns: columns on which to append, mapping source column to
@ -196,6 +206,10 @@ def _append_columns(
in `base_df` unchanged.
:return: new DataFrame with combined data from `base_df` and `append_df`
"""
return base_df.assign(
**{target: append_df[source] for source, target in columns.items()}
)
if all(key == value for key, value in columns.items()):
# make sure to return a new DataFrame instead of changing the `base_df`.
_base_df = base_df.copy()
_base_df.loc[:, columns.keys()] = append_df
return _base_df
append_df = append_df.rename(columns=columns)
return pd.concat([base_df, append_df], axis="columns")

View File

@ -172,18 +172,21 @@ POSTPROCESSING_OPERATIONS = {
{
"operation": "aggregate",
"options": {
"groupby": ["gender"],
"groupby": ["name"],
"aggregates": {
"q1": {
"operator": "percentile",
"column": "sum__num",
"options": {"q": 25},
# todo: rename "interpolation" to "method" when we updated
# numpy.
# https://numpy.org/doc/stable/reference/generated/numpy.percentile.html
"options": {"q": 25, "interpolation": "lower"},
},
"median": {"operator": "median", "column": "sum__num",},
},
},
},
{"operation": "sort", "options": {"columns": {"q1": False, "gender": True},},},
{"operation": "sort", "options": {"columns": {"q1": False, "name": True},},},
]
}

View File

@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import datetime
import re
import time
from typing import Any, Dict
@ -30,7 +29,7 @@ from superset.common.query_object import QueryObject
from superset.connectors.connector_registry import ConnectorRegistry
from superset.connectors.sqla.models import SqlMetric
from superset.extensions import cache_manager
from superset.utils.core import AdhocMetricExpressionType, backend
from superset.utils.core import AdhocMetricExpressionType, backend, QueryStatus
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
@ -91,8 +90,9 @@ class TestQueryContext(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_cache(self):
table_name = "birth_names"
table = self.get_table(name=table_name)
payload = get_query_context(table_name, table.id)
payload = get_query_context(
query_name=table_name, add_postprocessing_operations=True,
)
payload["force"] = True
query_context = ChartDataQueryContextSchema().load(payload)
@ -100,6 +100,10 @@ class TestQueryContext(SupersetTestCase):
query_cache_key = query_context.query_cache_key(query_object)
response = query_context.get_payload(cache_query_context=True)
# MUST BE a successful query
query_dump = response["queries"][0]
assert query_dump["status"] == QueryStatus.SUCCESS
cache_key = response["cache_key"]
assert cache_key is not None

View File

@ -16,7 +16,7 @@
# under the License.
import pytest
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import InvalidPostProcessingError
from superset.utils.core import PostProcessingBoxplotWhiskerType
from superset.utils.pandas_postprocessing import boxplot
from tests.unit_tests.fixtures.dataframes import names_df
@ -90,7 +90,7 @@ def test_boxplot_percentile():
def test_boxplot_percentile_incorrect_params():
with pytest.raises(QueryObjectValidationError):
with pytest.raises(InvalidPostProcessingError):
boxplot(
df=names_df,
groupby=["region"],
@ -98,7 +98,7 @@ def test_boxplot_percentile_incorrect_params():
metrics=["cars"],
)
with pytest.raises(QueryObjectValidationError):
with pytest.raises(InvalidPostProcessingError):
boxplot(
df=names_df,
groupby=["region"],
@ -107,7 +107,7 @@ def test_boxplot_percentile_incorrect_params():
percentiles=[10],
)
with pytest.raises(QueryObjectValidationError):
with pytest.raises(InvalidPostProcessingError):
boxplot(
df=names_df,
groupby=["region"],
@ -116,7 +116,7 @@ def test_boxplot_percentile_incorrect_params():
percentiles=[90, 10],
)
with pytest.raises(QueryObjectValidationError):
with pytest.raises(InvalidPostProcessingError):
boxplot(
df=names_df,
groupby=["region"],

View File

@ -14,49 +14,220 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pandas as pd
from superset.utils.pandas_postprocessing import compare
from tests.unit_tests.fixtures.dataframes import timeseries_df2
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
from superset.constants import PandasPostprocessingCompare as PPC
from superset.utils import pandas_postprocessing as pp
from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
from tests.unit_tests.fixtures.dataframes import multiple_metrics_df, timeseries_df2
def test_compare():
def test_compare_should_not_side_effect():
_timeseries_df2 = timeseries_df2.copy()
pp.compare(
df=_timeseries_df2,
source_columns=["y"],
compare_columns=["z"],
compare_type=PPC.DIFF,
)
assert _timeseries_df2.equals(timeseries_df2)
def test_compare_diff():
# `difference` comparison
post_df = compare(
post_df = pp.compare(
df=timeseries_df2,
source_columns=["y"],
compare_columns=["z"],
compare_type="difference",
compare_type=PPC.DIFF,
)
"""
label y z difference__y__z
2019-01-01 x 2.0 2.0 0.0
2019-01-02 y 2.0 4.0 2.0
2019-01-05 z 2.0 10.0 8.0
2019-01-07 q 2.0 8.0 6.0
"""
assert post_df.equals(
pd.DataFrame(
index=timeseries_df2.index,
data={
"label": ["x", "y", "z", "q"],
"y": [2.0, 2.0, 2.0, 2.0],
"z": [2.0, 4.0, 10.0, 8.0],
"difference__y__z": [0.0, 2.0, 8.0, 6.0],
},
)
)
assert post_df.columns.tolist() == ["label", "y", "z", "difference__y__z"]
assert series_to_list(post_df["difference__y__z"]) == [0.0, -2.0, -8.0, -6.0]
# drop original columns
post_df = compare(
post_df = pp.compare(
df=timeseries_df2,
source_columns=["y"],
compare_columns=["z"],
compare_type="difference",
compare_type=PPC.DIFF,
drop_original_columns=True,
)
assert post_df.columns.tolist() == ["label", "difference__y__z"]
assert post_df.equals(
pd.DataFrame(
index=timeseries_df2.index,
data={
"label": ["x", "y", "z", "q"],
"difference__y__z": [0.0, 2.0, 8.0, 6.0],
},
)
)
def test_compare_percentage():
# `percentage` comparison
post_df = compare(
post_df = pp.compare(
df=timeseries_df2,
source_columns=["y"],
compare_columns=["z"],
compare_type="percentage",
compare_type=PPC.PCT,
)
"""
label y z percentage__y__z
2019-01-01 x 2.0 2.0 0.0
2019-01-02 y 2.0 4.0 1.0
2019-01-05 z 2.0 10.0 4.0
2019-01-07 q 2.0 8.0 3.0
"""
assert post_df.equals(
pd.DataFrame(
index=timeseries_df2.index,
data={
"label": ["x", "y", "z", "q"],
"y": [2.0, 2.0, 2.0, 2.0],
"z": [2.0, 4.0, 10.0, 8.0],
"percentage__y__z": [0.0, 1.0, 4.0, 3.0],
},
)
)
assert post_df.columns.tolist() == ["label", "y", "z", "percentage__y__z"]
assert series_to_list(post_df["percentage__y__z"]) == [0.0, -0.5, -0.8, -0.75]
def test_compare_ratio():
# `ratio` comparison
post_df = compare(
post_df = pp.compare(
df=timeseries_df2,
source_columns=["y"],
compare_columns=["z"],
compare_type="ratio",
compare_type=PPC.RAT,
)
"""
label y z ratio__y__z
2019-01-01 x 2.0 2.0 1.0
2019-01-02 y 2.0 4.0 2.0
2019-01-05 z 2.0 10.0 5.0
2019-01-07 q 2.0 8.0 4.0
"""
assert post_df.equals(
pd.DataFrame(
index=timeseries_df2.index,
data={
"label": ["x", "y", "z", "q"],
"y": [2.0, 2.0, 2.0, 2.0],
"z": [2.0, 4.0, 10.0, 8.0],
"ratio__y__z": [1.0, 2.0, 5.0, 4.0],
},
)
)
def test_compare_multi_index_column():
index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03"])
index.name = "__timestamp"
iterables = [["m1", "m2"], ["a", "b"], ["x", "y"]]
columns = pd.MultiIndex.from_product(iterables, names=[None, "level1", "level2"])
df = pd.DataFrame(index=index, columns=columns, data=1)
"""
m1 m2
level1 a b a b
level2 x y x y x y x y
__timestamp
2021-01-01 1 1 1 1 1 1 1 1
2021-01-02 1 1 1 1 1 1 1 1
2021-01-03 1 1 1 1 1 1 1 1
"""
post_df = pp.compare(
df,
source_columns=["m1"],
compare_columns=["m2"],
compare_type=PPC.DIFF,
drop_original_columns=True,
)
flat_df = pp.flatten(post_df)
"""
__timestamp difference__m1__m2, a, x difference__m1__m2, a, y difference__m1__m2, b, x difference__m1__m2, b, y
0 2021-01-01 0 0 0 0
1 2021-01-02 0 0 0 0
2 2021-01-03 0 0 0 0
"""
assert flat_df.equals(
pd.DataFrame(
data={
"__timestamp": pd.to_datetime(
["2021-01-01", "2021-01-02", "2021-01-03"]
),
"difference__m1__m2, a, x": [0, 0, 0],
"difference__m1__m2, a, y": [0, 0, 0],
"difference__m1__m2, b, x": [0, 0, 0],
"difference__m1__m2, b, y": [0, 0, 0],
}
)
)
def test_compare_after_pivot():
pivot_df = pp.pivot(
df=multiple_metrics_df,
index=["dttm"],
columns=["country"],
aggregates={
"sum_metric": {"operator": "sum"},
"count_metric": {"operator": "sum"},
},
flatten_columns=False,
reset_index=False,
)
"""
count_metric sum_metric
country UK US UK US
dttm
2019-01-01 1 2 5 6
2019-01-02 3 4 7 8
"""
compared_df = pp.compare(
pivot_df,
source_columns=["count_metric"],
compare_columns=["sum_metric"],
compare_type=PPC.DIFF,
drop_original_columns=True,
)
"""
difference__count_metric__sum_metric
country UK US
dttm
2019-01-01 4 4
2019-01-02 4 4
"""
flat_df = pp.flatten(compared_df)
"""
dttm difference__count_metric__sum_metric, UK difference__count_metric__sum_metric, US
0 2019-01-01 4 4
1 2019-01-02 4 4
"""
assert flat_df.equals(
pd.DataFrame(
data={
"dttm": pd.to_datetime(["2019-01-01", "2019-01-02"]),
FLAT_COLUMN_SEPARATOR.join(
["difference__count_metric__sum_metric", "UK"]
): [4, 4],
FLAT_COLUMN_SEPARATOR.join(
["difference__count_metric__sum_metric", "US"]
): [4, 4],
}
)
)
assert post_df.columns.tolist() == ["label", "y", "z", "ratio__y__z"]
assert series_to_list(post_df["ratio__y__z"]) == [1.0, 0.5, 0.2, 0.25]

View File

@ -22,7 +22,7 @@ from numpy import nan
from numpy.testing import assert_array_equal
from pandas import DataFrame
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import InvalidPostProcessingError
from superset.utils.core import DTTM_ALIAS, PostProcessingContributionOrientation
from superset.utils.pandas_postprocessing import contribution
@ -40,10 +40,10 @@ def test_contribution():
"c": [nan, nan, nan],
}
)
with pytest.raises(QueryObjectValidationError, match="not numeric"):
with pytest.raises(InvalidPostProcessingError, match="not numeric"):
contribution(df, columns=[DTTM_ALIAS])
with pytest.raises(QueryObjectValidationError, match="same length"):
with pytest.raises(InvalidPostProcessingError, match="same length"):
contribution(df, columns=["a"], rename_columns=["aa", "bb"])
# cell contribution across row

View File

@ -14,11 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pandas as pd
import pytest
from pandas import to_datetime
from superset.exceptions import QueryObjectValidationError
from superset.utils.pandas_postprocessing import cum, pivot
from superset.exceptions import InvalidPostProcessingError
from superset.utils import pandas_postprocessing as pp
from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
from tests.unit_tests.fixtures.dataframes import (
multiple_metrics_df,
single_metric_df,
@ -27,33 +28,41 @@ from tests.unit_tests.fixtures.dataframes import (
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
def test_cum_should_not_side_effect():
_timeseries_df = timeseries_df.copy()
pp.cum(
df=timeseries_df, columns={"y": "y2"}, operator="sum",
)
assert _timeseries_df.equals(timeseries_df)
def test_cum():
# create new column (cumsum)
post_df = cum(df=timeseries_df, columns={"y": "y2"}, operator="sum",)
post_df = pp.cum(df=timeseries_df, columns={"y": "y2"}, operator="sum",)
assert post_df.columns.tolist() == ["label", "y", "y2"]
assert series_to_list(post_df["label"]) == ["x", "y", "z", "q"]
assert series_to_list(post_df["y"]) == [1.0, 2.0, 3.0, 4.0]
assert series_to_list(post_df["y2"]) == [1.0, 3.0, 6.0, 10.0]
# overwrite column (cumprod)
post_df = cum(df=timeseries_df, columns={"y": "y"}, operator="prod",)
post_df = pp.cum(df=timeseries_df, columns={"y": "y"}, operator="prod",)
assert post_df.columns.tolist() == ["label", "y"]
assert series_to_list(post_df["y"]) == [1.0, 2.0, 6.0, 24.0]
# overwrite column (cummin)
post_df = cum(df=timeseries_df, columns={"y": "y"}, operator="min",)
post_df = pp.cum(df=timeseries_df, columns={"y": "y"}, operator="min",)
assert post_df.columns.tolist() == ["label", "y"]
assert series_to_list(post_df["y"]) == [1.0, 1.0, 1.0, 1.0]
# invalid operator
with pytest.raises(QueryObjectValidationError):
cum(
with pytest.raises(InvalidPostProcessingError):
pp.cum(
df=timeseries_df, columns={"y": "y"}, operator="abc",
)
def test_cum_with_pivot_df_and_single_metric():
pivot_df = pivot(
def test_cum_after_pivot_with_single_metric():
pivot_df = pp.pivot(
df=single_metric_df,
index=["dttm"],
columns=["country"],
@ -61,19 +70,40 @@ def test_cum_with_pivot_df_and_single_metric():
flatten_columns=False,
reset_index=False,
)
cum_df = cum(df=pivot_df, operator="sum", is_pivot_df=True,)
# dttm UK US
# 0 2019-01-01 5 6
# 1 2019-01-02 12 14
assert cum_df["UK"].to_list() == [5.0, 12.0]
assert cum_df["US"].to_list() == [6.0, 14.0]
assert (
cum_df["dttm"].to_list() == to_datetime(["2019-01-01", "2019-01-02"]).to_list()
"""
sum_metric
country UK US
dttm
2019-01-01 5 6
2019-01-02 7 8
"""
cum_df = pp.cum(df=pivot_df, operator="sum", columns={"sum_metric": "sum_metric"})
"""
sum_metric
country UK US
dttm
2019-01-01 5 6
2019-01-02 12 14
"""
cum_and_flat_df = pp.flatten(cum_df)
"""
dttm sum_metric, UK sum_metric, US
0 2019-01-01 5 6
1 2019-01-02 12 14
"""
assert cum_and_flat_df.equals(
pd.DataFrame(
{
"dttm": pd.to_datetime(["2019-01-01", "2019-01-02"]),
FLAT_COLUMN_SEPARATOR.join(["sum_metric", "UK"]): [5, 12],
FLAT_COLUMN_SEPARATOR.join(["sum_metric", "US"]): [6, 14],
}
)
)
def test_cum_with_pivot_df_and_multiple_metrics():
pivot_df = pivot(
def test_cum_after_pivot_with_multiple_metrics():
pivot_df = pp.pivot(
df=multiple_metrics_df,
index=["dttm"],
columns=["country"],
@ -84,14 +114,39 @@ def test_cum_with_pivot_df_and_multiple_metrics():
flatten_columns=False,
reset_index=False,
)
cum_df = cum(df=pivot_df, operator="sum", is_pivot_df=True,)
# dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US
# 0 2019-01-01 1 2 5 6
# 1 2019-01-02 4 6 12 14
assert cum_df["count_metric, UK"].to_list() == [1.0, 4.0]
assert cum_df["count_metric, US"].to_list() == [2.0, 6.0]
assert cum_df["sum_metric, UK"].to_list() == [5.0, 12.0]
assert cum_df["sum_metric, US"].to_list() == [6.0, 14.0]
assert (
cum_df["dttm"].to_list() == to_datetime(["2019-01-01", "2019-01-02"]).to_list()
"""
count_metric sum_metric
country UK US UK US
dttm
2019-01-01 1 2 5 6
2019-01-02 3 4 7 8
"""
cum_df = pp.cum(
df=pivot_df,
operator="sum",
columns={"sum_metric": "sum_metric", "count_metric": "count_metric"},
)
"""
count_metric sum_metric
country UK US UK US
dttm
2019-01-01 1 2 5 6
2019-01-02 4 6 12 14
"""
flat_df = pp.flatten(cum_df)
"""
dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US
0 2019-01-01 1 2 5 6
1 2019-01-02 4 6 12 14
"""
assert flat_df.equals(
pd.DataFrame(
{
"dttm": pd.to_datetime(["2019-01-01", "2019-01-02"]),
FLAT_COLUMN_SEPARATOR.join(["count_metric", "UK"]): [1, 4],
FLAT_COLUMN_SEPARATOR.join(["count_metric", "US"]): [2, 6],
FLAT_COLUMN_SEPARATOR.join(["sum_metric", "UK"]): [5, 12],
FLAT_COLUMN_SEPARATOR.join(["sum_metric", "US"]): [6, 14],
}
)
)

View File

@ -16,7 +16,7 @@
# under the License.
import pytest
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import InvalidPostProcessingError
from superset.utils.pandas_postprocessing import diff
from tests.unit_tests.fixtures.dataframes import timeseries_df, timeseries_df2
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
@ -39,7 +39,7 @@ def test_diff():
assert series_to_list(post_df["y1"]) == [-1.0, -1.0, -1.0, None]
# invalid column reference
with pytest.raises(QueryObjectValidationError):
with pytest.raises(InvalidPostProcessingError):
diff(
df=timeseries_df, columns={"abc": "abc"},
)

View File

@ -0,0 +1,64 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pandas as pd
from superset.utils import pandas_postprocessing as pp
from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
def test_flat_should_not_change():
df = pd.DataFrame(data={"foo": [1, 2, 3], "bar": [4, 5, 6],})
assert pp.flatten(df).equals(df)
def test_flat_should_not_reset_index():
index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03"])
index.name = "__timestamp"
df = pd.DataFrame(index=index, data={"foo": [1, 2, 3], "bar": [4, 5, 6]})
assert pp.flatten(df, reset_index=False).equals(df)
def test_flat_should_flat_datetime_index():
index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03"])
index.name = "__timestamp"
df = pd.DataFrame(index=index, data={"foo": [1, 2, 3], "bar": [4, 5, 6]})
assert pp.flatten(df).equals(
pd.DataFrame({"__timestamp": index, "foo": [1, 2, 3], "bar": [4, 5, 6],})
)
def test_flat_should_flat_multiple_index():
index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03"])
index.name = "__timestamp"
iterables = [["foo", "bar"], [1, "two"]]
columns = pd.MultiIndex.from_product(iterables, names=["level1", "level2"])
df = pd.DataFrame(index=index, columns=columns, data=1)
assert pp.flatten(df).equals(
pd.DataFrame(
{
"__timestamp": index,
FLAT_COLUMN_SEPARATOR.join(["foo", "1"]): [1, 1, 1],
FLAT_COLUMN_SEPARATOR.join(["foo", "two"]): [1, 1, 1],
FLAT_COLUMN_SEPARATOR.join(["bar", "1"]): [1, 1, 1],
FLAT_COLUMN_SEPARATOR.join(["bar", "two"]): [1, 1, 1],
}
)
)

View File

@ -19,7 +19,7 @@ import numpy as np
import pytest
from pandas import DataFrame, Timestamp, to_datetime
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import InvalidPostProcessingError
from superset.utils.pandas_postprocessing import _flatten_column_after_pivot, pivot
from tests.unit_tests.fixtures.dataframes import categories_df, single_metric_df
from tests.unit_tests.pandas_postprocessing.utils import (
@ -172,7 +172,7 @@ def test_pivot_exceptions():
pivot(df=categories_df, columns=["dept"], aggregates=AGGREGATES_SINGLE)
# invalid index reference
with pytest.raises(QueryObjectValidationError):
with pytest.raises(InvalidPostProcessingError):
pivot(
df=categories_df,
index=["abc"],
@ -181,7 +181,7 @@ def test_pivot_exceptions():
)
# invalid column reference
with pytest.raises(QueryObjectValidationError):
with pytest.raises(InvalidPostProcessingError):
pivot(
df=categories_df,
index=["dept"],
@ -190,7 +190,7 @@ def test_pivot_exceptions():
)
# invalid aggregate options
with pytest.raises(QueryObjectValidationError):
with pytest.raises(InvalidPostProcessingError):
pivot(
df=categories_df,
index=["name"],

View File

@ -19,7 +19,7 @@ from importlib.util import find_spec
import pytest
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import InvalidPostProcessingError
from superset.utils.core import DTTM_ALIAS
from superset.utils.pandas_postprocessing import prophet
from tests.unit_tests.fixtures.dataframes import prophet_df
@ -75,40 +75,40 @@ def test_prophet_valid_zero_periods():
def test_prophet_import():
dynamic_module = find_spec("prophet")
if dynamic_module is None:
with pytest.raises(QueryObjectValidationError):
with pytest.raises(InvalidPostProcessingError):
prophet(df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.9)
def test_prophet_missing_temporal_column():
df = prophet_df.drop(DTTM_ALIAS, axis=1)
with pytest.raises(QueryObjectValidationError):
with pytest.raises(InvalidPostProcessingError):
prophet(
df=df, time_grain="P1M", periods=3, confidence_interval=0.9,
)
def test_prophet_incorrect_confidence_interval():
with pytest.raises(QueryObjectValidationError):
with pytest.raises(InvalidPostProcessingError):
prophet(
df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.0,
)
with pytest.raises(QueryObjectValidationError):
with pytest.raises(InvalidPostProcessingError):
prophet(
df=prophet_df, time_grain="P1M", periods=3, confidence_interval=1.0,
)
def test_prophet_incorrect_periods():
with pytest.raises(QueryObjectValidationError):
with pytest.raises(InvalidPostProcessingError):
prophet(
df=prophet_df, time_grain="P1M", periods=-1, confidence_interval=0.8,
)
def test_prophet_incorrect_time_grain():
with pytest.raises(QueryObjectValidationError):
with pytest.raises(InvalidPostProcessingError):
prophet(
df=prophet_df, time_grain="yearly", periods=10, confidence_interval=0.8,
)

View File

@ -14,45 +14,80 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pandas as pd
import pytest
from pandas import DataFrame, to_datetime
from superset.exceptions import QueryObjectValidationError
from superset.utils.pandas_postprocessing import resample
from tests.unit_tests.fixtures.dataframes import timeseries_df
from superset.exceptions import InvalidPostProcessingError
from superset.utils import pandas_postprocessing as pp
from tests.unit_tests.fixtures.dataframes import categories_df, timeseries_df
def test_resample_should_not_side_effect():
_timeseries_df = timeseries_df.copy()
pp.resample(df=_timeseries_df, rule="1D", method="ffill")
assert _timeseries_df.equals(timeseries_df)
def test_resample():
df = timeseries_df.copy()
df.index.name = "time_column"
df.reset_index(inplace=True)
post_df = resample(df=df, rule="1D", method="ffill", time_column="time_column",)
assert post_df["label"].tolist() == ["x", "y", "y", "y", "z", "z", "q"]
assert post_df["y"].tolist() == [1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 4.0]
post_df = resample(
df=df, rule="1D", method="asfreq", time_column="time_column", fill_value=0,
post_df = pp.resample(df=timeseries_df, rule="1D", method="ffill")
"""
label y
2019-01-01 x 1.0
2019-01-02 y 2.0
2019-01-03 y 2.0
2019-01-04 y 2.0
2019-01-05 z 3.0
2019-01-06 z 3.0
2019-01-07 q 4.0
"""
assert post_df.equals(
pd.DataFrame(
index=pd.to_datetime(
[
"2019-01-01",
"2019-01-02",
"2019-01-03",
"2019-01-04",
"2019-01-05",
"2019-01-06",
"2019-01-07",
]
),
data={
"label": ["x", "y", "y", "y", "z", "z", "q"],
"y": [1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 4.0],
},
)
)
assert post_df["label"].tolist() == ["x", "y", 0, 0, "z", 0, "q"]
assert post_df["y"].tolist() == [1.0, 2.0, 0, 0, 3.0, 0, 4.0]
def test_resample_with_groupby():
"""
The Dataframe contains a timestamp column, a string column and a numeric column.
__timestamp city val
0 2022-01-13 Chicago 6.0
1 2022-01-13 LA 5.0
2 2022-01-13 NY 4.0
3 2022-01-11 Chicago 3.0
4 2022-01-11 LA 2.0
5 2022-01-11 NY 1.0
"""
df = DataFrame(
{
"__timestamp": to_datetime(
def test_resample_zero_fill():
post_df = pp.resample(df=timeseries_df, rule="1D", method="asfreq", fill_value=0)
assert post_df.equals(
pd.DataFrame(
index=pd.to_datetime(
[
"2019-01-01",
"2019-01-02",
"2019-01-03",
"2019-01-04",
"2019-01-05",
"2019-01-06",
"2019-01-07",
]
),
data={
"label": ["x", "y", 0, 0, "z", 0, "q"],
"y": [1.0, 2.0, 0, 0, 3.0, 0, 4.0],
},
)
)
def test_resample_after_pivot():
df = pd.DataFrame(
data={
"__timestamp": pd.to_datetime(
[
"2022-01-13",
"2022-01-13",
@ -66,42 +101,53 @@ __timestamp city val
"val": [6.0, 5.0, 4.0, 3.0, 2.0, 1.0],
}
)
post_df = resample(
pivot_df = pp.pivot(
df=df,
rule="1D",
method="asfreq",
fill_value=0,
time_column="__timestamp",
groupby_columns=("city",),
index=["__timestamp"],
columns=["city"],
aggregates={"val": {"operator": "sum"},},
flatten_columns=False,
reset_index=False,
)
assert list(post_df.columns) == [
"__timestamp",
"city",
"val",
]
assert [str(dt.date()) for dt in post_df["__timestamp"]] == (
["2022-01-11"] * 3 + ["2022-01-12"] * 3 + ["2022-01-13"] * 3
"""
val
city Chicago LA NY
__timestamp
2022-01-11 3.0 2.0 1.0
2022-01-13 6.0 5.0 4.0
"""
resample_df = pp.resample(df=pivot_df, rule="1D", method="asfreq", fill_value=0,)
"""
val
city Chicago LA NY
__timestamp
2022-01-11 3.0 2.0 1.0
2022-01-12 0.0 0.0 0.0
2022-01-13 6.0 5.0 4.0
"""
flat_df = pp.flatten(resample_df)
"""
__timestamp val, Chicago val, LA val, NY
0 2022-01-11 3.0 2.0 1.0
1 2022-01-12 0.0 0.0 0.0
2 2022-01-13 6.0 5.0 4.0
"""
assert flat_df.equals(
pd.DataFrame(
data={
"__timestamp": pd.to_datetime(
["2022-01-11", "2022-01-12", "2022-01-13"]
),
"val, Chicago": [3.0, 0, 6.0],
"val, LA": [2.0, 0, 5.0],
"val, NY": [1.0, 0, 4.0],
}
)
)
assert list(post_df["val"]) == [3.0, 2.0, 1.0, 0, 0, 0, 6.0, 5.0, 4.0]
# should raise error when get a non-existent column
with pytest.raises(QueryObjectValidationError):
resample(
df=df,
rule="1D",
method="asfreq",
fill_value=0,
time_column="__timestamp",
groupby_columns=("city", "unkonw_column",),
)
# should raise error when get a None value in groupby list
with pytest.raises(QueryObjectValidationError):
resample(
df=df,
rule="1D",
method="asfreq",
fill_value=0,
time_column="__timestamp",
groupby_columns=("city", None,),
def test_resample_should_raise_ex():
with pytest.raises(InvalidPostProcessingError):
pp.resample(
df=categories_df, rule="1D", method="asfreq",
)

View File

@ -14,11 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pandas as pd
import pytest
from pandas import to_datetime
from superset.exceptions import QueryObjectValidationError
from superset.utils.pandas_postprocessing import pivot, rolling
from superset.exceptions import InvalidPostProcessingError
from superset.utils import pandas_postprocessing as pp
from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
from tests.unit_tests.fixtures.dataframes import (
multiple_metrics_df,
single_metric_df,
@ -27,9 +28,21 @@ from tests.unit_tests.fixtures.dataframes import (
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
def test_rolling_should_not_side_effect():
_timeseries_df = timeseries_df.copy()
pp.rolling(
df=timeseries_df,
columns={"y": "y"},
rolling_type="sum",
window=2,
min_periods=0,
)
assert _timeseries_df.equals(timeseries_df)
def test_rolling():
# sum rolling type
post_df = rolling(
post_df = pp.rolling(
df=timeseries_df,
columns={"y": "y"},
rolling_type="sum",
@ -41,7 +54,7 @@ def test_rolling():
assert series_to_list(post_df["y"]) == [1.0, 3.0, 5.0, 7.0]
# mean rolling type with alias
post_df = rolling(
post_df = pp.rolling(
df=timeseries_df,
rolling_type="mean",
columns={"y": "y_mean"},
@ -52,7 +65,7 @@ def test_rolling():
assert series_to_list(post_df["y_mean"]) == [1.0, 1.5, 2.0, 2.5]
# count rolling type
post_df = rolling(
post_df = pp.rolling(
df=timeseries_df,
rolling_type="count",
columns={"y": "y"},
@ -63,7 +76,7 @@ def test_rolling():
assert series_to_list(post_df["y"]) == [1.0, 2.0, 3.0, 4.0]
# quantile rolling type
post_df = rolling(
post_df = pp.rolling(
df=timeseries_df,
columns={"y": "q1"},
rolling_type="quantile",
@ -75,14 +88,14 @@ def test_rolling():
assert series_to_list(post_df["q1"]) == [1.0, 1.25, 1.5, 1.75]
# incorrect rolling type
with pytest.raises(QueryObjectValidationError):
rolling(
with pytest.raises(InvalidPostProcessingError):
pp.rolling(
df=timeseries_df, columns={"y": "y"}, rolling_type="abc", window=2,
)
# incorrect rolling type options
with pytest.raises(QueryObjectValidationError):
rolling(
with pytest.raises(InvalidPostProcessingError):
pp.rolling(
df=timeseries_df,
columns={"y": "y"},
rolling_type="quantile",
@ -91,8 +104,8 @@ def test_rolling():
)
def test_rolling_with_pivot_df_and_single_metric():
pivot_df = pivot(
def test_rolling_should_empty_df():
pivot_df = pp.pivot(
df=single_metric_df,
index=["dttm"],
columns=["country"],
@ -100,27 +113,65 @@ def test_rolling_with_pivot_df_and_single_metric():
flatten_columns=False,
reset_index=False,
)
rolling_df = rolling(
df=pivot_df, rolling_type="sum", window=2, min_periods=0, is_pivot_df=True,
)
# dttm UK US
# 0 2019-01-01 5 6
# 1 2019-01-02 12 14
assert rolling_df["UK"].to_list() == [5.0, 12.0]
assert rolling_df["US"].to_list() == [6.0, 14.0]
assert (
rolling_df["dttm"].to_list()
== to_datetime(["2019-01-01", "2019-01-02"]).to_list()
)
rolling_df = rolling(
df=pivot_df, rolling_type="sum", window=2, min_periods=2, is_pivot_df=True,
rolling_df = pp.rolling(
df=pivot_df,
rolling_type="sum",
window=2,
min_periods=2,
columns={"sum_metric": "sum_metric"},
)
assert rolling_df.empty is True
def test_rolling_with_pivot_df_and_multiple_metrics():
pivot_df = pivot(
def test_rolling_after_pivot_with_single_metric():
pivot_df = pp.pivot(
df=single_metric_df,
index=["dttm"],
columns=["country"],
aggregates={"sum_metric": {"operator": "sum"}},
flatten_columns=False,
reset_index=False,
)
"""
sum_metric
country UK US
dttm
2019-01-01 5 6
2019-01-02 7 8
"""
rolling_df = pp.rolling(
df=pivot_df,
columns={"sum_metric": "sum_metric"},
rolling_type="sum",
window=2,
min_periods=0,
)
"""
sum_metric
country UK US
dttm
2019-01-01 5.0 6.0
2019-01-02 12.0 14.0
"""
flat_df = pp.flatten(rolling_df)
"""
dttm sum_metric, UK sum_metric, US
0 2019-01-01 5.0 6.0
1 2019-01-02 12.0 14.0
"""
assert flat_df.equals(
pd.DataFrame(
data={
"dttm": pd.to_datetime(["2019-01-01", "2019-01-02"]),
FLAT_COLUMN_SEPARATOR.join(["sum_metric", "UK"]): [5.0, 12.0],
FLAT_COLUMN_SEPARATOR.join(["sum_metric", "US"]): [6.0, 14.0],
}
)
)
def test_rolling_after_pivot_with_multiple_metrics():
pivot_df = pp.pivot(
df=multiple_metrics_df,
index=["dttm"],
columns=["country"],
@ -131,17 +182,41 @@ def test_rolling_with_pivot_df_and_multiple_metrics():
flatten_columns=False,
reset_index=False,
)
rolling_df = rolling(
df=pivot_df, rolling_type="sum", window=2, min_periods=0, is_pivot_df=True,
"""
count_metric sum_metric
country UK US UK US
dttm
2019-01-01 1 2 5 6
2019-01-02 3 4 7 8
"""
rolling_df = pp.rolling(
df=pivot_df,
columns={"count_metric": "count_metric", "sum_metric": "sum_metric",},
rolling_type="sum",
window=2,
min_periods=0,
)
# dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US
# 0 2019-01-01 1.0 2.0 5.0 6.0
# 1 2019-01-02 4.0 6.0 12.0 14.0
assert rolling_df["count_metric, UK"].to_list() == [1.0, 4.0]
assert rolling_df["count_metric, US"].to_list() == [2.0, 6.0]
assert rolling_df["sum_metric, UK"].to_list() == [5.0, 12.0]
assert rolling_df["sum_metric, US"].to_list() == [6.0, 14.0]
assert (
rolling_df["dttm"].to_list()
== to_datetime(["2019-01-01", "2019-01-02",]).to_list()
"""
count_metric sum_metric
country UK US UK US
dttm
2019-01-01 1.0 2.0 5.0 6.0
2019-01-02 4.0 6.0 12.0 14.0
"""
flat_df = pp.flatten(rolling_df)
"""
dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US
0 2019-01-01 1.0 2.0 5.0 6.0
1 2019-01-02 4.0 6.0 12.0 14.0
"""
assert flat_df.equals(
pd.DataFrame(
data={
"dttm": pd.to_datetime(["2019-01-01", "2019-01-02"]),
FLAT_COLUMN_SEPARATOR.join(["count_metric", "UK"]): [1.0, 4.0],
FLAT_COLUMN_SEPARATOR.join(["count_metric", "US"]): [2.0, 6.0],
FLAT_COLUMN_SEPARATOR.join(["sum_metric", "UK"]): [5.0, 12.0],
FLAT_COLUMN_SEPARATOR.join(["sum_metric", "US"]): [6.0, 14.0],
}
)
)

View File

@ -16,7 +16,7 @@
# under the License.
import pytest
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import InvalidPostProcessingError
from superset.utils.pandas_postprocessing.select import select
from tests.unit_tests.fixtures.dataframes import timeseries_df
@ -47,9 +47,9 @@ def test_select():
assert post_df.columns.tolist() == ["y1"]
# invalid columns
with pytest.raises(QueryObjectValidationError):
with pytest.raises(InvalidPostProcessingError):
select(df=timeseries_df, columns=["abc"], rename={"abc": "qwerty"})
# select renamed column by new name
with pytest.raises(QueryObjectValidationError):
with pytest.raises(InvalidPostProcessingError):
select(df=timeseries_df, columns=["label_new"], rename={"label": "label_new"})

View File

@ -16,7 +16,7 @@
# under the License.
import pytest
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import InvalidPostProcessingError
from superset.utils.pandas_postprocessing import sort
from tests.unit_tests.fixtures.dataframes import categories_df
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
@ -26,5 +26,5 @@ def test_sort():
df = sort(df=categories_df, columns={"category": True, "asc_idx": False})
assert series_to_list(df["asc_idx"])[1] == 96
with pytest.raises(QueryObjectValidationError):
with pytest.raises(InvalidPostProcessingError):
sort(df=df, columns={"abc": True})