Skip to content

Commit

Permalink
feat: search itm (#680)
Browse files Browse the repository at this point in the history
Co-authored-by: MalvinaNikandrou <[email protected]>
  • Loading branch information
gpantaz and MalvinaNikandrou authored Jan 26, 2023
1 parent 8a5b325 commit e4fa938
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 48 deletions.
9 changes: 1 addition & 8 deletions src/emma_policy/datamodules/simbot_action_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,18 +186,11 @@ def _compute_sample_weights(self, dataset_db: Path) -> list[float]:
db = DatasetDb(dataset_db)
# First pass through the dataset to get action type counts
actions = []
subsampled_weight: list[int] = []
for _, _, instance_str in db:
instance = SimBotInstructionInstance.parse_raw(instance_str)
if self._skip_instance(instance.actions[-1]):
subsampled_weight.append(0)
else:
subsampled_weight.append(1)
actions.append(self._get_action_type(instance.actions[-1]))

data_weights = compute_weights(
actions, temperature=self._weight_temperature, subsampled_weight=subsampled_weight
)
data_weights = compute_weights(actions, temperature=self._weight_temperature)

return data_weights

Expand Down
88 changes: 73 additions & 15 deletions src/emma_policy/datamodules/simbot_action_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
SimBotInstructionInstance,
SimBotObjectAttributes,
)
from emma_datasets.db import DatasetDb
from overrides import overrides
from torchvision.ops import masks_to_boxes
from transformers import PreTrainedTokenizer
Expand Down Expand Up @@ -48,8 +49,8 @@ def get_simbot_instruction_paraphrase(
paraphraser: InstructionParaphraser, instance: SimBotInstructionInstance, object_name: str
) -> str:
"""Paraphrase a SimBot instruction."""
action_type = instance.actions[0].type.lower()
action_object_metadata = instance.actions[0].get_action_data["object"]
action_type = instance.actions[-1].type.lower()
action_object_metadata = instance.actions[-1].get_action_data["object"]
attributes = SimBotObjectAttributes(
**action_object_metadata.get("attributes", {"readable_name": object_name})
)
Expand All @@ -61,14 +62,53 @@ def get_simbot_instruction_paraphrase(
)


def check_punctuation(text: str) -> str:
def format_instruction(text: str) -> str:
"""Make sure the instruction ends in a fullstop."""
if not text.endswith(("?", ".")):
text = f"{text}."
text = text.replace("..", ".")
return text.lower()


class SearchNegativeSampler:
"""Search negative selection class.
Used to sample negative examples for the search objective. Creates a dictionary where keys
indices of the positive examples in the dataset and values are the readable names of each
object in the example. Given a readable name for an object, it samples keys from the dictionary
until the readable name is not present in the list of objects for a key.
"""

def __init__(self, db: DatasetDb):
self._positive_indices_map = self._create_positive_indices_objects_map(db)

def __call__(self, readable_name: str) -> int:
"""Sample a negative example."""
while True:
rand_idx = random.choice(list(self._positive_indices_map.keys()))
if readable_name.lower() not in self._positive_indices_map[rand_idx]:
return rand_idx

def _create_positive_indices_objects_map(self, db: DatasetDb) -> dict[int, list[str]]:
"""Create a map of indices and positive examples."""
db_size = len(db)
positive_indices_map = {}
with db:
for index in range(db_size):
instance_str: str = db[index]
instance = SimBotInstructionInstance.parse_raw(instance_str)

action = instance.actions[-1]
if action.type == "Search" and action.search["object"]["mask"] is not None:
attributes = action.search["object"].get("attributes", None)
if attributes is not None:
readable_names = [ # noqa: WPS220
attribute["readable_name"].lower() for attribute in attributes
]
positive_indices_map[index] = readable_names # noqa: WPS220
return positive_indices_map


