diff --git a/CHANGELOG.md b/CHANGELOG.md index d0363b7..d058315 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -84,20 +84,6 @@ All notable changes to this project will be documented in this file. See ## [1.45.1](https://github.com/emma-simbot/policy/compare/v1.45.0...v1.45.1) (2023-04-27) -### Bug Fixes - -* sticky note patch ([#785](https://github.com/emma-simbot/policy/issues/785)) ([e6f64d3](https://github.com/emma-simbot/policy/commit/e6f64d3cacf6aab444a2e29899479aaec84c4f25)) - -## [1.45.0](https://github.com/emma-simbot/policy/compare/v1.44.2...v1.45.0) (2023-04-27) - - -### Features - -* patch sticky note ([#784](https://github.com/emma-simbot/policy/issues/784)) ([6cf66a0](https://github.com/emma-simbot/policy/commit/6cf66a0e17a52e2b6d73e128a991d698d957ac36)) - -## [1.44.2](https://github.com/emma-simbot/policy/compare/v1.44.1...v1.44.2) (2023-04-25) - - ### Bug Fixes * patching v2 ([#782](https://github.com/emma-simbot/policy/issues/782)) ([827401c](https://github.com/emma-simbot/policy/commit/827401c45854ef482717c2e2a1efbf0cac9bfec3)) @@ -297,14 +283,6 @@ All notable changes to this project will be documented in this file. See ## [1.28.0](https://github.com/emma-simbot/policy/compare/v1.27.0...v1.28.0) (2023-02-11) - -### Features - -* add new examine sticky examples ([#704](https://github.com/emma-simbot/policy/issues/704)) ([121bcda](https://github.com/emma-simbot/policy/commit/121bcda1be9623904687320120ac6da4433d65f7)) - -## [1.27.0](https://github.com/emma-simbot/policy/compare/v1.26.0...v1.27.0) (2023-02-08) - - ### Features * Add dataset visualization ([#701](https://github.com/emma-simbot/policy/issues/701)) ([da6b764](https://github.com/emma-simbot/policy/commit/da6b764af4cfc3845fbb6fe5b63a2d0c0b854c2e)) @@ -312,13 +290,6 @@ All notable changes to this project will be documented in this file. See ## [1.26.0](https://github.com/emma-simbot/policy/compare/v1.25.1...v1.26.0) (2023-02-07) -### Features - -* Raw text matching with examine sticky note ([#698](https://github.com/emma-simbot/policy/issues/698)) ([776d364](https://github.com/emma-simbot/policy/commit/776d364ce9f00aff6e16751981c8af6a4a3b2868)) - -## [1.25.1](https://github.com/emma-simbot/policy/compare/v1.25.0...v1.25.1) (2023-02-06) - - ### Bug Fixes * negative candidate ids ([#696](https://github.com/emma-simbot/policy/issues/696)) ([8a27cf6](https://github.com/emma-simbot/policy/commit/8a27cf6f69fe95b1cd8ac760f04fc9fe58546e9e)) diff --git a/configs/datamodule/teach_datamodule.yaml b/configs/datamodule/teach_datamodule.yaml deleted file mode 100644 index f9f4389..0000000 --- a/configs/datamodule/teach_datamodule.yaml +++ /dev/null @@ -1,13 +0,0 @@ -_target_: emma_policy.datamodules.teach_edh_datamodule.TeachEdhDataModule - -model_name: heriot-watt/emma-small -teach_edh_train_db_file: ${work_dir}/storage/db/teach_training.db -teach_edh_valid_seen_db_file: ${work_dir}/storage/db/teach_valid_seen.db -teach_edh_valid_unseen_db_file: ${work_dir}/storage/db/teach_valid_unseen.db -load_valid_data_split: both -train_batch_size: 100 -val_batch_size: 100 -num_workers: 12 -max_lang_tokens: 512 -max_frames: 100 -tokenizer_truncation_side: right diff --git a/configs/experiment/simbot_nlu.yaml b/configs/experiment/simbot_cr.yaml similarity index 94% rename from configs/experiment/simbot_nlu.yaml rename to configs/experiment/simbot_cr.yaml index ac7cc60..95bf00b 100644 --- a/configs/experiment/simbot_nlu.yaml +++ b/configs/experiment/simbot_cr.yaml @@ -17,7 +17,7 @@ defaults: # name of the run determines folder name in logs # it's also accessed by loggers -name: "simbot_nlu" +name: "simbot_cr" seed: 12345 @@ -37,7 +37,7 @@ trainer: reload_dataloaders_every_n_epochs: 1 model: - _target_: emma_policy.models.simbot_nlu_policy.SimBotNLUEmmaPolicy + _target_: emma_policy.models.simbot_cr_policy.SimBotCREmmaPolicy model_name: heriot-watt/emma-base initialization_checkpoint: null @@ -53,7 +53,7 @@ model: resize_embeddings: True datamodule: - _target_: emma_policy.datamodules.simbot_nlu_datamodule.SimBotNLUDataModule + _target_: emma_policy.datamodules.simbot_cr_datamodule.SimBotCRDataModule model_name: heriot-watt/emma-base train_db_file: storage/db/simbot_clarifications_train.db @@ -112,7 +112,7 @@ callbacks: logger: wandb: _target_: pytorch_lightning.loggers.wandb.WandbLogger - project: "simbot_nlu" + project: "simbot_cr" name: ${name} save_dir: "logs/" offline: False # set True to store all logs only locally diff --git a/heriot-watt/emma-base-nlu/config.json b/heriot-watt/emma-base-cr/config.json similarity index 100% rename from heriot-watt/emma-base-nlu/config.json rename to heriot-watt/emma-base-cr/config.json diff --git a/notebooks/simbot_annotation_tool.py b/notebooks/simbot_annotation_tool.py deleted file mode 100644 index f889d8d..0000000 --- a/notebooks/simbot_annotation_tool.py +++ /dev/null @@ -1,895 +0,0 @@ -import argparse -import glob -import json -import logging -import os -import subprocess # noqa: S404 -from copy import deepcopy -from datetime import datetime -from pathlib import Path -from typing import Any, Literal, Optional - -import boto3 -import cv2 -import gradio as gr -import numpy as np -import pandas as pd -import torch -from boto3.dynamodb.conditions import Key -from botocore.exceptions import ClientError -from emma_datasets.constants.simbot.simbot import get_arena_definitions - -from emma_policy.commands.decode_images import decode_images_for_file -from emma_policy.commands.plot_bb import PlotBoundingBoxes - - -logging.basicConfig() -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - - -TurnOut = tuple[str, dict[str, Any], list[str], int] -SessionOut = tuple[int, str, str, list[Any], str, dict[str, Any], list[str], int] # noqa: WPS221 - - -class SessionClient: - """A simple client for retrieving sessions from the s3 bucket and dynamo db.""" - - def __init__( - self, - primary_key: str = "session_id", - resource_region: str = "us-east-1", - table_name: str = "SIMBOT_MEMORY_TABLE", - s3_sessions_bucket_url: str = "s3://emma-simbot-live-challenge", - sessions_file: str = "./notebooks/sessions.txt", - ignore_session_suffix: bool = False, - ) -> None: - - self._primary_key = primary_key - self._resource_region = resource_region - self._table_name = table_name - - self._db = boto3.resource("dynamodb", self._resource_region) - self._table = self._db.Table(self._table_name) - - self._s3_sessions_bucket_url = s3_sessions_bucket_url - self._sessions_file = sessions_file - self._ignore_session_suffix = ignore_session_suffix - - def get_all_session_turns_for_session(self, session_id: str) -> list[Any]: - """Get all the turns for a given session.""" - try: - response = self._table.query( - KeyConditionExpression=Key(self._primary_key).eq(session_id) - ) - except ClientError as err: - error_code = err.response["Error"]["Code"] - - if error_code != "ConditionalCheckFailedException": - logger.exception("Could not add turn to table.", exc_info=err) - raise err - return [] - - parsed_responses = response["Items"] - logger.debug(f"Successfully got previous {len(parsed_responses)} turns") - return parsed_responses - - def get_all_session_ids_from_bucket(self) -> dict[str, str]: - """Get all the session ids from the s3 bucket.""" - command = f"aws s3 ls --recursive {self._s3_sessions_bucket_url}" - with open(self._sessions_file, "w") as fpw: - subprocess.call(command.split(), stdout=fpw) # noqa: S603 - - df_csv = pd.read_csv(self._sessions_file, sep=r"\s+") - - session_days = df_csv.iloc[:, 0].tolist() - session_times = df_csv.iloc[:, 1].tolist() - session_files = df_csv.iloc[:, 3].tolist() - - sessions: dict[str, str] = {} - - session_metadata = zip(session_days, session_times, session_files) - for session_day, session_time, session_file in session_metadata: - # if not session_file.startswith("amzn1"): - # continue - session_name = os.path.dirname(session_file) - - timestamp = sessions.get(session_name, None) - if timestamp is not None: - t1 = datetime.strptime(timestamp, "%Y-%m-%d_%H:%M:%S") - t2 = datetime.strptime(timestamp, "%Y-%m-%d_%H:%M:%S") - - earliest_datetime = min((t1, t2)) - sessions[session_name] = datetime.strftime(earliest_datetime, "%Y-%m-%d_%H:%M:%S") - else: - sessions[session_name] = f"{session_day}_{session_time}" - - return sessions - - def download_from_s3( - self, local_cache_path: str, s3_object_url: str, is_folder: bool = False - ) -> None: - """Download a file or folder from the s3 bucket.""" - local_path = os.path.join(local_cache_path, s3_object_url) - if os.path.exists(local_path): - logger.debug(f"{s3_object_url} has been download in {local_path}") - return - - if self._ignore_session_suffix: - s3_object_url = "/".join(s3_object_url.split("/")[1:]) - - s3_url = os.path.join(self._s3_sessions_bucket_url, s3_object_url) - command = f"aws s3 cp {s3_url} {local_path}" - if is_folder: - command = f"{command} --recursive" - logger.debug(f"Downloading {s3_url} into {local_path}") - subprocess.call( # noqa: S603 - command.split(), - stderr=subprocess.DEVNULL, - stdout=subprocess.DEVNULL, - ) - - -class ArenaSessionAnnotation: - """Class for visualising and annotating turns from arena sessions.""" - - def __init__( - self, - output_annotation_json: str, - output_features_directory: str, - s3_sessions_bucket_url: str = "s3://emma-simbot-live-challenge", - cache_dir: str = "sessions", - max_bboxes: int = 36, - ignore_session_suffix: bool = True, - ) -> None: - self.output_annotation_json = output_annotation_json - - os.makedirs(output_features_directory, exist_ok=True) - self.output_features_directory = output_features_directory - - os.makedirs(cache_dir, exist_ok=True) - self.cache_dir = cache_dir - - arena_definitions = get_arena_definitions() - self.actions = sorted(arena_definitions["action_list"] + ["Search"]) - self.assets = list(arena_definitions["asset_to_label"].keys()) - self.max_bboxes = max_bboxes - self._ignore_session_suffix = ignore_session_suffix - - self._session_client = SessionClient( - s3_sessions_bucket_url=s3_sessions_bucket_url, - ignore_session_suffix=ignore_session_suffix, - ) - - sessions_dict = self._session_client.get_all_session_ids_from_bucket() - self._session_ids = list(sessions_dict.keys()) - self._session_timestamps = list(sessions_dict.values()) - self._bbox_plot = PlotBoundingBoxes() - - def __len__(self) -> int: - """Return the number of sessions.""" - return len(self._session_ids) - - def sort_sessions(self, session_index: int, key: Literal["alphabetical", "timestamp"]) -> int: - """Sort the sessions depending on the key and update the current session index.""" - if key == "alphabetical": - order = np.argsort(self._session_ids) - elif key == "timestamp": - predicate = [ - datetime.strptime(timestamp, "%Y-%m-%d_%H:%M:%S") - for timestamp in self._session_timestamps - ] - order = np.argsort(predicate) # type: ignore[arg-type] - - self._session_ids = np.array(self._session_ids)[order].tolist() - self._session_timestamps = np.array(self._session_timestamps)[order].tolist() - return order.tolist().index(session_index) - - def get_user_utterance_for_turn(self, current_session_turn: dict[str, Any]) -> str: - """Get the user utterance for the current turn.""" - metadata_current_turn = json.loads(current_session_turn["turn"]) - if metadata_current_turn["speech"] is None: - return "" - try: # noqa: WPS229 - # Utterances have changed fix compatibility here - modified_utterance = metadata_current_turn["speech"].get("modified_utterance", None) - if modified_utterance is None: - user_utterance = metadata_current_turn["speech"]["original_utterance"]["utterance"] - else: - user_utterance = metadata_current_turn["speech"]["modified_utterance"]["utterance"] - return user_utterance - except Exception: - utterance_metadata = metadata_current_turn["speech"].get("utterance", None) - - if utterance_metadata is not None: - return metadata_current_turn["speech"]["utterance"] - return " ".join( - [token["value"] for token in metadata_current_turn["speech"]["tokens"]] - ) - - def get_agent_metadata_for_turn(self, current_session_turn: dict[str, Any]) -> dict[str, Any]: - """Get the metadata dict for the current turn.""" - metadata_current_turn = json.loads(current_session_turn["turn"]) - - agent_turn_metadata = deepcopy(metadata_current_turn) - agent_turn_metadata.pop("timestamp", None) - agent_turn_metadata.pop("environment", None) - agent_turn_metadata.pop("auxiliary_metadata_uri", None) - agent_turn_metadata.pop("viewpoints", None) - agent_turn_metadata.pop("unique_room_names", None) - agent_turn_metadata["actions"].pop("dialog", None) - - inventory = agent_turn_metadata["state"]["inventory"] - agent_turn_metadata.pop("state", None) - agent_turn_metadata["inventory"] = inventory - return agent_turn_metadata - - def get_images_for_turn(self, current_session_turn: dict[str, Any]) -> list[str]: - """Get the images for the current turn.""" - turn_metadata = json.loads(current_session_turn["turn"]) - session_id = turn_metadata["session_id"] - prediction_id = turn_metadata["prediction_request_id"] - - local_image_path = Path(os.path.join(self.cache_dir, session_id, "images")) - os.makedirs(local_image_path, exist_ok=True) - - local_json_image_path = Path( - os.path.join(self.cache_dir, session_id, f"{prediction_id}.json") - ) - - if not local_json_image_path.exists(): - logger.debug(f"{local_json_image_path} does not exist") - return [] - - decode_images_for_file(local_json_image_path, local_image_path) - images_pattern = f"{prediction_id}*.png" - images = glob.glob(f"{os.path.join(self.cache_dir, session_id, 'images', images_pattern)}") - return sorted(images) - - def prepare_output_for_turn( - self, current_session_turn: dict[str, Any] - ) -> tuple[str, dict[str, Any], list[str]]: - """Prepare the output for the current turn.""" - user_utterance = self.get_user_utterance_for_turn(current_session_turn) - agent_turn_metadata = self.get_agent_metadata_for_turn(current_session_turn) - images = self.get_images_for_turn(current_session_turn) - return (user_utterance, agent_turn_metadata, images) - - def on_previous_turn(self, session_turns: list[Any], turn_index: int) -> Optional[TurnOut]: - """Get the previous turn.""" - # The next turn is either the turn with a -1 index or the first element in the list. - new_turn_index = max(0, turn_index - 1) - if session_turns: - new_session_turn = session_turns[new_turn_index] - (user_utterance, agent_turn_metadata, images) = self.prepare_output_for_turn( - new_session_turn - ) - return (user_utterance, agent_turn_metadata, images, new_turn_index) - return None - - def on_next_turn(self, session_turns: list[Any], turn_index: int) -> Optional[TurnOut]: - """Get the next turn.""" - # The next turn is either the turn with a +1 index or the last element in the list. - new_turn_index = min(len(session_turns) - 1, turn_index + 1) - if session_turns: - new_session_turn = session_turns[new_turn_index] - (user_utterance, agent_turn_metadata, images) = self.prepare_output_for_turn( - new_session_turn - ) - return (user_utterance, agent_turn_metadata, images, new_turn_index) - return None - - def on_previous_session_id(self, session_index: int) -> Optional[SessionOut]: - """Get the previous session id.""" - new_session_index = max(0, session_index - 1) - - session_id = self._session_ids[new_session_index] - session_timestamp = self._session_timestamps[new_session_index] - - self._session_client.download_from_s3( - local_cache_path=self.cache_dir, s3_object_url=session_id, is_folder=True - ) - - session_turns = self._session_client.get_all_session_turns_for_session(session_id) - if session_turns: - new_session_turn = session_turns[0] - (user_utterance, agent_turn_metadata, images) = self.prepare_output_for_turn( - new_session_turn - ) - - return ( # noqa: WPS227 - new_session_index, - session_id, - session_timestamp, - session_turns, - user_utterance, - agent_turn_metadata, - images, - 0, - ) - return None - - def on_next_session_id(self, session_index: int) -> Optional[SessionOut]: - """Get the next session id.""" - new_session_index = min(len(self._session_ids) - 1, session_index + 1) - - session_id = self._session_ids[new_session_index] - session_timestamp = self._session_timestamps[new_session_index] - - self._session_client.download_from_s3( - local_cache_path=self.cache_dir, s3_object_url=session_id, is_folder=True - ) - - session_turns = self._session_client.get_all_session_turns_for_session(session_id) - if session_turns: - new_session_turn = session_turns[0] - (user_utterance, agent_turn_metadata, images) = self.prepare_output_for_turn( - new_session_turn - ) - - return ( # noqa: WPS227 - new_session_index, - session_id, - session_timestamp, - session_turns, - user_utterance, - agent_turn_metadata, - images, - 0, - ) - return None - - def on_jump_session_id_slider(self, session_index: int) -> Optional[SessionOut]: - """Go to a session provided by its index.""" - session_id = self._session_ids[session_index] - session_timestamp = self._session_timestamps[session_index] - - self._session_client.download_from_s3( - local_cache_path=self.cache_dir, s3_object_url=session_id, is_folder=True - ) - session_turns = self._session_client.get_all_session_turns_for_session(session_id) - if session_turns: - new_session_turn = session_turns[0] - (user_utterance, agent_turn_metadata, images) = self.prepare_output_for_turn( - new_session_turn - ) - - return ( # noqa: WPS227 - session_index, - session_id, - session_timestamp, - session_turns, - user_utterance, - agent_turn_metadata, - images, - 0, - ) - return None - - def on_jump_session_id_textbox(self, session_id: str) -> Optional[SessionOut]: - """Go to a session provided by its id.""" - self._session_client.download_from_s3( - local_cache_path=self.cache_dir, s3_object_url=session_id, is_folder=True - ) - session_index = self._session_ids.index(session_id) - session_timestamp = self._session_timestamps[session_index] - session_turns = self._session_client.get_all_session_turns_for_session(session_id) - if session_turns: - new_session_turn = session_turns[0] - (user_utterance, agent_turn_metadata, images) = self.prepare_output_for_turn( - new_session_turn - ) - - return ( # noqa: WPS227 - session_index, - session_id, - session_timestamp, - session_turns, - user_utterance, - agent_turn_metadata, - images, - 0, - ) - return None - - def on_hide_all_boxes( - self, session_id: str, session_turns: list[Any], session_turn_index: int - ) -> tuple[list[int], list[str]]: - """Disable all the bounding boxes.""" - indices: list[int] = [] - return indices, self.on_show_specific_boxes( - session_id, session_turns, session_turn_index, indices - ) - - def on_show_all_boxes( - self, session_id: str, session_turns: list[Any], session_turn_index: int - ) -> tuple[list[int], list[str]]: - """Show all the bounding boxes.""" - indices = [idx + 1 for idx in range(self.max_bboxes)] - return indices, self.on_show_specific_boxes( - session_id, session_turns, session_turn_index, list(range(self.max_bboxes)) - ) - - def on_show_specific_boxes( - self, - session_id: str, - session_turns: list[Any], - session_turn_index: int, - indices: list[int], - ) -> list[str]: - """Show only a subset of the available bounding boxes.""" - images = self.get_images_for_turn(session_turns[session_turn_index]) - - local_image_bbox_path = Path(os.path.join(self.cache_dir, "images_bboxes")) - os.makedirs(local_image_bbox_path, exist_ok=True) - images_bboxes = [] - for idx, image in enumerate(images): - if self._ignore_session_suffix: - image_suffix = image - else: - image_suffix = "/".join(image.split("/"))[1:] - - image_bname = os.path.splitext(os.path.basename(image_suffix))[0] - (feature_basename, image_index) = image_bname.split("_") - feature_path = os.path.join(self.cache_dir, session_id, f"{feature_basename}.pt") - - if os.path.exists(feature_path): - image_features = torch.load(feature_path)[int(image_index)] - - boxes_coords = image_features["bbox_coords"].cpu().numpy() - - num_boxes = boxes_coords.shape[0] - boxes_indices = [idx for idx in indices if 0 <= idx < num_boxes] - - image_cv = cv2.imread(image) - self._bbox_plot.draw_bb( - image=image_cv, - boxes_coords=boxes_coords[boxes_indices], - boxes_labels=[f"{idx + 1}" for idx in boxes_indices], - draw_label=True, - ) - - if self._ignore_session_suffix: - image_from_suffix = "/".join(session_id.split("/")[1:]) - else: - image_from_suffix = session_id - image_bbox_path = os.path.join( - local_image_bbox_path, f"{image_from_suffix}_{idx}.png" - ) - cv2.imwrite(image_bbox_path, image_cv) - images_bboxes.append(image_bbox_path) - else: - logger.debug(f"Feature path {feature_path} does not exist") - return images_bboxes - - def on_update_annotation( - self, - session_id: str, - session_turns: list[Any], - session_turn_index: int, - user_utterance: str, - action_type: Optional[str] = None, - object_id: Optional[str] = None, - visual_token: Optional[int] = None, - ) -> dict[str, Any]: - """Update the annotation for a turn.""" - instruction_metadata: dict[str, Any] = { - "instruction": { - "instruction": user_utterance, - "actions": [0], - } - } - if action_type: - images = self.get_images_for_turn(session_turns[session_turn_index]) - if images: - image = images[0] - image_bname = os.path.splitext(os.path.basename(image))[0] - (feature_basename, image_index) = image_bname.split("_") - feature_path = os.path.join(self.cache_dir, session_id, f"{feature_basename}.pt") - - mask = None - if os.path.exists(feature_path) and visual_token: - image_features = torch.load(feature_path)[int(image_index)] - boxes_coords = image_features["bbox_coords"].cpu().numpy() - mask = boxes_coords[int(visual_token - 1)].astype(int).tolist() - - # The search metadata are slightly different from the other actions. - # The object dictionary has multiple object ids and object masks. - if action_type == "Search": - instruction_metadata["actions"] = [ - { - "id": 0, - "type": action_type, - action_type.lower(): { - "object": {"id": [object_id], "mask": [mask], "colorImageIndex": 0} - }, - "colorImages": [os.path.basename(image)], - "final": True, - "positive": True, - } - ] - else: - instruction_metadata["actions"] = [ - { - "id": 0, - "type": action_type, - action_type.lower(): { - "object": {"id": object_id, "mask": mask, "colorImageIndex": 0} - }, - "colorImages": [os.path.basename(image)], - "final": True, - } - ] - - # Fill in required metadata so that the dictionary can be parsed by the SimBotInstructionInstance - instruction_metadata["annotation_id"] = 0 - instruction_metadata["instruction_id"] = 0 - instruction_metadata["synthetic"] = False - instruction_metadata["mission_id"] = session_id - # This needs to be set to true so that we can get the correct features path from - # https://github.com/emma-simbot/datasets/blob/19db6ef9244e2e78acf2cb36a1c2f1bd6be799cd/src/emma_datasets/datamodels/datasets/utils/simbot_utils/simbot_datamodels.py#L149-L162 - instruction_metadata["vision_augmentation"] = True - return instruction_metadata - - def on_save_annotation_for_turn( - self, - session_id: str, - session_turns: list[Any], - session_turn_index: int, - instruction_metadata: dict[str, Any], - ) -> None: - """Save the annotations for a turn.""" - images = self.get_images_for_turn(session_turns[session_turn_index]) - - prediction_request_id = json.loads(session_turns[session_turn_index]["turn"])[ - "prediction_request_id" - ] - data = {} - # Allow for multiple annotations per prediction id - date_time = datetime.now().strftime("%m/%d/%Y-%H:%M:%S") - data_key = f"session_id_{session_id}_prediction_id_{prediction_request_id}_{date_time}" - if os.path.exists(self.output_annotation_json): - with open(self.output_annotation_json) as fpr: - data = json.load(fpr) - - data[data_key] = instruction_metadata - with open(self.output_annotation_json, "w") as fpw: - json.dump(data, fpw, indent=4) - - features = torch.load( - os.path.join(self.cache_dir, session_id, f"{prediction_request_id}.pt") - ) - features_formatted: dict[str, Any] = {"frames": []} - for feature_idx, feature_dict in features.items(): - feature_dict_formatted = { - "image": os.path.basename(images[feature_idx]), - "features": { - "bbox_features": feature_dict["bbox_features"].cpu(), - "bbox_coords": feature_dict["bbox_coords"].cpu(), - "bbox_probas": feature_dict["bbox_probas"].cpu(), - "cnn_features": feature_dict["cnn_features"].cpu(), - "width": 300, - "height": 300, - }, - } - features_formatted["frames"].append(feature_dict_formatted) - - torch.save( - features_formatted, - os.path.join( - self.output_features_directory, f"{prediction_request_id}_{feature_idx}.pt" - ), - ) - - -def main(args: argparse.Namespace) -> None: # noqa: WPS210 - """Main.""" - session_visualizer = ArenaSessionAnnotation( - output_annotation_json=args.output_annotation_json, - output_features_directory=args.output_features_directory, - cache_dir=args.cache_dir, - s3_sessions_bucket_url=args.s3_sessions_bucket_url, - ignore_session_suffix=args.ignore_session_suffix, - ) - - with gr.Blocks() as block: - session_id_turns = gr.State([]) - session_turn_index = gr.State(0) - with gr.Row(): - session_id_textbox = gr.Textbox(label="Session ID \U0000270D", interactive=True) - - session_timestamp_textbox = gr.Textbox(label="Session Timestamp", interactive=False) - - sort_sessions_dropdown = gr.Radio( - label="Sort Sessions", choices=["alphabetical", "timestamp"], value="alphabetical" - ) - with gr.Row(): - previous_session_id_button = gr.Button( - "Previous Session ID", label="Previous Session ID" - ) - next_session_id_button = gr.Button( - "Next Session ID", - label="Next Session ID", - ) - jump_session_id_button = gr.Button( - "Go To Session ID", - label="Go To Session ID", - ) - with gr.Row(): - jump_session_id_slider = gr.Slider( - minimum=0, - maximum=len(session_visualizer) - 1, - label="Jump To Session", - value=0, - step=1, - ) - - with gr.Row(): - previous_turn_button = gr.Button("Previous Turn", label="Previous Turn") - next_turn_button = gr.Button("Next Turn", label="Next Turn") - - with gr.Row(): - agent_turn_session_id_textbox = gr.JSON(label="Agent Turn Metadata") - - with gr.Column(): - output_image_gallery = gr.Gallery(label="Images For Current Turn") - - with gr.Row(): - checkboxgroup_bboxes = gr.CheckboxGroup( - choices=[f"{idx + 1}" for idx in range(session_visualizer.max_bboxes)], - type="index", - label="Show specific bounding boxes", - interactive=True, - ) - - with gr.Row(): - disable_all_boxes_button = gr.Button("Hide all bounding boxes") - show_all_boxes_button = gr.Button("Show all bounding boxes") - - with gr.Row(): - with gr.Column(): - user_turn_session_id_textbox = gr.Textbox( - label="User Turn Utterance \U0000270D", interactive=True - ) - - action_type_dropdown = gr.Dropdown( - label="Action Type", choices=session_visualizer.actions - ) - - object_id_dropdown = gr.Dropdown( - label="Object ID", choices=session_visualizer.assets - ) - - visual_token_dropdown = gr.Dropdown( - label="Visual Token", - choices=list(range(1, session_visualizer.max_bboxes + 1)), - ) - - with gr.Column(): - instruction_annotation_json = gr.JSON(label="Instruction metadata", value={}) - - with gr.Row(): - save_turn_button = gr.Button( - "Save Annotation For Turn", - label="Save Annotation For Turn", - variant="primary", - ) - - previous_session_id_button.click( - fn=session_visualizer.on_previous_session_id, - inputs=[jump_session_id_slider], - outputs=[ - jump_session_id_slider, - session_id_textbox, - session_timestamp_textbox, - session_id_turns, - user_turn_session_id_textbox, - agent_turn_session_id_textbox, - output_image_gallery, - session_turn_index, - ], - ) - - next_session_id_button.click( - fn=session_visualizer.on_next_session_id, - inputs=[jump_session_id_slider], - outputs=[ - jump_session_id_slider, - session_id_textbox, - session_timestamp_textbox, - session_id_turns, - user_turn_session_id_textbox, - agent_turn_session_id_textbox, - output_image_gallery, - session_turn_index, - ], - ) - - jump_session_id_button.click( - fn=session_visualizer.on_jump_session_id_textbox, - inputs=[session_id_textbox], - outputs=[ - jump_session_id_slider, - session_id_textbox, - session_timestamp_textbox, - session_id_turns, - user_turn_session_id_textbox, - agent_turn_session_id_textbox, - output_image_gallery, - session_turn_index, - ], - ) - - jump_session_id_slider.change( - fn=session_visualizer.on_jump_session_id_slider, - inputs=[jump_session_id_slider], - outputs=[ - jump_session_id_slider, - session_id_textbox, - session_timestamp_textbox, - session_id_turns, - user_turn_session_id_textbox, - agent_turn_session_id_textbox, - output_image_gallery, - session_turn_index, - ], - ) - - sort_sessions_dropdown.change( - fn=session_visualizer.sort_sessions, - inputs=[jump_session_id_slider, sort_sessions_dropdown], - outputs=[jump_session_id_slider], - ) - - previous_turn_button.click( - fn=session_visualizer.on_previous_turn, - inputs=[session_id_turns, session_turn_index], - outputs=[ - user_turn_session_id_textbox, - agent_turn_session_id_textbox, - output_image_gallery, - session_turn_index, - ], - ) - - next_turn_button.click( - fn=session_visualizer.on_next_turn, - inputs=[session_id_turns, session_turn_index], - outputs=[ - user_turn_session_id_textbox, - agent_turn_session_id_textbox, - output_image_gallery, - session_turn_index, - ], - ) - - disable_all_boxes_button.click( - fn=session_visualizer.on_hide_all_boxes, - inputs=[ - session_id_textbox, - session_id_turns, - session_turn_index, - ], - outputs=[checkboxgroup_bboxes, output_image_gallery], - ) - - show_all_boxes_button.click( - fn=session_visualizer.on_show_all_boxes, - inputs=[ - session_id_textbox, - session_id_turns, - session_turn_index, - ], - outputs=[checkboxgroup_bboxes, output_image_gallery], - ) - - checkboxgroup_bboxes.change( - fn=session_visualizer.on_show_specific_boxes, - inputs=[ - session_id_textbox, - session_id_turns, - session_turn_index, - checkboxgroup_bboxes, - ], - outputs=[output_image_gallery], - ) - - user_turn_session_id_textbox.change( - fn=session_visualizer.on_update_annotation, - inputs=[ - session_id_textbox, - session_id_turns, - session_turn_index, - user_turn_session_id_textbox, - action_type_dropdown, - object_id_dropdown, - visual_token_dropdown, - ], - outputs=[instruction_annotation_json], - ) - - action_type_dropdown.change( - fn=session_visualizer.on_update_annotation, - inputs=[ - session_id_textbox, - session_id_turns, - session_turn_index, - user_turn_session_id_textbox, - action_type_dropdown, - object_id_dropdown, - visual_token_dropdown, - ], - outputs=[instruction_annotation_json], - ) - - object_id_dropdown.change( - fn=session_visualizer.on_update_annotation, - inputs=[ - session_id_textbox, - session_id_turns, - session_turn_index, - user_turn_session_id_textbox, - action_type_dropdown, - object_id_dropdown, - visual_token_dropdown, - ], - outputs=[instruction_annotation_json], - ) - visual_token_dropdown.change( - fn=session_visualizer.on_update_annotation, - inputs=[ - session_id_textbox, - session_id_turns, - session_turn_index, - user_turn_session_id_textbox, - action_type_dropdown, - object_id_dropdown, - visual_token_dropdown, - ], - outputs=[instruction_annotation_json], - ) - - save_turn_button.click( - fn=session_visualizer.on_save_annotation_for_turn, - inputs=[ - session_id_textbox, - session_id_turns, - session_turn_index, - instruction_annotation_json, - ], - ) - - block.launch(share=True) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--output_annotation_json", - default="session_annotations/session_annotations.json", - help="Path to output annotation json file.", - ) - - parser.add_argument( - "--output_features_directory", - default="session_annotations/features/", - help="Path to output annotation feature directory.", - ) - - parser.add_argument( - "--cache_dir", - default="sessions", - help="Path to cache directory storing raw session metadata while annotating.", - ) - - parser.add_argument( - "--s3_sessions_bucket_url", - help="S3 bucket to where all the sessions are stored", - default="s3://emma-simbot-live-challenge", - ) - - parser.add_argument( - "--ignore_session_suffix", help="Ignore session suffix", action="store_true" - ) - args = parser.parse_args() - main(args) diff --git a/notebooks/simbot_dataset_visualization_app.py b/notebooks/simbot_dataset_visualization_app.py index e125778..81a70ed 100644 --- a/notebooks/simbot_dataset_visualization_app.py +++ b/notebooks/simbot_dataset_visualization_app.py @@ -13,7 +13,7 @@ from transformers import AutoTokenizer from emma_policy.datamodules.simbot_action_dataset import SimBotActionDataset -from emma_policy.datamodules.simbot_nlu_dataset import SimBotNLUDataset +from emma_policy.datamodules.simbot_cr_dataset import SimBotCRDataset logging.basicConfig() @@ -86,9 +86,9 @@ def get_data_from_action_dataset(args: argparse.Namespace) -> dict[str, Any]: return data_dict -def get_data_from_nlu_dataset(args: argparse.Namespace) -> dict[str, Any]: - """Get the visualization data from the NLU dataset.""" - train_dataset = SimBotNLUDataset( +def get_data_from_cr_dataset(args: argparse.Namespace) -> dict[str, Any]: + """Get the visualization data from the CR dataset.""" + train_dataset = SimBotCRDataset( dataset_db_path=args.dataset_db, tokenizer=AutoTokenizer.from_pretrained("heriot-watt/emma-base"), is_train=True, @@ -98,13 +98,9 @@ def get_data_from_nlu_dataset(args: argparse.Namespace) -> dict[str, Any]: data_per_action = defaultdict(list) for index, instance in tqdm(enumerate(train_dataset)): # type: ignore[arg-type] - data.append(instance.raw_target["nlu_class"]) - data_per_object[instance.raw_target["object_type"]].append( - instance.raw_target["nlu_class"] - ) - data_per_action[instance.raw_target["action_type"]].append( - instance.raw_target["nlu_class"] - ) + data.append(instance.raw_target["cr_class"]) + data_per_object[instance.raw_target["object_type"]].append(instance.raw_target["cr_class"]) + data_per_action[instance.raw_target["action_type"]].append(instance.raw_target["cr_class"]) if index == len(train_dataset) - 1: break @@ -119,8 +115,8 @@ def get_data_for_visualization(args: argparse.Namespace) -> dict[str, Any]: if args.cache_dir.exists(): with open(args.cache_dir) as file_in: return json.load(file_in) - elif args.dataset_type == "nlu": - return get_data_from_nlu_dataset(args) + elif args.dataset_type == "cr": + return get_data_from_cr_dataset(args) return get_data_from_action_dataset(args) @@ -181,7 +177,7 @@ def main(args: argparse.Namespace) -> None: parser.add_argument( "--dataset_type", type=str, - choices=["nlu", "action"], + choices=["cr", "action"], help="Type of the dataset", ) parser.add_argument( diff --git a/notebooks/teach_attention_masks.ipynb b/notebooks/teach_attention_masks.ipynb deleted file mode 100644 index 3b70401..0000000 --- a/notebooks/teach_attention_masks.ipynb +++ /dev/null @@ -1,256 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We'll create the local and global masks for the enitre sequence assuming that input tokens are concatenated in the following order: [cnn scene tokens, object tokens, language tokens].\n", - "\n", - "Since we can only create the mask after the padding, we use scene, object and text temporal vectors with values:\n", - "\n", - "- -1 for history tokens\n", - "- 0 for padding tokens\n", - "- 1 .. N the number of the corresponding future frame \n", - "\n", - "Text input tokens will always be history tokens. We make text input tokens global." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "\n", - "# Example with 3 history frames, 2 future frames and 1 padding token.\n", - "scene_temporal_ids1 = torch.Tensor([-1, -1, -1, 1, 2, 0])\n", - "# The history frames have a total of 6 objects, future frame 1 has 3 objects and future frame 2 has 2 objects.\n", - "object_temporal_ids1 = torch.Tensor([-1, -1, -1, -1, -1, -1, 1, 1, 1, 2, 2, 0, 0])\n", - "# There is no text for future frames.\n", - "text_temporal_ids1 = torch.Tensor([-1, -1, -1, -1, -1, -1, -1, 0, 0])" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "# Create a second sample with 2 history and 3 future frames\n", - "scene_temporal_ids2 = torch.Tensor([-1, -1, 1, 2, 3, 0])\n", - "object_temporal_ids2 = torch.Tensor([-1, -1, -1, -1, 1, 1, 1, 2, 2, 3, 0, 0, 0])\n", - "text_temporal_ids2 = torch.Tensor([-1, -1, -1, -1, -1, 0, 0, 0, 0])\n", - "\n", - "# Concatenate them in a batch\n", - "scene_temporal_ids = torch.stack([scene_temporal_ids1, scene_temporal_ids2])\n", - "object_temporal_ids = torch.stack([object_temporal_ids1, object_temporal_ids2])\n", - "text_temporal_ids = torch.stack([text_temporal_ids1, text_temporal_ids2])" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "from emma_policy.datamodules.collate import (\n", - " make_text_history_global_pattern,\n", - " make_encoder_causal_mask_batch,\n", - ")\n", - "\n", - "attention2d = make_encoder_causal_mask_batch(\n", - " scene_temporal_ids,\n", - " object_temporal_ids,\n", - " text_temporal_ids,\n", - " dtype=scene_temporal_ids.dtype,\n", - ")\n", - "global_attenion = make_text_history_global_pattern(\n", - " scene_temporal_ids,\n", - " object_temporal_ids,\n", - " text_temporal_ids,\n", - " dtype=scene_temporal_ids.dtype,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2D ttention mask shape: batch size x total tokens x total tokens = torch.Size([2, 28, 28])\n", - "Global ttention mask shapebatch size x total tokens = torch.Size([2, 28])\n" - ] - }, - { - "data": { - "text/plain": [ - "tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", - " 0., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n", - " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", - " 0., 1., 1., 1., 1., 1., 0., 0., 0., 0.]])" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "print(f\"2D ttention mask shape: batch size x total tokens x total tokens = {attention2d.shape}\")\n", - "assert (\n", - " attention2d.shape[2]\n", - " == scene_temporal_ids.shape[1] + object_temporal_ids.shape[1] + text_temporal_ids.shape[1]\n", - ")\n", - "print(f\"Global ttention mask shapebatch size x total tokens = {global_attenion.shape}\")\n", - "assert (\n", - " global_attenion.shape[1]\n", - " == scene_temporal_ids.shape[1] + object_temporal_ids.shape[1] + text_temporal_ids.shape[1]\n", - ")\n", - "global_attenion" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now we can check any possible combination. The element (i, j) of the 2D attention mask is 1 if element i is allowed to attend to element j." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "First item in the batch.\n", - "Scene-to-scene attention: 3 history frames, 2 future frames, 1 padding\n" - ] - }, - { - "data": { - "text/plain": [ - "tensor([[1., 1., 1., 0., 0., 0.],\n", - " [1., 1., 1., 0., 0., 0.],\n", - " [1., 1., 1., 0., 0., 0.],\n", - " [1., 1., 1., 1., 0., 0.],\n", - " [1., 1., 1., 1., 1., 0.],\n", - " [0., 0., 0., 0., 0., 0.]])" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "print(\"First item in the batch.\")\n", - "print(\"Scene-to-scene attention: 3 history frames, 2 future frames, 1 padding\")\n", - "scene_len = scene_temporal_ids.shape[-1]\n", - "attention2d[0, :scene_len, :scene_len]" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Scene-to-objects attention: 6 history objects, 3 objects in frame 1, 2 objects in frame 2, 2 paddings\n" - ] - }, - { - "data": { - "text/plain": [ - "tensor([[1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n", - " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "print(\n", - " \"Scene-to-objects attention: 6 history objects, 3 objects in frame 1, 2 objects in frame 2, 2 paddings\"\n", - ")\n", - "object_len = object_temporal_ids.shape[-1]\n", - "attention2d[0, :scene_len, scene_len : scene_len + object_len]" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Scene-to-text attention: 7 history text tokens\n" - ] - }, - { - "data": { - "text/plain": [ - "tensor([[1., 1., 1., 1., 1., 1., 1., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 0., 0.],\n", - " [0., 0., 0., 0., 0., 0., 0., 0., 0.]])" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "print(\"Scene-to-text attention: 7 history text tokens\")\n", - "text_len = text_temporal_ids.shape[-1]\n", - "attention2d[0, :scene_len, scene_len + object_len :]" - ] - } - ], - "metadata": { - "interpreter": { - "hash": "0ee53ab302d70dc2b4b6ceff365a75f0f8d5471af86eaa2f96d460774c6ebc79" - }, - "kernelspec": { - "display_name": "Python 3.9.7 ('emma')", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.7" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/plot_ablation_results.py b/plot_ablation_results.py new file mode 100644 index 0000000..c9744c0 --- /dev/null +++ b/plot_ablation_results.py @@ -0,0 +1,34 @@ +import numpy as np +import matplotlib.pyplot as plt + +ablation_percentages = np.array([0, 0.25, 0.5, 0.75, 1.0]) + +# the 2nd and the second to last are swapped +vision_aug_results = np.array([27.15, 34.72, 35.24, 36.64, 36.81]) +cdf_aug_results = np.array([34.09, 34.89, 36.61, 36.81, 36.81]) +human_results = np.array([19.17, 19.17, 19.17, 19.17, 19.17]) +y_ticks = [17, 20, 23, 26, 29, 32, 35, 38] + +fig = plt.figure() +ax = fig.add_subplot(111) +ax.plot(ablation_percentages, human_results, "-.g", label="DTC") +ax.plot(ablation_percentages, vision_aug_results, "-bs", label="Visual Aug") + +ax2 = ax.twinx() +ax2.plot(ablation_percentages, cdf_aug_results, "--ro", label="CDF Aug") +# fig.legend(loc="upper right") + +# ax.set_xlabel("Pr") +ax.set_yticks(y_ticks) +ax2.set_yticks(y_ticks) +ax.set_ylabel(r"Vision Augmentations") +ax2.set_ylabel(r"CDF Augmentations") + +ax.grid() + +fig.legend(loc="upper center", bbox_to_anchor=(0.5, 0.425), fancybox=True, ncol=3) +plt.xticks(ablation_percentages, ablation_percentages) +# plt.show() +plt.title("Performance curves when ablating augmentations") +ax.set_xlabel("Proportion of train instances") +plt.savefig("human.pdf", transparent=True) diff --git a/scripts/download_features_for_db.py b/scripts/download_features_for_db.py deleted file mode 100644 index cbe7068..0000000 --- a/scripts/download_features_for_db.py +++ /dev/null @@ -1,132 +0,0 @@ -import argparse -import logging -from pathlib import Path -from typing import Literal - -import boto3 -import botocore -from emma_datasets.common import get_progress -from emma_datasets.datamodels import BaseInstance, Instance -from emma_datasets.datamodels.datasets import TeachEdhInstance -from emma_datasets.datamodels.datasets.simbot import SimBotInstructionInstance -from emma_datasets.db import DatasetDb - -from emma_policy.utils import get_logger - - -log = get_logger(__name__) - -logging.getLogger("boto3").setLevel(logging.CRITICAL) -logging.getLogger("botocore").setLevel(logging.CRITICAL) -logging.getLogger("nose").setLevel(logging.CRITICAL) -logging.getLogger("s3transfer").setLevel(logging.CRITICAL) -logging.getLogger("urllib3").setLevel(logging.CRITICAL) - - -class FixtureDownload: - """Downloads the features for all instances given a db.""" - - def __init__( - self, - input_db_path: Path, - instance_type: Literal["pretrain", "teach_edh", "simbot"] = "pretrain", - ) -> None: - self._db = DatasetDb(input_db_path) - - self._instance_model: BaseInstance = Instance - self.instance_type = instance_type - if self.instance_type == "teach_edh": - self._instance_model = TeachEdhInstance - elif self.instance_type == "simbot": - self._instance_model = SimBotInstructionInstance - - self._s3 = boto3.client("s3") - - self._progress = get_progress() - self._task_id = self._progress.add_task( - f"Downloading features from {input_db_path}", total=len(self._db), comment="" - ) - - def run(self) -> None: - """Do the downloading for the fixtures.""" - with self._progress, self._db: # noqa: WPS316 - for _, _, data in self._db: - instance = self._instance_model.parse_raw(data) - - if ( # noqa: WPS337 - self._instance_model != SimBotInstructionInstance - and instance.is_full_trajectory - ): - local_feature_paths = instance.features_path - else: - local_feature_paths = [instance.features_path] - self._dowload_feature_files( - instance=instance, local_feature_paths=local_feature_paths - ) - self._progress.advance(self._task_id) - - def _dowload_feature_files( - self, - instance: BaseInstance, - local_feature_paths: list[Path], - ) -> None: - """Download the feature files.""" - for local_feature_path in local_feature_paths: - local_feature_path.parent.mkdir(parents=True, exist_ok=True) - - if "alfred" in local_feature_path.parts: - s3_path = self._get_paths_for_alfred(local_feature_path) - else: - s3_path = self._get_paths(local_feature_path) - - self._download_file(s3_path, local_feature_path) - - if self.instance_type == "teach_edh": - local_future_feature_path = instance.future_features_path - local_future_feature_path.parent.mkdir(parents=True, exist_ok=True) - s3_path = self._get_paths(local_future_feature_path) - self._download_file(s3_path, local_future_feature_path) - - def _download_file(self, s3_path: Path, local_path: Path) -> None: - try: - self._s3.download_file("emma-simbot", str(s3_path), str(local_path)) - except botocore.exceptions.ClientError: - log.error(f"Failed to download {local_path}") - - def _get_paths(self, local_feature_path: Path) -> Path: - """Get the paths as is, without needing any special handling.""" - idx2split = local_feature_path.parts.index("datasets") - feature_path = Path(*local_feature_path.parts[idx2split + 1 :]) - s3_feature_path = Path("datasets", feature_path) - - return s3_feature_path - - def _get_paths_for_alfred(self, local_feature_path: Path) -> Path: - idx2split = local_feature_path.parts.index("alfred") - feature_path = Path(*local_feature_path.parts[idx2split + 1 :]) - s3_feature_path = Path("datasets", "alfred", "full_2.1.0", feature_path) - - return s3_feature_path - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - - parser.add_argument( - "--input_db", - help="Path to the input database", - type=Path, - default="storage/fixtures/instances.db", - ) - - parser.add_argument( - "--instance_type", - help="Type of instance model used within the DatasetDb", - choices=["pretrain", "teach_edh", "simbot"], - default="pretrain", - ) - - args = parser.parse_args() - - downloader = FixtureDownload(args.input_db, args.instance_type) - downloader.run() diff --git a/scripts/slurm/SLURM.md b/scripts/slurm/SLURM.md deleted file mode 100644 index 730e461..0000000 --- a/scripts/slurm/SLURM.md +++ /dev/null @@ -1,12 +0,0 @@ -# Submitting a job to Cirrus GPU Nodes - -The following examples are specific to Cirrus and project ec202. - -The official documentation for using the Cirrus GPU nodes can be found [here](https://cirrus.readthedocs.io/en/main/user-guide/gpu.html). - -You can find an example submission script in `scripts/slurm/job_multi_node.slurm`. -You need to do at least the following modifications: - -1. Modify the number of nodes (each node has 4 GPUs). -2. Activate your environment. Cirrus suggests creating a custom miniconda environment. For further instructions, see [here](https://cirrus.readthedocs.io/en/main/user-guide/python.html?highlight=pytorch#custom-miniconda3-environments). -3. Add you personal WANDB_API_KEY which you can find in the settings of your wandb profile. diff --git a/scripts/slurm/job_multi_node.slurm b/scripts/slurm/job_multi_node.slurm deleted file mode 100644 index 87114e5..0000000 --- a/scripts/slurm/job_multi_node.slurm +++ /dev/null @@ -1,28 +0,0 @@ -#!/bin/bash -# -#SBATCH --job-name=emma_pretraining -#SBATCH --partition=gpu-cascade -#SBATCH --qos=gpu -#SBATCH --exclusive -#SBATCH --nodes=2 -#SBATCH --gres=gpu:4 -#SBATCH --time=48:00:00 -#SBATCH --account=ec202 - -# Load the required modules -module load nvidia/nvhpc -# The following loads the central pytorch environment. Instead you need to initialise your environment: -module load pytorch/1.11.0-gpu - -# Add your wandb credentials -export WANDB_API_KEY=fakekey3 -export WANDB_CONFIG_DIR=/work/ec202/ec202/shared/.config/wadnb -export WANDB_CACHE_DIR=/work/ec202/ec202/shared/.cache/wadnb - -# Set the following environment variables when running on multiple nodes -# We assume 4 GPUs per node -export SLURM_NTASKS=$((4 * SLURM_NNODES)) -export SLURM_NTASKS_PER_NODE=$(expr ${SLURM_NTASKS} \/ ${SLURM_NNODES}) -export SLURM_TASKS_PER_NODE="${SLURM_NTASKS_PER_NODE}(x${SLURM_NNODES})" - -srun python run.py trainer=ddp trainer.num_nodes=$SLURM_NNODES trainer.devices=$SLURM_TASKS_PER_NODE diff --git a/src/emma_policy/commands/clean_simbot_db_from_unmatched_instances.py b/src/emma_policy/commands/clean_simbot_db_from_unmatched_instances.py index e52322b..1f25468 100644 --- a/src/emma_policy/commands/clean_simbot_db_from_unmatched_instances.py +++ b/src/emma_policy/commands/clean_simbot_db_from_unmatched_instances.py @@ -25,8 +25,8 @@ SimBotActionDataset, compressed_mask_is_bbox, ) -from emma_policy.datamodules.simbot_nlu_dataset import ( - SimBotNLUDataset, +from emma_policy.datamodules.simbot_cr_dataset import ( + SimBotCRDataset, action_is_object_interaction, ) from emma_policy.utils import decompress_simbot_mask, get_logger @@ -49,7 +49,7 @@ def __init__( valid_output_db_file: Path, iou_threshold: float = 0.5, minimum_iou_threshold: float = 0.1, - simbot_db_type: Literal["action", "nlu"] = "action", + simbot_db_type: Literal["action", "cr"] = "action", matching_strategy: Literal["threshold_only", "threshold_and_label"] = "threshold_only", model_name: str = "heriot-watt/emma-base", ) -> None: @@ -67,7 +67,7 @@ def __init__( self._object_assets_to_names = arena_definitions["asset_to_label"] self._label_to_idx = arena_definitions["label_to_idx"] - self.dataset: Union[SimBotActionDataset, SimBotNLUDataset] + self.dataset: Union[SimBotActionDataset, SimBotCRDataset] if simbot_db_type == "action": self._action_idx = -1 self._purge_instance = self._discard_action_unmatched_instance @@ -79,9 +79,9 @@ def __init__( ) else: self._action_idx = 0 - self._purge_instance = self._discard_nlu_unmatched_instance + self._purge_instance = self._discard_cr_unmatched_instance self.dataset_name = DatasetName.simbot_clarifications.name - self.dataset = SimBotNLUDataset( + self.dataset = SimBotCRDataset( dataset_db_path=valid_input_db_file, tokenizer=tokenizer, iou_threshold=args.iou_threshold, @@ -226,7 +226,7 @@ def _discard_action_unmatched_instance( return instance return None - def _discard_nlu_unmatched_instance( + def _discard_cr_unmatched_instance( self, instance: SimBotInstructionInstance ) -> Optional[SimBotInstructionInstance]: """Discard instances where the target object does not match any predicted bounding box.""" @@ -342,7 +342,7 @@ def _write_db( ) parser.add_argument("--valid_output_db_path", type=Path) parser.add_argument( - "--simbot_db_type", choices=["action", "nlu"], help="The type of SimBot task." + "--simbot_db_type", choices=["action", "cr"], help="The type of SimBot task." ) parser.add_argument( "--iou_threshold", diff --git a/src/emma_policy/commands/find_simbot_centroid_features.py b/src/emma_policy/commands/find_simbot_centroid_features.py deleted file mode 100644 index 9f8af3d..0000000 --- a/src/emma_policy/commands/find_simbot_centroid_features.py +++ /dev/null @@ -1,243 +0,0 @@ -import argparse -from typing import Any, Union - -import torch -from emma_datasets.constants.simbot.simbot import get_arena_definitions -from emma_datasets.datamodels.datasets.simbot import SimBotInstructionInstance -from emma_datasets.datamodels.datasets.utils.simbot_utils.instruction_processing import ( - get_object_asset_from_object_id, -) -from emma_datasets.db import DatasetDb -from sklearn.metrics import classification_report, confusion_matrix -from torch.nn import CosineSimilarity -from torchvision.ops import masks_to_boxes -from tqdm import tqdm - -from emma_policy.datamodules.base_dataset import best_match_features -from emma_policy.utils import decompress_simbot_mask - - -class EntityFeatureClassifier: - """EntityFeatureClassifier class. - - Used to determine the sub-classes based on pure feature averaging and cosine similarity. - """ - - def __init__( - self, - train_db: str, - test_db: str, - save_path: str, - feature_size: float = 2048, - ) -> None: - self.train_db = train_db - self.test_db = test_db - self.save_path = save_path - - arena_definitions = get_arena_definitions() - self._assets_to_labels = get_arena_definitions()["asset_to_label"] - self._special_asset_to_readable_name = arena_definitions["special_asset_to_readable_name"] - self._remaining_subclasses = { - "Computer_Monitor_01": "Computer", - "Computer_Monitor_Broken": "Computer", - "Computer_Monitor_New": "Computer", - "Lab_Terminal": "Computer", - "TAMPrototypeHead_01": "Computer", - "Desk_01": "Table", - "SM_Prop_Table_02": "Table", - } - self._special_asset_to_readable_name.update(self._remaining_subclasses) - self._special_asset_running_avg = { - readable_name: {"count": 1, "centroid": torch.zeros(feature_size)} # type: ignore[call-overload] - for _, readable_name in self._special_asset_to_readable_name.items() - } - self._similarity = CosineSimilarity() - - def run(self) -> None: - """Run the classification pipeline.""" - self.run_for_split(input_db=self.train_db, split="train") - self._save_centroids() - if self.test_db is not None: - self._test_results: dict[str, list[str]] = {"groundtruth": [], "prediction": []} - self.run_for_split(input_db=self.test_db, split="test") - report = classification_report( - self._test_results["groundtruth"], self._test_results["prediction"] - ) - print(report) # noqa: WPS421 - cm = confusion_matrix( - self._test_results["groundtruth"], self._test_results["prediction"] - ) - print(cm) # noqa: WPS421 - - def run_for_split(self, input_db: str, split: str = "train") -> None: - """Run the classifier for a single split.""" - db = DatasetDb(input_db) - db_size = len(db) - - for idx in tqdm(range(db_size)): - instance_str = db[idx] - instance = SimBotInstructionInstance.parse_raw(instance_str) - for action_idx, action in enumerate(instance.actions): - action_type = action.type - action_metadata = getattr(action, action_type.lower()) - action_object_metadata = action_metadata.get("object", None) - - # Ignore actions that do not have an object id - if ( # noqa: WPS337 - action_type == "Examine" - or action_object_metadata is None - or "id" not in action_object_metadata - or action_object_metadata.get("mask", None) is None - ): - continue - - features = torch.load(instance.features_path[action_idx])["frames"][0]["features"] - self._run_for_action( - action_type, - action_object_metadata, - features, - instance.vision_augmentation, - split, - ) - - def _run_for_object( - self, - object_asset: str, - object_mask: Union[list[int], list[list[int]]], - features: dict[str, Any], - vision_augmentation: bool = False, - split: str = "train", - ) -> None: - if object_asset not in self._special_asset_to_readable_name: - return - - readable_name = self._special_asset_to_readable_name[object_asset] - matched_indices, ground_truth_flags = self._gt_bbox_from_features( - object_mask, features, vision_augmentation - ) - - if split == "train" and ground_truth_flags[0].item(): - self._update_running_average( - object_asset=readable_name, - bbox_features=features["bbox_features"][matched_indices[0].item()], - ) - - elif split == "test" and ground_truth_flags[0].item(): - self._update_test_metrics( - object_asset=readable_name, - bbox_features=features["bbox_features"][matched_indices[0].item()], - ) - - def _run_for_action( - self, - action_type: str, - action_object_metadata: dict[str, Any], - features: dict[str, Any], - vision_augmentation: bool = False, - split: str = "train", - ) -> None: - if action_type == "Search": - for object_idx, object_id in enumerate(action_object_metadata["id"]): - object_asset = get_object_asset_from_object_id(object_id, self._assets_to_labels) - self._run_for_object( - object_asset=object_asset, - object_mask=action_object_metadata["mask"][object_idx], - features=features, - vision_augmentation=vision_augmentation, - ) - - else: - object_asset = get_object_asset_from_object_id( - action_object_metadata["id"], self._assets_to_labels - ) - self._run_for_object( - object_asset=object_asset, - object_mask=action_object_metadata["mask"], - features=features, - vision_augmentation=vision_augmentation, - split=split, - ) - - def _gt_bbox_from_features( - self, - mask: Union[list[int], list[list[int]]], - features: dict[str, Any], - vision_augmentation: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor]: - if vision_augmentation: - gt_bbox = torch.tensor(mask).unsqueeze(0) - else: - gt_binary_mask = decompress_simbot_mask(mask) # type: ignore[arg-type] - gt_bbox = masks_to_boxes(torch.tensor(gt_binary_mask).unsqueeze(0)) - - return best_match_features( - ground_truth_bbox=gt_bbox, - object_coordinates_bbox=features["bbox_coords"], - threshold=0.5, - ) - - def _update_running_average(self, object_asset: str, bbox_features: torch.Tensor) -> None: - running_counter = self._special_asset_running_avg[object_asset]["count"] - self._special_asset_running_avg[object_asset]["centroid"] = ( - (running_counter - 1) * self._special_asset_running_avg[object_asset]["centroid"] - + bbox_features - ) / running_counter - self._special_asset_running_avg[object_asset]["count"] += 1 - - def _update_test_metrics(self, object_asset: str, bbox_features: torch.Tensor) -> None: - assets = list(self._special_asset_running_avg.keys()) - centroids = [ - special_asset_dict["centroid"] - for special_asset_dict in list(self._special_asset_running_avg.values()) - ] - similarity_value = self._similarity(bbox_features, torch.stack(centroids)) - most_similar_vector = similarity_value.argmax().item() - - predicted_asset = assets[most_similar_vector] - self._test_results["groundtruth"].append(object_asset) - self._test_results["prediction"].append(predicted_asset) - - def _save_centroids(self) -> None: - centroids = {} - for special_asset, special_asset_dict in self._special_asset_running_avg.items(): - # Ignore assets that we didnt find any examples during train - if special_asset_dict["count"] > 1: - centroids[special_asset] = special_asset_dict["centroid"] - readable_to_class = { - readable_name: self._assets_to_labels.get(asset, asset) - for asset, readable_name in self._special_asset_to_readable_name.items() - } - - centroids_per_class: dict[str, dict[str, torch.Tensor]] = {} - for readable_name, centroid in centroids.items(): - object_class = readable_to_class[readable_name] - if object_class not in centroids_per_class: - centroids_per_class[object_class] = {} - centroids_per_class[object_class][readable_name] = centroid - - torch.save(centroids_per_class, self.save_path) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--train_db", - help="Path to train db.", - ) - parser.add_argument( - "--test_db", - help="Path to test db.", - ) - parser.add_argument( - "--save_path", - help="Path to output image.", - ) - args = parser.parse_args() - - classifier = EntityFeatureClassifier( - train_db=args.train_db, - test_db=args.test_db, - save_path=args.save_path, - ) - - classifier.run() diff --git a/src/emma_policy/commands/run_simbot_action_api.py b/src/emma_policy/commands/run_simbot_action_api.py index 87bf672..c4afed2 100644 --- a/src/emma_policy/commands/run_simbot_action_api.py +++ b/src/emma_policy/commands/run_simbot_action_api.py @@ -15,7 +15,6 @@ setup_rich_logging, ) from fastapi import FastAPI, Request, Response, status -from opentelemetry import trace from pydantic import BaseSettings, FilePath from transformers import PreTrainedTokenizer from uvicorn import Config, Server @@ -32,13 +31,10 @@ SimBotFindPredictionProcessor, post_process_action, ) -from emma_policy.inference.model_wrapper.simbot_raw_text_matcher import SimBotActionRawTextMatcher from emma_policy.models.simbot_combined_policy import SimBotEmmaCombinedPolicy from emma_policy.models.simbot_emma_policy import SimBotEmmaPolicy -tracer = trace.get_tracer(__name__) - PolicyModelType = Union[SimBotEmmaCombinedPolicy, SimBotEmmaPolicy] @@ -54,26 +50,13 @@ class ApiSettings(BaseSettings): model_type: Literal["combined", "standalone"] = "combined" device: str = "cpu" - raw_text_match_json: Path = Path("storage/constants/simbot_low_level_examples.json") raw_distance_threshold: int = 2 - enable_prediction_patching: bool = True - - # Observability - traces_to_opensearch: bool = False - log_to_cloudwatch: bool = False - aws_profile: str = "TeamProfile" - watchtower_log_group_name: str = "simbot_challenge" - watchtower_log_stream_name: str = "policy/{machine_name}/{logger_name}/{process_id}" - - otlp_endpoint: str = "localhost:4317" - opensearch_service_name: str = "policy" class ApiStore(TypedDict, total=False): """Common state for the API.""" input_builder: SimBotActionInputBuilder - raw_text_matcher: SimBotActionRawTextMatcher tokenizer: PreTrainedTokenizer model: PolicyModelType action_output_processor: SimBotActionPredictionProcessor @@ -133,16 +116,9 @@ async def startup_event() -> None: ) logging.info(f"Model is on device: {api_store['model'].device}") - api_store["action_output_processor"] = SimBotActionPredictionProcessor( - enable_prediction_patching=settings.enable_prediction_patching - ) + api_store["action_output_processor"] = SimBotActionPredictionProcessor() api_store["find_output_processor"] = SimBotFindPredictionProcessor() - api_store["raw_text_matcher"] = SimBotActionRawTextMatcher( - raw_text_match_json=settings.raw_text_match_json, - distance_threshold=settings.raw_distance_threshold, - ) - logging.info("Inference service is setup!") @@ -191,45 +167,38 @@ async def generate_find(request: Request, response: Response) -> list[str]: response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR raise request_err - sticky_note_case = api_store["input_builder"].check_sticky_note_case( - simbot_request, is_action=False - ) - if sticky_note_case is not None: - return [sticky_note_case] - (instruction, batch, decoder_input_ids, step_index) = api_store["input_builder"]( simbot_request, task=Task.visual_grounding ) - with tracer.start_as_current_span("Model inference"): - if batch is not None: - if decoder_input_ids is not None: - len_decode = decoder_input_ids.shape[1] - else: - len_decode = 0 - try: - with torch.no_grad(): - model_output = api_store["model"].inference_step( - batch, - decoder_input_ids=decoder_input_ids, - num_beams=api_store["num_beams"], - no_repeat_ngram_size=api_store["no_repeat_ngram_size"], - ) - actions = api_store["tokenizer"].batch_decode( - model_output[:, len_decode:], skip_special_tokens=False - ) - - except Exception as err: - # TODO: report session ID for better debugging - error_message = f"Failed to get next action for request `{simbot_request}" - logger.error(error_message, exc_info=err) - response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR - raise err + + if batch is not None: + if decoder_input_ids is not None: + len_decode = decoder_input_ids.shape[1] else: - actions = [""] - logger.debug(f"Empty action for request: {simbot_request}") + len_decode = 0 + try: + with torch.no_grad(): + model_output = api_store["model"].inference_step( + batch, + decoder_input_ids=decoder_input_ids, + num_beams=api_store["num_beams"], + no_repeat_ngram_size=api_store["no_repeat_ngram_size"], + ) + actions = api_store["tokenizer"].batch_decode( + model_output[:, len_decode:], skip_special_tokens=False + ) + + except Exception as err: + # TODO: report session ID for better debugging + error_message = f"Failed to get next action for request `{simbot_request}" + logger.error(error_message, exc_info=err) + response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR + raise err + else: + actions = [""] + logger.debug(f"Empty action for request: {simbot_request}") - with tracer.start_as_current_span("Post processing"): - post_processed_actions = api_store["find_output_processor"](actions, simbot_request) + post_processed_actions = api_store["find_output_processor"](actions, simbot_request) logger.debug(f"Predicted actions: {post_processed_actions}") return post_processed_actions @@ -250,45 +219,43 @@ async def grab_from_history(request: Request, response: Response) -> Optional[in (_, batch, decoder_input_ids, step_index) = api_store["input_builder"]( simbot_request, task=Task.visual_grounding ) - with tracer.start_as_current_span("Model inference"): - if batch is not None: - len_decode = 0 - try: - with torch.no_grad(): - model_output = api_store["model"].inference_step( - batch, - decoder_input_ids=None, - num_beams=api_store["num_beams"], - no_repeat_ngram_size=api_store["no_repeat_ngram_size"], - ) - actions = api_store["tokenizer"].batch_decode( - model_output[:, len_decode:], skip_special_tokens=False - ) - - except Exception as err: - # TODO: report session ID for better debugging - error_message = f"Failed to get next action for request `{simbot_request}" - logger.error(error_message, exc_info=err) - response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR - raise err - else: - actions = [""] - logger.debug(f"Empty action for request: {simbot_request}") - - with tracer.start_as_current_span("Post processing"): - # Select all step_indexes with an object - filtered_step_idx = [ - step_index[idx] - for idx, action in enumerate(actions) - if "_token" in action and step_index - ] - logger.debug(f"Filtered steps: {filtered_step_idx}") - - unique_ordered_steps = sorted(set(filtered_step_idx)) - logger.debug(f"Sorted ordered steps: {unique_ordered_steps}") - - # most recent timestep with object - most_recent_step = unique_ordered_steps[-1] if unique_ordered_steps else None + + if batch is not None: + len_decode = 0 + try: + with torch.no_grad(): + model_output = api_store["model"].inference_step( + batch, + decoder_input_ids=None, + num_beams=api_store["num_beams"], + no_repeat_ngram_size=api_store["no_repeat_ngram_size"], + ) + actions = api_store["tokenizer"].batch_decode( + model_output[:, len_decode:], skip_special_tokens=False + ) + + except Exception as err: + # TODO: report session ID for better debugging + error_message = f"Failed to get next action for request `{simbot_request}" + logger.error(error_message, exc_info=err) + response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR + raise err + else: + actions = [""] + logger.debug(f"Empty action for request: {simbot_request}") + + # Select all step_indexes with an object + filtered_step_idx = [ + step_index[idx] for idx, action in enumerate(actions) if "_token" in action and step_index + ] + logger.debug(f"Filtered steps: {filtered_step_idx}") + + unique_ordered_steps = sorted(set(filtered_step_idx)) + logger.debug(f"Sorted ordered steps: {unique_ordered_steps}") + + # most recent timestep with object + most_recent_step = unique_ordered_steps[-1] if unique_ordered_steps else None + logger.debug(f"most recent step: {most_recent_step}") return most_recent_step @@ -309,59 +276,48 @@ async def generate(request: Request, response: Response) -> str: response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR raise request_err - sticky_note_case = api_store["input_builder"].check_sticky_note_case( - simbot_request, is_action=True - ) - if sticky_note_case is not None: - return sticky_note_case - - if api_store["input_builder"].check_carrot_case(simbot_request): - return "dummy look down ." - # (batch, decoder_input_ids, step_index) = api_store["input_builder"]( # ) (raw_input, batch, decoder_input_ids, step_index) = api_store["input_builder"]( simbot_request, task=Task.action_execution ) - with tracer.start_as_current_span("Model inference"): - if batch is not None: - max_length = api_store["max_length_per_action_sequence"] - if decoder_input_ids is not None: - max_length += decoder_input_ids.shape[1] - len_decode = decoder_input_ids.shape[1] - else: - len_decode = 0 - try: - with torch.no_grad(): - model_output = api_store["model"].inference_step( - batch, - decoder_input_ids=decoder_input_ids, - num_beams=api_store["num_beams"], - no_repeat_ngram_size=api_store["no_repeat_ngram_size"], - max_length=max_length, - ) - action = api_store["tokenizer"].batch_decode( - model_output[:, len_decode:], skip_special_tokens=False - )[0] - - action = api_store["action_output_processor"]( - prediction=action, - frame_features=simbot_request.environment_history[-1].features, - instruction=raw_input, - ) - - except Exception as err: - # TODO: report session ID for better debugging - error_message = f"Failed to get next action for request `{simbot_request}" - logger.error(error_message, exc_info=err) - response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR - raise err + if batch is not None: + max_length = api_store["max_length_per_action_sequence"] + if decoder_input_ids is not None: + max_length += decoder_input_ids.shape[1] + len_decode = decoder_input_ids.shape[1] else: - action = "" - logger.debug(f"Empty action for request: {simbot_request}") + len_decode = 0 + try: + with torch.no_grad(): + model_output = api_store["model"].inference_step( + batch, + decoder_input_ids=decoder_input_ids, + num_beams=api_store["num_beams"], + no_repeat_ngram_size=api_store["no_repeat_ngram_size"], + max_length=max_length, + ) + action = api_store["tokenizer"].batch_decode( + model_output[:, len_decode:], skip_special_tokens=False + )[0] + + action = api_store["action_output_processor"]( + prediction=action, + frame_features=simbot_request.environment_history[-1].features, + instruction=raw_input, + ) + + except Exception as err: + # TODO: report session ID for better debugging + error_message = f"Failed to get next action for request `{simbot_request}" + logger.error(error_message, exc_info=err) + response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR + raise err + else: + action = "" + logger.debug(f"Empty action for request: {simbot_request}") - with tracer.start_as_current_span("Post processing"): - action = post_process_action(action) + action = post_process_action(action) logger.debug(f"Predicted action: {action}") return action diff --git a/src/emma_policy/commands/run_simbot_nlu.py b/src/emma_policy/commands/run_simbot_cr.py similarity index 56% rename from src/emma_policy/commands/run_simbot_nlu.py rename to src/emma_policy/commands/run_simbot_cr.py index 0a473e7..d039dc1 100644 --- a/src/emma_policy/commands/run_simbot_nlu.py +++ b/src/emma_policy/commands/run_simbot_cr.py @@ -1,41 +1,30 @@ import logging -import sys from argparse import ArgumentParser, Namespace from pathlib import Path from typing import Any, Literal, TypedDict, Union import torch -from emma_common.api.instrumentation import instrument_app -from emma_common.aws.cloudwatch import add_cloudwatch_handler_to_logger from emma_common.datamodels import TorchDataMixin -from emma_common.logging import ( - InstrumentedInterceptHandler, - logger, - setup_logging, - setup_rich_logging, -) +from emma_common.logging import logger, setup_rich_logging from fastapi import FastAPI, Request, Response, status -from opentelemetry import trace from pydantic import BaseSettings, FilePath from transformers import PreTrainedTokenizer from uvicorn import Config, Server -from emma_policy._version import __version__ # noqa: WPS436 from emma_policy.datamodules.simbot_combined_datamodule import prepare_combined_tokenizer -from emma_policy.datamodules.simbot_nlu_datamodule import prepare_nlu_tokenizer -from emma_policy.datamodules.simbot_nlu_dataset import SimBotNLUIntents -from emma_policy.inference.model_wrapper.simbot_nlu_input_builder import SimBotNLUInputBuilder -from emma_policy.inference.model_wrapper.simbot_nlu_output_processor import ( - SimBotNLUPredictionProcessor, +from emma_policy.datamodules.simbot_cr_datamodule import prepare_cr_tokenizer +from emma_policy.datamodules.simbot_cr_dataset import SimBotCRIntents +from emma_policy.inference.model_wrapper.simbot_cr_input_builder import SimBotCRInputBuilder +from emma_policy.inference.model_wrapper.simbot_cr_output_processor import ( + SimBotCRPredictionProcessor, ) from emma_policy.models.simbot_combined_policy import SimBotEmmaCombinedPolicy -from emma_policy.models.simbot_nlu_policy import SimBotNLUEmmaPolicy, postprocess_nlu_output +from emma_policy.models.simbot_cr_policy import SimBotCREmmaPolicy, postprocess_cr_output -tracer = trace.get_tracer(__name__) -DEFAULT_ACTION = SimBotNLUIntents.act_one_match.value +DEFAULT_ACTION = SimBotCRIntents.act_one_match.value -NLUModelType = Union[SimBotEmmaCombinedPolicy, SimBotNLUEmmaPolicy] +CRModelType = Union[SimBotEmmaCombinedPolicy, SimBotCREmmaPolicy] class ApiSettings(BaseSettings): @@ -45,31 +34,20 @@ class ApiSettings(BaseSettings): host: str = "0.0.0.0" # noqa: S104 workers: int = 1 log_level: str = "debug" - model_checkpoint_path: FilePath = Path("storage/model/checkpoints/simbot/nlu.ckpt") + model_checkpoint_path: FilePath = Path("storage/model/checkpoints/simbot/cr.ckpt") model_name: str = "heriot-watt/emma-base" model_type: Literal["combined", "standalone"] = "combined" device: str = "cpu" disable_missing_inventory: bool = False - enable_prediction_patching: bool = True - - # Observability - traces_to_opensearch: bool = False - log_to_cloudwatch: bool = False - aws_profile: str = "TeamProfile" - watchtower_log_group_name: str = "simbot_challenge" - watchtower_log_stream_name: str = "nlu/{machine_name}/{logger_name}/{process_id}" - - otlp_endpoint: str = "localhost:4317" - opensearch_service_name: str = "nlu" class ApiStore(TypedDict, total=False): """Common state for the API.""" - input_builder: SimBotNLUInputBuilder + input_builder: SimBotCRInputBuilder tokenizer: PreTrainedTokenizer - model: NLUModelType - output_processor: SimBotNLUPredictionProcessor + model: CRModelType + output_processor: SimBotCRPredictionProcessor num_beams: int no_repeat_ngram_size: int max_generated_text_length: int @@ -87,8 +65,8 @@ def load_model( model_name: str, device: str, model_type: Literal["combined", "standalone"], -) -> NLUModelType: - """Load an NLU checkpoint.""" +) -> CRModelType: + """Load an CR checkpoint.""" if model_type == "combined": model = SimBotEmmaCombinedPolicy( model_name=model_name, @@ -96,7 +74,7 @@ def load_model( max_generated_text_length=api_store["max_generated_text_length"], ).load_from_checkpoint(checkpoint_path) else: - model = SimBotNLUEmmaPolicy( + model = SimBotCREmmaPolicy( model_name=model_name, num_beams=api_store["num_beams"], max_generated_text_length=api_store["max_generated_text_length"], @@ -116,19 +94,18 @@ async def startup_event() -> None: if settings.model_type == "combined": api_store["tokenizer"] = prepare_combined_tokenizer(settings.model_name) else: - api_store["tokenizer"] = prepare_nlu_tokenizer(settings.model_name) - api_store["input_builder"] = SimBotNLUInputBuilder( + api_store["tokenizer"] = prepare_cr_tokenizer(settings.model_name) + api_store["input_builder"] = SimBotCRInputBuilder( tokenizer=api_store["tokenizer"], device=settings.device, ) api_store["valid_action_types"] = [ - intent.value for intent in SimBotNLUIntents if intent.is_nlu_output + intent.value for intent in SimBotCRIntents if intent.is_cr_output ] - api_store["output_processor"] = SimBotNLUPredictionProcessor( + api_store["output_processor"] = SimBotCRPredictionProcessor( valid_action_types=api_store["valid_action_types"], default_prediction=DEFAULT_ACTION, disable_missing_inventory=settings.disable_missing_inventory, - enable_prediction_patching=settings.enable_prediction_patching, ) logging.info(f"Loading model on device `{settings.device}`") api_store["model"] = load_model( @@ -169,48 +146,33 @@ async def generate(request: Request, response: Response) -> str: response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR raise request_err - with tracer.start_as_current_span("Model inference"): - logger.debug("Preparing the model input") - # If the environment history is greater than 1, - # the agent has already clarified or acted. - if len(simbot_request.environment_history) == 1: - batch, instruction = api_store["input_builder"](simbot_request) - try: # noqa: WPS229 - with torch.no_grad(): - actions = api_store["model"].inference_step(batch) - - decoded_action = postprocess_nlu_output(api_store["tokenizer"], actions)[0] - - action = api_store["output_processor"]( - instruction=instruction, - prediction=decoded_action, - frame_features=simbot_request.environment_history[-1].features, - ) - - except Exception as err: - # TODO: report session ID for better debugging - error_message = f"Failed to get next action for request `{simbot_request}" - logger.error(error_message, exc_info=err) - response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR - raise err - else: - action = DEFAULT_ACTION + logger.debug("Preparing the model input") + # If the environment history is greater than 1, + # the agent has already clarified or acted. + if len(simbot_request.environment_history) == 1: + batch, instruction = api_store["input_builder"](simbot_request) + try: # noqa: WPS229 + with torch.no_grad(): + actions = api_store["model"].inference_step(batch) + + decoded_action = postprocess_cr_output(api_store["tokenizer"], actions)[0] + + action = api_store["output_processor"](prediction=decoded_action) + + except Exception as err: + # TODO: report session ID for better debugging + error_message = f"Failed to get next action for request `{simbot_request}" + logger.error(error_message, exc_info=err) + response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR + raise err + else: + action = DEFAULT_ACTION return action def main() -> None: """Runs a server that serves any instance of an EMMA policy model.""" - if settings.traces_to_opensearch: - instrument_app( - app, - otlp_endpoint=settings.otlp_endpoint, - service_name=settings.opensearch_service_name, - service_version=__version__, - service_namespace="SimBot", - ) - setup_logging(sys.stdout, InstrumentedInterceptHandler()) - else: - setup_rich_logging(rich_traceback_show_locals=False) + setup_rich_logging(rich_traceback_show_locals=False) server = Server( Config( @@ -220,15 +182,6 @@ def main() -> None: log_level=settings.log_level, ) ) - if settings.log_to_cloudwatch: - add_cloudwatch_handler_to_logger( - boto3_profile_name=settings.aws_profile, - log_stream_name=settings.watchtower_log_stream_name, - log_group_name=settings.watchtower_log_group_name, - send_interval=1, - enable_trace_logging=settings.traces_to_opensearch, - ) - server.run() diff --git a/src/emma_policy/commands/run_teach_api.py b/src/emma_policy/commands/run_teach_api.py deleted file mode 100644 index 91848a6..0000000 --- a/src/emma_policy/commands/run_teach_api.py +++ /dev/null @@ -1,180 +0,0 @@ -import logging -from io import BytesIO -from typing import Any, Optional - -import httpx -from fastapi import FastAPI, File, Form, HTTPException, Response, UploadFile, status -from PIL import Image -from uvicorn import Config, Server - -from emma_policy.inference.api import ApiSettings -from emma_policy.inference.api.edh_parsers import get_edh_history_images, parse_edh_instance -from emma_policy.inference.api.logger import setup_logger -from emma_policy.inference.api.settings import parse_api_args -from emma_policy.inference.api.teach_state import ApiStore -from emma_policy.inference.model_wrapper import PolicyModelWrapper, SimulatorAction - - -logger = logging.getLogger(__name__) - - -settings = ApiSettings() -api_store: ApiStore = {} -app = FastAPI() -logger.info("Initializing TEACh API") - - -@app.on_event("startup") -async def startup_event() -> None: - """Run specific functions when starting up the API.""" - api_args, model_args = parse_api_args() - - api_store["data_dir"] = api_args.data_dir - api_store["images_dir"] = api_args.images_dir - api_store["split"] = api_args.split - - logger.info("Loading model") - api_store["model"] = PolicyModelWrapper.from_argparse( - process_index=1, num_processes=1, model_args=model_args - ) - logging.info("Policy TEACh API is setup!") - - -@app.get("/") -@app.get("/ping") -@app.get("/test") -async def root(response: Response) -> dict[str, Any]: - """Ping the API to make sure it is responding.""" - response.status_code = status.HTTP_200_OK - return {"action": "Look Up", "obj_relative_coord": [0.1, 0.2]} - - -@app.get("/healthcheck", status_code=status.HTTP_200_OK) -async def healthcheck(response: Response) -> str: - """Verify all the APIs are running and working.""" - logger.info("Checking Policy API") - policy_response = status.HTTP_200_OK - logger.info(f"Policy API Response: {policy_response}") - - async with httpx.AsyncClient() as client: - logger.info("Checking Perception API") - perception_response = (await client.get(settings.feature_extractor_endpoint)).status_code - logger.info(f"Perception API Response: {perception_response}") - - # Verify all the APIs are available. - all_passed = all( - [response == status.HTTP_200_OK for response in (policy_response, perception_response)] - ) - - if not all_passed: - response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE - return "failed" - - return "success" - - -@app.post("/start_new_edh_instance", status_code=status.HTTP_200_OK) -async def start_new_edh_instance( - edh_name: Optional[str] = Form(...), # noqa: WPS404 - edh_instance: str = Form(...), # noqa: WPS404 - edh_history_images: list[UploadFile] = File(...), # noqa: WPS404 -) -> str: - """Reset the model wrapper to start a new EDH instance.""" - logger.info(f"Starting new EDH instance with name `{edh_name}`") - - parsed_edh_instance = parse_edh_instance(edh_instance) - - logger.debug("Loading PIL images from bytes") - edh_history_image_bytes = [await raw_file.read() for raw_file in edh_history_images] - - try: - logger.debug("Attempting to parse images for EDH history") - parsed_edh_history_images = get_edh_history_images( - parsed_edh_instance, edh_history_image_bytes, api_store["data_dir"], api_store["split"] - ) - except Exception: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to get EDH history images", - ) - - if not parsed_edh_history_images: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="No EDH history images present", - ) - - try: - logger.debug("Starting a new EDH instance on the model") - api_store["model"].start_new_edh_instance(parsed_edh_instance, parsed_edh_history_images) - except Exception: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to start new EDH instance `{edh_name}`", - ) - - logger.info("Successfully started a new EDH instance on the model") - return "success" - - -@app.post("/get_next_action", status_code=status.HTTP_200_OK, response_model=SimulatorAction) -async def get_next_action( - img_name: Optional[str] = Form(...), # noqa: WPS404 - edh_name: Optional[str] = Form(...), # noqa: WPS404 - prev_action: Optional[str] = Form(None), # noqa: WPS404 - edh_instance: str = Form(...), # noqa: WPS404 - img: UploadFile = File(...), # noqa: WPS404 -) -> SimulatorAction: - """Get the next action from the model for the given instance.""" - if not img_name or not edh_instance: - logger.warning("Either img or edh_instance is None") - return SimulatorAction(action=None, obj_relative_coord=None) - - parsed_edh_instance = parse_edh_instance(edh_instance) - - logger.info(f"Getting next action for EDH `{parsed_edh_instance.instance_id}`") - - logger.debug("Creating PIL image from the bytes") - raw_image = await img.read() - image = Image.open(BytesIO(raw_image)) - - logger.debug(f"Previous action: {prev_action}") - previous_simulator_action = ( - SimulatorAction.parse_raw(prev_action) if prev_action is not None else None - ) - - try: - logger.debug("Attemtping to get next action from the model") - action, obj_relative_coord = api_store["model"].get_next_action( - image, parsed_edh_instance, previous_simulator_action, img_name, edh_name - ) - except Exception: - error_message = f"Failed to get next action for EDH with name `{edh_name}" - logger.error(error_message, exc_info=True) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=error_message - ) - - logger.info(f"Returning next action `{action}` (EDH `{parsed_edh_instance.instance_id})`") - return SimulatorAction(action=action, obj_relative_coord=obj_relative_coord) - - -def main() -> None: - """Run the API, exactly the same as the way TEACh does it.""" - server = Server( - Config( - "emma_policy.commands.run_teach_api:app", - host=settings.host, - port=settings.port, - log_level=settings.log_level, - ) - ) - - # Separately adjust the log level for EMMA-related modules - setup_logger(emma_log_level=settings.log_level) - - server.run() - - -if __name__ == "__main__": - main() diff --git a/src/emma_policy/datamodules/simbot_action_dataset.py b/src/emma_policy/datamodules/simbot_action_dataset.py index c43472a..3b368c3 100644 --- a/src/emma_policy/datamodules/simbot_action_dataset.py +++ b/src/emma_policy/datamodules/simbot_action_dataset.py @@ -98,9 +98,6 @@ def __getitem__(self, index: int) -> EmmaDatasetItem: if instance.vision_augmentation: return self.simbot_vision_augmentation(instance) return self.simbot_action_execution(instance) - # except Exception as e: - # print(e) - # breakpoint() def simbot_vision_augmentation( # noqa: WPS210, WPS231 self, instance: SimBotInstructionInstance @@ -366,19 +363,13 @@ def map_object_to_visual_token( ] gt_object_dict = action.get_action_data - # If the groundtruth object is a sticky note, the groundtruth bbox - # coordinates are currently provided directly in the mask - # TODO: this could potentially be improved if we have the segmentation masks for the sticky notes as well instead of the bounding boxes object_mask = gt_object_dict["object"]["mask"] - if object_name == "Sticky Note": - ground_truth_bbox = torch.tensor(object_mask[0]).float() + if compressed_mask_is_bbox(object_mask): + ground_truth_bbox = torch.tensor(object_mask, dtype=torch.float32).unsqueeze(0) else: - if compressed_mask_is_bbox(object_mask): - ground_truth_bbox = torch.tensor(object_mask, dtype=torch.float32).unsqueeze(0) - else: - gt_binary_mask = decompress_simbot_mask(object_mask) - ground_truth_bbox = masks_to_boxes(torch.tensor(gt_binary_mask).unsqueeze(0)) + gt_binary_mask = decompress_simbot_mask(object_mask) + ground_truth_bbox = masks_to_boxes(torch.tensor(gt_binary_mask).unsqueeze(0)) ground_truth_bbox[:, (0, 2)] /= self._image_width ground_truth_bbox[:, (1, 3)] /= self._image_height diff --git a/src/emma_policy/datamodules/simbot_combined_datamodule.py b/src/emma_policy/datamodules/simbot_combined_datamodule.py index 11bc76e..49e3839 100644 --- a/src/emma_policy/datamodules/simbot_combined_datamodule.py +++ b/src/emma_policy/datamodules/simbot_combined_datamodule.py @@ -14,7 +14,7 @@ from emma_policy.datamodules.collate import collate_fn from emma_policy.datamodules.emma_dataclasses import EmmaDatasetBatch, EmmaDatasetItem from emma_policy.datamodules.simbot_action_dataset import SimBotActionDataset -from emma_policy.datamodules.simbot_nlu_dataset import SimBotNLUDataset, SimBotNLUIntents +from emma_policy.datamodules.simbot_cr_dataset import SimBotCRDataset, SimBotCRIntents from emma_policy.utils import DistributedWeightedSampler, compute_weights @@ -29,7 +29,7 @@ def prepare_combined_tokenizer( """Add special tokens to tokenizer.""" tokenizer = AutoTokenizer.from_pretrained(model_name) # vad special tokens - vad_special_tokens = [intent.value for intent in SimBotNLUIntents if intent.is_special_token] + vad_special_tokens = [intent.value for intent in SimBotCRIntents if intent.is_special_token] action_special_tokens = SimBotAction_SPECIAL_TOKENS combined_special_tokens = vad_special_tokens + action_special_tokens @@ -122,7 +122,7 @@ def setup(self, stage: Optional[str] = None) -> None: ) # Train - train_vad_dataset = SimBotNLUDataset( + train_vad_dataset = SimBotCRDataset( dataset_db_path=self._simbot_vad_train_db_file, tokenizer=self._tokenizer, is_train=True, @@ -144,7 +144,7 @@ def setup(self, stage: Optional[str] = None) -> None: allow_paraphrasing=True, ) - valid_vad_dataset = SimBotNLUDataset( + valid_vad_dataset = SimBotCRDataset( dataset_db_path=self._simbot_vad_valid_db_file, tokenizer=self._tokenizer, is_train=False, @@ -163,7 +163,7 @@ def setup(self, stage: Optional[str] = None) -> None: allow_paraphrasing=True, ) - test_vad_dataset = SimBotNLUDataset( + test_vad_dataset = SimBotCRDataset( dataset_db_path=self._simbot_vad_valid_db_file, tokenizer=self._tokenizer, is_train=False, @@ -222,7 +222,7 @@ def _compute_sample_weights(self) -> list[float]: """Proportional temperature scaling to mitigate action type imbalance.""" action_db = DatasetDb(self._simbot_action_train_db_file) # First pass through the dataset to get action type counts - actions: list[Union[str, SimBotNLUIntents]] = [] + actions: list[Union[str, SimBotCRIntents]] = [] for _, _, instance_str in action_db: instance = SimBotInstructionInstance.parse_raw(instance_str) actions.append(self._get_action_type(instance.actions[-1])) diff --git a/src/emma_policy/datamodules/simbot_nlu_datamodule.py b/src/emma_policy/datamodules/simbot_cr_datamodule.py similarity index 85% rename from src/emma_policy/datamodules/simbot_nlu_datamodule.py rename to src/emma_policy/datamodules/simbot_cr_datamodule.py index 6b89918..c4515c4 100644 --- a/src/emma_policy/datamodules/simbot_nlu_datamodule.py +++ b/src/emma_policy/datamodules/simbot_cr_datamodule.py @@ -7,18 +7,18 @@ from emma_policy.datamodules.collate import collate_fn from emma_policy.datamodules.emma_dataclasses import EmmaDatasetBatch -from emma_policy.datamodules.simbot_nlu_dataset import SimBotNLUDataset, SimBotNLUIntents +from emma_policy.datamodules.simbot_cr_dataset import SimBotCRDataset, SimBotCRIntents from emma_policy.utils import DistributedWeightedSampler, compute_weights -def prepare_nlu_tokenizer( +def prepare_cr_tokenizer( model_name: str = "heriot-watt/emma-base", tokenizer_truncation_side: Literal["left", "right"] = "right", max_lang_tokens: Optional[int] = 64, ) -> PreTrainedTokenizer: """Add special tokens to tokenizer.""" tokenizer = AutoTokenizer.from_pretrained(model_name) - special_tokens = [intent.value for intent in SimBotNLUIntents if intent.is_special_token] + special_tokens = [intent.value for intent in SimBotCRIntents if intent.is_special_token] # doesn't add if they are already there tokenizer.add_special_tokens({"additional_special_tokens": special_tokens}) tokenizer.truncation_side = tokenizer_truncation_side @@ -28,13 +28,13 @@ def prepare_nlu_tokenizer( return tokenizer -def get_nlu_classes() -> list[str]: - """Get the NLU classes.""" - return [intent.name for intent in SimBotNLUIntents if intent.is_nlu_output] +def get_cr_classes() -> list[str]: + """Get the CR classes.""" + return [intent.name for intent in SimBotCRIntents if intent.is_cr_output] -class SimBotNLUDataModule(LightningDataModule): - """Data module to load SimBot instructions for the EMMA NLU model.""" +class SimBotCRDataModule(LightningDataModule): + """Data module to load SimBot instructions for the EMMA CR model.""" def __init__( self, @@ -84,7 +84,7 @@ def prepare_data(self) -> None: def setup_tokenizer(self) -> PreTrainedTokenizer: """Add special tokens to tokenizer.""" - self.tokenizer = prepare_nlu_tokenizer( + self.tokenizer = prepare_cr_tokenizer( model_name=self._model_name, tokenizer_truncation_side=self.tokenizer_truncation_side, max_lang_tokens=self._max_lang_tokens, @@ -95,27 +95,27 @@ def setup(self, stage: Optional[str] = None) -> None: """Setup datasets for the dataloaders.""" self.setup_tokenizer() - self._train_dataset = SimBotNLUDataset( + self._train_dataset = SimBotCRDataset( dataset_db_path=self._train_db_file, tokenizer=self.tokenizer, is_train=True, shuffle_objects=True, ) - self._valid_dataset = SimBotNLUDataset( + self._valid_dataset = SimBotCRDataset( dataset_db_path=self._valid_db_file, tokenizer=self.tokenizer, is_train=False, ) - self._test_dataset = SimBotNLUDataset( + self._test_dataset = SimBotCRDataset( dataset_db_path=self._test_db_file, tokenizer=self.tokenizer, is_train=False, ) def train_dataloader(self) -> DataLoader[EmmaDatasetBatch]: - """Generate train dataloader for SimBot NLU instances.""" + """Generate train dataloader for SimBot CR instances.""" if self._weighted_sampling: training_sampler_weights = compute_weights( self._train_dataset.data_intents, @@ -141,7 +141,7 @@ def train_dataloader(self) -> DataLoader[EmmaDatasetBatch]: ) def val_dataloader(self) -> DataLoader[EmmaDatasetBatch]: - """Generate valid dataloader for SimBot NLU instances.""" + """Generate valid dataloader for SimBot CR instances.""" return DataLoader( self._valid_dataset, # type: ignore[arg-type] batch_size=self._val_batch_size, @@ -151,7 +151,7 @@ def val_dataloader(self) -> DataLoader[EmmaDatasetBatch]: ) def test_dataloader(self) -> DataLoader[EmmaDatasetBatch]: - """Generate test dataloader for SimBot NLU instances.""" + """Generate test dataloader for SimBot CR instances.""" return DataLoader( self._test_dataset, # type: ignore[arg-type] batch_size=self._val_batch_size, diff --git a/src/emma_policy/datamodules/simbot_nlu_dataset.py b/src/emma_policy/datamodules/simbot_cr_dataset.py similarity index 93% rename from src/emma_policy/datamodules/simbot_nlu_dataset.py rename to src/emma_policy/datamodules/simbot_cr_dataset.py index 6b3ef60..9ad7a06 100644 --- a/src/emma_policy/datamodules/simbot_nlu_dataset.py +++ b/src/emma_policy/datamodules/simbot_cr_dataset.py @@ -40,8 +40,8 @@ logger = get_logger(__name__) -class SimBotNLUIntents(Enum): - """SimBot NLU intent types.""" +class SimBotCRIntents(Enum): + """SimBot CR intent types.""" act = "" search = "" @@ -72,7 +72,7 @@ def is_special_token(self) -> bool: } @property - def is_nlu_output(self) -> bool: + def is_cr_output(self) -> bool: """Wether an intent is a valid output.""" return self in { self.act_one_match, @@ -98,8 +98,8 @@ def action_is_object_interaction(action: SimBotAction) -> bool: return "officeRoom" not in object_metadata -class SimBotNLUDataset(EmmaBaseDataset[EmmaDatasetItem]): - """Dataset for AreanNLU. +class SimBotCRDataset(EmmaBaseDataset[EmmaDatasetItem]): + """Dataset for AreanCR. Each instance is loaded from the DatasetDb file and converted to an instance of `EmmaDatasetItem` before being returned. @@ -126,11 +126,11 @@ def __init__( ) self.is_train = is_train - self.data_intents: list[SimBotNLUIntents] = [] + self.data_intents: list[SimBotCRIntents] = [] self._synthetic_negative_candidates: list[int] = [] self._question_type_intent_map = { - SimBotClarificationTypes.location: SimBotNLUIntents.act_no_match, - SimBotClarificationTypes.disambiguation: SimBotNLUIntents.act_too_many_matches, + SimBotClarificationTypes.location: SimBotCRIntents.act_no_match, + SimBotClarificationTypes.disambiguation: SimBotCRIntents.act_too_many_matches, } self._prepare_data() arena_definitions = get_arena_definitions() @@ -155,7 +155,7 @@ def __len__(self) -> int: @overrides(check_signature=False) def __getitem__(self, index: int) -> EmmaDatasetItem: - """Get the SimBot NLU instance at the given index as an instance of `EmmaDatasetItem`.""" + """Get the SimBot CR instance at the given index as an instance of `EmmaDatasetItem`.""" with self.db: instance_str = self.db[index] instance = SimBotInstructionInstance.parse_raw(instance_str) @@ -250,7 +250,7 @@ def prepare_cdf_action_instance( ) if missing_holding_object: instance.actions[0].inventory_object_id = None - target_text = SimBotNLUIntents.act_missing_inventory.value + target_text = SimBotCRIntents.act_missing_inventory.value object_readable_name = get_object_readable_name_from_object_id( object_id=holding_object, @@ -259,7 +259,7 @@ def prepare_cdf_action_instance( ) else: - target_text = SimBotNLUIntents.act_one_match.value + target_text = SimBotCRIntents.act_one_match.value object_readable_name = self._get_target_object_name( action=instance.actions[0], name_type="readable", @@ -296,10 +296,10 @@ def prepare_human_action_instance( inventory_object_id=instance.actions[0].inventory_object_id, ) # First try to get the target for a clarification - target_text = self._get_nlu_human_question(instance) + target_text = self._get_cr_human_question(instance) # If target_text is an empty list, we have an action if not target_text: - target_text = SimBotNLUIntents.act_one_match.value + target_text = SimBotCRIntents.act_one_match.value object_readable_name = self._get_target_object_name( action=instance.actions[0], name_type="readable", @@ -328,7 +328,7 @@ def prepare_synthetic_action_instance( instruction=instruction, inventory_object_id=instance.actions[0].inventory_object_id, ) - target_text = self._get_nlu_synthetic_too_many_matches(instance) + target_text = self._get_cr_synthetic_too_many_matches(instance) else: instruction, target_text = self._augment_synthetic_action(instance) return instruction, target_text @@ -361,12 +361,12 @@ def prepare_search_instance( visual_features = self._load_visual_features( features_path=negative_instance.features_path[0] ) - target_text = f"{SimBotNLUIntents.search_no_match.value} {object_readable_name}" + target_text = f"{SimBotCRIntents.search_no_match.value} {object_readable_name}" is_negative = True elif action_object_metadata["mask"] is None: # A negative search sample visual_features = self._load_visual_features(features_path=instance.features_path[0]) - target_text = f"{SimBotNLUIntents.search_no_match.value} {object_readable_name}" + target_text = f"{SimBotCRIntents.search_no_match.value} {object_readable_name}" is_negative = True else: @@ -387,13 +387,13 @@ def prepare_search_instance( ) # If there is a matching bounding box, append its visual token to the target text if ground_truth_flags.shape[0] == 0: - target_text = f"{SimBotNLUIntents.search_no_match.value} {object_readable_name}" + target_text = f"{SimBotCRIntents.search_no_match.value} {object_readable_name}" is_negative = True elif ground_truth_flags.shape[0] == 1: - target_text = f"{SimBotNLUIntents.search_one_match.value} {object_readable_name}" + target_text = f"{SimBotCRIntents.search_one_match.value} {object_readable_name}" else: target_text = ( - f"{SimBotNLUIntents.search_too_many_matches.value} {object_readable_name}" + f"{SimBotCRIntents.search_too_many_matches.value} {object_readable_name}" ) instruction = self._prepare_search_instruction( @@ -450,7 +450,7 @@ def _augment_synthetic_action(self, instance: SimBotInstructionInstance) -> tupl instruction, target_text = self._augment_synthetic_inventory( instance=instance, missing_inventory_proba=self._one_match_to_missining_inventory_proba, - target_text=SimBotNLUIntents.act_one_match.value, + target_text=SimBotCRIntents.act_one_match.value, ) else: visual_features = self._load_visual_features(features_path=instance.features_path[0]) @@ -466,7 +466,7 @@ def _augment_synthetic_action(self, instance: SimBotInstructionInstance) -> tupl instruction, target_text = self._augment_synthetic_inventory( instance=instance, missing_inventory_proba=self._no_match_to_missining_inventory_proba, - target_text=SimBotNLUIntents.act_no_match.value, + target_text=SimBotCRIntents.act_no_match.value, include_location_in_attributes=False, ) @@ -495,7 +495,7 @@ def _get_synthectic_action_instruction( return instruction - def _get_nlu_human_question(self, instance: SimBotInstructionInstance) -> Optional[str]: + def _get_cr_human_question(self, instance: SimBotInstructionInstance) -> Optional[str]: """Get the target text and question type vector from a human question. Examples to avoid: @@ -519,9 +519,9 @@ def _get_nlu_human_question(self, instance: SimBotInstructionInstance) -> Option return question_as_target - def _get_nlu_synthetic_too_many_matches(self, instance: SimBotInstructionInstance) -> str: + def _get_cr_synthetic_too_many_matches(self, instance: SimBotInstructionInstance) -> str: """Get the target text and question type vector from a synthetic question.""" - question_as_target = SimBotNLUIntents.act_too_many_matches.value + question_as_target = SimBotCRIntents.act_too_many_matches.value object_name = self._get_target_object_name(instance.actions[0], name_type="class") if object_name: question_as_target = f"{question_as_target} {object_name}" @@ -575,7 +575,7 @@ def _augment_synthetic_inventory( inventory_object_id=None, instruction=instruction, ) - target_text = SimBotNLUIntents.act_missing_inventory.value + target_text = SimBotCRIntents.act_missing_inventory.value object_readable_name = get_object_readable_name_from_object_id( object_id=instance.actions[0].inventory_object_id, object_assets_to_names=self._object_assets_to_names, @@ -618,21 +618,21 @@ def _prepare_data(self) -> None: continue self._synthetic_negative_candidates.append(index) - def _get_data_intent(self, instance: SimBotInstructionInstance) -> SimBotNLUIntents: + def _get_data_intent(self, instance: SimBotInstructionInstance) -> SimBotCRIntents: if instance.actions[0].type == "Search": action_object_metadata = instance.actions[0].get_action_data["object"] if action_object_metadata["mask"] is None: - return SimBotNLUIntents.search_no_match - return SimBotNLUIntents.search_one_match + return SimBotCRIntents.search_no_match + return SimBotCRIntents.search_one_match if instance.instruction.necessary_question_answers: qa_pair = instance.instruction.necessary_question_answers[0] return self._question_type_intent_map.get( - qa_pair.question_type, SimBotNLUIntents.act_one_match + qa_pair.question_type, SimBotCRIntents.act_one_match ) elif instance.ambiguous: - return SimBotNLUIntents.act_too_many_matches - return SimBotNLUIntents.act_one_match + return SimBotCRIntents.act_too_many_matches + return SimBotCRIntents.act_one_match def _get_instance_frame(self, instance: SimBotInstructionInstance, target_text: str) -> int: """Get either the image infront of you or the image with the target object.""" @@ -656,8 +656,8 @@ def _load_visual_features(self, features_path: Path, frame_idx: int = 0) -> Emma return visual_features def _is_no_match(self, target_text: str) -> bool: - """Check if the instance NLU label is no_match.""" - return SimBotNLUIntents.no_match.value in target_text + """Check if the instance CR label is no_match.""" + return SimBotCRIntents.no_match.value in target_text def _get_target_object_name( self, action: SimBotAction, name_type: Literal["class", "readable"] = "readable" @@ -754,7 +754,7 @@ def _sample_cdf_no_match( inventory_object_id=instance.actions[0].inventory_object_id, ) - target_text = SimBotNLUIntents.act_no_match.value + target_text = SimBotCRIntents.act_no_match.value if object_readable_name: target_text = f"{target_text} {object_readable_name}" diff --git a/src/emma_policy/datamodules/teach_edh_datamodule.py b/src/emma_policy/datamodules/teach_edh_datamodule.py deleted file mode 100644 index aa684f0..0000000 --- a/src/emma_policy/datamodules/teach_edh_datamodule.py +++ /dev/null @@ -1,129 +0,0 @@ -from pathlib import Path -from typing import Literal, Optional, Union - -from pytorch_lightning import LightningDataModule -from torch.utils.data import ConcatDataset, DataLoader -from transformers import AutoTokenizer - -from emma_policy.datamodules.collate import collate_fn -from emma_policy.datamodules.emma_dataclasses import EmmaDatasetBatch -from emma_policy.datamodules.teach_edh_dataset import TeachEdhDataset - - -class TeachEdhDataModule(LightningDataModule): - """Data module to load EDH instances for the EMMA Policy model.""" - - def __init__( - self, - teach_edh_train_db_file: Union[str, Path], - teach_edh_valid_seen_db_file: Union[str, Path], - teach_edh_valid_unseen_db_file: Union[str, Path], - load_valid_data_split: Optional[Literal["seen", "unseen", "both"]] = None, - train_batch_size: int = 8, - val_batch_size: int = 8, - num_workers: int = 0, - model_name: str = "heriot-watt/emma-base", - max_lang_tokens: Optional[int] = None, - max_frames: int = 0, - tokenizer_truncation_side: Literal["left", "right"] = "right", - ) -> None: - super().__init__() - if isinstance(teach_edh_train_db_file, str): - teach_edh_train_db_file = Path(teach_edh_train_db_file) - if isinstance(teach_edh_valid_seen_db_file, str): - teach_edh_valid_seen_db_file = Path(teach_edh_valid_seen_db_file) - if isinstance(teach_edh_valid_unseen_db_file, str): - teach_edh_valid_unseen_db_file = Path(teach_edh_valid_unseen_db_file) - - self._teach_edh_train_db_file = teach_edh_train_db_file - self._teach_edh_valid_seen_db_file = teach_edh_valid_seen_db_file - self._teach_edh_valid_unseen_db_file = teach_edh_valid_unseen_db_file - - # Preparation - self._load_valid_data_split = load_valid_data_split - - # Dataloader constraints - self._max_lang_tokens = max_lang_tokens - self._max_frames = max_frames - self._tokenizer_truncation_side = tokenizer_truncation_side - self._num_workers = num_workers - self._train_batch_size = train_batch_size - self._val_batch_size = val_batch_size - - # Model - self._model_name = model_name - - def prepare_data(self) -> None: - """Perform any preparation steps necessary before loading the data to the model.""" - super().prepare_data() - - AutoTokenizer.from_pretrained(self._model_name) - - def setup(self, stage: Optional[str] = None) -> None: - """Setup datasets for the dataloaders.""" - self._tokenizer = AutoTokenizer.from_pretrained(self._model_name) - self._tokenizer.truncation_side = self._tokenizer_truncation_side - - if self._max_lang_tokens: - self._tokenizer.model_max_length = self._max_lang_tokens - - self._train_dataset = TeachEdhDataset( - dataset_db_path=self._teach_edh_train_db_file, - tokenizer=self._tokenizer, - max_frames=self._max_frames, - ) - - if self._load_valid_data_split: - self._valid_seen_dataset = TeachEdhDataset( - dataset_db_path=self._teach_edh_valid_seen_db_file, - tokenizer=self._tokenizer, - max_frames=self._max_frames, - ) - self._valid_unseen_dataset = TeachEdhDataset( - dataset_db_path=self._teach_edh_valid_unseen_db_file, - tokenizer=self._tokenizer, - max_frames=self._max_frames, - ) - - def train_dataloader(self) -> DataLoader[EmmaDatasetBatch]: - """Generate train dataloader for TEACh EDH instances.""" - return DataLoader( - self._train_dataset, # type: ignore[arg-type] - batch_size=self._train_batch_size, - num_workers=self._num_workers, - collate_fn=collate_fn, - shuffle=True, - pin_memory=True, - ) - - def val_dataloader(self) -> DataLoader[EmmaDatasetBatch]: - """Generate validation dataloader for the TEACh EDH instances. - - Default to returning the valid seen dataset because there needs to be a return else it will - causes exceptions down the line. - """ - if self._load_valid_data_split == "unseen": - return DataLoader( - self._valid_unseen_dataset, # type: ignore[arg-type] - batch_size=self._val_batch_size, - num_workers=self._num_workers, - collate_fn=collate_fn, - shuffle=False, - ) - - if self._load_valid_data_split == "both": - return DataLoader( - ConcatDataset([self._valid_seen_dataset, self._valid_unseen_dataset]), - batch_size=self._val_batch_size, - num_workers=self._num_workers, - collate_fn=collate_fn, - shuffle=False, - ) - - return DataLoader( - self._valid_seen_dataset, # type: ignore[arg-type] - batch_size=self._val_batch_size, - num_workers=self._num_workers, - collate_fn=collate_fn, - shuffle=False, - ) diff --git a/src/emma_policy/datamodules/teach_edh_dataset.py b/src/emma_policy/datamodules/teach_edh_dataset.py deleted file mode 100644 index 3d54ffd..0000000 --- a/src/emma_policy/datamodules/teach_edh_dataset.py +++ /dev/null @@ -1,274 +0,0 @@ -import json -from pathlib import Path -from typing import Any, Literal, Union - -import torch -from emma_datasets.datamodels.datasets.teach import ( - ExtendedTeachDriverAction, - TeachDriverAction, - TeachEdhInstance, -) -from overrides import overrides -from transformers import PreTrainedTokenizer - -from emma_policy.common.settings import Settings -from emma_policy.datamodules.base_dataset import EmmaBaseDataset -from emma_policy.datamodules.emma_dataclasses import EmmaDatasetItem, EmmaVisualFeatures -from emma_policy.datamodules.pretrain_dataset import split_action_name -from emma_policy.datamodules.pretrain_instances import Task -from emma_policy.utils import get_logger - - -logger = get_logger(__name__) - -BBOX_DIAMETER = 10 -BBOX_RADIUS = BBOX_DIAMETER / 2 -AI2THOR_CLASS_DICT_FILE = Settings().paths.constants.joinpath("ai2thor_labels.json") - - -class TeachEdhDataset(EmmaBaseDataset[EmmaDatasetItem]): - """Dataset for EDH instances from TEACh. - - Each instance is loaded from the DatasetDb file and converted to an instance of - `EmmaDatasetItem` before being returned. - """ - - def __init__( - self, dataset_db_path: Path, tokenizer: PreTrainedTokenizer, max_frames: int = 0 - ) -> None: - super().__init__( - dataset_db_path=dataset_db_path, tokenizer=tokenizer, max_frames=max_frames - ) - - with open(AI2THOR_CLASS_DICT_FILE) as in_file: - self.ai2thor_label_mapping = json.load(in_file) - - @overrides(check_signature=False) - def __getitem__(self, index: int) -> EmmaDatasetItem: - """Get the EDH instance at the given index as an instance of `EmmaDatasetItem`.""" - with self.db: - instance_str: str = self.db[index] - - instance = TeachEdhInstance.parse_raw(instance_str) - return self._convert_instance_to_emma_dataset_item(instance) - - def _convert_instance_to_emma_dataset_item( - self, instance: TeachEdhInstance - ) -> EmmaDatasetItem: - """Convert the EDH instance to an instance of `EmmaDatasetItem`.""" - visual_features, scene_temporal_ids, object_temporal_ids = self._prepare_visual_input( - instance - ) - input_encoding = self.tokenizer( - self._get_input_text_from_instance(instance, visual_features), - return_tensors=self._return_tensor_type, - truncation=True, - ) - - target_encoding = self.tokenizer( - self._get_target_text_from_instance(instance, visual_features), - return_tensors=self._return_tensor_type, - ) - - target_temporal_ids = self._make_target_temporal_ids(target_encoding.input_ids.squeeze(0)) - return EmmaDatasetItem( - # Language - input_token_ids=input_encoding.input_ids.squeeze(0), - text_attention_mask=input_encoding.attention_mask.squeeze(0), - target_token_ids=target_encoding.input_ids.squeeze(0), - decoder_attention_mask=target_encoding.attention_mask.squeeze(0), - target_temporal_ids=target_temporal_ids, - # Visual features - object_attention_mask=visual_features.object_attention_mask, - object_coordinates=visual_features.object_coordinates, - object_features=visual_features.object_features, - object_frame_tokens=visual_features.object_frame_tokens, - scene_attention_mask=visual_features.scene_attention_mask, - scene_coordinates=visual_features.scene_coordinates, - scene_features=visual_features.scene_features, - scene_frame_tokens=visual_features.scene_frame_tokens, - visual_token_ids=visual_features.visual_token_ids, - scene_temporal_ids=scene_temporal_ids, - object_temporal_ids=object_temporal_ids, - # Task - task=self._get_task_as_tensor(Task.action_execution), - ) - - def _get_input_text_from_instance( - self, instance: TeachEdhInstance, visual_features: EmmaVisualFeatures - ) -> str: - """Get the input text from a TEACh EDH instance.""" - input_text = self._get_concatenated_dialog_history(instance) - - actions = self._convert_trajectory_to_text( - actions=instance.extended_driver_action_history, - feature_dicts=self._load_feature_dicts( - instance.features_path, instance.modality, allow_empty=True - ), - visual_features=visual_features, - truncation_side="left", # keep most recent actions - ) - - if actions: - input_text = "{input_text} {sep_token} {action_trajectory}".format( - input_text=input_text, - sep_token=self.tokenizer.sep_token, - action_trajectory=actions, - ) - - # Add action execution task prefix - input_text = self._get_random_template_for_task(Task.action_execution).format( - instruction=input_text, - ) - return input_text - - def _get_target_text_from_instance( - self, instance: TeachEdhInstance, visual_features: EmmaVisualFeatures - ) -> str: - """Get the target text from a TEACh EDH instance.""" - return self._convert_trajectory_to_text( - actions=instance.driver_actions_future, - feature_dicts=self._load_feature_dicts( - instance.future_features_path, instance.modality, allow_empty=True - ), - visual_features=visual_features, - truncation_side="right", # keep first actions - ) - - def _get_concatenated_dialog_history( - self, instance: TeachEdhInstance, cleaned: bool = True - ) -> str: - """Get dialog history as a concatenated list of strings.""" - if cleaned: - dialog_history = instance.dialog_history_cleaned - else: - dialog_history = instance.dialog_history - - concat_dialog_history = [ - f"<<{utterance.speaker.lower()}>> {utterance.utterance}" - for utterance in dialog_history - if utterance.utterance - ] - concat_dialog_history[-1] = self._refine_instruction_text(concat_dialog_history[-1]) # type: ignore[assignment] - return " ".join(concat_dialog_history) - - def _convert_trajectory_to_text( - self, - actions: Union[list[ExtendedTeachDriverAction], list[TeachDriverAction]], - feature_dicts: list[dict[str, Any]], - visual_features: EmmaVisualFeatures, - truncation_side: Literal["left", "right"] = "left", - ) -> str: - """Convert a list of driver actions to a single string.""" - if self.max_frames: - feature_dicts = self._truncate_frames(feature_dicts, truncation_side=truncation_side) - actions = self._truncate_frames(actions, truncation_side=truncation_side) - - trajectory = [] - - for action_idx, action in enumerate(actions): - trajectory.extend(split_action_name(action.action_name)) - - if action.obj_interaction_action == 1: - ground_truth_centroid_coord = ( - action.x * feature_dicts[action_idx]["width"], - action.y * feature_dicts[action_idx]["height"], - ) - ground_truth_bbox = torch.tensor( - [ - ground_truth_centroid_coord[0] - BBOX_RADIUS, # x1 - ground_truth_centroid_coord[1] - BBOX_RADIUS, # y1 - ground_truth_centroid_coord[0] + BBOX_RADIUS, # x2 - ground_truth_centroid_coord[1] + BBOX_RADIUS, # y2 - ] - ) - # normalized coordinates - ground_truth_bbox[[0, 2]] /= feature_dicts[action_idx]["width"] - ground_truth_bbox[[1, 3]] /= feature_dicts[action_idx]["height"] - - # Get the index of the objects from the current frame. Frames start from 1. - frame_token = self.tokenizer.convert_tokens_to_ids(f"") - frame_objects = visual_features.object_frame_tokens == frame_token - - matched_index, gt_flags = self._best_match_features( - ground_truth_bbox=ground_truth_bbox.unsqueeze(0), - object_coordinates_bbox=visual_features.object_coordinates[frame_objects], - threshold=0, # this is set to 0 to filter out boxes not matching at all - ) - - # we first add the class of the object we want to interact with - trajectory.append(action.object_name.lower()) - - # then if we have a matching bounding box, we add the visual token as well - found_matched_object = gt_flags[0] - if found_matched_object: - trajectory.append( - self.tokenizer.decode( - visual_features.visual_token_ids[frame_objects][matched_index[0]] - ) - ) - - trajectory[-1] = f"{trajectory[-1]}{self.tokenizer.sep_token}" - - return " ".join(trajectory) - - def _make_image_temporal_ids( - self, feature_len_history: int, feature_len_future: int, object_frame_tokens: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - """Get temporal ids for scenes and objects. - - We assign -1 to history tokens and the corresponding frame index to future tokens. - """ - scene_temporal_ids = torch.cat( - [torch.full((feature_len_history,), -1), torch.arange(1, feature_len_future + 1)] - ) - # Relying on the fact that frame ids are consecutive tokens - start = self.tokenizer("").input_ids[1] - # We get the object frame id from the frame tokens - object_frame_ids = object_frame_tokens - start + 1 - object_temporal_ids = object_frame_ids - feature_len_history - object_temporal_ids.masked_fill_(object_temporal_ids <= 0, -1) - return scene_temporal_ids, object_temporal_ids - - def _make_target_temporal_ids(self, target_tokens: torch.Tensor) -> torch.Tensor: - """Get future indices for target tokens.""" - target_temporal_ids = torch.zeros_like(target_tokens) - separator_indices = torch.where(target_tokens == self.tokenizer.sep_token_id)[0] - target_temporal_ids[separator_indices + 1] = 1 - # Increment the frame id after each observed separator token - target_temporal_ids = torch.cumsum(target_temporal_ids, -1) + 1 - return target_temporal_ids - - def _prepare_visual_input( - self, instance: TeachEdhInstance - ) -> tuple[EmmaVisualFeatures, torch.Tensor, torch.Tensor]: - """Load history and future visual features and compute temporal ids.""" - # Load history visual features - visual_features = self._load_visual_features( - features_path=instance.features_path, - modality=instance.modality, - truncation_side="left", - allow_empty=True, - ) - len_history = visual_features.scene_features.shape[0] - # Load future visual features - len_future = 0 - if instance.future_features_path.exists(): - visual_features_future = self._load_visual_features( - features_path=instance.future_features_path, - modality=instance.modality, - truncation_side="right", - start_offset=len(visual_features.scene_features), - ) - len_future = visual_features_future.scene_features.shape[0] - # Combine history and future visual features - visual_features = self._concat_visual_features( - [visual_features, visual_features_future] - ) - - scene_temporal_ids, object_temporal_ids = self._make_image_temporal_ids( - feature_len_history=len_history, - feature_len_future=len_future, - object_frame_tokens=visual_features.object_frame_tokens, - ) - return visual_features, scene_temporal_ids, object_temporal_ids diff --git a/src/emma_policy/inference/__init__.py b/src/emma_policy/inference/__init__.py index fce9529..8b13789 100644 --- a/src/emma_policy/inference/__init__.py +++ b/src/emma_policy/inference/__init__.py @@ -1,6 +1 @@ -from emma_policy.inference.actions import ( - TEACH_ACTION_TO_SYNONYMS, - AgentAction, - get_synonyms_to_teach_action_map, -) -from emma_policy.inference.decoded_trajectory_parser import DecodedTrajectoryParser + diff --git a/src/emma_policy/inference/actions.py b/src/emma_policy/inference/actions.py deleted file mode 100644 index 997bd97..0000000 --- a/src/emma_policy/inference/actions.py +++ /dev/null @@ -1,167 +0,0 @@ -import json -from collections.abc import Mapping -from dataclasses import dataclass -from functools import lru_cache -from types import MappingProxyType -from typing import Optional - -import spacy -import torch -from scipy.spatial.distance import cdist - -from emma_policy.common.settings import Settings - - -TEACH_ACTION_TO_SYNONYMS: Mapping[str, set[str]] = MappingProxyType( - { - "Forward": {"forward", "move ahead"}, - "Backward": {"backward"}, - "Turn Left": {"turn left"}, - "Turn Right": {"turn right"}, - "Look Up": {"look up"}, - "Look Down": {"look down"}, - "Pan Left": {"pan left", "strafe left"}, - "Pan Right": {"pan right", "strafe right"}, - "Move Up": {"move up"}, - "Move Down": {"move down"}, - "Pickup": {"pickup", "pick", "pick up", "lift"}, - "Place": {"place", "put", "put down", "drop"}, - "Open": {"open"}, - "Close": {"close"}, - "ToggleOn": {"toggle on", "switch on", "turn on"}, - "ToggleOff": {"toggle off", "switch off", "turn off"}, - "Slice": {"slice", "cut"}, - "Dirty": {"dirty"}, - "Clean": {"clean", "wash"}, - "Fill": {"fill"}, - "Empty": {"empty"}, - "Pour": {"pour"}, - "Break": {"break", "smash"}, - } -) - -AI2THOR_CLASS_DICT_FILE = Settings().paths.constants.joinpath("ai2thor_labels.json") -AI2THOR_VECTORS_DICT_FILE = Settings().paths.constants.joinpath("ai2thor_vectors.pt") -TEACH_ACTION_DEFINITIONS = Settings().paths.constants.joinpath("teach_default_definitions.json") - - -@lru_cache(maxsize=1) -def get_synonyms_to_teach_action_map() -> dict[str, str]: - """Convert synonyms per action into a map to make it easier to get the correct action.""" - return { - synonym: action - for action, synonym_set in TEACH_ACTION_TO_SYNONYMS.items() - for synonym in synonym_set - } - - -@lru_cache(maxsize=1) -def load_teach_objects_to_indices_map() -> dict[str, int]: - """Load teach object map dictionary.""" - with open(AI2THOR_CLASS_DICT_FILE) as in_file: - object_indices_map = json.load(in_file)["label_to_idx"] - return object_indices_map - - -@lru_cache(maxsize=1) -def get_lowercase_to_teach_object_map() -> dict[str, str]: - """Map lowercase object names to teach objects.""" - with open(AI2THOR_CLASS_DICT_FILE) as in_file: - object_indices_map = json.load(in_file)["label_to_idx"] - lower_case_map = {object_name.lower(): object_name for object_name in object_indices_map} - return lower_case_map - - -@lru_cache(maxsize=1) -def prepare_ai2thor_object_and_similarity() -> tuple[spacy.Language, dict[str, torch.Tensor]]: - """Get vectors and similarities of teach objects.""" - text_processing = spacy.load( - "en_core_web_lg", - exclude=["tagger", "attribute_ruler", "parser", "senter", "lemmatizer", "ner"], - ) - return text_processing, torch.load(AI2THOR_VECTORS_DICT_FILE) - - -@lru_cache(maxsize=1) -def teach_action_types() -> dict[str, str]: - """Load action types mapping for teach actions.""" - with open(TEACH_ACTION_DEFINITIONS) as in_file: - action_definitions = json.load(in_file)["definitions"]["actions"] - action_types = { - action["action_name"]: action["action_type"] for action in action_definitions - } - return action_types - - -@dataclass -class AgentAction: - """A class that represents a robot action performed in the environment.""" - - action: str - object_label: Optional[str] = None - raw_object_label: Optional[str] = None - object_visual_token: Optional[str] = None - object_to_index = load_teach_objects_to_indices_map() - text_processing, object_similarities = prepare_ai2thor_object_and_similarity() - - def get_object_index_from_visual_token(self) -> Optional[int]: - """Get the index of the object - bounding box that matches the visual token. - - A visual token has the form . X is the index of the object starting with 1. - """ - if self.object_visual_token is not None: - return int(self.object_visual_token.split("_")[-1][0]) - 1 - return None - - def get_object_index_from_label(self, bbox_probas: torch.Tensor) -> Optional[int]: - """Get the index of the object bounding box that matches the object label. - - If there are multiple objects with the same label, pick the one with the highest - confidence. - """ - if self.object_label is None: - return None - - bbox_labels = torch.argmax(bbox_probas, -1) - - object_index = self.object_to_index[self.object_label] - object_index_in_bbox = torch.where(bbox_labels == object_index)[0] - - if len(object_index_in_bbox) > 1: - # if we have multiple objects with the same label, select the one with the highest - # confidence score - object_probas = bbox_probas[object_index_in_bbox] - most_confident_object_index = object_probas[:, object_index].argmax() - return int(object_index_in_bbox[most_confident_object_index].item()) - - if len(object_index_in_bbox) == 1: - return int(object_index_in_bbox[0].item()) - - return None - - def get_similarity_based_object_index(self, bbox_probas: torch.Tensor) -> Optional[int]: - """Get the index of the object bounding box that has the most similar the object label.""" - if self.object_label is None: - return None - - bbox_label_indices = torch.argmax(bbox_probas, -1) - object_index = self.object_to_index[self.object_label] - similarities = self.object_similarities["similarities"][object_index, bbox_label_indices] - object_index_in_bbox = torch.argmax(similarities) - - return int(object_index_in_bbox.item()) - - def get_similarity_based_raw_object_index(self, bbox_probas: torch.Tensor) -> Optional[int]: - """Get the index of the object bounding box based on similarity with a raw object name.""" - if self.raw_object_label is None: - return None - bbox_label_indices = torch.argmax(bbox_probas, -1) - object_vector = torch.tensor(self.text_processing(self.raw_object_label).vector).unsqueeze( - 0 - ) - vectors = self.object_similarities["vectors"][bbox_label_indices] - similarities = torch.tensor(1 - cdist(object_vector, vectors, "cosine")) - - object_index_in_bbox = torch.argmax(similarities) - - return int(object_index_in_bbox.item()) diff --git a/src/emma_policy/inference/api/__init__.py b/src/emma_policy/inference/api/__init__.py index 5fba28d..8b13789 100644 --- a/src/emma_policy/inference/api/__init__.py +++ b/src/emma_policy/inference/api/__init__.py @@ -1 +1 @@ -from emma_policy.inference.api.settings import ApiSettings, parse_api_args + diff --git a/src/emma_policy/inference/api/edh_parsers.py b/src/emma_policy/inference/api/edh_parsers.py deleted file mode 100644 index 7b32e02..0000000 --- a/src/emma_policy/inference/api/edh_parsers.py +++ /dev/null @@ -1,64 +0,0 @@ -import logging -from io import BytesIO -from pathlib import Path - -from emma_datasets.datamodels.datasets import TeachEdhInstance -from fastapi import HTTPException, status -from PIL import Image - -from emma_policy.inference.api.teach_state import TeachDatasetSplit - - -logger = logging.getLogger("uvicorn.error") - - -def parse_edh_instance(raw_edh_instance: str) -> TeachEdhInstance: - """Parse raw EDH instance into structure form.""" - try: - return TeachEdhInstance.parse_raw(raw_edh_instance) - except Exception: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Could not parse EDH instance", - ) - - -def get_edh_history_images_from_dir( - edh_instance: TeachEdhInstance, data_dir: Path, dataset_split: TeachDatasetSplit -) -> list[Image.Image]: - """Load the EDH history images from the drive.""" - image_dir = data_dir.joinpath("images", dataset_split, edh_instance.game_id) - - edh_history_images = [ - Image.open(image_dir.joinpath(image_file_name)) - for image_file_name in edh_instance.driver_image_history - ] - - return edh_history_images - - -def get_edh_history_images( - edh_instance: TeachEdhInstance, - raw_images: list[bytes], - data_dir: Path, - dataset_split: TeachDatasetSplit, -) -> list[Image.Image]: - """Convert the EDH history images from the request to a list of PIL Images. - - The API _should_ be returning a list of images as bytes. These need to be converted back into - PIL Images so we can do something with them. - """ - if not edh_instance.driver_image_history: - return [] - - logging.info(f"Attempting to load {len(raw_images)} images from bytes") - edh_history_images = [Image.open(BytesIO(raw_image)) for raw_image in raw_images] - - if not edh_history_images: - logger.info("Attempting to load EDH history images from disk") - edh_history_images = get_edh_history_images_from_dir(edh_instance, data_dir, dataset_split) - - if not edh_history_images: - logger.error(f"History images are empty for EDH instance `{edh_instance.game_id}`") - - return edh_history_images diff --git a/src/emma_policy/inference/api/settings.py b/src/emma_policy/inference/api/settings.py deleted file mode 100644 index 03d711e..0000000 --- a/src/emma_policy/inference/api/settings.py +++ /dev/null @@ -1,40 +0,0 @@ -from argparse import ArgumentParser, Namespace -from pathlib import Path - -from pydantic import AnyHttpUrl, BaseSettings - - -class ApiSettings(BaseSettings): - """Common settings, which can also be got from the environment vars.""" - - port: int = 5000 - host: str = "0.0.0.0" # noqa: S104 - log_level: str = "info" - feature_extractor_endpoint: AnyHttpUrl = "http://0.0.0.0:5500" # type: ignore[assignment] - - -def parse_api_args() -> tuple[Namespace, list[str]]: - """Parse any arguments, with any extras being provided as model arguments.""" - arg_parser = ArgumentParser() - - arg_parser.add_argument( - "--data_dir", - type=Path, - required=True, - help='Base data directory containing subfolders "games" and "edh_instances"', - ) - arg_parser.add_argument( - "--images_dir", - type=Path, - required=True, - help="Images directory containing inference image output", - ) - arg_parser.add_argument( - "--split", - type=str, - default="valid_seen", - choices=["train", "valid_seen", "valid_unseen", "test_seen", "test_unseen"], - help="One of train, valid_seen, valid_unseen, test_seen, test_unseen", - ) - - return arg_parser.parse_known_args() diff --git a/src/emma_policy/inference/api/teach_state.py b/src/emma_policy/inference/api/teach_state.py deleted file mode 100644 index 741a832..0000000 --- a/src/emma_policy/inference/api/teach_state.py +++ /dev/null @@ -1,16 +0,0 @@ -from pathlib import Path -from typing import Literal, TypedDict - -from emma_policy.inference.model_wrapper import PolicyModelWrapper - - -TeachDatasetSplit = Literal["train", "valid_seen", "valid_unseen", "test_seen", "test_unseen"] - - -class ApiStore(TypedDict, total=False): - """Common state for the API.""" - - data_dir: Path - images_dir: Path - split: TeachDatasetSplit - model: PolicyModelWrapper diff --git a/src/emma_policy/inference/decoded_trajectory_parser.py b/src/emma_policy/inference/decoded_trajectory_parser.py deleted file mode 100644 index df42fc5..0000000 --- a/src/emma_policy/inference/decoded_trajectory_parser.py +++ /dev/null @@ -1,121 +0,0 @@ -import logging -from typing import Literal - -from emma_policy.inference.actions import ( - AgentAction, - get_lowercase_to_teach_object_map, - get_synonyms_to_teach_action_map, -) - - -logger = logging.getLogger(__name__) - - -ExecutionDomain = Literal["TEACh", "AI2THOR"] - - -class DecodedTrajectoryParser: - """Convert the decoded action trajectory from the model to execute on the given domain.""" - - def __init__( - self, - execution_domain: ExecutionDomain, - action_delimiter: str, - eos_token: str, - ) -> None: - self._execution_domain = execution_domain - self._action_delimiter = action_delimiter - - self._synonym_to_action_map = get_synonyms_to_teach_action_map() - self.eos_token = eos_token - self.lowercase_to_teach_objects = get_lowercase_to_teach_object_map() - - def __call__(self, decoded_trajectory: str) -> AgentAction: - """Converts a sequence of tokens into a list of executable actions.""" - logger.debug(f"Decoded trajectory: `{decoded_trajectory}`") - - decoded_actions_list = self._separate_decoded_trajectory(decoded_trajectory) - - if not decoded_actions_list or decoded_actions_list[0].endswith(self.eos_token): - # if the list is empty it means that we generated only the actio delimiter - # or if we have an action that ends with EOS - return AgentAction(action="Stop") - - return self._convert_action_to_executable_form(decoded_actions_list[0]) - - def _separate_decoded_trajectory(self, decoded_trajectory: str) -> list[str]: - """Split the decoded trajectory string into a list of action strings. - - Uses the given action delimiter (which is likely going to be the tokenizer SEP token). - - Also removes any blank strings from the list of actions. - """ - split_actions = decoded_trajectory.split(self._action_delimiter) - return [action for action in split_actions if action] - - def _get_teach_action_from_tokens(self, action_tokens: list[str]) -> tuple[str, list[str]]: - """Get the teach action from the decoded action string. - - Assumptions: - - The action appears at the start of the `decoded_action_string`. - - The action can be of a length more than 1. - - Example: - - If decoded_action == `forward`, then return `Forward` - - If decoded_action == `pickup mug`, then return `Pickup` - """ - parsed_action_name = None - - action_name = None - index = len(action_tokens) - while index > 0: - action_name = " ".join(action_tokens[:index]) - - if action_name in self._synonym_to_action_map: - parsed_action_name = action_name - break - - index -= 1 - - if parsed_action_name is None: - # edge case: we were not able to map the current action, just return an empty action - return "", action_tokens - - return self._synonym_to_action_map[parsed_action_name], action_tokens[index:] - - def _convert_action_to_executable_form(self, action_str: str) -> AgentAction: - """Convert the decoded action string into an executable form. - - We need to handle different cases: - - Index 0: Should be the TEACh API Action - - Index 1: Should be the object class (when available) - - Index 2: Should be the visual token (when available) - - We are assuming that the visual token will only ever be present after the object class. - """ - action_tokens = action_str.strip().split(" ") - - teach_action, teach_action_params = self._get_teach_action_from_tokens(action_tokens) - - object_label = None - object_visual_token = None - raw_object_label = None - - for action_param in teach_action_params: - action_param = action_param.strip() - - if action_param.startswith("<") and action_param.endswith(">"): - object_visual_token = action_param.strip() - else: - obj_label = self.lowercase_to_teach_objects.get(action_param) - if obj_label is not None: - object_label = self.lowercase_to_teach_objects[action_param] - else: - raw_object_label = action_param - - return AgentAction( - action=teach_action, - object_label=object_label, - raw_object_label=raw_object_label, - object_visual_token=object_visual_token, - ) diff --git a/src/emma_policy/inference/model_wrapper/__init__.py b/src/emma_policy/inference/model_wrapper/__init__.py index 60219b5..3c8a32e 100644 --- a/src/emma_policy/inference/model_wrapper/__init__.py +++ b/src/emma_policy/inference/model_wrapper/__init__.py @@ -1,6 +1,3 @@ -from emma_policy.inference.model_wrapper.base import BaseModelWrapper, SimulatorAction -from emma_policy.inference.model_wrapper.policy import PolicyModelWrapper from emma_policy.inference.model_wrapper.simbot_action_input_builder import ( SimBotActionInputBuilder, ) -from emma_policy.inference.model_wrapper.simbot_raw_text_matcher import SimBotActionRawTextMatcher diff --git a/src/emma_policy/inference/model_wrapper/base.py b/src/emma_policy/inference/model_wrapper/base.py deleted file mode 100644 index 487069d..0000000 --- a/src/emma_policy/inference/model_wrapper/base.py +++ /dev/null @@ -1,94 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, Optional - -from PIL.Image import Image -from pydantic import BaseModel - - -class SimulatorAction(BaseModel): - """Dictionary containing the previous action taken by the agent. - - Args: - action: Action taken by the agent in the environment. - obj_relative_coord: Relative (x, y) coordinate indicating the object in the image. - - Note: The TEACh wrapper on AI2-THOR examines the ground truth segmentation mask of the - agent's egocentric image, selects an object in a 10x10 pixel patch around the pixel - indicated by the coordinate if the desired action can be performed on it, and executes - the action in AI2-THOR. - """ - - action: Optional[str] - obj_relative_coord: Optional[tuple[float, float]] - - -class BaseModelWrapper(ABC): - """Base wrapper to use so that the model can communicate with the inference engine. - - This has been implemented in line with `teach.inference.teach_model.TeachModel` from the - `alexa/teach` repo: https://github.com/alexa/teach/blob/main/src/teach/inference/teach_model.py - - Documentation included any additional questions/discoveries found during implementation of the - inference engine. - """ - - @abstractmethod - def __init__(self, process_index: int, num_processes: int, model_args: list[str]) -> None: - """A model will be initialized for each evaluation process. - - Args: - process_index: Index of the process that LAUNCHED the model - num_processes: Total number of processes that are launched in parallel - model_args: Extra CLI arguments to `teach_eval` which get passed to the model. - How relevant is this for our model? - """ - - @abstractmethod - def get_next_action( - self, - img: Image, - edh_instance: Any, - prev_action: Optional[SimulatorAction], - img_name: Optional[str] = None, - edh_name: Optional[str] = None, - ) -> tuple[str, Optional[tuple[float, float]]]: - """Get the next predicted action from the model. - - Called at each timestep. - - Args: - img: Agent's egocentric image. - edh_instance: EDH Instance from the file - prev_action: Previous action taken by the agent, if any - img_name: File name of the image - edh_name: File name for the EDH instance - - Returns: - - action name from `all_agent_actions` - - obj_relative_coord: A relative (x, y) coordinate (values between 0 and 1) - indicating an object in the image - """ - raise NotImplementedError - - @abstractmethod - def start_new_edh_instance( - self, - edh_instance: Any, - edh_history_images: list[Image], - edh_name: Optional[str] = None, - ) -> bool: - """Start a new EDH instance, resetting the state and anything else. - - Called at the start of each EDH instance AFTER the environment has ben set to the initial - state, but before actions are requested from the model. - - Args: - edh_instance: EDH Instance from the file - edh_history_images: Images loaded from the files specified in - `edh_instance['driver_image_history']` - edh_name: File name for the EDH instance - - Returns: - True if successfully created an EDH instance - """ - raise NotImplementedError diff --git a/src/emma_policy/inference/model_wrapper/policy.py b/src/emma_policy/inference/model_wrapper/policy.py deleted file mode 100644 index cd0c690..0000000 --- a/src/emma_policy/inference/model_wrapper/policy.py +++ /dev/null @@ -1,467 +0,0 @@ -import dataclasses -import logging -from argparse import ArgumentParser -from pathlib import Path -from random import randint -from typing import Optional - -import numpy as np -import torch -from emma_datasets.datamodels.datasets import TeachEdhInstance -from PIL import Image -from pytorch_lightning import LightningModule -from transformers.generation_stopping_criteria import StoppingCriteriaList - -from emma_policy.datamodules.batch_attention_masks import make_batch_attention_masks -from emma_policy.datamodules.emma_dataclasses import ( - EmmaDatasetBatch, - EmmaDatasetItem, - EmmaDatasetPadding, -) -from emma_policy.inference.actions import AgentAction, teach_action_types -from emma_policy.inference.api.settings import ApiSettings -from emma_policy.inference.decoded_trajectory_parser import DecodedTrajectoryParser -from emma_policy.inference.model_wrapper.base import BaseModelWrapper, SimulatorAction -from emma_policy.inference.model_wrapper.stopping_criteria import ActionStopCriteria -from emma_policy.inference.model_wrapper.teach_edh_inference_dataset import ( - TeachEdhInferenceDataset, -) -from emma_policy.inference.model_wrapper.teach_edh_inference_state import EdhInstanceInferenceState -from emma_policy.models.emma_policy import EmmaPolicy - - -logger = logging.getLogger(__name__) - -IMAGE_SIMILARITY_ABSOLUTE_THRESHOLD = 1e-5 - - -class PolicyModelWrapper(BaseModelWrapper): - """Wrapper around the EMMA Policy model for performing inference.""" - - def __init__( - self, - process_index: int, - num_processes: int, - model_checkpoint_path: Path, - model_name: str = "heriot-watt/emma-base", - max_frames: int = 100, - max_target_len: int = 10, - max_lang_tokens: Optional[int] = None, - device_id: int = -1, - generation_num_beams: int = 1, - no_repeat_ngram_size: int = 0, - ) -> None: - - self._device = self._get_device(process_index, device_id) - self._model_name = model_name - self._model = self._setup_model(model_checkpoint_path) - - feature_extractor_endpoint = ApiSettings().feature_extractor_endpoint - logger.info(f"Using feature extractor API at `{feature_extractor_endpoint}`") - - self._teach_edh_inference_dataset = TeachEdhInferenceDataset.from_model_name( - model_name=model_name, - max_frames=max_frames, - max_lang_tokens=max_lang_tokens, - feature_extractor_endpoint=feature_extractor_endpoint, - ) - - self._tokenizer = self._teach_edh_inference_dataset.tokenizer - self._parse_decoded_trajectory = DecodedTrajectoryParser( - execution_domain="TEACh", action_delimiter=".", eos_token=self._tokenizer.eos_token - ) - - self._edh_instance_state = EdhInstanceInferenceState( - max_target_len, - max_target_len, - max_past_decoding_steps=max_frames - 1, - eos_token_id=self._tokenizer.eos_token_id, - ) - - self.action_stop = StoppingCriteriaList( - [ - ActionStopCriteria( - action_sep_token_id=self._tokenizer.sep_token_id, # type: ignore[arg-type] - eos_token_id=self._tokenizer.eos_token_id, # type: ignore[arg-type] - ) - ] - ) - self._generation_num_beams = generation_num_beams - self.no_repeat_ngram_size = no_repeat_ngram_size - - # Update the torch device used by the Perception API to ensure they're the same - self._teach_edh_inference_dataset.client.update_device(self._device) - self._action_types = teach_action_types() - - @classmethod - def from_argparse( - cls, process_index: int, num_processes: int, model_args: list[str] - ) -> "PolicyModelWrapper": - """Create the policy model from argparse.""" - arg_parser = ArgumentParser("EMMA Policy Model Wrapper") - - arg_parser.add_argument( - "--model_checkpoint_path", - type=Path, - required=True, - help="Path to the model checkpoint file.", - ) - arg_parser.add_argument( - "--model_name", - type=str, - default="heriot-watt/emma-base", - help="Name of the pretrained model to setup the correct checkpoint", - ) - arg_parser.add_argument( - "--device_id", type=int, default=-1, help="CPU/GPU device id. Use -1 for CPU" - ) - arg_parser.add_argument( - "--generation_num_beams", - type=int, - default=1, - help="Number of beams for beam search. 1 means no beam search.", - ) - arg_parser.add_argument( - "--max_frames", - type=int, - default=32, # noqa: WPS432 - help="Set max number of frames for the model to decode for.", - ) - arg_parser.add_argument( - "--max_target_len", - type=int, - default=14, # noqa: WPS432 - help="Set the max target tokens for each decoding step.", - ) - arg_parser.add_argument( - "--no_repeat_ngram_size", - type=int, - default=0, - help="if > 0 all ngrams of that size can occur only once.", - ) - parsed_model_args = arg_parser.parse_args(model_args) - - logger.debug(parsed_model_args) - - return cls( - process_index=process_index, - num_processes=num_processes, - model_checkpoint_path=parsed_model_args.model_checkpoint_path, - model_name=parsed_model_args.model_name, - device_id=parsed_model_args.device_id, - generation_num_beams=parsed_model_args.generation_num_beams, - max_frames=parsed_model_args.max_frames, - max_target_len=parsed_model_args.max_target_len, - no_repeat_ngram_size=parsed_model_args.no_repeat_ngram_size, - ) - - def start_new_edh_instance( - self, - edh_instance: TeachEdhInstance, - edh_history_images: list[Image.Image], - edh_name: Optional[str] = None, - ) -> bool: - """Reset the model ready for a new EDH instance.""" - self._teach_edh_inference_dataset.start_new_edh_instance(edh_instance, edh_history_images) - - self._edh_instance_state.reset_state() - - return True - - def get_next_action( - self, - img: Image, - edh_instance: TeachEdhInstance, - prev_action: Optional[SimulatorAction], - img_name: Optional[str] = None, - edh_name: Optional[str] = None, - ) -> tuple[str, Optional[tuple[float, float]]]: - """Get the next predicted action from the model. - - Called at each timestep. - - Args: - img: Agent's egocentric image. - edh_instance: EDH Instance from the file. - prev_action: Previous action taken by the agent, if any - img_name: File name of the image - edh_name: File name for the EDH instance - - Returns: - - action name for the TEACh API - - obj_relative_coord: A relative (x, y) coordinate (values between 0 and 1) - indicating an object in the image if available - """ - logger.info(f"Getting next action for EDH instance `{edh_instance.instance_id}`") - - dataset_instance = self._convert_edh_to_dataset_instance(current_frame=img) - - model_input_batch = self._create_model_input_from_dataset_instance(dataset_instance) - - output_token_ids = self._predict_action_trajectory(model_input_batch) - - next_action = self._parse_predicted_action(output_token_ids) - - is_agent_stuck = self._is_agent_stuck( - previous_action=prev_action, next_action=next_action, current_frame=img - ) - - if is_agent_stuck: - logger.debug(f"Agent is stuck: Previous `{prev_action}` -> Next `{next_action}`") - next_action, output_token_ids = self.handle_repeatedly_walking_into_obstruction( - model_input_batch - ) - - # TODO: Any other cases that need handling? - - self._edh_instance_state.update_state( - instance=dataset_instance, output_token_ids=output_token_ids - ) - - # Update the previous frame - self._teach_edh_inference_dataset.previous_frame = img - - return next_action.action, self._get_object_relative_coordinates_from_action( - next_action, dataset_instance - ) - - def handle_repeatedly_walking_into_obstruction( - self, model_input_batch: EmmaDatasetBatch - ) -> tuple[AgentAction, torch.Tensor]: - """Get the agent to turn if it keeps walking into an obstacle.""" - turn_token = self._tokenizer.encode( - " turn", add_special_tokens=False, return_tensors="pt" - )[0] - - extended_decoding_input_ids = self._edh_instance_state.decoding_input_ids.copy() - extended_decoding_input_ids.append(turn_token) - decoding_input_ids = torch.cat(extended_decoding_input_ids) - - output_token_ids = self._predict_action_trajectory( - model_input_batch, decoding_input_ids=decoding_input_ids - ) - new_action = self._parse_predicted_action(output_token_ids) - return new_action, output_token_ids - - def _is_agent_stuck( - self, - previous_action: Optional[SimulatorAction], - next_action: AgentAction, - current_frame: Image, - ) -> bool: - """Determine whether or not the agent is stuck. - - Perform two main checks: - - Does the previous frame match the current frame? - - Did the agent predict `Forward` again? - """ - is_agent_predicting_the_same_action = ( - previous_action is not None and previous_action.action == next_action.action - ) - prev_frame = np.asarray(self._teach_edh_inference_dataset.previous_frame) - current_frame = np.array(current_frame) - if prev_frame.shape != prev_frame.shape: - return False - - is_images_identical = np.allclose( - prev_frame, - current_frame, - atol=IMAGE_SIMILARITY_ABSOLUTE_THRESHOLD, - ) - - return is_agent_predicting_the_same_action and bool(is_images_identical) - - def _compute_center_from_bbox(self, bbox_coordinates: torch.Tensor) -> tuple[float, float]: - """Compute the centroid of a given bounding box. - - Args: - bbox_coordinates (torch.Tensor): Coordinates as XYXY (x1, y1, x2, y2) - - Returns: - Relative (x, y) coordinates of the center of the bounding box. - """ - x_center = ((bbox_coordinates[0] + bbox_coordinates[2]) / 2).item() - y_center = ((bbox_coordinates[1] + bbox_coordinates[3]) / 2).item() - return (y_center, x_center) - - def _setup_model(self, model_checkpoint_path: Path) -> LightningModule: - """Setup the model from the checkpoint.""" - model = EmmaPolicy(model_name=self._model_name) - model = model.load_from_checkpoint( - model_checkpoint_path.as_posix(), strict=False, map_location=self._device - ) - model.eval() - - return model - - def _get_device(self, process_index: int, device_id: int = -1) -> torch.device: - """Get the device for the model. - - This does it the exact same way they did it for ET, provided the device_id is not greater - than -1. - """ - if not torch.cuda.is_available(): - return torch.device("cpu") - - gpu_count = torch.cuda.device_count() - logger.info(f"{gpu_count} GPUs detected") - - model_device_id = device_id if device_id > -1 else process_index % gpu_count - - device = torch.device(f"cuda:{model_device_id}") - logger.info(f"Device used: {device}") - - return device - - def _convert_edh_to_dataset_instance(self, current_frame: Image.Image) -> EmmaDatasetItem: - """Convert the TEACh EDH instance to the EmmaDatasetItem for the model.""" - dataset_instance = self._teach_edh_inference_dataset.get_next_dataset_instance( - current_frame=current_frame - ) - # Add some dummy token ids for predicting the next action - extended_decoding_input_ids = self._edh_instance_state.decoding_input_ids.copy() - extended_decoding_input_ids.append( - torch.zeros(self._edh_instance_state.step_max_target_length, dtype=torch.int64) - ) - dataset_instance.target_token_ids = torch.cat(extended_decoding_input_ids) - # Add some dummy target temporal ids for next prediction - extended_target_temporal_ids = self._edh_instance_state.target_temporal_ids.copy() - extended_target_temporal_ids.append( - torch.full( - size=(self._edh_instance_state.step_max_target_length,), - fill_value=self._edh_instance_state.decoding_step, - dtype=torch.int64, - ) - ) - dataset_instance.target_temporal_ids = torch.cat(extended_target_temporal_ids) - dataset_instance.decoder_attention_mask = torch.ones_like( - dataset_instance.target_temporal_ids, dtype=torch.int64 - ) - return dataset_instance - - def _create_model_input_from_dataset_instance( - self, dataset_instance: EmmaDatasetItem - ) -> EmmaDatasetBatch: - """Create the batched input for the model from the dataset instance. - - Collate lists of samples into batches after padding. - """ - fields = dataclasses.fields(EmmaDatasetItem) - padding = EmmaDatasetPadding() - - raw_batch = { - field.name: getattr(dataset_instance, field.name).unsqueeze(0) - for field in fields - if getattr(dataset_instance, field.name) is not None - } - make_batch_attention_masks(raw_batch, padding_value=padding.attention_mask) - return EmmaDatasetBatch(**raw_batch) - - def _predict_action_trajectory( - self, model_input: EmmaDatasetBatch, decoding_input_ids: Optional[torch.Tensor] = None - ) -> torch.Tensor: - """Get the model to predict the action trajectory.""" - if decoding_input_ids is None: - decoding_input_ids = torch.cat(self._edh_instance_state.decoding_input_ids) - with torch.no_grad(): - output_token_ids = self._model.inference_step( # type: ignore[operator] - model_input, - decoder_input_ids=decoding_input_ids.unsqueeze(0), - max_length_per_action_sequence=self._edh_instance_state.total_max_target_length, - action_stop=self.action_stop, - num_beams=self._generation_num_beams, - no_repeat_ngram_size=self.no_repeat_ngram_size, - )[0] - - return output_token_ids - - def _parse_predicted_action(self, model_output_tokens: torch.Tensor) -> AgentAction: - """Convert the predicted action from the model into the actual action. - - If it's the first decoding step, ignore the initial special tokens (e.g. ). - """ - next_action_token_ids = model_output_tokens[ - self._edh_instance_state.previous_decoded_token_length : - ] - - if self._edh_instance_state.is_first_decoding_step: - next_action_token_ids = next_action_token_ids[1:] - - next_action_raw_string = self._tokenizer.decode( - next_action_token_ids, skip_special_tokens=False - ) - - next_action = self._parse_decoded_trajectory(next_action_raw_string) - - return next_action - - def _get_object_relative_coordinates_from_action( - self, action: AgentAction, teach_item: EmmaDatasetItem - ) -> Optional[tuple[float, float]]: - """Return relative (x, y) coordinates indicating the position of the object in the image. - - Note: The TEACh wrapper on AI2-THOR examines the ground truth segmentation mask of the - agent's egocentric image, selects an object in a 10x10 pixel patch around the pixel - indicated by the coordinate if the desired action can be performed on it, and - executes the action in AI2-THOR. - """ - action_type = self._action_types.get(action.action) - if action_type is not None and action_type != "ObjectInteraction": - return None - - # Attempt to index the object label - object_index = action.get_object_index_from_label( - bbox_probas=self._teach_edh_inference_dataset.get_current_object_probas() - ) - if object_index is not None: - logger.debug(f"Attempt to get object index from [b]label[/]: IDX `{object_index}`") - return self._compute_center_from_bbox( - bbox_coordinates=self._teach_edh_inference_dataset.get_current_coordinates()[ - object_index - ] - ) - - # Attempt to index the visual token - object_index = action.get_object_index_from_visual_token() - if object_index is not None: - logger.debug( - f"Attempt to get object index from [b]visual token[/]: IDX `{object_index}`" - ) - return self._compute_center_from_bbox( - bbox_coordinates=self._teach_edh_inference_dataset.get_current_coordinates()[ - object_index - ] - ) - - # Attempt to get an object with the most similar label - object_index = action.get_similarity_based_object_index( - bbox_probas=self._teach_edh_inference_dataset.get_current_object_probas() - ) - if object_index is not None: - logger.debug( - f"Attempt to get object index with [b]most similar visual token[/]: IDX `{object_index}`" - ) - return self._compute_center_from_bbox( - bbox_coordinates=teach_item.object_coordinates[object_index] - ) - - # Attempt to get an object with the most similar name that was not an AI2THOR label - object_index = action.get_similarity_based_raw_object_index( - bbox_probas=self._teach_edh_inference_dataset.get_current_object_probas() - ) - logger.debug( - f"Attempt to get object with [b]most similar name that is not an AI2THOR label[/]: IDX `{object_index}`" - ) - - # Pick a random object - if object_index is None: - object_index = randint( - 0, len(self._teach_edh_inference_dataset.get_current_coordinates()) - 1 - ) - logger.debug(f"Get [b]random object[/]: IDX `{object_index}`") - - return self._compute_center_from_bbox( - bbox_coordinates=self._teach_edh_inference_dataset.get_current_coordinates()[ - object_index - ], - ) diff --git a/src/emma_policy/inference/model_wrapper/simbot_action_input_builder.py b/src/emma_policy/inference/model_wrapper/simbot_action_input_builder.py index 1d1b20e..952ba31 100644 --- a/src/emma_policy/inference/model_wrapper/simbot_action_input_builder.py +++ b/src/emma_policy/inference/model_wrapper/simbot_action_input_builder.py @@ -1,5 +1,4 @@ import logging -import re from typing import Any, Optional import torch @@ -114,63 +113,6 @@ def __call__(self, request: EmmaPolicyRequest, task: Task) -> ActionBuilderOutpu logger.error(f"Found unsupported task: {task}") return (instruction, batch, decoder_input_ids, step_index) - def check_carrot_case(self, request: EmmaPolicyRequest) -> bool: - """Check if the previous action toggled the carrot machine.""" - if len(request.environment_history) < 2: - return False - - previous_action = request.environment_history[-2] - if previous_action.output is None: - return False - - if "" in previous_action.output: - return False - - return previous_action.output.startswith( - "toggle everything's a carrot machine Optional[str]: - """Check if the instruction refers to a sticky note.""" - features = request.environment_history[-1].features - entity_labels = features[0].entity_labels - - ignore_instruction = any( - [ - len(request.environment_history) > 1, - len(features) > 1, - entity_labels is None, - len(request.dialogue_history) > 1, - request.dialogue_history[-1].role == "agent", - ] - ) - if ignore_instruction: - return None - - patterns = "|".join( - [ - r"\S?sticky\s+", - r"\S?stickynote\s+", - r"\S?note\S?", - r"\S?clue\S?", - r"\S?hint\S?", - r"\S?postit\S?", - r"\S?posted\S?", - ] - ) - search_pattern = f"({patterns})" - search_result = re.search(search_pattern, request.dialogue_history[-1].utterance) - - if search_result is not None and "Sticky Note" in entity_labels: - vis_token = entity_labels.index("Sticky Note") + 1 - if is_action: - return f"goto sticky note ." - return f" " - - return None - def _prepare_decoder_input_ids( self, previous_actions: Optional[str] = None ) -> Optional[torch.Tensor]: diff --git a/src/emma_policy/inference/model_wrapper/simbot_action_output_processor.py b/src/emma_policy/inference/model_wrapper/simbot_action_output_processor.py index 798c039..5de100c 100644 --- a/src/emma_policy/inference/model_wrapper/simbot_action_output_processor.py +++ b/src/emma_policy/inference/model_wrapper/simbot_action_output_processor.py @@ -1,4 +1,3 @@ -import re from typing import Optional import torch @@ -22,10 +21,9 @@ def post_process_action(action: str) -> str: class SimBotActionPredictionProcessor: """Process SimBot Action predictions.""" - def __init__(self, enable_prediction_patching: bool = True) -> None: + def __init__(self) -> None: self._button_colors = ["blue", "green", "red"] self._stop_token = "" # noqa: S105 - self._enable_prediction_patching = enable_prediction_patching def __call__( self, @@ -38,46 +36,8 @@ def __call__( if instruction is None or entity_labels is None: return prediction - if not self._enable_prediction_patching: - return prediction - - if "frame_token" in prediction and "vis_token" in prediction: - prediction_after_robot_arm = self._special_robotics_lab_button_case( - instruction, prediction, entity_labels - ) - prediction_after_button = self._special_colorchanger_button_case( - instruction, prediction_after_robot_arm, entity_labels - ) - - prediction_after_special_monitor = self._special_monitor_toggle_case( - instruction, prediction_after_button, entity_labels - ) - - prediction_after_carrot = self._special_carrot_case( - prediction_after_special_monitor, entity_labels - ) - - prediction_after_cartridge = self._special_cartridge_case( - instruction, prediction_after_carrot, entity_labels - ) - - return prediction_after_cartridge return prediction - def _is_toggle_instruction(self, instruction: str) -> bool: - return any( - [ - " toggle " in instruction, - " activate " in instruction, - " turn " in instruction, - " switch " in instruction, - " flip " in instruction, - " push " in instruction, - " press " in instruction, - " use " in instruction, - ] - ) - def _get_detected_objects( self, frame_features: list[EmmaExtractedFeatures] ) -> Optional[list[str]]: @@ -87,166 +47,6 @@ def _get_detected_objects( class_labels = [label.lower() for label in class_labels] return class_labels - def _special_robotics_lab_button_case( # noqa: WPS231 - self, instruction: str, prediction: str, entity_labels: list[str] - ) -> str: - if "" not in prediction: - return prediction - - is_toggle_instruction = self._is_toggle_instruction(instruction) - button_in_instruction = "button" in instruction - - if is_toggle_instruction and button_in_instruction: - frame_token_id = self._get_frame_token_from_prediction(prediction) - token_id = None - if "robot arm" in entity_labels: - token_id = entity_labels.index("robot arm") + 1 - entity = "robot arm" - elif "emotion tester" in entity_labels: - token_id = entity_labels.index("emotion tester") + 1 - entity = "emotion tester" - elif "printer" in entity_labels: - token_id = entity_labels.index("printer") + 1 - entity = "printer" - elif "coffee unmaker" in entity_labels: - token_id = entity_labels.index("coffee unmaker") + 1 - entity = "coffee unmaker" - - # TODO: check if we should only replace the prediction when no computer is present - if token_id is not None and frame_token_id is not None: - return f"toggle {entity} ." - return prediction - - def _special_cartridge_case( - self, instruction: str, prediction: str, entity_labels: list[str] - ) -> str: - should_ignore = ( - "" not in prediction - or "cartridge" not in instruction - or "pickup" not in prediction - ) - if should_ignore: - return prediction - - vis_token = self._get_visual_token_from_prediction(prediction) - frame_token_id = self._get_frame_token_from_prediction(prediction) - if "printer cartridge" not in entity_labels or vis_token is None or frame_token_id is None: - return prediction - - new_vis_token = entity_labels.index("printer cartridge") + 1 - return f"pickup printer cartridge ." - - def _special_carrot_case(self, prediction: str, entity_labels: list[str]) -> str: - """Remove the token whenever we are toggling the carrot machine. - - There is a bug in the arena where the agent gets a visual effects frame as the next frame - whenever it tries to toggle the carrot machine. To handle this remove the stop token at the - current time step and at the next timestep make a dummy action. - """ - vis_token = self._get_visual_token_from_prediction(prediction) - - prediction_toggles_carrot_machine = ( - vis_token - and entity_labels[vis_token - 1] == "everything's a carrot machine" - and "toggle" in prediction - ) - frame_token_id = self._get_frame_token_from_prediction(prediction) - if prediction_toggles_carrot_machine and frame_token_id: - return f"toggle everything's a carrot machine ." - - # TODO: do we need force placing? - tried_to_pick_up_carrot_machine = ( - vis_token - and "pickup" in prediction - and entity_labels[vis_token - 1] == "everything's a carrot machine" - ) - if "carrot" in entity_labels and tried_to_pick_up_carrot_machine: - new_vis_token = entity_labels.index("carrot") + 1 - return ( - f"pickup carrot ." - ) - return prediction - - def _special_colorchanger_button_case( - self, instruction: str, prediction: str, entity_labels: list[str] - ) -> str: - if "" not in prediction: - return prediction - - frame_token_id = self._get_frame_token_from_prediction(prediction) - if frame_token_id is None: - return prediction - - pattern = r".*(the )?(red|blue|green)( one| button)?\.$" - match = re.search(pattern, instruction) - if match is not None: - color_result = re.search("(red|blue|green)", match.group()) - if color_result is not None: - color = color_result.group() - color_button = f"{color} button" - if color is not None: - if color_button in entity_labels: - token_id = entity_labels.index(color_button) + 1 # noqa: WPS220 - toggle_action = self._make_toggle( # noqa: WPS220 - "button", frame_token_id, token_id - ) - return toggle_action # noqa: WPS220 - - return prediction - - def _special_monitor_toggle_case( # noqa: WPS212, WPS231 - self, instruction: str, prediction: str, entity_labels: list[str] - ) -> str: - - is_toggle_instruction = self._is_toggle_instruction(instruction) - if not is_toggle_instruction or "" not in prediction: - return prediction - - # pickup bowl -> 11> 11> -> 11 - frame_token_id = self._get_frame_token_from_prediction(prediction) - if frame_token_id is None: - return prediction - - laser_condition = "laser monitor" in entity_labels - if "laser" in instruction and laser_condition: - token_id = entity_labels.index("laser monitor") + 1 - return self._make_toggle("freeze ray monitor", frame_token_id, token_id) - - freeze_ray_monitor_in_bbox = "freeze ray monitor" in entity_labels - if "freeze" in instruction and freeze_ray_monitor_in_bbox: - token_id = entity_labels.index("freeze ray monitor") + 1 - return self._make_toggle("freeze ray monitor", frame_token_id, token_id) - - gravity_flipper_monitor_in_bbox = "gravity monitor" in entity_labels - if "gravity" in instruction and gravity_flipper_monitor_in_bbox: - token_id = entity_labels.index("gravity monitor") + 1 - return self._make_toggle("gravity monitor", frame_token_id, token_id) - - embiggenator_monitor_in_bbox = "embiggenator monitor" in entity_labels - if "embiggenator" in instruction and embiggenator_monitor_in_bbox: - token_id = entity_labels.index("embiggenator monitor") + 1 - return self._make_toggle("embiggenator monitor", frame_token_id, token_id) - - is_portal_generator = "portal" in instruction or "generator" in instruction - portal_generator_monitor_in_bbox = "portal generator monitor" in entity_labels - if is_portal_generator and portal_generator_monitor_in_bbox: - token_id = entity_labels.index("portal generator monitor") + 1 - return self._make_toggle("portal generator monitor", frame_token_id, token_id) - return prediction - - def _make_toggle(self, object_class: str, frame_token: int, vis_token: int) -> str: - return f"toggle {object_class} {self._stop_token}." - - def _get_visual_token_from_prediction(self, prediction: str) -> Optional[int]: - if "")[0]) - return None - - def _get_frame_token_from_prediction(self, prediction: str) -> Optional[int]: - if "")[0]) - return None - class SimBotFindPredictionProcessor: """Process SimBot Find predictions.""" diff --git a/src/emma_policy/inference/model_wrapper/simbot_nlu_input_builder.py b/src/emma_policy/inference/model_wrapper/simbot_cr_input_builder.py similarity index 97% rename from src/emma_policy/inference/model_wrapper/simbot_nlu_input_builder.py rename to src/emma_policy/inference/model_wrapper/simbot_cr_input_builder.py index 70bf0b1..ee90ac9 100644 --- a/src/emma_policy/inference/model_wrapper/simbot_nlu_input_builder.py +++ b/src/emma_policy/inference/model_wrapper/simbot_cr_input_builder.py @@ -21,8 +21,8 @@ FeatureDictsType = list[dict[str, Any]] -class SimBotNLUInputBuilder: - """Build the input for the Emma NLU model.""" +class SimBotCRInputBuilder: + """Build the input for the Emma CR model.""" def __init__(self, tokenizer: PreTrainedTokenizer, device: str = "cpu") -> None: self._tokenizer = tokenizer @@ -66,7 +66,7 @@ def _prepare_input_text(self, request: EmmaPolicyRequest) -> tuple[BatchEncoding # Remove the QA instruction = instruction.split("<>")[0].strip() - logger.debug(f"Preparing NLU input for instruction: {instruction}") + logger.debug(f"Preparing CR input for instruction: {instruction}") source_text = f"Predict the system act: {instruction}" tokenized_instruction = self._tokenizer.encode_plus( source_text, return_tensors="pt", truncation=True diff --git a/src/emma_policy/inference/model_wrapper/simbot_nlu_output_processor.py b/src/emma_policy/inference/model_wrapper/simbot_cr_output_processor.py similarity index 62% rename from src/emma_policy/inference/model_wrapper/simbot_nlu_output_processor.py rename to src/emma_policy/inference/model_wrapper/simbot_cr_output_processor.py index 00056db..9fcc35d 100644 --- a/src/emma_policy/inference/model_wrapper/simbot_nlu_output_processor.py +++ b/src/emma_policy/inference/model_wrapper/simbot_cr_output_processor.py @@ -3,81 +3,36 @@ from emma_common.datamodels import EmmaExtractedFeatures -from emma_policy.datamodules.simbot_nlu_dataset import SimBotNLUIntents +from emma_policy.datamodules.simbot_cr_dataset import SimBotCRIntents -class SimBotNLUPredictionProcessor: - """Process SimBot NLU predictions.""" +class SimBotCRPredictionProcessor: + """Process SimBot CR predictions.""" def __init__( self, valid_action_types: list[str], default_prediction: str, disable_missing_inventory: bool = False, - enable_prediction_patching: bool = True, ) -> None: self.valid_action_types = valid_action_types self._disable_missing_inventory = disable_missing_inventory self._default_prediction = default_prediction - self._enable_prediction_patching = enable_prediction_patching - def __call__( # noqa: WPS231 - self, instruction: str, prediction: str, frame_features: list[EmmaExtractedFeatures] - ) -> str: + def __call__(self, prediction: str) -> str: """Process the prediction.""" disable_missing_invetory = ( - prediction.startswith(SimBotNLUIntents.act_missing_inventory.value) + prediction.startswith(SimBotCRIntents.act_missing_inventory.value) and self._disable_missing_inventory ) if disable_missing_invetory: return self._default_prediction - sticky_note_case = self._check_sticky_note(instruction, prediction, frame_features) - if sticky_note_case is not None: - return sticky_note_case - object_name = self._get_target_object(prediction) if object_name is None: return prediction - if self._enable_prediction_patching: - class_labels = self._get_detected_objects(frame_features=frame_features) - if prediction.startswith(SimBotNLUIntents.act_no_match.value): - prediction = self._special_robotics_lab_button_case( - prediction=prediction, - class_labels=class_labels, - ) - - prediction = self._special_carrot_machine_case( - instruction=instruction, - prediction=prediction, - class_labels=class_labels, - ) - - prediction = self._special_color_changer_case( - instruction=instruction, - prediction=prediction, - class_labels=class_labels, - ) - elif prediction.startswith(SimBotNLUIntents.act_too_many_matches.value): - prediction = self._rule_based_ambiguity_check( - prediction=prediction, - class_labels=class_labels, - object_name=object_name, - ) - elif prediction.startswith(SimBotNLUIntents.search.value): - prediction = self._special_color_changer_case( - instruction=instruction, - prediction=prediction, - class_labels=class_labels, - ) - if prediction.startswith(SimBotNLUIntents.act.value): - prediction = self._special_monitor_toggle_case( - instruction=instruction, - prediction=prediction, - class_labels=class_labels, - ) - new_prediction = self._overwrite_the_nlu_prediction(prediction, object_name) + new_prediction = self._overwrite_the_cr_prediction(prediction, object_name) return new_prediction def _prediction_type_is_valid(self, prediction: str) -> bool: @@ -85,21 +40,21 @@ def _prediction_type_is_valid(self, prediction: str) -> bool: prediction_type = prediction.split(" ")[0] return prediction_type in self.valid_action_types - def _overwrite_the_nlu_prediction(self, prediction: str, object_name: Optional[str]) -> str: - """Check if the predicted NLU output needs to be overwritten.""" + def _overwrite_the_cr_prediction(self, prediction: str, object_name: Optional[str]) -> str: + """Check if the predicted CR output needs to be overwritten.""" # If the predicted prediction is not valid return the default prediction if not self._prediction_type_is_valid(prediction): return self._default_prediction # For search intents only return object_name - if prediction.startswith(SimBotNLUIntents.search.value): - return f"{SimBotNLUIntents.search.value} {self._get_target_object(prediction)}" + if prediction.startswith(SimBotCRIntents.search.value): + return f"{SimBotCRIntents.search.value} {self._get_target_object(prediction)}" # For act one_match intents only return - if prediction.startswith(SimBotNLUIntents.act_one_match.value): - return SimBotNLUIntents.act_one_match.value + if prediction.startswith(SimBotCRIntents.act_one_match.value): + return SimBotCRIntents.act_one_match.value return prediction def _get_target_object(self, prediction: str) -> Optional[str]: - """Extract the target object from the NLU prediction.""" + """Extract the target object from the CR prediction.""" split_parts = prediction.split(" ") return " ".join(split_parts[1:]) if len(split_parts) > 1 else None @@ -116,7 +71,7 @@ def _rule_based_ambiguity_check( self, prediction: str, class_labels: Optional[list[str]], object_name: str ) -> str: """Change too_many_matches prediction if there is one detected object.""" - # For now, overwrite the NLU only if there are no multiples in front of you + # For now, overwrite the CR only if there are no multiples in front of you # So if there's only one object that you are looking at, assume no ambiguity if class_labels is None: return prediction @@ -242,41 +197,3 @@ def _special_monitor_toggle_case( # noqa: WPS212, WPS231 return " portal generator monitor" return prediction - - def _check_sticky_note( - self, instruction: str, prediction: str, frame_features: list[EmmaExtractedFeatures] - ) -> Optional[str]: - """Check if the instruction refers to a sticky note.""" - entity_labels = frame_features[0].entity_labels - - ignore_instruction = any( - [ - len(frame_features) > 1, - entity_labels is None, - ] - ) - if ignore_instruction: - return None - - patterns = "|".join( - [ - r"\S?sticky\s+", - r"\S?stickynote\s+", - r"\S?note\S?", - r"\S?clue\S?", - r"\S?hint\S?", - r"\S?postit\S?", - r"\S?posted\S?", - ] - ) - search_pattern = f"({patterns})" - search_result = re.search(search_pattern, instruction) - - if search_result is None: - return None - - if prediction.startswith(SimBotNLUIntents.search.value): - return f"{SimBotNLUIntents.search.value} sticky note" - if "Sticky Note" in entity_labels: - return self._default_prediction - return " sticky note" diff --git a/src/emma_policy/inference/model_wrapper/simbot_raw_text_matcher.py b/src/emma_policy/inference/model_wrapper/simbot_raw_text_matcher.py deleted file mode 100644 index 7fbbed6..0000000 --- a/src/emma_policy/inference/model_wrapper/simbot_raw_text_matcher.py +++ /dev/null @@ -1,86 +0,0 @@ -import json -import logging -import re -from pathlib import Path -from typing import Optional - -from emma_common.datamodels import EmmaPolicyRequest, SpeakerRole - -from emma_policy.utils.simbot_raw_text_matching import levenshtein_distance - - -logger = logging.getLogger(__name__) - - -# deprecated! -class SimBotActionRawTextMatcher: - """Simple raw text matcher used to minimise latency cost for trivial actions.""" - - def __init__(self, raw_text_match_json: Path, distance_threshold: int = 2) -> None: - with open(raw_text_match_json) as fp: - self.raw_text_matching = json.load(fp) - self.distance_threshold = distance_threshold - self._wake_words = ["Alexa", "Amazon", "Echo", "Computer", "Ziggy"] - - def __call__(self, input_request: EmmaPolicyRequest) -> Optional[str]: - """Process the input request.""" - if len(input_request.environment_history) > 1: - logger.warning( - "Received environment history for raw text match action prediction. This will be ignored." - ) - - if len(input_request.dialogue_history) >= 2: - logger.warning( - "Received multiple turns in the dialogue history. Only the first one will be considered." - ) - - request_utterance = input_request.dialogue_history[0] - if request_utterance.role != SpeakerRole.user: - logger.debug( - f"The curret request does not have a user utterance: {input_request}. Returning None." - ) - return None - processed_str = self.preprocess_text(request_utterance.utterance) - for action, action_metadata in self.raw_text_matching.items(): - action_templates = action_metadata["examples"] - min_distance_for_action = min( - [ - levenshtein_distance(processed_str, action_template) - for action_template in action_templates - ] - ) - - if min_distance_for_action < self.distance_threshold: - output_string = self.postprocess_text(self.raw_text_matching[action]["command"]) - logger.debug(f"Matched input request to raw output action {output_string}.") - return output_string - logger.debug("Could not match input request to raw output action.") - return None - - def preprocess_text(self, input_string: str) -> str: - """Preprocess the raw input string.""" - new_string = re.sub(r"[^\w\s]", "", input_string) - new_string = new_string.strip().lower() - - # Remove wake words - for wake_word in self._wake_words: - new_string = new_string.replace(wake_word.lower(), "") - - new_string = new_string.replace("okay", "") - - # Remove polite intros - new_string = new_string.replace("can you", "") - new_string = new_string.replace("can you please", "") - new_string = new_string.replace("could you", "") - new_string = new_string.replace("could you please", "") - new_string = new_string.replace("please", "") - - new_string = " ".join(new_string.split()) - return new_string - - def postprocess_text(self, output_string: str) -> str: - """Postprocess the output string. - - This should return a string that is suitable to handle by the experience hub. - """ - return f"{output_string.lower()} ." diff --git a/src/emma_policy/inference/model_wrapper/teach_edh_inference_dataset.py b/src/emma_policy/inference/model_wrapper/teach_edh_inference_dataset.py deleted file mode 100644 index d8ce702..0000000 --- a/src/emma_policy/inference/model_wrapper/teach_edh_inference_dataset.py +++ /dev/null @@ -1,238 +0,0 @@ -import logging -from typing import Any, Optional - -import torch -from emma_datasets.datamodels.datasets import TeachEdhInstance -from overrides import overrides -from PIL.Image import Image -from pydantic import AnyHttpUrl -from transformers import AutoTokenizer, BatchEncoding, PreTrainedTokenizer - -from emma_policy.api.clients import FeatureExtractorClient -from emma_policy.datamodules.base_dataset import prepare_emma_visual_features -from emma_policy.datamodules.emma_dataclasses import EmmaDatasetItem, EmmaVisualFeatures -from emma_policy.datamodules.pretrain_instances import Task -from emma_policy.datamodules.teach_edh_dataset import TeachEdhDataset - - -logger = logging.getLogger(__name__) - - -class TeachEdhInferenceDataset(TeachEdhDataset): - """TeachEdh Dataset for inference.""" - - def __init__( - self, - tokenizer: PreTrainedTokenizer, - feature_extractor_endpoint: AnyHttpUrl, - max_frames: int = 100, - ) -> None: - # This is what is expected by the `TeachEdhDataset` - self.tokenizer = tokenizer - self.max_frames = max_frames - self.shuffle_objects = False - self.previous_frame: Optional[Image] = None - - self.client = FeatureExtractorClient(feature_extractor_endpoint) - - self._trajectory_visual_features: list[EmmaVisualFeatures] = [] - self._history_visual_features: EmmaVisualFeatures - self._original_history_length: int = -1 - self._feature_dicts: list[dict[str, Any]] = [] - self._input_encoding: BatchEncoding - self._current_bbox_probas: Optional[torch.Tensor] - self._current_coordinates: Optional[torch.Tensor] - - @classmethod - def from_model_name( - cls, - model_name: str, - feature_extractor_endpoint: AnyHttpUrl, - max_frames: int = 0, - max_lang_tokens: Optional[int] = None, - ) -> "TeachEdhInferenceDataset": - """Instantiate TeachEdhInferenceDataset.""" - tokenizer = AutoTokenizer.from_pretrained(model_name) - if max_lang_tokens: - tokenizer.model_max_length = max_lang_tokens - - return cls( - tokenizer=tokenizer, - max_frames=max_frames, - feature_extractor_endpoint=feature_extractor_endpoint, - ) - - def __len__(self) -> int: - """Return the total number of instances within the database.""" - return 1 - - @overrides(check_signature=False) - def __getitem__(self, index: int) -> None: - """Get the single instance during inference.""" - raise NotImplementedError("Dont call __getitem__") - - def start_new_edh_instance( - self, - edh_instance: TeachEdhInstance, - edh_history_images: list[Image], - edh_name: Optional[str] = None, - ) -> bool: - """Clear the state and start a new EDH instance.""" - logger.debug(f"Preparing visual features for `{edh_instance.instance_id}`") - self._original_history_length = min(self.max_frames, len(edh_history_images)) - edh_history_images = edh_history_images[: self._original_history_length] - self._feature_dicts = [ - {"width": image.size[0], "height": image.size[1]} for image in edh_history_images - ] - - self._history_visual_features = self.prepare_visual_features(edh_history_images) - self._trajectory_visual_features = [] - - logger.debug(f"Tokenizing input text `{edh_instance.instance_id}`") - self._input_encoding = self.tokenizer( - self._get_input_text_from_instance(edh_instance, self._history_visual_features), - return_tensors=self._return_tensor_type, - truncation=True, - ) - self._current_bbox_probas = None - self._current_coordinates = None - self.previous_frame = edh_history_images[-1] - - logger.debug(f"Model prepared `{edh_instance.instance_id}`") - return True - - def get_next_dataset_instance(self, current_frame: Image) -> EmmaDatasetItem: - """Get the emma input given the current egocentric view.""" - return self._convert_instance_to_emma_dataset_item(current_frame) - - def get_current_object_probas(self) -> torch.Tensor: - """Return the bounding box probabilities from the current egocentric view.""" - if self._current_bbox_probas is None: - raise AssertionError( - "Do not try to get current object probabilities before calling `get_next_dataset_instance`" - ) - - return self._current_bbox_probas - - def get_current_coordinates(self) -> torch.Tensor: # noqa: WPS615 - """Return the bbox coordinates from the current egocentric view.""" - if self._current_coordinates is None: - raise AssertionError( - "Do not try to get current object probabilities before calling `get_next_dataset_instance`" - ) - return self._current_coordinates - - def prepare_visual_features( - self, edh_history_images: list[Image], start_offset: int = 0 - ) -> EmmaVisualFeatures: - """Prepare an EmmaVisualFeatures object.""" - if self.max_frames: - edh_history_images = edh_history_images[-self.max_frames :] - - # TODO: make this work in batches - logger.debug("Building the feature dicts") - feature_dicts: list[dict[str, Any]] = [] - - for idx, edh_history_image in enumerate(edh_history_images): - logger.debug(f"Requesting features for image {idx+1}/{len(edh_history_images)}") - - feature_response = self.client.extract_single_image(edh_history_image).dict() - feature_response["width"] = edh_history_image.size[0] - feature_response["height"] = edh_history_image.size[1] - - feature_dicts.append(feature_response) - - self._current_bbox_probas = feature_dicts[-1]["bbox_probas"] - self._current_coordinates = feature_dicts[-1]["bbox_coords"] - - logger.debug("Converting feature dicts to `EmmaVisualFeatures` object") - return prepare_emma_visual_features( - feature_dicts=feature_dicts, tokenizer=self.tokenizer, start_offset=start_offset - ) - - def _prepare_visual_input( - self, current_frame: Image - ) -> tuple[EmmaVisualFeatures, torch.Tensor, torch.Tensor]: - """Load history and future visual features and compute temporal ids.""" - offset = self._original_history_length + len(self._trajectory_visual_features) - # Update the features seen in the trajectory - self._trajectory_visual_features.append( - self.prepare_visual_features(edh_history_images=[current_frame], start_offset=offset) - ) - self._trajectory_visual_features = self._truncate_frames( - self._trajectory_visual_features, truncation_side="left" - ) - # Fix frame tokens after truncation - for idx, frame_features in enumerate(self._trajectory_visual_features): - new_frame_token = self.tokenizer.convert_tokens_to_ids( - f"" - ) - self._trajectory_visual_features[idx].scene_frame_tokens = torch.tensor( - [new_frame_token] - ) - self._trajectory_visual_features[idx].object_frame_tokens = torch.tensor( - [new_frame_token] * frame_features.object_frame_tokens.shape[0], # noqa: WPS435 - ) - - # Concatenate history and trajectory tokens - visual_features_list = [self._history_visual_features] + self._trajectory_visual_features - visual_features = self._concat_visual_features(visual_features_list) - - scene_temporal_ids, object_temporal_ids = self._make_image_temporal_ids( - feature_len_history=self._original_history_length, - feature_len_future=len(self._trajectory_visual_features), - object_frame_tokens=visual_features.object_frame_tokens, - ) - return visual_features, scene_temporal_ids, object_temporal_ids - - def _convert_instance_to_emma_dataset_item(self, current_frame: Image) -> EmmaDatasetItem: - """Convert the EDH instance to an instance of `EmmaDatasetItem`.""" - visual_features, scene_temporal_ids, object_temporal_ids = self._prepare_visual_input( - current_frame=current_frame - ) - - return EmmaDatasetItem( - # Language - input_token_ids=self._input_encoding.input_ids.squeeze(0), - text_attention_mask=self._input_encoding.attention_mask.squeeze(0), - # Visual features - object_attention_mask=visual_features.object_attention_mask, - object_coordinates=visual_features.object_coordinates, - object_features=visual_features.object_features, - object_frame_tokens=visual_features.object_frame_tokens, - scene_attention_mask=visual_features.scene_attention_mask, - scene_coordinates=visual_features.scene_coordinates, - scene_features=visual_features.scene_features, - scene_frame_tokens=visual_features.scene_frame_tokens, - visual_token_ids=visual_features.visual_token_ids, - scene_temporal_ids=scene_temporal_ids, - object_temporal_ids=object_temporal_ids, - # Task - task=self._get_task_as_tensor(Task.action_execution), - ) - - def _get_input_text_from_instance( - self, instance: TeachEdhInstance, visual_features: EmmaVisualFeatures - ) -> str: - """Get the input text from a TEACh EDH instance.""" - input_text = self._get_concatenated_dialog_history(instance) - - actions = self._convert_trajectory_to_text( - actions=instance.extended_driver_action_history, - feature_dicts=self._feature_dicts, - visual_features=visual_features, - truncation_side="left", # keep most recent actions - ) - - if actions: - input_text = "{input_text} {sep_token} {action_trajectory}".format( - input_text=input_text, - sep_token=self.tokenizer.sep_token, - action_trajectory=actions, - ) - - # Add action execution task prefix - input_text = self._get_random_template_for_task(Task.action_execution).format( - instruction=input_text, - ) - return input_text diff --git a/src/emma_policy/inference/model_wrapper/teach_edh_inference_state.py b/src/emma_policy/inference/model_wrapper/teach_edh_inference_state.py deleted file mode 100644 index 2963c09..0000000 --- a/src/emma_policy/inference/model_wrapper/teach_edh_inference_state.py +++ /dev/null @@ -1,84 +0,0 @@ -from typing import Optional - -import torch - -from emma_policy.datamodules.emma_dataclasses import EmmaDatasetItem - - -class EdhInstanceInferenceState: - """EDH Instance state used during inference.""" - - def __init__( - self, - step_max_target_length: int, - total_max_target_length: int, - max_past_decoding_steps: int = 99, - decoding_step: int = 1, - decoding_input_ids: Optional[list[torch.Tensor]] = None, - target_temporal_ids: Optional[list[torch.Tensor]] = None, - eos_token_id: Optional[int] = 2, - ) -> None: - self.step_max_target_length = step_max_target_length - self.total_max_target_length = total_max_target_length - self.max_past_decoding_steps = max_past_decoding_steps - self.decoding_step = decoding_step - self._eos_token_id = eos_token_id - self.decoding_input_ids: list[torch.Tensor] = ( - decoding_input_ids - if decoding_input_ids is not None - else [torch.tensor([self._eos_token_id], dtype=torch.int64)] - ) - self.target_temporal_ids: list[torch.Tensor] = ( - target_temporal_ids - if target_temporal_ids is not None - else [torch.empty(0, dtype=torch.int64)] - ) - - self.reset_state() - - @property - def previous_decoded_token_length(self) -> int: - """Get the length of the previously decoded tokens.""" - return sum(len(tensor) for tensor in self.decoding_input_ids) - - @property - def is_first_decoding_step(self) -> bool: - """Return True if it is currently the first decoding step.""" - return self.decoding_step == 1 - - def reset_state(self) -> None: - """Reset the EDH state.""" - self.decoding_step = 1 - self.decoding_input_ids = [torch.tensor([self._eos_token_id], dtype=torch.int64)] - self.target_temporal_ids = [torch.empty(0, dtype=torch.int64)] - self.total_max_target_length = self.step_max_target_length - - def update_state(self, instance: EmmaDatasetItem, output_token_ids: torch.Tensor) -> None: - """Update the state to prepare for the next prediction.""" - new_token_length = output_token_ids.shape[0] - self.previous_decoded_token_length - # Fix the target token ids. Append to "past values" only the ones that were generated - self.decoding_input_ids.append(output_token_ids[-new_token_length:]) - self.decoding_input_ids = self.decoding_input_ids[-self.max_past_decoding_steps :] - - # Fix the target temporal ids. Append the step number to the positions of the generated - # output - if instance.target_temporal_ids is not None: - self.target_temporal_ids.append( - torch.full( - size=(new_token_length,), - fill_value=self.decoding_step, - dtype=torch.int64, - ) - ) - # Make sure that when truncating target temporal ids start from 1 - if len(self.target_temporal_ids) > self.max_past_decoding_steps: - self.target_temporal_ids = self.target_temporal_ids[ - -self.max_past_decoding_steps : - ] - for idx, temporal_ids in enumerate(self.target_temporal_ids, 1): - self.target_temporal_ids[idx - 1] = torch.full_like(temporal_ids, idx) - - # Increase the decoding step counter - self.decoding_step = min(self.decoding_step + 1, self.max_past_decoding_steps + 1) - # Update the total_max_target_length for the next decoding step - self.total_max_target_length = output_token_ids.shape[0] + self.step_max_target_length diff --git a/src/emma_policy/models/simbot_nlu_policy.py b/src/emma_policy/models/simbot_cr_policy.py similarity index 92% rename from src/emma_policy/models/simbot_nlu_policy.py rename to src/emma_policy/models/simbot_cr_policy.py index 93289e2..d216e60 100644 --- a/src/emma_policy/models/simbot_nlu_policy.py +++ b/src/emma_policy/models/simbot_cr_policy.py @@ -14,10 +14,10 @@ ) from emma_policy.datamodules.emma_dataclasses import EmmaDatasetBatch -from emma_policy.datamodules.simbot_nlu_datamodule import prepare_nlu_tokenizer +from emma_policy.datamodules.simbot_cr_datamodule import prepare_cr_tokenizer from emma_policy.models.emma_policy import EmmaPolicy from emma_policy.models.model_output_emma import EmmaSeq2SeqLMOutput -from emma_policy.utils.simbot_nlu_metrics import SimbotActionTypeF1, SimbotNLUExactMatch +from emma_policy.utils.simbot_cr_metrics import SimbotActionTypeF1, SimbotCRExactMatch PredictType = Union[ @@ -32,7 +32,7 @@ ForcedWordIdsList = list[list[list[int]]] -def postprocess_nlu_output(tokenizer: PreTrainedTokenizer, output: PredictType) -> list[str]: +def postprocess_cr_output(tokenizer: PreTrainedTokenizer, output: PredictType) -> list[str]: """Remove special tokens from predicted outputs.""" special_tokens = [ tokenizer.bos_token, @@ -53,7 +53,7 @@ def remove_sequence_special_tokens(sentence: str, special_tokens: list[str]) -> return sentence -class SimBotNLUEmmaPolicy(EmmaPolicy): +class SimBotCREmmaPolicy(EmmaPolicy): """Emma Lightning Module.""" def __init__( @@ -64,12 +64,12 @@ def __init__( save_results_path: Optional[Path] = None, **kwargs: Any, ) -> None: - super().__init__(model_name=f"{model_name}-nlu", **kwargs) + super().__init__(model_name=f"{model_name}-cr", **kwargs) self.model_name = model_name self._question_answers: dict[str, list[str]] = {"predictions": [], "references": []} self._num_beams = num_beams - self._tokenizer = prepare_nlu_tokenizer(model_name=model_name) + self._tokenizer = prepare_cr_tokenizer(model_name=model_name) self._min_length = 1 self._max_generated_text_length = max_generated_text_length @@ -95,7 +95,7 @@ def __init__( self._num_beams += 1 # constrains need num_beams > 1 self.task_metrics = None # type: ignore[assignment] self.validation_action_type_F1 = SimbotActionTypeF1(tokenizer=self._tokenizer) - self.validation_accuracy = SimbotNLUExactMatch() + self.validation_accuracy = SimbotCRExactMatch() self._results_path = save_results_path self._test_results: dict[str, list[str]] = { @@ -168,7 +168,7 @@ def test_step(self, batch: EmmaDatasetBatch, batch_idx: int) -> None: self._test_results["groundtruths"].extend( [sample["references"] for sample in batch.raw_target] # type: ignore[union-attr] ) - sent = postprocess_nlu_output(self._tokenizer, prediction_output) + sent = postprocess_cr_output(self._tokenizer, prediction_output) self._test_results["predictions"].extend(sent) @overrides(check_signature=False) @@ -200,7 +200,7 @@ def predict_step(self, batch: EmmaDatasetBatch, batch_idx: int) -> PredictType: def inference_step(self, batch: EmmaDatasetBatch, batch_idx: int = 0) -> PredictType: """Inference step.""" return self.predict_step(batch, batch_idx) - # return postprocess_nlu_output(self.tokenizer, output_tokens) + # return postprocess_cr_output(self.tokenizer, output_tokens) def compute_metrics(self, prediction_output: torch.Tensor, batch: EmmaDatasetBatch) -> None: """Compute the evaluation metrics.""" diff --git a/src/emma_policy/utils/simbot_nlu_metrics.py b/src/emma_policy/utils/simbot_cr_metrics.py similarity index 98% rename from src/emma_policy/utils/simbot_nlu_metrics.py rename to src/emma_policy/utils/simbot_cr_metrics.py index 9ce0ed6..732530a 100644 --- a/src/emma_policy/utils/simbot_nlu_metrics.py +++ b/src/emma_policy/utils/simbot_cr_metrics.py @@ -6,7 +6,7 @@ from transformers import PreTrainedTokenizer -class SimbotNLUExactMatch(Metric): +class SimbotCRExactMatch(Metric): """Loss for a exact match.""" def __init__(self, dist_sync_on_step: bool = True, threshold: float = 0.5) -> None: diff --git a/storage/constants/simbot_low_level_examples.json b/storage/constants/simbot_low_level_examples.json deleted file mode 100644 index b0f63e0..0000000 --- a/storage/constants/simbot_low_level_examples.json +++ /dev/null @@ -1,441 +0,0 @@ -{ - "GoToQuantumLab": { - "examples": [ - "go back to the quantum app", - "go back to the quantum lab", - "go back to the quantum rap", - "go to the quantum app", - "go to the quantum lab", - "go to the quantum rap", - "head to the quantum app", - "head to the quantum lab", - "head to the quantum rap", - "move to the quantum app", - "move to the quantum lab", - "move to the quantum rap", - "return to the quantum app", - "return to the quantum lab", - "return to the quantum rap" - ], - "command": "Goto Lab2" - }, - "GoToRoboticsLab": { - "examples": [ - "go back to the robotic slab", - "go back to the robotics app", - "go back to the robotics lab", - "go back to the robotics rap", - "go to the robotic slab", - "go to the robotics app", - "go to the robotics lab", - "go to the robotics rap", - "head to the robotic slab", - "head to the robotics app", - "head to the robotics lab", - "head to the robotics rap", - "move to the robotic slab", - "move to the robotics app", - "move to the robotics lab", - "move to the robotics rap", - "return to the robotic slab", - "return to the robotics app", - "return to the robotics lab", - "return to the robotics rap" - ], - "command": "Goto Lab1" - }, - "GoToBreakRoom": { - "examples": [ - "go back to the break room", - "go back to the breakroom", - "go to the break room", - "go to the breakroom", - "head to the break room", - "head to the breakroom", - "move to the break room", - "move to the breakroom", - "return to the break room", - "return to the breakroom" - ], - "command": "Goto BreakRoom" - }, - "GoToMainOffice": { - "examples": [ - "go back to the big office", - "go back to the main office", - "go to the big office", - "go to the main office", - "head to the big office", - "head to the main office", - "move to the big office", - "move to the main office", - "return to the big office", - "return to the main office" - ], - "command": "Goto MainOffice" - }, - "GoToSmallOffice": { - "examples": [ - "go back to the small office", - "go to the small office", - "head to the small office", - "move to the small office", - "return to the small office" - ], - "command": "Goto SmallOffice" - }, - "GoToReception": { - "examples": [ - "go back to the reception", - "go to the reception", - "head to the reception", - "move to the reception", - "return to the reception" - ], - "command": "Goto Reception" - }, - "GoToWarehouse": { - "examples": [ - "go back to the warehouse", - "go to the warehouse", - "head to the warehouse", - "move to the warehouse", - "return to the warehouse" - ], - "command": "Goto Warehouse" - }, - "Movefoward": { - "examples": [ - "advance forward", - "continue a bit forward", - "continue ahead", - "continue forward", - "continue moving forward", - "continue walking forward", - "for a few steps, move forward", - "forward", - "go forward", - "go straight", - "head forward", - "keep going forward", - "keep moving forward", - "move a bit forward", - "move a few steps forward", - "move a little forward", - "move ahead", - "move forward", - "move forward a bit", - "move forward for a few steps", - "move slightly forward", - "onward", - "roll ahead", - "roll forward", - "run ahead", - "run forward", - "step forward", - "step forward a bit", - "straight", - "take a step", - "take a step forward", - "walk a bit forward", - "walk a few steps forward", - "walk a few steps onwards", - "walk ahead", - "walk ahead for a few steps", - "walk forward", - "walk forward for a few steps", - "walk slightly forward", - "walk straight", - "walk towards forward", - "walk towards the front" - ], - "command": "Move Forward" - }, - "MoveBackward": { - "examples": [ - "backward", - "continue a bit backwards", - "continue backwards", - "continue moving backward", - "continue walking backwards", - "for a few steps, move backwards", - "go backwards", - "move a bit back", - "move a bit backwards", - "move a little back", - "move a little backwards", - "move back", - "move back a bit", - "move back a few steps", - "move back a little", - "move back for a few steps", - "move backward", - "move backwards a bit", - "move backwards a few steps", - "move backwards a little", - "move backwards for a few steps", - "move slightly back", - "move slightly backwards", - "step back", - "step back a bit", - "take a step back", - "take a step backwards", - "walk a bit back", - "walk a bit backwards", - "walk a few steps back", - "walk a few steps backwards", - "walk back", - "walk back a bit", - "walk back a few steps", - "walk back for a few steps", - "walk back slightly", - "walk backwards", - "walk backwards a few steps", - "walk backwards for a few steps", - "walk slightly back", - "walk slightly backwards" - ], - "command": "Move Backward" - }, - "Rotate Right": { - "examples": [ - "go right", - "go rite", - "go white", - "go write", - "look right", - "look rite", - "look white", - "look write", - "make a right turn", - "make a rite turn", - "make a white turn", - "make a write turn", - "move right", - "move rite", - "move white", - "move write", - "right turn", - "rite turn", - "rotate a bit to the right", - "rotate a bit to the rite", - "rotate a bit to the white", - "rotate a bit to the write", - "rotate right", - "rotate rite", - "rotate to the right", - "rotate to the rite", - "rotate to the white", - "rotate to the write", - "rotate to your right", - "rotate white", - "rotate write", - "turn a bit to the right", - "turn a bit to the rite", - "turn a bit to the white", - "turn a bit to the write", - "turn clockwise", - "turn on your right hand side", - "turn on your rite hand side", - "turn on your white hand side", - "turn on your write hand side", - "turn right", - "turn rite", - "turn to the right", - "turn to the right direction", - "turn to the rite", - "turn to the rite direction", - "turn to the white", - "turn to the white direction", - "turn to the write", - "turn to the write direction", - "turn to your right", - "turn to your rite", - "turn to your white", - "turn to your write", - "turn white", - "turn write", - "white turn", - "write turn" - ], - "command": "Rotate Right" - }, - "Rotate Left": { - "examples": [ - "go left", - "left turn", - "look left", - "make a left turn", - "move left", - "rotate a bit to the left", - "rotate left", - "rotate to the left", - "rotate to your left", - "take a left", - "turn a bit to the left", - "turn anticlockwise", - "turn counterclockwise", - "turn left", - "turn left a little bit", - "turn on the left", - "turn on your left hand side", - "turn to the left", - "turn to the left direction", - "turn to your left" - ], - "command": "Rotate Left" - }, - "LookDown": { - "examples": [ - "look at the floor", - "look at your feet", - "look down", - "look down to the floor", - "look downwards", - "look to the floor" - ], - "command": "Look Down" - }, - "LookUp": { - "examples": [ - "look at the ceiling", - "look at the roof", - "look to the ceiling", - "look to the roof", - "look up", - "look up to the ceiling", - "look upwards" - ], - "command": "Look Up" - }, - "TurnAround": { - "examples": ["do a uturn", "look behind you", "spin around", "turn around"], - "command": "Turn Around" - }, - "ExamineStickyNote": { - "examples": [ - "examine the stickynote", - "examine the sticky note", - "examine the note", - "examine the post-it", - "examine the post it", - "examine the postit", - "examine the poster eight", - "examine the posted label", - "examine the postit label", - "examine the note in front of you", - "examine the sticky note in front of you", - "examine the stickynote in front of you", - "examine the post it in front of you", - "examine the postit in front of you", - "examine the post-it in front of you", - "examine the poster eight in front of you", - "examine the posted", - "examine the posted label", - "examine the posted in front of you", - "examine the posted label in front of you", - "examine the posted message,", - "examine the posted message label", - "examine the posted massage,", - "examine the posted massage label", - "examine the posted label message", - "examine the posted label massage", - "examine the posted message in front of you", - "examine the posted massage in front of you", - "examine the posted message label in front of you", - "examine the posted massage label in front of you", - "examine the posted label message in front of you", - "examine the posted label massage in front of you", - "read the sticky note", - "read the stickynote", - "read the note", - "read the post-it", - "read the postit", - "read the post it", - "read the poster eight", - "read the posted label", - "read the postit label", - "read the sticky note in front of you", - "read the stickynote in front of you", - "read the note in front of you", - "read the post-it in front of you", - "read the postit in front of you", - "read the post it in front of you", - "read the poster eight in front of you", - "read the posted in front of you", - "read the posted label in front of you", - "read the posted message", - "read the posted massage", - "read the posted message label", - "read the posted massage label", - "read the posted label message", - "read the posted label massage", - "read the posted message in front of you", - "read the posted massage in front of you", - "read the posted message label in front of you", - "read the posted massage label in front of you", - "read the posted label message in front of you", - "read the posted label massage in front of you", - "open the sticky note", - "open the stickynote", - "open the note", - "open the post-it", - "open the post it", - "open the postit", - "open the poster eight", - "open the note in front of you", - "open the posted label", - "open the postit label", - "open the sticky note in front of you", - "open the stickynote in front of you", - "open the post it in front of you", - "open the postit in front of you", - "open the post-it in front of you", - "open the posted in front of you", - "open the poster eight in front of you", - "open the posted label in front of you", - "open the posted message", - "open the posted massage", - "open the posted message label", - "open the posted massage label", - "open the posted label message", - "open the posted label massage", - "open the posted message in front of you", - "open the posted massage in front of you", - "open the posted message label in front of you", - "open the posted massage label in front of you", - "open the posted label message in front of you", - "open the posted label massage in front of you", - "take the stickynote", - "take the sticky note", - "take the note", - "take the post-it", - "take the post it", - "take the postit", - "take the poster eight", - "take the note in front of you", - "take the posted label", - "take the postit label", - "take the sticky note in front of you", - "take the stickynote in front of you", - "take the post it in front of you", - "take the postit in front of you", - "take the post-it in front of you", - "take the posted in front of you", - "take the posted label in front of you", - "take the posted message", - "take the posted massage", - "take the posted message label", - "take the posted massage label", - "take the posted label message", - "take the posted label massage", - "take the posted message in front of you", - "take the posted massage in front of you", - "take the posted message label in front of you", - "take the posted massage label in front of you", - "take the posted label message in front of you", - "take the posted label massage in front of you" - ], - "command": "Examine Sticky Note " - } -} diff --git a/tests/datamodules/test_teach_edh_batch_creation.py b/tests/datamodules/test_teach_edh_batch_creation.py deleted file mode 100644 index a2ab27a..0000000 --- a/tests/datamodules/test_teach_edh_batch_creation.py +++ /dev/null @@ -1,110 +0,0 @@ -import torch -from pytest_cases import parametrize - -from emma_policy.datamodules.batch_attention_masks import ( - make_mask_from_temporal_ids, - make_text_history_global_pattern, -) -from emma_policy.datamodules.teach_edh_datamodule import TeachEdhDataModule - - -@parametrize( - "total_seq_len,text_attention_mask,target_mask", - [ - ( - 4, - torch.tensor([[1, 1, 1, 0], [1, 0, 0, 0]]), - torch.tensor([[1, 1, 1, 0], [1, 0, 0, 0]]), - ), - ( - 5, - torch.tensor([[1, 1, 1], [0, 0, 0], [1, 1, 0]]), - torch.tensor([[0, 0, 1, 1, 1], [0, 0, 0, 0, 0], [0, 0, 1, 1, 0]]), # noqa: WPS221 - ), - ], -) -def test_text_history_global_attention( - total_seq_len: int, text_attention_mask: torch.Tensor, target_mask: torch.Tensor -) -> None: - """Check global attention output for dummy inputs.""" - output = make_text_history_global_pattern( - total_seq_len=total_seq_len, - text_attention_mask=text_attention_mask, - dtype=text_attention_mask.dtype, - ) - assert torch.equal(output, target_mask) - - -def test_text_history_global_attention_counts( - teach_edh_datamodule: TeachEdhDataModule, -) -> None: - """Ensure that the global attention mask has as many 1s as text tokens.""" - for batch in teach_edh_datamodule.train_dataloader(): - assert batch.global_attention_mask.sum() == batch.text_attention_mask.sum() - - -@parametrize( - "scene_temporal_ids,object_temporal_ids,text_temporal_ids, target_mask", - [ - ( - torch.tensor([[-1, 1, 0]]), - torch.tensor([[-1, -1, 1, 2]]), - torch.tensor([[-1, -1]]), - torch.tensor( - [ - [1, 0, 0, 1, 1, 0, 0, 1, 1], - [1, 1, 0, 1, 1, 1, 0, 1, 1], - [0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 0, 0, 1, 1, 0, 0, 1, 1], - [1, 0, 0, 1, 1, 0, 0, 1, 1], - [1, 1, 0, 1, 1, 1, 0, 1, 1], - [1, 1, 0, 1, 1, 1, 1, 1, 1], - [1, 0, 0, 1, 1, 0, 0, 1, 1], - [1, 0, 0, 1, 1, 0, 0, 1, 1], - ] - ).unsqueeze(0), - ), - ( - torch.tensor([[1, 2], [-1, 1]]), - torch.tensor([[1, 1, 2], [-1, 1, 0]]), - torch.tensor([[0], [-1]]), - torch.tensor( - [ - [ - [1, 0, 1, 1, 0, 0], - [1, 1, 1, 1, 1, 0], - [1, 0, 1, 1, 0, 0], - [1, 0, 1, 1, 0, 0], - [1, 1, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 0], - ], - [ - [1, 0, 1, 0, 0, 1], - [1, 1, 1, 1, 0, 1], - [1, 0, 1, 0, 0, 1], - [1, 1, 1, 1, 0, 1], - [0, 0, 0, 0, 0, 0], - [1, 0, 1, 0, 0, 1], - ], - ] - ), - ), - ], -) -def test_encoder_full_attention_mask( - scene_temporal_ids: torch.Tensor, - object_temporal_ids: torch.Tensor, - text_temporal_ids: torch.Tensor, - target_mask: torch.Tensor, -) -> None: - """Check 2D attention output for dummy inputs.""" - input_temporal_ids = torch.cat( - [scene_temporal_ids, object_temporal_ids, text_temporal_ids], - dim=1, - ) - output = make_mask_from_temporal_ids( - source_temporal_ids=input_temporal_ids, - target_temporal_ids=input_temporal_ids, - dtype=text_temporal_ids.dtype, - ) - assert torch.equal(output, target_mask) diff --git a/tests/datamodules/test_teach_edh_datamodule.py b/tests/datamodules/test_teach_edh_datamodule.py deleted file mode 100644 index 9e39407..0000000 --- a/tests/datamodules/test_teach_edh_datamodule.py +++ /dev/null @@ -1,60 +0,0 @@ -from pytest_cases import parametrize_with_cases -from torch.utils.data import ConcatDataset - -from emma_policy.datamodules.emma_dataclasses import EmmaDatasetBatch -from emma_policy.datamodules.teach_edh_datamodule import TeachEdhDataModule -from tests.fixtures.datamodules import TeachEdhDataModuleCases - - -def test_dataloader_creates_train_batches(teach_edh_datamodule: TeachEdhDataModule) -> None: - # Ensure that the train dataloader is making batches - for batch in iter(teach_edh_datamodule.train_dataloader()): - assert isinstance(batch, EmmaDatasetBatch) - - -@parametrize_with_cases("teach_edh_datamodule", cases=TeachEdhDataModuleCases, glob="valid_seen") -def test_dataloader_creates_valid_seen_batches(teach_edh_datamodule: TeachEdhDataModule) -> None: - valid_dataloader = teach_edh_datamodule.val_dataloader() - - # Ensure the valid dataloder is using the valid seen dataset - assert valid_dataloader.dataset == teach_edh_datamodule._valid_seen_dataset - - # Ensure that the valid dataloader is making batches - for batch in teach_edh_datamodule.val_dataloader(): - assert isinstance(batch, EmmaDatasetBatch) - - -@parametrize_with_cases("teach_edh_datamodule", cases=TeachEdhDataModuleCases, glob="valid_unseen") -def test_dataloader_creates_valid_unseen_batches(teach_edh_datamodule: TeachEdhDataModule) -> None: - valid_dataloader = teach_edh_datamodule.val_dataloader() - - # Ensure the valid dataloder is using the valid unseen dataset - assert valid_dataloader.dataset == teach_edh_datamodule._valid_unseen_dataset - - # Ensure that the valid dataloader is making batches - for batch in teach_edh_datamodule.val_dataloader(): - assert isinstance(batch, EmmaDatasetBatch) - - -@parametrize_with_cases( - "teach_edh_datamodule", cases=TeachEdhDataModuleCases, glob="valid_seen_and_unseen" -) -def test_dataloader_uses_both_seen_and_unseen_valid_instances( - teach_edh_datamodule: TeachEdhDataModule, -) -> None: - valid_dataloader = teach_edh_datamodule.val_dataloader() - - # Ensure the dataset given to the dataloder is the ConcatDataset - assert isinstance(valid_dataloader.dataset, ConcatDataset) - - # Ensure that both the valid seen and valid unseen datasets are in the ConcatDataset - for dataset in valid_dataloader.dataset.datasets: - assert dataset in { - teach_edh_datamodule._valid_seen_dataset, - teach_edh_datamodule._valid_unseen_dataset, - } - assert dataset != teach_edh_datamodule._train_dataset - - # Ensure that the valid dataloader is making batches - for batch in teach_edh_datamodule.val_dataloader(): - assert isinstance(batch, EmmaDatasetBatch) diff --git a/tests/datamodules/test_teach_edh_dataset.py b/tests/datamodules/test_teach_edh_dataset.py deleted file mode 100644 index 54ef289..0000000 --- a/tests/datamodules/test_teach_edh_dataset.py +++ /dev/null @@ -1,144 +0,0 @@ -import itertools -from pathlib import Path - -import torch -from emma_datasets.datamodels import DatasetSplit -from emma_datasets.datamodels.datasets import TeachEdhInstance -from emma_datasets.db import DatasetDb -from filelock import FileLock -from pytest_cases import fixture, parametrize - -from emma_policy.datamodules.emma_dataclasses import EmmaDatasetItem -from emma_policy.datamodules.teach_edh_dataset import TeachEdhDataset -from emma_policy.models.tokenizer_emma import EmmaTokenizer - - -@fixture -def teach_edh_dataset( - cached_db_dir_path: Path, - teach_edh_instances_db: dict[DatasetSplit, Path], - emma_tokenizer: EmmaTokenizer, -) -> TeachEdhDataset: - """Merge all the TEACh EDH instances into a single DatasetDB to test all the instances.""" - output_db_path = cached_db_dir_path.joinpath("teach_merged.db") - - with FileLock(cached_db_dir_path.joinpath("teach_edh_dataset.lock")): - - if not output_db_path.exists(): - output_db = DatasetDb(output_db_path, readonly=False) - teach_split_dbs = itertools.chain.from_iterable( - [DatasetDb(db_dir) for db_dir in teach_edh_instances_db.values()] - ) - - with output_db: - data_idx = 0 - for _, _, instance in teach_split_dbs: - output_db[(data_idx, f"teach_edh_{data_idx}")] = instance - data_idx += 1 - - return TeachEdhDataset(dataset_db_path=output_db_path, tokenizer=emma_tokenizer) - - -def test_dataset_can_get_instances_without_error(teach_edh_dataset: TeachEdhDataset) -> None: - """Ensure instances can be retrieved without error.""" - total_num_instances = len(teach_edh_dataset.db) - - for idx in range(total_num_instances): - dataset_item = teach_edh_dataset[idx] - assert isinstance(dataset_item, EmmaDatasetItem) - - -def test_dataset_creates_input_text_without_errors(teach_edh_dataset: TeachEdhDataset) -> None: - """Verify the dataset can create input text without erroring.""" - total_num_instances = len(teach_edh_dataset.db) - - for idx in range(total_num_instances): - with teach_edh_dataset.db: - instance_str: str = teach_edh_dataset.db[idx] - - instance = TeachEdhInstance.parse_raw(instance_str) - visual_features, _, _ = teach_edh_dataset._prepare_visual_input(instance) - input_text = teach_edh_dataset._get_input_text_from_instance(instance, visual_features) - - assert input_text - assert isinstance(input_text, str) - - -def test_dataset_creates_target_text_without_errors(teach_edh_dataset: TeachEdhDataset) -> None: - """Verify the dataset creates target text without errors.""" - total_num_instances = len(teach_edh_dataset.db) - - for idx in range(total_num_instances): - with teach_edh_dataset.db: - instance_str: str = teach_edh_dataset.db[idx] - - instance = TeachEdhInstance.parse_raw(instance_str) - visual_features, _, _ = teach_edh_dataset._prepare_visual_input(instance) - target_text = teach_edh_dataset._get_target_text_from_instance(instance, visual_features) - - assert target_text - assert isinstance(target_text, str) - - -@parametrize("unknown_visual_token_threshold", [0.5]) -def test_parsed_visual_tokens_are_not_all_unknown( - teach_edh_dataset: TeachEdhDataset, - emma_tokenizer: EmmaTokenizer, - unknown_visual_token_threshold: float, -) -> None: - total_num_instances = len(teach_edh_dataset.db) - - for idx in range(total_num_instances): - with teach_edh_dataset.db: - instance_str: str = teach_edh_dataset.db[idx] - - instance = TeachEdhInstance.parse_raw(instance_str) - visual_features, _, _ = teach_edh_dataset._prepare_visual_input(instance) - - # Checking the unknowns in the action history - input_text = teach_edh_dataset._get_input_text_from_instance(instance, visual_features) - target_text = teach_edh_dataset._get_target_text_from_instance(instance, visual_features) - - all_interaction_actions = list( - filter( - lambda action: action.obj_interaction_action == 1, - itertools.chain(instance.driver_action_history, instance.driver_actions_future), - ) - ) - - # Get the maximum number of visual tokens - max_visual_token_count = len(all_interaction_actions) - - parsed_unk_token_count = input_text.count(emma_tokenizer.unk_token) + target_text.count( - emma_tokenizer.unk_token - ) - - if parsed_unk_token_count > max_visual_token_count * unknown_visual_token_threshold: - raise AssertionError("The number of unknowns in the `input_text` is too high.") - - -@parametrize( - "target_tokens,expected_target_tokens", - [ - ( - torch.tensor([16, 15, 42, 43, 370, 27, 28, 2]), - torch.tensor([1, 1, 1, 1, 1, 2, 2, 2]), - ), - ( - torch.tensor([14, 15, 16, 370, 17, 18, 19, 20, 370, 370, 2]), - torch.tensor([1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 4]), - ), - ], -) -def test_target_temporal_ids( - teach_edh_dataset: TeachEdhDataset, - target_tokens: torch.Tensor, - expected_target_tokens: torch.Tensor, -) -> None: - """Ensure that temporal ids for taget tokens are constructed correctly. - - Separator token id=370 - """ - target_temporal_ids = teach_edh_dataset._make_target_temporal_ids(target_tokens=target_tokens) - - assert torch.equal(target_temporal_ids, expected_target_tokens) diff --git a/tests/fixtures/datamodules.py b/tests/fixtures/datamodules.py index f93cda4..df29712 100644 --- a/tests/fixtures/datamodules.py +++ b/tests/fixtures/datamodules.py @@ -5,7 +5,6 @@ from pytest_cases import fixture, parametrize from emma_policy.datamodules.pretrain_datamodule import EmmaPretrainDataModule -from emma_policy.datamodules.teach_edh_datamodule import TeachEdhDataModule @fixture @@ -29,77 +28,3 @@ def emma_pretrain_datamodule( dm.setup() return dm - - -@fixture -@parametrize("valid_data_split", ["seen", "unseen", "both"]) -def teach_edh_datamodule( - teach_edh_instances_db: dict[DatasetSplit, Path], - valid_data_split: Literal["seen", "unseen", "both"], -) -> TeachEdhDataModule: - datamodule = TeachEdhDataModule( - teach_edh_train_db_file=teach_edh_instances_db[DatasetSplit.train], - teach_edh_valid_seen_db_file=teach_edh_instances_db[DatasetSplit.valid_seen], - teach_edh_valid_unseen_db_file=teach_edh_instances_db[DatasetSplit.valid_unseen], - load_valid_data_split=valid_data_split, - ) - datamodule.prepare_data() - datamodule.setup() - - return datamodule - - -class TeachEdhDataModuleCases: - def case_valid_seen( - self, teach_edh_instances_db: dict[DatasetSplit, Path] - ) -> TeachEdhDataModule: - datamodule = TeachEdhDataModule( - teach_edh_train_db_file=teach_edh_instances_db[DatasetSplit.train], - teach_edh_valid_seen_db_file=teach_edh_instances_db[DatasetSplit.valid_seen], - teach_edh_valid_unseen_db_file=teach_edh_instances_db[DatasetSplit.valid_unseen], - load_valid_data_split="seen", - ) - datamodule.prepare_data() - datamodule.setup() - - return datamodule - - def case_valid_unseen( - self, teach_edh_instances_db: dict[DatasetSplit, Path] - ) -> TeachEdhDataModule: - datamodule = TeachEdhDataModule( - teach_edh_train_db_file=teach_edh_instances_db[DatasetSplit.train], - teach_edh_valid_seen_db_file=teach_edh_instances_db[DatasetSplit.valid_seen], - teach_edh_valid_unseen_db_file=teach_edh_instances_db[DatasetSplit.valid_unseen], - load_valid_data_split="unseen", - ) - datamodule.prepare_data() - datamodule.setup() - - return datamodule - - def case_valid_seen_and_unseen( - self, teach_edh_instances_db: dict[DatasetSplit, Path] - ) -> TeachEdhDataModule: - datamodule = TeachEdhDataModule( - teach_edh_train_db_file=teach_edh_instances_db[DatasetSplit.train], - teach_edh_valid_seen_db_file=teach_edh_instances_db[DatasetSplit.valid_seen], - teach_edh_valid_unseen_db_file=teach_edh_instances_db[DatasetSplit.valid_unseen], - load_valid_data_split="both", - ) - datamodule.prepare_data() - datamodule.setup() - - return datamodule - - -class DataModuleCases: - def case_pretrain( - self, emma_pretrain_datamodule: EmmaPretrainDataModule - ) -> EmmaPretrainDataModule: - return emma_pretrain_datamodule - - def case_teach_edh_datamodule( - self, teach_edh_datamodule: TeachEdhDataModule - ) -> TeachEdhDataModule: - return teach_edh_datamodule diff --git a/tests/fixtures/instance_dbs.py b/tests/fixtures/instance_dbs.py index 8805c7c..f4c73aa 100644 --- a/tests/fixtures/instance_dbs.py +++ b/tests/fixtures/instance_dbs.py @@ -76,52 +76,3 @@ def __get__(self, obj: TeachEdhInstance, obj_type: Any = None) -> Path: # noqa: *obj._future_features_path.parts[dataset_index:], ) return self() - - -@fixture(scope="session") -def teach_edh_instances_db( - cached_db_dir_path: Path, fixtures_root: Path, session_mocker: MockerFixture -) -> dict[DatasetSplit, Path]: - """Create an DatasetDb of TEACh EDH instances and cache to use across tests. - - Additionally, this fixture also mocks the features path of each TeachEdhInstance to point to - the fixtures dir. - """ - session_mocker.patch.object( - TeachEdhInstance, - "features_path", - new_callable=TeachEdhInstanceFeaturesPathPropertyMock, - ) - - session_mocker.patch.object( - TeachEdhInstance, - "future_features_path", - new_callable=TeachEdhInstanceFutureFeaturesPathPropertyMock, - ) - - teach_dataset_splits = {DatasetSplit.train, DatasetSplit.valid_seen, DatasetSplit.valid_unseen} - - all_instance_dbs: dict[DatasetSplit, Path] = { - dataset_split: cached_db_dir_path.joinpath(f"teach_{dataset_split.name}.db") - for dataset_split in teach_dataset_splits - } - - with FileLock(cached_db_dir_path.joinpath("teach_edh_dataset_splits.lock")): - for dataset_split, db_path in all_instance_dbs.items(): - if not db_path.exists(): - progress = get_progress() - instance_creator = DownstreamInstanceCreator(TeachEdhInstance, progress) - - instance_iterator = instance_creator( - input_data=fixtures_root.joinpath( - "teach_edh", "edh_instances", dataset_split.name - ).glob("*.json"), - progress=progress, - ) - db = DatasetDb(db_path, readonly=False) - - with db: - for idx, instance in enumerate(instance_iterator): - db[(idx, f"teach_edh_{idx}")] = instance # noqa: WPS220 - - return all_instance_dbs diff --git a/tests/fixtures/teach_api.py b/tests/fixtures/teach_api.py deleted file mode 100644 index 48efa99..0000000 --- a/tests/fixtures/teach_api.py +++ /dev/null @@ -1,174 +0,0 @@ -import json -from collections.abc import Generator -from io import BytesIO -from pathlib import Path -from typing import Any -from unittest.mock import patch - -from emma_datasets.datamodels.datasets import TeachEdhInstance -from fastapi.testclient import TestClient -from PIL import Image -from pytest_cases import fixture -from pytest_mock import MockerFixture -from requests_mock import Mocker - -from emma_policy.api.clients import FeatureExtractorClient -from emma_policy.commands.run_teach_api import app -from emma_policy.common.settings import Settings -from emma_policy.inference.model_wrapper import PolicyModelWrapper -from tests.fixtures.instance_dbs import ( - TeachEdhInstanceFeaturesPathPropertyMock, - TeachEdhInstanceFutureFeaturesPathPropertyMock, -) - - -@fixture(scope="module") -def edh_instance_path(fixtures_root: Path) -> Path: - """Get and return the path to the EDH instance.""" - return fixtures_root.joinpath( - "teach_edh", "edh_instances", "train", "1c70e34df85e61c8_6282.edh1.json" - ) - - -@fixture(scope="module") -def inference_images_path(fixtures_root: Path) -> Path: - """Get the path to where the images are kept.""" - return fixtures_root.joinpath("teach_edh", "inference_images") - - -@fixture(scope="module") -def teach_edh_instance(edh_instance_path: Path, session_mocker: MockerFixture) -> TeachEdhInstance: - """Get the TEACh EDH Instance used for the tests.""" - session_mocker.patch.object( - TeachEdhInstance, - "features_path", - new_callable=TeachEdhInstanceFeaturesPathPropertyMock, - ) - - session_mocker.patch.object( - TeachEdhInstance, - "future_features_path", - new_callable=TeachEdhInstanceFutureFeaturesPathPropertyMock, - ) - - return TeachEdhInstance.parse_file(edh_instance_path) - - -@fixture(scope="module") -def edh_instance_next_image( - teach_edh_instance: TeachEdhInstance, inference_images_path: Path -) -> Image.Image: - """Load the next frame that the agent would be given.""" - next_image_name = teach_edh_instance.driver_images_future[0] - next_image_path = inference_images_path.joinpath(next_image_name) - image = Image.open(next_image_path) - - return image - - -@fixture(scope="module") -def teach_edh_instance_history_images( - teach_edh_instance: TeachEdhInstance, inference_images_path: Path -) -> list[Image.Image]: - """Convert the driver history images into a list of PIL images. - - Note: InferenceRunner provides a list of `PIL.Image.Image` - """ - images = [] - - for image_file_name in teach_edh_instance.driver_image_history: - image_path = inference_images_path.joinpath(image_file_name) - original_image = Image.open(image_path) - images.append(original_image) - - return images - - -@fixture(scope="module") -def teach_edh_instance_future_images( - teach_edh_instance: TeachEdhInstance, inference_images_path: Path -) -> list[Image.Image]: - """Convert the driver future images into a list of PIL images. - - Note: InferenceRunner provides a list of `PIL.Image.Image` - """ - images = [] - - for image_file_name in teach_edh_instance.driver_images_future: - image_path = inference_images_path.joinpath(image_file_name) - original_image = Image.open(image_path) - images.append(original_image) - - return images - - -@fixture -def policy_model_wrapper(fixtures_root: Path, requests_mock: Mocker) -> PolicyModelWrapper: - """Create a policy model wrapper so no need to keep repeating the args.""" - model_checkpoint_path = fixtures_root.joinpath("teach_tiny.ckpt") - - perception_update_device_path = FeatureExtractorClient( - Settings().feature_extractor_endpoint - )._update_model_device_endpoint - - requests_mock.post(perception_update_device_path) - - model_wrapper = PolicyModelWrapper( - process_index=1, - num_processes=1, - model_checkpoint_path=model_checkpoint_path, - model_name="heriot-watt/emma-tiny", - ) - - return model_wrapper - - -@fixture(scope="module") -def client(fixtures_root: Path) -> Generator[TestClient, None, None]: - """Get an API client which can be used for testing.""" - data_dir = fixtures_root.joinpath("teach_edh") - images_dir = fixtures_root.joinpath("teach_edh", "inference_images") - split = "train" - - patched_argv = ["main", "--data_dir", data_dir, "--images_dir", images_dir, "--split", split] - - with patch("sys.argv", patched_argv): - yield TestClient(app) - - -@fixture(scope="module") -def start_new_edh_instance_request_body(edh_instance_path: Path) -> dict[str, Any]: - """Get an example request body the API should be able to receive. - - This has been adapted from the `RemoteModel` class in `alexa/teach`: - https://github.com/alexa/teach/blob/2e5be94ebdef4910a61cb1bce069d80b0079d1d3/src/teach/inference/remote_model.py#L93-L94 - """ - raw_edh_instance = json.loads(edh_instance_path.read_bytes()) - - request_body = { - "edh_name": raw_edh_instance.get("instance_id", None), - "edh_instance": json.dumps(raw_edh_instance), - } - - return request_body - - -@fixture(scope="module") -def start_new_edh_instance_request_files( - teach_edh_instance_history_images: list[Image.Image], -) -> list[tuple[str, tuple[str, BytesIO, str]]]: - """Convert images into expected format. - - This has been taken from `RemoteModel` class in `alexa/teach`. - """ - images = [] - idx = 0 - - for image in teach_edh_instance_history_images: - image_in_memory = BytesIO() - image.save(image_in_memory, "jpeg") - image_in_memory.seek(0) - images.append(("edh_history_images", (f"history{idx}", image_in_memory, "image/jpeg"))) - idx += 1 - - return images diff --git a/tests/inference/test_decode_trajectory_for_domain.py b/tests/inference/test_decode_trajectory_for_domain.py index be18edf..c5f7848 100644 --- a/tests/inference/test_decode_trajectory_for_domain.py +++ b/tests/inference/test_decode_trajectory_for_domain.py @@ -1,98 +1,9 @@ import pytest from pytest_cases import fixture, parametrize_with_cases -from emma_policy.inference import ( - TEACH_ACTION_TO_SYNONYMS, - AgentAction, - DecodedTrajectoryParser, - get_synonyms_to_teach_action_map, -) from emma_policy.models.tokenizer_emma import EmmaTokenizer @fixture(scope="module") def action_delimiter(emma_tokenizer: EmmaTokenizer) -> str: return emma_tokenizer.sep_token - - -class DecodedTeachTrajectories: - """Various cases to ensure the TEACh trajectories are parsed correctly.""" - - def case_forward(self) -> tuple[str, AgentAction]: - trajectory = "forward ." - api_action = AgentAction("Forward") - - return trajectory, api_action - - def case_move_ahead(self) -> tuple[str, AgentAction]: - trajectory = "move ahead ." - api_action = AgentAction("Forward") - - return trajectory, api_action - - def case_stop_token(self) -> tuple[str, AgentAction]: - trajectory = "" - api_action = AgentAction("Stop") - - return trajectory, api_action - - def case_interaction_object_and_vis_token(self) -> tuple[str, AgentAction]: - trajectory = "pick up mug ." - api_action = AgentAction( # noqa: S106 - "Pickup", object_label="Mug", object_visual_token="" - ) - - return trajectory, api_action - - def case_interaction_object_and_no_vis_token(self) -> tuple[str, AgentAction]: - trajectory = "pick up mug ." - api_action = AgentAction("Pickup", object_label="Mug", object_visual_token=None) - - return trajectory, api_action - - def case_interaction_invalid_object_and_vis_token(self) -> tuple[str, AgentAction]: - trajectory = "pick up mugs ." - api_action = AgentAction( # noqa: S106 - "Pickup", - object_label=None, - object_visual_token="", - raw_object_label="mugs", - ) - - return trajectory, api_action - - @pytest.mark.skip(reason="We assume that this case is not possible.") - def case_only_interaction_visual_token(self) -> tuple[str, AgentAction]: - trajectory = "pick up ." - api_action = AgentAction( # noqa: S106 - "Pickup", object_label=None, object_visual_token="" - ) - - return trajectory, api_action - - -@parametrize_with_cases("decoded_actions,expected_output", cases=DecodedTeachTrajectories) -def test_decoded_action_trajectories_are_converted_properly( - decoded_actions: str, expected_output: AgentAction, action_delimiter: str -) -> None: - trajectory_parser = DecodedTrajectoryParser( # noqa: S106 - execution_domain="TEACh", action_delimiter=action_delimiter, eos_token="" - ) - parsed_trajectory = trajectory_parser(decoded_actions) - - assert parsed_trajectory == expected_output - - -def test_all_synonyms_are_mapped_to_teach_actions() -> None: - """Ensure that each synonym is correctly mapped to one of the TEACh actions. - - Count the total number of synonyms across the mapping, and ensure that the count is identical - to the size of the converted map. - """ - total_synonyms_count = sum( - len(synonym_set) for synonym_set in TEACH_ACTION_TO_SYNONYMS.values() - ) - - synonyms_actions_map = get_synonyms_to_teach_action_map() - - assert len(synonyms_actions_map) == total_synonyms_count diff --git a/tests/inference/test_policy_model_wrapper.py b/tests/inference/test_policy_model_wrapper.py index 92eb2a4..b5fb0b8 100644 --- a/tests/inference/test_policy_model_wrapper.py +++ b/tests/inference/test_policy_model_wrapper.py @@ -39,139 +39,3 @@ def load_frame_features_like_api_response(features_path: Path) -> list[dict[str, def test_model_is_loaded_from_checkpoint(policy_model_wrapper: PolicyModelWrapper) -> None: """Verify the model has been loaded from the checkpoint correctly.""" assert not policy_model_wrapper._model.training - - -@pytest.mark.skip(reason="Using inference like this will be deprecated in the near future.") -def test_new_edh_instance_is_initialized( - single_feature_extractor_endpoint: str, - policy_model_wrapper: PolicyModelWrapper, - teach_edh_instance: TeachEdhInstance, - teach_edh_instance_history_images: list[Image.Image], - requests_mock: Mocker, -) -> None: - """Verify that a new EDH instance is properly initialized within the wrapper.""" - history_features = load_frame_features_like_api_response(teach_edh_instance.features_path) - - requests_mock.register_uri( - "POST", - single_feature_extractor_endpoint, - [{"json": features} for features in history_features], - ) - - policy_model_wrapper.start_new_edh_instance( - edh_instance=teach_edh_instance, - edh_history_images=teach_edh_instance_history_images, - edh_name=teach_edh_instance.instance_id, - ) - - assert policy_model_wrapper._edh_instance_state.decoding_step == 1 - assert ( - policy_model_wrapper._teach_edh_inference_dataset.previous_frame - == teach_edh_instance_history_images[-1] - ) - assert len(policy_model_wrapper._teach_edh_inference_dataset._feature_dicts) == len( - teach_edh_instance_history_images - ) - - -@pytest.mark.skip(reason="Using inference like this will be deprecated in the near future.") -def test_next_action_can_be_predicted( - single_feature_extractor_endpoint: str, - policy_model_wrapper: PolicyModelWrapper, - teach_edh_instance: TeachEdhInstance, - teach_edh_instance_history_images: list[Image.Image], - edh_instance_next_image: Image.Image, - requests_mock: Mocker, -) -> None: - """Verify that the next action can be predicted after starting a new edh instance.""" - history_features = load_frame_features_like_api_response(teach_edh_instance.features_path) - future_features = load_frame_features_like_api_response( - teach_edh_instance.future_features_path - ) - - requests_mock.register_uri( - "POST", - single_feature_extractor_endpoint, - [{"json": features} for features in itertools.chain(history_features, future_features)], - ) - - policy_model_wrapper.start_new_edh_instance( - edh_instance=teach_edh_instance, - edh_history_images=teach_edh_instance_history_images, - edh_name=teach_edh_instance.instance_id, - ) - - assert policy_model_wrapper._edh_instance_state.decoding_step == 1 - assert ( - policy_model_wrapper._teach_edh_inference_dataset.previous_frame - == teach_edh_instance_history_images[-1] - ) - assert len(policy_model_wrapper._teach_edh_inference_dataset._feature_dicts) == len( - teach_edh_instance_history_images - ) - - previous_action = None - previous_state = deepcopy(policy_model_wrapper._edh_instance_state) - - next_action, action_coords = policy_model_wrapper.get_next_action( - edh_instance_next_image, teach_edh_instance, previous_action - ) - # Verify the decoding step has increases by 1 - assert ( - policy_model_wrapper._edh_instance_state.decoding_step == previous_state.decoding_step + 1 - ) - - -@pytest.mark.skip(reason="Not all future images have been downloaded/added to the fixtures") -def test_successive_next_actions_can_be_predicted( - single_feature_extractor_endpoint: str, - policy_model_wrapper: PolicyModelWrapper, - teach_edh_instance: TeachEdhInstance, - teach_edh_instance_history_images: list[Image.Image], - teach_edh_instance_future_images: list[Image.Image], - requests_mock: Mocker, -) -> None: - """Verfiy all successive next actions can be predicted.""" - history_features = load_frame_features_like_api_response(teach_edh_instance.features_path) - future_features = load_frame_features_like_api_response( - teach_edh_instance.future_features_path - ) - - requests_mock.register_uri( - "POST", - single_feature_extractor_endpoint, - [{"json": features} for features in itertools.chain(history_features, future_features)], - ) - - policy_model_wrapper.start_new_edh_instance( - edh_instance=teach_edh_instance, - edh_history_images=teach_edh_instance_history_images, - edh_name=teach_edh_instance.instance_id, - ) - - assert policy_model_wrapper._edh_instance_state.decoding_step == 1 - assert ( - policy_model_wrapper._teach_edh_inference_dataset.previous_frame - == teach_edh_instance_history_images[-1] - ) - assert len(policy_model_wrapper._teach_edh_inference_dataset._feature_dicts) == len( - teach_edh_instance_history_images - ) - - previous_action = None - previous_state = deepcopy(policy_model_wrapper._edh_instance_state) - - for future_image in teach_edh_instance_future_images: - next_action, action_coords = policy_model_wrapper.get_next_action( - future_image, teach_edh_instance, previous_action - ) - - # Verify the decoding step has increases by 1 - assert ( - policy_model_wrapper._edh_instance_state.decoding_step - == previous_state.decoding_step + 1 - ) - - # Update the state tracking - previous_action = SimulatorAction(action=next_action, obj_relative_coord=action_coords) - previous_state = deepcopy(policy_model_wrapper._edh_instance_state) diff --git a/tests/models/test_pretrain_model.py b/tests/models/test_pretrain_model.py index cd215e3..278a729 100644 --- a/tests/models/test_pretrain_model.py +++ b/tests/models/test_pretrain_model.py @@ -5,7 +5,6 @@ from transformers import PreTrainedModel from emma_policy.datamodules.pretrain_datamodule import EmmaPretrainDataModule -from emma_policy.datamodules.teach_edh_datamodule import TeachEdhDataModule from emma_policy.models.model_output_emma import EmmaSeq2SeqLMOutput @@ -16,15 +15,11 @@ def case_pretrain_datamodule( return emma_pretrain_datamodule -def case_teach_edh_datamodule(teach_edh_datamodule: TeachEdhDataModule) -> TeachEdhDataModule: - return teach_edh_datamodule - - # ----------------------------------- Tests ---------------------------------- # @parametrize_with_cases("datamodule", cases=".", glob="*_datamodule") def test_pretrain_model_forward_works_on_train_data( emma_model_for_causal_lm: PreTrainedModel, - datamodule: Union[EmmaPretrainDataModule, TeachEdhDataModule], + datamodule: EmmaPretrainDataModule, ) -> None: train_loader = datamodule.train_dataloader() batch = next(iter(train_loader))