fix(dashboard): Return columns and verbose_map for groupby values of Pivot Table v2 [ID-7] (#17287)

* fix(dashboard): Return columns and verbose_map for groupby values of Pivot Table v2

* Refactor

* Fix test and lint

* Fix test

* Refactor

* Fix lint
This commit is contained in:
Kamil Gabryjelski 2021-11-05 16:05:48 +01:00 committed by GitHub
parent ab1fcf3068
commit fa51b3234e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 89 additions and 14 deletions

View File

@ -282,7 +282,9 @@ class BaseDatasource(
"select_star": self.select_star,
}
def data_for_slices(self, slices: List[Slice]) -> Dict[str, Any]:
def data_for_slices( # pylint: disable=too-many-locals
self, slices: List[Slice]
) -> Dict[str, Any]:
"""
The representation of the datasource containing only the required data
to render the provided slices.
@ -317,11 +319,23 @@ class BaseDatasource(
if "column" in filter_config
)
column_names.update(
column
for column_param in COLUMN_FORM_DATA_PARAMS
for column in utils.get_iterable(form_data.get(column_param) or [])
)
# legacy charts don't have query_context charts
query_context = slc.get_query_context()
if query_context:
column_names.update(
[
column
for query in query_context.queries
for column in query.columns
]
or []
)
else:
column_names.update(
column
for column_param in COLUMN_FORM_DATA_PARAMS
for column in utils.get_iterable(form_data.get(column_param) or [])
)
filtered_metrics = [
metric
@ -639,7 +653,6 @@ class BaseColumn(AuditMixinNullable, ImportExportMixin):
class BaseMetric(AuditMixinNullable, ImportExportMixin):
"""Interface for Metrics"""
__tablename__: Optional[str] = None # {connector_name}_metric

View File

@ -184,6 +184,13 @@ def create_slices(tbl: SqlaTable, admin_owner: bool) -> Tuple[List[Slice], List[
"markup_type": "markdown",
}
default_query_context = {
"result_format": "json",
"result_type": "full",
"datasource": {"id": tbl.id, "type": "table",},
"queries": [{"columns": [], "metrics": [],},],
}
admin = get_admin_user()
if admin_owner:
slice_props = dict(
@ -362,6 +369,22 @@ def create_slices(tbl: SqlaTable, admin_owner: bool) -> Tuple[List[Slice], List[
metrics=metrics,
),
),
Slice(
**slice_props,
slice_name="Pivot Table v2",
viz_type="pivot_table_v2",
params=get_slice_json(
defaults,
viz_type="pivot_table_v2",
groupbyRows=["name"],
groupbyColumns=["state"],
metrics=[metric],
),
query_context=get_slice_json(
default_query_context,
queries=[{"columns": ["name", "state"], "metrics": [metric],}],
),
),
]
misc_slices = [
Slice(

View File

@ -40,6 +40,7 @@ from superset.utils.urls import get_url_path
from superset.viz import BaseViz, viz_types
if TYPE_CHECKING:
from superset.common.query_context import QueryContext
from superset.connectors.base.models import BaseDatasource
metadata = Model.metadata # pylint: disable=no-member
@ -247,6 +248,18 @@ class Slice( # pylint: disable=too-many-public-methods
update_time_range(form_data)
return form_data
def get_query_context(self) -> Optional["QueryContext"]:
# pylint: disable=import-outside-toplevel
from superset.common.query_context import QueryContext
if self.query_context:
try:
return QueryContext(**json.loads(self.query_context))
except json.decoder.JSONDecodeError as ex:
logger.error("Malformed json in slice's query context", exc_info=True)
logger.exception(ex)
return None
def get_explore_url(
self,
base_url: str = "/superset/explore",

View File

@ -790,7 +790,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
rv = self.get_assert_metric(uri, "get_list")
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(data["count"], 33)
self.assertEqual(data["count"], 34)
def test_get_charts_changed_on(self):
"""
@ -1040,7 +1040,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
"""
Chart API: Test get charts filter
"""
# Assuming we have 33 sample charts
# Assuming we have 34 sample charts
self.login(username="admin")
arguments = {"page_size": 10, "page": 0}
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
@ -1054,7 +1054,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
rv = self.get_assert_metric(uri, "get_list")
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(len(data["result"]), 3)
self.assertEqual(len(data["result"]), 4)
def test_get_charts_no_data_access(self):
"""

View File

@ -1099,7 +1099,7 @@ class TestDatabaseApi(SupersetTestCase):
rv = self.get_assert_metric(uri, "related_objects")
self.assertEqual(rv.status_code, 200)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(response["charts"]["count"], 33)
self.assertEqual(response["charts"]["count"], 34)
self.assertEqual(response["dashboards"]["count"], 3)
def test_get_database_related_objects_not_found(self):

View File

@ -518,7 +518,7 @@ class TestSqlaTableModel(SupersetTestCase):
self.assertTrue("Metric 'invalid' does not exist", context.exception)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_data_for_slices(self):
def test_data_for_slices_with_no_query_context(self):
tbl = self.get_table(name="birth_names")
slc = (
metadata_db.session.query(Slice)
@ -532,9 +532,35 @@ class TestSqlaTableModel(SupersetTestCase):
assert len(data_for_slices["columns"]) == 1
assert data_for_slices["metrics"][0]["metric_name"] == "sum__num"
assert data_for_slices["columns"][0]["column_name"] == "gender"
assert set(data_for_slices["verbose_map"].keys()) == set(
["__timestamp", "sum__num", "gender",]
assert set(data_for_slices["verbose_map"].keys()) == {
"__timestamp",
"sum__num",
"gender",
}
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_data_for_slices_with_query_context(self):
tbl = self.get_table(name="birth_names")
slc = (
metadata_db.session.query(Slice)
.filter_by(
datasource_id=tbl.id,
datasource_type=tbl.type,
slice_name="Pivot Table v2",
)
.first()
)
data_for_slices = tbl.data_for_slices([slc])
assert len(data_for_slices["metrics"]) == 1
assert len(data_for_slices["columns"]) == 2
assert data_for_slices["metrics"][0]["metric_name"] == "sum__num"
assert data_for_slices["columns"][0]["column_name"] == "name"
assert set(data_for_slices["verbose_map"].keys()) == {
"__timestamp",
"sum__num",
"name",
"state",
}
def test_literal_dttm_type_factory():