diff --git a/.flake8 b/.flake8 index b36791b..a3f985a 100644 --- a/.flake8 +++ b/.flake8 @@ -14,6 +14,8 @@ extend-ignore = E501, # Stop finding commented out code because it's mistaking shape annotations for code E800, + # Let pre-commit complain about sorting + I, # Don't complain about asserts S101, # Stop complaining about using functions from random @@ -110,6 +112,7 @@ allowed-domain-names = data utils util + obj params per-file-ignores = diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ad9b343..50c6ff2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,3 @@ -# using default_language_version -default_language_version: - node: 16.14.2 - repos: # -------------------------- Version control checks -------------------------- # - repo: https://github.com/pre-commit/pre-commit-hooks diff --git a/README.md b/README.md index a0238ea..e89c371 100644 --- a/README.md +++ b/README.md @@ -150,7 +150,7 @@ python run.py experiment=nlvr2_downstream.yaml #### DTC - Unified model -When initializing from the pretrained model, which doesn't include the special tokens for the downstream CR and action prediction tasks, you will need to manually edit the vocabulary size in the [model config]([heriot-watt/emma-base-combined/config.json](https://github.com/emma-heriot-watt/policy/blob/main/heriot-watt/emma-base-cr/config.json)https://github.com/emma-heriot-watt/policy/blob/main/heriot-watt/emma-base-cr/config.json). For initialization from the pretrained `emma-base`, set the `vocab_size` to 10252. +When initializing from the pretrained model, which doesn't include the special tokens for the downstream CR and action prediction tasks, you will need to manually edit the vocabulary size in the [model config](<[heriot-watt/emma-base-combined/config.json](https://github.com/emma-heriot-watt/policy/blob/main/heriot-watt/emma-base-cr/config.json)https://github.com/emma-heriot-watt/policy/blob/main/heriot-watt/emma-base-cr/config.json>). For initialization from the pretrained `emma-base`, set the `vocab_size` to 10252. ``` python run.py experiment=simbot_combined.yaml diff --git a/plot_ablation_results.py b/plot_ablation_results.py index cf19a39..4a81609 100644 --- a/plot_ablation_results.py +++ b/plot_ablation_results.py @@ -1,5 +1,5 @@ -import matplotlib.pyplot as plt import numpy as np +from matplotlib import pyplot as plt ablation_percentages = np.array([0, 0.25, 0.5, 0.75, 1.0]) @@ -11,24 +11,24 @@ y_ticks = [17, 20, 23, 26, 29, 32, 35, 38] fig = plt.figure() -ax = fig.add_subplot(111) +ax = fig.add_subplot(111) # noqa: WPS432 ax.plot(ablation_percentages, human_results, "-.g", label="DTC") ax.plot(ablation_percentages, vision_aug_results, "-bs", label="Visual Aug") ax2 = ax.twinx() -ax2.plot(ablation_percentages, cdf_aug_results, "--ro", label="CDF Aug") +ax2.plot(ablation_percentages, cdf_aug_results, "--ro", label="CDF Aug") # type: ignore[attr-defined] # fig.legend(loc="upper right") # ax.set_xlabel("Pr") ax.set_yticks(y_ticks) ax2.set_yticks(y_ticks) -ax.set_ylabel(r"Vision Augmentations") -ax2.set_ylabel(r"CDF Augmentations") +ax.set_ylabel("Vision Augmentations") +ax2.set_ylabel("CDF Augmentations") ax.grid() fig.legend(loc="upper center", bbox_to_anchor=(0.5, 0.425), fancybox=True, ncol=3) -plt.xticks(ablation_percentages, ablation_percentages) +plt.xticks(ablation_percentages, ablation_percentages) # type: ignore[arg-type] # plt.show() plt.title("Performance curves when ablating augmentations") ax.set_xlabel("Proportion of train instances") diff --git a/src/emma_policy/commands/run_simbot_action_api.py b/src/emma_policy/commands/run_simbot_action_api.py index c4afed2..83600cf 100644 --- a/src/emma_policy/commands/run_simbot_action_api.py +++ b/src/emma_policy/commands/run_simbot_action_api.py @@ -1,25 +1,16 @@ import logging -import sys from argparse import ArgumentParser, Namespace from pathlib import Path from typing import Any, Literal, Optional, TypedDict, Union import torch -from emma_common.api.instrumentation import instrument_app -from emma_common.aws.cloudwatch import add_cloudwatch_handler_to_logger from emma_common.datamodels import TorchDataMixin -from emma_common.logging import ( - InstrumentedInterceptHandler, - logger, - setup_logging, - setup_rich_logging, -) +from emma_common.logging import logger, setup_rich_logging from fastapi import FastAPI, Request, Response, status from pydantic import BaseSettings, FilePath from transformers import PreTrainedTokenizer from uvicorn import Config, Server -from emma_policy._version import __version__ # noqa: WPS436 from emma_policy.datamodules.pretrain_instances import Task from emma_policy.datamodules.simbot_action_datamodule import prepare_action_tokenizer from emma_policy.datamodules.simbot_combined_datamodule import prepare_combined_tokenizer @@ -141,21 +132,6 @@ async def healthcheck(response: Response) -> str: return "success" -# [deprecated!] -# @app.post("/generate_raw_text_match", status_code=status.HTTP_200_OK) -# async def generate_raw_text_match(request: Request, response: Response) -> Optional[str]: -# """Endpoint for simple raw text matching.""" -# try: -# simbot_request = GenerateRequest.parse_obj(await request.json()) -# except Exception as request_err: -# logging.exception("Unable to parse request", exc_info=request_err) -# response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR -# raise request_err -# with tracer.start_as_current_span("Raw text match"): -# output_string = api_store["raw_text_matcher"](simbot_request) -# return output_string - - @app.post("/generate_find", status_code=status.HTTP_200_OK) async def generate_find(request: Request, response: Response) -> list[str]: """Endpoint for find.""" @@ -324,17 +300,7 @@ async def generate(request: Request, response: Response) -> str: def main() -> None: """Runs the server.""" - if settings.traces_to_opensearch: - instrument_app( - app, - otlp_endpoint=settings.otlp_endpoint, - service_name=settings.opensearch_service_name, - service_version=__version__, - service_namespace="SimBot", - ) - setup_logging(sys.stdout, InstrumentedInterceptHandler()) - else: - setup_rich_logging(rich_traceback_show_locals=False) + setup_rich_logging(rich_traceback_show_locals=False) server = Server( Config( @@ -345,15 +311,6 @@ def main() -> None: ) ) - if settings.log_to_cloudwatch: - add_cloudwatch_handler_to_logger( - boto3_profile_name=settings.aws_profile, - log_stream_name=settings.watchtower_log_stream_name, - log_group_name=settings.watchtower_log_group_name, - send_interval=1, - enable_trace_logging=settings.traces_to_opensearch, - ) - server.run() @@ -361,7 +318,6 @@ def parse_api_args() -> Namespace: """Parse any arguments.""" arg_parser = ArgumentParser() - # TODO: move this to an inference config arg_parser.add_argument( "--tokenizer_truncation_side", type=str, @@ -399,4 +355,3 @@ def parse_api_args() -> Namespace: if __name__ == "__main__": main() - main() diff --git a/src/emma_policy/datamodules/pretrain_instances/convert_to_pretrain_instances.py b/src/emma_policy/datamodules/pretrain_instances/convert_to_pretrain_instances.py index b1d44e9..7f8d7c0 100644 --- a/src/emma_policy/datamodules/pretrain_instances/convert_to_pretrain_instances.py +++ b/src/emma_policy/datamodules/pretrain_instances/convert_to_pretrain_instances.py @@ -88,7 +88,7 @@ def __getitem__(self, task: Task) -> Iterator[PretrainInstance]: """Get the pretraining instances for the given task.""" return self.instance_task_map[task] - @property # type: ignore[misc] + @property @image_task_check def mlm(self) -> Iterator[PretrainInstance]: # noqa: WPS231 """Get pretrain instances for the MLM task.""" @@ -125,7 +125,7 @@ def mlm(self) -> Iterator[PretrainInstance]: # noqa: WPS231 yield from all_captions - @property # type: ignore[misc] + @property @image_task_check def itm(self) -> Iterator[PretrainInstance]: """Get the pretrain instances for the ITM task.""" @@ -137,7 +137,7 @@ def itm(self) -> Iterator[PretrainInstance]: for caption in self.instance.captions ) - @property # type: ignore[misc] + @property @image_task_check def visual_grounding(self) -> Iterator[PretrainInstance]: """Get the pretrain instances for the visual grounding task.""" @@ -150,7 +150,7 @@ def visual_grounding(self) -> Iterator[PretrainInstance]: task=Task.visual_grounding, ) - @property # type: ignore[misc] + @property @image_task_check def dense_captioning(self) -> Iterator[PretrainInstance]: """Get the pretrain instances for the dense captioning task.""" @@ -163,7 +163,7 @@ def dense_captioning(self) -> Iterator[PretrainInstance]: task=Task.dense_captioning, ) - @property # type: ignore[misc] + @property @image_task_check def relation_detection(self) -> Iterator[PretrainInstance]: """Get the pretrain instances for the relation detection task.""" @@ -210,7 +210,7 @@ def relation_detection(self) -> Iterator[PretrainInstance]: task=Task.relation_detection, ) - @property # type: ignore[misc] + @property @image_task_check def captioning(self) -> Iterator[PretrainInstance]: """Get the pretrain instances for the captioning task.""" @@ -222,7 +222,7 @@ def captioning(self) -> Iterator[PretrainInstance]: for caption in self.instance.captions ) - @property # type: ignore[misc] + @property @image_task_check def vqa(self) -> Iterator[PretrainInstance]: """Get the pretrain instances for the VQA task.""" @@ -238,7 +238,7 @@ def vqa(self) -> Iterator[PretrainInstance]: for qa_pair in self.instance.qa_pairs ) - @property # type: ignore[misc] + @property @video_task_check def instruction_prediction(self) -> Iterator[PretrainInstance]: """Get the pretrain instance for the instruction prediction for a given subgoal.""" @@ -259,7 +259,7 @@ def instruction_prediction(self) -> Iterator[PretrainInstance]: for caption in self.instance.captions ) - @property # type: ignore[misc] + @property @video_task_check def goal_prediction(self) -> Iterator[PretrainInstance]: """Get the pretrain instance for the goal prediction for a given trajectory.""" @@ -282,7 +282,7 @@ def goal_prediction(self) -> Iterator[PretrainInstance]: for task_description in self.instance.task_description ) - @property # type: ignore[misc] + @property @video_task_check def action_execution(self) -> Iterator[PretrainInstance]: """Get the pretrain instance for the action execution task given a subgoal instruction.""" @@ -304,7 +304,7 @@ def action_execution(self) -> Iterator[PretrainInstance]: for caption in self.instance.captions ) - @property # type: ignore[misc] + @property @video_task_check def vtm(self) -> Iterator[PretrainInstance]: """Get the pretrain instance for the video-text matching task given a subgoal.""" @@ -327,7 +327,7 @@ def vtm(self) -> Iterator[PretrainInstance]: for caption in self.instance.captions ) - @property # type: ignore[misc] + @property @video_task_check def fom(self) -> Iterator[PretrainInstance]: """Get the pretrain instance for the feature order modeling task given a subgoal.""" @@ -346,7 +346,7 @@ def fom(self) -> Iterator[PretrainInstance]: for caption in self.instance.captions ) - @property # type: ignore[misc] + @property @video_task_check def vmlm(self) -> Iterator[PretrainInstance]: """Get pretrain instances for the video MLM task.""" diff --git a/src/emma_policy/datamodules/pretrain_instances/is_train_instance.py b/src/emma_policy/datamodules/pretrain_instances/is_train_instance.py index e3cb73c..e5174fb 100644 --- a/src/emma_policy/datamodules/pretrain_instances/is_train_instance.py +++ b/src/emma_policy/datamodules/pretrain_instances/is_train_instance.py @@ -15,7 +15,6 @@ def load_coco_ids(coco_splits_path: Path) -> set[str]: We only extract the image ID's, which are in the form `COCO_val2014_000000238836`. """ with open(coco_splits_path) as in_file: - print(coco_splits_path) data_list = json.load(in_file) image_ids: set[str] = set() diff --git a/src/emma_policy/inference/model_wrapper/simbot_action_input_builder.py b/src/emma_policy/inference/model_wrapper/simbot_action_input_builder.py index 952ba31..87b60c3 100644 --- a/src/emma_policy/inference/model_wrapper/simbot_action_input_builder.py +++ b/src/emma_policy/inference/model_wrapper/simbot_action_input_builder.py @@ -176,7 +176,7 @@ def _parse_environment_history_from_request( """Parse the feature dicts and actions from the current request.""" feature_dicts: list[dict[str, torch.Tensor]] = [] step_index: list[int] = [] - previous_actions = [] + previous_actions: list[str] = [] total_steps = len(request.environment_history) for idx, step in enumerate(request.environment_history, 1): if step.output is None and idx < total_steps: @@ -196,7 +196,7 @@ def _parse_environment_history_from_request( if previous_actions: # Currently the implementation allows None previous actios # but in practice this should never happen. - previous_actions_str = " ".join(previous_actions) # type: ignore[arg-type] + previous_actions_str = " ".join(previous_actions) return (feature_dicts, previous_actions_str, step_index) def _prepare_input_text(self, instruction: str, task: Task) -> BatchEncoding: diff --git a/src/emma_policy/inference/model_wrapper/simbot_action_output_processor.py b/src/emma_policy/inference/model_wrapper/simbot_action_output_processor.py index 5de100c..3694ed5 100644 --- a/src/emma_policy/inference/model_wrapper/simbot_action_output_processor.py +++ b/src/emma_policy/inference/model_wrapper/simbot_action_output_processor.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, cast import torch from emma_common.datamodels import EmmaExtractedFeatures, EmmaPolicyRequest @@ -99,4 +99,4 @@ def _get_largest_entity( width = bbox_coords[:, 2] - bbox_coords[:, 0] height = bbox_coords[:, 3] - bbox_coords[:, 1] areas = width * height - return indices[torch.argmax(areas).item()] # type:ignore[call-overload] + return indices[cast(int, torch.argmax(areas).item())] diff --git a/src/emma_policy/inference/model_wrapper/simbot_cr_output_processor.py b/src/emma_policy/inference/model_wrapper/simbot_cr_output_processor.py index 9fcc35d..62f1192 100644 --- a/src/emma_policy/inference/model_wrapper/simbot_cr_output_processor.py +++ b/src/emma_policy/inference/model_wrapper/simbot_cr_output_processor.py @@ -122,7 +122,7 @@ def _special_color_changer_case( if match is not None: color_result = re.search("(red|blue|green)", match.group()) if color_result is not None: - color = color_result.group() # type: ignore[union-attr] + color = color_result.group() color_button = f"{color} button" if color_button in class_labels: return self._default_prediction diff --git a/src/emma_policy/models/emma_policy.py b/src/emma_policy/models/emma_policy.py index 1b3e5d5..ae931a1 100644 --- a/src/emma_policy/models/emma_policy.py +++ b/src/emma_policy/models/emma_policy.py @@ -71,9 +71,7 @@ def num_training_steps(self) -> int: ) dataset_size = int(dataset_size * self.trainer.limit_train_batches) else: - dataset_size = len( # type: ignore[unreachable] - self.trainer.datamodule.train_dataloader() - ) + dataset_size = len(self.trainer.datamodule.train_dataloader()) # type: ignore[attr-defined] num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes) # Check if using tpus diff --git a/src/emma_policy/models/simbot_combined_policy.py b/src/emma_policy/models/simbot_combined_policy.py index d90ff37..0bb001b 100644 --- a/src/emma_policy/models/simbot_combined_policy.py +++ b/src/emma_policy/models/simbot_combined_policy.py @@ -88,7 +88,11 @@ def on_test_epoch_end(self) -> None: # noqa: WPS231 decoder_input_ids = self._decoder_input_ids outputs = { - example_id: {"prediction": generated_action, "groundtruth": gt, "teacher_forcing": dec} # type: ignore[misc] + example_id: { + "prediction": generated_action, + "groundtruth": gt, + "teacher_forcing": dec, + } for example_id, generated_action, gt, dec in zip( all_example_ids, generated_actions, @@ -107,14 +111,10 @@ def on_test_epoch_end(self) -> None: # noqa: WPS231 generated_actions = [None for _ in range(world_size)] # type: ignore[misc] torch.distributed.all_gather_object(generated_actions, self._generated_actions) if torch.distributed.get_rank() == 0: - all_example_ids = list( - itertools.chain.from_iterable(all_example_ids) # type: ignore[arg-type] - ) - generated_actions = list( - itertools.chain.from_iterable(generated_actions) # type: ignore[arg-type] - ) + all_example_ids = list(itertools.chain.from_iterable(all_example_ids)) + generated_actions = list(itertools.chain.from_iterable(generated_actions)) outputs = { - example_id: generated_action # type:ignore[misc] + example_id: generated_action # type: ignore[misc] for example_id, generated_action in zip(all_example_ids, generated_actions) } self._save_results(outputs) @@ -317,16 +317,16 @@ def _test_instance(self, batch: EmmaDatasetBatch) -> PredictType: if batch.decoder_input_ids is None: raise AssertionError("Expected decoder input ids for single instance testing") - separator_positions = torch.where( - batch.decoder_input_ids[0] == self._separator_token_id # type:ignore[index] - )[0] + separator_positions = torch.where(batch.decoder_input_ids[0] == self._separator_token_id)[ + 0 + ] if separator_positions.shape[0] > 1: end_index = int(separator_positions[-2].item()) + 1 - decoder_input_ids = batch.decoder_input_ids[:, :end_index] # type:ignore[index] + decoder_input_ids = batch.decoder_input_ids[:, :end_index] else: - decoder_input_ids = batch.decoder_input_ids[:, 0].unsqueeze(0) # type:ignore[index] + decoder_input_ids = batch.decoder_input_ids[:, 0].unsqueeze(0) outputs = self.inference_step( batch, decoder_input_ids=decoder_input_ids, max_length=self._max_generated_text_length @@ -350,5 +350,6 @@ def _test_instance(self, batch: EmmaDatasetBatch) -> PredictType: return outputs def _save_results(self, outputs: dict[str, Any]) -> None: - with open(self._results_path, "w") as fp: # type: ignore[arg-type] + assert self._results_path is not None + with open(self._results_path, "w") as fp: json.dump(outputs, fp, indent=4) diff --git a/src/emma_policy/models/simbot_emma_policy.py b/src/emma_policy/models/simbot_emma_policy.py index a1250d7..fada675 100644 --- a/src/emma_policy/models/simbot_emma_policy.py +++ b/src/emma_policy/models/simbot_emma_policy.py @@ -67,7 +67,11 @@ def on_test_epoch_end(self) -> None: # noqa: WPS231 decoder_input_ids = self._decoder_input_ids outputs = { - example_id: {"prediction": generated_action, "groundtruth": gt, "teacher_forcing": dec} # type: ignore[misc] + example_id: { + "prediction": generated_action, + "groundtruth": gt, + "teacher_forcing": dec, + } for example_id, generated_action, gt, dec in zip( all_example_ids, generated_actions, @@ -86,14 +90,10 @@ def on_test_epoch_end(self) -> None: # noqa: WPS231 generated_actions = [None for _ in range(world_size)] # type: ignore[misc] torch.distributed.all_gather_object(generated_actions, self._generated_actions) if torch.distributed.get_rank() == 0: - all_example_ids = list( - itertools.chain.from_iterable(all_example_ids) # type: ignore[arg-type] - ) - generated_actions = list( - itertools.chain.from_iterable(generated_actions) # type: ignore[arg-type] - ) + all_example_ids = list(itertools.chain.from_iterable(all_example_ids)) + generated_actions = list(itertools.chain.from_iterable(generated_actions)) outputs = { - example_id: generated_action # type:ignore[misc] + example_id: generated_action # type: ignore[misc] for example_id, generated_action in zip(all_example_ids, generated_actions) } self._save_results(outputs) @@ -290,16 +290,16 @@ def _test_instance(self, batch: EmmaDatasetBatch) -> PredictType: if batch.decoder_input_ids is None: raise AssertionError("Expected decoder input ids for single instance testing") - separator_positions = torch.where( - batch.decoder_input_ids[0] == self._separator_token_id # type:ignore[index] - )[0] + separator_positions = torch.where(batch.decoder_input_ids[0] == self._separator_token_id)[ + 0 + ] if separator_positions.shape[0] > 1: end_index = int(separator_positions[-2].item()) + 1 - decoder_input_ids = batch.decoder_input_ids[:, :end_index] # type:ignore[index] + decoder_input_ids = batch.decoder_input_ids[:, :end_index] else: - decoder_input_ids = batch.decoder_input_ids[:, 0].unsqueeze(0) # type:ignore[index] + decoder_input_ids = batch.decoder_input_ids[:, 0].unsqueeze(0) outputs = self.inference_step( batch, decoder_input_ids=decoder_input_ids, max_length=self._max_generated_text_length diff --git a/src/emma_policy/train.py b/src/emma_policy/train.py index 5d4b5a6..386b94d 100644 --- a/src/emma_policy/train.py +++ b/src/emma_policy/train.py @@ -156,7 +156,7 @@ def from_hydra_config(cls, config: DictConfig) -> "TrainModel": ) if "resize_embeddings" in config.model and config.model.resize_embeddings: model.resize_model_embeddings( # type: ignore[operator] - tokenizer=datamodule.setup_tokenizer() + tokenizer=datamodule.setup_tokenizer() # type: ignore[attr-defined] ) callbacks: list[Callback] = [] diff --git a/src/emma_policy/utils/boxes.py b/src/emma_policy/utils/boxes.py index 7b8fb71..36fef40 100644 --- a/src/emma_policy/utils/boxes.py +++ b/src/emma_policy/utils/boxes.py @@ -8,7 +8,7 @@ import torch -RawBoxType = Union[list[float], tuple[float, ...], torch.Tensor, np.ndarray] +RawBoxType = Union[list[float], tuple[float, ...], torch.Tensor, np.ndarray] # type: ignore[type-arg] @unique diff --git a/tests/fixtures/instance_dbs.py b/tests/fixtures/instance_dbs.py index 713bc00..87c9c47 100644 --- a/tests/fixtures/instance_dbs.py +++ b/tests/fixtures/instance_dbs.py @@ -33,7 +33,7 @@ def pretrain_db_dir_path(cached_db_dir_path: Path, instances_db_path: Path) -> P return cached_db_dir_path -class TeachEdhInstanceFeaturesPathPropertyMock(PropertyMock): # type: ignore[misc] +class TeachEdhInstanceFeaturesPathPropertyMock(PropertyMock): """Mock the `features_path` property within the TeachEdhInstance. The features path within each instance is derived automatically and NOT hard-coded into the @@ -41,7 +41,7 @@ class TeachEdhInstanceFeaturesPathPropertyMock(PropertyMock): # type: ignore[mi and return something else. """ - def __get__(self, obj: TeachEdhInstance, obj_type: Any = None) -> Path: # noqa: WPS110 + def __get__(self, obj: TeachEdhInstance, obj_type: Any = None) -> Path: # type: ignore[override] """Get the features path from the fixtures. This updates the `return_value`, which is used by `unittest.Mock` to return a value. @@ -51,7 +51,7 @@ def __get__(self, obj: TeachEdhInstance, obj_type: Any = None) -> Path: # noqa: return self() -class TeachEdhInstanceFutureFeaturesPathPropertyMock(PropertyMock): # type: ignore[misc] +class TeachEdhInstanceFutureFeaturesPathPropertyMock(PropertyMock): """Mock the `future_features_path` property within the TeachEdhInstance. The future features path within each instance is derived automatically and NOT hard-coded into @@ -59,7 +59,7 @@ class TeachEdhInstanceFutureFeaturesPathPropertyMock(PropertyMock): # type: ign property and return something else. """ - def __get__(self, obj: TeachEdhInstance, obj_type: Any = None) -> Path: # noqa: WPS110 + def __get__(self, obj: TeachEdhInstance, obj_type: Any = None) -> Path: # type: ignore[override] """Get the future features path from the fixtures. This updates the `return_value`, which is used by `unittest.Mock` to return a value.