Skip to content

Commit

Permalink
feat: update policy
Browse files Browse the repository at this point in the history
  • Loading branch information
gpantaz committed Nov 29, 2023
1 parent 26597bf commit 903a4af
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 111 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def load_coco_ids(coco_splits_path: Path) -> set[str]:
We only extract the image ID's, which are in the form `COCO_val2014_000000238836`.
"""
with open(coco_splits_path) as in_file:
print(coco_splits_path)
data_list = json.load(in_file)

image_ids: set[str] = set()
Expand Down
7 changes: 7 additions & 0 deletions tests/fixtures/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,10 @@ def emma_pretrain_datamodule(
dm.setup()

return dm


class DataModuleCases:
def case_pretrain(
self, emma_pretrain_datamodule: EmmaPretrainDataModule
) -> EmmaPretrainDataModule:
return emma_pretrain_datamodule
6 changes: 0 additions & 6 deletions tests/inference/test_policy_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from emma_policy.api.clients import FeatureExtractorClient
from emma_policy.common.settings import Settings
from emma_policy.inference.model_wrapper import PolicyModelWrapper, SimulatorAction


@pytest.fixture(scope="module")
Expand All @@ -34,8 +33,3 @@ def load_frame_features_like_api_response(features_path: Path) -> list[dict[str,
]

return response_features


def test_model_is_loaded_from_checkpoint(policy_model_wrapper: PolicyModelWrapper) -> None:
"""Verify the model has been loaded from the checkpoint correctly."""
assert not policy_model_wrapper._model.training
68 changes: 0 additions & 68 deletions tests/inference/test_simbot_api.py

This file was deleted.

37 changes: 0 additions & 37 deletions tests/inference/test_teach_api.py

This file was deleted.

0 comments on commit 903a4af

Please sign in to comment.