feat: fine-grain chart data telemetry (#31273)

This commit is contained in:
Beto Dealmeida 2024-12-10 13:09:39 -05:00 committed by GitHub
parent 232e2055aa
commit d6a82f7852
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 59 additions and 29 deletions

View File

@ -399,17 +399,19 @@ class ChartDataRestApi(ChartRestApi):
for query in queries: for query in queries:
with contextlib.suppress(KeyError): with contextlib.suppress(KeyError):
del query["query"] del query["query"]
response_data = json.dumps( with event_logger.log_context(f"{self.__class__.__name__}.json_dumps"):
{"result": queries}, response_data = json.dumps(
default=json.json_int_dttm_ser, {"result": queries},
ignore_nan=True, default=json.json_int_dttm_ser,
) ignore_nan=True,
)
resp = make_response(response_data, 200) resp = make_response(response_data, 200)
resp.headers["Content-Type"] = "application/json; charset=utf-8" resp.headers["Content-Type"] = "application/json; charset=utf-8"
return resp return resp
return self.response_400(message=f"Unsupported result_format: {result_format}") return self.response_400(message=f"Unsupported result_format: {result_format}")
@event_logger.log_this
def _get_data_response( def _get_data_response(
self, self,
command: ChartDataCommand, command: ChartDataCommand,
@ -435,11 +437,13 @@ class ChartDataRestApi(ChartRestApi):
) -> dict[str, Any]: ) -> dict[str, Any]:
return { return {
"dashboard_id": form_data.get("form_data", {}).get("dashboardId"), "dashboard_id": form_data.get("form_data", {}).get("dashboardId"),
"dataset_id": form_data.get("datasource", {}).get("id") "dataset_id": (
if isinstance(form_data.get("datasource"), dict) form_data.get("datasource", {}).get("id")
and form_data.get("datasource", {}).get("type") if isinstance(form_data.get("datasource"), dict)
== DatasourceType.TABLE.value and form_data.get("datasource", {}).get("type")
else None, == DatasourceType.TABLE.value
else None
),
"slice_id": form_data.get("form_data", {}).get("slice_id"), "slice_id": form_data.get("form_data", {}).get("slice_id"),
} }

View File

@ -34,6 +34,7 @@ import pandas as pd
from flask_babel import gettext as __ from flask_babel import gettext as __
from superset.common.chart_data import ChartDataResultFormat from superset.common.chart_data import ChartDataResultFormat
from superset.extensions import event_logger
from superset.utils.core import ( from superset.utils.core import (
extract_dataframe_dtypes, extract_dataframe_dtypes,
get_column_names, get_column_names,
@ -296,6 +297,7 @@ post_processors = {
} }
@event_logger.log_this
def apply_post_process( def apply_post_process(
result: dict[Any, Any], result: dict[Any, Any],
form_data: Optional[dict[str, Any]] = None, form_data: Optional[dict[str, Any]] = None,
@ -344,15 +346,19 @@ def apply_post_process(
# `Tuple[str]`. Otherwise encoding to JSON later will fail because # `Tuple[str]`. Otherwise encoding to JSON later will fail because
# maps cannot have tuples as their keys in JSON. # maps cannot have tuples as their keys in JSON.
processed_df.columns = [ processed_df.columns = [
" ".join(str(name) for name in column).strip() (
if isinstance(column, tuple) " ".join(str(name) for name in column).strip()
else column if isinstance(column, tuple)
else column
)
for column in processed_df.columns for column in processed_df.columns
] ]
processed_df.index = [ processed_df.index = [
" ".join(str(name) for name in index).strip() (
if isinstance(index, tuple) " ".join(str(name) for name in index).strip()
else index if isinstance(index, tuple)
else index
)
for index in processed_df.index for index in processed_df.index
] ]

View File

@ -75,7 +75,11 @@ from superset.extensions import (
from superset.models.helpers import AuditMixinNullable, ImportExportMixin from superset.models.helpers import AuditMixinNullable, ImportExportMixin
from superset.result_set import SupersetResultSet from superset.result_set import SupersetResultSet
from superset.sql_parse import Table from superset.sql_parse import Table
from superset.superset_typing import OAuth2ClientConfig, ResultSetColumnType from superset.superset_typing import (
DbapiDescription,
OAuth2ClientConfig,
ResultSetColumnType,
)
from superset.utils import cache as cache_util, core as utils, json from superset.utils import cache as cache_util, core as utils, json
from superset.utils.backports import StrEnum from superset.utils.backports import StrEnum
from superset.utils.core import DatasourceName, get_username from superset.utils.core import DatasourceName, get_username
@ -667,7 +671,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
) )
return sql_ return sql_
def get_df( # pylint: disable=too-many-locals def get_df(
self, self,
sql: str, sql: str,
catalog: str | None = None, catalog: str | None = None,
@ -700,21 +704,37 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
object_ref=__name__, object_ref=__name__,
): ):
self.db_engine_spec.execute(cursor, sql_, self) self.db_engine_spec.execute(cursor, sql_, self)
if i < len(sqls) - 1:
# If it's not the last, we don't keep the results rows = self.fetch_rows(cursor, i == len(sqls) - 1)
cursor.fetchall() if rows is not None:
else: df = self.load_into_dataframe(cursor.description, rows)
# Last query, fetch and process the results
data = self.db_engine_spec.fetch_data(cursor)
result_set = SupersetResultSet(
data, cursor.description, self.db_engine_spec
)
df = result_set.to_pandas_df()
if mutator: if mutator:
df = mutator(df) df = mutator(df)
return self.post_process_df(df) return self.post_process_df(df)
@event_logger.log_this
def fetch_rows(self, cursor: Any, last: bool) -> list[tuple[Any, ...]] | None:
if not last:
cursor.fetchall()
return None
return self.db_engine_spec.fetch_data(cursor)
@event_logger.log_this
def load_into_dataframe(
self,
description: DbapiDescription,
data: list[tuple[Any, ...]],
) -> pd.DataFrame:
result_set = SupersetResultSet(
data,
description,
self.db_engine_spec,
)
return result_set.to_pandas_df()
def compile_sqla_query( def compile_sqla_query(
self, self,
qry: Select, qry: Select,

View File

@ -1710,7 +1710,7 @@ def test_alert_limit_is_applied(
with patch.object( with patch.object(
create_alert_email_chart.database.db_engine_spec, create_alert_email_chart.database.db_engine_spec,
"fetch_data", "fetch_data",
return_value=None, return_value=[],
): # noqa: F841 ): # noqa: F841
AsyncExecuteReportScheduleCommand( AsyncExecuteReportScheduleCommand(
TEST_ID, create_alert_email_chart.id, datetime.utcnow() TEST_ID, create_alert_email_chart.id, datetime.utcnow()