From 69ab43ec19a9fb9892c26a0f99bab1d7c52c6466 Mon Sep 17 00:00:00 2001 From: Joery Date: Sat, 11 Nov 2023 11:03:21 +0100 Subject: [PATCH] Extend State bound with an Optional protocol that has a attribute. Extend Environment.reset to accept an field. --- conftest.py | 4 ++- examples/counting_env.py | 4 ++- jit_env/__init__.py | 3 ++- jit_env/_core.py | 57 ++++++++++++++++++++++++++++++++-------- jit_env/compat.py | 9 ++++--- jit_env/wrappers.py | 27 ++++++++++++++----- 6 files changed, 80 insertions(+), 24 deletions(-) diff --git a/conftest.py b/conftest.py index 817f384..0111eef 100644 --- a/conftest.py +++ b/conftest.py @@ -20,7 +20,9 @@ def __init__(self): def reset( self, - key: jax.random.KeyArray + key: jax.random.KeyArray, + *, + options: jit_env.EnvOptions = None ) -> tuple[DummyState, jit_env.TimeStep]: return DummyState(key=key), jit_env.restart(jax.numpy.zeros(())) diff --git a/examples/counting_env.py b/examples/counting_env.py index 1a80038..e624989 100644 --- a/examples/counting_env.py +++ b/examples/counting_env.py @@ -42,7 +42,9 @@ def __init__(self, maximum: int | Integer[jax.Array, '']): def reset( self, - key: PRNGKeyArray + key: PRNGKeyArray, + *, + options=None ) -> tuple[MyState, jit_env.TimeStep]: state = MyState(key=key, count=jnp.zeros((), jnp.int32)) return state, jit_env.restart(state.count, shape=()) diff --git a/jit_env/__init__.py b/jit_env/__init__.py index c22d7b6..501870d 100644 --- a/jit_env/__init__.py +++ b/jit_env/__init__.py @@ -11,7 +11,8 @@ State as State, Observation as Observation, RewardT as Reward, - DiscountT as Discount + DiscountT as Discount, + EnvOptions ) from jit_env import specs diff --git a/jit_env/_core.py b/jit_env/_core.py index c018ddc..8b2932c 100644 --- a/jit_env/_core.py +++ b/jit_env/_core.py @@ -11,10 +11,13 @@ from typing import ( Any, TYPE_CHECKING, TypeVar, Generic, Sequence, - Protocol, Callable + Protocol, Callable, Union ) +from typing_extensions import TypeAlias from dataclasses import field +import jax + if TYPE_CHECKING: # pragma: no cover # See: https://github.com/python/mypy/issues/6239 from dataclasses import dataclass @@ -42,15 +45,31 @@ class StateProtocol(Protocol): key: PRNGKeyArray -# The following should all be valid Jax types +class StateWithOptionsProtocol(StateProtocol): + """Extension of StateProtocol to memorize Environment.reset Options.""" + options: EnvOptions = None + + +# The following should all be valid Jax types (not explicitly enforced) +State = TypeVar("State", bound=Union[StateProtocol, StateWithOptionsProtocol]) Action = TypeVar("Action") Observation = TypeVar("Observation") -StepT = TypeVar("StepT", bound=Int8[Array, '']) -RewardT = TypeVar("RewardT", bound=PyTree[ArrayLike]) -DiscountT = TypeVar("DiscountT", bound=PyTree[ArrayLike]) - -State = TypeVar("State", bound=StateProtocol) +StepT = TypeVar("StepT", bound=PyTree[ + Union[Int8[Array, ''], int, jax.ShapeDtypeStruct, None] +]) +RewardT = TypeVar("RewardT", bound=PyTree[ + Union[ArrayLike, jax.ShapeDtypeStruct, None] +]) +DiscountT = TypeVar("DiscountT", bound=PyTree[ + Union[ArrayLike, jax.ShapeDtypeStruct, None] +]) + +# Modify Environment behaviour through reset with a generic datastructure. +# See also the Flax documentation on how to effectively manipulate such a +# structure in combination with Jax Transforms like Jit or Vmap: +# - https://flax.readthedocs.io/en/latest/api_reference/flax.struct.html +EnvOptions: TypeAlias = Union[PyTree[Any], dict[str, Any], None] class StepType: @@ -156,15 +175,31 @@ def unwrapped(self) -> Environment[ @abc.abstractmethod def reset( - self, - key: PRNGKeyArray + self, + key: PRNGKeyArray, + *, + options: EnvOptions = None ) -> tuple[ State, TimeStep[Observation, RewardT, DiscountT, Int8[Array, '']] ]: """Starts a new episode as a functionally pure transformation. + Optionally, one can pass in an `options` structure to modify the + reset behaviour slightly. Note that this poses the dangers of + non-homogenous computations (i.e., non-SIMD). As an example, if + array shapes change due to a dimensionality flag in `options`, then + a jax.jit compiled Environment may re-compile or even crash! + + Thus, the `options` should be carefully used when doing something like + curriculum learning or adversarial Environment creation. Functionality + like increasing the number of obstacles/ enemies in a maze is fine as + this is agnostic to the environment dimensionality. + Args: key: Pseudo RNG Key to initialize `State` with. + * + options (kw-only): + Optional arguments for modifying environment parameters. Returns: A tuple of `State` and `TimeStep` at indices; @@ -349,10 +384,10 @@ def unwrapped(self) -> Environment[ """Helper function to unpack Composite Environments to the base.""" return self.env.unwrapped - def reset(self, key: PRNGKeyArray) -> tuple[ + def reset(self, key: PRNGKeyArray, *, options: EnvOptions = None) -> tuple[ State, TimeStep[Observation, RewardT, DiscountT, Int8[Array, '']] ]: - return self.env.reset(key) + return self.env.reset(key, options=options) def step(self, state: State, action: Action) -> tuple[ State, TimeStep[Observation, RewardT, DiscountT, Int8[Array, '']] diff --git a/jit_env/compat.py b/jit_env/compat.py index 0431989..1fce9bf 100644 --- a/jit_env/compat.py +++ b/jit_env/compat.py @@ -88,10 +88,13 @@ class ToDeepmindEnv(dm_env.Environment): def __init__( self, env: _core.Environment, - rng: _jax.random.KeyArray = _jax.random.PRNGKey(0) + rng: _jax.random.KeyArray = _jax.random.PRNGKey(0), + *, + options: _core.EnvOptions = None ): self.env = env self.rng = rng + self.options = options self._state = None @@ -101,7 +104,7 @@ def __init__( def reset(self) -> dm_env.TimeStep: self.rng, key = _jax.random.split(self.rng) - self._state, step = self.env.reset(key) + self._state, step = self.env.reset(key, options=self.options) return dm_env.restart(step.observation) def step(self, action) -> dm_env.TimeStep: @@ -248,7 +251,7 @@ def reset( self._seed(seed) self.rng, reset_key = _jax.random.split(self.rng) - self.env_state, step = self.env.reset(reset_key) + self.env_state, step = self.env.reset(reset_key, options=options) return step.observation, (step.extras or {}) diff --git a/jit_env/wrappers.py b/jit_env/wrappers.py index 600b16e..9b67999 100644 --- a/jit_env/wrappers.py +++ b/jit_env/wrappers.py @@ -11,6 +11,8 @@ import jax as _jax +from jaxtyping import PRNGKeyArray as _PRNGKeyArray + import jit_env from jit_env import _core from jit_env import specs as _specs @@ -48,9 +50,11 @@ def __repr__(self) -> str: def reset( self, - key: _jax.random.KeyArray + key: _PRNGKeyArray, + *, + options: _core.EnvOptions = None ) -> tuple[_core.State, _core.TimeStep]: - return self._reset_fun(key) + return self._reset_fun(key, options=options) def step( self, @@ -99,9 +103,11 @@ def __repr__(self) -> str: def reset( self, - key: _jax.random.KeyArray # (Batch, dim_key) + key: _PRNGKeyArray, # (Batch, dim_key) + *, + options: _core.EnvOptions = None ) -> tuple[_core.State, _core.TimeStep]: - return self._reset_fun(key) + return self._reset_fun(key, options=options) def step( self, @@ -217,9 +223,13 @@ def __repr__(self) -> str: def reset( self, - key: _jax.random.KeyArray + key: _PRNGKeyArray, + *, + options: _core.EnvOptions = None ) -> tuple[_core.State, _core.TimeStep]: - return super().reset(_jax.random.split(key, num=self.num)) + return super().reset( + _jax.random.split(key, num=self.num), options=options + ) class ResetMixin: @@ -257,8 +267,11 @@ def _auto_reset( and extras fields are copied from the `step` argument. """ + # See jit_env.StateWithOptionsProtocol + options = getattr(state, 'options', None) + key, _ = _jax.random.split(state.key) - state, reset_timestep = self.env.reset(key) + state, reset_timestep = self.env.reset(key, options=options) timestep = _core.TimeStep( step_type=reset_timestep.step_type, # Overwrite step