feat(advanced analysis): support MultiIndex column in post processing stage (#19116)
This commit is contained in:
parent
6083545e86
commit
375c03e084
|
|
@ -21,16 +21,16 @@ import {
|
|||
getColumnLabel,
|
||||
getMetricLabel,
|
||||
PostProcessingBoxplot,
|
||||
BoxPlotQueryObjectWhiskerType,
|
||||
} from '@superset-ui/core';
|
||||
import { PostProcessingFactory } from './types';
|
||||
|
||||
type BoxPlotQueryObjectWhiskerType =
|
||||
PostProcessingBoxplot['options']['whisker_type'];
|
||||
const PERCENTILE_REGEX = /(\d+)\/(\d+) percentiles/;
|
||||
|
||||
export const boxplotOperator: PostProcessingFactory<
|
||||
PostProcessingBoxplot | undefined
|
||||
> = (formData, queryObject) => {
|
||||
export const boxplotOperator: PostProcessingFactory<PostProcessingBoxplot> = (
|
||||
formData,
|
||||
queryObject,
|
||||
) => {
|
||||
const { groupby, whiskerOptions } = formData;
|
||||
|
||||
if (whiskerOptions) {
|
||||
|
|
|
|||
|
|
@ -19,16 +19,15 @@
|
|||
import { PostProcessingContribution } from '@superset-ui/core';
|
||||
import { PostProcessingFactory } from './types';
|
||||
|
||||
export const contributionOperator: PostProcessingFactory<
|
||||
PostProcessingContribution | undefined
|
||||
> = (formData, queryObject) => {
|
||||
if (formData.contributionMode) {
|
||||
return {
|
||||
operation: 'contribution',
|
||||
options: {
|
||||
orientation: formData.contributionMode,
|
||||
},
|
||||
};
|
||||
}
|
||||
return undefined;
|
||||
};
|
||||
export const contributionOperator: PostProcessingFactory<PostProcessingContribution> =
|
||||
(formData, queryObject) => {
|
||||
if (formData.contributionMode) {
|
||||
return {
|
||||
operation: 'contribution',
|
||||
options: {
|
||||
orientation: formData.contributionMode,
|
||||
},
|
||||
};
|
||||
}
|
||||
return undefined;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -0,0 +1,26 @@
|
|||
/* eslint-disable camelcase */
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitationsxw
|
||||
* under the License.
|
||||
*/
|
||||
import { PostProcessingFlatten } from '@superset-ui/core';
|
||||
import { PostProcessingFactory } from './types';
|
||||
|
||||
export const flattenOperator: PostProcessingFactory<PostProcessingFlatten> = (
|
||||
formData,
|
||||
queryObject,
|
||||
) => ({ operation: 'flatten' });
|
||||
|
|
@ -26,4 +26,5 @@ export { resampleOperator } from './resampleOperator';
|
|||
export { contributionOperator } from './contributionOperator';
|
||||
export { prophetOperator } from './prophetOperator';
|
||||
export { boxplotOperator } from './boxplotOperator';
|
||||
export { flattenOperator } from './flattenOperator';
|
||||
export * from './utils';
|
||||
|
|
|
|||
|
|
@ -24,19 +24,14 @@ import {
|
|||
PostProcessingPivot,
|
||||
} from '@superset-ui/core';
|
||||
import { PostProcessingFactory } from './types';
|
||||
import { isValidTimeCompare } from './utils';
|
||||
import { timeComparePivotOperator } from './timeComparePivotOperator';
|
||||
|
||||
export const pivotOperator: PostProcessingFactory<
|
||||
PostProcessingPivot | undefined
|
||||
> = (formData, queryObject) => {
|
||||
export const pivotOperator: PostProcessingFactory<PostProcessingPivot> = (
|
||||
formData,
|
||||
queryObject,
|
||||
) => {
|
||||
const metricLabels = ensureIsArray(queryObject.metrics).map(getMetricLabel);
|
||||
const { x_axis: xAxis } = formData;
|
||||
if ((xAxis || queryObject.is_timeseries) && metricLabels.length) {
|
||||
if (isValidTimeCompare(formData, queryObject)) {
|
||||
return timeComparePivotOperator(formData, queryObject);
|
||||
}
|
||||
|
||||
return {
|
||||
operation: 'pivot',
|
||||
options: {
|
||||
|
|
@ -48,6 +43,8 @@ export const pivotOperator: PostProcessingFactory<
|
|||
metricLabels.map(metric => [metric, { operator: 'mean' }]),
|
||||
),
|
||||
drop_missing_columns: false,
|
||||
flatten_columns: false,
|
||||
reset_index: false,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,9 +19,10 @@
|
|||
import { DTTM_ALIAS, PostProcessingProphet } from '@superset-ui/core';
|
||||
import { PostProcessingFactory } from './types';
|
||||
|
||||
export const prophetOperator: PostProcessingFactory<
|
||||
PostProcessingProphet | undefined
|
||||
> = (formData, queryObject) => {
|
||||
export const prophetOperator: PostProcessingFactory<PostProcessingProphet> = (
|
||||
formData,
|
||||
queryObject,
|
||||
) => {
|
||||
if (formData.forecastEnabled) {
|
||||
return {
|
||||
operation: 'prophet',
|
||||
|
|
|
|||
|
|
@ -17,36 +17,23 @@
|
|||
* specific language governing permissions and limitationsxw
|
||||
* under the License.
|
||||
*/
|
||||
import {
|
||||
DTTM_ALIAS,
|
||||
ensureIsArray,
|
||||
isPhysicalColumn,
|
||||
PostProcessingResample,
|
||||
} from '@superset-ui/core';
|
||||
import { PostProcessingResample } from '@superset-ui/core';
|
||||
import { PostProcessingFactory } from './types';
|
||||
|
||||
export const resampleOperator: PostProcessingFactory<
|
||||
PostProcessingResample | undefined
|
||||
> = (formData, queryObject) => {
|
||||
export const resampleOperator: PostProcessingFactory<PostProcessingResample> = (
|
||||
formData,
|
||||
queryObject,
|
||||
) => {
|
||||
const resampleZeroFill = formData.resample_method === 'zerofill';
|
||||
const resampleMethod = resampleZeroFill ? 'asfreq' : formData.resample_method;
|
||||
const resampleRule = formData.resample_rule;
|
||||
if (resampleMethod && resampleRule) {
|
||||
const groupby_columns = ensureIsArray(queryObject.columns).map(column => {
|
||||
if (isPhysicalColumn(column)) {
|
||||
return column;
|
||||
}
|
||||
return column.label;
|
||||
});
|
||||
|
||||
return {
|
||||
operation: 'resample',
|
||||
options: {
|
||||
method: resampleMethod,
|
||||
rule: resampleRule,
|
||||
fill_value: resampleZeroFill ? 0 : null,
|
||||
time_column: formData.x_axis || DTTM_ALIAS,
|
||||
groupby_columns,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,39 +18,25 @@
|
|||
* under the License.
|
||||
*/
|
||||
import {
|
||||
ComparisionType,
|
||||
ensureIsArray,
|
||||
ensureIsInt,
|
||||
PostProcessingCum,
|
||||
PostProcessingRolling,
|
||||
RollingType,
|
||||
} from '@superset-ui/core';
|
||||
import {
|
||||
getMetricOffsetsMap,
|
||||
isValidTimeCompare,
|
||||
TIME_COMPARISON_SEPARATOR,
|
||||
} from './utils';
|
||||
import { getMetricOffsetsMap, isValidTimeCompare } from './utils';
|
||||
import { PostProcessingFactory } from './types';
|
||||
|
||||
export const rollingWindowOperator: PostProcessingFactory<
|
||||
PostProcessingRolling | PostProcessingCum | undefined
|
||||
PostProcessingRolling | PostProcessingCum
|
||||
> = (formData, queryObject) => {
|
||||
let columns: (string | undefined)[];
|
||||
if (isValidTimeCompare(formData, queryObject)) {
|
||||
const metricsMap = getMetricOffsetsMap(formData, queryObject);
|
||||
const comparisonType = formData.comparison_type;
|
||||
if (comparisonType === ComparisionType.Values) {
|
||||
// time compare type: actual values
|
||||
columns = [
|
||||
...Array.from(metricsMap.values()),
|
||||
...Array.from(metricsMap.keys()),
|
||||
];
|
||||
} else {
|
||||
// time compare type: difference / percentage / ratio
|
||||
columns = Array.from(metricsMap.entries()).map(([offset, metric]) =>
|
||||
[comparisonType, metric, offset].join(TIME_COMPARISON_SEPARATOR),
|
||||
);
|
||||
}
|
||||
columns = [
|
||||
...Array.from(metricsMap.values()),
|
||||
...Array.from(metricsMap.keys()),
|
||||
];
|
||||
} else {
|
||||
columns = ensureIsArray(queryObject.metrics).map(metric => {
|
||||
if (typeof metric === 'string') {
|
||||
|
|
@ -67,7 +53,6 @@ export const rollingWindowOperator: PostProcessingFactory<
|
|||
options: {
|
||||
operator: 'sum',
|
||||
columns: columnsMap,
|
||||
is_pivot_df: true,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
|
@ -84,7 +69,6 @@ export const rollingWindowOperator: PostProcessingFactory<
|
|||
window: ensureIsInt(formData.rolling_periods, 1),
|
||||
min_periods: ensureIsInt(formData.min_periods, 0),
|
||||
columns: columnsMap,
|
||||
is_pivot_df: true,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,9 +20,10 @@
|
|||
import { DTTM_ALIAS, PostProcessingSort, RollingType } from '@superset-ui/core';
|
||||
import { PostProcessingFactory } from './types';
|
||||
|
||||
export const sortOperator: PostProcessingFactory<
|
||||
PostProcessingSort | undefined
|
||||
> = (formData, queryObject) => {
|
||||
export const sortOperator: PostProcessingFactory<PostProcessingSort> = (
|
||||
formData,
|
||||
queryObject,
|
||||
) => {
|
||||
const { x_axis: xAxis } = formData;
|
||||
if (
|
||||
(xAxis || queryObject.is_timeseries) &&
|
||||
|
|
|
|||
|
|
@ -21,26 +21,25 @@ import { ComparisionType, PostProcessingCompare } from '@superset-ui/core';
|
|||
import { getMetricOffsetsMap, isValidTimeCompare } from './utils';
|
||||
import { PostProcessingFactory } from './types';
|
||||
|
||||
export const timeCompareOperator: PostProcessingFactory<
|
||||
PostProcessingCompare | undefined
|
||||
> = (formData, queryObject) => {
|
||||
const comparisonType = formData.comparison_type;
|
||||
const metricOffsetMap = getMetricOffsetsMap(formData, queryObject);
|
||||
export const timeCompareOperator: PostProcessingFactory<PostProcessingCompare> =
|
||||
(formData, queryObject) => {
|
||||
const comparisonType = formData.comparison_type;
|
||||
const metricOffsetMap = getMetricOffsetsMap(formData, queryObject);
|
||||
|
||||
if (
|
||||
isValidTimeCompare(formData, queryObject) &&
|
||||
comparisonType !== ComparisionType.Values
|
||||
) {
|
||||
return {
|
||||
operation: 'compare',
|
||||
options: {
|
||||
source_columns: Array.from(metricOffsetMap.values()),
|
||||
compare_columns: Array.from(metricOffsetMap.keys()),
|
||||
compare_type: comparisonType,
|
||||
drop_original_columns: true,
|
||||
},
|
||||
};
|
||||
}
|
||||
if (
|
||||
isValidTimeCompare(formData, queryObject) &&
|
||||
comparisonType !== ComparisionType.Values
|
||||
) {
|
||||
return {
|
||||
operation: 'compare',
|
||||
options: {
|
||||
source_columns: Array.from(metricOffsetMap.values()),
|
||||
compare_columns: Array.from(metricOffsetMap.keys()),
|
||||
compare_type: comparisonType,
|
||||
drop_original_columns: true,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
return undefined;
|
||||
};
|
||||
return undefined;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -18,54 +18,40 @@
|
|||
* under the License.
|
||||
*/
|
||||
import {
|
||||
ComparisionType,
|
||||
DTTM_ALIAS,
|
||||
ensureIsArray,
|
||||
getColumnLabel,
|
||||
NumpyFunction,
|
||||
PostProcessingPivot,
|
||||
} from '@superset-ui/core';
|
||||
import {
|
||||
getMetricOffsetsMap,
|
||||
isValidTimeCompare,
|
||||
TIME_COMPARISON_SEPARATOR,
|
||||
} from './utils';
|
||||
import { getMetricOffsetsMap, isValidTimeCompare } from './utils';
|
||||
import { PostProcessingFactory } from './types';
|
||||
|
||||
export const timeComparePivotOperator: PostProcessingFactory<
|
||||
PostProcessingPivot | undefined
|
||||
> = (formData, queryObject) => {
|
||||
const comparisonType = formData.comparison_type;
|
||||
const metricOffsetMap = getMetricOffsetsMap(formData, queryObject);
|
||||
export const timeComparePivotOperator: PostProcessingFactory<PostProcessingPivot> =
|
||||
(formData, queryObject) => {
|
||||
const metricOffsetMap = getMetricOffsetsMap(formData, queryObject);
|
||||
|
||||
if (isValidTimeCompare(formData, queryObject)) {
|
||||
const valuesAgg = Object.fromEntries(
|
||||
[...metricOffsetMap.values(), ...metricOffsetMap.keys()].map(metric => [
|
||||
metric,
|
||||
// use the 'mean' aggregates to avoid drop NaN
|
||||
{ operator: 'mean' as NumpyFunction },
|
||||
]),
|
||||
);
|
||||
const changeAgg = Object.fromEntries(
|
||||
[...metricOffsetMap.entries()]
|
||||
.map(([offset, metric]) =>
|
||||
[comparisonType, metric, offset].join(TIME_COMPARISON_SEPARATOR),
|
||||
)
|
||||
// use the 'mean' aggregates to avoid drop NaN
|
||||
.map(metric => [metric, { operator: 'mean' as NumpyFunction }]),
|
||||
);
|
||||
if (isValidTimeCompare(formData, queryObject)) {
|
||||
const aggregates = Object.fromEntries(
|
||||
[...metricOffsetMap.values(), ...metricOffsetMap.keys()].map(metric => [
|
||||
metric,
|
||||
// use the 'mean' aggregates to avoid drop NaN
|
||||
{ operator: 'mean' as NumpyFunction },
|
||||
]),
|
||||
);
|
||||
|
||||
return {
|
||||
operation: 'pivot',
|
||||
options: {
|
||||
index: [formData.x_axis || DTTM_ALIAS],
|
||||
columns: ensureIsArray(queryObject.columns).map(getColumnLabel),
|
||||
aggregates:
|
||||
comparisonType === ComparisionType.Values ? valuesAgg : changeAgg,
|
||||
drop_missing_columns: false,
|
||||
},
|
||||
};
|
||||
}
|
||||
return {
|
||||
operation: 'pivot',
|
||||
options: {
|
||||
index: [formData.x_axis || DTTM_ALIAS],
|
||||
columns: ensureIsArray(queryObject.columns).map(getColumnLabel),
|
||||
drop_missing_columns: false,
|
||||
flatten_columns: false,
|
||||
reset_index: false,
|
||||
aggregates,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
return undefined;
|
||||
};
|
||||
return undefined;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -0,0 +1,59 @@
|
|||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
import { QueryObject, SqlaFormData } from '@superset-ui/core';
|
||||
import { flattenOperator } from '@superset-ui/chart-controls';
|
||||
|
||||
const formData: SqlaFormData = {
|
||||
metrics: [
|
||||
'count(*)',
|
||||
{ label: 'sum(val)', expressionType: 'SQL', sqlExpression: 'sum(val)' },
|
||||
],
|
||||
time_range: '2015 : 2016',
|
||||
granularity: 'month',
|
||||
datasource: 'foo',
|
||||
viz_type: 'table',
|
||||
};
|
||||
const queryObject: QueryObject = {
|
||||
metrics: [
|
||||
'count(*)',
|
||||
{ label: 'sum(val)', expressionType: 'SQL', sqlExpression: 'sum(val)' },
|
||||
],
|
||||
time_range: '2015 : 2016',
|
||||
granularity: 'month',
|
||||
post_processing: [
|
||||
{
|
||||
operation: 'pivot',
|
||||
options: {
|
||||
index: ['__timestamp'],
|
||||
columns: ['nation'],
|
||||
aggregates: {
|
||||
'count(*)': {
|
||||
operator: 'sum',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
test('should do flattenOperator', () => {
|
||||
expect(flattenOperator(formData, queryObject)).toEqual({
|
||||
operation: 'flatten',
|
||||
});
|
||||
});
|
||||
|
|
@ -80,6 +80,8 @@ test('pivot by __timestamp without groupby', () => {
|
|||
'sum(val)': { operator: 'mean' },
|
||||
},
|
||||
drop_missing_columns: false,
|
||||
flatten_columns: false,
|
||||
reset_index: false,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
|
@ -101,6 +103,8 @@ test('pivot by __timestamp with groupby', () => {
|
|||
'sum(val)': { operator: 'mean' },
|
||||
},
|
||||
drop_missing_columns: false,
|
||||
flatten_columns: false,
|
||||
reset_index: false,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
|
@ -127,44 +131,8 @@ test('pivot by x_axis with groupby', () => {
|
|||
'sum(val)': { operator: 'mean' },
|
||||
},
|
||||
drop_missing_columns: false,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
test('timecompare in formdata', () => {
|
||||
expect(
|
||||
pivotOperator(
|
||||
{
|
||||
...formData,
|
||||
comparison_type: 'values',
|
||||
time_compare: ['1 year ago', '1 year later'],
|
||||
},
|
||||
{
|
||||
...queryObject,
|
||||
columns: ['foo', 'bar'],
|
||||
is_timeseries: true,
|
||||
},
|
||||
),
|
||||
).toEqual({
|
||||
operation: 'pivot',
|
||||
options: {
|
||||
aggregates: {
|
||||
'count(*)': { operator: 'mean' },
|
||||
'count(*)__1 year ago': { operator: 'mean' },
|
||||
'count(*)__1 year later': { operator: 'mean' },
|
||||
'sum(val)': {
|
||||
operator: 'mean',
|
||||
},
|
||||
'sum(val)__1 year ago': {
|
||||
operator: 'mean',
|
||||
},
|
||||
'sum(val)__1 year later': {
|
||||
operator: 'mean',
|
||||
},
|
||||
},
|
||||
drop_missing_columns: false,
|
||||
columns: ['foo', 'bar'],
|
||||
index: ['__timestamp'],
|
||||
flatten_columns: false,
|
||||
reset_index: false,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@
|
|||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
import { AdhocColumn, QueryObject, SqlaFormData } from '@superset-ui/core';
|
||||
import { QueryObject, SqlaFormData } from '@superset-ui/core';
|
||||
import { resampleOperator } from '@superset-ui/chart-controls';
|
||||
|
||||
const formData: SqlaFormData = {
|
||||
|
|
@ -74,8 +74,6 @@ test('should do resample on implicit time column', () => {
|
|||
method: 'ffill',
|
||||
rule: '1D',
|
||||
fill_value: null,
|
||||
time_column: '__timestamp',
|
||||
groupby_columns: [],
|
||||
},
|
||||
});
|
||||
});
|
||||
|
|
@ -95,10 +93,8 @@ test('should do resample on x-axis', () => {
|
|||
operation: 'resample',
|
||||
options: {
|
||||
fill_value: null,
|
||||
groupby_columns: [],
|
||||
method: 'ffill',
|
||||
rule: '1D',
|
||||
time_column: 'ds',
|
||||
},
|
||||
});
|
||||
});
|
||||
|
|
@ -115,81 +111,6 @@ test('should do zerofill resample', () => {
|
|||
method: 'asfreq',
|
||||
rule: '1D',
|
||||
fill_value: 0,
|
||||
time_column: '__timestamp',
|
||||
groupby_columns: [],
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
test('should append physical column to resample', () => {
|
||||
expect(
|
||||
resampleOperator(
|
||||
{ ...formData, resample_method: 'zerofill', resample_rule: '1D' },
|
||||
{ ...queryObject, columns: ['column1', 'column2'] },
|
||||
),
|
||||
).toEqual({
|
||||
operation: 'resample',
|
||||
options: {
|
||||
method: 'asfreq',
|
||||
rule: '1D',
|
||||
fill_value: 0,
|
||||
time_column: '__timestamp',
|
||||
groupby_columns: ['column1', 'column2'],
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
test('should append label of adhoc column and physical column to resample', () => {
|
||||
expect(
|
||||
resampleOperator(
|
||||
{ ...formData, resample_method: 'zerofill', resample_rule: '1D' },
|
||||
{
|
||||
...queryObject,
|
||||
columns: [
|
||||
{
|
||||
hasCustomLabel: true,
|
||||
label: 'concat_a_b',
|
||||
expressionType: 'SQL',
|
||||
sqlExpression: "'a' + 'b'",
|
||||
} as AdhocColumn,
|
||||
'column2',
|
||||
],
|
||||
},
|
||||
),
|
||||
).toEqual({
|
||||
operation: 'resample',
|
||||
options: {
|
||||
method: 'asfreq',
|
||||
rule: '1D',
|
||||
fill_value: 0,
|
||||
time_column: '__timestamp',
|
||||
groupby_columns: ['concat_a_b', 'column2'],
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
test('should append `undefined` if adhoc non-existing label', () => {
|
||||
expect(
|
||||
resampleOperator(
|
||||
{ ...formData, resample_method: 'zerofill', resample_rule: '1D' },
|
||||
{
|
||||
...queryObject,
|
||||
columns: [
|
||||
{
|
||||
sqlExpression: "'a' + 'b'",
|
||||
} as AdhocColumn,
|
||||
'column2',
|
||||
],
|
||||
},
|
||||
),
|
||||
).toEqual({
|
||||
operation: 'resample',
|
||||
options: {
|
||||
method: 'asfreq',
|
||||
rule: '1D',
|
||||
fill_value: 0,
|
||||
time_column: '__timestamp',
|
||||
groupby_columns: [undefined, 'column2'],
|
||||
},
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -79,7 +79,6 @@ test('rolling_type: cumsum', () => {
|
|||
'count(*)': 'count(*)',
|
||||
'sum(val)': 'sum(val)',
|
||||
},
|
||||
is_pivot_df: true,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
|
@ -102,42 +101,13 @@ test('rolling_type: sum/mean/std', () => {
|
|||
'count(*)': 'count(*)',
|
||||
'sum(val)': 'sum(val)',
|
||||
},
|
||||
is_pivot_df: true,
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
test('rolling window and "actual values" in the time compare', () => {
|
||||
expect(
|
||||
rollingWindowOperator(
|
||||
{
|
||||
...formData,
|
||||
rolling_type: 'cumsum',
|
||||
comparison_type: 'values',
|
||||
time_compare: ['1 year ago', '1 year later'],
|
||||
},
|
||||
queryObject,
|
||||
),
|
||||
).toEqual({
|
||||
operation: 'cum',
|
||||
options: {
|
||||
operator: 'sum',
|
||||
columns: {
|
||||
'count(*)': 'count(*)',
|
||||
'count(*)__1 year ago': 'count(*)__1 year ago',
|
||||
'count(*)__1 year later': 'count(*)__1 year later',
|
||||
'sum(val)': 'sum(val)',
|
||||
'sum(val)__1 year ago': 'sum(val)__1 year ago',
|
||||
'sum(val)__1 year later': 'sum(val)__1 year later',
|
||||
},
|
||||
is_pivot_df: true,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
test('rolling window and "difference / percentage / ratio" in the time compare', () => {
|
||||
const comparisionTypes = ['difference', 'percentage', 'ratio'];
|
||||
test('should append compared metrics when sets time compare type', () => {
|
||||
const comparisionTypes = ['values', 'difference', 'percentage', 'ratio'];
|
||||
comparisionTypes.forEach(cType => {
|
||||
expect(
|
||||
rollingWindowOperator(
|
||||
|
|
@ -154,12 +124,13 @@ test('rolling window and "difference / percentage / ratio" in the time compare',
|
|||
options: {
|
||||
operator: 'sum',
|
||||
columns: {
|
||||
[`${cType}__count(*)__count(*)__1 year ago`]: `${cType}__count(*)__count(*)__1 year ago`,
|
||||
[`${cType}__count(*)__count(*)__1 year later`]: `${cType}__count(*)__count(*)__1 year later`,
|
||||
[`${cType}__sum(val)__sum(val)__1 year ago`]: `${cType}__sum(val)__sum(val)__1 year ago`,
|
||||
[`${cType}__sum(val)__sum(val)__1 year later`]: `${cType}__sum(val)__sum(val)__1 year later`,
|
||||
'count(*)': 'count(*)',
|
||||
'count(*)__1 year ago': 'count(*)__1 year ago',
|
||||
'count(*)__1 year later': 'count(*)__1 year later',
|
||||
'sum(val)': 'sum(val)',
|
||||
'sum(val)__1 year ago': 'sum(val)__1 year ago',
|
||||
'sum(val)__1 year later': 'sum(val)__1 year later',
|
||||
},
|
||||
is_pivot_df: true,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -17,17 +17,23 @@
|
|||
* under the License.
|
||||
*/
|
||||
import { QueryObject, SqlaFormData } from '@superset-ui/core';
|
||||
import { timeCompareOperator, timeComparePivotOperator } from '../../../src';
|
||||
import { timeCompareOperator } from '../../../src';
|
||||
|
||||
const formData: SqlaFormData = {
|
||||
metrics: ['count(*)'],
|
||||
metrics: [
|
||||
'count(*)',
|
||||
{ label: 'sum(val)', expressionType: 'SQL', sqlExpression: 'sum(val)' },
|
||||
],
|
||||
time_range: '2015 : 2016',
|
||||
granularity: 'month',
|
||||
datasource: 'foo',
|
||||
viz_type: 'table',
|
||||
};
|
||||
const queryObject: QueryObject = {
|
||||
metrics: ['count(*)'],
|
||||
metrics: [
|
||||
'count(*)',
|
||||
{ label: 'sum(val)', expressionType: 'SQL', sqlExpression: 'sum(val)' },
|
||||
],
|
||||
time_range: '2015 : 2016',
|
||||
granularity: 'month',
|
||||
post_processing: [
|
||||
|
|
@ -40,21 +46,26 @@ const queryObject: QueryObject = {
|
|||
'count(*)': {
|
||||
operator: 'mean',
|
||||
},
|
||||
'sum(val)': {
|
||||
operator: 'mean',
|
||||
},
|
||||
},
|
||||
drop_missing_columns: false,
|
||||
flatten_columns: false,
|
||||
reset_index: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
operation: 'aggregation',
|
||||
options: {
|
||||
groupby: ['col1'],
|
||||
aggregates: 'count',
|
||||
aggregates: {},
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
test('time compare: skip transformation', () => {
|
||||
test('should skip CompareOperator', () => {
|
||||
expect(timeCompareOperator(formData, queryObject)).toEqual(undefined);
|
||||
expect(
|
||||
timeCompareOperator({ ...formData, time_compare: [] }, queryObject),
|
||||
|
|
@ -80,7 +91,7 @@ test('time compare: skip transformation', () => {
|
|||
).toEqual(undefined);
|
||||
});
|
||||
|
||||
test('time compare: difference/percentage/ratio', () => {
|
||||
test('should generate difference/percentage/ratio CompareOperator', () => {
|
||||
const comparisionTypes = ['difference', 'percentage', 'ratio'];
|
||||
comparisionTypes.forEach(cType => {
|
||||
expect(
|
||||
|
|
@ -95,108 +106,16 @@ test('time compare: difference/percentage/ratio', () => {
|
|||
).toEqual({
|
||||
operation: 'compare',
|
||||
options: {
|
||||
source_columns: ['count(*)', 'count(*)'],
|
||||
compare_columns: ['count(*)__1 year ago', 'count(*)__1 year later'],
|
||||
source_columns: ['count(*)', 'count(*)', 'sum(val)', 'sum(val)'],
|
||||
compare_columns: [
|
||||
'count(*)__1 year ago',
|
||||
'count(*)__1 year later',
|
||||
'sum(val)__1 year ago',
|
||||
'sum(val)__1 year later',
|
||||
],
|
||||
compare_type: cType,
|
||||
drop_original_columns: true,
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
test('time compare pivot: skip transformation', () => {
|
||||
expect(timeComparePivotOperator(formData, queryObject)).toEqual(undefined);
|
||||
expect(
|
||||
timeComparePivotOperator({ ...formData, time_compare: [] }, queryObject),
|
||||
).toEqual(undefined);
|
||||
expect(
|
||||
timeComparePivotOperator(
|
||||
{ ...formData, comparison_type: null },
|
||||
queryObject,
|
||||
),
|
||||
).toEqual(undefined);
|
||||
expect(
|
||||
timeCompareOperator(
|
||||
{ ...formData, comparison_type: 'foobar' },
|
||||
queryObject,
|
||||
),
|
||||
).toEqual(undefined);
|
||||
});
|
||||
|
||||
test('time compare pivot: values', () => {
|
||||
expect(
|
||||
timeComparePivotOperator(
|
||||
{
|
||||
...formData,
|
||||
comparison_type: 'values',
|
||||
time_compare: ['1 year ago', '1 year later'],
|
||||
},
|
||||
queryObject,
|
||||
),
|
||||
).toEqual({
|
||||
operation: 'pivot',
|
||||
options: {
|
||||
aggregates: {
|
||||
'count(*)': { operator: 'mean' },
|
||||
'count(*)__1 year ago': { operator: 'mean' },
|
||||
'count(*)__1 year later': { operator: 'mean' },
|
||||
},
|
||||
drop_missing_columns: false,
|
||||
columns: [],
|
||||
index: ['__timestamp'],
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
test('time compare pivot: difference/percentage/ratio', () => {
|
||||
const comparisionTypes = ['difference', 'percentage', 'ratio'];
|
||||
comparisionTypes.forEach(cType => {
|
||||
expect(
|
||||
timeComparePivotOperator(
|
||||
{
|
||||
...formData,
|
||||
comparison_type: cType,
|
||||
time_compare: ['1 year ago', '1 year later'],
|
||||
},
|
||||
queryObject,
|
||||
),
|
||||
).toEqual({
|
||||
operation: 'pivot',
|
||||
options: {
|
||||
aggregates: {
|
||||
[`${cType}__count(*)__count(*)__1 year ago`]: { operator: 'mean' },
|
||||
[`${cType}__count(*)__count(*)__1 year later`]: { operator: 'mean' },
|
||||
},
|
||||
drop_missing_columns: false,
|
||||
columns: [],
|
||||
index: ['__timestamp'],
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
test('time compare pivot on x-axis', () => {
|
||||
expect(
|
||||
timeComparePivotOperator(
|
||||
{
|
||||
...formData,
|
||||
comparison_type: 'values',
|
||||
time_compare: ['1 year ago', '1 year later'],
|
||||
x_axis: 'ds',
|
||||
},
|
||||
queryObject,
|
||||
),
|
||||
).toEqual({
|
||||
operation: 'pivot',
|
||||
options: {
|
||||
aggregates: {
|
||||
'count(*)': { operator: 'mean' },
|
||||
'count(*)__1 year ago': { operator: 'mean' },
|
||||
'count(*)__1 year later': { operator: 'mean' },
|
||||
},
|
||||
drop_missing_columns: false,
|
||||
columns: [],
|
||||
index: ['ds'],
|
||||
},
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -0,0 +1,137 @@
|
|||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
import { QueryObject, SqlaFormData } from '@superset-ui/core';
|
||||
import { timeCompareOperator, timeComparePivotOperator } from '../../../src';
|
||||
|
||||
const formData: SqlaFormData = {
|
||||
metrics: [
|
||||
'count(*)',
|
||||
{ label: 'sum(val)', expressionType: 'SQL', sqlExpression: 'sum(val)' },
|
||||
],
|
||||
time_range: '2015 : 2016',
|
||||
granularity: 'month',
|
||||
datasource: 'foo',
|
||||
viz_type: 'table',
|
||||
};
|
||||
const queryObject: QueryObject = {
|
||||
metrics: [
|
||||
'count(*)',
|
||||
{ label: 'sum(val)', expressionType: 'SQL', sqlExpression: 'sum(val)' },
|
||||
],
|
||||
columns: ['foo', 'bar'],
|
||||
time_range: '2015 : 2016',
|
||||
granularity: 'month',
|
||||
post_processing: [],
|
||||
};
|
||||
|
||||
test('should skip pivot', () => {
|
||||
expect(timeComparePivotOperator(formData, queryObject)).toEqual(undefined);
|
||||
expect(
|
||||
timeComparePivotOperator({ ...formData, time_compare: [] }, queryObject),
|
||||
).toEqual(undefined);
|
||||
expect(
|
||||
timeComparePivotOperator(
|
||||
{ ...formData, comparison_type: null },
|
||||
queryObject,
|
||||
),
|
||||
).toEqual(undefined);
|
||||
expect(
|
||||
timeCompareOperator(
|
||||
{ ...formData, comparison_type: 'foobar' },
|
||||
queryObject,
|
||||
),
|
||||
).toEqual(undefined);
|
||||
});
|
||||
|
||||
test('should pivot on any type of timeCompare', () => {
|
||||
const anyTimeCompareTypes = ['values', 'difference', 'percentage', 'ratio'];
|
||||
anyTimeCompareTypes.forEach(cType => {
|
||||
expect(
|
||||
timeComparePivotOperator(
|
||||
{
|
||||
...formData,
|
||||
comparison_type: cType,
|
||||
time_compare: ['1 year ago', '1 year later'],
|
||||
},
|
||||
{
|
||||
...queryObject,
|
||||
is_timeseries: true,
|
||||
},
|
||||
),
|
||||
).toEqual({
|
||||
operation: 'pivot',
|
||||
options: {
|
||||
aggregates: {
|
||||
'count(*)': { operator: 'mean' },
|
||||
'count(*)__1 year ago': { operator: 'mean' },
|
||||
'count(*)__1 year later': { operator: 'mean' },
|
||||
'sum(val)': { operator: 'mean' },
|
||||
'sum(val)__1 year ago': {
|
||||
operator: 'mean',
|
||||
},
|
||||
'sum(val)__1 year later': {
|
||||
operator: 'mean',
|
||||
},
|
||||
},
|
||||
drop_missing_columns: false,
|
||||
flatten_columns: false,
|
||||
reset_index: false,
|
||||
columns: ['foo', 'bar'],
|
||||
index: ['__timestamp'],
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
test('should pivot on x-axis', () => {
|
||||
expect(
|
||||
timeComparePivotOperator(
|
||||
{
|
||||
...formData,
|
||||
comparison_type: 'values',
|
||||
time_compare: ['1 year ago', '1 year later'],
|
||||
x_axis: 'ds',
|
||||
},
|
||||
queryObject,
|
||||
),
|
||||
).toEqual({
|
||||
operation: 'pivot',
|
||||
options: {
|
||||
aggregates: {
|
||||
'count(*)': { operator: 'mean' },
|
||||
'count(*)__1 year ago': { operator: 'mean' },
|
||||
'count(*)__1 year later': { operator: 'mean' },
|
||||
'sum(val)': {
|
||||
operator: 'mean',
|
||||
},
|
||||
'sum(val)__1 year ago': {
|
||||
operator: 'mean',
|
||||
},
|
||||
'sum(val)__1 year later': {
|
||||
operator: 'mean',
|
||||
},
|
||||
},
|
||||
drop_missing_columns: false,
|
||||
columns: ['foo', 'bar'],
|
||||
index: ['ds'],
|
||||
flatten_columns: false,
|
||||
reset_index: false,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
|
@ -64,25 +64,34 @@ export interface Aggregates {
|
|||
};
|
||||
}
|
||||
|
||||
export interface PostProcessingAggregation {
|
||||
export type DefaultPostProcessing = undefined;
|
||||
|
||||
interface _PostProcessingAggregation {
|
||||
operation: 'aggregation';
|
||||
options: {
|
||||
groupby: string[];
|
||||
aggregates: Aggregates;
|
||||
};
|
||||
}
|
||||
export type PostProcessingAggregation =
|
||||
| _PostProcessingAggregation
|
||||
| DefaultPostProcessing;
|
||||
|
||||
export interface PostProcessingBoxplot {
|
||||
export type BoxPlotQueryObjectWhiskerType = 'tukey' | 'min/max' | 'percentile';
|
||||
interface _PostProcessingBoxplot {
|
||||
operation: 'boxplot';
|
||||
options: {
|
||||
groupby: string[];
|
||||
metrics: string[];
|
||||
whisker_type: 'tukey' | 'min/max' | 'percentile';
|
||||
whisker_type: BoxPlotQueryObjectWhiskerType;
|
||||
percentiles?: [number, number];
|
||||
};
|
||||
}
|
||||
export type PostProcessingBoxplot =
|
||||
| _PostProcessingBoxplot
|
||||
| DefaultPostProcessing;
|
||||
|
||||
export interface PostProcessingContribution {
|
||||
interface _PostProcessingContribution {
|
||||
operation: 'contribution';
|
||||
options?: {
|
||||
orientation?: 'row' | 'column';
|
||||
|
|
@ -90,8 +99,11 @@ export interface PostProcessingContribution {
|
|||
rename_columns?: string[];
|
||||
};
|
||||
}
|
||||
export type PostProcessingContribution =
|
||||
| _PostProcessingContribution
|
||||
| DefaultPostProcessing;
|
||||
|
||||
export interface PostProcessingPivot {
|
||||
interface _PostProcessingPivot {
|
||||
operation: 'pivot';
|
||||
options: {
|
||||
aggregates: Aggregates;
|
||||
|
|
@ -107,8 +119,9 @@ export interface PostProcessingPivot {
|
|||
reset_index?: boolean;
|
||||
};
|
||||
}
|
||||
export type PostProcessingPivot = _PostProcessingPivot | DefaultPostProcessing;
|
||||
|
||||
export interface PostProcessingProphet {
|
||||
interface _PostProcessingProphet {
|
||||
operation: 'prophet';
|
||||
options: {
|
||||
time_grain: TimeGranularity;
|
||||
|
|
@ -119,8 +132,11 @@ export interface PostProcessingProphet {
|
|||
daily_seasonality?: boolean | number;
|
||||
};
|
||||
}
|
||||
export type PostProcessingProphet =
|
||||
| _PostProcessingProphet
|
||||
| DefaultPostProcessing;
|
||||
|
||||
export interface PostProcessingDiff {
|
||||
interface _PostProcessingDiff {
|
||||
operation: 'diff';
|
||||
options: {
|
||||
columns: string[];
|
||||
|
|
@ -128,28 +144,31 @@ export interface PostProcessingDiff {
|
|||
axis: PandasAxis;
|
||||
};
|
||||
}
|
||||
export type PostProcessingDiff = _PostProcessingDiff | DefaultPostProcessing;
|
||||
|
||||
export interface PostProcessingRolling {
|
||||
interface _PostProcessingRolling {
|
||||
operation: 'rolling';
|
||||
options: {
|
||||
rolling_type: RollingType;
|
||||
window: number;
|
||||
min_periods: number;
|
||||
columns: string[];
|
||||
is_pivot_df?: boolean;
|
||||
};
|
||||
}
|
||||
export type PostProcessingRolling =
|
||||
| _PostProcessingRolling
|
||||
| DefaultPostProcessing;
|
||||
|
||||
export interface PostProcessingCum {
|
||||
interface _PostProcessingCum {
|
||||
operation: 'cum';
|
||||
options: {
|
||||
columns: string[];
|
||||
operator: NumpyFunction;
|
||||
is_pivot_df?: boolean;
|
||||
};
|
||||
}
|
||||
export type PostProcessingCum = _PostProcessingCum | DefaultPostProcessing;
|
||||
|
||||
export interface PostProcessingCompare {
|
||||
export interface _PostProcessingCompare {
|
||||
operation: 'compare';
|
||||
options: {
|
||||
source_columns: string[];
|
||||
|
|
@ -158,26 +177,39 @@ export interface PostProcessingCompare {
|
|||
drop_original_columns: boolean;
|
||||
};
|
||||
}
|
||||
export type PostProcessingCompare =
|
||||
| _PostProcessingCompare
|
||||
| DefaultPostProcessing;
|
||||
|
||||
export interface PostProcessingSort {
|
||||
interface _PostProcessingSort {
|
||||
operation: 'sort';
|
||||
options: {
|
||||
columns: Record<string, boolean>;
|
||||
};
|
||||
}
|
||||
export type PostProcessingSort = _PostProcessingSort | DefaultPostProcessing;
|
||||
|
||||
export interface PostProcessingResample {
|
||||
interface _PostProcessingResample {
|
||||
operation: 'resample';
|
||||
options: {
|
||||
method: string;
|
||||
rule: string;
|
||||
fill_value?: number | null;
|
||||
time_column: string;
|
||||
// If AdhocColumn doesn't have a label, it will be undefined.
|
||||
// todo: we have to give an explicit label for AdhocColumn.
|
||||
groupby_columns?: Array<string | undefined>;
|
||||
};
|
||||
}
|
||||
export type PostProcessingResample =
|
||||
| _PostProcessingResample
|
||||
| DefaultPostProcessing;
|
||||
|
||||
interface _PostProcessingFlatten {
|
||||
operation: 'flatten';
|
||||
options?: {
|
||||
reset_index?: boolean;
|
||||
};
|
||||
}
|
||||
export type PostProcessingFlatten =
|
||||
| _PostProcessingFlatten
|
||||
| DefaultPostProcessing;
|
||||
|
||||
/**
|
||||
* Parameters for chart data postprocessing.
|
||||
|
|
@ -194,7 +226,8 @@ export type PostProcessingRule =
|
|||
| PostProcessingCum
|
||||
| PostProcessingCompare
|
||||
| PostProcessingSort
|
||||
| PostProcessingResample;
|
||||
| PostProcessingResample
|
||||
| PostProcessingFlatten;
|
||||
|
||||
export function isPostProcessingAggregation(
|
||||
rule?: PostProcessingRule,
|
||||
|
|
|
|||
|
|
@ -18,11 +18,14 @@
|
|||
*/
|
||||
import {
|
||||
buildQueryContext,
|
||||
DTTM_ALIAS,
|
||||
PostProcessingResample,
|
||||
QueryFormData,
|
||||
} from '@superset-ui/core';
|
||||
import { rollingWindowOperator } from '@superset-ui/chart-controls';
|
||||
import {
|
||||
flattenOperator,
|
||||
rollingWindowOperator,
|
||||
sortOperator,
|
||||
} from '@superset-ui/chart-controls';
|
||||
|
||||
const TIME_GRAIN_MAP: Record<string, string> = {
|
||||
PT1S: 'S',
|
||||
|
|
@ -47,12 +50,10 @@ const TIME_GRAIN_MAP: Record<string, string> = {
|
|||
|
||||
export default function buildQuery(formData: QueryFormData) {
|
||||
return buildQueryContext(formData, baseQueryObject => {
|
||||
// todo: move into full advanced analysis section here
|
||||
const rollingProc = rollingWindowOperator(formData, baseQueryObject);
|
||||
if (rollingProc) {
|
||||
rollingProc.options = { ...rollingProc.options, is_pivot_df: false };
|
||||
}
|
||||
const { time_grain_sqla } = formData;
|
||||
let resampleProc: PostProcessingResample | undefined;
|
||||
let resampleProc: PostProcessingResample;
|
||||
if (rollingProc && time_grain_sqla) {
|
||||
const rule = TIME_GRAIN_MAP[time_grain_sqla];
|
||||
if (rule) {
|
||||
|
|
@ -62,7 +63,6 @@ export default function buildQuery(formData: QueryFormData) {
|
|||
method: 'asfreq',
|
||||
rule,
|
||||
fill_value: null,
|
||||
time_column: DTTM_ALIAS,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
|
@ -72,16 +72,10 @@ export default function buildQuery(formData: QueryFormData) {
|
|||
...baseQueryObject,
|
||||
is_timeseries: true,
|
||||
post_processing: [
|
||||
{
|
||||
operation: 'sort',
|
||||
options: {
|
||||
columns: {
|
||||
[DTTM_ALIAS]: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
sortOperator(formData, baseQueryObject),
|
||||
resampleProc,
|
||||
rollingProc,
|
||||
flattenOperator(formData, baseQueryObject),
|
||||
],
|
||||
},
|
||||
];
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ import {
|
|||
QueryObject,
|
||||
normalizeOrderBy,
|
||||
} from '@superset-ui/core';
|
||||
import { pivotOperator } from '@superset-ui/chart-controls';
|
||||
import { flattenOperator, pivotOperator } from '@superset-ui/chart-controls';
|
||||
|
||||
export default function buildQuery(formData: QueryFormData) {
|
||||
const {
|
||||
|
|
@ -66,6 +66,7 @@ export default function buildQuery(formData: QueryFormData) {
|
|||
is_timeseries: true,
|
||||
post_processing: [
|
||||
pivotOperator(formData1, { ...baseQueryObject, is_timeseries: true }),
|
||||
flattenOperator(formData1, { ...baseQueryObject, is_timeseries: true }),
|
||||
],
|
||||
} as QueryObject;
|
||||
return [normalizeOrderBy(queryObjectA)];
|
||||
|
|
@ -77,6 +78,7 @@ export default function buildQuery(formData: QueryFormData) {
|
|||
is_timeseries: true,
|
||||
post_processing: [
|
||||
pivotOperator(formData2, { ...baseQueryObject, is_timeseries: true }),
|
||||
flattenOperator(formData2, { ...baseQueryObject, is_timeseries: true }),
|
||||
],
|
||||
} as QueryObject;
|
||||
return [normalizeOrderBy(queryObjectB)];
|
||||
|
|
|
|||
|
|
@ -22,42 +22,54 @@ import {
|
|||
ensureIsArray,
|
||||
QueryFormData,
|
||||
normalizeOrderBy,
|
||||
RollingType,
|
||||
PostProcessingPivot,
|
||||
} from '@superset-ui/core';
|
||||
import {
|
||||
rollingWindowOperator,
|
||||
timeCompareOperator,
|
||||
isValidTimeCompare,
|
||||
sortOperator,
|
||||
pivotOperator,
|
||||
resampleOperator,
|
||||
contributionOperator,
|
||||
prophetOperator,
|
||||
timeComparePivotOperator,
|
||||
flattenOperator,
|
||||
} from '@superset-ui/chart-controls';
|
||||
|
||||
export default function buildQuery(formData: QueryFormData) {
|
||||
const { x_axis, groupby } = formData;
|
||||
const is_timeseries = x_axis === DTTM_ALIAS || !x_axis;
|
||||
return buildQueryContext(formData, baseQueryObject => {
|
||||
const pivotOperatorInRuntime: PostProcessingPivot | undefined =
|
||||
pivotOperator(formData, {
|
||||
...baseQueryObject,
|
||||
index: x_axis,
|
||||
is_timeseries,
|
||||
});
|
||||
if (
|
||||
pivotOperatorInRuntime &&
|
||||
Object.values(RollingType).includes(formData.rolling_type)
|
||||
) {
|
||||
pivotOperatorInRuntime.options = {
|
||||
...pivotOperatorInRuntime.options,
|
||||
...{
|
||||
flatten_columns: false,
|
||||
reset_index: false,
|
||||
},
|
||||
};
|
||||
}
|
||||
/* the `pivotOperatorInRuntime` determines how to pivot the dataframe returned from the raw query.
|
||||
1. If it's a time compared query, there will return a pivoted dataframe that append time compared metrics. for instance:
|
||||
|
||||
MAX(value) MAX(value)__1 year ago MIN(value) MIN(value)__1 year ago
|
||||
city LA LA LA LA
|
||||
__timestamp
|
||||
2015-01-01 568.0 671.0 5.0 6.0
|
||||
2015-02-01 407.0 649.0 4.0 3.0
|
||||
2015-03-01 318.0 465.0 0.0 3.0
|
||||
|
||||
2. If it's a normal query, there will return a pivoted dataframe.
|
||||
|
||||
MAX(value) MIN(value)
|
||||
city LA LA
|
||||
__timestamp
|
||||
2015-01-01 568.0 5.0
|
||||
2015-02-01 407.0 4.0
|
||||
2015-03-01 318.0 0.0
|
||||
|
||||
*/
|
||||
const pivotOperatorInRuntime: PostProcessingPivot = isValidTimeCompare(
|
||||
formData,
|
||||
baseQueryObject,
|
||||
)
|
||||
? timeComparePivotOperator(formData, baseQueryObject)
|
||||
: pivotOperator(formData, {
|
||||
...baseQueryObject,
|
||||
index: x_axis,
|
||||
is_timeseries,
|
||||
});
|
||||
|
||||
return [
|
||||
{
|
||||
|
|
@ -70,13 +82,16 @@ export default function buildQuery(formData: QueryFormData) {
|
|||
time_offsets: isValidTimeCompare(formData, baseQueryObject)
|
||||
? formData.time_compare
|
||||
: [],
|
||||
/* Note that:
|
||||
1. The resample, rolling, cum, timeCompare operators should be after pivot.
|
||||
2. the flatOperator makes multiIndex Dataframe into flat Dataframe
|
||||
*/
|
||||
post_processing: [
|
||||
resampleOperator(formData, baseQueryObject),
|
||||
timeCompareOperator(formData, baseQueryObject),
|
||||
sortOperator(formData, { ...baseQueryObject, is_timeseries: true }),
|
||||
// in order to be able to rolling in multiple series, must do pivot before rollingOperator
|
||||
pivotOperatorInRuntime,
|
||||
rollingWindowOperator(formData, baseQueryObject),
|
||||
timeCompareOperator(formData, baseQueryObject),
|
||||
resampleOperator(formData, baseQueryObject),
|
||||
flattenOperator(formData, baseQueryObject),
|
||||
contributionOperator(formData, baseQueryObject),
|
||||
prophetOperator(formData, baseQueryObject),
|
||||
],
|
||||
|
|
|
|||
|
|
@ -767,6 +767,7 @@ class ChartDataPostProcessingOperationSchema(Schema):
|
|||
"diff",
|
||||
"compare",
|
||||
"resample",
|
||||
"flatten",
|
||||
)
|
||||
),
|
||||
example="aggregate",
|
||||
|
|
|
|||
|
|
@ -36,7 +36,11 @@ from superset.common.utils import dataframe_utils as df_utils
|
|||
from superset.common.utils.query_cache_manager import QueryCacheManager
|
||||
from superset.connectors.base.models import BaseDatasource
|
||||
from superset.constants import CacheRegion
|
||||
from superset.exceptions import QueryObjectValidationError, SupersetException
|
||||
from superset.exceptions import (
|
||||
InvalidPostProcessingError,
|
||||
QueryObjectValidationError,
|
||||
SupersetException,
|
||||
)
|
||||
from superset.extensions import cache_manager, security_manager
|
||||
from superset.models.helpers import QueryResult
|
||||
from superset.utils import csv
|
||||
|
|
@ -196,7 +200,11 @@ class QueryContextProcessor:
|
|||
query += ";\n\n".join(queries)
|
||||
query += ";\n\n"
|
||||
|
||||
df = query_object.exec_post_processing(df)
|
||||
# Re-raising QueryObjectValidationError
|
||||
try:
|
||||
df = query_object.exec_post_processing(df)
|
||||
except InvalidPostProcessingError as ex:
|
||||
raise QueryObjectValidationError from ex
|
||||
|
||||
result.df = df
|
||||
result.query = query
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@
|
|||
# pylint: disable=invalid-name
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from pprint import pformat
|
||||
|
|
@ -27,6 +28,7 @@ from pandas import DataFrame
|
|||
|
||||
from superset.common.chart_data import ChartDataResultType
|
||||
from superset.exceptions import (
|
||||
InvalidPostProcessingError,
|
||||
QueryClauseValidationException,
|
||||
QueryObjectValidationError,
|
||||
)
|
||||
|
|
@ -337,6 +339,10 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
|||
}
|
||||
return query_object_dict
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# we use `print` or `logging` output QueryObject
|
||||
return json.dumps(self.to_dict(), sort_keys=True, default=str,)
|
||||
|
||||
def cache_key(self, **extra: Any) -> str:
|
||||
"""
|
||||
The cache key is made out of the key/values from to_dict(), plus any
|
||||
|
|
@ -398,15 +404,15 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
|||
:raises QueryObjectValidationError: If the post processing operation
|
||||
is incorrect
|
||||
"""
|
||||
logger.debug("post_processing: %s", pformat(self.post_processing))
|
||||
logger.debug("post_processing: \n %s", pformat(self.post_processing))
|
||||
for post_process in self.post_processing:
|
||||
operation = post_process.get("operation")
|
||||
if not operation:
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_("`operation` property of post processing object undefined")
|
||||
)
|
||||
if not hasattr(pandas_postprocessing, operation):
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_(
|
||||
"Unsupported post processing operation: %(operation)s",
|
||||
type=operation,
|
||||
|
|
|
|||
|
|
@ -190,6 +190,10 @@ class QueryObjectValidationError(SupersetException):
|
|||
status = 400
|
||||
|
||||
|
||||
class InvalidPostProcessingError(SupersetException):
|
||||
status = 400
|
||||
|
||||
|
||||
class CacheLoadError(SupersetException):
|
||||
status = 404
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from superset.utils.pandas_postprocessing.compare import compare
|
|||
from superset.utils.pandas_postprocessing.contribution import contribution
|
||||
from superset.utils.pandas_postprocessing.cum import cum
|
||||
from superset.utils.pandas_postprocessing.diff import diff
|
||||
from superset.utils.pandas_postprocessing.flatten import flatten
|
||||
from superset.utils.pandas_postprocessing.geography import (
|
||||
geodetic_parse,
|
||||
geohash_decode,
|
||||
|
|
@ -49,5 +50,6 @@ __all__ = [
|
|||
"rolling",
|
||||
"select",
|
||||
"sort",
|
||||
"flatten",
|
||||
"_flatten_column_after_pivot",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ def aggregate(
|
|||
:param groupby: columns to aggregate
|
||||
:param aggregates: A mapping from metric column to the function used to
|
||||
aggregate values.
|
||||
:raises QueryObjectValidationError: If the request in incorrect
|
||||
:raises InvalidPostProcessingError: If the request in incorrect
|
||||
"""
|
||||
aggregates = aggregates or {}
|
||||
aggregate_funcs = _get_aggregate_funcs(df, aggregates)
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ import numpy as np
|
|||
from flask_babel import gettext as _
|
||||
from pandas import DataFrame, Series
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils.core import PostProcessingBoxplotWhiskerType
|
||||
from superset.utils.pandas_postprocessing.aggregate import aggregate
|
||||
|
||||
|
|
@ -84,7 +84,7 @@ def boxplot(
|
|||
or not isinstance(percentiles[1], (int, float))
|
||||
or percentiles[0] >= percentiles[1]
|
||||
):
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_(
|
||||
"percentiles must be a list or tuple with two numeric values, "
|
||||
"of which the first is lower than the second value"
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from flask_babel import gettext as _
|
|||
from pandas import DataFrame
|
||||
|
||||
from superset.constants import PandasPostprocessingCompare
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils.core import TIME_COMPARISION
|
||||
from superset.utils.pandas_postprocessing.utils import validate_column_args
|
||||
|
||||
|
|
@ -31,7 +31,7 @@ def compare( # pylint: disable=too-many-arguments
|
|||
df: DataFrame,
|
||||
source_columns: List[str],
|
||||
compare_columns: List[str],
|
||||
compare_type: Optional[PandasPostprocessingCompare],
|
||||
compare_type: PandasPostprocessingCompare,
|
||||
drop_original_columns: Optional[bool] = False,
|
||||
precision: Optional[int] = 4,
|
||||
) -> DataFrame:
|
||||
|
|
@ -46,31 +46,38 @@ def compare( # pylint: disable=too-many-arguments
|
|||
compare columns.
|
||||
:param precision: Round a change rate to a variable number of decimal places.
|
||||
:return: DataFrame with compared columns.
|
||||
:raises QueryObjectValidationError: If the request in incorrect.
|
||||
:raises InvalidPostProcessingError: If the request in incorrect.
|
||||
"""
|
||||
if len(source_columns) != len(compare_columns):
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_("`compare_columns` must have the same length as `source_columns`.")
|
||||
)
|
||||
if compare_type not in tuple(PandasPostprocessingCompare):
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_("`compare_type` must be `difference`, `percentage` or `ratio`")
|
||||
)
|
||||
if len(source_columns) == 0:
|
||||
return df
|
||||
|
||||
for s_col, c_col in zip(source_columns, compare_columns):
|
||||
s_df = df.loc[:, [s_col]]
|
||||
s_df.rename(columns={s_col: "__intermediate"}, inplace=True)
|
||||
c_df = df.loc[:, [c_col]]
|
||||
c_df.rename(columns={c_col: "__intermediate"}, inplace=True)
|
||||
if compare_type == PandasPostprocessingCompare.DIFF:
|
||||
diff_series = df[s_col] - df[c_col]
|
||||
diff_df = c_df - s_df
|
||||
elif compare_type == PandasPostprocessingCompare.PCT:
|
||||
diff_series = (
|
||||
((df[s_col] - df[c_col]) / df[c_col]).astype(float).round(precision)
|
||||
)
|
||||
# https://en.wikipedia.org/wiki/Relative_change_and_difference#Percentage_change
|
||||
diff_df = ((c_df - s_df) / s_df).astype(float).round(precision)
|
||||
else:
|
||||
# compare_type == "ratio"
|
||||
diff_series = (df[s_col] / df[c_col]).astype(float).round(precision)
|
||||
diff_df = diff_series.to_frame(
|
||||
name=TIME_COMPARISION.join([compare_type, s_col, c_col])
|
||||
diff_df = (c_df / s_df).astype(float).round(precision)
|
||||
|
||||
diff_df.rename(
|
||||
columns={
|
||||
"__intermediate": TIME_COMPARISION.join([compare_type, s_col, c_col])
|
||||
},
|
||||
inplace=True,
|
||||
)
|
||||
df = pd.concat([df, diff_df], axis=1)
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from typing import List, Optional
|
|||
from flask_babel import gettext as _
|
||||
from pandas import DataFrame
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils.core import PostProcessingContributionOrientation
|
||||
from superset.utils.pandas_postprocessing.utils import validate_column_args
|
||||
|
||||
|
|
@ -55,7 +55,7 @@ def contribution(
|
|||
numeric_columns = numeric_df.columns.tolist()
|
||||
for col in columns:
|
||||
if col not in numeric_columns:
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_(
|
||||
'Column "%(column)s" is not numeric or does not '
|
||||
"exists in the query results.",
|
||||
|
|
@ -65,7 +65,7 @@ def contribution(
|
|||
columns = columns or numeric_df.columns
|
||||
rename_columns = rename_columns or columns
|
||||
if len(rename_columns) != len(columns):
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_("`rename_columns` must have the same length as `columns`.")
|
||||
)
|
||||
# limit to selected columns
|
||||
|
|
|
|||
|
|
@ -14,27 +14,21 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Dict
|
||||
|
||||
from flask_babel import gettext as _
|
||||
from pandas import DataFrame
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils.pandas_postprocessing.utils import (
|
||||
_append_columns,
|
||||
_flatten_column_after_pivot,
|
||||
ALLOWLIST_CUMULATIVE_FUNCTIONS,
|
||||
validate_column_args,
|
||||
)
|
||||
|
||||
|
||||
@validate_column_args("columns")
|
||||
def cum(
|
||||
df: DataFrame,
|
||||
operator: str,
|
||||
columns: Optional[Dict[str, str]] = None,
|
||||
is_pivot_df: bool = False,
|
||||
) -> DataFrame:
|
||||
def cum(df: DataFrame, operator: str, columns: Dict[str, str],) -> DataFrame:
|
||||
"""
|
||||
Calculate cumulative sum/product/min/max for select columns.
|
||||
|
||||
|
|
@ -45,29 +39,16 @@ def cum(
|
|||
`y2` based on cumulative values calculated from `y`, leaving the original
|
||||
column `y` unchanged.
|
||||
:param operator: cumulative operator, e.g. `sum`, `prod`, `min`, `max`
|
||||
:param is_pivot_df: Dataframe is pivoted or not
|
||||
:return: DataFrame with cumulated columns
|
||||
"""
|
||||
columns = columns or {}
|
||||
if is_pivot_df:
|
||||
df_cum = df
|
||||
else:
|
||||
df_cum = df[columns.keys()]
|
||||
df_cum = df.loc[:, columns.keys()]
|
||||
operation = "cum" + operator
|
||||
if operation not in ALLOWLIST_CUMULATIVE_FUNCTIONS or not hasattr(
|
||||
df_cum, operation
|
||||
):
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_("Invalid cumulative operator: %(operator)s", operator=operator)
|
||||
)
|
||||
if is_pivot_df:
|
||||
df_cum = getattr(df_cum, operation)()
|
||||
agg_in_pivot_df = df.columns.get_level_values(0).drop_duplicates().to_list()
|
||||
agg: Dict[str, Dict[str, Any]] = {col: {} for col in agg_in_pivot_df}
|
||||
df_cum.columns = [
|
||||
_flatten_column_after_pivot(col, agg) for col in df_cum.columns
|
||||
]
|
||||
df_cum.reset_index(level=0, inplace=True)
|
||||
else:
|
||||
df_cum = _append_columns(df, getattr(df_cum, operation)(), columns)
|
||||
df_cum = _append_columns(df, getattr(df_cum, operation)(), columns)
|
||||
return df_cum
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ def diff(
|
|||
:param periods: periods to shift for calculating difference.
|
||||
:param axis: 0 for row, 1 for column. default 0.
|
||||
:return: DataFrame with diffed columns
|
||||
:raises QueryObjectValidationError: If the request in incorrect
|
||||
:raises InvalidPostProcessingError: If the request in incorrect
|
||||
"""
|
||||
df_diff = df[columns.keys()]
|
||||
df_diff = df_diff.diff(periods=periods, axis=axis)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,81 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import pandas as pd
|
||||
|
||||
from superset.utils.pandas_postprocessing.utils import (
|
||||
_is_multi_index_on_columns,
|
||||
FLAT_COLUMN_SEPARATOR,
|
||||
)
|
||||
|
||||
|
||||
def flatten(df: pd.DataFrame, reset_index: bool = True,) -> pd.DataFrame:
|
||||
"""
|
||||
Convert N-dimensional DataFrame to a flat DataFrame
|
||||
|
||||
:param df: N-dimensional DataFrame.
|
||||
:param reset_index: Convert index to column when df.index isn't RangeIndex
|
||||
:return: a flat DataFrame
|
||||
|
||||
Examples
|
||||
-----------
|
||||
|
||||
Convert DatetimeIndex into columns.
|
||||
|
||||
>>> index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03",])
|
||||
>>> index.name = "__timestamp"
|
||||
>>> df = pd.DataFrame(index=index, data={"metric": [1, 2, 3]})
|
||||
>>> df
|
||||
metric
|
||||
__timestamp
|
||||
2021-01-01 1
|
||||
2021-01-02 2
|
||||
2021-01-03 3
|
||||
>>> df = flatten(df)
|
||||
>>> df
|
||||
__timestamp metric
|
||||
0 2021-01-01 1
|
||||
1 2021-01-02 2
|
||||
2 2021-01-03 3
|
||||
|
||||
Convert DatetimeIndex and MultipleIndex into columns
|
||||
|
||||
>>> iterables = [["foo", "bar"], ["one", "two"]]
|
||||
>>> columns = pd.MultiIndex.from_product(iterables, names=["level1", "level2"])
|
||||
>>> df = pd.DataFrame(index=index, columns=columns, data=1)
|
||||
>>> df
|
||||
level1 foo bar
|
||||
level2 one two one two
|
||||
__timestamp
|
||||
2021-01-01 1 1 1 1
|
||||
2021-01-02 1 1 1 1
|
||||
2021-01-03 1 1 1 1
|
||||
>>> flatten(df)
|
||||
__timestamp foo, one foo, two bar, one bar, two
|
||||
0 2021-01-01 1 1 1 1
|
||||
1 2021-01-02 1 1 1 1
|
||||
2 2021-01-03 1 1 1 1
|
||||
"""
|
||||
if _is_multi_index_on_columns(df):
|
||||
# every cell should be converted to string
|
||||
df.columns = [
|
||||
FLAT_COLUMN_SEPARATOR.join([str(cell) for cell in series])
|
||||
for series in df.columns.to_flat_index()
|
||||
]
|
||||
|
||||
if reset_index and not isinstance(df.index, pd.RangeIndex):
|
||||
df = df.reset_index(level=0)
|
||||
return df
|
||||
|
|
@ -21,7 +21,7 @@ from flask_babel import gettext as _
|
|||
from geopy.point import Point
|
||||
from pandas import DataFrame
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils.pandas_postprocessing.utils import _append_columns
|
||||
|
||||
|
||||
|
|
@ -46,7 +46,7 @@ def geohash_decode(
|
|||
df, lonlat_df, {"latitude": latitude, "longitude": longitude}
|
||||
)
|
||||
except ValueError as ex:
|
||||
raise QueryObjectValidationError(_("Invalid geohash string")) from ex
|
||||
raise InvalidPostProcessingError(_("Invalid geohash string")) from ex
|
||||
|
||||
|
||||
def geohash_encode(
|
||||
|
|
@ -69,7 +69,7 @@ def geohash_encode(
|
|||
)
|
||||
return _append_columns(df, encode_df, {"geohash": geohash})
|
||||
except ValueError as ex:
|
||||
raise QueryObjectValidationError(_("Invalid longitude/latitude")) from ex
|
||||
raise InvalidPostProcessingError(_("Invalid longitude/latitude")) from ex
|
||||
|
||||
|
||||
def geodetic_parse(
|
||||
|
|
@ -111,4 +111,4 @@ def geodetic_parse(
|
|||
columns["altitude"] = altitude
|
||||
return _append_columns(df, geodetic_df, columns)
|
||||
except ValueError as ex:
|
||||
raise QueryObjectValidationError(_("Invalid geodetic string")) from ex
|
||||
raise InvalidPostProcessingError(_("Invalid geodetic string")) from ex
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from flask_babel import gettext as _
|
|||
from pandas import DataFrame
|
||||
|
||||
from superset.constants import NULL_STRING, PandasAxis
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils.pandas_postprocessing.utils import (
|
||||
_flatten_column_after_pivot,
|
||||
_get_aggregate_funcs,
|
||||
|
|
@ -64,14 +64,14 @@ def pivot( # pylint: disable=too-many-arguments,too-many-locals
|
|||
:param flatten_columns: Convert column names to strings
|
||||
:param reset_index: Convert index to column
|
||||
:return: A pivot table
|
||||
:raises QueryObjectValidationError: If the request in incorrect
|
||||
:raises InvalidPostProcessingError: If the request in incorrect
|
||||
"""
|
||||
if not index:
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_("Pivot operation requires at least one index")
|
||||
)
|
||||
if not aggregates:
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_("Pivot operation must include at least one aggregate")
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from typing import Optional, Union
|
|||
from flask_babel import gettext as _
|
||||
from pandas import DataFrame
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils.core import DTTM_ALIAS
|
||||
from superset.utils.pandas_postprocessing.utils import PROPHET_TIME_GRAIN_MAP
|
||||
|
||||
|
|
@ -58,7 +58,7 @@ def _prophet_fit_and_predict( # pylint: disable=too-many-arguments
|
|||
prophet_logger.setLevel(logging.CRITICAL)
|
||||
prophet_logger.setLevel(logging.NOTSET)
|
||||
except ModuleNotFoundError as ex:
|
||||
raise QueryObjectValidationError(_("`prophet` package not installed")) from ex
|
||||
raise InvalidPostProcessingError(_("`prophet` package not installed")) from ex
|
||||
model = Prophet(
|
||||
interval_width=confidence_interval,
|
||||
yearly_seasonality=yearly_seasonality,
|
||||
|
|
@ -111,24 +111,24 @@ def prophet( # pylint: disable=too-many-arguments
|
|||
index = index or DTTM_ALIAS
|
||||
# validate inputs
|
||||
if not time_grain:
|
||||
raise QueryObjectValidationError(_("Time grain missing"))
|
||||
raise InvalidPostProcessingError(_("Time grain missing"))
|
||||
if time_grain not in PROPHET_TIME_GRAIN_MAP:
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_("Unsupported time grain: %(time_grain)s", time_grain=time_grain,)
|
||||
)
|
||||
freq = PROPHET_TIME_GRAIN_MAP[time_grain]
|
||||
# check type at runtime due to marhsmallow schema not being able to handle
|
||||
# union types
|
||||
if not isinstance(periods, int) or periods < 0:
|
||||
raise QueryObjectValidationError(_("Periods must be a whole number"))
|
||||
raise InvalidPostProcessingError(_("Periods must be a whole number"))
|
||||
if not confidence_interval or confidence_interval <= 0 or confidence_interval >= 1:
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_("Confidence interval must be between 0 and 1 (exclusive)")
|
||||
)
|
||||
if index not in df.columns:
|
||||
raise QueryObjectValidationError(_("DataFrame must include temporal column"))
|
||||
raise InvalidPostProcessingError(_("DataFrame must include temporal column"))
|
||||
if len(df.columns) < 2:
|
||||
raise QueryObjectValidationError(_("DataFrame include at least one series"))
|
||||
raise InvalidPostProcessingError(_("DataFrame include at least one series"))
|
||||
|
||||
target_df = DataFrame()
|
||||
for column in [column for column in df.columns if column != index]:
|
||||
|
|
|
|||
|
|
@ -14,48 +14,35 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from pandas import DataFrame
|
||||
import pandas as pd
|
||||
from flask_babel import gettext as _
|
||||
|
||||
from superset.utils.pandas_postprocessing.utils import validate_column_args
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
|
||||
|
||||
@validate_column_args("groupby_columns")
|
||||
def resample( # pylint: disable=too-many-arguments
|
||||
df: DataFrame,
|
||||
def resample(
|
||||
df: pd.DataFrame,
|
||||
rule: str,
|
||||
method: str,
|
||||
time_column: str,
|
||||
groupby_columns: Optional[Tuple[Optional[str], ...]] = None,
|
||||
fill_value: Optional[Union[float, int]] = None,
|
||||
) -> DataFrame:
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
support upsampling in resample
|
||||
|
||||
:param df: DataFrame to resample.
|
||||
:param rule: The offset string representing target conversion.
|
||||
:param method: How to fill the NaN value after resample.
|
||||
:param time_column: existing columns in DataFrame.
|
||||
:param groupby_columns: columns except time_column in dataframe
|
||||
:param fill_value: What values do fill missing.
|
||||
:return: DataFrame after resample
|
||||
:raises QueryObjectValidationError: If the request in incorrect
|
||||
:raises InvalidPostProcessingError: If the request in incorrect
|
||||
"""
|
||||
if not isinstance(df.index, pd.DatetimeIndex):
|
||||
raise InvalidPostProcessingError(_("Resample operation requires DatetimeIndex"))
|
||||
|
||||
def _upsampling(_df: DataFrame) -> DataFrame:
|
||||
_df = _df.set_index(time_column)
|
||||
if method == "asfreq" and fill_value is not None:
|
||||
return _df.resample(rule).asfreq(fill_value=fill_value)
|
||||
return getattr(_df.resample(rule), method)()
|
||||
|
||||
if groupby_columns:
|
||||
df = (
|
||||
df.set_index(keys=list(groupby_columns))
|
||||
.groupby(by=list(groupby_columns))
|
||||
.apply(_upsampling)
|
||||
)
|
||||
df = df.reset_index().set_index(time_column).sort_index()
|
||||
if method == "asfreq" and fill_value is not None:
|
||||
_df = df.resample(rule).asfreq(fill_value=fill_value)
|
||||
else:
|
||||
df = _upsampling(df)
|
||||
return df.reset_index()
|
||||
_df = getattr(df.resample(rule), method)()
|
||||
return _df
|
||||
|
|
|
|||
|
|
@ -19,10 +19,9 @@ from typing import Any, Dict, Optional, Union
|
|||
from flask_babel import gettext as _
|
||||
from pandas import DataFrame
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils.pandas_postprocessing.utils import (
|
||||
_append_columns,
|
||||
_flatten_column_after_pivot,
|
||||
DENYLIST_ROLLING_FUNCTIONS,
|
||||
validate_column_args,
|
||||
)
|
||||
|
|
@ -32,13 +31,12 @@ from superset.utils.pandas_postprocessing.utils import (
|
|||
def rolling( # pylint: disable=too-many-arguments
|
||||
df: DataFrame,
|
||||
rolling_type: str,
|
||||
columns: Optional[Dict[str, str]] = None,
|
||||
columns: Dict[str, str],
|
||||
window: Optional[int] = None,
|
||||
rolling_type_options: Optional[Dict[str, Any]] = None,
|
||||
center: bool = False,
|
||||
win_type: Optional[str] = None,
|
||||
min_periods: Optional[int] = None,
|
||||
is_pivot_df: bool = False,
|
||||
) -> DataFrame:
|
||||
"""
|
||||
Apply a rolling window on the dataset. See the Pandas docs for further details:
|
||||
|
|
@ -58,21 +56,17 @@ def rolling( # pylint: disable=too-many-arguments
|
|||
:param win_type: Type of window function.
|
||||
:param min_periods: The minimum amount of periods required for a row to be included
|
||||
in the result set.
|
||||
:param is_pivot_df: Dataframe is pivoted or not
|
||||
:return: DataFrame with the rolling columns
|
||||
:raises QueryObjectValidationError: If the request in incorrect
|
||||
:raises InvalidPostProcessingError: If the request in incorrect
|
||||
"""
|
||||
rolling_type_options = rolling_type_options or {}
|
||||
columns = columns or {}
|
||||
if is_pivot_df:
|
||||
df_rolling = df
|
||||
else:
|
||||
df_rolling = df[columns.keys()]
|
||||
df_rolling = df.loc[:, columns.keys()]
|
||||
|
||||
kwargs: Dict[str, Union[str, int]] = {}
|
||||
if window is None:
|
||||
raise QueryObjectValidationError(_("Undefined window for rolling operation"))
|
||||
raise InvalidPostProcessingError(_("Undefined window for rolling operation"))
|
||||
if window == 0:
|
||||
raise QueryObjectValidationError(_("Window must be > 0"))
|
||||
raise InvalidPostProcessingError(_("Window must be > 0"))
|
||||
|
||||
kwargs["window"] = window
|
||||
if min_periods is not None:
|
||||
|
|
@ -86,13 +80,13 @@ def rolling( # pylint: disable=too-many-arguments
|
|||
if rolling_type not in DENYLIST_ROLLING_FUNCTIONS or not hasattr(
|
||||
df_rolling, rolling_type
|
||||
):
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_("Invalid rolling_type: %(type)s", type=rolling_type)
|
||||
)
|
||||
try:
|
||||
df_rolling = getattr(df_rolling, rolling_type)(**rolling_type_options)
|
||||
except TypeError as ex:
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_(
|
||||
"Invalid options for %(rolling_type)s: %(options)s",
|
||||
rolling_type=rolling_type,
|
||||
|
|
@ -100,15 +94,7 @@ def rolling( # pylint: disable=too-many-arguments
|
|||
)
|
||||
) from ex
|
||||
|
||||
if is_pivot_df:
|
||||
agg_in_pivot_df = df.columns.get_level_values(0).drop_duplicates().to_list()
|
||||
agg: Dict[str, Dict[str, Any]] = {col: {} for col in agg_in_pivot_df}
|
||||
df_rolling.columns = [
|
||||
_flatten_column_after_pivot(col, agg) for col in df_rolling.columns
|
||||
]
|
||||
df_rolling.reset_index(level=0, inplace=True)
|
||||
else:
|
||||
df_rolling = _append_columns(df, df_rolling, columns)
|
||||
df_rolling = _append_columns(df, df_rolling, columns)
|
||||
|
||||
if min_periods:
|
||||
df_rolling = df_rolling[min_periods:]
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ def select(
|
|||
For instance, `{'y': 'y2'}` will rename the column `y` to
|
||||
`y2`.
|
||||
:return: Subset of columns in original DataFrame
|
||||
:raises QueryObjectValidationError: If the request in incorrect
|
||||
:raises InvalidPostProcessingError: If the request in incorrect
|
||||
"""
|
||||
df_select = df.copy(deep=False)
|
||||
if columns:
|
||||
|
|
|
|||
|
|
@ -30,6 +30,6 @@ def sort(df: DataFrame, columns: Dict[str, bool]) -> DataFrame:
|
|||
:param columns: columns by by which to sort. The key specifies the column name,
|
||||
value specifies if sorting in ascending order.
|
||||
:return: Sorted DataFrame
|
||||
:raises QueryObjectValidationError: If the request in incorrect
|
||||
:raises InvalidPostProcessingError: If the request in incorrect
|
||||
"""
|
||||
return df.sort_values(by=list(columns.keys()), ascending=list(columns.values()))
|
||||
|
|
|
|||
|
|
@ -18,10 +18,11 @@ from functools import partial
|
|||
from typing import Any, Callable, Dict, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from flask_babel import gettext as _
|
||||
from pandas import DataFrame, NamedAgg, Timestamp
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
|
||||
NUMPY_FUNCTIONS = {
|
||||
"average": np.average,
|
||||
|
|
@ -91,6 +92,8 @@ PROPHET_TIME_GRAIN_MAP = {
|
|||
"P1W/1970-01-04T00:00:00Z": "W",
|
||||
}
|
||||
|
||||
FLAT_COLUMN_SEPARATOR = ", "
|
||||
|
||||
|
||||
def _flatten_column_after_pivot(
|
||||
column: Union[float, Timestamp, str, Tuple[str, ...]],
|
||||
|
|
@ -113,21 +116,26 @@ def _flatten_column_after_pivot(
|
|||
# drop aggregate for single aggregate pivots with multiple groupings
|
||||
# from column name (aggregates always come first in column name)
|
||||
column = column[1:]
|
||||
return ", ".join([str(col) for col in column])
|
||||
return FLAT_COLUMN_SEPARATOR.join([str(col) for col in column])
|
||||
|
||||
|
||||
def _is_multi_index_on_columns(df: DataFrame) -> bool:
|
||||
return isinstance(df.columns, pd.MultiIndex)
|
||||
|
||||
|
||||
def validate_column_args(*argnames: str) -> Callable[..., Any]:
|
||||
def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
def wrapped(df: DataFrame, **options: Any) -> Any:
|
||||
if options.get("is_pivot_df"):
|
||||
# skip validation when pivot Dataframe
|
||||
return func(df, **options)
|
||||
columns = df.columns.tolist()
|
||||
if _is_multi_index_on_columns(df):
|
||||
# MultiIndex column validate first level
|
||||
columns = df.columns.get_level_values(0)
|
||||
else:
|
||||
columns = df.columns.tolist()
|
||||
for name in argnames:
|
||||
if name in options and not all(
|
||||
elem in columns for elem in options.get(name) or []
|
||||
):
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_("Referenced columns not available in DataFrame.")
|
||||
)
|
||||
return func(df, **options)
|
||||
|
|
@ -152,14 +160,14 @@ def _get_aggregate_funcs(
|
|||
for name, agg_obj in aggregates.items():
|
||||
column = agg_obj.get("column", name)
|
||||
if column not in df:
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_(
|
||||
"Column referenced by aggregate is undefined: %(column)s",
|
||||
column=column,
|
||||
)
|
||||
)
|
||||
if "operator" not in agg_obj:
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_("Operator undefined for aggregator: %(name)s", name=name,)
|
||||
)
|
||||
operator = agg_obj["operator"]
|
||||
|
|
@ -168,7 +176,7 @@ def _get_aggregate_funcs(
|
|||
else:
|
||||
func = NUMPY_FUNCTIONS.get(operator)
|
||||
if not func:
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_("Invalid numpy function: %(operator)s", operator=operator,)
|
||||
)
|
||||
options = agg_obj.get("options", {})
|
||||
|
|
@ -186,6 +194,8 @@ def _append_columns(
|
|||
assign method, which overwrites the original column in `base_df` if the column
|
||||
already exists, and appends the column if the name is not defined.
|
||||
|
||||
Note that! this is a memory-intensive operation.
|
||||
|
||||
:param base_df: DataFrame which to use as the base
|
||||
:param append_df: DataFrame from which to select data.
|
||||
:param columns: columns on which to append, mapping source column to
|
||||
|
|
@ -196,6 +206,10 @@ def _append_columns(
|
|||
in `base_df` unchanged.
|
||||
:return: new DataFrame with combined data from `base_df` and `append_df`
|
||||
"""
|
||||
return base_df.assign(
|
||||
**{target: append_df[source] for source, target in columns.items()}
|
||||
)
|
||||
if all(key == value for key, value in columns.items()):
|
||||
# make sure to return a new DataFrame instead of changing the `base_df`.
|
||||
_base_df = base_df.copy()
|
||||
_base_df.loc[:, columns.keys()] = append_df
|
||||
return _base_df
|
||||
append_df = append_df.rename(columns=columns)
|
||||
return pd.concat([base_df, append_df], axis="columns")
|
||||
|
|
|
|||
|
|
@ -172,18 +172,21 @@ POSTPROCESSING_OPERATIONS = {
|
|||
{
|
||||
"operation": "aggregate",
|
||||
"options": {
|
||||
"groupby": ["gender"],
|
||||
"groupby": ["name"],
|
||||
"aggregates": {
|
||||
"q1": {
|
||||
"operator": "percentile",
|
||||
"column": "sum__num",
|
||||
"options": {"q": 25},
|
||||
# todo: rename "interpolation" to "method" when we updated
|
||||
# numpy.
|
||||
# https://numpy.org/doc/stable/reference/generated/numpy.percentile.html
|
||||
"options": {"q": 25, "interpolation": "lower"},
|
||||
},
|
||||
"median": {"operator": "median", "column": "sum__num",},
|
||||
},
|
||||
},
|
||||
},
|
||||
{"operation": "sort", "options": {"columns": {"q1": False, "gender": True},},},
|
||||
{"operation": "sort", "options": {"columns": {"q1": False, "name": True},},},
|
||||
]
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import datetime
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
|
@ -30,7 +29,7 @@ from superset.common.query_object import QueryObject
|
|||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.connectors.sqla.models import SqlMetric
|
||||
from superset.extensions import cache_manager
|
||||
from superset.utils.core import AdhocMetricExpressionType, backend
|
||||
from superset.utils.core import AdhocMetricExpressionType, backend, QueryStatus
|
||||
from tests.integration_tests.base_tests import SupersetTestCase
|
||||
from tests.integration_tests.fixtures.birth_names_dashboard import (
|
||||
load_birth_names_dashboard_with_slices,
|
||||
|
|
@ -91,8 +90,9 @@ class TestQueryContext(SupersetTestCase):
|
|||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_cache(self):
|
||||
table_name = "birth_names"
|
||||
table = self.get_table(name=table_name)
|
||||
payload = get_query_context(table_name, table.id)
|
||||
payload = get_query_context(
|
||||
query_name=table_name, add_postprocessing_operations=True,
|
||||
)
|
||||
payload["force"] = True
|
||||
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
|
|
@ -100,6 +100,10 @@ class TestQueryContext(SupersetTestCase):
|
|||
query_cache_key = query_context.query_cache_key(query_object)
|
||||
|
||||
response = query_context.get_payload(cache_query_context=True)
|
||||
# MUST BE a successful query
|
||||
query_dump = response["queries"][0]
|
||||
assert query_dump["status"] == QueryStatus.SUCCESS
|
||||
|
||||
cache_key = response["cache_key"]
|
||||
assert cache_key is not None
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@
|
|||
# under the License.
|
||||
import pytest
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils.core import PostProcessingBoxplotWhiskerType
|
||||
from superset.utils.pandas_postprocessing import boxplot
|
||||
from tests.unit_tests.fixtures.dataframes import names_df
|
||||
|
|
@ -90,7 +90,7 @@ def test_boxplot_percentile():
|
|||
|
||||
|
||||
def test_boxplot_percentile_incorrect_params():
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
boxplot(
|
||||
df=names_df,
|
||||
groupby=["region"],
|
||||
|
|
@ -98,7 +98,7 @@ def test_boxplot_percentile_incorrect_params():
|
|||
metrics=["cars"],
|
||||
)
|
||||
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
boxplot(
|
||||
df=names_df,
|
||||
groupby=["region"],
|
||||
|
|
@ -107,7 +107,7 @@ def test_boxplot_percentile_incorrect_params():
|
|||
percentiles=[10],
|
||||
)
|
||||
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
boxplot(
|
||||
df=names_df,
|
||||
groupby=["region"],
|
||||
|
|
@ -116,7 +116,7 @@ def test_boxplot_percentile_incorrect_params():
|
|||
percentiles=[90, 10],
|
||||
)
|
||||
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
boxplot(
|
||||
df=names_df,
|
||||
groupby=["region"],
|
||||
|
|
|
|||
|
|
@ -14,49 +14,220 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import pandas as pd
|
||||
|
||||
from superset.utils.pandas_postprocessing import compare
|
||||
from tests.unit_tests.fixtures.dataframes import timeseries_df2
|
||||
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
|
||||
from superset.constants import PandasPostprocessingCompare as PPC
|
||||
from superset.utils import pandas_postprocessing as pp
|
||||
from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
|
||||
from tests.unit_tests.fixtures.dataframes import multiple_metrics_df, timeseries_df2
|
||||
|
||||
|
||||
def test_compare():
|
||||
def test_compare_should_not_side_effect():
|
||||
_timeseries_df2 = timeseries_df2.copy()
|
||||
pp.compare(
|
||||
df=_timeseries_df2,
|
||||
source_columns=["y"],
|
||||
compare_columns=["z"],
|
||||
compare_type=PPC.DIFF,
|
||||
)
|
||||
assert _timeseries_df2.equals(timeseries_df2)
|
||||
|
||||
|
||||
def test_compare_diff():
|
||||
# `difference` comparison
|
||||
post_df = compare(
|
||||
post_df = pp.compare(
|
||||
df=timeseries_df2,
|
||||
source_columns=["y"],
|
||||
compare_columns=["z"],
|
||||
compare_type="difference",
|
||||
compare_type=PPC.DIFF,
|
||||
)
|
||||
"""
|
||||
label y z difference__y__z
|
||||
2019-01-01 x 2.0 2.0 0.0
|
||||
2019-01-02 y 2.0 4.0 2.0
|
||||
2019-01-05 z 2.0 10.0 8.0
|
||||
2019-01-07 q 2.0 8.0 6.0
|
||||
"""
|
||||
assert post_df.equals(
|
||||
pd.DataFrame(
|
||||
index=timeseries_df2.index,
|
||||
data={
|
||||
"label": ["x", "y", "z", "q"],
|
||||
"y": [2.0, 2.0, 2.0, 2.0],
|
||||
"z": [2.0, 4.0, 10.0, 8.0],
|
||||
"difference__y__z": [0.0, 2.0, 8.0, 6.0],
|
||||
},
|
||||
)
|
||||
)
|
||||
assert post_df.columns.tolist() == ["label", "y", "z", "difference__y__z"]
|
||||
assert series_to_list(post_df["difference__y__z"]) == [0.0, -2.0, -8.0, -6.0]
|
||||
|
||||
# drop original columns
|
||||
post_df = compare(
|
||||
post_df = pp.compare(
|
||||
df=timeseries_df2,
|
||||
source_columns=["y"],
|
||||
compare_columns=["z"],
|
||||
compare_type="difference",
|
||||
compare_type=PPC.DIFF,
|
||||
drop_original_columns=True,
|
||||
)
|
||||
assert post_df.columns.tolist() == ["label", "difference__y__z"]
|
||||
assert post_df.equals(
|
||||
pd.DataFrame(
|
||||
index=timeseries_df2.index,
|
||||
data={
|
||||
"label": ["x", "y", "z", "q"],
|
||||
"difference__y__z": [0.0, 2.0, 8.0, 6.0],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_compare_percentage():
|
||||
# `percentage` comparison
|
||||
post_df = compare(
|
||||
post_df = pp.compare(
|
||||
df=timeseries_df2,
|
||||
source_columns=["y"],
|
||||
compare_columns=["z"],
|
||||
compare_type="percentage",
|
||||
compare_type=PPC.PCT,
|
||||
)
|
||||
"""
|
||||
label y z percentage__y__z
|
||||
2019-01-01 x 2.0 2.0 0.0
|
||||
2019-01-02 y 2.0 4.0 1.0
|
||||
2019-01-05 z 2.0 10.0 4.0
|
||||
2019-01-07 q 2.0 8.0 3.0
|
||||
"""
|
||||
assert post_df.equals(
|
||||
pd.DataFrame(
|
||||
index=timeseries_df2.index,
|
||||
data={
|
||||
"label": ["x", "y", "z", "q"],
|
||||
"y": [2.0, 2.0, 2.0, 2.0],
|
||||
"z": [2.0, 4.0, 10.0, 8.0],
|
||||
"percentage__y__z": [0.0, 1.0, 4.0, 3.0],
|
||||
},
|
||||
)
|
||||
)
|
||||
assert post_df.columns.tolist() == ["label", "y", "z", "percentage__y__z"]
|
||||
assert series_to_list(post_df["percentage__y__z"]) == [0.0, -0.5, -0.8, -0.75]
|
||||
|
||||
|
||||
def test_compare_ratio():
|
||||
# `ratio` comparison
|
||||
post_df = compare(
|
||||
post_df = pp.compare(
|
||||
df=timeseries_df2,
|
||||
source_columns=["y"],
|
||||
compare_columns=["z"],
|
||||
compare_type="ratio",
|
||||
compare_type=PPC.RAT,
|
||||
)
|
||||
"""
|
||||
label y z ratio__y__z
|
||||
2019-01-01 x 2.0 2.0 1.0
|
||||
2019-01-02 y 2.0 4.0 2.0
|
||||
2019-01-05 z 2.0 10.0 5.0
|
||||
2019-01-07 q 2.0 8.0 4.0
|
||||
"""
|
||||
assert post_df.equals(
|
||||
pd.DataFrame(
|
||||
index=timeseries_df2.index,
|
||||
data={
|
||||
"label": ["x", "y", "z", "q"],
|
||||
"y": [2.0, 2.0, 2.0, 2.0],
|
||||
"z": [2.0, 4.0, 10.0, 8.0],
|
||||
"ratio__y__z": [1.0, 2.0, 5.0, 4.0],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_compare_multi_index_column():
|
||||
index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03"])
|
||||
index.name = "__timestamp"
|
||||
iterables = [["m1", "m2"], ["a", "b"], ["x", "y"]]
|
||||
columns = pd.MultiIndex.from_product(iterables, names=[None, "level1", "level2"])
|
||||
df = pd.DataFrame(index=index, columns=columns, data=1)
|
||||
"""
|
||||
m1 m2
|
||||
level1 a b a b
|
||||
level2 x y x y x y x y
|
||||
__timestamp
|
||||
2021-01-01 1 1 1 1 1 1 1 1
|
||||
2021-01-02 1 1 1 1 1 1 1 1
|
||||
2021-01-03 1 1 1 1 1 1 1 1
|
||||
"""
|
||||
post_df = pp.compare(
|
||||
df,
|
||||
source_columns=["m1"],
|
||||
compare_columns=["m2"],
|
||||
compare_type=PPC.DIFF,
|
||||
drop_original_columns=True,
|
||||
)
|
||||
flat_df = pp.flatten(post_df)
|
||||
"""
|
||||
__timestamp difference__m1__m2, a, x difference__m1__m2, a, y difference__m1__m2, b, x difference__m1__m2, b, y
|
||||
0 2021-01-01 0 0 0 0
|
||||
1 2021-01-02 0 0 0 0
|
||||
2 2021-01-03 0 0 0 0
|
||||
"""
|
||||
assert flat_df.equals(
|
||||
pd.DataFrame(
|
||||
data={
|
||||
"__timestamp": pd.to_datetime(
|
||||
["2021-01-01", "2021-01-02", "2021-01-03"]
|
||||
),
|
||||
"difference__m1__m2, a, x": [0, 0, 0],
|
||||
"difference__m1__m2, a, y": [0, 0, 0],
|
||||
"difference__m1__m2, b, x": [0, 0, 0],
|
||||
"difference__m1__m2, b, y": [0, 0, 0],
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_compare_after_pivot():
|
||||
pivot_df = pp.pivot(
|
||||
df=multiple_metrics_df,
|
||||
index=["dttm"],
|
||||
columns=["country"],
|
||||
aggregates={
|
||||
"sum_metric": {"operator": "sum"},
|
||||
"count_metric": {"operator": "sum"},
|
||||
},
|
||||
flatten_columns=False,
|
||||
reset_index=False,
|
||||
)
|
||||
"""
|
||||
count_metric sum_metric
|
||||
country UK US UK US
|
||||
dttm
|
||||
2019-01-01 1 2 5 6
|
||||
2019-01-02 3 4 7 8
|
||||
"""
|
||||
compared_df = pp.compare(
|
||||
pivot_df,
|
||||
source_columns=["count_metric"],
|
||||
compare_columns=["sum_metric"],
|
||||
compare_type=PPC.DIFF,
|
||||
drop_original_columns=True,
|
||||
)
|
||||
"""
|
||||
difference__count_metric__sum_metric
|
||||
country UK US
|
||||
dttm
|
||||
2019-01-01 4 4
|
||||
2019-01-02 4 4
|
||||
"""
|
||||
flat_df = pp.flatten(compared_df)
|
||||
"""
|
||||
dttm difference__count_metric__sum_metric, UK difference__count_metric__sum_metric, US
|
||||
0 2019-01-01 4 4
|
||||
1 2019-01-02 4 4
|
||||
"""
|
||||
assert flat_df.equals(
|
||||
pd.DataFrame(
|
||||
data={
|
||||
"dttm": pd.to_datetime(["2019-01-01", "2019-01-02"]),
|
||||
FLAT_COLUMN_SEPARATOR.join(
|
||||
["difference__count_metric__sum_metric", "UK"]
|
||||
): [4, 4],
|
||||
FLAT_COLUMN_SEPARATOR.join(
|
||||
["difference__count_metric__sum_metric", "US"]
|
||||
): [4, 4],
|
||||
}
|
||||
)
|
||||
)
|
||||
assert post_df.columns.tolist() == ["label", "y", "z", "ratio__y__z"]
|
||||
assert series_to_list(post_df["ratio__y__z"]) == [1.0, 0.5, 0.2, 0.25]
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from numpy import nan
|
|||
from numpy.testing import assert_array_equal
|
||||
from pandas import DataFrame
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils.core import DTTM_ALIAS, PostProcessingContributionOrientation
|
||||
from superset.utils.pandas_postprocessing import contribution
|
||||
|
||||
|
|
@ -40,10 +40,10 @@ def test_contribution():
|
|||
"c": [nan, nan, nan],
|
||||
}
|
||||
)
|
||||
with pytest.raises(QueryObjectValidationError, match="not numeric"):
|
||||
with pytest.raises(InvalidPostProcessingError, match="not numeric"):
|
||||
contribution(df, columns=[DTTM_ALIAS])
|
||||
|
||||
with pytest.raises(QueryObjectValidationError, match="same length"):
|
||||
with pytest.raises(InvalidPostProcessingError, match="same length"):
|
||||
contribution(df, columns=["a"], rename_columns=["aa", "bb"])
|
||||
|
||||
# cell contribution across row
|
||||
|
|
|
|||
|
|
@ -14,11 +14,12 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from pandas import to_datetime
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.utils.pandas_postprocessing import cum, pivot
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils import pandas_postprocessing as pp
|
||||
from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
|
||||
from tests.unit_tests.fixtures.dataframes import (
|
||||
multiple_metrics_df,
|
||||
single_metric_df,
|
||||
|
|
@ -27,33 +28,41 @@ from tests.unit_tests.fixtures.dataframes import (
|
|||
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
|
||||
|
||||
|
||||
def test_cum_should_not_side_effect():
|
||||
_timeseries_df = timeseries_df.copy()
|
||||
pp.cum(
|
||||
df=timeseries_df, columns={"y": "y2"}, operator="sum",
|
||||
)
|
||||
assert _timeseries_df.equals(timeseries_df)
|
||||
|
||||
|
||||
def test_cum():
|
||||
# create new column (cumsum)
|
||||
post_df = cum(df=timeseries_df, columns={"y": "y2"}, operator="sum",)
|
||||
post_df = pp.cum(df=timeseries_df, columns={"y": "y2"}, operator="sum",)
|
||||
assert post_df.columns.tolist() == ["label", "y", "y2"]
|
||||
assert series_to_list(post_df["label"]) == ["x", "y", "z", "q"]
|
||||
assert series_to_list(post_df["y"]) == [1.0, 2.0, 3.0, 4.0]
|
||||
assert series_to_list(post_df["y2"]) == [1.0, 3.0, 6.0, 10.0]
|
||||
|
||||
# overwrite column (cumprod)
|
||||
post_df = cum(df=timeseries_df, columns={"y": "y"}, operator="prod",)
|
||||
post_df = pp.cum(df=timeseries_df, columns={"y": "y"}, operator="prod",)
|
||||
assert post_df.columns.tolist() == ["label", "y"]
|
||||
assert series_to_list(post_df["y"]) == [1.0, 2.0, 6.0, 24.0]
|
||||
|
||||
# overwrite column (cummin)
|
||||
post_df = cum(df=timeseries_df, columns={"y": "y"}, operator="min",)
|
||||
post_df = pp.cum(df=timeseries_df, columns={"y": "y"}, operator="min",)
|
||||
assert post_df.columns.tolist() == ["label", "y"]
|
||||
assert series_to_list(post_df["y"]) == [1.0, 1.0, 1.0, 1.0]
|
||||
|
||||
# invalid operator
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
cum(
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
pp.cum(
|
||||
df=timeseries_df, columns={"y": "y"}, operator="abc",
|
||||
)
|
||||
|
||||
|
||||
def test_cum_with_pivot_df_and_single_metric():
|
||||
pivot_df = pivot(
|
||||
def test_cum_after_pivot_with_single_metric():
|
||||
pivot_df = pp.pivot(
|
||||
df=single_metric_df,
|
||||
index=["dttm"],
|
||||
columns=["country"],
|
||||
|
|
@ -61,19 +70,40 @@ def test_cum_with_pivot_df_and_single_metric():
|
|||
flatten_columns=False,
|
||||
reset_index=False,
|
||||
)
|
||||
cum_df = cum(df=pivot_df, operator="sum", is_pivot_df=True,)
|
||||
# dttm UK US
|
||||
# 0 2019-01-01 5 6
|
||||
# 1 2019-01-02 12 14
|
||||
assert cum_df["UK"].to_list() == [5.0, 12.0]
|
||||
assert cum_df["US"].to_list() == [6.0, 14.0]
|
||||
assert (
|
||||
cum_df["dttm"].to_list() == to_datetime(["2019-01-01", "2019-01-02"]).to_list()
|
||||
"""
|
||||
sum_metric
|
||||
country UK US
|
||||
dttm
|
||||
2019-01-01 5 6
|
||||
2019-01-02 7 8
|
||||
"""
|
||||
cum_df = pp.cum(df=pivot_df, operator="sum", columns={"sum_metric": "sum_metric"})
|
||||
"""
|
||||
sum_metric
|
||||
country UK US
|
||||
dttm
|
||||
2019-01-01 5 6
|
||||
2019-01-02 12 14
|
||||
"""
|
||||
cum_and_flat_df = pp.flatten(cum_df)
|
||||
"""
|
||||
dttm sum_metric, UK sum_metric, US
|
||||
0 2019-01-01 5 6
|
||||
1 2019-01-02 12 14
|
||||
"""
|
||||
assert cum_and_flat_df.equals(
|
||||
pd.DataFrame(
|
||||
{
|
||||
"dttm": pd.to_datetime(["2019-01-01", "2019-01-02"]),
|
||||
FLAT_COLUMN_SEPARATOR.join(["sum_metric", "UK"]): [5, 12],
|
||||
FLAT_COLUMN_SEPARATOR.join(["sum_metric", "US"]): [6, 14],
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_cum_with_pivot_df_and_multiple_metrics():
|
||||
pivot_df = pivot(
|
||||
def test_cum_after_pivot_with_multiple_metrics():
|
||||
pivot_df = pp.pivot(
|
||||
df=multiple_metrics_df,
|
||||
index=["dttm"],
|
||||
columns=["country"],
|
||||
|
|
@ -84,14 +114,39 @@ def test_cum_with_pivot_df_and_multiple_metrics():
|
|||
flatten_columns=False,
|
||||
reset_index=False,
|
||||
)
|
||||
cum_df = cum(df=pivot_df, operator="sum", is_pivot_df=True,)
|
||||
# dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US
|
||||
# 0 2019-01-01 1 2 5 6
|
||||
# 1 2019-01-02 4 6 12 14
|
||||
assert cum_df["count_metric, UK"].to_list() == [1.0, 4.0]
|
||||
assert cum_df["count_metric, US"].to_list() == [2.0, 6.0]
|
||||
assert cum_df["sum_metric, UK"].to_list() == [5.0, 12.0]
|
||||
assert cum_df["sum_metric, US"].to_list() == [6.0, 14.0]
|
||||
assert (
|
||||
cum_df["dttm"].to_list() == to_datetime(["2019-01-01", "2019-01-02"]).to_list()
|
||||
"""
|
||||
count_metric sum_metric
|
||||
country UK US UK US
|
||||
dttm
|
||||
2019-01-01 1 2 5 6
|
||||
2019-01-02 3 4 7 8
|
||||
"""
|
||||
cum_df = pp.cum(
|
||||
df=pivot_df,
|
||||
operator="sum",
|
||||
columns={"sum_metric": "sum_metric", "count_metric": "count_metric"},
|
||||
)
|
||||
"""
|
||||
count_metric sum_metric
|
||||
country UK US UK US
|
||||
dttm
|
||||
2019-01-01 1 2 5 6
|
||||
2019-01-02 4 6 12 14
|
||||
"""
|
||||
flat_df = pp.flatten(cum_df)
|
||||
"""
|
||||
dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US
|
||||
0 2019-01-01 1 2 5 6
|
||||
1 2019-01-02 4 6 12 14
|
||||
"""
|
||||
assert flat_df.equals(
|
||||
pd.DataFrame(
|
||||
{
|
||||
"dttm": pd.to_datetime(["2019-01-01", "2019-01-02"]),
|
||||
FLAT_COLUMN_SEPARATOR.join(["count_metric", "UK"]): [1, 4],
|
||||
FLAT_COLUMN_SEPARATOR.join(["count_metric", "US"]): [2, 6],
|
||||
FLAT_COLUMN_SEPARATOR.join(["sum_metric", "UK"]): [5, 12],
|
||||
FLAT_COLUMN_SEPARATOR.join(["sum_metric", "US"]): [6, 14],
|
||||
}
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@
|
|||
# under the License.
|
||||
import pytest
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils.pandas_postprocessing import diff
|
||||
from tests.unit_tests.fixtures.dataframes import timeseries_df, timeseries_df2
|
||||
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
|
||||
|
|
@ -39,7 +39,7 @@ def test_diff():
|
|||
assert series_to_list(post_df["y1"]) == [-1.0, -1.0, -1.0, None]
|
||||
|
||||
# invalid column reference
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
diff(
|
||||
df=timeseries_df, columns={"abc": "abc"},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,64 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import pandas as pd
|
||||
|
||||
from superset.utils import pandas_postprocessing as pp
|
||||
from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
|
||||
|
||||
|
||||
def test_flat_should_not_change():
|
||||
df = pd.DataFrame(data={"foo": [1, 2, 3], "bar": [4, 5, 6],})
|
||||
|
||||
assert pp.flatten(df).equals(df)
|
||||
|
||||
|
||||
def test_flat_should_not_reset_index():
|
||||
index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03"])
|
||||
index.name = "__timestamp"
|
||||
df = pd.DataFrame(index=index, data={"foo": [1, 2, 3], "bar": [4, 5, 6]})
|
||||
|
||||
assert pp.flatten(df, reset_index=False).equals(df)
|
||||
|
||||
|
||||
def test_flat_should_flat_datetime_index():
|
||||
index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03"])
|
||||
index.name = "__timestamp"
|
||||
df = pd.DataFrame(index=index, data={"foo": [1, 2, 3], "bar": [4, 5, 6]})
|
||||
|
||||
assert pp.flatten(df).equals(
|
||||
pd.DataFrame({"__timestamp": index, "foo": [1, 2, 3], "bar": [4, 5, 6],})
|
||||
)
|
||||
|
||||
|
||||
def test_flat_should_flat_multiple_index():
|
||||
index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03"])
|
||||
index.name = "__timestamp"
|
||||
iterables = [["foo", "bar"], [1, "two"]]
|
||||
columns = pd.MultiIndex.from_product(iterables, names=["level1", "level2"])
|
||||
df = pd.DataFrame(index=index, columns=columns, data=1)
|
||||
|
||||
assert pp.flatten(df).equals(
|
||||
pd.DataFrame(
|
||||
{
|
||||
"__timestamp": index,
|
||||
FLAT_COLUMN_SEPARATOR.join(["foo", "1"]): [1, 1, 1],
|
||||
FLAT_COLUMN_SEPARATOR.join(["foo", "two"]): [1, 1, 1],
|
||||
FLAT_COLUMN_SEPARATOR.join(["bar", "1"]): [1, 1, 1],
|
||||
FLAT_COLUMN_SEPARATOR.join(["bar", "two"]): [1, 1, 1],
|
||||
}
|
||||
)
|
||||
)
|
||||
|
|
@ -19,7 +19,7 @@ import numpy as np
|
|||
import pytest
|
||||
from pandas import DataFrame, Timestamp, to_datetime
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils.pandas_postprocessing import _flatten_column_after_pivot, pivot
|
||||
from tests.unit_tests.fixtures.dataframes import categories_df, single_metric_df
|
||||
from tests.unit_tests.pandas_postprocessing.utils import (
|
||||
|
|
@ -172,7 +172,7 @@ def test_pivot_exceptions():
|
|||
pivot(df=categories_df, columns=["dept"], aggregates=AGGREGATES_SINGLE)
|
||||
|
||||
# invalid index reference
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
pivot(
|
||||
df=categories_df,
|
||||
index=["abc"],
|
||||
|
|
@ -181,7 +181,7 @@ def test_pivot_exceptions():
|
|||
)
|
||||
|
||||
# invalid column reference
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
pivot(
|
||||
df=categories_df,
|
||||
index=["dept"],
|
||||
|
|
@ -190,7 +190,7 @@ def test_pivot_exceptions():
|
|||
)
|
||||
|
||||
# invalid aggregate options
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
pivot(
|
||||
df=categories_df,
|
||||
index=["name"],
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from importlib.util import find_spec
|
|||
|
||||
import pytest
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils.core import DTTM_ALIAS
|
||||
from superset.utils.pandas_postprocessing import prophet
|
||||
from tests.unit_tests.fixtures.dataframes import prophet_df
|
||||
|
|
@ -75,40 +75,40 @@ def test_prophet_valid_zero_periods():
|
|||
def test_prophet_import():
|
||||
dynamic_module = find_spec("prophet")
|
||||
if dynamic_module is None:
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
prophet(df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.9)
|
||||
|
||||
|
||||
def test_prophet_missing_temporal_column():
|
||||
df = prophet_df.drop(DTTM_ALIAS, axis=1)
|
||||
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
prophet(
|
||||
df=df, time_grain="P1M", periods=3, confidence_interval=0.9,
|
||||
)
|
||||
|
||||
|
||||
def test_prophet_incorrect_confidence_interval():
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
prophet(
|
||||
df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.0,
|
||||
)
|
||||
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
prophet(
|
||||
df=prophet_df, time_grain="P1M", periods=3, confidence_interval=1.0,
|
||||
)
|
||||
|
||||
|
||||
def test_prophet_incorrect_periods():
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
prophet(
|
||||
df=prophet_df, time_grain="P1M", periods=-1, confidence_interval=0.8,
|
||||
)
|
||||
|
||||
|
||||
def test_prophet_incorrect_time_grain():
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
prophet(
|
||||
df=prophet_df, time_grain="yearly", periods=10, confidence_interval=0.8,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -14,45 +14,80 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from pandas import DataFrame, to_datetime
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.utils.pandas_postprocessing import resample
|
||||
from tests.unit_tests.fixtures.dataframes import timeseries_df
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils import pandas_postprocessing as pp
|
||||
from tests.unit_tests.fixtures.dataframes import categories_df, timeseries_df
|
||||
|
||||
|
||||
def test_resample_should_not_side_effect():
|
||||
_timeseries_df = timeseries_df.copy()
|
||||
pp.resample(df=_timeseries_df, rule="1D", method="ffill")
|
||||
assert _timeseries_df.equals(timeseries_df)
|
||||
|
||||
|
||||
def test_resample():
|
||||
df = timeseries_df.copy()
|
||||
df.index.name = "time_column"
|
||||
df.reset_index(inplace=True)
|
||||
|
||||
post_df = resample(df=df, rule="1D", method="ffill", time_column="time_column",)
|
||||
assert post_df["label"].tolist() == ["x", "y", "y", "y", "z", "z", "q"]
|
||||
|
||||
assert post_df["y"].tolist() == [1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 4.0]
|
||||
|
||||
post_df = resample(
|
||||
df=df, rule="1D", method="asfreq", time_column="time_column", fill_value=0,
|
||||
post_df = pp.resample(df=timeseries_df, rule="1D", method="ffill")
|
||||
"""
|
||||
label y
|
||||
2019-01-01 x 1.0
|
||||
2019-01-02 y 2.0
|
||||
2019-01-03 y 2.0
|
||||
2019-01-04 y 2.0
|
||||
2019-01-05 z 3.0
|
||||
2019-01-06 z 3.0
|
||||
2019-01-07 q 4.0
|
||||
"""
|
||||
assert post_df.equals(
|
||||
pd.DataFrame(
|
||||
index=pd.to_datetime(
|
||||
[
|
||||
"2019-01-01",
|
||||
"2019-01-02",
|
||||
"2019-01-03",
|
||||
"2019-01-04",
|
||||
"2019-01-05",
|
||||
"2019-01-06",
|
||||
"2019-01-07",
|
||||
]
|
||||
),
|
||||
data={
|
||||
"label": ["x", "y", "y", "y", "z", "z", "q"],
|
||||
"y": [1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 4.0],
|
||||
},
|
||||
)
|
||||
)
|
||||
assert post_df["label"].tolist() == ["x", "y", 0, 0, "z", 0, "q"]
|
||||
assert post_df["y"].tolist() == [1.0, 2.0, 0, 0, 3.0, 0, 4.0]
|
||||
|
||||
|
||||
def test_resample_with_groupby():
|
||||
"""
|
||||
The Dataframe contains a timestamp column, a string column and a numeric column.
|
||||
__timestamp city val
|
||||
0 2022-01-13 Chicago 6.0
|
||||
1 2022-01-13 LA 5.0
|
||||
2 2022-01-13 NY 4.0
|
||||
3 2022-01-11 Chicago 3.0
|
||||
4 2022-01-11 LA 2.0
|
||||
5 2022-01-11 NY 1.0
|
||||
"""
|
||||
df = DataFrame(
|
||||
{
|
||||
"__timestamp": to_datetime(
|
||||
def test_resample_zero_fill():
|
||||
post_df = pp.resample(df=timeseries_df, rule="1D", method="asfreq", fill_value=0)
|
||||
assert post_df.equals(
|
||||
pd.DataFrame(
|
||||
index=pd.to_datetime(
|
||||
[
|
||||
"2019-01-01",
|
||||
"2019-01-02",
|
||||
"2019-01-03",
|
||||
"2019-01-04",
|
||||
"2019-01-05",
|
||||
"2019-01-06",
|
||||
"2019-01-07",
|
||||
]
|
||||
),
|
||||
data={
|
||||
"label": ["x", "y", 0, 0, "z", 0, "q"],
|
||||
"y": [1.0, 2.0, 0, 0, 3.0, 0, 4.0],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_resample_after_pivot():
|
||||
df = pd.DataFrame(
|
||||
data={
|
||||
"__timestamp": pd.to_datetime(
|
||||
[
|
||||
"2022-01-13",
|
||||
"2022-01-13",
|
||||
|
|
@ -66,42 +101,53 @@ __timestamp city val
|
|||
"val": [6.0, 5.0, 4.0, 3.0, 2.0, 1.0],
|
||||
}
|
||||
)
|
||||
post_df = resample(
|
||||
pivot_df = pp.pivot(
|
||||
df=df,
|
||||
rule="1D",
|
||||
method="asfreq",
|
||||
fill_value=0,
|
||||
time_column="__timestamp",
|
||||
groupby_columns=("city",),
|
||||
index=["__timestamp"],
|
||||
columns=["city"],
|
||||
aggregates={"val": {"operator": "sum"},},
|
||||
flatten_columns=False,
|
||||
reset_index=False,
|
||||
)
|
||||
assert list(post_df.columns) == [
|
||||
"__timestamp",
|
||||
"city",
|
||||
"val",
|
||||
]
|
||||
assert [str(dt.date()) for dt in post_df["__timestamp"]] == (
|
||||
["2022-01-11"] * 3 + ["2022-01-12"] * 3 + ["2022-01-13"] * 3
|
||||
"""
|
||||
val
|
||||
city Chicago LA NY
|
||||
__timestamp
|
||||
2022-01-11 3.0 2.0 1.0
|
||||
2022-01-13 6.0 5.0 4.0
|
||||
"""
|
||||
resample_df = pp.resample(df=pivot_df, rule="1D", method="asfreq", fill_value=0,)
|
||||
"""
|
||||
val
|
||||
city Chicago LA NY
|
||||
__timestamp
|
||||
2022-01-11 3.0 2.0 1.0
|
||||
2022-01-12 0.0 0.0 0.0
|
||||
2022-01-13 6.0 5.0 4.0
|
||||
"""
|
||||
flat_df = pp.flatten(resample_df)
|
||||
"""
|
||||
__timestamp val, Chicago val, LA val, NY
|
||||
0 2022-01-11 3.0 2.0 1.0
|
||||
1 2022-01-12 0.0 0.0 0.0
|
||||
2 2022-01-13 6.0 5.0 4.0
|
||||
"""
|
||||
assert flat_df.equals(
|
||||
pd.DataFrame(
|
||||
data={
|
||||
"__timestamp": pd.to_datetime(
|
||||
["2022-01-11", "2022-01-12", "2022-01-13"]
|
||||
),
|
||||
"val, Chicago": [3.0, 0, 6.0],
|
||||
"val, LA": [2.0, 0, 5.0],
|
||||
"val, NY": [1.0, 0, 4.0],
|
||||
}
|
||||
)
|
||||
)
|
||||
assert list(post_df["val"]) == [3.0, 2.0, 1.0, 0, 0, 0, 6.0, 5.0, 4.0]
|
||||
|
||||
# should raise error when get a non-existent column
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
resample(
|
||||
df=df,
|
||||
rule="1D",
|
||||
method="asfreq",
|
||||
fill_value=0,
|
||||
time_column="__timestamp",
|
||||
groupby_columns=("city", "unkonw_column",),
|
||||
)
|
||||
|
||||
# should raise error when get a None value in groupby list
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
resample(
|
||||
df=df,
|
||||
rule="1D",
|
||||
method="asfreq",
|
||||
fill_value=0,
|
||||
time_column="__timestamp",
|
||||
groupby_columns=("city", None,),
|
||||
def test_resample_should_raise_ex():
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
pp.resample(
|
||||
df=categories_df, rule="1D", method="asfreq",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -14,11 +14,12 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from pandas import to_datetime
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.utils.pandas_postprocessing import pivot, rolling
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils import pandas_postprocessing as pp
|
||||
from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
|
||||
from tests.unit_tests.fixtures.dataframes import (
|
||||
multiple_metrics_df,
|
||||
single_metric_df,
|
||||
|
|
@ -27,9 +28,21 @@ from tests.unit_tests.fixtures.dataframes import (
|
|||
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
|
||||
|
||||
|
||||
def test_rolling_should_not_side_effect():
|
||||
_timeseries_df = timeseries_df.copy()
|
||||
pp.rolling(
|
||||
df=timeseries_df,
|
||||
columns={"y": "y"},
|
||||
rolling_type="sum",
|
||||
window=2,
|
||||
min_periods=0,
|
||||
)
|
||||
assert _timeseries_df.equals(timeseries_df)
|
||||
|
||||
|
||||
def test_rolling():
|
||||
# sum rolling type
|
||||
post_df = rolling(
|
||||
post_df = pp.rolling(
|
||||
df=timeseries_df,
|
||||
columns={"y": "y"},
|
||||
rolling_type="sum",
|
||||
|
|
@ -41,7 +54,7 @@ def test_rolling():
|
|||
assert series_to_list(post_df["y"]) == [1.0, 3.0, 5.0, 7.0]
|
||||
|
||||
# mean rolling type with alias
|
||||
post_df = rolling(
|
||||
post_df = pp.rolling(
|
||||
df=timeseries_df,
|
||||
rolling_type="mean",
|
||||
columns={"y": "y_mean"},
|
||||
|
|
@ -52,7 +65,7 @@ def test_rolling():
|
|||
assert series_to_list(post_df["y_mean"]) == [1.0, 1.5, 2.0, 2.5]
|
||||
|
||||
# count rolling type
|
||||
post_df = rolling(
|
||||
post_df = pp.rolling(
|
||||
df=timeseries_df,
|
||||
rolling_type="count",
|
||||
columns={"y": "y"},
|
||||
|
|
@ -63,7 +76,7 @@ def test_rolling():
|
|||
assert series_to_list(post_df["y"]) == [1.0, 2.0, 3.0, 4.0]
|
||||
|
||||
# quantile rolling type
|
||||
post_df = rolling(
|
||||
post_df = pp.rolling(
|
||||
df=timeseries_df,
|
||||
columns={"y": "q1"},
|
||||
rolling_type="quantile",
|
||||
|
|
@ -75,14 +88,14 @@ def test_rolling():
|
|||
assert series_to_list(post_df["q1"]) == [1.0, 1.25, 1.5, 1.75]
|
||||
|
||||
# incorrect rolling type
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
rolling(
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
pp.rolling(
|
||||
df=timeseries_df, columns={"y": "y"}, rolling_type="abc", window=2,
|
||||
)
|
||||
|
||||
# incorrect rolling type options
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
rolling(
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
pp.rolling(
|
||||
df=timeseries_df,
|
||||
columns={"y": "y"},
|
||||
rolling_type="quantile",
|
||||
|
|
@ -91,8 +104,8 @@ def test_rolling():
|
|||
)
|
||||
|
||||
|
||||
def test_rolling_with_pivot_df_and_single_metric():
|
||||
pivot_df = pivot(
|
||||
def test_rolling_should_empty_df():
|
||||
pivot_df = pp.pivot(
|
||||
df=single_metric_df,
|
||||
index=["dttm"],
|
||||
columns=["country"],
|
||||
|
|
@ -100,27 +113,65 @@ def test_rolling_with_pivot_df_and_single_metric():
|
|||
flatten_columns=False,
|
||||
reset_index=False,
|
||||
)
|
||||
rolling_df = rolling(
|
||||
df=pivot_df, rolling_type="sum", window=2, min_periods=0, is_pivot_df=True,
|
||||
)
|
||||
# dttm UK US
|
||||
# 0 2019-01-01 5 6
|
||||
# 1 2019-01-02 12 14
|
||||
assert rolling_df["UK"].to_list() == [5.0, 12.0]
|
||||
assert rolling_df["US"].to_list() == [6.0, 14.0]
|
||||
assert (
|
||||
rolling_df["dttm"].to_list()
|
||||
== to_datetime(["2019-01-01", "2019-01-02"]).to_list()
|
||||
)
|
||||
|
||||
rolling_df = rolling(
|
||||
df=pivot_df, rolling_type="sum", window=2, min_periods=2, is_pivot_df=True,
|
||||
rolling_df = pp.rolling(
|
||||
df=pivot_df,
|
||||
rolling_type="sum",
|
||||
window=2,
|
||||
min_periods=2,
|
||||
columns={"sum_metric": "sum_metric"},
|
||||
)
|
||||
assert rolling_df.empty is True
|
||||
|
||||
|
||||
def test_rolling_with_pivot_df_and_multiple_metrics():
|
||||
pivot_df = pivot(
|
||||
def test_rolling_after_pivot_with_single_metric():
|
||||
pivot_df = pp.pivot(
|
||||
df=single_metric_df,
|
||||
index=["dttm"],
|
||||
columns=["country"],
|
||||
aggregates={"sum_metric": {"operator": "sum"}},
|
||||
flatten_columns=False,
|
||||
reset_index=False,
|
||||
)
|
||||
"""
|
||||
sum_metric
|
||||
country UK US
|
||||
dttm
|
||||
2019-01-01 5 6
|
||||
2019-01-02 7 8
|
||||
"""
|
||||
rolling_df = pp.rolling(
|
||||
df=pivot_df,
|
||||
columns={"sum_metric": "sum_metric"},
|
||||
rolling_type="sum",
|
||||
window=2,
|
||||
min_periods=0,
|
||||
)
|
||||
"""
|
||||
sum_metric
|
||||
country UK US
|
||||
dttm
|
||||
2019-01-01 5.0 6.0
|
||||
2019-01-02 12.0 14.0
|
||||
"""
|
||||
flat_df = pp.flatten(rolling_df)
|
||||
"""
|
||||
dttm sum_metric, UK sum_metric, US
|
||||
0 2019-01-01 5.0 6.0
|
||||
1 2019-01-02 12.0 14.0
|
||||
"""
|
||||
assert flat_df.equals(
|
||||
pd.DataFrame(
|
||||
data={
|
||||
"dttm": pd.to_datetime(["2019-01-01", "2019-01-02"]),
|
||||
FLAT_COLUMN_SEPARATOR.join(["sum_metric", "UK"]): [5.0, 12.0],
|
||||
FLAT_COLUMN_SEPARATOR.join(["sum_metric", "US"]): [6.0, 14.0],
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_rolling_after_pivot_with_multiple_metrics():
|
||||
pivot_df = pp.pivot(
|
||||
df=multiple_metrics_df,
|
||||
index=["dttm"],
|
||||
columns=["country"],
|
||||
|
|
@ -131,17 +182,41 @@ def test_rolling_with_pivot_df_and_multiple_metrics():
|
|||
flatten_columns=False,
|
||||
reset_index=False,
|
||||
)
|
||||
rolling_df = rolling(
|
||||
df=pivot_df, rolling_type="sum", window=2, min_periods=0, is_pivot_df=True,
|
||||
"""
|
||||
count_metric sum_metric
|
||||
country UK US UK US
|
||||
dttm
|
||||
2019-01-01 1 2 5 6
|
||||
2019-01-02 3 4 7 8
|
||||
"""
|
||||
rolling_df = pp.rolling(
|
||||
df=pivot_df,
|
||||
columns={"count_metric": "count_metric", "sum_metric": "sum_metric",},
|
||||
rolling_type="sum",
|
||||
window=2,
|
||||
min_periods=0,
|
||||
)
|
||||
# dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US
|
||||
# 0 2019-01-01 1.0 2.0 5.0 6.0
|
||||
# 1 2019-01-02 4.0 6.0 12.0 14.0
|
||||
assert rolling_df["count_metric, UK"].to_list() == [1.0, 4.0]
|
||||
assert rolling_df["count_metric, US"].to_list() == [2.0, 6.0]
|
||||
assert rolling_df["sum_metric, UK"].to_list() == [5.0, 12.0]
|
||||
assert rolling_df["sum_metric, US"].to_list() == [6.0, 14.0]
|
||||
assert (
|
||||
rolling_df["dttm"].to_list()
|
||||
== to_datetime(["2019-01-01", "2019-01-02",]).to_list()
|
||||
"""
|
||||
count_metric sum_metric
|
||||
country UK US UK US
|
||||
dttm
|
||||
2019-01-01 1.0 2.0 5.0 6.0
|
||||
2019-01-02 4.0 6.0 12.0 14.0
|
||||
"""
|
||||
flat_df = pp.flatten(rolling_df)
|
||||
"""
|
||||
dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US
|
||||
0 2019-01-01 1.0 2.0 5.0 6.0
|
||||
1 2019-01-02 4.0 6.0 12.0 14.0
|
||||
"""
|
||||
assert flat_df.equals(
|
||||
pd.DataFrame(
|
||||
data={
|
||||
"dttm": pd.to_datetime(["2019-01-01", "2019-01-02"]),
|
||||
FLAT_COLUMN_SEPARATOR.join(["count_metric", "UK"]): [1.0, 4.0],
|
||||
FLAT_COLUMN_SEPARATOR.join(["count_metric", "US"]): [2.0, 6.0],
|
||||
FLAT_COLUMN_SEPARATOR.join(["sum_metric", "UK"]): [5.0, 12.0],
|
||||
FLAT_COLUMN_SEPARATOR.join(["sum_metric", "US"]): [6.0, 14.0],
|
||||
}
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@
|
|||
# under the License.
|
||||
import pytest
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils.pandas_postprocessing.select import select
|
||||
from tests.unit_tests.fixtures.dataframes import timeseries_df
|
||||
|
||||
|
|
@ -47,9 +47,9 @@ def test_select():
|
|||
assert post_df.columns.tolist() == ["y1"]
|
||||
|
||||
# invalid columns
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
select(df=timeseries_df, columns=["abc"], rename={"abc": "qwerty"})
|
||||
|
||||
# select renamed column by new name
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
select(df=timeseries_df, columns=["label_new"], rename={"label": "label_new"})
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@
|
|||
# under the License.
|
||||
import pytest
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils.pandas_postprocessing import sort
|
||||
from tests.unit_tests.fixtures.dataframes import categories_df
|
||||
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
|
||||
|
|
@ -26,5 +26,5 @@ def test_sort():
|
|||
df = sort(df=categories_df, columns={"category": True, "asc_idx": False})
|
||||
assert series_to_list(df["asc_idx"])[1] == 96
|
||||
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
sort(df=df, columns={"abc": True})
|
||||
|
|
|
|||
Loading…
Reference in New Issue