Skip to content

Commit

Permalink
fix lint issues
Browse files Browse the repository at this point in the history
  • Loading branch information
amitkparekh committed Dec 5, 2023
1 parent 6b2bfad commit 9657f54
Show file tree
Hide file tree
Showing 16 changed files with 65 additions and 113 deletions.
3 changes: 3 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ extend-ignore =
E501,
# Stop finding commented out code because it's mistaking shape annotations for code
E800,
# Let pre-commit complain about sorting
I,
# Don't complain about asserts
S101,
# Stop complaining about using functions from random
Expand Down Expand Up @@ -110,6 +112,7 @@ allowed-domain-names =
data
utils
util
obj
params

per-file-ignores =
Expand Down
4 changes: 0 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
# using default_language_version
default_language_version:
node: 16.14.2

repos:
# -------------------------- Version control checks -------------------------- #
- repo: https://github.com/pre-commit/pre-commit-hooks
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ python run.py experiment=nlvr2_downstream.yaml

#### DTC - Unified model

When initializing from the pretrained model, which doesn't include the special tokens for the downstream CR and action prediction tasks, you will need to manually edit the vocabulary size in the [model config]([heriot-watt/emma-base-combined/config.json](https://github.com/emma-heriot-watt/policy/blob/main/heriot-watt/emma-base-cr/config.json)https://github.com/emma-heriot-watt/policy/blob/main/heriot-watt/emma-base-cr/config.json). For initialization from the pretrained `emma-base`, set the `vocab_size` to 10252.
When initializing from the pretrained model, which doesn't include the special tokens for the downstream CR and action prediction tasks, you will need to manually edit the vocabulary size in the [model config](<[heriot-watt/emma-base-combined/config.json](https://github.com/emma-heriot-watt/policy/blob/main/heriot-watt/emma-base-cr/config.json)https://github.com/emma-heriot-watt/policy/blob/main/heriot-watt/emma-base-cr/config.json>). For initialization from the pretrained `emma-base`, set the `vocab_size` to 10252.

