Skip to content

Commit

Permalink
Re-write PlanMontezuma
Browse files Browse the repository at this point in the history
Signed-off-by: guillemdb <[email protected]>
  • Loading branch information
guillemdb committed Oct 16, 2024
1 parent 178a29f commit 398aab5
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 146 deletions.
195 changes: 62 additions & 133 deletions src/plangym/videogames/montezuma.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,15 @@
"""Implementation of the montezuma environment adapted for planning problems."""

from typing import Iterable, Any
from typing import Iterable, Any, SupportsFloat

import cv2
import gymnasium as gym
import numpy
from numpy import ndarray

from plangym.core import wrap_callable
from plangym.utils import remove_time_limit
from plangym.videogames.atari import AtariEnv


# ------------------------------------------------------------------------------
# Copyright (c) 2018-2019 Uber Technologies, Inc.
#
# Licensed under the Uber Non-Commercial License (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at the root directory of this project.
#
# See the License for the specific language governing permissions and
# limitations under the License.
from plangym.videogames.atari import AtariEnv, ale_to_ram


class MontezumaPosLevel:
Expand Down Expand Up @@ -82,7 +72,7 @@ def __repr__(self):
42, # Torch
]

KNOWN_XY = [None] * 24
KNOWN_XY: list[None | tuple[int, int]] = [None] * 24

KEY_BITS = 0x8 | 0x4 | 0x2

Expand All @@ -100,35 +90,28 @@ def __init__(
only_keys: bool = False,
death_room_8: bool = True,
render_mode: str = "rgb_array",
x_repeat: int = 1,
): # TODO: version that also considers the room objects were found in
"""Initialize a :class:`CustomMontezuma`."""
# spec = gym_registry.spec("MontezumaRevengeDeterministic-v4")
# not actually needed, but we feel safer
# spec.max_episode_steps = int(1e100)
# spec.max_episode_time = int(1e100)
self.render_mode = render_mode
self.score_objects = score_objects
self.check_death = check_death
self.objects_from_pixels = objects_from_pixels
self.objects_remember_rooms = objects_remember_rooms
self.only_keys = only_keys
self.coords_obs = obs_type == "coords"
self._x_repeat = x_repeat
self._death_room_8 = death_room_8

env = gym.make("MontezumaRevengeDeterministic-v4", render_mode=self.render_mode)
self.env = remove_time_limit(env)
self.unwrapped.seed(0)
self.env.reset()
self.score_objects = score_objects
self.ram = None
self.check_death = check_death
self.cur_steps = 0
self.cur_score = 0
self.rooms = {}
self.room_time = (None, None)
self.room_threshold = 40
self.unwrapped.seed(0)
self.coords_obs = obs_type == "coords"
self.state = []
self.ram_death_state = -1
self._x_repeat = 2
self._death_room_8 = death_room_8
self.cur_lives = 5
self.ignore_ram_death = False
self.objects_from_pixels = objects_from_pixels
self.objects_remember_rooms = objects_remember_rooms
self.only_keys = only_keys
self.pos = MontezumaPosLevel(0, 0, 0, 0, 0)
if self.coords_obs:
shape = self.get_coords().shape
Expand All @@ -139,64 +122,66 @@ def __init__(
shape=shape,
)

@staticmethod
def get_room_xy(room: int) -> None | tuple[int, int]:
"""Get the tuple that encodes the provided room."""
if room >= len(KNOWN_XY) or room < 0:
return None
if KNOWN_XY[room] is None:
for y, loc in enumerate(PYRAMID):
if room in loc:
KNOWN_XY[int(room)] = (loc.index(room), y)
break
return KNOWN_XY[room]

def __getattr__(self, e):
"""Forward to gym environment."""
return getattr(self.env, e)

def get_ram(self):
"""Return the current RAM."""
return ale_to_ram(self.env.unwrapped.ale)

