refactor: queryObject - add QueryObjectFactory (#17466)
This commit is contained in:
parent
b914e2d497
commit
377db1bd71
|
|
@ -32,7 +32,7 @@ from superset.charts.dao import ChartDAO
|
|||
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
|
||||
from superset.common.db_query_status import QueryStatus
|
||||
from superset.common.query_actions import get_query_results
|
||||
from superset.common.query_object import QueryObject
|
||||
from superset.common.query_object import QueryObject, QueryObjectFactory
|
||||
from superset.common.utils import QueryCacheManager
|
||||
from superset.connectors.base.models import BaseDatasource
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
|
|
@ -102,8 +102,10 @@ class QueryContext:
|
|||
)
|
||||
self.result_type = result_type or ChartDataResultType.FULL
|
||||
self.result_format = result_format or ChartDataResultFormat.JSON
|
||||
query_object_factory = QueryObjectFactory()
|
||||
self.queries = [
|
||||
QueryObject(self.result_type, **query_obj) for query_obj in queries
|
||||
query_object_factory.create(self.result_type, **query_obj)
|
||||
for query_obj in queries
|
||||
]
|
||||
self.force = force
|
||||
self.custom_cache_timeout = custom_cache_timeout
|
||||
|
|
|
|||
|
|
@ -14,19 +14,18 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=invalid-name
|
||||
# pylint: disable=invalid-name, no-self-use
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, TYPE_CHECKING
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from flask_babel import gettext as _
|
||||
from pandas import DataFrame
|
||||
|
||||
from superset import app, db
|
||||
from superset.common.chart_data import ChartDataResultType
|
||||
from superset.connectors.base.models import BaseDatasource
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.typing import Column, Metric, OrderBy
|
||||
|
|
@ -47,7 +46,7 @@ from superset.utils.hashing import md5_sha_from_dict
|
|||
from superset.views.utils import get_time_range_endpoints
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.common.query_context import QueryContext # pragma: no cover
|
||||
from superset.connectors.base.models import BaseDatasource
|
||||
|
||||
|
||||
config = app.config
|
||||
|
|
@ -111,14 +110,14 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
|||
time_range: Optional[str]
|
||||
to_dttm: Optional[datetime]
|
||||
|
||||
def __init__( # pylint: disable=too-many-arguments,too-many-locals
|
||||
def __init__( # pylint: disable=too-many-locals
|
||||
self,
|
||||
parent_result_type: ChartDataResultType,
|
||||
*,
|
||||
annotation_layers: Optional[List[Dict[str, Any]]] = None,
|
||||
applied_time_extras: Optional[Dict[str, str]] = None,
|
||||
apply_fetch_values_predicate: bool = False,
|
||||
columns: Optional[List[Column]] = None,
|
||||
datasource: Optional[DatasourceDict] = None,
|
||||
datasource: Optional[BaseDatasource] = None,
|
||||
extras: Optional[Dict[str, Any]] = None,
|
||||
filters: Optional[List[QueryObjectFilterClause]] = None,
|
||||
granularity: Optional[str] = None,
|
||||
|
|
@ -128,7 +127,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
|||
order_desc: bool = True,
|
||||
orderby: Optional[List[OrderBy]] = None,
|
||||
post_processing: Optional[List[Optional[Dict[str, Any]]]] = None,
|
||||
row_limit: Optional[int] = None,
|
||||
row_limit: int,
|
||||
row_offset: Optional[int] = None,
|
||||
series_columns: Optional[List[Column]] = None,
|
||||
series_limit: int = 0,
|
||||
|
|
@ -137,13 +136,12 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
|||
time_shift: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self.result_type = kwargs.get("result_type", parent_result_type)
|
||||
self._set_annotation_layers(annotation_layers)
|
||||
self.applied_time_extras = applied_time_extras or {}
|
||||
self.apply_fetch_values_predicate = apply_fetch_values_predicate or False
|
||||
self.columns = columns or []
|
||||
self._set_datasource(datasource)
|
||||
self._set_extras(extras)
|
||||
self.datasource = datasource
|
||||
self.extras = extras or {}
|
||||
self.filter = filters or []
|
||||
self.granularity = granularity
|
||||
self.is_rowcount = is_rowcount
|
||||
|
|
@ -152,14 +150,16 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
|||
self.order_desc = order_desc
|
||||
self.orderby = orderby or []
|
||||
self._set_post_processing(post_processing)
|
||||
self._set_row_limit(row_limit)
|
||||
self.row_limit = row_limit
|
||||
self.row_offset = row_offset or 0
|
||||
self._init_series_columns(series_columns, metrics, is_timeseries)
|
||||
self.series_limit = series_limit
|
||||
self.series_limit_metric = series_limit_metric
|
||||
self.set_dttms(time_range, time_shift)
|
||||
self.time_range = time_range
|
||||
self.time_shift = parse_human_timedelta(time_shift)
|
||||
self.from_dttm = kwargs.get("from_dttm")
|
||||
self.to_dttm = kwargs.get("to_dttm")
|
||||
self.result_type = kwargs.get("result_type")
|
||||
self.time_offsets = kwargs.get("time_offsets", [])
|
||||
self.inner_from_dttm = kwargs.get("inner_from_dttm")
|
||||
self.inner_to_dttm = kwargs.get("inner_to_dttm")
|
||||
|
|
@ -176,20 +176,6 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
|||
if layer["annotationType"] != "FORMULA"
|
||||
]
|
||||
|
||||
def _set_datasource(self, datasource: Optional[DatasourceDict]) -> None:
|
||||
self.datasource = None
|
||||
if datasource:
|
||||
self.datasource = ConnectorRegistry.get_datasource(
|
||||
str(datasource["type"]), int(datasource["id"]), db.session
|
||||
)
|
||||
|
||||
def _set_extras(self, extras: Optional[Dict[str, Any]]) -> None:
|
||||
self.extras = extras or {}
|
||||
if config["SIP_15_ENABLED"]:
|
||||
self.extras["time_range_endpoints"] = get_time_range_endpoints(
|
||||
form_data=self.extras
|
||||
)
|
||||
|
||||
def _set_is_timeseries(self, is_timeseries: Optional[bool]) -> None:
|
||||
# is_timeseries is True if time column is in either columns or groupby
|
||||
# (both are dimensions)
|
||||
|
|
@ -212,17 +198,8 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
|||
def _set_post_processing(
|
||||
self, post_processing: Optional[List[Optional[Dict[str, Any]]]]
|
||||
) -> None:
|
||||
self.post_processing = [
|
||||
post_proc for post_proc in post_processing or [] if post_proc
|
||||
]
|
||||
|
||||
def _set_row_limit(self, row_limit: Optional[int]) -> None:
|
||||
default_row_limit = (
|
||||
config["SAMPLES_ROW_LIMIT"]
|
||||
if self.result_type == ChartDataResultType.SAMPLES
|
||||
else config["ROW_LIMIT"]
|
||||
)
|
||||
self.row_limit = apply_max_row_limit(row_limit or default_row_limit)
|
||||
post_processing = post_processing or []
|
||||
self.post_processing = [post_proc for post_proc in post_processing if post_proc]
|
||||
|
||||
def _init_series_columns(
|
||||
self,
|
||||
|
|
@ -237,18 +214,6 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
|||
else:
|
||||
self.series_columns = []
|
||||
|
||||
def set_dttms(self, time_range: Optional[str], time_shift: Optional[str]) -> None:
|
||||
self.from_dttm, self.to_dttm = get_since_until(
|
||||
relative_start=self.extras.get(
|
||||
"relative_start", config["DEFAULT_RELATIVE_START_TIME"]
|
||||
),
|
||||
relative_end=self.extras.get(
|
||||
"relative_end", config["DEFAULT_RELATIVE_END_TIME"]
|
||||
),
|
||||
time_range=time_range,
|
||||
time_shift=time_shift,
|
||||
)
|
||||
|
||||
def _rename_deprecated_fields(self, kwargs: Dict[str, Any]) -> None:
|
||||
# rename deprecated fields
|
||||
for field in DEPRECATED_FIELDS:
|
||||
|
|
@ -439,3 +404,71 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
|||
options = post_process.get("options", {})
|
||||
df = getattr(pandas_postprocessing, operation)(df, **options)
|
||||
return df
|
||||
|
||||
|
||||
class QueryObjectFactory: # pylint: disable=too-few-public-methods
|
||||
def create( # pylint: disable=too-many-arguments
|
||||
self,
|
||||
parent_result_type: ChartDataResultType,
|
||||
datasource: Optional[DatasourceDict] = None,
|
||||
extras: Optional[Dict[str, Any]] = None,
|
||||
row_limit: Optional[int] = None,
|
||||
time_range: Optional[str] = None,
|
||||
time_shift: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> QueryObject:
|
||||
datasource_model_instance = None
|
||||
if datasource:
|
||||
datasource_model_instance = self._convert_to_model(datasource)
|
||||
processed_extras = self._process_extras(extras)
|
||||
result_type = kwargs.setdefault("result_type", parent_result_type)
|
||||
row_limit = self._process_row_limit(row_limit, result_type)
|
||||
from_dttm, to_dttm = self._get_dttms(time_range, time_shift, processed_extras)
|
||||
kwargs["from_dttm"] = from_dttm
|
||||
kwargs["to_dttm"] = to_dttm
|
||||
return QueryObject(
|
||||
datasource=datasource_model_instance,
|
||||
extras=extras,
|
||||
row_limit=row_limit,
|
||||
time_range=time_range,
|
||||
time_shift=time_shift,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource:
|
||||
return ConnectorRegistry.get_datasource(
|
||||
str(datasource["type"]), int(datasource["id"]), db.session
|
||||
)
|
||||
|
||||
def _process_extras(self, extras: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
extras = extras or {}
|
||||
if config["SIP_15_ENABLED"]:
|
||||
extras["time_range_endpoints"] = get_time_range_endpoints(form_data=extras)
|
||||
return extras
|
||||
|
||||
def _process_row_limit(
|
||||
self, row_limit: Optional[int], result_type: ChartDataResultType
|
||||
) -> int:
|
||||
default_row_limit = (
|
||||
config["SAMPLES_ROW_LIMIT"]
|
||||
if result_type == ChartDataResultType.SAMPLES
|
||||
else config["ROW_LIMIT"]
|
||||
)
|
||||
return apply_max_row_limit(row_limit or default_row_limit)
|
||||
|
||||
def _get_dttms(
|
||||
self,
|
||||
time_range: Optional[str],
|
||||
time_shift: Optional[str],
|
||||
extras: Dict[str, Any],
|
||||
) -> Tuple[Optional[datetime], Optional[datetime]]:
|
||||
return get_since_until(
|
||||
relative_start=extras.get(
|
||||
"relative_start", config["DEFAULT_RELATIVE_START_TIME"]
|
||||
),
|
||||
relative_end=extras.get(
|
||||
"relative_end", config["DEFAULT_RELATIVE_END_TIME"]
|
||||
),
|
||||
time_range=time_range,
|
||||
time_shift=time_shift,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ def get_sql_text(payload: Dict[str, Any]) -> str:
|
|||
|
||||
|
||||
class TestQueryContext(SupersetTestCase):
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_schema_deserialization(self):
|
||||
"""
|
||||
Ensure that the deserialized QueryContext contains all required fields.
|
||||
|
|
|
|||
Loading…
Reference in New Issue