feat(native-filters): add temporal support to select filter (#13622)
This commit is contained in:
parent
adc247b7e4
commit
13f7e0d755
|
|
@ -18,9 +18,18 @@
|
|||
*/
|
||||
|
||||
import {
|
||||
GenericDataType,
|
||||
getNumberFormatter,
|
||||
getTimeFormatter,
|
||||
NumberFormats,
|
||||
TimeFormats,
|
||||
} from '@superset-ui/core';
|
||||
import {
|
||||
getDataRecordFormatter,
|
||||
getRangeExtraFormData,
|
||||
getSelectExtraFormData,
|
||||
} from '../../../src/filters/utils';
|
||||
} from 'src/filters/utils';
|
||||
import { FALSE_STRING, NULL_STRING, TRUE_STRING } from 'src/utils/common';
|
||||
|
||||
describe('Filter utils', () => {
|
||||
describe('getRangeExtraFormData', () => {
|
||||
|
|
@ -157,4 +166,85 @@ describe('Filter utils', () => {
|
|||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('getDataRecordFormatter', () => {
|
||||
it('default formatter returns expected values', () => {
|
||||
const formatter = getDataRecordFormatter();
|
||||
expect(formatter(null, GenericDataType.STRING)).toEqual(NULL_STRING);
|
||||
expect(formatter(null, GenericDataType.NUMERIC)).toEqual(NULL_STRING);
|
||||
expect(formatter(null, GenericDataType.TEMPORAL)).toEqual(NULL_STRING);
|
||||
expect(formatter(null, GenericDataType.BOOLEAN)).toEqual(NULL_STRING);
|
||||
expect(formatter('foo', GenericDataType.STRING)).toEqual('foo');
|
||||
expect(formatter('foo', GenericDataType.NUMERIC)).toEqual('foo');
|
||||
expect(formatter('foo', GenericDataType.TEMPORAL)).toEqual('foo');
|
||||
expect(formatter('foo', GenericDataType.BOOLEAN)).toEqual(FALSE_STRING);
|
||||
expect(formatter(true, GenericDataType.BOOLEAN)).toEqual(TRUE_STRING);
|
||||
expect(formatter(false, GenericDataType.BOOLEAN)).toEqual(FALSE_STRING);
|
||||
expect(formatter('true', GenericDataType.BOOLEAN)).toEqual(TRUE_STRING);
|
||||
expect(formatter('false', GenericDataType.BOOLEAN)).toEqual(FALSE_STRING);
|
||||
expect(formatter('TRUE', GenericDataType.BOOLEAN)).toEqual(TRUE_STRING);
|
||||
expect(formatter('FALSE', GenericDataType.BOOLEAN)).toEqual(FALSE_STRING);
|
||||
expect(formatter(0, GenericDataType.BOOLEAN)).toEqual(FALSE_STRING);
|
||||
expect(formatter(1, GenericDataType.BOOLEAN)).toEqual(TRUE_STRING);
|
||||
expect(formatter(2, GenericDataType.BOOLEAN)).toEqual(TRUE_STRING);
|
||||
expect(formatter(0, GenericDataType.STRING)).toEqual('0');
|
||||
expect(formatter(0, GenericDataType.NUMERIC)).toEqual('0');
|
||||
expect(formatter(0, GenericDataType.TEMPORAL)).toEqual('0');
|
||||
expect(formatter(1234567.89, GenericDataType.STRING)).toEqual(
|
||||
'1234567.89',
|
||||
);
|
||||
expect(formatter(1234567.89, GenericDataType.NUMERIC)).toEqual(
|
||||
'1234567.89',
|
||||
);
|
||||
expect(formatter(1234567.89, GenericDataType.TEMPORAL)).toEqual(
|
||||
'1234567.89',
|
||||
);
|
||||
expect(formatter(1234567.89, GenericDataType.BOOLEAN)).toEqual(
|
||||
TRUE_STRING,
|
||||
);
|
||||
});
|
||||
|
||||
it('formatter with defined formatters returns expected values', () => {
|
||||
const formatter = getDataRecordFormatter({
|
||||
timeFormatter: getTimeFormatter(TimeFormats.DATABASE_DATETIME),
|
||||
numberFormatter: getNumberFormatter(NumberFormats.SMART_NUMBER),
|
||||
});
|
||||
expect(formatter(null, GenericDataType.STRING)).toEqual(NULL_STRING);
|
||||
expect(formatter(null, GenericDataType.NUMERIC)).toEqual(NULL_STRING);
|
||||
expect(formatter(null, GenericDataType.TEMPORAL)).toEqual(NULL_STRING);
|
||||
expect(formatter(null, GenericDataType.BOOLEAN)).toEqual(NULL_STRING);
|
||||
expect(formatter('foo', GenericDataType.STRING)).toEqual('foo');
|
||||
expect(formatter('foo', GenericDataType.NUMERIC)).toEqual('foo');
|
||||
expect(formatter('foo', GenericDataType.TEMPORAL)).toEqual('foo');
|
||||
expect(formatter('foo', GenericDataType.BOOLEAN)).toEqual(FALSE_STRING);
|
||||
expect(formatter(0, GenericDataType.STRING)).toEqual('0');
|
||||
expect(formatter(0, GenericDataType.NUMERIC)).toEqual('0');
|
||||
expect(formatter(0, GenericDataType.TEMPORAL)).toEqual(
|
||||
'1970-01-01 00:00:00',
|
||||
);
|
||||
expect(formatter(0, GenericDataType.BOOLEAN)).toEqual(FALSE_STRING);
|
||||
expect(formatter(1234567.89, GenericDataType.STRING)).toEqual(
|
||||
'1234567.89',
|
||||
);
|
||||
expect(formatter(1234567.89, GenericDataType.NUMERIC)).toEqual('1.23M');
|
||||
expect(formatter(1234567.89, GenericDataType.TEMPORAL)).toEqual(
|
||||
'1970-01-01 00:20:34',
|
||||
);
|
||||
expect(formatter(1234567.89, GenericDataType.BOOLEAN)).toEqual(
|
||||
TRUE_STRING,
|
||||
);
|
||||
expect(formatter('1970-01-01 00:00:00', GenericDataType.STRING)).toEqual(
|
||||
'1970-01-01 00:00:00',
|
||||
);
|
||||
expect(formatter('1970-01-01 00:00:00', GenericDataType.NUMERIC)).toEqual(
|
||||
'1970-01-01 00:00:00',
|
||||
);
|
||||
expect(formatter('1970-01-01 00:00:00', GenericDataType.BOOLEAN)).toEqual(
|
||||
FALSE_STRING,
|
||||
);
|
||||
expect(
|
||||
formatter('1970-01-01 00:00:00', GenericDataType.TEMPORAL),
|
||||
).toEqual('1970-01-01 00:00:00');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -16,17 +16,48 @@
|
|||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
import { Behavior, DataMask, t, tn, ensureIsArray } from '@superset-ui/core';
|
||||
import {
|
||||
createMultiFormatter,
|
||||
Behavior,
|
||||
DataMask,
|
||||
ensureIsArray,
|
||||
GenericDataType,
|
||||
t,
|
||||
tn,
|
||||
} from '@superset-ui/core';
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import { Select } from 'src/common/components';
|
||||
import { PluginFilterSelectProps } from './types';
|
||||
import { Styles, StyledSelect } from '../common';
|
||||
import { getSelectExtraFormData } from '../../utils';
|
||||
import { StyledSelect, Styles } from '../common';
|
||||
import { getDataRecordFormatter, getSelectExtraFormData } from '../../utils';
|
||||
|
||||
const { Option } = Select;
|
||||
|
||||
const timeFormatter = createMultiFormatter({
|
||||
id: 'smart_date_verbose',
|
||||
label: 'Adaptive temporal formatter',
|
||||
formats: {
|
||||
millisecond: '%Y-%m-%d %H:%M:%S.%L',
|
||||
second: '%Y-%m-%d %H:%M:%S',
|
||||
minute: '%Y-%m-%d %H:%M',
|
||||
hour: '%Y-%m-%d %H:%M:%M',
|
||||
day: '%Y-%m-%d',
|
||||
week: '%Y-%m-%d',
|
||||
month: '%Y-%m-%d',
|
||||
year: '%Y-%m-%d',
|
||||
},
|
||||
});
|
||||
|
||||
export default function PluginFilterSelect(props: PluginFilterSelectProps) {
|
||||
const { data, formData, height, width, behaviors, setDataMask } = props;
|
||||
const {
|
||||
coltypeMap,
|
||||
data,
|
||||
formData,
|
||||
height,
|
||||
width,
|
||||
behaviors,
|
||||
setDataMask,
|
||||
} = props;
|
||||
const {
|
||||
defaultValue,
|
||||
enableEmptyFilter,
|
||||
|
|
@ -37,10 +68,16 @@ export default function PluginFilterSelect(props: PluginFilterSelectProps) {
|
|||
inputRef,
|
||||
} = formData;
|
||||
|
||||
const [values, setValues] = useState<(string | number)[]>(defaultValue ?? []);
|
||||
const [values, setValues] = useState<(string | number | boolean)[]>(
|
||||
defaultValue ?? [],
|
||||
);
|
||||
const groupby = ensureIsArray<string>(formData.groupby);
|
||||
|
||||
let { groupby = [] } = formData;
|
||||
groupby = Array.isArray(groupby) ? groupby : [groupby];
|
||||
const [col] = groupby;
|
||||
const datatype: GenericDataType = coltypeMap[col];
|
||||
const labelFormatter = getDataRecordFormatter({
|
||||
timeFormatter,
|
||||
});
|
||||
|
||||
const handleChange = (
|
||||
value?: (number | string)[] | number | string | null,
|
||||
|
|
@ -50,7 +87,6 @@ export default function PluginFilterSelect(props: PluginFilterSelectProps) {
|
|||
);
|
||||
setValues(resultValue);
|
||||
|
||||
const [col] = groupby;
|
||||
const emptyFilter =
|
||||
enableEmptyFilter && !inverseSelection && resultValue?.length === 0;
|
||||
|
||||
|
|
@ -104,6 +140,7 @@ export default function PluginFilterSelect(props: PluginFilterSelectProps) {
|
|||
<Styles height={height} width={width}>
|
||||
<StyledSelect
|
||||
allowClear
|
||||
// @ts-ignore
|
||||
value={values}
|
||||
showSearch={showSearch}
|
||||
mode={multiSelect ? 'multiple' : undefined}
|
||||
|
|
@ -113,10 +150,11 @@ export default function PluginFilterSelect(props: PluginFilterSelectProps) {
|
|||
ref={inputRef}
|
||||
>
|
||||
{(data || []).map(row => {
|
||||
const option = `${groupby.map(col => row[col])[0]}`;
|
||||
const [value] = groupby.map(col => row[col]);
|
||||
return (
|
||||
<Option key={option} value={option}>
|
||||
{option}
|
||||
// @ts-ignore
|
||||
<Option key={`${value}`} value={value}>
|
||||
{labelFormatter(value, datatype)}
|
||||
</Option>
|
||||
);
|
||||
})}
|
||||
|
|
|
|||
|
|
@ -21,14 +21,18 @@ import { DEFAULT_FORM_DATA, PluginFilterSelectQueryFormData } from './types';
|
|||
|
||||
export default function buildQuery(formData: PluginFilterSelectQueryFormData) {
|
||||
const { sortAscending } = { ...DEFAULT_FORM_DATA, ...formData };
|
||||
return buildQueryContext(formData, baseQueryObject => [
|
||||
{
|
||||
...baseQueryObject,
|
||||
apply_fetch_values_predicate: true,
|
||||
groupby: baseQueryObject.columns,
|
||||
orderby: sortAscending
|
||||
? baseQueryObject.columns.map(column => [column, true])
|
||||
: [],
|
||||
},
|
||||
]);
|
||||
return buildQueryContext(formData, baseQueryObject => {
|
||||
const { columns, filters = [] } = baseQueryObject;
|
||||
return [
|
||||
{
|
||||
...baseQueryObject,
|
||||
apply_fetch_values_predicate: true,
|
||||
groupby: columns,
|
||||
filters: filters.concat(
|
||||
columns.map(column => ({ col: column, op: 'IS NOT NULL' })),
|
||||
),
|
||||
orderby: sortAscending ? columns.map(column => [column, true]) : [],
|
||||
},
|
||||
];
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,17 +16,24 @@
|
|||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
import { ChartProps } from '@superset-ui/core';
|
||||
import { DEFAULT_FORM_DATA } from './types';
|
||||
import { GenericDataType } from '@superset-ui/core';
|
||||
import { DEFAULT_FORM_DATA, PluginFilterSelectChartProps } from './types';
|
||||
|
||||
export default function transformProps(chartProps: ChartProps) {
|
||||
export default function transformProps(
|
||||
chartProps: PluginFilterSelectChartProps,
|
||||
) {
|
||||
const { formData, height, hooks, queriesData, width, behaviors } = chartProps;
|
||||
const newFormData = { ...DEFAULT_FORM_DATA, ...formData };
|
||||
const { setDataMask = () => {} } = hooks;
|
||||
|
||||
const { data } = queriesData[0];
|
||||
const [queryData] = queriesData;
|
||||
const { colnames = [], coltypes = [], data } = queryData || [];
|
||||
const coltypeMap: Record<string, GenericDataType> = colnames.reduce(
|
||||
(accumulator, item, index) => ({ ...accumulator, [item]: coltypes[index] }),
|
||||
{},
|
||||
);
|
||||
|
||||
return {
|
||||
coltypeMap,
|
||||
width,
|
||||
behaviors,
|
||||
height,
|
||||
|
|
|
|||
|
|
@ -17,10 +17,13 @@
|
|||
* under the License.
|
||||
*/
|
||||
import {
|
||||
ChartProps,
|
||||
Behavior,
|
||||
DataRecord,
|
||||
GenericDataType,
|
||||
QueryFormData,
|
||||
SetDataMaskHook,
|
||||
ChartDataResponseResult,
|
||||
} from '@superset-ui/core';
|
||||
import { RefObject } from 'react';
|
||||
import { PluginFilterStylesProps } from '../types';
|
||||
|
|
@ -39,7 +42,12 @@ export type PluginFilterSelectQueryFormData = QueryFormData &
|
|||
PluginFilterStylesProps &
|
||||
PluginFilterSelectCustomizeProps;
|
||||
|
||||
export interface PluginFilterSelectChartProps extends ChartProps {
|
||||
queriesData: ChartDataResponseResult[];
|
||||
}
|
||||
|
||||
export type PluginFilterSelectProps = PluginFilterStylesProps & {
|
||||
coltypeMap: Record<string, GenericDataType>;
|
||||
data: DataRecord[];
|
||||
setDataMask: SetDataMaskHook;
|
||||
behaviors: Behavior[];
|
||||
|
|
|
|||
|
|
@ -16,7 +16,14 @@
|
|||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
import { QueryObjectFilterClause } from '@superset-ui/core';
|
||||
import {
|
||||
DataRecordValue,
|
||||
GenericDataType,
|
||||
NumberFormatter,
|
||||
QueryObjectFilterClause,
|
||||
TimeFormatter,
|
||||
} from '@superset-ui/core';
|
||||
import { FALSE_STRING, NULL_STRING, TRUE_STRING } from 'src/utils/common';
|
||||
|
||||
export const getSelectExtraFormData = (
|
||||
col: string,
|
||||
|
|
@ -67,3 +74,47 @@ export const getRangeExtraFormData = (
|
|||
},
|
||||
};
|
||||
};
|
||||
|
||||
export interface DataRecordValueFormatter {
|
||||
(value: DataRecordValue, dtype: GenericDataType): string;
|
||||
}
|
||||
|
||||
export function getDataRecordFormatter({
|
||||
timeFormatter,
|
||||
numberFormatter,
|
||||
}: {
|
||||
timeFormatter?: TimeFormatter;
|
||||
numberFormatter?: NumberFormatter;
|
||||
} = {}): DataRecordValueFormatter {
|
||||
return (value, dtype) => {
|
||||
if (value === null || value === undefined) {
|
||||
return NULL_STRING;
|
||||
}
|
||||
if (typeof value === 'boolean') {
|
||||
return value ? TRUE_STRING : FALSE_STRING;
|
||||
}
|
||||
if (dtype === GenericDataType.BOOLEAN) {
|
||||
try {
|
||||
return JSON.parse(String(value).toLowerCase())
|
||||
? TRUE_STRING
|
||||
: FALSE_STRING;
|
||||
} catch {
|
||||
return FALSE_STRING;
|
||||
}
|
||||
}
|
||||
if (typeof value === 'string') {
|
||||
return value;
|
||||
}
|
||||
if (timeFormatter && dtype === GenericDataType.TEMPORAL) {
|
||||
return timeFormatter(value);
|
||||
}
|
||||
if (
|
||||
numberFormatter &&
|
||||
typeof value === 'number' &&
|
||||
dtype === GenericDataType.NUMERIC
|
||||
) {
|
||||
return numberFormatter(value);
|
||||
}
|
||||
return String(value);
|
||||
};
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ import {
|
|||
// ATTENTION: If you change any constants, make sure to also change constants.py
|
||||
|
||||
export const NULL_STRING = '<NULL>';
|
||||
export const TRUE_STRING = 'TRUE';
|
||||
export const FALSE_STRING = 'FALSE';
|
||||
|
||||
// moment time format strings
|
||||
export const SHORT_DATE = 'MMM D, YYYY';
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import json
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Hashable, List, Optional, Type, Union
|
||||
|
||||
|
|
@ -334,7 +335,7 @@ class BaseDatasource(
|
|||
@staticmethod
|
||||
def filter_values_handler(
|
||||
values: Optional[FilterValues],
|
||||
target_column_is_numeric: bool = False,
|
||||
target_column_type: utils.GenericDataType,
|
||||
is_list_target: bool = False,
|
||||
) -> Optional[FilterValues]:
|
||||
if values is None:
|
||||
|
|
@ -342,12 +343,18 @@ class BaseDatasource(
|
|||
|
||||
def handle_single_value(value: Optional[FilterValue]) -> Optional[FilterValue]:
|
||||
# backward compatibility with previous <select> components
|
||||
if (
|
||||
isinstance(value, (float, int))
|
||||
and target_column_type == utils.GenericDataType.TEMPORAL
|
||||
):
|
||||
return datetime.utcfromtimestamp(value / 1000)
|
||||
if isinstance(value, str):
|
||||
value = value.strip("\t\n'\"")
|
||||
if target_column_is_numeric:
|
||||
|
||||
if target_column_type == utils.GenericDataType.NUMERIC:
|
||||
# For backwards compatibility and edge cases
|
||||
# where a column data type might have changed
|
||||
value = utils.cast_to_num(value)
|
||||
return utils.cast_to_num(value)
|
||||
if value == NULL_STRING:
|
||||
return None
|
||||
if value == "<empty string>":
|
||||
|
|
|
|||
|
|
@ -1515,7 +1515,9 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
eq = cls.filter_values_handler(
|
||||
eq,
|
||||
is_list_target=is_list_target,
|
||||
target_column_is_numeric=is_numeric_col,
|
||||
target_column_type=utils.GenericDataType.NUMERIC
|
||||
if is_numeric_col
|
||||
else utils.GenericDataType.STRING,
|
||||
)
|
||||
|
||||
# For these two ops, could have used Dimension,
|
||||
|
|
|
|||
|
|
@ -1112,16 +1112,22 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
|
|||
if not all([flt.get(s) for s in ["col", "op"]]):
|
||||
continue
|
||||
col = flt["col"]
|
||||
val = flt.get("val")
|
||||
op = flt["op"].upper()
|
||||
col_obj = columns_by_name.get(col)
|
||||
if col_obj:
|
||||
col_spec = db_engine_spec.get_column_spec(col_obj.type)
|
||||
is_list_target = op in (
|
||||
utils.FilterOperator.IN.value,
|
||||
utils.FilterOperator.NOT_IN.value,
|
||||
)
|
||||
if col_spec:
|
||||
target_type = col_spec.generic_type
|
||||
else:
|
||||
target_type = GenericDataType.STRING
|
||||
eq = self.filter_values_handler(
|
||||
values=flt.get("val"),
|
||||
target_column_is_numeric=col_obj.is_numeric,
|
||||
values=val,
|
||||
target_column_type=target_type,
|
||||
is_list_target=is_list_target,
|
||||
)
|
||||
if is_list_target:
|
||||
|
|
|
|||
|
|
@ -55,6 +55,7 @@ from sqlalchemy.types import String, TypeEngine, UnicodeText
|
|||
from superset import app, security_manager, sql_parse
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.models.sql_types.base import literal_dttm_type_factory
|
||||
from superset.sql_parse import ParsedQuery, Table
|
||||
from superset.utils import core as utils
|
||||
from superset.utils.core import ColumnSpec, GenericDataType
|
||||
|
|
@ -209,6 +210,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
String(),
|
||||
utils.GenericDataType.STRING,
|
||||
),
|
||||
(
|
||||
re.compile(r"^datetime", re.IGNORECASE),
|
||||
types.DateTime(),
|
||||
GenericDataType.TEMPORAL,
|
||||
),
|
||||
(re.compile(r"^date", re.IGNORECASE), types.Date(), GenericDataType.TEMPORAL,),
|
||||
(
|
||||
re.compile(r"^timestamp", re.IGNORECASE),
|
||||
|
|
@ -1176,22 +1182,19 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
:param source: Type coming from the database table or cursor description
|
||||
:return: ColumnSpec object
|
||||
"""
|
||||
column_type = None
|
||||
|
||||
if (
|
||||
cls.get_sqla_column_type(
|
||||
native_type, column_type_mappings=column_type_mappings
|
||||
)
|
||||
is not None
|
||||
):
|
||||
column_type, generic_type = cls.get_sqla_column_type( # type: ignore
|
||||
native_type, column_type_mappings=column_type_mappings
|
||||
)
|
||||
col_types = cls.get_sqla_column_type(
|
||||
native_type, column_type_mappings=column_type_mappings
|
||||
)
|
||||
if col_types:
|
||||
column_type, generic_type = col_types
|
||||
# wrap temporal types in custom type that supports literal binding
|
||||
# using datetimes
|
||||
if generic_type == GenericDataType.TEMPORAL:
|
||||
column_type = literal_dttm_type_factory(
|
||||
type(column_type), cls, native_type or ""
|
||||
)
|
||||
is_dttm = generic_type == GenericDataType.TEMPORAL
|
||||
|
||||
if column_type:
|
||||
return ColumnSpec(
|
||||
sqla_type=column_type, generic_type=generic_type, is_dttm=is_dttm
|
||||
)
|
||||
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -145,7 +145,7 @@ class PostgresEngineSpec(PostgresBaseEngineSpec):
|
|||
tt = target_type.upper()
|
||||
if tt == utils.TemporalType.DATE:
|
||||
return f"TO_DATE('{dttm.date().isoformat()}', 'YYYY-MM-DD')"
|
||||
if tt == utils.TemporalType.TIMESTAMP:
|
||||
if "TIMESTAMP" in tt or "DATETIME" in tt:
|
||||
dttm_formatted = dttm.isoformat(sep=" ", timespec="microseconds")
|
||||
return f"""TO_TIMESTAMP('{dttm_formatted}', 'YYYY-MM-DD HH24:MI:SS.US')"""
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ class SqliteEngineSpec(BaseEngineSpec):
|
|||
@classmethod
|
||||
def convert_dttm(cls, target_type: str, dttm: datetime) -> Optional[str]:
|
||||
tt = target_type.upper()
|
||||
if tt == utils.TemporalType.TEXT:
|
||||
if tt in (utils.TemporalType.TEXT, utils.TemporalType.DATETIME):
|
||||
return f"""'{dttm.isoformat(sep=" ", timespec="microseconds")}'"""
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,61 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Type, TYPE_CHECKING
|
||||
|
||||
from flask_babel import gettext as __
|
||||
from sqlalchemy import types
|
||||
from sqlalchemy.engine.interfaces import Dialect
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
|
||||
def literal_dttm_type_factory(
|
||||
sqla_type: Type[types.TypeEngine],
|
||||
db_engine_spec: Type["BaseEngineSpec"],
|
||||
col_type: str,
|
||||
) -> Type[types.TypeEngine]:
|
||||
"""
|
||||
Create a custom SQLAlchemy type that supports datetime literal binds.
|
||||
|
||||
:param sqla_type: Base type to extend
|
||||
:param db_engine_spec: Database engine spec which supports `convert_dttm` method
|
||||
:param col_type: native column type as defined in table metadata
|
||||
:return: SQLAlchemy type that supports using datetima as literal bind
|
||||
"""
|
||||
# pylint: disable=too-few-public-methods
|
||||
class TemporalWrapperType(sqla_type): # type: ignore
|
||||
# pylint: disable=unused-argument
|
||||
def literal_processor(self, dialect: Dialect) -> Callable[[Any], Any]:
|
||||
def process(value: Any) -> Any:
|
||||
if isinstance(value, datetime):
|
||||
ts_expression = db_engine_spec.convert_dttm(col_type, value)
|
||||
if ts_expression is None:
|
||||
raise NotImplementedError(
|
||||
__(
|
||||
"Temporal expression not supported for type: "
|
||||
"%(col_type)s",
|
||||
col_type=col_type,
|
||||
)
|
||||
)
|
||||
return ts_expression
|
||||
return super().process(value)
|
||||
|
||||
return process
|
||||
|
||||
return TemporalWrapperType
|
||||
|
|
@ -14,6 +14,7 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from flask import Flask
|
||||
|
|
@ -26,7 +27,7 @@ DbapiDescriptionRow = Tuple[
|
|||
]
|
||||
DbapiDescription = Union[List[DbapiDescriptionRow], Tuple[DbapiDescriptionRow, ...]]
|
||||
DbapiResult = Sequence[Union[List[Any], Tuple[Any, ...]]]
|
||||
FilterValue = Union[float, int, str]
|
||||
FilterValue = Union[datetime, float, int, str]
|
||||
FilterValues = Union[FilterValue, List[FilterValue], Tuple[FilterValue]]
|
||||
FormData = Dict[str, Any]
|
||||
Granularity = Union[str, Dict[str, Union[str, float]]]
|
||||
|
|
|
|||
|
|
@ -430,6 +430,10 @@ def parse_js_uri_path_item(
|
|||
def cast_to_num(value: Optional[Union[float, int, str]]) -> Optional[Union[float, int]]:
|
||||
"""Casts a value to an int/float
|
||||
|
||||
>>> cast_to_num('1 ')
|
||||
1.0
|
||||
>>> cast_to_num(' 2')
|
||||
2.0
|
||||
>>> cast_to_num('5')
|
||||
5
|
||||
>>> cast_to_num('5.2')
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@
|
|||
# under the License.
|
||||
# isort:skip_file
|
||||
"""Unit tests for Superset"""
|
||||
from datetime import datetime
|
||||
import imp
|
||||
import json
|
||||
from contextlib import contextmanager
|
||||
|
|
@ -130,7 +131,7 @@ class SupersetTestCase(TestCase):
|
|||
return (db.session.query(func.max(model.id)).scalar() or 0) + 1
|
||||
|
||||
@staticmethod
|
||||
def get_birth_names_dataset():
|
||||
def get_birth_names_dataset() -> SqlaTable:
|
||||
example_db = get_example_database()
|
||||
return (
|
||||
db.session.query(SqlaTable)
|
||||
|
|
@ -526,6 +527,10 @@ class SupersetTestCase(TestCase):
|
|||
mock_method.assert_called_once_with("error", func_name)
|
||||
return rv
|
||||
|
||||
@classmethod
|
||||
def get_dttm(cls):
|
||||
return datetime.strptime("2019-01-02 03:04:05.678900", "%Y-%m-%d %H:%M:%S.%f")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def db_insert_temp_object(obj: DeclarativeMeta):
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@
|
|||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
from unittest import mock
|
||||
from zipfile import is_zipfile, ZipFile
|
||||
|
||||
|
|
@ -36,7 +37,7 @@ from sqlalchemy.sql import func
|
|||
from tests.fixtures.world_bank_dashboard import load_world_bank_dashboard_with_slices
|
||||
from tests.test_app import app
|
||||
from superset.charts.commands.data import ChartDataCommand
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
||||
from superset.extensions import async_query_manager, cache_manager, db
|
||||
from superset.models.annotations import AnnotationLayer
|
||||
from superset.models.core import Database, FavStar, FavStarClassName
|
||||
|
|
@ -1194,6 +1195,46 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
|
|||
result = response_payload["result"][0]
|
||||
self.assertEqual(result["rowcount"], 10)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_chart_data_dttm_filter(self):
|
||||
"""
|
||||
Chart data API: Ensure temporal column filter converts epoch to dttm expression
|
||||
"""
|
||||
table = self.get_birth_names_dataset()
|
||||
if table.database.backend == "presto":
|
||||
# TODO: date handling on Presto not fully in line with other engine specs
|
||||
return
|
||||
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
request_payload["queries"][0]["time_range"] = ""
|
||||
dttm = self.get_dttm()
|
||||
ms_epoch = dttm.timestamp() * 1000
|
||||
request_payload["queries"][0]["filters"][0] = {
|
||||
"col": "ds",
|
||||
"op": "!=",
|
||||
"val": ms_epoch,
|
||||
}
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
result = response_payload["result"][0]
|
||||
|
||||
# assert that unconverted timestamp is not present in query
|
||||
assert str(ms_epoch) not in result["query"]
|
||||
|
||||
# assert that converted timestamp is present in query where supported
|
||||
dttm_col: Optional[TableColumn] = None
|
||||
for col in table.columns:
|
||||
if col.column_name == table.main_dttm_col:
|
||||
dttm_col = col
|
||||
if dttm_col:
|
||||
dttm_expression = table.database.db_engine_spec.convert_dttm(
|
||||
dttm_col.type, dttm,
|
||||
)
|
||||
self.assertIn(dttm_expression, result["query"])
|
||||
else:
|
||||
raise Exception("ds column not found")
|
||||
|
||||
def test_chart_data_prophet(self):
|
||||
"""
|
||||
Chart data API: Ensure prophet post transformation works
|
||||
|
|
|
|||
|
|
@ -30,7 +30,3 @@ class TestDbEngineSpec(SupersetTestCase):
|
|||
main = Database(database_name="test_database", sqlalchemy_uri="sqlite://")
|
||||
limited = engine_spec_class.apply_limit_to_sql(sql, limit, main)
|
||||
self.assertEqual(expected_sql, limited)
|
||||
|
||||
@classmethod
|
||||
def get_dttm(cls):
|
||||
return datetime.strptime("2019-01-02 03:04:05.678900", "%Y-%m-%d %H:%M:%S.%f")
|
||||
|
|
|
|||
|
|
@ -109,7 +109,12 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
|
|||
"TO_TIMESTAMP('2019-01-02 03:04:05.678900', 'YYYY-MM-DD HH24:MI:SS.US')",
|
||||
)
|
||||
|
||||
self.assertEqual(PostgresEngineSpec.convert_dttm("DATETIME", dttm), None)
|
||||
self.assertEqual(
|
||||
PostgresEngineSpec.convert_dttm("DATETIME", dttm),
|
||||
"TO_TIMESTAMP('2019-01-02 03:04:05.678900', 'YYYY-MM-DD HH24:MI:SS.US')",
|
||||
)
|
||||
|
||||
self.assertEqual(PostgresEngineSpec.convert_dttm("TIME", dttm), None)
|
||||
|
||||
def test_empty_dbapi_cursor_description(self):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -560,11 +560,11 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||
self.assertEqual(column_spec.generic_type, GenericDataType.NUMERIC)
|
||||
|
||||
column_spec = PrestoEngineSpec.get_column_spec("time")
|
||||
assert isinstance(column_spec.sqla_type, types.Time)
|
||||
assert issubclass(column_spec.sqla_type, types.Time)
|
||||
self.assertEqual(column_spec.generic_type, GenericDataType.TEMPORAL)
|
||||
|
||||
column_spec = PrestoEngineSpec.get_column_spec("timestamp")
|
||||
assert isinstance(column_spec.sqla_type, types.TIMESTAMP)
|
||||
assert issubclass(column_spec.sqla_type, types.TIMESTAMP)
|
||||
self.assertEqual(column_spec.generic_type, GenericDataType.TEMPORAL)
|
||||
|
||||
sqla_type = PrestoEngineSpec.get_sqla_column_type(None)
|
||||
|
|
|
|||
Loading…
Reference in New Issue