def reset(self, seed=None, return_info: bool = False) -> tuple[numpy.ndarray, dict[str, Any]]:
"""Reset the environment."""
obs, info = self.env.reset()
self.cur_lives = 5
for _ in range(3):
obs, *_, _info = self.env.step(0)
self.ram = self.env.unwrapped.ale.getRAM()
obs, *_, info = self.env.step(0)
self.ram = self.get_ram()
self.cur_score = 0
self.cur_steps = 0
self.ram_death_state = -1
self.pos = None
self.pos = self.pos_from_obs(
self.get_face_pixels(obs),
obs,
)
if self.get_pos().room not in self.rooms:
self.rooms[self.get_pos().room] = (
False,
obs[50:].repeat(self._x_repeat, axis=1),
)
self.room_time = (self.get_pos().room, 0)
self.pos = self.pos_from_obs(self.get_face_pixels(obs), obs)
assert self.pos is not None
if self.pos.room not in self.rooms:
self.rooms[self.pos.room] = obs[50:].repeat(self._x_repeat, axis=1)
if self.coords_obs:
return self.get_coords()
return self.get_coords(), info
return obs, info

def step(self, action) -> tuple[numpy.ndarray, float, bool, bool, dict]:
def step(
self, action
) -> (
tuple[ndarray, SupportsFloat, bool, bool, dict[str, Any]]
| tuple[Any, SupportsFloat, bool, bool, dict[str, Any]]
):
"""Step the environment."""
obs, reward, done, truncated, info = self.env.step(action)
self.ram = self.env.unwrapped.ale.getRAM()
self.cur_score += reward
self.cur_steps += 1
self.ram = self.get_ram()

face_pixels = self.get_face_pixels(obs)
pixel_death = self.is_pixel_death(obs, face_pixels)
ram_death = self.is_ram_death()
# TODO: remove all this stuff
if self.check_death and pixel_death:
done = True
elif self.check_death and not pixel_death and ram_death:
done = True

self.cur_score += reward
self.pos = self.pos_from_obs(face_pixels, obs)
if self.pos.room != self.room_time[0]: # pragma: no cover
self.room_time = (self.pos.room, 0)
self.room_time = (self.pos.room, self.room_time[1] + 1)
if self.pos.room not in self.rooms or (
self.room_time[1] == self.room_threshold and not self.rooms[self.pos.room][0]
):
self.rooms[self.pos.room] = (
self.room_time[1] == self.room_threshold,
obs[50:].repeat(self._x_repeat, axis=1),
)
if self._death_room_8:
if self.check_death and (self.is_pixel_death(obs, face_pixels) or self.is_ram_death()):
done = True
elif self._death_room_8: # pragma: no cover
done = done or self.pos.room == 8

if self.pos.room not in self.rooms: # pragma: no cover
self.rooms[self.pos.room] = obs[50:].repeat(self._x_repeat, axis=1)

if self.coords_obs:
return self.get_coords(), reward, done, truncated, info
return obs, reward, done, truncated, info
Expand Down Expand Up @@ -270,7 +255,7 @@ def get_objects_from_pixels(self, obs, room, old_objects):
pixel_areas.remove(n_pixels)
cur_object |= 1 << i

if self.only_keys:
if self.only_keys: # pragma: no cover
# These are the key bytes
cur_object &= KEY_BITS
return cur_object
Expand All @@ -284,19 +269,13 @@ def state_to_numpy(self) -> numpy.ndarray:
state = self.unwrapped.clone_state()
return numpy.array((state, None), dtype=object)

def _restore_state(self, state) -> None:
"""Restore the state of the game from the provided numpy array."""
self.unwrapped.restore_state(state)

def get_restore(self) -> tuple:
"""Return a tuple containing all the information needed to clone the state of the env."""
return (
self.state_to_numpy(),
self.cur_score,
self.cur_steps,
self.pos,
self.room_time,
self.ram_death_state,
self.score_objects,
self.cur_lives,
)
Expand All @@ -308,20 +287,15 @@ def restore(self, data) -> None:
score,
steps,
pos,
room_time,
ram_death_state,
self.score_objects,
self.cur_lives,
) = data
self.env.reset()
self._restore_state(full_state)
self.ram = self.env.unwrapped.ale.getRAM()
self.unwrapped.restore_state(full_state)
self.ram = self.get_ram()
self.cur_score = score
self.cur_steps = steps
self.pos = pos
self.room_time = room_time
assert len(self.room_time) == 2
self.ram_death_state = ram_death_state