class SimBotActionDataset(EmmaBaseDataset[EmmaDatasetItem]):
"""Dataset for SimBotAction.
Expand All @@ -81,6 +121,7 @@ def __init__(
dataset_db_path: Path,
tokenizer: PreTrainedTokenizer,
iou_threshold: float = 0.5,
search_negative_proba: float = 0.5,
max_frames: int = 15,
use_only_necessary_questions: bool = True,
allow_paraphrasing: bool = False,
Expand All @@ -92,6 +133,7 @@ def __init__(

self._iou_threshold = iou_threshold
self._goto_proba = 0
self._search_negative_proba = search_negative_proba
self._use_only_necessary_questions = use_only_necessary_questions
self.question_answer_prompt = "<<driver>> {question} <<commander>> {answer}"
arena_definitions = get_arena_definitions()
Expand All @@ -100,6 +142,8 @@ def __init__(
self._image_height = arena_definitions["image_height"]
self._paraphraser = InstructionParaphraser()
self._allow_paraphrasing = allow_paraphrasing
self._search_negative_sampler = SearchNegativeSampler(self.db)
self._special_name_cases = arena_definitions["special_asset_to_readable_name"]

@overrides(check_signature=False)
def __getitem__(self, index: int) -> EmmaDatasetItem:
Expand All @@ -112,12 +156,12 @@ def __getitem__(self, index: int) -> EmmaDatasetItem:
return self.simbot_vision_augmentation(instance)
return self.simbot_action_execution(instance)

def simbot_vision_augmentation( # noqa: WPS210
def simbot_vision_augmentation( # noqa: WPS210, WPS231
self, instance: SimBotInstructionInstance
) -> EmmaDatasetItem:
"""Process a visual augmentation instance for the SimBot action task."""
if instance.actions[0].type == "Search":
action_object_metadata = instance.actions[0].get_action_data["object"]
if instance.actions[-1].type == "Search":
action_object_metadata = instance.actions[-1].get_action_data["object"]

object_candidates = len(action_object_metadata["id"])
object_candidate_idx = random.choice(range(object_candidates))
Expand All @@ -139,22 +183,36 @@ def simbot_vision_augmentation( # noqa: WPS210
source_text = self._get_random_template_for_task(Task.visual_grounding).format(
caption=source_text
)
source_text = check_punctuation(source_text)
source_text = format_instruction(source_text)

object_name = get_object_label_from_object_id(
object_id=action_object_metadata["id"][object_candidate_idx],
object_assets_to_names=self._object_assets_to_names,
)

target_frames = [0 for _ in instance.actions]

visual_features, _, _ = self._load_visual_features(
features_path=instance.features_path,
target_frames=target_frames,
# We need to skip the instances that are from annotations aka paraphrasable
# TODO: we need to make this easier
select_negative = (
random.random() >= self._search_negative_proba and instance.paraphrasable
)
if select_negative:
negative_idx = self._search_negative_sampler(
action_object_metadata["attributes"][object_candidate_idx]["readable_name"]
)
instance_str = self.db[negative_idx]
negative_instance = SimBotInstructionInstance.parse_raw(instance_str)
visual_features, _, _ = self._load_visual_features(
features_path=negative_instance.features_path,
target_frames=[0 for _ in negative_instance.actions],
)
else:
visual_features, _, _ = self._load_visual_features(
features_path=instance.features_path,
target_frames=[0 for _ in instance.actions],
)

ground_truth_bboxes = action_object_metadata["mask"]
if ground_truth_bboxes is None:
if ground_truth_bboxes is None or select_negative:
target_text = f"no {object_name} <stop>."
else:
ground_truth_bbox = ground_truth_bboxes[object_candidate_idx]
Expand Down Expand Up @@ -317,7 +375,7 @@ def _prepare_source_text(self, instance: SimBotInstructionInstance) -> str:
and answer: where is the hammer? the hammer is on the table in the robotics lab.
"""
if self._allow_paraphrasing and instance.paraphrasable:
action_object_metadata = instance.actions[0].get_action_data["object"]
action_object_metadata = instance.actions[-1].get_action_data["object"]
object_name = get_object_label_from_object_id(
object_id=action_object_metadata["id"],
object_assets_to_names=self._object_assets_to_names,
Expand Down Expand Up @@ -346,7 +404,7 @@ def _prepare_source_text(self, instance: SimBotInstructionInstance) -> str:
)
source_text = f"{source_text}. {question_answer_text}"

