Skip to content

Commit

Permalink
feat: Add dataset visualization (#701)
Browse files Browse the repository at this point in the history
  • Loading branch information
MalvinaNikandrou authored Feb 8, 2023
1 parent 8057ffa commit da6b764
Show file tree
Hide file tree
Showing 2 changed files with 240 additions and 3 deletions.
193 changes: 193 additions & 0 deletions notebooks/simbot_dataset_visualization_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import argparse
import json
import logging
from collections import defaultdict
from collections.abc import Collection
from pathlib import Path
from typing import Any

import gradio as gr
import plotly
from plotly.subplots import make_subplots
from tqdm import tqdm
from transformers import AutoTokenizer

from emma_policy.datamodules.simbot_action_dataset import SimBotActionDataset
from emma_policy.datamodules.simbot_nlu_dataset import SimBotNLUDataset


logging.basicConfig()
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)

data_type = dict[str, Collection[Any]]


class DatasetVisualizer:
"""Visualize the SimBot dataset distribution."""

def __init__(self, data: data_type, total_examples: int) -> None:
self.data = data
self.total_examples = total_examples
self._accent_color = "#ffb400" # purple
self._base_color = "#9080ff" # mustard

def get_data_visualization(self, subset_name: str = "") -> plotly.graph_objs.Figure:
"""Prepare the output for a subset of the data."""
data_subset = self.data.get(subset_name, None)
fig = make_subplots(rows=2, cols=1, specs=[[{"type": "histogram"}], [{"type": "pie"}]])
if not data_subset:
return fig
fig.append_trace(
plotly.graph_objects.Histogram(
x=data_subset, showlegend=False, marker={"color": self._accent_color}
),
row=1,
col=1,
)
fig.append_trace(
plotly.graph_objects.Pie(
values=[len(data_subset), self.total_examples - len(data_subset)],
labels=[subset_name, "other"],
marker={"colors": [self._accent_color, self._base_color]},
),
row=2,
col=1,
)
return fig


def get_data_from_action_dataset(args: argparse.Namespace) -> dict[str, Any]:
"""Get the visualization data from the action dataset."""
train_dataset = SimBotActionDataset(
dataset_db_path=args.dataset_db,
tokenizer=AutoTokenizer.from_pretrained("heriot-watt/emma-base"),
)
data = []
data_per_object = defaultdict(list)
data_per_action = defaultdict(list)

for index, instance in tqdm(enumerate(train_dataset)): # type: ignore[arg-type]
data.append(instance.raw_target["action_type"])
if instance.raw_target["object_type"] is None:
continue
data_per_object[instance.raw_target["object_type"]].append(
instance.raw_target["action_type"]
)
data_per_action[instance.raw_target["action_type"]].append(
instance.raw_target["object_type"]
)

if index > len(train_dataset) - 1:
break
data_dict = {"overall": data, "per_object": data_per_object, "per_action": data_per_action}
with open(args.cache_dir, "w") as file_out:
json.dump(data_dict, file_out)
return data_dict


def get_data_from_nlu_dataset(args: argparse.Namespace) -> dict[str, Any]:
"""Get the visualization data from the NLU dataset."""
train_dataset = SimBotNLUDataset(
dataset_db_path=args.dataset_db,
tokenizer=AutoTokenizer.from_pretrained("heriot-watt/emma-base"),
is_train=True,
)
data = []
data_per_object = defaultdict(list)
data_per_action = defaultdict(list)

for index, instance in tqdm(enumerate(train_dataset)): # type: ignore[arg-type]
data.append(instance.raw_target["nlu_class"])
data_per_object[instance.raw_target["object_type"]].append(
instance.raw_target["nlu_class"]
)
data_per_action[instance.raw_target["action_type"]].append(
instance.raw_target["nlu_class"]
)

if index == len(train_dataset) - 1:
break
data_dict = {"overall": data, "per_object": data_per_object, "per_action": data_per_action}
with open(args.cache_dir, "w") as file_out:
json.dump(data_dict, file_out)
return data_dict


def get_data_for_visualization(args: argparse.Namespace) -> dict[str, Any]:
"""Get the data for the visualization."""
if args.cache_dir.exists():
with open(args.cache_dir) as file_in:
return json.load(file_in)
elif args.dataset_type == "nlu":
return get_data_from_nlu_dataset(args)
return get_data_from_action_dataset(args)


def main(args: argparse.Namespace) -> None:
"""Main."""
data = get_data_for_visualization(args)
total_examples = len(data["overall"])
object_visualizer = DatasetVisualizer(data["per_object"], total_examples=total_examples)
action_visualizer = DatasetVisualizer(data["per_action"], total_examples=total_examples)
with gr.Blocks() as block:
with gr.Row():
gr.Plot(
plotly.graph_objects.Figure(
data=[
plotly.graph_objects.Histogram(
x=data["overall"], marker={"color": "#9080ff"}
)
]
),
label="Overall Label Distribution",
)
with gr.Row():
object_types = sorted(set(data["per_object"].keys()))
input_object = gr.Dropdown(object_types, label="Object")
object_plot = gr.Plot(label="Distribution per Object")
with gr.Row():
action_types = sorted(set(data["per_action"].keys()))
input_action = gr.Dropdown(action_types, label="Action")
action_plot = gr.Plot(label="Distribution per Action")

