Skip to content

Commit

Permalink
[RLlib] [MultiAgentEnv Refactor #2] Change space types for BaseEnvs
Browse files Browse the repository at this point in the history
… and `MultiAgentEnvs` (ray-project#21063)
  • Loading branch information
avnishn authored Jan 6, 2022
1 parent 8b4cb45 commit 39f8072
Show file tree
Hide file tree
Showing 5 changed files with 277 additions and 60 deletions.
41 changes: 19 additions & 22 deletions rllib/env/base_env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from typing import Callable, Tuple, Optional, List, Dict, Any, TYPE_CHECKING,\
Union
Union, Set

import gym
import ray
Expand Down Expand Up @@ -198,14 +198,13 @@ def get_sub_environments(
return []

@PublicAPI
def get_agent_ids(self) -> Dict[EnvID, List[AgentID]]:
"""Return the agent ids for each sub-environment.
def get_agent_ids(self) -> Set[AgentID]:
"""Return the agent ids for the sub_environment.
Returns:
A dict mapping from env_id to a list of agent_ids.
All agent ids for each the environment.
"""
logger.warning("get_agent_ids() has not been implemented")
return {}
return {_DUMMY_AGENT_ID}

@PublicAPI
def try_render(self, env_id: Optional[EnvID] = None) -> None:
Expand Down Expand Up @@ -234,8 +233,8 @@ def get_unwrapped(self) -> List[EnvType]:

@PublicAPI
@property
def observation_space(self) -> gym.spaces.Dict:
"""Returns the observation space for each environment.
def observation_space(self) -> gym.Space:
"""Returns the observation space for each agent.
Note: samples from the observation space need to be preprocessed into a
`MultiEnvDict` before being used by a policy.
Expand All @@ -248,7 +247,7 @@ def observation_space(self) -> gym.spaces.Dict:
@PublicAPI
@property
def action_space(self) -> gym.Space:
"""Returns the action space for each environment.
"""Returns the action space for each agent.
Note: samples from the action space need to be preprocessed into a
`MultiEnvDict` before being passed to `send_actions`.
Expand All @@ -270,6 +269,7 @@ def action_space_sample(self, agent_id: list = None) -> MultiEnvDict:
Returns:
A random action for each environment.
"""
logger.warning("action_space_sample() has not been implemented")
del agent_id
return {}

Expand All @@ -286,6 +286,7 @@ def observation_space_sample(self, agent_id: list = None) -> MultiEnvDict:
A random action for each environment.
"""
logger.warning("observation_space_sample() has not been implemented")
del agent_id
return {}

@PublicAPI
Expand Down Expand Up @@ -326,8 +327,7 @@ def action_space_contains(self, x: MultiEnvDict) -> bool:
"""
return self._space_contains(self.action_space, x)

@staticmethod
def _space_contains(space: gym.Space, x: MultiEnvDict) -> bool:
def _space_contains(self, space: gym.Space, x: MultiEnvDict) -> bool:
"""Check if the given space contains the observations of x.
Args:
Expand All @@ -337,17 +337,14 @@ def _space_contains(space: gym.Space, x: MultiEnvDict) -> bool:
Returns:
True if the observations of x are contained in space.
"""
# this removes the agent_id key and inner dicts
# in MultiEnvDicts
flattened_obs = {
env_id: list(obs.values())
for env_id, obs in x.items()
}
ret = True
for env_id in flattened_obs:
for obs in flattened_obs[env_id]:
ret = ret and space[env_id].contains(obs)
return ret
agents = set(self.get_agent_ids())
for multi_agent_dict in x.values():
for agent_id, obs in multi_agent_dict:
if (agent_id not in agents) or (
not space[agent_id].contains(obs)):
return False

return True


# Fixed agent identifier when there is only the single agent in the env
Expand Down
216 changes: 199 additions & 17 deletions rllib/env/multi_agent_env.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import gym
from typing import Callable, Dict, List, Tuple, Type, Optional, Union
import logging
from typing import Callable, Dict, List, Tuple, Type, Optional, Union, Set

from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.env_context import EnvContext
from ray.rllib.utils.annotations import ExperimentalAPI, override, PublicAPI
from ray.rllib.utils.annotations import ExperimentalAPI, override, PublicAPI, \
DeveloperAPI
from ray.rllib.utils.typing import AgentID, EnvID, EnvType, MultiAgentDict, \
MultiEnvDict

# If the obs space is Dict type, look for the global state under this key.
ENV_STATE = "state"

logger = logging.getLogger(__name__)


@PublicAPI
class MultiAgentEnv(gym.Env):
Expand All @@ -20,6 +24,15 @@ class MultiAgentEnv(gym.Env):
referred to as "agents" or "RL agents".
"""

def __init__(self):
self.observation_space = None
self.action_space = None
self._agent_ids = {}

# do the action and observation spaces map from agent ids to spaces
# for the individual agents?
self._spaces_in_preferred_format = None

@PublicAPI
def reset(self) -> MultiAgentDict:
"""Resets the env and returns observations from ready agents.
Expand Down Expand Up @@ -81,20 +94,127 @@ def step(
"""
raise NotImplementedError

@ExperimentalAPI
def observation_space_contains(self, x: MultiAgentDict) -> bool:
"""Checks if the observation space contains the given key.
Args:
x: Observations to check.
Returns:
True if the observation space contains the given all observations
in x.
"""
if not hasattr(self, "_spaces_in_preferred_format") or \
self._spaces_in_preferred_format is None:
self._spaces_in_preferred_format = \
self._check_if_space_maps_agent_id_to_sub_space()
if self._spaces_in_preferred_format:
return self.observation_space.contains(x)

logger.warning("observation_space_contains() has not been implemented")
return True

@ExperimentalAPI
def action_space_contains(self, x: MultiAgentDict) -> bool:
"""Checks if the action space contains the given action.
Args:
x: Actions to check.
Returns:
True if the action space contains all actions in x.
"""
if not hasattr(self, "_spaces_in_preferred_format") or \
self._spaces_in_preferred_format is None:
self._spaces_in_preferred_format = \
self._check_if_space_maps_agent_id_to_sub_space()
if self._spaces_in_preferred_format:
return self.action_space.contains(x)

logger.warning("action_space_contains() has not been implemented")
return True

@ExperimentalAPI
def action_space_sample(self, agent_ids: list = None) -> MultiAgentDict:
"""Returns a random action for each environment, and potentially each
agent in that environment.
Args:
agent_ids: List of agent ids to sample actions for. If None or
empty list, sample actions for all agents in the
environment.
Returns:
A random action for each environment.
"""
if not hasattr(self, "_spaces_in_preferred_format") or \
self._spaces_in_preferred_format is None:
self._spaces_in_preferred_format = \
self._check_if_space_maps_agent_id_to_sub_space()
if self._spaces_in_preferred_format:
if agent_ids is None:
agent_ids = self.get_agent_ids()
samples = self.action_space.sample()
return {agent_id: samples[agent_id] for agent_id in agent_ids}
logger.warning("action_space_sample() has not been implemented")
del agent_ids
return {}

@ExperimentalAPI
def observation_space_sample(self, agent_ids: list = None) -> MultiEnvDict:
"""Returns a random observation from the observation space for each
agent if agent_ids is None, otherwise returns a random observation for
the agents in agent_ids.
Args:
agent_ids: List of agent ids to sample actions for. If None or
empty list, sample actions for all agents in the
environment.
Returns:
A random action for each environment.
"""

if not hasattr(self, "_spaces_in_preferred_format") or \
self._spaces_in_preferred_format is None:
self._spaces_in_preferred_format = \
self._check_if_space_maps_agent_id_to_sub_space()
if self._spaces_in_preferred_format:
if agent_ids is None:
agent_ids = self.get_agent_ids()
samples = self.observation_space.sample()
samples = {agent_id: samples[agent_id] for agent_id in agent_ids}
return samples
logger.warning("observation_space_sample() has not been implemented")
del agent_ids
return {}

@PublicAPI
def get_agent_ids(self) -> Set[AgentID]:
"""Returns a set of agent ids in the environment.
Returns:
set of agent ids.
"""
if not isinstance(self._agent_ids, set):
self._agent_ids = set(self._agent_ids)
return self._agent_ids

@PublicAPI
def render(self, mode=None) -> None:
"""Tries to render the environment."""

# By default, do nothing.
pass

# yapf: disable
# __grouping_doc_begin__
# yapf: disable
# __grouping_doc_begin__
@ExperimentalAPI
def with_agent_groups(
self,
groups: Dict[str, List[AgentID]],
obs_space: gym.Space = None,
self,
groups: Dict[str, List[AgentID]],
obs_space: gym.Space = None,
act_space: gym.Space = None) -> "MultiAgentEnv":
"""Convenience method for grouping together agents in this env.
Expand Down Expand Up @@ -132,8 +252,9 @@ def with_agent_groups(
from ray.rllib.env.wrappers.group_agents_wrapper import \
GroupAgentsWrapper
return GroupAgentsWrapper(self, groups, obs_space, act_space)
# __grouping_doc_end__
# yapf: enable

# __grouping_doc_end__
# yapf: enable

@PublicAPI
def to_base_env(
Expand Down Expand Up @@ -182,6 +303,20 @@ def to_base_env(

return env

@DeveloperAPI
def _check_if_space_maps_agent_id_to_sub_space(self) -> bool:
# do the action and observation spaces map from agent ids to spaces
# for the individual agents?
obs_space_check = (
hasattr(self, "observation_space")
and isinstance(self.observation_space, gym.spaces.Dict)
and set(self.observation_space.keys()) == self.get_agent_ids())
action_space_check = (
hasattr(self, "action_space")
and isinstance(self.action_space, gym.spaces.Dict)
and set(self.action_space.keys()) == self.get_agent_ids())
return obs_space_check and action_space_check


def make_multi_agent(
env_name_or_creator: Union[str, Callable[[EnvContext], EnvType]],
Expand Down Expand Up @@ -242,6 +377,40 @@ def __init__(self, config=None):
self.dones = set()
self.observation_space = self.agents[0].observation_space
self.action_space = self.agents[0].action_space
self._agent_ids = set(range(num))

@override(MultiAgentEnv)
def observation_space_sample(self,
agent_ids: list = None) -> MultiAgentDict:
if agent_ids is None:
agent_ids = list(range(len(self.agents)))
obs = {
agent_id: self.observation_space.sample()
for agent_id in agent_ids
}

return obs

@override(MultiAgentEnv)
def action_space_sample(self,
agent_ids: list = None) -> MultiAgentDict:
if agent_ids is None:
agent_ids = list(range(len(self.agents)))
actions = {
agent_id: self.action_space.sample()
for agent_id in agent_ids
}

return actions

@override(MultiAgentEnv)
def action_space_contains(self, x: MultiAgentDict) -> bool:
return all(self.action_space.contains(val) for val in x.values())

@override(MultiAgentEnv)
def observation_space_contains(self, x: MultiAgentDict) -> bool:
return all(
self.observation_space.contains(val) for val in x.values())

@override(MultiAgentEnv)
def reset(self):
Expand Down Expand Up @@ -277,7 +446,7 @@ def __init__(self, make_env: Callable[[int], EnvType],
Args:
make_env (Callable[[int], EnvType]): Factory that produces a new
MultiAgentEnv intance. Must be defined, if the number of
MultiAgentEnv instance. Must be defined, if the number of
existing envs is less than num_envs.
existing_envs (List[MultiAgentEnv]): List of already existing
multi-agent envs.
Expand Down Expand Up @@ -355,18 +524,31 @@ def try_render(self, env_id: Optional[EnvID] = None) -> None:
@override(BaseEnv)
@PublicAPI
def observation_space(self) -> gym.spaces.Dict:
space = {
_id: env.observation_space
for _id, env in enumerate(self.envs)
}
return gym.spaces.Dict(space)
self.envs[0].observation_space

@property
@override(BaseEnv)
@PublicAPI
def action_space(self) -> gym.Space:
space = {_id: env.action_space for _id, env in enumerate(self.envs)}
return gym.spaces.Dict(space)
return self.envs[0].action_space

@override(BaseEnv)
def observation_space_contains(self, x: MultiEnvDict) -> bool:
return all(
self.envs[0].observation_space_contains(val) for val in x.values())

@override(BaseEnv)
def action_space_contains(self, x: MultiEnvDict) -> bool:
return all(
self.envs[0].action_space_contains(val) for val in x.values())

@override(BaseEnv)
def observation_space_sample(self, agent_ids: list = None) -> MultiEnvDict:
return self.envs[0].observation_space_sample(agent_ids)

@override(BaseEnv)
def action_space_sample(self, agent_ids: list = None) -> MultiEnvDict:
return self.envs[0].action_space_sample(agent_ids)


class _MultiAgentEnvState:
Expand Down
Loading

0 comments on commit 39f8072

Please sign in to comment.