diff --git a/src/plangym/videogames/montezuma.py b/src/plangym/videogames/montezuma.py index 274380b..4bd3e55 100644 --- a/src/plangym/videogames/montezuma.py +++ b/src/plangym/videogames/montezuma.py @@ -1,25 +1,15 @@ """Implementation of the montezuma environment adapted for planning problems.""" -from typing import Iterable, Any +from typing import Iterable, Any, SupportsFloat import cv2 import gymnasium as gym import numpy +from numpy import ndarray from plangym.core import wrap_callable from plangym.utils import remove_time_limit -from plangym.videogames.atari import AtariEnv - - -# ------------------------------------------------------------------------------ -# Copyright (c) 2018-2019 Uber Technologies, Inc. -# -# Licensed under the Uber Non-Commercial License (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at the root directory of this project. -# -# See the License for the specific language governing permissions and -# limitations under the License. +from plangym.videogames.atari import AtariEnv, ale_to_ram class MontezumaPosLevel: @@ -82,7 +72,7 @@ def __repr__(self): 42, # Torch ] -KNOWN_XY = [None] * 24 +KNOWN_XY: list[None | tuple[int, int]] = [None] * 24 KEY_BITS = 0x8 | 0x4 | 0x2 @@ -100,35 +90,28 @@ def __init__( only_keys: bool = False, death_room_8: bool = True, render_mode: str = "rgb_array", + x_repeat: int = 1, ): # TODO: version that also considers the room objects were found in """Initialize a :class:`CustomMontezuma`.""" - # spec = gym_registry.spec("MontezumaRevengeDeterministic-v4") - # not actually needed, but we feel safer - # spec.max_episode_steps = int(1e100) - # spec.max_episode_time = int(1e100) self.render_mode = render_mode + self.score_objects = score_objects + self.check_death = check_death + self.objects_from_pixels = objects_from_pixels + self.objects_remember_rooms = objects_remember_rooms + self.only_keys = only_keys + self.coords_obs = obs_type == "coords" + self._x_repeat = x_repeat + self._death_room_8 = death_room_8 + env = gym.make("MontezumaRevengeDeterministic-v4", render_mode=self.render_mode) self.env = remove_time_limit(env) + self.unwrapped.seed(0) self.env.reset() - self.score_objects = score_objects self.ram = None - self.check_death = check_death self.cur_steps = 0 self.cur_score = 0 self.rooms = {} - self.room_time = (None, None) - self.room_threshold = 40 - self.unwrapped.seed(0) - self.coords_obs = obs_type == "coords" - self.state = [] - self.ram_death_state = -1 - self._x_repeat = 2 - self._death_room_8 = death_room_8 self.cur_lives = 5 - self.ignore_ram_death = False - self.objects_from_pixels = objects_from_pixels - self.objects_remember_rooms = objects_remember_rooms - self.only_keys = only_keys self.pos = MontezumaPosLevel(0, 0, 0, 0, 0) if self.coords_obs: shape = self.get_coords().shape @@ -139,64 +122,64 @@ def __init__( shape=shape, ) + @staticmethod + def get_room_xy(room: int) -> None | tuple[int, int]: + """Get the tuple that encodes the provided room.""" + if KNOWN_XY[room] is None: + for y, loc in enumerate(PYRAMID): + if room in loc: + KNOWN_XY[int(room)] = (loc.index(room), y) + break + return KNOWN_XY[room] + def __getattr__(self, e): """Forward to gym environment.""" return getattr(self.env, e) + def get_ram(self): + """Return the current RAM.""" + return ale_to_ram(self.env.unwrapped.ale) + def reset(self, seed=None, return_info: bool = False) -> tuple[numpy.ndarray, dict[str, Any]]: """Reset the environment.""" obs, info = self.env.reset() self.cur_lives = 5 for _ in range(3): - obs, *_, _info = self.env.step(0) - self.ram = self.env.unwrapped.ale.getRAM() + obs, *_, info = self.env.step(0) + self.ram = self.get_ram() self.cur_score = 0 self.cur_steps = 0 - self.ram_death_state = -1 self.pos = None - self.pos = self.pos_from_obs( - self.get_face_pixels(obs), - obs, - ) - if self.get_pos().room not in self.rooms: - self.rooms[self.get_pos().room] = ( - False, - obs[50:].repeat(self._x_repeat, axis=1), - ) - self.room_time = (self.get_pos().room, 0) + self.pos = self.pos_from_obs(self.get_face_pixels(obs), obs) + assert self.pos is not None + if self.pos.room not in self.rooms: + self.rooms[self.pos.room] = obs[50:].repeat(self._x_repeat, axis=1) if self.coords_obs: - return self.get_coords() + return self.get_coords(), info return obs, info - def step(self, action) -> tuple[numpy.ndarray, float, bool, bool, dict]: + def step( + self, action + ) -> ( + tuple[ndarray, SupportsFloat, bool, bool, dict[str, Any]] + | tuple[Any, SupportsFloat, bool, bool, dict[str, Any]] + ): """Step the environment.""" obs, reward, done, truncated, info = self.env.step(action) - self.ram = self.env.unwrapped.ale.getRAM() + self.cur_score += reward self.cur_steps += 1 + self.ram = self.get_ram() face_pixels = self.get_face_pixels(obs) - pixel_death = self.is_pixel_death(obs, face_pixels) - ram_death = self.is_ram_death() - # TODO: remove all this stuff - if self.check_death and pixel_death: - done = True - elif self.check_death and not pixel_death and ram_death: - done = True - - self.cur_score += reward self.pos = self.pos_from_obs(face_pixels, obs) - if self.pos.room != self.room_time[0]: # pragma: no cover - self.room_time = (self.pos.room, 0) - self.room_time = (self.pos.room, self.room_time[1] + 1) - if self.pos.room not in self.rooms or ( - self.room_time[1] == self.room_threshold and not self.rooms[self.pos.room][0] - ): - self.rooms[self.pos.room] = ( - self.room_time[1] == self.room_threshold, - obs[50:].repeat(self._x_repeat, axis=1), - ) - if self._death_room_8: + if self.check_death and (self.is_pixel_death(obs, face_pixels) or self.is_ram_death()): + done = True + elif self._death_room_8: done = done or self.pos.room == 8 + + if self.pos.room not in self.rooms: + self.rooms[self.pos.room] = obs[50:].repeat(self._x_repeat, axis=1) + if self.coords_obs: return self.get_coords(), reward, done, truncated, info return obs, reward, done, truncated, info @@ -284,10 +267,6 @@ def state_to_numpy(self) -> numpy.ndarray: state = self.unwrapped.clone_state() return numpy.array((state, None), dtype=object) - def _restore_state(self, state) -> None: - """Restore the state of the game from the provided numpy array.""" - self.unwrapped.restore_state(state) - def get_restore(self) -> tuple: """Return a tuple containing all the information needed to clone the state of the env.""" return ( @@ -295,8 +274,6 @@ def get_restore(self) -> tuple: self.cur_score, self.cur_steps, self.pos, - self.room_time, - self.ram_death_state, self.score_objects, self.cur_lives, ) @@ -308,20 +285,15 @@ def restore(self, data) -> None: score, steps, pos, - room_time, - ram_death_state, self.score_objects, self.cur_lives, ) = data self.env.reset() - self._restore_state(full_state) - self.ram = self.env.unwrapped.ale.getRAM() + self.unwrapped.restore_state(full_state) + self.ram = self.get_ram() self.cur_score = score self.cur_steps = steps self.pos = pos - self.room_time = room_time - assert len(self.room_time) == 2 - self.ram_death_state = ram_death_state def is_transition_screen(self, obs) -> bool: """Return True if the current observation corresponds to a transition between rooms.""" @@ -345,14 +317,12 @@ def get_face_pixels(self, obs) -> set: def is_pixel_death(self, obs, face_pixels): """Return a death signal extracted from the observation of the environment.""" - # There are no face pixels and yet we are not in a transition screen. We + # There are no face pixels, and yet we are not in a transition screen. We # must be dead! if len(face_pixels) == 0: - # All of the screen except the bottom is black: this is not a death but a + # All the screen except the bottom is black: this is not a death but a # room transition. Ignore. - if self.is_transition_screen(obs): # pragma: no cover - return False - return True + return not self.is_transition_screen(obs) # pragma: no cover # We already checked for the presence of no face pixels, however, # sometimes we can die and still have face pixels. In those cases, @@ -364,7 +334,6 @@ def is_pixel_death(self, obs, face_pixels): pixel[1] + neighbor[1], ) in face_pixels: # pragma: no cover return False - return True # pragma: no cover def is_ram_death(self) -> bool: @@ -372,44 +341,11 @@ def is_ram_death(self) -> bool: self.cur_lives = max(self.ram[58], self.cur_lives) return self.ram[55] != 0 or self.ram[58] < self.cur_lives - def get_pos(self) -> MontezumaPosLevel: - """Return the current pos.""" - assert self.pos is not None - return self.pos - - @staticmethod - def get_room_xy(room) -> tuple[None | tuple[int, int]]: - """Get the tuple that encodes the provided room.""" - if KNOWN_XY[room] is None: - for y, loc in enumerate(PYRAMID): - if room in loc: - KNOWN_XY[room] = (loc.index(room), y) - break - return KNOWN_XY[room] - - @staticmethod - def get_room_out_of_bounds(room_x, room_y) -> bool: - """Return a boolean indicating if the provided tuple represents and invalid room.""" - return room_y < 0 or room_x < 0 or room_y >= len(PYRAMID) or room_x >= len(PYRAMID[0]) - - @staticmethod - def get_room_from_xy(room_x, room_y) -> int: - """Return the number of the room from a tuple.""" - return PYRAMID[room_y][room_x] - - @staticmethod - def make_pos(score, pos) -> MontezumaPosLevel: - """Create a MontezumaPosLevel object using the provided data.""" - return MontezumaPosLevel(pos.level, score, pos.room, pos.x, pos.y) - def render(self, mode="human", **kwargs) -> None | numpy.ndarray: """Render the environment.""" return self.env.render() -# ------------------------------------------------------------------------------ - - class MontezumaEnv(AtariEnv): """Plangym implementation of the MontezumaEnv environment optimized for planning.""" @@ -480,26 +416,20 @@ def get_state(self) -> numpy.ndarray: score, steps, pos, - room_time, - ram_death_state, score_objects, cur_lives, ) = data - room_time = room_time if room_time[0] is not None else (-1, -1) - assert len(room_time) == 2 + metadata = numpy.array( [ float(score), float(steps), - float(room_time[0]), - float(room_time[1]), - float(ram_death_state), float(score_objects), float(cur_lives), ], dtype=float, ) - assert len(metadata) == 7 + assert len(metadata) == 4 posarray = numpy.array(pos.tuple, dtype=float) assert len(posarray) == 5 return numpy.concatenate([full_state, metadata, posarray]) @@ -522,16 +452,13 @@ def set_state(self, state: numpy.ndarray): x=float(pos_vals[3]), y=float(pos_vals[4]), ) - score, steps, rt0, rt1, ram_death_state, score_objects, cur_lives = state[-12:-5].tolist() - room_time = (rt0, rt1) if rt0 != -1 and rt1 != -1 else (None, None) + score, steps, score_objects, cur_lives = state[-9:-5].tolist() full_state = state[0] data = ( full_state, score, steps, pos, - room_time, - int(ram_death_state), bool(score_objects), int(cur_lives), ) diff --git a/tests/videogames/test_montezuma.py b/tests/videogames/test_montezuma.py index 69fe44b..f335972 100644 --- a/tests/videogames/test_montezuma.py +++ b/tests/videogames/test_montezuma.py @@ -70,16 +70,6 @@ def test_repr(self, pos_level): class TestCustomMontezuma: - def test_make_pos(self, env): - assert isinstance(env.gym_env.make_pos(1000, env.gym_env.pos), MontezumaPosLevel) - - def test_get_room(self): - env = CustomMontezuma() - env.get_room_xy(3) - env.get_room_from_xy(0, 3) - assert env.get_room_out_of_bounds(99, 99) - assert not env.get_room_out_of_bounds(0, 0) - def test_pos_from_unproc_state(self): env = CustomMontezuma(obs_type="rgb") obs = env.reset() @@ -105,7 +95,7 @@ def test_get_objects_from_pixel(self): assert isinstance(tup, tuple) -class TestMontezume(api_tests.TestPlanEnv): +class TestMontezuma(api_tests.TestPlanEnv): @pytest.mark.parametrize("state", [None, True]) @pytest.mark.parametrize("return_state", [None, True, False]) def test_step(self, env, state, return_state, dt=1): @@ -129,8 +119,6 @@ def test_step(self, env, state, return_state, dt=1): curr_state = env.get_state() curr_state, new_state = curr_state[1:], new_state[1:] assert new_state.shape == curr_state.shape - # FIXME: We are not setting and getting the state properly - return assert (new_state == curr_state).all(), ( f"original: {new_state[new_state != curr_state]} " f"env: {curr_state[new_state != curr_state]}"