input_object.change(
fn=object_visualizer.get_data_visualization,
inputs=[input_object],
outputs=[object_plot],
)
input_action.change(
fn=action_visualizer.get_data_visualization,
inputs=[input_action],
outputs=[action_plot],
)
block.launch(share=args.share)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset_db",
default=Path("storage/db/simbot_actions_train.db"),
type=Path,
help="Path the simbot dataset db.",
)
parser.add_argument(
"--cache_dir",
default=Path("storage/db/action_app_cache1.json"),
type=Path,
help="Path the simbot dataset cache.",
)
parser.add_argument(
"--dataset_type",
type=str,
choices=["nlu", "action"],
help="Type of the dataset",
)
parser.add_argument(
"--share",
help="Create a publicly shareable link from your computer for the interface",
action="store_true",
)
args = parser.parse_args()
main(args)
50 changes: 47 additions & 3 deletions src/emma_policy/datamodules/simbot_action_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import random
from pathlib import Path
from typing import Union
from typing import Optional, Union

import torch
from emma_datasets.constants.simbot.simbot import get_arena_definitions
Expand Down Expand Up @@ -230,6 +230,7 @@ def simbot_vision_augmentation( # noqa: WPS210, WPS231
ground_truth_bboxes = action_object_metadata["mask"]
if ground_truth_bboxes is None or select_negative:
target_text = f"no {object_name} <stop>."
action_type = "search_negative"
else:
ground_truth_bbox = ground_truth_bboxes[object_candidate_idx]
ground_truth_bbox = torch.tensor(
Expand All @@ -254,8 +255,10 @@ def simbot_vision_augmentation( # noqa: WPS210, WPS231
visual_features.scene_frame_tokens[0]
)
target_text = f"{scene_frame_token} {object_token} <stop>."
action_type = "search_positive"
else:
target_text = f"no {object_name} <stop>."
action_type = "search_negative"

target_text = target_text.lower()

Expand All @@ -281,7 +284,13 @@ def simbot_vision_augmentation( # noqa: WPS210, WPS231
# Now shift them to the right
decoder_input_ids[1:] = full_target_token_ids[:-1].clone() # noqa: WPS362
decoder_attention_mask = torch.ones_like(decoder_input_ids)
raw_target = f"mission{instance.mission_id}_instr{instance.instruction_id}_ann{instance.annotation_id}_action{instance.actions[-1].type}" # noqa: WPS221
raw_target = {
"instance_id": self._get_instance_id(instance),
"instruction": source_text,
"target": target_text,
"action_type": action_type,
"object_type": object_name,
}

return EmmaDatasetItem(
input_token_ids=input_encoding.input_ids.squeeze(0),
Expand Down Expand Up @@ -353,7 +362,13 @@ def simbot_action_execution(self, instance: SimBotInstructionInstance) -> EmmaDa
# Now shift them to the right
decoder_input_ids[1:] = full_target_token_ids[:-1].clone() # noqa: WPS362
decoder_attention_mask = torch.ones_like(decoder_input_ids)
raw_target = f"mission{instance.mission_id}_instr{instance.instruction_id}_ann{instance.annotation_id}_action{instance.actions[-1].type}" # noqa: WPS221
raw_target = {
"instance_id": self._get_instance_id(instance),
"instruction": source_text,
"target": target_text,
"action_type": instance.actions[-1].type,
"object_type": self._get_target_object(instance.actions[-1]),
}

return EmmaDatasetItem(
input_token_ids=input_encoding.input_ids.squeeze(0),
Expand Down Expand Up @@ -611,3 +626,32 @@ def _load_visual_features(
visual_features = self._prepare_emma_visual_features(feature_dicts=feature_dicts)

return visual_features, frames, objects_per_frame

def _get_instance_id(self, instance: SimBotInstructionInstance) -> str:
"""Construct the instance id."""
instruction_id = f"mission{instance.mission_id}_instr{instance.instruction_id}"
return f"{instruction_id}_ann{instance.annotation_id}_action{instance.actions[-1].type}"

def _get_target_object(self, action: SimBotAction) -> Optional[str]:
"""Prepare the object name."""
action_type = action.type
# case 1: navigation actions except GoTo
if action_type in {"Look", "Move", "Rotate", "Turn"}:
return None

action_object_metadata = action.get_action_data["object"]
# case 2: room/object navigation or interaction action
object_id = action_object_metadata.get("id", None)
# action with a specific object
if object_id is not None:
object_name = get_object_readable_name_from_object_id(
object_id=action_object_metadata["id"],
object_assets_to_names=self._object_assets_to_names,
special_name_cases=self._special_name_cases,
)
# action without an object (e.g, Goto Office)
else:
# {'object': {'officeRoom': 'Lab1'}}
object_name = list(action_object_metadata.values())[0]

return object_name

0 comments on commit da6b764

Please sign in to comment.