From 903a4af13e1a93af60ba86a7bdab20ef605f3b01 Mon Sep 17 00:00:00 2001 From: gpantaz Date: Wed, 29 Nov 2023 16:12:33 +0000 Subject: [PATCH] feat: update policy --- .../pretrain_instances/is_train_instance.py | 1 + tests/fixtures/datamodules.py | 7 ++ tests/inference/test_policy_model_wrapper.py | 6 -- tests/inference/test_simbot_api.py | 68 ------------------- tests/inference/test_teach_api.py | 37 ---------- 5 files changed, 8 insertions(+), 111 deletions(-) delete mode 100644 tests/inference/test_simbot_api.py delete mode 100644 tests/inference/test_teach_api.py diff --git a/src/emma_policy/datamodules/pretrain_instances/is_train_instance.py b/src/emma_policy/datamodules/pretrain_instances/is_train_instance.py index e5174fb..e3cb73c 100644 --- a/src/emma_policy/datamodules/pretrain_instances/is_train_instance.py +++ b/src/emma_policy/datamodules/pretrain_instances/is_train_instance.py @@ -15,6 +15,7 @@ def load_coco_ids(coco_splits_path: Path) -> set[str]: We only extract the image ID's, which are in the form `COCO_val2014_000000238836`. """ with open(coco_splits_path) as in_file: + print(coco_splits_path) data_list = json.load(in_file) image_ids: set[str] = set() diff --git a/tests/fixtures/datamodules.py b/tests/fixtures/datamodules.py index df29712..1fbad6a 100644 --- a/tests/fixtures/datamodules.py +++ b/tests/fixtures/datamodules.py @@ -28,3 +28,10 @@ def emma_pretrain_datamodule( dm.setup() return dm + + +class DataModuleCases: + def case_pretrain( + self, emma_pretrain_datamodule: EmmaPretrainDataModule + ) -> EmmaPretrainDataModule: + return emma_pretrain_datamodule diff --git a/tests/inference/test_policy_model_wrapper.py b/tests/inference/test_policy_model_wrapper.py index b5fb0b8..409aa6c 100644 --- a/tests/inference/test_policy_model_wrapper.py +++ b/tests/inference/test_policy_model_wrapper.py @@ -11,7 +11,6 @@ from emma_policy.api.clients import FeatureExtractorClient from emma_policy.common.settings import Settings -from emma_policy.inference.model_wrapper import PolicyModelWrapper, SimulatorAction @pytest.fixture(scope="module") @@ -34,8 +33,3 @@ def load_frame_features_like_api_response(features_path: Path) -> list[dict[str, ] return response_features - - -def test_model_is_loaded_from_checkpoint(policy_model_wrapper: PolicyModelWrapper) -> None: - """Verify the model has been loaded from the checkpoint correctly.""" - assert not policy_model_wrapper._model.training diff --git a/tests/inference/test_simbot_api.py b/tests/inference/test_simbot_api.py deleted file mode 100644 index 724676c..0000000 --- a/tests/inference/test_simbot_api.py +++ /dev/null @@ -1,68 +0,0 @@ -from typing import Optional - -from emma_common.datamodels import DialogueUtterance, EmmaPolicyRequest, SpeakerRole -from pytest_cases import parametrize -from transformers import AutoTokenizer - -from emma_policy.datamodules.pretrain_instances import Task -from emma_policy.inference.model_wrapper.simbot_action_input_builder import ( - SimBotActionInputBuilder, -) - - -@parametrize( - "input_request, target", - [ - ( - EmmaPolicyRequest( - dialogue_history=[DialogueUtterance(role=SpeakerRole.user, utterance="")], - environment_history=[], - ), - None, - ), - ( - EmmaPolicyRequest( - dialogue_history=[ - DialogueUtterance(role=SpeakerRole.user, utterance="Instruction"), - DialogueUtterance(role=SpeakerRole.agent, utterance="Is this a question?"), - DialogueUtterance(role=SpeakerRole.user, utterance="Maybe"), - ], - environment_history=[], - ), - ("<> instruction. <> is this a question? <> maybe."), - ), - ( - EmmaPolicyRequest( - dialogue_history=[ - DialogueUtterance(role=SpeakerRole.user, utterance="Instruction"), - ], - environment_history=[], - ), - "<> instruction.", - ), - ( - EmmaPolicyRequest( - dialogue_history=[ - DialogueUtterance(role=SpeakerRole.user, utterance="Instruction1"), - DialogueUtterance(role=SpeakerRole.agent, utterance="Is this a question?"), - DialogueUtterance(role=SpeakerRole.user, utterance="Maybe"), - DialogueUtterance(role=SpeakerRole.user, utterance="Instruction2"), - ], - environment_history=[], - ), - "<> instruction1. <> is this a question? <> maybe. <> instruction2.", - ), - ], -) -def test_simbot_action_builder_parses_dialogue_history( - input_request: EmmaPolicyRequest, - target: Optional[str], -) -> None: - """Test that the action builder parses a request properly.""" - tokenizer = AutoTokenizer.from_pretrained("heriot-watt/emma-base") - builder = SimBotActionInputBuilder(tokenizer=tokenizer) - output = builder._parse_dialogue_from_request(input_request, task=Task.action_execution) - assert output == target - if output is not None: - input_text = builder._prepare_input_text(instruction=output, task=Task.action_execution) - assert input_text diff --git a/tests/inference/test_teach_api.py b/tests/inference/test_teach_api.py deleted file mode 100644 index 2a0d8d1..0000000 --- a/tests/inference/test_teach_api.py +++ /dev/null @@ -1,37 +0,0 @@ -from io import BytesIO -from typing import Any - -import pytest -from fastapi.testclient import TestClient - - -def test_ping_works(client: TestClient) -> None: - """Verify the API can be pinged.""" - response = client.get("/") - assert response.status_code == 200 - assert response.json() == {"action": "Look Up", "obj_relative_coord": [0.1, 0.2]} - - -@pytest.mark.skip() -def test_get_edh_history_images_convert_bytes_to_pillow_images() -> None: - raise NotImplementedError - - -@pytest.mark.skip() -def test_start_new_instance_prepares_the_model_properly( - client: TestClient, - start_new_edh_instance_request_body: dict[str, Any], - start_new_edh_instance_request_files: list[tuple[str, tuple[str, BytesIO, str]]], -) -> None: - response = client.post( - "/start_new_edh_instance", - data=start_new_edh_instance_request_body, - files=start_new_edh_instance_request_files, - ) - - assert response.status_code == 200 - - -@pytest.mark.skip(reason="Not implemented yet") -def test_get_next_action_returns_dict_for_the_inference_runner(client: TestClient) -> None: - raise NotImplementedError