```
python run.py experiment=simbot_combined.yaml
Expand Down
12 changes: 6 additions & 6 deletions plot_ablation_results.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import pyplot as plt


ablation_percentages = np.array([0, 0.25, 0.5, 0.75, 1.0])
Expand All @@ -11,24 +11,24 @@
y_ticks = [17, 20, 23, 26, 29, 32, 35, 38]

fig = plt.figure()
ax = fig.add_subplot(111)
ax = fig.add_subplot(111) # noqa: WPS432
ax.plot(ablation_percentages, human_results, "-.g", label="DTC")
ax.plot(ablation_percentages, vision_aug_results, "-bs", label="Visual Aug")

ax2 = ax.twinx()
ax2.plot(ablation_percentages, cdf_aug_results, "--ro", label="CDF Aug")
ax2.plot(ablation_percentages, cdf_aug_results, "--ro", label="CDF Aug") # type: ignore[attr-defined]
# fig.legend(loc="upper right")

# ax.set_xlabel("Pr")
ax.set_yticks(y_ticks)
ax2.set_yticks(y_ticks)
ax.set_ylabel(r"Vision Augmentations")
ax2.set_ylabel(r"CDF Augmentations")
ax.set_ylabel("Vision Augmentations")
ax2.set_ylabel("CDF Augmentations")

ax.grid()

fig.legend(loc="upper center", bbox_to_anchor=(0.5, 0.425), fancybox=True, ncol=3)
plt.xticks(ablation_percentages, ablation_percentages)
plt.xticks(ablation_percentages, ablation_percentages) # type: ignore[arg-type]
# plt.show()
plt.title("Performance curves when ablating augmentations")
ax.set_xlabel("Proportion of train instances")
Expand Down
49 changes: 2 additions & 47 deletions src/emma_policy/commands/run_simbot_action_api.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,16 @@
import logging
import sys
from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import Any, Literal, Optional, TypedDict, Union

import torch
from emma_common.api.instrumentation import instrument_app
from emma_common.aws.cloudwatch import add_cloudwatch_handler_to_logger
from emma_common.datamodels import TorchDataMixin
from emma_common.logging import (
InstrumentedInterceptHandler,
logger,
setup_logging,
setup_rich_logging,
)
from emma_common.logging import logger, setup_rich_logging
from fastapi import FastAPI, Request, Response, status
from pydantic import BaseSettings, FilePath
from transformers import PreTrainedTokenizer
from uvicorn import Config, Server

from emma_policy._version import __version__ # noqa: WPS436
from emma_policy.datamodules.pretrain_instances import Task
from emma_policy.datamodules.simbot_action_datamodule import prepare_action_tokenizer
from emma_policy.datamodules.simbot_combined_datamodule import prepare_combined_tokenizer
Expand Down Expand Up @@ -141,21 +132,6 @@ async def healthcheck(response: Response) -> str:
return "success"


# [deprecated!]
# @app.post("/generate_raw_text_match", status_code=status.HTTP_200_OK)
# async def generate_raw_text_match(request: Request, response: Response) -> Optional[str]:
# """Endpoint for simple raw text matching."""
# try:
# simbot_request = GenerateRequest.parse_obj(await request.json())
# except Exception as request_err:
# logging.exception("Unable to parse request", exc_info=request_err)
# response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
# raise request_err
# with tracer.start_as_current_span("Raw text match"):
# output_string = api_store["raw_text_matcher"](simbot_request)
# return output_string


@app.post("/generate_find", status_code=status.HTTP_200_OK)
async def generate_find(request: Request, response: Response) -> list[str]:
"""Endpoint for find."""
Expand Down Expand Up @@ -324,17 +300,7 @@ async def generate(request: Request, response: Response) -> str:

def main() -> None:
"""Runs the server."""
if settings.traces_to_opensearch:
instrument_app(
app,
otlp_endpoint=settings.otlp_endpoint,
service_name=settings.opensearch_service_name,
service_version=__version__,
service_namespace="SimBot",
)
setup_logging(sys.stdout, InstrumentedInterceptHandler())
else:
setup_rich_logging(rich_traceback_show_locals=False)
setup_rich_logging(rich_traceback_show_locals=False)

server = Server(
Config(
Expand All @@ -345,23 +311,13 @@ def main() -> None:
)
)

if settings.log_to_cloudwatch:
add_cloudwatch_handler_to_logger(
boto3_profile_name=settings.aws_profile,
log_stream_name=settings.watchtower_log_stream_name,
log_group_name=settings.watchtower_log_group_name,
send_interval=1,
enable_trace_logging=settings.traces_to_opensearch,
)

server.run()


def parse_api_args() -> Namespace:
"""Parse any arguments."""
arg_parser = ArgumentParser()

# TODO: move this to an inference config
arg_parser.add_argument(
"--tokenizer_truncation_side",
type=str,
Expand Down Expand Up @@ -399,4 +355,3 @@ def parse_api_args() -> Namespace:

if __name__ == "__main__":
main()
main()
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __getitem__(self, task: Task) -> Iterator[PretrainInstance]:
"""Get the pretraining instances for the given task."""
return self.instance_task_map[task]

@property # type: ignore[misc]
@property
@image_task_check
def mlm(self) -> Iterator[PretrainInstance]: # noqa: WPS231
"""Get pretrain instances for the MLM task."""
Expand Down Expand Up @@ -125,7 +125,7 @@ def mlm(self) -> Iterator[PretrainInstance]: # noqa: WPS231

yield from all_captions

@property # type: ignore[misc]
@property
@image_task_check
def itm(self) -> Iterator[PretrainInstance]:
"""Get the pretrain instances for the ITM task."""
Expand All @@ -137,7 +137,7 @@ def itm(self) -> Iterator[PretrainInstance]:
for caption in self.instance.captions
)

@property # type: ignore[misc]
@property
@image_task_check
def visual_grounding(self) -> Iterator[PretrainInstance]:
"""Get the pretrain instances for the visual grounding task."""
Expand All @@ -150,7 +150,7 @@ def visual_grounding(self) -> Iterator[PretrainInstance]:
task=Task.visual_grounding,
)

@property # type: ignore[misc]
@property
@image_task_check
def dense_captioning(self) -> Iterator[PretrainInstance]:
"""Get the pretrain instances for the dense captioning task."""
Expand All @@ -163,7 +163,7 @@ def dense_captioning(self) -> Iterator[PretrainInstance]:
task=Task.dense_captioning,
)

@property # type: ignore[misc]
@property
@image_task_check
def relation_detection(self) -> Iterator[PretrainInstance]:
"""Get the pretrain instances for the relation detection task."""
Expand Down Expand Up @@ -210,7 +210,7 @@ def relation_detection(self) -> Iterator[PretrainInstance]:
task=Task.relation_detection,
)

@property # type: ignore[misc]
@property
@image_task_check
def captioning(self) -> Iterator[PretrainInstance]:
"""Get the pretrain instances for the captioning task."""
Expand All @@ -222,7 +222,7 @@ def captioning(self) -> Iterator[PretrainInstance]:
for caption in self.instance.captions
)

@property # type: ignore[misc]
@property
@image_task_check
def vqa(self) -> Iterator[PretrainInstance]:
"""Get the pretrain instances for the VQA task."""
Expand All @@ -238,7 +238,7 @@ def vqa(self) -> Iterator[PretrainInstance]:
for qa_pair in self.instance.qa_pairs
)

@property # type: ignore[misc]
@property
@video_task_check
def instruction_prediction(self) -> Iterator[PretrainInstance]:
"""Get the pretrain instance for the instruction prediction for a given subgoal."""
Expand All @@ -259,7 +259,7 @@ def instruction_prediction(self) -> Iterator[PretrainInstance]:
for caption in self.instance.captions
)

@property # type: ignore[misc]
@property
@video_task_check
def goal_prediction(self) -> Iterator[PretrainInstance]:
"""Get the pretrain instance for the goal prediction for a given trajectory."""
Expand All @@ -282,7 +282,7 @@ def goal_prediction(self) -> Iterator[PretrainInstance]:
for task_description in self.instance.task_description
)

@property # type: ignore[misc]
@property
@video_task_check
def action_execution(self) -> Iterator[PretrainInstance]:
"""Get the pretrain instance for the action execution task given a subgoal instruction."""
Expand All @@ -304,7 +304,7 @@ def action_execution(self) -> Iterator[PretrainInstance]:
for caption in self.instance.captions
)

@property # type: ignore[misc]
@property
@video_task_check
def vtm(self) -> Iterator[PretrainInstance]:
"""Get the pretrain instance for the video-text matching task given a subgoal."""
Expand All @@ -327,7 +327,7 @@ def vtm(self) -> Iterator[PretrainInstance]:
for caption in self.instance.captions
)

@property # type: ignore[misc]
@property
@video_task_check
def fom(self) -> Iterator[PretrainInstance]:
"""Get the pretrain instance for the feature order modeling task given a subgoal."""
Expand All @@ -346,7 +346,7 @@ def fom(self) -> Iterator[PretrainInstance]:
for caption in self.instance.captions
)

@property # type: ignore[misc]
@property
@video_task_check
def vmlm(self) -> Iterator[PretrainInstance]:
"""Get pretrain instances for the video MLM task."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def load_coco_ids(coco_splits_path: Path) -> set[str]:
We only extract the image ID's, which are in the form `COCO_val2014_000000238836`.
"""
with open(coco_splits_path) as in_file:
print(coco_splits_path)
data_list = json.load(in_file)

image_ids: set[str] = set()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def _parse_environment_history_from_request(
"""Parse the feature dicts and actions from the current request."""
feature_dicts: list[dict[str, torch.Tensor]] = []
step_index: list[int] = []
previous_actions = []
previous_actions: list[str] = []
total_steps = len(request.environment_history)
for idx, step in enumerate(request.environment_history, 1):
if step.output is None and idx < total_steps:
Expand All @@ -196,7 +196,7 @@ def _parse_environment_history_from_request(
if previous_actions:
# Currently the implementation allows None previous actios
# but in practice this should never happen.
previous_actions_str = " ".join(previous_actions) # type: ignore[arg-type]
previous_actions_str = " ".join(previous_actions)
return (feature_dicts, previous_actions_str, step_index)

def _prepare_input_text(self, instruction: str, task: Task) -> BatchEncoding:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, cast

import torch
from emma_common.datamodels import EmmaExtractedFeatures, EmmaPolicyRequest
Expand Down Expand Up @@ -99,4 +99,4 @@ def _get_largest_entity(
width = bbox_coords[:, 2] - bbox_coords[:, 0]
height = bbox_coords[:, 3] - bbox_coords[:, 1]
areas = width * height
return indices[torch.argmax(areas).item()] # type:ignore[call-overload]
return indices[cast(int, torch.argmax(areas).item())]
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _special_color_changer_case(
if match is not None:
color_result = re.search("(red|blue|green)", match.group())
if color_result is not None:
color = color_result.group() # type: ignore[union-attr]
color = color_result.group()
color_button = f"{color} button"
if color_button in class_labels:
return self._default_prediction
Expand Down
4 changes: 1 addition & 3 deletions src/emma_policy/models/emma_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ def num_training_steps(self) -> int:
)
dataset_size = int(dataset_size * self.trainer.limit_train_batches)
else:
dataset_size = len( # type: ignore[unreachable]
self.trainer.datamodule.train_dataloader()
)
dataset_size = len(self.trainer.datamodule.train_dataloader()) # type: ignore[attr-defined]
num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes)

# Check if using tpus
Expand Down
Loading

0 comments on commit 9657f54

Please sign in to comment.