Skip to content

Commit

Permalink
Merge pull request #54 from epignatelli/envs
Browse files Browse the repository at this point in the history
Implement environments
  • Loading branch information
epignatelli authored May 29, 2024
2 parents b05a291 + 01c3178 commit bf22b15
Show file tree
Hide file tree
Showing 35 changed files with 2,214 additions and 317 deletions.
2 changes: 1 addition & 1 deletion assets/COPYRIGHT
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Copyright 2024 https://github.com/Farama-Foundation/Minigrid
The following images are under Apache 2.0 License as per https://github.com/Farama-Foundation/Minigrid/LICENSE.
A copy of the license is provided in the fileassets/LICENSE.
A copy of the license is provided in the file assets/LICENSE.

6 changes: 4 additions & 2 deletions navix/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@
entities,
grid,
observations,
tasks,
rewards,
environments,
terminations,
config,
spaces,
rendering
rendering,
transitions,
events,
)
2 changes: 1 addition & 1 deletion navix/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
# under the License.


__version__ = "0.4.0"
__version__ = "0.5.0"
__version_info__ = tuple(int(i) for i in __version__.split(".") if i.isdigit())
120 changes: 83 additions & 37 deletions navix/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,25 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations
from typing import Tuple

import jax
import jax.numpy as jnp
from jax import Array
import jax.numpy as jnp
import jax.tree_util as jtu

from .entities import Entities, State
from .entities import Entities
from .states import EventsManager, State
from .components import DISCARD_PILE_COORDS
from .grid import translate, rotate, positions_equal


class Directions:
EAST = 0
SOUTH = 1
WEST = 2
NORTH = 3
EAST = jnp.asarray(0)
SOUTH = jnp.asarray(1)
WEST = jnp.asarray(2)
NORTH = jnp.asarray(3)


def _rotate(state: State, spin: int) -> State:
Expand All @@ -52,17 +54,28 @@ def _rotate(state: State, spin: int) -> State:
return state


def _walkable(state: State, position: Array) -> Array:
def _can_walk_there(state: State, position: Array) -> Tuple[Array, EventsManager]:
# according to the grid
walkable = jnp.equal(state.grid[tuple(position)], 0)
events = jax.lax.cond(
walkable,
lambda: state.events,
lambda: state.events.record_grid_hit(position),
)

for k in state.entities:
same_position = positions_equal(state.entities[k].position, position)
events = jax.lax.cond(
jnp.any(same_position),
lambda x: x.record_walk_into(state.entities[k], position),
lambda x: x,
events,
)
obstructs = jnp.logical_and(
jnp.logical_not(state.entities[k].walkable),
positions_equal(state.entities[k].position, position),
jnp.logical_not(state.entities[k].walkable), same_position
)
walkable = jnp.logical_and(walkable, jnp.any(jnp.logical_not(obstructs)))
return jnp.asarray(walkable, dtype=jnp.bool_)
return jnp.asarray(walkable, dtype=jnp.bool_), events


def _move(state: State, direction: Array) -> State:
Expand All @@ -71,22 +84,12 @@ def _move(state: State, direction: Array) -> State:

player = state.get_player(idx=0)
new_position = translate(player.position, direction)
can_move = _walkable(state, new_position)
can_move, events = _can_walk_there(state, new_position)
new_position = jnp.where(can_move, new_position, player.position)
# update structs
player = player.replace(position=new_position)
state = state.set_player(player)
return state


def undefined(state: State) -> State:
# this is problematic because jax.lax.switch evaluates
# all *python* branches (no XLA computation is performed)
# even though only one is selected
# one option is the following, but this breaks type checking
# def raise_error(state: State) -> State:
# raise ValueError("Undefined action")
# jax.debug.callback(raise_error)
raise ValueError("Undefined action")
return state.replace(events=events)


def noop(state: State) -> State:
Expand Down Expand Up @@ -142,11 +145,27 @@ def pickup(state: State) -> State:
jnp.any(key_found), lambda: player.replace(pocket=key), lambda: player
)

# update events
events = jax.lax.cond(
jnp.any(key_found),
lambda: state.events.record_key_pickup(keys, position_in_front),
lambda: state.events,
)

state = state.set_player(player)
state = state.set_keys(keys)
state = state.set_events(events)
return state


def drop(state: State) -> State:
raise NotImplementedError()


def toggle(state: State) -> State:
raise NotImplementedError()


def open(state: State) -> State:
"""Unlocks and opens an openable object (like a door) if possible"""

Expand Down Expand Up @@ -179,22 +198,49 @@ def open(state: State) -> State:
jnp.any(can_open), lambda: player.replace(pocket=pocket), lambda: player
)

# update events
events = jax.lax.cond(
jnp.any(do_open),
lambda: state.events.record_door_opening(doors, position_in_front),
lambda: state.events,
)

state = state.set_player(player)
state = state.set_doors(doors)
state = state.set_events(events)

return state


def done(state: State) -> State:
return state


# TODO(epignatelli): a mutable dictionary here is dangerous
ACTIONS = {
# -1: undefined,
0: noop,
1: rotate_cw,
2: rotate_ccw,
3: forward,
4: right,
5: backward,
6: left,
7: pickup,
8: open,
}
# DEFAULT_ACTION_SET = (
# rotate_ccw,
# rotate_cw,
# forward,
# pickup,
# drop,
# toggle,
# done
# )
"""Default action set from Minigrid. See
https://github.com/Farama-Foundation/Minigrid/blob/master/minigrid/core/actions.py"""


COMPLETE_ACTION_SET = (
noop,
rotate_cw,
rotate_ccw,
forward,
right,
backward,
left,
pickup,
open,
done,
)


DEFAULT_ACTION_SET = COMPLETE_ACTION_SET
Loading

0 comments on commit bf22b15

Please sign in to comment.