Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(dashboard): Return columns and verbose_map for groupby values of Pivot Table v2 [ID-7] #17287

Merged
merged 6 commits into from
Nov 5, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions superset/connectors/base/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,11 +317,23 @@ def data_for_slices(self, slices: List[Slice]) -> Dict[str, Any]:
if "column" in filter_config
)

column_names.update(
column
for column_param in COLUMN_FORM_DATA_PARAMS
for column in utils.get_iterable(form_data.get(column_param) or [])
)
# legacy charts don't have query_context charts
if slc.query_context:
query_context = slc.get_query_context()
column_names.update(
[
column
for query in query_context.get("queries", [])
for column in query.get("columns", [])
]
or []
)
else:
column_names.update(
column
for column_param in COLUMN_FORM_DATA_PARAMS
for column in utils.get_iterable(form_data.get(column_param) or [])
)
zhaoyongjie marked this conversation as resolved.
Show resolved Hide resolved

filtered_metrics = [
metric
Expand Down
15 changes: 15 additions & 0 deletions superset/examples/birth_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,21 @@ def create_slices(
metrics=metrics,
),
),
Slice(
**slice_props,
slice_name="Pivot Table v2",
viz_type="pivot_table_v2",
params=get_slice_json(
defaults,
viz_type="pivot_table_v2",
groupbyRows=["name"],
groupbyColumns=["state"],
metrics=[metric],
),
query_context=get_slice_json(
{"queries": [{"columns": ["name", "state"], "metrics": [metric],}]}
),
),
]
misc_slices = [
Slice(
Expand Down
11 changes: 11 additions & 0 deletions superset/models/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,17 @@ def form_data(self) -> Dict[str, Any]:
update_time_range(form_data)
return form_data

def get_query_context(self) -> Dict[str, Any]:
query_context: Dict[str, Any] = {}
if not self.query_context:
return query_context
try:
query_context = json.loads(self.query_context)
except json.decoder.JSONDecodeError as ex:
logger.error("Malformed json in slice's query context", exc_info=True)
logger.exception(ex)
return query_context

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can make this function more generic. like this:

    def get_query_context(self) -> QueryContext:
        if self.query_context:
            try:
                return self.query_context(**json.loads(self.query_context))
            except json.decoder.JSONDecodeError as ex:
                logger.error("Malformed json in slice's query context", exc_info=True)
                logger.exception(ex)
        return QueryContext(
                       datasource={"type": self.datasource_type, "id": self.datasource_id},
                       queries=[self.viz.query_obj()],)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhaoyongjie on the first return statement I believe we could instantiate the QueryContext. Also, for non v1 charts or in the case there isn't a query context payload, we probably should just return None, as the legacy charts aren't really compatible with QueryContext. Could this work?

    def get_query_context(self) -> Optional[QueryContext]:
        if self.query_context:
            try:
                return QueryContext(**json.loads(self.query_context))
            except json.decoder.JSONDecodeError as ex:
                logger.error("Malformed json in slice's query context", exc_info=True)
                logger.exception(ex)
        return None

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! It makes sense return None in legacy chart.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for suggestions! I made the change, would appreciate another look 🙂

def get_explore_url(
self,
base_url: str = "/superset/explore",
Expand Down
32 changes: 29 additions & 3 deletions tests/integration_tests/model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def test_query_with_non_existent_metrics(self):
self.assertTrue("Metric 'invalid' does not exist", context.exception)

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_data_for_slices(self):
def test_data_for_slices_with_no_query_context(self):
tbl = self.get_table(name="birth_names")
slc = (
metadata_db.session.query(Slice)
Expand All @@ -532,9 +532,35 @@ def test_data_for_slices(self):
assert len(data_for_slices["columns"]) == 1
assert data_for_slices["metrics"][0]["metric_name"] == "sum__num"
assert data_for_slices["columns"][0]["column_name"] == "gender"
assert set(data_for_slices["verbose_map"].keys()) == set(
["__timestamp", "sum__num", "gender",]
assert set(data_for_slices["verbose_map"].keys()) == {
"__timestamp",
"sum__num",
"gender",
}

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_data_for_slices_with_query_context(self):
tbl = self.get_table(name="birth_names")
slc = (
metadata_db.session.query(Slice)
.filter_by(
datasource_id=tbl.id,
datasource_type=tbl.type,
slice_name="Pivot Table v2",
)
.first()
)
data_for_slices = tbl.data_for_slices([slc])
assert len(data_for_slices["metrics"]) == 1
assert len(data_for_slices["columns"]) == 2
assert data_for_slices["metrics"][0]["metric_name"] == "sum__num"
assert data_for_slices["columns"][0]["column_name"] == "name"
assert set(data_for_slices["verbose_map"].keys()) == {
"__timestamp",
"sum__num",
"name",
"state",
}


def test_literal_dttm_type_factory():
Expand Down