From 39f8072eac851f0abcff2acaf9b9f15b3caf0917 Mon Sep 17 00:00:00 2001 From: Avnish Narayan <38871737+avnishn@users.noreply.github.com> Date: Thu, 6 Jan 2022 14:34:20 -0800 Subject: [PATCH] [RLlib] [MultiAgentEnv Refactor #2] Change space types for `BaseEnvs` and `MultiAgentEnvs` (#21063) --- rllib/env/base_env.py | 41 ++-- rllib/env/multi_agent_env.py | 216 ++++++++++++++++-- rllib/env/tests/test_multi_agent_env.py | 50 ++++ rllib/env/vector_env.py | 21 -- rllib/tests/test_nested_observation_spaces.py | 9 + 5 files changed, 277 insertions(+), 60 deletions(-) create mode 100644 rllib/env/tests/test_multi_agent_env.py diff --git a/rllib/env/base_env.py b/rllib/env/base_env.py index 760152c79369b..905bbd30d851e 100644 --- a/rllib/env/base_env.py +++ b/rllib/env/base_env.py @@ -1,6 +1,6 @@ import logging from typing import Callable, Tuple, Optional, List, Dict, Any, TYPE_CHECKING,\ - Union + Union, Set import gym import ray @@ -198,14 +198,13 @@ def get_sub_environments( return [] @PublicAPI - def get_agent_ids(self) -> Dict[EnvID, List[AgentID]]: - """Return the agent ids for each sub-environment. + def get_agent_ids(self) -> Set[AgentID]: + """Return the agent ids for the sub_environment. Returns: - A dict mapping from env_id to a list of agent_ids. + All agent ids for each the environment. """ - logger.warning("get_agent_ids() has not been implemented") - return {} + return {_DUMMY_AGENT_ID} @PublicAPI def try_render(self, env_id: Optional[EnvID] = None) -> None: @@ -234,8 +233,8 @@ def get_unwrapped(self) -> List[EnvType]: @PublicAPI @property - def observation_space(self) -> gym.spaces.Dict: - """Returns the observation space for each environment. + def observation_space(self) -> gym.Space: + """Returns the observation space for each agent. Note: samples from the observation space need to be preprocessed into a `MultiEnvDict` before being used by a policy. @@ -248,7 +247,7 @@ def observation_space(self) -> gym.spaces.Dict: @PublicAPI @property def action_space(self) -> gym.Space: - """Returns the action space for each environment. + """Returns the action space for each agent. Note: samples from the action space need to be preprocessed into a `MultiEnvDict` before being passed to `send_actions`. @@ -270,6 +269,7 @@ def action_space_sample(self, agent_id: list = None) -> MultiEnvDict: Returns: A random action for each environment. """ + logger.warning("action_space_sample() has not been implemented") del agent_id return {} @@ -286,6 +286,7 @@ def observation_space_sample(self, agent_id: list = None) -> MultiEnvDict: A random action for each environment. """ logger.warning("observation_space_sample() has not been implemented") + del agent_id return {} @PublicAPI @@ -326,8 +327,7 @@ def action_space_contains(self, x: MultiEnvDict) -> bool: """ return self._space_contains(self.action_space, x) - @staticmethod - def _space_contains(space: gym.Space, x: MultiEnvDict) -> bool: + def _space_contains(self, space: gym.Space, x: MultiEnvDict) -> bool: """Check if the given space contains the observations of x. Args: @@ -337,17 +337,14 @@ def _space_contains(space: gym.Space, x: MultiEnvDict) -> bool: Returns: True if the observations of x are contained in space. """ - # this removes the agent_id key and inner dicts - # in MultiEnvDicts - flattened_obs = { - env_id: list(obs.values()) - for env_id, obs in x.items() - } - ret = True - for env_id in flattened_obs: - for obs in flattened_obs[env_id]: - ret = ret and space[env_id].contains(obs) - return ret + agents = set(self.get_agent_ids()) + for multi_agent_dict in x.values(): + for agent_id, obs in multi_agent_dict: + if (agent_id not in agents) or ( + not space[agent_id].contains(obs)): + return False + + return True # Fixed agent identifier when there is only the single agent in the env diff --git a/rllib/env/multi_agent_env.py b/rllib/env/multi_agent_env.py index c5025bc949a05..aa5430f3ce7c0 100644 --- a/rllib/env/multi_agent_env.py +++ b/rllib/env/multi_agent_env.py @@ -1,15 +1,19 @@ import gym -from typing import Callable, Dict, List, Tuple, Type, Optional, Union +import logging +from typing import Callable, Dict, List, Tuple, Type, Optional, Union, Set from ray.rllib.env.base_env import BaseEnv from ray.rllib.env.env_context import EnvContext -from ray.rllib.utils.annotations import ExperimentalAPI, override, PublicAPI +from ray.rllib.utils.annotations import ExperimentalAPI, override, PublicAPI, \ + DeveloperAPI from ray.rllib.utils.typing import AgentID, EnvID, EnvType, MultiAgentDict, \ MultiEnvDict # If the obs space is Dict type, look for the global state under this key. ENV_STATE = "state" +logger = logging.getLogger(__name__) + @PublicAPI class MultiAgentEnv(gym.Env): @@ -20,6 +24,15 @@ class MultiAgentEnv(gym.Env): referred to as "agents" or "RL agents". """ + def __init__(self): + self.observation_space = None + self.action_space = None + self._agent_ids = {} + + # do the action and observation spaces map from agent ids to spaces + # for the individual agents? + self._spaces_in_preferred_format = None + @PublicAPI def reset(self) -> MultiAgentDict: """Resets the env and returns observations from ready agents. @@ -81,6 +94,113 @@ def step( """ raise NotImplementedError + @ExperimentalAPI + def observation_space_contains(self, x: MultiAgentDict) -> bool: + """Checks if the observation space contains the given key. + + Args: + x: Observations to check. + + Returns: + True if the observation space contains the given all observations + in x. + """ + if not hasattr(self, "_spaces_in_preferred_format") or \ + self._spaces_in_preferred_format is None: + self._spaces_in_preferred_format = \ + self._check_if_space_maps_agent_id_to_sub_space() + if self._spaces_in_preferred_format: + return self.observation_space.contains(x) + + logger.warning("observation_space_contains() has not been implemented") + return True + + @ExperimentalAPI + def action_space_contains(self, x: MultiAgentDict) -> bool: + """Checks if the action space contains the given action. + + Args: + x: Actions to check. + + Returns: + True if the action space contains all actions in x. + """ + if not hasattr(self, "_spaces_in_preferred_format") or \ + self._spaces_in_preferred_format is None: + self._spaces_in_preferred_format = \ + self._check_if_space_maps_agent_id_to_sub_space() + if self._spaces_in_preferred_format: + return self.action_space.contains(x) + + logger.warning("action_space_contains() has not been implemented") + return True + + @ExperimentalAPI + def action_space_sample(self, agent_ids: list = None) -> MultiAgentDict: + """Returns a random action for each environment, and potentially each + agent in that environment. + + Args: + agent_ids: List of agent ids to sample actions for. If None or + empty list, sample actions for all agents in the + environment. + + Returns: + A random action for each environment. + """ + if not hasattr(self, "_spaces_in_preferred_format") or \ + self._spaces_in_preferred_format is None: + self._spaces_in_preferred_format = \ + self._check_if_space_maps_agent_id_to_sub_space() + if self._spaces_in_preferred_format: + if agent_ids is None: + agent_ids = self.get_agent_ids() + samples = self.action_space.sample() + return {agent_id: samples[agent_id] for agent_id in agent_ids} + logger.warning("action_space_sample() has not been implemented") + del agent_ids + return {} + + @ExperimentalAPI + def observation_space_sample(self, agent_ids: list = None) -> MultiEnvDict: + """Returns a random observation from the observation space for each + agent if agent_ids is None, otherwise returns a random observation for + the agents in agent_ids. + + Args: + agent_ids: List of agent ids to sample actions for. If None or + empty list, sample actions for all agents in the + environment. + + Returns: + A random action for each environment. + """ + + if not hasattr(self, "_spaces_in_preferred_format") or \ + self._spaces_in_preferred_format is None: + self._spaces_in_preferred_format = \ + self._check_if_space_maps_agent_id_to_sub_space() + if self._spaces_in_preferred_format: + if agent_ids is None: + agent_ids = self.get_agent_ids() + samples = self.observation_space.sample() + samples = {agent_id: samples[agent_id] for agent_id in agent_ids} + return samples + logger.warning("observation_space_sample() has not been implemented") + del agent_ids + return {} + + @PublicAPI + def get_agent_ids(self) -> Set[AgentID]: + """Returns a set of agent ids in the environment. + + Returns: + set of agent ids. + """ + if not isinstance(self._agent_ids, set): + self._agent_ids = set(self._agent_ids) + return self._agent_ids + @PublicAPI def render(self, mode=None) -> None: """Tries to render the environment.""" @@ -88,13 +208,13 @@ def render(self, mode=None) -> None: # By default, do nothing. pass -# yapf: disable -# __grouping_doc_begin__ + # yapf: disable + # __grouping_doc_begin__ @ExperimentalAPI def with_agent_groups( - self, - groups: Dict[str, List[AgentID]], - obs_space: gym.Space = None, + self, + groups: Dict[str, List[AgentID]], + obs_space: gym.Space = None, act_space: gym.Space = None) -> "MultiAgentEnv": """Convenience method for grouping together agents in this env. @@ -132,8 +252,9 @@ def with_agent_groups( from ray.rllib.env.wrappers.group_agents_wrapper import \ GroupAgentsWrapper return GroupAgentsWrapper(self, groups, obs_space, act_space) -# __grouping_doc_end__ -# yapf: enable + + # __grouping_doc_end__ + # yapf: enable @PublicAPI def to_base_env( @@ -182,6 +303,20 @@ def to_base_env( return env + @DeveloperAPI + def _check_if_space_maps_agent_id_to_sub_space(self) -> bool: + # do the action and observation spaces map from agent ids to spaces + # for the individual agents? + obs_space_check = ( + hasattr(self, "observation_space") + and isinstance(self.observation_space, gym.spaces.Dict) + and set(self.observation_space.keys()) == self.get_agent_ids()) + action_space_check = ( + hasattr(self, "action_space") + and isinstance(self.action_space, gym.spaces.Dict) + and set(self.action_space.keys()) == self.get_agent_ids()) + return obs_space_check and action_space_check + def make_multi_agent( env_name_or_creator: Union[str, Callable[[EnvContext], EnvType]], @@ -242,6 +377,40 @@ def __init__(self, config=None): self.dones = set() self.observation_space = self.agents[0].observation_space self.action_space = self.agents[0].action_space + self._agent_ids = set(range(num)) + + @override(MultiAgentEnv) + def observation_space_sample(self, + agent_ids: list = None) -> MultiAgentDict: + if agent_ids is None: + agent_ids = list(range(len(self.agents))) + obs = { + agent_id: self.observation_space.sample() + for agent_id in agent_ids + } + + return obs + + @override(MultiAgentEnv) + def action_space_sample(self, + agent_ids: list = None) -> MultiAgentDict: + if agent_ids is None: + agent_ids = list(range(len(self.agents))) + actions = { + agent_id: self.action_space.sample() + for agent_id in agent_ids + } + + return actions + + @override(MultiAgentEnv) + def action_space_contains(self, x: MultiAgentDict) -> bool: + return all(self.action_space.contains(val) for val in x.values()) + + @override(MultiAgentEnv) + def observation_space_contains(self, x: MultiAgentDict) -> bool: + return all( + self.observation_space.contains(val) for val in x.values()) @override(MultiAgentEnv) def reset(self): @@ -277,7 +446,7 @@ def __init__(self, make_env: Callable[[int], EnvType], Args: make_env (Callable[[int], EnvType]): Factory that produces a new - MultiAgentEnv intance. Must be defined, if the number of + MultiAgentEnv instance. Must be defined, if the number of existing envs is less than num_envs. existing_envs (List[MultiAgentEnv]): List of already existing multi-agent envs. @@ -355,18 +524,31 @@ def try_render(self, env_id: Optional[EnvID] = None) -> None: @override(BaseEnv) @PublicAPI def observation_space(self) -> gym.spaces.Dict: - space = { - _id: env.observation_space - for _id, env in enumerate(self.envs) - } - return gym.spaces.Dict(space) + self.envs[0].observation_space @property @override(BaseEnv) @PublicAPI def action_space(self) -> gym.Space: - space = {_id: env.action_space for _id, env in enumerate(self.envs)} - return gym.spaces.Dict(space) + return self.envs[0].action_space + + @override(BaseEnv) + def observation_space_contains(self, x: MultiEnvDict) -> bool: + return all( + self.envs[0].observation_space_contains(val) for val in x.values()) + + @override(BaseEnv) + def action_space_contains(self, x: MultiEnvDict) -> bool: + return all( + self.envs[0].action_space_contains(val) for val in x.values()) + + @override(BaseEnv) + def observation_space_sample(self, agent_ids: list = None) -> MultiEnvDict: + return self.envs[0].observation_space_sample(agent_ids) + + @override(BaseEnv) + def action_space_sample(self, agent_ids: list = None) -> MultiEnvDict: + return self.envs[0].action_space_sample(agent_ids) class _MultiAgentEnvState: diff --git a/rllib/env/tests/test_multi_agent_env.py b/rllib/env/tests/test_multi_agent_env.py new file mode 100644 index 0000000000000..2657825776050 --- /dev/null +++ b/rllib/env/tests/test_multi_agent_env.py @@ -0,0 +1,50 @@ +import pytest +from ray.rllib.env.multi_agent_env import make_multi_agent +from ray.rllib.tests.test_nested_observation_spaces import NestedMultiAgentEnv + + +class TestMultiAgentEnv: + def test_space_in_preferred_format(self): + env = NestedMultiAgentEnv() + spaces_in_preferred_format = \ + env._check_if_space_maps_agent_id_to_sub_space() + assert spaces_in_preferred_format, "Space is not in preferred " \ + "format" + env2 = make_multi_agent("CartPole-v1")() + spaces_in_preferred_format = \ + env2._check_if_space_maps_agent_id_to_sub_space() + assert not spaces_in_preferred_format, "Space should not be in " \ + "preferred format but is." + + def test_spaces_sample_contain_in_preferred_format(self): + env = NestedMultiAgentEnv() + # this environment has spaces that are in the preferred format + # for multi-agent environments where the spaces are dict spaces + # mapping agent-ids to sub-spaces + obs = env.observation_space_sample() + assert env.observation_space_contains(obs), "Observation space does " \ + "not contain obs" + + action = env.action_space_sample() + assert env.action_space_contains(action), "Action space does " \ + "not contain action" + + def test_spaces_sample_contain_not_in_preferred_format(self): + env = make_multi_agent("CartPole-v1")({"num_agents": 2}) + # this environment has spaces that are not in the preferred format + # for multi-agent environments where the spaces not in the preferred + # format, users must override the observation_space_contains, + # action_space_contains observation_space_sample, + # and action_space_sample methods in order to do proper checks + obs = env.observation_space_sample() + assert env.observation_space_contains(obs), "Observation space does " \ + "not contain obs" + action = env.action_space_sample() + assert env.action_space_contains(action), "Action space does " \ + "not contain action" + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/env/vector_env.py b/rllib/env/vector_env.py index 20018ad0076a7..18e0476f96a2c 100644 --- a/rllib/env/vector_env.py +++ b/rllib/env/vector_env.py @@ -339,24 +339,3 @@ def observation_space(self) -> gym.spaces.Dict: @PublicAPI def action_space(self) -> gym.Space: return self._action_space - - @staticmethod - def _space_contains(space: gym.Space, x: MultiEnvDict) -> bool: - """Check if the given space contains the observations of x. - - Args: - space: The space to if x's observations are contained in. - x: The observations to check. - - Note: With vector envs, we can process the raw observations - and ignore the agent ids and env ids, since vector envs' - sub environements are guaranteed to be the same - - Returns: - True if the observations of x are contained in space. - """ - for _, multi_agent_dict in x.items(): - for _, element in multi_agent_dict.items(): - if not space.contains(element): - return False - return True diff --git a/rllib/tests/test_nested_observation_spaces.py b/rllib/tests/test_nested_observation_spaces.py index ec2b4e0afe66f..5903c71ad682c 100644 --- a/rllib/tests/test_nested_observation_spaces.py +++ b/rllib/tests/test_nested_observation_spaces.py @@ -120,6 +120,15 @@ def step(self, action): class NestedMultiAgentEnv(MultiAgentEnv): def __init__(self): + self.observation_space = spaces.Dict({ + "dict_agent": DICT_SPACE, + "tuple_agent": TUPLE_SPACE + }) + self.action_space = spaces.Dict({ + "dict_agent": spaces.Discrete(1), + "tuple_agent": spaces.Discrete(1) + }) + self._agent_ids = {"dict_agent", "tuple_agent"} self.steps = 0 def reset(self):