source_text = check_punctuation(source_text)
source_text = format_instruction(source_text)

return source_text

Expand Down
73 changes: 50 additions & 23 deletions src/emma_policy/datamodules/simbot_nlu_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from emma_policy.datamodules.base_dataset import EmmaBaseDataset
from emma_policy.datamodules.emma_dataclasses import EmmaDatasetItem, EmmaVisualFeatures
from emma_policy.datamodules.simbot_action_dataset import (
check_punctuation,
SearchNegativeSampler,
format_instruction,
get_simbot_instruction_paraphrase,
)
from emma_policy.utils import get_logger
Expand Down Expand Up @@ -87,6 +88,7 @@ def __init__(
max_frames: int = 4,
is_train: bool = True,
iou_threshold: float = 0.5,
search_negative_proba: float = 0.5,
) -> None:
super().__init__(
dataset_db_path=dataset_db_path, tokenizer=tokenizer, max_frames=max_frames
Expand All @@ -104,11 +106,13 @@ def __init__(
arena_definitions = get_arena_definitions()
self._object_assets_to_names = arena_definitions["asset_to_label"]
self._label_to_idx = arena_definitions["label_to_idx"]
self._special_name_cases = get_arena_definitions()["special_asset_to_readable_name"]
self._special_name_cases = arena_definitions["special_asset_to_readable_name"]
self._image_width = arena_definitions["image_width"]
self._image_height = arena_definitions["image_height"]
self._paraphraser = InstructionParaphraser()
self._iou_threshold = iou_threshold
self._search_negative_proba = search_negative_proba
self._search_negative_sampler = SearchNegativeSampler(self.db)

@overrides(check_signature=False)
def __len__(self) -> int:
Expand All @@ -122,11 +126,16 @@ def __getitem__(self, index: int) -> EmmaDatasetItem:
instance_str = self.db[index]
instance = SimBotInstructionInstance.parse_raw(instance_str)
first_action = instance.actions[0]
frame_idx = 0
if first_action.type == "Search":
instruction, target_text = self.prepare_search_instance(instance)
instruction, visual_features, target_text = self.prepare_search_instance(instance)
else:
instruction, target_text = self.prepare_action_instance(instance)
source_text = f"Predict the system act: {check_punctuation(instruction)}"
frame_idx = self._get_instance_frame(instance, target_text.lower())
visual_features = self._load_visual_features(
features_path=instance.features_path[0], frame_idx=frame_idx
)
source_text = f"Predict the system act: {format_instruction(instruction)}"
target_text = target_text.lower()
input_encoding = self.tokenizer.encode_plus(
source_text, return_tensors=self._return_tensor_type, truncation=True
Expand All @@ -135,15 +144,13 @@ def __getitem__(self, index: int) -> EmmaDatasetItem:
target_text, return_tensors=self._return_tensor_type, truncation=True
)

frame_idx = self._get_instance_frame(instance, target_text)
visual_features = self._load_visual_features(
features_path=instance.features_path[0], frame_idx=frame_idx
)

raw_target = {
"example_id": f"{instance.mission_id}_{instance.annotation_id}_{instance.instruction_id}",
"references": target_text,
"instruction": source_text,
"nlu_class": target_text.split()[0],
"object_type": " ".join(target_text.split()[1:]),
"action_type": first_action.type,
}

return EmmaDatasetItem(
Expand Down Expand Up @@ -204,33 +211,53 @@ def prepare_synthetic_action_instance(
instruction, target_text = self._augment_synthetic_action(instance)
return instruction, target_text

def prepare_search_instance(self, instance: SimBotInstructionInstance) -> tuple[str, str]:
def prepare_search_instance(
self, instance: SimBotInstructionInstance
) -> tuple[str, EmmaVisualFeatures, str]:
"""Get source and target text for Search instructions."""
# Select the object
action_object_metadata = instance.actions[0].get_action_data["object"]
object_candidates = len(action_object_metadata["id"])
object_candidate_idx = random.choice(range(object_candidates))
object_id = action_object_metadata["id"][object_candidate_idx]
# Always paraphrase after selecting a target for search
instruction = self._paraphraser(
action_type="search",
object_id=object_id,
object_attributes=SimBotObjectAttributes(
**action_object_metadata["attributes"][object_candidate_idx]
),
)

# Prepare the instruction
if instance.paraphrasable:
instruction = self._paraphraser(
action_type="search",
object_id=object_id,
object_attributes=SimBotObjectAttributes(
**action_object_metadata["attributes"][object_candidate_idx]
),
)
else:
instruction = instance.instruction.instruction

# Prepare the visual_features and target
object_readable_name = get_object_readable_name_from_object_id(
object_id=object_id,
object_assets_to_names=self._object_assets_to_names,
special_name_cases=self._special_name_cases,
)

if action_object_metadata["mask"] is None:
negative_proba = random.random() >= self._search_negative_proba
# We need to skip the instances that are from annotations aka paraphrasable
select_negative = negative_proba and instance.paraphrasable
if select_negative:
negative_idx = self._search_negative_sampler(object_readable_name)
negative_instance = SimBotInstructionInstance.parse_raw(self.db[negative_idx])
visual_features = self._load_visual_features(
features_path=negative_instance.features_path[0]
)
target_text = f"{SimBotNLUIntents.search_no_match.value} {object_readable_name}"
elif action_object_metadata["mask"] is None:
# A negative search sample
visual_features = self._load_visual_features(features_path=instance.features_path[0])
target_text = f"{SimBotNLUIntents.search_no_match.value} {object_readable_name}"
else:
ground_truth_bboxes = action_object_metadata["mask"]
ground_truth_bbox = ground_truth_bboxes[object_candidate_idx]
# A positive search sample
ground_truth_bbox = torch.tensor(
ground_truth_bbox,
action_object_metadata["mask"][object_candidate_idx],
dtype=torch.float32,
).unsqueeze(0)

Expand All @@ -253,7 +280,7 @@ def prepare_search_instance(self, instance: SimBotInstructionInstance) -> tuple[
f"{SimBotNLUIntents.search_too_many_matches.value} {object_readable_name}"
)

return instruction, target_text
return instruction, visual_features, target_text

def _augment_synthetic_action(self, instance: SimBotInstructionInstance) -> tuple[str, str]:
"""Prepare the instruction and target text for a synthetic unambiguous instruction.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
EmmaDatasetItem,
EmmaVisualFeatures,
)
from emma_policy.datamodules.simbot_action_dataset import check_punctuation
from emma_policy.datamodules.simbot_action_dataset import format_instruction
from emma_policy.inference.api.simbot_state import GenerateRequest


Expand All @@ -35,7 +35,7 @@ def __call__(self, request: GenerateRequest) -> EmmaDatasetBatch:
the agent in the environment.
"""
# Add a fullstop at the end and lowercase
instruction = check_punctuation(request.dialogue_history[-1].utterance)
instruction = format_instruction(request.dialogue_history[-1].utterance)
logger.debug(f"Preparing NLU input for instruction: {instruction}")
encoded_inputs = self._prepare_input_text(instruction)
feature_dicts = self._prepare_feature_dicts(request.environment_history[-1].features)
Expand Down

0 comments on commit e4fa938

Please sign in to comment.