diff --git a/src/emma_policy/datamodules/simbot_action_datamodule.py b/src/emma_policy/datamodules/simbot_action_datamodule.py index 39738b3..51a353a 100644 --- a/src/emma_policy/datamodules/simbot_action_datamodule.py +++ b/src/emma_policy/datamodules/simbot_action_datamodule.py @@ -186,18 +186,11 @@ def _compute_sample_weights(self, dataset_db: Path) -> list[float]: db = DatasetDb(dataset_db) # First pass through the dataset to get action type counts actions = [] - subsampled_weight: list[int] = [] for _, _, instance_str in db: instance = SimBotInstructionInstance.parse_raw(instance_str) - if self._skip_instance(instance.actions[-1]): - subsampled_weight.append(0) - else: - subsampled_weight.append(1) actions.append(self._get_action_type(instance.actions[-1])) - data_weights = compute_weights( - actions, temperature=self._weight_temperature, subsampled_weight=subsampled_weight - ) + data_weights = compute_weights(actions, temperature=self._weight_temperature) return data_weights diff --git a/src/emma_policy/datamodules/simbot_action_dataset.py b/src/emma_policy/datamodules/simbot_action_dataset.py index 8685ed6..81f2edf 100644 --- a/src/emma_policy/datamodules/simbot_action_dataset.py +++ b/src/emma_policy/datamodules/simbot_action_dataset.py @@ -15,6 +15,7 @@ SimBotInstructionInstance, SimBotObjectAttributes, ) +from emma_datasets.db import DatasetDb from overrides import overrides from torchvision.ops import masks_to_boxes from transformers import PreTrainedTokenizer @@ -48,8 +49,8 @@ def get_simbot_instruction_paraphrase( paraphraser: InstructionParaphraser, instance: SimBotInstructionInstance, object_name: str ) -> str: """Paraphrase a SimBot instruction.""" - action_type = instance.actions[0].type.lower() - action_object_metadata = instance.actions[0].get_action_data["object"] + action_type = instance.actions[-1].type.lower() + action_object_metadata = instance.actions[-1].get_action_data["object"] attributes = SimBotObjectAttributes( **action_object_metadata.get("attributes", {"readable_name": object_name}) ) @@ -61,7 +62,7 @@ def get_simbot_instruction_paraphrase( ) -def check_punctuation(text: str) -> str: +def format_instruction(text: str) -> str: """Make sure the instruction ends in a fullstop.""" if not text.endswith(("?", ".")): text = f"{text}." @@ -69,6 +70,45 @@ def check_punctuation(text: str) -> str: return text.lower() +class SearchNegativeSampler: + """Search negative selection class. + + Used to sample negative examples for the search objective. Creates a dictionary where keys + indices of the positive examples in the dataset and values are the readable names of each + object in the example. Given a readable name for an object, it samples keys from the dictionary + until the readable name is not present in the list of objects for a key. + """ + + def __init__(self, db: DatasetDb): + self._positive_indices_map = self._create_positive_indices_objects_map(db) + + def __call__(self, readable_name: str) -> int: + """Sample a negative example.""" + while True: + rand_idx = random.choice(list(self._positive_indices_map.keys())) + if readable_name.lower() not in self._positive_indices_map[rand_idx]: + return rand_idx + + def _create_positive_indices_objects_map(self, db: DatasetDb) -> dict[int, list[str]]: + """Create a map of indices and positive examples.""" + db_size = len(db) + positive_indices_map = {} + with db: + for index in range(db_size): + instance_str: str = db[index] + instance = SimBotInstructionInstance.parse_raw(instance_str) + + action = instance.actions[-1] + if action.type == "Search" and action.search["object"]["mask"] is not None: + attributes = action.search["object"].get("attributes", None) + if attributes is not None: + readable_names = [ # noqa: WPS220 + attribute["readable_name"].lower() for attribute in attributes + ] + positive_indices_map[index] = readable_names # noqa: WPS220 + return positive_indices_map + + class SimBotActionDataset(EmmaBaseDataset[EmmaDatasetItem]): """Dataset for SimBotAction. @@ -81,6 +121,7 @@ def __init__( dataset_db_path: Path, tokenizer: PreTrainedTokenizer, iou_threshold: float = 0.5, + search_negative_proba: float = 0.5, max_frames: int = 15, use_only_necessary_questions: bool = True, allow_paraphrasing: bool = False, @@ -92,6 +133,7 @@ def __init__( self._iou_threshold = iou_threshold self._goto_proba = 0 + self._search_negative_proba = search_negative_proba self._use_only_necessary_questions = use_only_necessary_questions self.question_answer_prompt = "<> {question} <> {answer}" arena_definitions = get_arena_definitions() @@ -100,6 +142,8 @@ def __init__( self._image_height = arena_definitions["image_height"] self._paraphraser = InstructionParaphraser() self._allow_paraphrasing = allow_paraphrasing + self._search_negative_sampler = SearchNegativeSampler(self.db) + self._special_name_cases = arena_definitions["special_asset_to_readable_name"] @overrides(check_signature=False) def __getitem__(self, index: int) -> EmmaDatasetItem: @@ -112,12 +156,12 @@ def __getitem__(self, index: int) -> EmmaDatasetItem: return self.simbot_vision_augmentation(instance) return self.simbot_action_execution(instance) - def simbot_vision_augmentation( # noqa: WPS210 + def simbot_vision_augmentation( # noqa: WPS210, WPS231 self, instance: SimBotInstructionInstance ) -> EmmaDatasetItem: """Process a visual augmentation instance for the SimBot action task.""" - if instance.actions[0].type == "Search": - action_object_metadata = instance.actions[0].get_action_data["object"] + if instance.actions[-1].type == "Search": + action_object_metadata = instance.actions[-1].get_action_data["object"] object_candidates = len(action_object_metadata["id"]) object_candidate_idx = random.choice(range(object_candidates)) @@ -139,22 +183,36 @@ def simbot_vision_augmentation( # noqa: WPS210 source_text = self._get_random_template_for_task(Task.visual_grounding).format( caption=source_text ) - source_text = check_punctuation(source_text) + source_text = format_instruction(source_text) object_name = get_object_label_from_object_id( object_id=action_object_metadata["id"][object_candidate_idx], object_assets_to_names=self._object_assets_to_names, ) - target_frames = [0 for _ in instance.actions] - - visual_features, _, _ = self._load_visual_features( - features_path=instance.features_path, - target_frames=target_frames, + # We need to skip the instances that are from annotations aka paraphrasable + # TODO: we need to make this easier + select_negative = ( + random.random() >= self._search_negative_proba and instance.paraphrasable ) + if select_negative: + negative_idx = self._search_negative_sampler( + action_object_metadata["attributes"][object_candidate_idx]["readable_name"] + ) + instance_str = self.db[negative_idx] + negative_instance = SimBotInstructionInstance.parse_raw(instance_str) + visual_features, _, _ = self._load_visual_features( + features_path=negative_instance.features_path, + target_frames=[0 for _ in negative_instance.actions], + ) + else: + visual_features, _, _ = self._load_visual_features( + features_path=instance.features_path, + target_frames=[0 for _ in instance.actions], + ) ground_truth_bboxes = action_object_metadata["mask"] - if ground_truth_bboxes is None: + if ground_truth_bboxes is None or select_negative: target_text = f"no {object_name} ." else: ground_truth_bbox = ground_truth_bboxes[object_candidate_idx] @@ -317,7 +375,7 @@ def _prepare_source_text(self, instance: SimBotInstructionInstance) -> str: and answer: where is the hammer? the hammer is on the table in the robotics lab. """ if self._allow_paraphrasing and instance.paraphrasable: - action_object_metadata = instance.actions[0].get_action_data["object"] + action_object_metadata = instance.actions[-1].get_action_data["object"] object_name = get_object_label_from_object_id( object_id=action_object_metadata["id"], object_assets_to_names=self._object_assets_to_names, @@ -346,7 +404,7 @@ def _prepare_source_text(self, instance: SimBotInstructionInstance) -> str: ) source_text = f"{source_text}. {question_answer_text}" - source_text = check_punctuation(source_text) + source_text = format_instruction(source_text) return source_text diff --git a/src/emma_policy/datamodules/simbot_nlu_dataset.py b/src/emma_policy/datamodules/simbot_nlu_dataset.py index c0e3afc..0157763 100644 --- a/src/emma_policy/datamodules/simbot_nlu_dataset.py +++ b/src/emma_policy/datamodules/simbot_nlu_dataset.py @@ -24,7 +24,8 @@ from emma_policy.datamodules.base_dataset import EmmaBaseDataset from emma_policy.datamodules.emma_dataclasses import EmmaDatasetItem, EmmaVisualFeatures from emma_policy.datamodules.simbot_action_dataset import ( - check_punctuation, + SearchNegativeSampler, + format_instruction, get_simbot_instruction_paraphrase, ) from emma_policy.utils import get_logger @@ -87,6 +88,7 @@ def __init__( max_frames: int = 4, is_train: bool = True, iou_threshold: float = 0.5, + search_negative_proba: float = 0.5, ) -> None: super().__init__( dataset_db_path=dataset_db_path, tokenizer=tokenizer, max_frames=max_frames @@ -104,11 +106,13 @@ def __init__( arena_definitions = get_arena_definitions() self._object_assets_to_names = arena_definitions["asset_to_label"] self._label_to_idx = arena_definitions["label_to_idx"] - self._special_name_cases = get_arena_definitions()["special_asset_to_readable_name"] + self._special_name_cases = arena_definitions["special_asset_to_readable_name"] self._image_width = arena_definitions["image_width"] self._image_height = arena_definitions["image_height"] self._paraphraser = InstructionParaphraser() self._iou_threshold = iou_threshold + self._search_negative_proba = search_negative_proba + self._search_negative_sampler = SearchNegativeSampler(self.db) @overrides(check_signature=False) def __len__(self) -> int: @@ -122,11 +126,16 @@ def __getitem__(self, index: int) -> EmmaDatasetItem: instance_str = self.db[index] instance = SimBotInstructionInstance.parse_raw(instance_str) first_action = instance.actions[0] + frame_idx = 0 if first_action.type == "Search": - instruction, target_text = self.prepare_search_instance(instance) + instruction, visual_features, target_text = self.prepare_search_instance(instance) else: instruction, target_text = self.prepare_action_instance(instance) - source_text = f"Predict the system act: {check_punctuation(instruction)}" + frame_idx = self._get_instance_frame(instance, target_text.lower()) + visual_features = self._load_visual_features( + features_path=instance.features_path[0], frame_idx=frame_idx + ) + source_text = f"Predict the system act: {format_instruction(instruction)}" target_text = target_text.lower() input_encoding = self.tokenizer.encode_plus( source_text, return_tensors=self._return_tensor_type, truncation=True @@ -135,15 +144,13 @@ def __getitem__(self, index: int) -> EmmaDatasetItem: target_text, return_tensors=self._return_tensor_type, truncation=True ) - frame_idx = self._get_instance_frame(instance, target_text) - visual_features = self._load_visual_features( - features_path=instance.features_path[0], frame_idx=frame_idx - ) - raw_target = { "example_id": f"{instance.mission_id}_{instance.annotation_id}_{instance.instruction_id}", "references": target_text, "instruction": source_text, + "nlu_class": target_text.split()[0], + "object_type": " ".join(target_text.split()[1:]), + "action_type": first_action.type, } return EmmaDatasetItem( @@ -204,33 +211,53 @@ def prepare_synthetic_action_instance( instruction, target_text = self._augment_synthetic_action(instance) return instruction, target_text - def prepare_search_instance(self, instance: SimBotInstructionInstance) -> tuple[str, str]: + def prepare_search_instance( + self, instance: SimBotInstructionInstance + ) -> tuple[str, EmmaVisualFeatures, str]: """Get source and target text for Search instructions.""" + # Select the object action_object_metadata = instance.actions[0].get_action_data["object"] object_candidates = len(action_object_metadata["id"]) object_candidate_idx = random.choice(range(object_candidates)) object_id = action_object_metadata["id"][object_candidate_idx] - # Always paraphrase after selecting a target for search - instruction = self._paraphraser( - action_type="search", - object_id=object_id, - object_attributes=SimBotObjectAttributes( - **action_object_metadata["attributes"][object_candidate_idx] - ), - ) + + # Prepare the instruction + if instance.paraphrasable: + instruction = self._paraphraser( + action_type="search", + object_id=object_id, + object_attributes=SimBotObjectAttributes( + **action_object_metadata["attributes"][object_candidate_idx] + ), + ) + else: + instruction = instance.instruction.instruction + + # Prepare the visual_features and target object_readable_name = get_object_readable_name_from_object_id( object_id=object_id, object_assets_to_names=self._object_assets_to_names, special_name_cases=self._special_name_cases, ) - if action_object_metadata["mask"] is None: + negative_proba = random.random() >= self._search_negative_proba + # We need to skip the instances that are from annotations aka paraphrasable + select_negative = negative_proba and instance.paraphrasable + if select_negative: + negative_idx = self._search_negative_sampler(object_readable_name) + negative_instance = SimBotInstructionInstance.parse_raw(self.db[negative_idx]) + visual_features = self._load_visual_features( + features_path=negative_instance.features_path[0] + ) + target_text = f"{SimBotNLUIntents.search_no_match.value} {object_readable_name}" + elif action_object_metadata["mask"] is None: + # A negative search sample + visual_features = self._load_visual_features(features_path=instance.features_path[0]) target_text = f"{SimBotNLUIntents.search_no_match.value} {object_readable_name}" else: - ground_truth_bboxes = action_object_metadata["mask"] - ground_truth_bbox = ground_truth_bboxes[object_candidate_idx] + # A positive search sample ground_truth_bbox = torch.tensor( - ground_truth_bbox, + action_object_metadata["mask"][object_candidate_idx], dtype=torch.float32, ).unsqueeze(0) @@ -253,7 +280,7 @@ def prepare_search_instance(self, instance: SimBotInstructionInstance) -> tuple[ f"{SimBotNLUIntents.search_too_many_matches.value} {object_readable_name}" ) - return instruction, target_text + return instruction, visual_features, target_text def _augment_synthetic_action(self, instance: SimBotInstructionInstance) -> tuple[str, str]: """Prepare the instruction and target text for a synthetic unambiguous instruction. diff --git a/src/emma_policy/inference/model_wrapper/simbot_nlu_input_builder.py b/src/emma_policy/inference/model_wrapper/simbot_nlu_input_builder.py index dc0054d..9d03b4c 100644 --- a/src/emma_policy/inference/model_wrapper/simbot_nlu_input_builder.py +++ b/src/emma_policy/inference/model_wrapper/simbot_nlu_input_builder.py @@ -12,7 +12,7 @@ EmmaDatasetItem, EmmaVisualFeatures, ) -from emma_policy.datamodules.simbot_action_dataset import check_punctuation +from emma_policy.datamodules.simbot_action_dataset import format_instruction from emma_policy.inference.api.simbot_state import GenerateRequest @@ -35,7 +35,7 @@ def __call__(self, request: GenerateRequest) -> EmmaDatasetBatch: the agent in the environment. """ # Add a fullstop at the end and lowercase - instruction = check_punctuation(request.dialogue_history[-1].utterance) + instruction = format_instruction(request.dialogue_history[-1].utterance) logger.debug(f"Preparing NLU input for instruction: {instruction}") encoded_inputs = self._prepare_input_text(instruction) feature_dicts = self._prepare_feature_dicts(request.environment_history[-1].features)