def is_transition_screen(self, obs) -> bool:
"""Return True if the current observation corresponds to a transition between rooms."""
Expand All @@ -345,14 +319,12 @@ def get_face_pixels(self, obs) -> set:

def is_pixel_death(self, obs, face_pixels):
"""Return a death signal extracted from the observation of the environment."""
# There are no face pixels and yet we are not in a transition screen. We
# There are no face pixels, and yet we are not in a transition screen. We
# must be dead!
if len(face_pixels) == 0:
# All of the screen except the bottom is black: this is not a death but a
# All the screen except the bottom is black: this is not a death but a
# room transition. Ignore.
if self.is_transition_screen(obs): # pragma: no cover
return False
return True
return not self.is_transition_screen(obs) # pragma: no cover

# We already checked for the presence of no face pixels, however,
# sometimes we can die and still have face pixels. In those cases,
Expand All @@ -364,52 +336,18 @@ def is_pixel_death(self, obs, face_pixels):
pixel[1] + neighbor[1],
) in face_pixels: # pragma: no cover
return False

return True # pragma: no cover

def is_ram_death(self) -> bool:
"""Return a death signal extracted from the ram of the environment."""
self.cur_lives = max(self.ram[58], self.cur_lives)
return self.ram[55] != 0 or self.ram[58] < self.cur_lives

def get_pos(self) -> MontezumaPosLevel:
"""Return the current pos."""
assert self.pos is not None
return self.pos

@staticmethod
def get_room_xy(room) -> tuple[None | tuple[int, int]]:
"""Get the tuple that encodes the provided room."""
if KNOWN_XY[room] is None:
for y, loc in enumerate(PYRAMID):
if room in loc:
KNOWN_XY[room] = (loc.index(room), y)
break
return KNOWN_XY[room]

@staticmethod
def get_room_out_of_bounds(room_x, room_y) -> bool:
"""Return a boolean indicating if the provided tuple represents and invalid room."""
return room_y < 0 or room_x < 0 or room_y >= len(PYRAMID) or room_x >= len(PYRAMID[0])

@staticmethod
def get_room_from_xy(room_x, room_y) -> int:
"""Return the number of the room from a tuple."""
return PYRAMID[room_y][room_x]

@staticmethod
def make_pos(score, pos) -> MontezumaPosLevel:
"""Create a MontezumaPosLevel object using the provided data."""
return MontezumaPosLevel(pos.level, score, pos.room, pos.x, pos.y)

def render(self, mode="human", **kwargs) -> None | numpy.ndarray:
"""Render the environment."""
return self.env.render()


# ------------------------------------------------------------------------------


class MontezumaEnv(AtariEnv):
"""Plangym implementation of the MontezumaEnv environment optimized for planning."""

Expand Down Expand Up @@ -480,26 +418,20 @@ def get_state(self) -> numpy.ndarray:
score,
steps,
pos,
room_time,
ram_death_state,
score_objects,
cur_lives,
) = data
room_time = room_time if room_time[0] is not None else (-1, -1)
assert len(room_time) == 2

metadata = numpy.array(
[
float(score),
float(steps),
float(room_time[0]),
float(room_time[1]),
float(ram_death_state),
float(score_objects),
float(cur_lives),
],
dtype=float,
)
assert len(metadata) == 7
assert len(metadata) == 4
posarray = numpy.array(pos.tuple, dtype=float)
assert len(posarray) == 5
return numpy.concatenate([full_state, metadata, posarray])
Expand All @@ -522,16 +454,13 @@ def set_state(self, state: numpy.ndarray):
x=float(pos_vals[3]),
y=float(pos_vals[4]),
)
score, steps, rt0, rt1, ram_death_state, score_objects, cur_lives = state[-12:-5].tolist()
room_time = (rt0, rt1) if rt0 != -1 and rt1 != -1 else (None, None)
score, steps, score_objects, cur_lives = state[-9:-5].tolist()
full_state = state[0]
data = (
full_state,
score,
steps,
pos,
room_time,
int(ram_death_state),
bool(score_objects),
int(cur_lives),
)
Expand Down
Loading

0 comments on commit 398aab5

Please sign in to comment.