Skip to content

Commit

Permalink
Migrate imitation envs to seals (#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocamonde authored Oct 4, 2022
1 parent 7def17c commit 3d2cd41
Show file tree
Hide file tree
Showing 10 changed files with 753 additions and 56 deletions.
2 changes: 1 addition & 1 deletion ci/code_checks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ set -e # quit immediately on error

echo "Source format checking"
flake8 ${SRC_FILES[@]}
black --check ${SRC_FILES}
black --check ${SRC_FILES[@]}
codespell -I .codespell.skip --skip='*.pyc' ${SRC_FILES[@]}

if [ -x "`which circleci`" ]; then
Expand Down
2 changes: 2 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[mypy]
ignore_missing_imports = true
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def get_readme() -> str:
"flake8-docstrings",
"flake8-isort",
"isort",
"matplotlib",
"mypy",
"pydocstyle",
"pytest",
Expand Down Expand Up @@ -137,7 +138,7 @@ def get_readme() -> str:
packages=find_packages("src"),
package_dir={"": "src"},
package_data={"seals": ["py.typed"]},
install_requires=["gym"],
install_requires=["gym", "numpy"],
tests_require=TESTS_REQUIRE,
extras_require={
# recommended packages for development
Expand Down
261 changes: 230 additions & 31 deletions src/seals/base_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ class ResettablePOMDP(gym.Env, abc.ABC, Generic[State, Observation, Action]):
meet these criteria.
"""

_state_space: gym.Space
_observation_space: gym.Space
_action_space: gym.Space
_cur_state: Optional[State]
_n_actions_taken: Optional[int]

def __init__(
self,
*,
Expand All @@ -41,8 +47,8 @@ def __init__(
self._observation_space = observation_space
self._action_space = action_space

self.cur_state: Optional[State] = None
self._n_actions_taken: Optional[int] = None
self._cur_state = None
self._n_actions_taken = None
self.seed()

@abc.abstractmethod
Expand Down Expand Up @@ -86,6 +92,19 @@ def n_actions_taken(self) -> int:
assert self._n_actions_taken is not None
return self._n_actions_taken

@property
def state(self) -> State:
"""Current state."""
assert self._cur_state is not None
return self._cur_state

@state.setter
def state(self, state: State):
"""Set current state."""
if state not in self.state_space:
raise ValueError(f"{state} not in {self.state_space}")
self._cur_state = state

def seed(self, seed=None) -> Sequence[int]:
"""Set random seed."""
if seed is None:
Expand All @@ -97,32 +116,57 @@ def seed(self, seed=None) -> Sequence[int]:

def reset(self) -> Observation:
"""Reset episode and return initial observation."""
self.cur_state = self.initial_state()
assert self.cur_state in self.state_space, f"unexpected state {self.cur_state}"
self.state = self.initial_state()
self._n_actions_taken = 0
return self.obs_from_state(self.cur_state)
return self.obs_from_state(self.state)

def step(self, action: Action) -> Tuple[Observation, float, bool, dict]:
"""Transition state using given action."""
if self.cur_state is None or self._n_actions_taken is None:
if self._cur_state is None or self._n_actions_taken is None:
raise ValueError("Need to call reset() before first step()")
if action not in self.action_space:
raise ValueError(f"{action} not in {self.action_space}")

old_state = self.cur_state
self.cur_state = self.transition(self.cur_state, action)
assert self.cur_state in self.state_space, f"unexpected state {self.cur_state}"
obs = self.obs_from_state(self.cur_state)
assert obs in self.observation_space, f"{obs} not in {self.observation_space}"
rew = self.reward(old_state, action, self.cur_state)
done = self.terminal(self.cur_state, self._n_actions_taken)
old_state = self.state
self.state = self.transition(self.state, action)
obs = self.obs_from_state(self.state)
assert obs in self.observation_space
reward = self.reward(old_state, action, self.state)
self._n_actions_taken += 1
done = self.terminal(self.state, self.n_actions_taken)

infos = {"old_state": old_state, "new_state": self._cur_state}
return obs, reward, done, infos


class ExposePOMDPStateWrapper(gym.Wrapper, Generic[State, Observation, Action]):
"""A wrapper that exposes the current state of the POMDP as the observation."""

def __init__(self, env: ResettablePOMDP[State, Observation, Action]) -> None:
"""Build wrapper.
Args:
env: POMDP to wrap.
"""
super().__init__(env)
self._observation_space = env.state_space

def reset(self) -> State:
"""Reset environment and return initial state."""
self.env.reset()
return self.env.state

infos = {"old_state": old_state, "new_state": self.cur_state}
return obs, rew, done, infos
def step(self, action) -> Tuple[State, float, bool, dict]:
"""Transition state using given action."""
obs, reward, done, info = self.env.step(action)
return self.env.state, reward, done, info


class ResettableMDP(ResettablePOMDP[State, State, Action], Generic[State, Action]):
class ResettableMDP(
ResettablePOMDP[State, State, Action],
abc.ABC,
Generic[State, Action],
):
"""ABC for MDPs that are resettable."""

def __init__(
Expand All @@ -148,8 +192,20 @@ def obs_from_state(self, state: State) -> State:
return state


class TabularModelMDP(ResettableMDP[int, int]):
"""Base class for tabular environments with known dynamics."""
# TODO(juan) this does not implement the .render() method,
# so in theory it should not be instantiated directly.
# Not sure why this is not raising an error?
class BaseTabularModelPOMDP(ResettablePOMDP[int, Observation, int]):
"""Base class for tabular environments with known dynamics.
This is the general class that also allows subclassing for creating
MDP (where observation == state) or POMDP (where observation != state).
"""

transition_matrix: np.ndarray
reward_matrix: np.ndarray

state_space: spaces.Discrete

def __init__(
self,
Expand Down Expand Up @@ -179,14 +235,28 @@ def __init__(
ValueError: `transition_matrix`, `reward_matrix` or
`initial_state_dist` have shapes different to specified above.
"""
n_states, n_actions, n_next_states = transition_matrix.shape
if n_states != n_next_states:
# The following matrices should conform to the shapes below:

# transition matrix: n_states x n_actions x n_states
n_states = transition_matrix.shape[0]
if n_states != transition_matrix.shape[2]:
raise ValueError(
"Malformed transition_matrix:\n"
f"transition_matrix.shape: {transition_matrix.shape}\n"
f"{n_states} != {n_next_states}",
f"{n_states} != {transition_matrix.shape[2]}",
)

# reward matrix: n_states x n_actions x n_states
# OR n_states x n_actions
# OR n_states
if reward_matrix.shape != transition_matrix.shape[: len(reward_matrix.shape)]:
raise ValueError(
"transition_matrix and reward_matrix are not compatible:\n"
f"transition_matrix.shape: {transition_matrix.shape}\n"
f"reward_matrix.shape: {reward_matrix.shape}",
)

# initial state dist: n_states
if initial_state_dist is None:
initial_state_dist = util.one_hot_encoding(0, n_states)
if initial_state_dist.ndim != 1:
Expand All @@ -197,28 +267,32 @@ def __init__(
if initial_state_dist.shape[0] != n_states:
raise ValueError(
"transition_matrix and initial_state_dist are not compatible:\n"
f"n_states = {n_states}\n"
f"number of states = {n_states}\n"
f"len(initial_state_dist) = {len(initial_state_dist)}",
)

if reward_matrix.shape != transition_matrix.shape[: len(reward_matrix.shape)]:
raise ValueError(
"transition_matrix and reward_matrix are not compatible:\n"
f"transition_matrix.shape: {transition_matrix.shape}\n"
f"reward_matrix.shape: {reward_matrix.shape}",
)

self.transition_matrix = transition_matrix
self.reward_matrix = reward_matrix
self._feature_matrix = None
self.horizon = horizon
self.initial_state_dist = initial_state_dist

super().__init__(
state_space=spaces.Discrete(n_states),
action_space=spaces.Discrete(n_actions),
state_space=self._construct_state_space(),
action_space=self._construct_action_space(),
observation_space=self._construct_observation_space(),
)

def _construct_state_space(self) -> gym.Space:
return spaces.Discrete(self.state_dim)

def _construct_action_space(self) -> gym.Space:
return spaces.Discrete(self.action_dim)

@abc.abstractmethod
def _construct_observation_space(self) -> gym.Space:
pass # pragma: no cover

def initial_state(self) -> int:
"""Samples from the initial state distribution."""
return util.sample_distribution(
Expand Down Expand Up @@ -250,3 +324,128 @@ def feature_matrix(self):
n_states = self.state_space.n
self._feature_matrix = np.eye(n_states)
return self._feature_matrix

@property
def state_dim(self):
"""Number of states in this MDP (int)."""
return self.transition_matrix.shape[0]

@property
def action_dim(self) -> int:
"""Number of action vectors (int)."""
return self.transition_matrix.shape[1]


class TabularModelPOMDP(BaseTabularModelPOMDP[np.ndarray]):
"""Tabular model POMDP.
This class is specifically for environments where observation != state,
from both a typing perspective but also by defining the method that
draws observations from the state.
The tabular model is deterministic in drawing observations from the state,
in that given a certain state, the observation is always the same;
a vector with self.obs_dim entries.
"""

observation_matrix: np.ndarray

def __init__(
self,
*,
transition_matrix: np.ndarray,
observation_matrix: np.ndarray,
reward_matrix: np.ndarray,
horizon: float = np.inf,
initial_state_dist: Optional[np.ndarray] = None,
):
"""Initializes a tabular model POMDP."""
self.observation_matrix = observation_matrix
super().__init__(
transition_matrix=transition_matrix,
reward_matrix=reward_matrix,
horizon=horizon,
initial_state_dist=initial_state_dist,
)

# observation matrix: n_states x n_observations
if observation_matrix.shape[0] != self.state_dim:
raise ValueError(
"transition_matrix and observation_matrix are not compatible:\n"
f"transition_matrix.shape[0]: {self.state_dim}\n"
f"observation_matrix.shape[0]: {observation_matrix.shape[0]}",
)

def _construct_observation_space(self) -> gym.Space:
min_val: float
max_val: float
try:
dtype_iinfo = np.iinfo(self.obs_dtype)
min_val, max_val = dtype_iinfo.min, dtype_iinfo.max
except ValueError:
min_val = -np.inf
max_val = np.inf
return spaces.Box(
low=min_val,
high=max_val,
shape=(self.obs_dim,),
dtype=self.obs_dtype,
)

def obs_from_state(self, state: int) -> np.ndarray:
"""Computes observation from state."""
# Copy so it can't be mutated in-place (updates will be reflected in
# self.observation_matrix!)
obs = self.observation_matrix[state].copy()
assert obs.ndim == 1, obs.shape
return obs

@property
def obs_dim(self) -> int:
"""Size of observation vectors for this MDP."""
return self.observation_matrix.shape[1]

@property
def obs_dtype(self) -> int:
"""Data type of observation vectors (e.g. np.float32)."""
return self.observation_matrix.dtype


class TabularModelMDP(BaseTabularModelPOMDP[int]):
"""Tabular model MDP.
A tabular model MDP is a tabular MDP where the transition and reward
matrices are constant.
"""

def __init__(
self,
*,
transition_matrix: np.ndarray,
reward_matrix: np.ndarray,
horizon: float = np.inf,
initial_state_dist: Optional[np.ndarray] = None,
):
"""Initializes a tabular model MDP.
Args:
transition_matrix: Matrix of shape `(n_states, n_actions, n_states)`
containing transition probabilities.
reward_matrix: Matrix of shape `(n_states, n_actions, n_states)`
containing reward values.
initial_state_dist: Distribution over initial states. Shape `(n_states,)`.
horizon: Maximum number of steps to take in an episode.
"""
super().__init__(
transition_matrix=transition_matrix,
reward_matrix=reward_matrix,
horizon=horizon,
initial_state_dist=initial_state_dist,
)

def obs_from_state(self, state: int) -> int:
"""Identity since observation == state in an MDP."""
return state

def _construct_observation_space(self) -> gym.Space:
return self._construct_state_space()
Loading

0 comments on commit 3d2cd41

Please sign in to comment.