From b5a4bc1f349edb5eb69d571e1807e6d849ada017 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Thu, 26 Nov 2020 14:24:42 +0100 Subject: [PATCH 1/7] WIP. --- rllib/agents/ddpg/ddpg_torch_policy.py | 2 +- rllib/agents/dqn/dqn_torch_policy.py | 8 +-- rllib/agents/marwil/marwil_torch_policy.py | 2 +- rllib/agents/ppo/appo_torch_policy.py | 2 +- rllib/agents/ppo/ppo_tf_policy.py | 7 +- rllib/agents/ppo/ppo_torch_policy.py | 2 +- rllib/agents/ppo/tests/test_appo.py | 1 + rllib/agents/sac/sac_torch_policy.py | 2 +- rllib/contrib/maddpg/maddpg_policy.py | 2 +- .../collectors/simple_list_collector.py | 6 +- rllib/evaluation/rollout_worker.py | 5 +- rllib/evaluation/sampler.py | 11 +++- .../tests/test_trajectory_view_api.py | 14 ++-- .../policy/episode_env_aware_policy.py | 66 ++++++++++++++++++- rllib/policy/sample_batch.py | 2 +- rllib/policy/torch_policy.py | 4 ++ rllib/utils/exploration/exploration.py | 2 +- rllib/utils/sgd.py | 2 +- rllib/utils/tf_ops.py | 37 ++++++++--- 19 files changed, 136 insertions(+), 41 deletions(-) diff --git a/rllib/agents/ddpg/ddpg_torch_policy.py b/rllib/agents/ddpg/ddpg_torch_policy.py index 445564466ca0..39123783a421 100644 --- a/rllib/agents/ddpg/ddpg_torch_policy.py +++ b/rllib/agents/ddpg/ddpg_torch_policy.py @@ -274,7 +274,7 @@ def setup_late_mixins(policy, obs_space, action_space, config): optimizer_fn=make_ddpg_optimizers, validate_spaces=validate_spaces, before_init=before_init_fn, - after_init=setup_late_mixins, + before_loss_init=setup_late_mixins, action_distribution_fn=get_distribution_inputs_and_class, make_model_and_action_dist=build_ddpg_models_and_action_dist, apply_gradients_fn=apply_gradients_fn, diff --git a/rllib/agents/dqn/dqn_torch_policy.py b/rllib/agents/dqn/dqn_torch_policy.py index 0d633fa0602d..9b9d535d95b4 100644 --- a/rllib/agents/dqn/dqn_torch_policy.py +++ b/rllib/agents/dqn/dqn_torch_policy.py @@ -317,9 +317,9 @@ def setup_early_mixins(policy: Policy, obs_space, action_space, LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) -def after_init(policy: Policy, obs_space: gym.spaces.Space, - action_space: gym.spaces.Space, - config: TrainerConfigDict) -> None: +def before_loss_init(policy: Policy, obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: TrainerConfigDict) -> None: ComputeTDErrorMixin.__init__(policy) TargetNetworkMixin.__init__(policy, obs_space, action_space, config) # Move target net to device (this is done automatically for the @@ -397,7 +397,7 @@ def extra_action_out_fn(policy: Policy, input_dict, state_batches, model, extra_learn_fetches_fn=lambda policy: {"td_error": policy.q_loss.td_error}, extra_action_out_fn=extra_action_out_fn, before_init=setup_early_mixins, - after_init=after_init, + before_loss_init=before_loss_init, mixins=[ TargetNetworkMixin, ComputeTDErrorMixin, diff --git a/rllib/agents/marwil/marwil_torch_policy.py b/rllib/agents/marwil/marwil_torch_policy.py index f10e71c2d83c..e88e5e312f40 100644 --- a/rllib/agents/marwil/marwil_torch_policy.py +++ b/rllib/agents/marwil/marwil_torch_policy.py @@ -81,5 +81,5 @@ def setup_mixins(policy, obs_space, action_space, config): get_default_config=lambda: ray.rllib.agents.marwil.marwil.DEFAULT_CONFIG, stats_fn=stats, postprocess_fn=postprocess_advantages, - after_init=setup_mixins, + before_loss_init=setup_mixins, mixins=[ValueNetworkMixin]) diff --git a/rllib/agents/ppo/appo_torch_policy.py b/rllib/agents/ppo/appo_torch_policy.py index 2b9a3389b1d5..f24dfc6d1b54 100644 --- a/rllib/agents/ppo/appo_torch_policy.py +++ b/rllib/agents/ppo/appo_torch_policy.py @@ -331,7 +331,7 @@ def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space, extra_grad_process_fn=apply_grad_clipping, optimizer_fn=choose_optimizer, before_init=setup_early_mixins, - after_init=setup_late_mixins, + before_loss_init=setup_late_mixins, make_model=make_appo_model, mixins=[ LearningRateSchedule, KLCoeffMixin, TargetNetworkMixin, diff --git a/rllib/agents/ppo/ppo_tf_policy.py b/rllib/agents/ppo/ppo_tf_policy.py index c0442eb9242c..5a8ef7b801e9 100644 --- a/rllib/agents/ppo/ppo_tf_policy.py +++ b/rllib/agents/ppo/ppo_tf_policy.py @@ -47,7 +47,12 @@ def ppo_surrogate_loss( # RNN case: Mask away 0-padded chunks at end of time axis. if state: - max_seq_len = tf.reduce_max(train_batch["seq_lens"]) + # Derive max_seq_len from the data itself, not from the seq_lens + # tensor. This is in case e.g. seq_lens=[2, 3], but the data is still + # 0-padded up to T=5 (as it's the case for attention nets). + B = tf.shape(train_batch["seq_lens"])[0] + max_seq_len = tf.shape(logits)[0] // B + mask = tf.sequence_mask(train_batch["seq_lens"], max_seq_len) mask = tf.reshape(mask, [-1]) diff --git a/rllib/agents/ppo/ppo_torch_policy.py b/rllib/agents/ppo/ppo_torch_policy.py index 69bdc4b65728..58637fa0a64b 100644 --- a/rllib/agents/ppo/ppo_torch_policy.py +++ b/rllib/agents/ppo/ppo_torch_policy.py @@ -265,7 +265,7 @@ def setup_mixins(policy: Policy, obs_space: gym.spaces.Space, postprocess_fn=postprocess_ppo_gae, extra_grad_process_fn=apply_grad_clipping, before_init=setup_config, - after_init=setup_mixins, + before_loss_init=setup_mixins, mixins=[ LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin, ValueNetworkMixin diff --git a/rllib/agents/ppo/tests/test_appo.py b/rllib/agents/ppo/tests/test_appo.py index 20a6ddd29de5..259f010b386f 100644 --- a/rllib/agents/ppo/tests/test_appo.py +++ b/rllib/agents/ppo/tests/test_appo.py @@ -24,6 +24,7 @@ def test_appo_compilation(self): for _ in framework_iterator(config): print("w/o v-trace") _config = config.copy() + _config["vtrace"] = False trainer = ppo.APPOTrainer(config=_config, env="CartPole-v0") for i in range(num_iterations): print(trainer.train()) diff --git a/rllib/agents/sac/sac_torch_policy.py b/rllib/agents/sac/sac_torch_policy.py index b4b22586517e..a1a8f996bc23 100644 --- a/rllib/agents/sac/sac_torch_policy.py +++ b/rllib/agents/sac/sac_torch_policy.py @@ -489,7 +489,7 @@ def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space, extra_grad_process_fn=apply_grad_clipping, optimizer_fn=optimizer_fn, validate_spaces=validate_spaces, - after_init=setup_late_mixins, + before_loss_init=setup_late_mixins, make_model_and_action_dist=build_sac_model_and_action_dist, mixins=[TargetNetworkMixin, ComputeTDErrorMixin], action_distribution_fn=action_distribution_fn, diff --git a/rllib/contrib/maddpg/maddpg_policy.py b/rllib/contrib/maddpg/maddpg_policy.py index 01b95092ffa6..35a4b0cce602 100644 --- a/rllib/contrib/maddpg/maddpg_policy.py +++ b/rllib/contrib/maddpg/maddpg_policy.py @@ -1,8 +1,8 @@ import ray from ray.rllib.agents.dqn.dqn_tf_policy import minimize_and_clip, _adjust_nstep from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY -from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.models import ModelCatalog +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.annotations import override from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.policy.policy import Policy diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index f03d38cec2e6..98eda003ce9d 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -238,7 +238,7 @@ def add_postprocessed_batch_for_training( """ for view_col, data in batch.items(): # Skip columns that are not used for training. - if view_col in view_requirements and \ + if view_col not in view_requirements or \ not view_requirements[view_col].used_for_training: continue self.buffers[view_col].extend(data) @@ -465,8 +465,7 @@ def postprocess_episode(self, pre_batch = collector.build(policy.view_requirements) pre_batches[agent_id] = (policy, pre_batch) - # Apply postprocessor. - post_batches = {} + # Apply reward clipping before calling postprocessing functions. if self.clip_rewards is True: for _, (_, pre_batch) in pre_batches.items(): pre_batch["rewards"] = np.sign(pre_batch["rewards"]) @@ -477,6 +476,7 @@ def postprocess_episode(self, a_min=-self.clip_rewards, a_max=self.clip_rewards) + post_batches = {} for agent_id, (_, pre_batch) in pre_batches.items(): # Entire episode is said to be done. # Error if no DONE at end of this agent's trajectory. diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index f98d29e328e1..6d6d532d57c6 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -257,7 +257,8 @@ def __init__( directory if specified. log_dir (str): Directory where logs can be placed. log_level (str): Set the root log level on creation. - callbacks (DefaultCallbacks): Custom training callbacks. + callbacks (Type[DefaultCallbacks]): Custom sub-class of + DefaultCallbacks for training/policy/rollout-worker callbacks. input_creator (Callable[[IOContext], InputReader]): Function that returns an InputReader object for loading previous generated experiences. @@ -340,7 +341,7 @@ def gen_rollouts(): self.callbacks: "DefaultCallbacks" = callbacks() else: from ray.rllib.agents.callbacks import DefaultCallbacks - self.callbacks: "DefaultCallbacks" = DefaultCallbacks() + self.callbacks: DefaultCallbacks = DefaultCallbacks() self.worker_index: int = worker_index self.num_workers: int = num_workers model_config: ModelConfigDict = model_config or {} diff --git a/rllib/evaluation/sampler.py b/rllib/evaluation/sampler.py index 7c098dd6529a..447b50b709a3 100644 --- a/rllib/evaluation/sampler.py +++ b/rllib/evaluation/sampler.py @@ -1058,15 +1058,20 @@ def _process_observations_w_trajectory_view_api( "new_obs": filtered_obs, } # Add extra-action-fetches to collectors. - values_dict.update(**episode.last_pi_info_for(agent_id)) + pol = policies[policy_id] + for key, value in episode.last_pi_info_for(agent_id).items(): + values_dict[key] = value + # Env infos for this agent. + if "infos" in pol.view_requirements: + values_dict["infos"] = agent_infos _sample_collector.add_action_reward_next_obs( episode.episode_id, agent_id, env_id, policy_id, agent_done, values_dict) if not agent_done: item = PolicyEvalData( - env_id, agent_id, filtered_obs, infos[env_id].get( - agent_id, {}), None if last_observation is None else + env_id, agent_id, filtered_obs, agent_infos, None + if last_observation is None else episode.rnn_state_for(agent_id), None if last_observation is None else episode.last_action_for(agent_id), diff --git a/rllib/evaluation/tests/test_trajectory_view_api.py b/rllib/evaluation/tests/test_trajectory_view_api.py index 7897b32267e8..bd2488c47ed6 100644 --- a/rllib/evaluation/tests/test_trajectory_view_api.py +++ b/rllib/evaluation/tests/test_trajectory_view_api.py @@ -10,7 +10,7 @@ from ray.rllib.examples.env.debug_counter_env import MultiAgentDebugCounterEnv from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.examples.policy.episode_env_aware_policy import \ - EpisodeEnvAwarePolicy + EpisodeEnvAwareLSTMPolicy from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.view_requirement import ViewRequirement @@ -121,7 +121,6 @@ def test_traj_view_simple_performance(self): obs_space = Box(-1.0, 1.0, shape=(700, )) from ray.rllib.examples.env.random_env import RandomMultiAgentEnv - from ray.tune import register_env register_env("ma_env", lambda c: RandomMultiAgentEnv({ "num_agents": 2, @@ -147,7 +146,6 @@ def policy_fn(agent_id): "policy_mapping_fn": policy_fn, } num_iterations = 2 - # Only works in torch so far. for _ in framework_iterator(config, frameworks="torch"): print("w/ traj. view API") config["_use_trajectory_view_api"] = True @@ -253,7 +251,7 @@ def test_traj_view_lstm_functionality(self): rollout_fragment_length = 200 assert rollout_fragment_length % max_seq_len == 0 policies = { - "pol0": (EpisodeEnvAwarePolicy, obs_space, action_space, {}), + "pol0": (EpisodeEnvAwareLSTMPolicy, obs_space, action_space, {}), } def policy_fn(agent_id): @@ -316,8 +314,8 @@ def analyze_rnn_batch(batch, max_seq_len): state_in_1 = batch["state_in_1"][idx] # Check postprocessing outputs. - if "postprocessed_column" in batch: - postprocessed_col_t = batch["postprocessed_column"][idx] + if "2xobs" in batch: + postprocessed_col_t = batch["2xobs"][idx] assert (obs_t == postprocessed_col_t / 2.0).all() # Check state-in/out and next-obs values. @@ -386,8 +384,8 @@ def analyze_rnn_batch(batch, max_seq_len): r_t = batch["rewards"][k] # Check postprocessing outputs. - if "postprocessed_column" in batch: - postprocessed_col_t = batch["postprocessed_column"][k] + if "2xobs" in batch: + postprocessed_col_t = batch["2xobs"][k] assert (obs_t == postprocessed_col_t / 2.0).all() # Check state-in/out and next-obs values. diff --git a/rllib/examples/policy/episode_env_aware_policy.py b/rllib/examples/policy/episode_env_aware_policy.py index 44605cbd8909..ff8f91271479 100644 --- a/rllib/examples/policy/episode_env_aware_policy.py +++ b/rllib/examples/policy/episode_env_aware_policy.py @@ -8,7 +8,7 @@ from ray.rllib.utils.annotations import override -class EpisodeEnvAwarePolicy(RandomPolicy): +class EpisodeEnvAwareLSTMPolicy(RandomPolicy): """A Policy that always knows the current EpisodeID and EnvID and returns these in its actions.""" @@ -78,5 +78,67 @@ def postprocess_trajectory(self, sample_batch, other_agent_batches=None, episode=None): - sample_batch["postprocessed_column"] = sample_batch["obs"] * 2.0 + sample_batch["2xobs"] = sample_batch["obs"] * 2.0 + return sample_batch + + +class EpisodeEnvAwareAttentionPolicy(RandomPolicy): + """A Policy that always knows the current EpisodeID and EnvID and + returns these in its actions.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.state_space = Box(-1.0, 1.0, (1, )) + self.config["model"] = {"max_seq_len": 50} + + class _fake_model: + pass + + self.model = _fake_model() + self.model.inference_view_requirements = { + SampleBatch.AGENT_INDEX: ViewRequirement(), + SampleBatch.EPS_ID: ViewRequirement(), + "env_id": ViewRequirement(), + "t": ViewRequirement(), + SampleBatch.OBS: ViewRequirement(), + "state_in_0": ViewRequirement( + "state_out_0", + # Provide state outs -50 to -1 as "state-in". + data_rel_pos="-50:-1", + # Repeat the incoming state every n time steps (usually max seq + # len). + batch_repeat_value=self.config["model"]["max_seq_len"], + space=self.state_space) + } + + self.view_requirements = dict(super()._get_default_view_requirements(), + **self.model.inference_view_requirements) + + @override(Policy) + def is_recurrent(self): + return True + + @override(Policy) + def compute_actions_from_input_dict(self, + input_dict, + explore=None, + timestep=None, + **kwargs): + ts = input_dict["t"] + print(ts) + # Always return [episodeID, envID] as actions. + actions = np.array([[ + input_dict[SampleBatch.AGENT_INDEX][i], + input_dict[SampleBatch.EPS_ID][i], input_dict["env_id"][i] + ] for i, _ in enumerate(input_dict["obs"])]) + states = [np.array([[ts[i]] for i in range(len(input_dict["obs"]))])] + self.global_timestep += 1 + return actions, states, {} + + @override(Policy) + def postprocess_trajectory(self, + sample_batch, + other_agent_batches=None, + episode=None): + sample_batch["3xobs"] = sample_batch["obs"] * 3.0 return sample_batch diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index 89d5abdb08f8..db5ead008246 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -81,7 +81,7 @@ def __init__(self, *args, **kwargs): if self.seq_lens is not None and len(self.seq_lens) > 0: self.count = sum(self.seq_lens) else: - self.count = len(self.data[k]) + self.count = len(next(iter(self.data.values()))) # Keeps track of new columns added after initial ones. self.new_columns = [] diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index 7b20a34ac145..f294b510dba0 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -354,6 +354,8 @@ def compute_gradients(self, ) train_batch = self._lazy_tensor_dict(postprocessed_batch) + + # Calculate the actual policy loss. loss_out = force_list( self._loss(self, self.model, self.dist_class, train_batch)) @@ -369,6 +371,7 @@ def compute_gradients(self, assert len(loss_out) == len(self._optimizers) + # assert not any(torch.isnan(l) for l in loss_out) fetches = self.extra_compute_grad_fetches() # Loop through all optimizers. @@ -376,6 +379,7 @@ def compute_gradients(self, all_grads = [] for i, opt in enumerate(self._optimizers): + # Erase gradients in all vars of this optimizer. opt.zero_grad() # Recompute gradients of loss over all variables. loss_out[i].backward(retain_graph=(i < len(self._optimizers) - 1)) diff --git a/rllib/utils/exploration/exploration.py b/rllib/utils/exploration/exploration.py index a612df235129..4b88fbf97305 100644 --- a/rllib/utils/exploration/exploration.py +++ b/rllib/utils/exploration/exploration.py @@ -184,7 +184,7 @@ def get_exploration_loss(self, policy_loss: List[TensorType], Policy's own loss function and maybe the Model's custom loss. train_batch (SampleBatch): The training data to calculate the loss(es) for. This train data has already gone through - this Exploration's `preprocess_train_batch()` method. + this Exploration's `postprocess_trajectory()` method. Returns: List[TensorType]: The updated list of loss terms. diff --git a/rllib/utils/sgd.py b/rllib/utils/sgd.py index b267d04f08b2..d5576e0fa57d 100644 --- a/rllib/utils/sgd.py +++ b/rllib/utils/sgd.py @@ -66,7 +66,7 @@ def minibatches(samples, sgd_minibatch_size): # Replace with `if samples.seq_lens` check. if "state_in_0" in samples.data or "state_out_0" in samples.data: if log_once("not_shuffling_rnn_data_in_simple_mode"): - logger.warning("Not shuffling RNN data for SGD in simple mode") + logger.warning("Not time-shuffling RNN data for SGD.") else: samples.shuffle() diff --git a/rllib/utils/tf_ops.py b/rllib/utils/tf_ops.py index 39d2ce0033f8..d0344d123567 100644 --- a/rllib/utils/tf_ops.py +++ b/rllib/utils/tf_ops.py @@ -37,14 +37,14 @@ def explained_variance(y, pred): return tf.maximum(-1.0, 1 - (diff_var / y_var)) -def get_placeholder(*, space=None, value=None, name=None): +def get_placeholder(*, space=None, value=None, name=None, time_axis=False): from ray.rllib.models.catalog import ModelCatalog if space is not None: if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple)): return ModelCatalog.get_action_placeholder(space, None) return tf1.placeholder( - shape=(None, ) + space.shape, + shape=(None, ) + ((None, ) if time_axis else ()) + space.shape, dtype=tf.float32 if space.dtype == np.float64 else space.dtype, name=name, ) @@ -52,8 +52,9 @@ def get_placeholder(*, space=None, value=None, name=None): assert value is not None shape = value.shape[1:] return tf1.placeholder( - shape=(None, ) + (shape if isinstance(shape, tuple) else tuple( - shape.as_list())), + shape=(None, ) + ((None, ) + if time_axis else ()) + (shape if isinstance( + shape, tuple) else tuple(shape.as_list())), dtype=tf.float32 if value.dtype == np.float64 else value.dtype, name=name, ) @@ -132,10 +133,11 @@ def make_tf_callable(session_or_none, dynamic_shape=False): def make_wrapper(fn): if session_or_none: - placeholders = [] + args_placeholders = [] + kwargs_placeholders = {} symbolic_out = [None] - def call(*args): + def call(*args, **kwargs): args_flat = [] for a in args: if type(a) is list: @@ -153,13 +155,30 @@ def call(*args): shape = () else: shape = v.shape - placeholders.append( + args_placeholders.append( tf1.placeholder( dtype=v.dtype, shape=shape, name="arg_{}".format(i))) - symbolic_out[0] = fn(*placeholders) - feed_dict = dict(zip(placeholders, args)) + for k, v in kwargs.items(): + if dynamic_shape: + if len(v.shape) > 0: + shape = (None, ) + v.shape[1:] + else: + shape = () + else: + shape = v.shape + kwargs_placeholders[k] = \ + tf1.placeholder( + dtype=v.dtype, + shape=shape, + name="kwarg_{}".format(k)) + symbolic_out[0] = fn(*args_placeholders, + **kwargs_placeholders) + feed_dict = dict(zip(args_placeholders, args)) + feed_dict.update( + {kwargs_placeholders[k]: kwargs[k] + for k in kwargs.keys()}) ret = session_or_none.run(symbolic_out[0], feed_dict) return ret From 043768059c5aa87d34edeb5993bdb6dacef28927 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Thu, 26 Nov 2020 14:28:11 +0100 Subject: [PATCH 2/7] Fix. --- rllib/evaluation/sampler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/rllib/evaluation/sampler.py b/rllib/evaluation/sampler.py index 447b50b709a3..cb54c9372faa 100644 --- a/rllib/evaluation/sampler.py +++ b/rllib/evaluation/sampler.py @@ -1033,7 +1033,9 @@ def _process_observations_w_trajectory_view_api( agent_id) episode._set_last_observation(agent_id, filtered_obs) episode._set_last_raw_obs(agent_id, raw_obs) - episode._set_last_info(agent_id, infos[env_id].get(agent_id, {})) + # Infos from the environment. + agent_infos = infos[env_id].get(agent_id, {}) + episode._set_last_info(agent_id, agent_infos) # Record transition info if applicable. if last_observation is None: From 86911d9acd23bda6e06f37ac40bd95b0c25fdd14 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Thu, 26 Nov 2020 16:15:26 +0100 Subject: [PATCH 3/7] WIP. --- .../evaluation/collectors/sample_collector.py | 12 ++++++---- .../collectors/simple_list_collector.py | 21 +++++++++++++----- rllib/evaluation/sampler.py | 6 +++-- .../tests/test_trajectory_view_api.py | 8 +++---- .../policy/episode_env_aware_policy.py | 10 +++++---- .../policy/rock_paper_scissors_dummies.py | 2 +- rllib/models/modelv2.py | 3 ++- rllib/models/tf/recurrent_net.py | 4 ++-- rllib/models/torch/recurrent_net.py | 4 ++-- rllib/policy/policy.py | 22 +++++-------------- rllib/policy/view_requirement.py | 13 ++++++----- 11 files changed, 58 insertions(+), 47 deletions(-) diff --git a/rllib/evaluation/collectors/sample_collector.py b/rllib/evaluation/collectors/sample_collector.py index 4689c9261d8a..6ebc3d097018 100644 --- a/rllib/evaluation/collectors/sample_collector.py +++ b/rllib/evaluation/collectors/sample_collector.py @@ -31,7 +31,8 @@ class _SampleCollector(metaclass=ABCMeta): @abstractmethod def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID, - policy_id: PolicyID, init_obs: TensorType) -> None: + policy_id: PolicyID, t: int, + init_obs: TensorType) -> None: """Adds an initial obs (after reset) to this collector. Since the very first observation in an environment is collected w/o @@ -48,6 +49,8 @@ def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID, values for. env_id (EnvID): The environment index (in a vectorized setup). policy_id (PolicyID): Unique id for policy controlling the agent. + t (int): The time step (episode length - 1). The initial obs has + ts=-1(!), then an action/reward/next-obs at t=0, etc.. init_obs (TensorType): Initial observation (after env.reset()). Examples: @@ -172,9 +175,10 @@ def postprocess_episode(self, MultiAgentBatch. Used for batch_mode=`complete_episodes`. Returns: - Any: An ID that can be used in `build_multi_agent_batch` to - retrieve the samples that have been postprocessed as a - ready-built MultiAgentBatch. + Optional[MultiAgentBatch]: If `build` is True, the + SampleBatch or MultiAgentBatch built from `episode` (either + just from that episde or from the `_PolicyCollectorGroup` + in the `episode.batch_builder` property). """ raise NotImplementedError diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index 98eda003ce9d..9382db5a93a5 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -53,7 +53,7 @@ def __init__(self, shift_before: int = 0): self.count = 0 def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID, - env_id: EnvID, init_obs: TensorType, + env_id: EnvID, t: int, init_obs: TensorType, view_requirements: Dict[str, ViewRequirement]) -> None: """Adds an initial observation (after reset) to the Agent's trajectory. @@ -63,6 +63,8 @@ def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID, agent_id (AgentID): Unique ID for the agent we are adding the initial observation for. env_id (EnvID): The environment index (in a vectorized setup). + t (int): The time step (episode length - 1). The initial obs has + ts=-1(!), then an action/reward/next-obs at t=0, etc.. init_obs (TensorType): The initial observation tensor (after `env.reset()`). view_requirements (Dict[str, ViewRequirements]) @@ -74,8 +76,13 @@ def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID, SampleBatch.EPS_ID: episode_id, SampleBatch.AGENT_INDEX: agent_id, "env_id": env_id, + "t": t, }) self.buffers[SampleBatch.OBS].append(init_obs) + self.buffers[SampleBatch.EPS_ID].append(episode_id) + self.buffers[SampleBatch.AGENT_INDEX].append(agent_id) + self.buffers["env_id"].append(env_id) + self.buffers["t"].append(t) def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> \ None: @@ -133,7 +140,7 @@ def build(self, view_requirements: Dict[str, ViewRequirement]) -> \ continue # OBS are already shifted by -1 (the initial obs starts one ts # before all other data columns). - shift = view_req.shift - \ + shift = view_req.data_rel_pos - \ (1 if data_col == SampleBatch.OBS else 0) if data_col not in np_data: np_data[data_col] = to_float_np_array(self.buffers[data_col]) @@ -183,7 +190,10 @@ def _build_buffers(self, single_row: Dict[str, TensorType]) -> None: for col, data in single_row.items(): if col in self.buffers: continue - shift = self.shift_before - (1 if col == SampleBatch.OBS else 0) + shift = self.shift_before - (1 if col in [ + SampleBatch.OBS, SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, + "env_id", "t" + ] else 0) # Python primitive or dict (e.g. INFOs). if isinstance(data, (int, float, bool, str, dict)): self.buffers[col] = [0 for _ in range(shift)] @@ -356,7 +366,7 @@ def episode_step(self, episode_id: EpisodeID) -> None: @override(_SampleCollector) def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID, - env_id: EnvID, policy_id: PolicyID, + env_id: EnvID, policy_id: PolicyID, t: int, init_obs: TensorType) -> None: # Make sure our mappings are up to date. agent_key = (episode.episode_id, agent_id) @@ -376,6 +386,7 @@ def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID, episode_id=episode.episode_id, agent_id=agent_id, env_id=env_id, + t=t, init_obs=init_obs, view_requirements=view_reqs) @@ -425,7 +436,7 @@ def get_inference_input_dict(self, policy_id: PolicyID) -> \ # Create the batch of data from the different buffers. data_col = view_req.data_col or view_col time_indices = \ - view_req.shift - ( + view_req.data_rel_pos - ( 1 if data_col in [SampleBatch.OBS, "t", "env_id", SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX] else 0) diff --git a/rllib/evaluation/sampler.py b/rllib/evaluation/sampler.py index cb54c9372faa..804c269aa749 100644 --- a/rllib/evaluation/sampler.py +++ b/rllib/evaluation/sampler.py @@ -1040,7 +1040,8 @@ def _process_observations_w_trajectory_view_api( # Record transition info if applicable. if last_observation is None: _sample_collector.add_init_obs(episode, agent_id, env_id, - policy_id, filtered_obs) + policy_id, episode.length - 1, + filtered_obs) else: # Add actions, rewards, next-obs to collectors. values_dict = { @@ -1158,7 +1159,8 @@ def _process_observations_w_trajectory_view_api( # Add initial obs to buffer. _sample_collector.add_init_obs( - new_episode, agent_id, env_id, policy_id, filtered_obs) + new_episode, agent_id, env_id, policy_id, + new_episode.length - 1, filtered_obs) item = PolicyEvalData( env_id, agent_id, filtered_obs, diff --git a/rllib/evaluation/tests/test_trajectory_view_api.py b/rllib/evaluation/tests/test_trajectory_view_api.py index bd2488c47ed6..3fc23289c129 100644 --- a/rllib/evaluation/tests/test_trajectory_view_api.py +++ b/rllib/evaluation/tests/test_trajectory_view_api.py @@ -59,7 +59,7 @@ def test_traj_view_normal_case(self): assert view_req_policy[key].data_col is None else: assert view_req_policy[key].data_col == SampleBatch.OBS - assert view_req_policy[key].shift == 1 + assert view_req_policy[key].data_rel_pos == 1 rollout_worker = trainer.workers.local_worker() sample_batch = rollout_worker.sample() expected_count = \ @@ -99,10 +99,10 @@ def test_traj_view_lstm_prev_actions_and_rewards(self): if key == SampleBatch.PREV_ACTIONS: assert view_req_policy[key].data_col == SampleBatch.ACTIONS - assert view_req_policy[key].shift == -1 + assert view_req_policy[key].data_rel_pos == -1 elif key == SampleBatch.PREV_REWARDS: assert view_req_policy[key].data_col == SampleBatch.REWARDS - assert view_req_policy[key].shift == -1 + assert view_req_policy[key].data_rel_pos == -1 elif key not in [ SampleBatch.NEXT_OBS, SampleBatch.PREV_ACTIONS, SampleBatch.PREV_REWARDS @@ -110,7 +110,7 @@ def test_traj_view_lstm_prev_actions_and_rewards(self): assert view_req_policy[key].data_col is None else: assert view_req_policy[key].data_col == SampleBatch.OBS - assert view_req_policy[key].shift == 1 + assert view_req_policy[key].data_rel_pos == 1 trainer.stop() def test_traj_view_simple_performance(self): diff --git a/rllib/examples/policy/episode_env_aware_policy.py b/rllib/examples/policy/episode_env_aware_policy.py index ff8f91271479..89d4c525efbd 100644 --- a/rllib/examples/policy/episode_env_aware_policy.py +++ b/rllib/examples/policy/episode_env_aware_policy.py @@ -28,14 +28,16 @@ class _fake_model: "t": ViewRequirement(), SampleBatch.OBS: ViewRequirement(), SampleBatch.PREV_ACTIONS: ViewRequirement( - SampleBatch.ACTIONS, space=self.action_space, shift=-1), + SampleBatch.ACTIONS, space=self.action_space, data_rel_pos=-1), SampleBatch.PREV_REWARDS: ViewRequirement( - SampleBatch.REWARDS, shift=-1), + SampleBatch.REWARDS, data_rel_pos=-1), } for i in range(2): self.model.inference_view_requirements["state_in_{}".format(i)] = \ ViewRequirement( - "state_out_{}".format(i), shift=-1, space=self.state_space) + "state_out_{}".format(i), + data_rel_pos=-1, + space=self.state_space) self.model.inference_view_requirements[ "state_out_{}".format(i)] = \ ViewRequirement(space=self.state_space) @@ -43,7 +45,7 @@ class _fake_model: self.view_requirements = dict( **{ SampleBatch.NEXT_OBS: ViewRequirement( - SampleBatch.OBS, shift=1), + SampleBatch.OBS, data_rel_pos=1), SampleBatch.ACTIONS: ViewRequirement(space=self.action_space), SampleBatch.REWARDS: ViewRequirement(), SampleBatch.DONES: ViewRequirement(), diff --git a/rllib/examples/policy/rock_paper_scissors_dummies.py b/rllib/examples/policy/rock_paper_scissors_dummies.py index 72b2fdd518c9..011d49e5633e 100644 --- a/rllib/examples/policy/rock_paper_scissors_dummies.py +++ b/rllib/examples/policy/rock_paper_scissors_dummies.py @@ -16,7 +16,7 @@ def __init__(self, *args, **kwargs): self.view_requirements.update({ "state_in_0": ViewRequirement( "state_out_0", - shift=-1, + data_rel_pos=-1, space=gym.spaces.Box(0, 100, shape=(), dtype=np.int32)) }) diff --git a/rllib/models/modelv2.py b/rllib/models/modelv2.py index 4ac047c59db2..21bc139d4d6c 100644 --- a/rllib/models/modelv2.py +++ b/rllib/models/modelv2.py @@ -61,7 +61,8 @@ def __init__(self, obs_space: gym.spaces.Space, self.time_major = self.model_config.get("_time_major") # Basic view requirement for all models: Use the observation as input. self.inference_view_requirements = { - SampleBatch.OBS: ViewRequirement(shift=0, space=self.obs_space), + SampleBatch.OBS: ViewRequirement( + data_rel_pos=0, space=self.obs_space), } # TODO: (sven): Get rid of `get_initial_state` once Trajectory diff --git a/rllib/models/tf/recurrent_net.py b/rllib/models/tf/recurrent_net.py index f939c7ae36a6..57ecb22f05e0 100644 --- a/rllib/models/tf/recurrent_net.py +++ b/rllib/models/tf/recurrent_net.py @@ -178,10 +178,10 @@ def __init__(self, obs_space: gym.spaces.Space, if model_config["lstm_use_prev_action"]: self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \ ViewRequirement(SampleBatch.ACTIONS, space=self.action_space, - shift=-1) + data_rel_pos=-1) if model_config["lstm_use_prev_reward"]: self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \ - ViewRequirement(SampleBatch.REWARDS, shift=-1) + ViewRequirement(SampleBatch.REWARDS, data_rel_pos=-1) @override(RecurrentNetwork) def forward(self, input_dict: Dict[str, TensorType], diff --git a/rllib/models/torch/recurrent_net.py b/rllib/models/torch/recurrent_net.py index d558bf3dbf74..8e0d2263c7d2 100644 --- a/rllib/models/torch/recurrent_net.py +++ b/rllib/models/torch/recurrent_net.py @@ -159,10 +159,10 @@ def __init__(self, obs_space: gym.spaces.Space, if model_config["lstm_use_prev_action"]: self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \ ViewRequirement(SampleBatch.ACTIONS, space=self.action_space, - shift=-1) + data_rel_pos=-1) if model_config["lstm_use_prev_reward"]: self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \ - ViewRequirement(SampleBatch.REWARDS, shift=-1) + ViewRequirement(SampleBatch.REWARDS, data_rel_pos=-1) @override(RecurrentNetwork) def forward(self, input_dict: Dict[str, TensorType], diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index 80a2fae1bbd3..6499dbc3b774 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -564,13 +564,14 @@ def _get_default_view_requirements(self): SampleBatch.OBS: ViewRequirement(space=self.observation_space), SampleBatch.NEXT_OBS: ViewRequirement( data_col=SampleBatch.OBS, - shift=1, + data_rel_pos=1, space=self.observation_space), SampleBatch.ACTIONS: ViewRequirement(space=self.action_space), SampleBatch.REWARDS: ViewRequirement(), SampleBatch.DONES: ViewRequirement(), SampleBatch.INFOS: ViewRequirement(), SampleBatch.EPS_ID: ViewRequirement(), + SampleBatch.UNROLL_ID: ViewRequirement(), SampleBatch.AGENT_INDEX: ViewRequirement(), "t": ViewRequirement(), } @@ -616,7 +617,7 @@ def _initialize_loss_from_dummy_batch( batch_for_postproc.count = self._dummy_batch.count postprocessed_batch = self.postprocess_trajectory(batch_for_postproc) if state_outs: - B = 4 # For RNNs, have B=2, T=[depends on sample_batch_size] + B = 4 # For RNNs, have B=4, T=[depends on sample_batch_size] # TODO: (sven) This hack will not work for attention net traj. # view setup. i = 0 @@ -656,7 +657,8 @@ def _initialize_loss_from_dummy_batch( # Tag those only needed for post-processing. for key in batch_for_postproc.accessed_keys: if key not in train_batch.accessed_keys and \ - key in self.view_requirements: + key in self.view_requirements and \ + key not in self.model.inference_view_requirements: self.view_requirements[key].used_for_training = False # Remove those not needed at all (leave those that are needed # by Sampler to properly execute sample collection). @@ -679,18 +681,6 @@ def _initialize_loss_from_dummy_batch( "postprocessing function.".format(key)) else: del self.view_requirements[key] - # Add those data_cols (again) that are missing and have - # dependencies by view_cols. - for key in list(self.view_requirements.keys()): - vr = self.view_requirements[key] - if vr.data_col is not None and \ - vr.data_col not in self.view_requirements: - used_for_training = \ - vr.data_col in train_batch.accessed_keys - self.view_requirements[vr.data_col] = \ - ViewRequirement( - space=vr.space, - used_for_training=used_for_training) def _get_dummy_batch_from_view_requirements( self, batch_size: int = 1) -> SampleBatch: @@ -726,7 +716,7 @@ def _update_model_inference_view_requirements_from_init_state(self): model.inference_view_requirements["state_in_{}".format(i)] = \ ViewRequirement( "state_out_{}".format(i), - shift=-1, + data_rel_pos=-1, space=Box(-1.0, 1.0, shape=state.shape)) model.inference_view_requirements["state_out_{}".format(i)] = \ ViewRequirement(space=Box(-1.0, 1.0, shape=state.shape)) diff --git a/rllib/policy/view_requirement.py b/rllib/policy/view_requirement.py index 3264b759b532..8813c3aca288 100644 --- a/rllib/policy/view_requirement.py +++ b/rllib/policy/view_requirement.py @@ -29,7 +29,7 @@ class ViewRequirement: def __init__(self, data_col: Optional[str] = None, space: gym.Space = None, - shift: Union[int, List[int]] = 0, + data_rel_pos: Union[int, List[int]] = 0, used_for_training: bool = True): """Initializes a ViewRequirement object. @@ -40,13 +40,14 @@ def __init__(self, space (gym.Space): The gym Space used in case we need to pad data in inaccessible areas of the trajectory (t<0 or t>H). Default: Simple box space, e.g. rewards. - shift (Union[int, List[int]]): Single shift value of list of - shift values to use relative to the underlying `data_col`. + data_rel_pos (Union[int, str, List[int]]): Single shift value or + list of relative positions to use (relative to the underlying + `data_col`). Example: For a view column "prev_actions", you can set - `data_col="actions"` and `shift=-1`. + `data_col="actions"` and `data_rel_pos=-1`. Example: For a view column "obs" in an Atari framestacking fashion, you can set `data_col="obs"` and - `shift=[-3, -2, -1, 0]`. + `data_rel_pos=[-3, -2, -1, 0]`. used_for_training (bool): Whether the data will be used for training. If False, the column will not be copied into the final train batch. @@ -54,5 +55,5 @@ def __init__(self, self.data_col = data_col self.space = space or gym.spaces.Box( float("-inf"), float("inf"), shape=()) - self.shift = shift + self.data_rel_pos = data_rel_pos self.used_for_training = used_for_training From 5e269c4945ca01575ea1de5eb21a38c5eda40b45 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Thu, 26 Nov 2020 16:37:19 +0100 Subject: [PATCH 4/7] Fix. --- rllib/evaluation/collectors/simple_list_collector.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index 98eda003ce9d..3fe304ab8ea0 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -157,8 +157,9 @@ def build(self, view_requirements: Dict[str, ViewRequirement]) -> \ batch = SampleBatch(batch_data) if SampleBatch.UNROLL_ID not in batch.data: - batch.data[SampleBatch.UNROLL_ID] = np.repeat( - _AgentCollector._next_unroll_id, batch.count) + if SampleBatch.UNROLL_ID in view_requirements: + batch.data[SampleBatch.UNROLL_ID] = np.repeat( + _AgentCollector._next_unroll_id, batch.count) _AgentCollector._next_unroll_id += 1 # This trajectory is continuing -> Copy data at the end (in the size of @@ -256,7 +257,6 @@ def build(self): """ # Create batch from our buffers. batch = SampleBatch(self.buffers) - assert SampleBatch.UNROLL_ID in batch.data # Clear buffers for future samples. self.buffers.clear() # Reset count to 0. From 787810d3eeaf8b5b9c47bdcf5a2275ad34e38bbf Mon Sep 17 00:00:00 2001 From: sven1977 Date: Thu, 26 Nov 2020 17:45:50 +0100 Subject: [PATCH 5/7] Fix. --- rllib/evaluation/collectors/simple_list_collector.py | 10 +++++++--- rllib/policy/policy.py | 1 + 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index 3fe304ab8ea0..b7a9e5244240 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -157,9 +157,12 @@ def build(self, view_requirements: Dict[str, ViewRequirement]) -> \ batch = SampleBatch(batch_data) if SampleBatch.UNROLL_ID not in batch.data: - if SampleBatch.UNROLL_ID in view_requirements: - batch.data[SampleBatch.UNROLL_ID] = np.repeat( - _AgentCollector._next_unroll_id, batch.count) + # TODO: (sven) Once we have the additional + # model.preprocess_train_batch in place (attention net PR), we + # should not even need UNROLL_ID anymore: + # Add "if SampleBatch.UNROLL_ID in view_requirements:" here. + batch.data[SampleBatch.UNROLL_ID] = np.repeat( + _AgentCollector._next_unroll_id, batch.count) _AgentCollector._next_unroll_id += 1 # This trajectory is continuing -> Copy data at the end (in the size of @@ -257,6 +260,7 @@ def build(self): """ # Create batch from our buffers. batch = SampleBatch(self.buffers) + assert SampleBatch.UNROLL_ID in batch.data # Clear buffers for future samples. self.buffers.clear() # Reset count to 0. diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index 80a2fae1bbd3..81a952018b08 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -572,6 +572,7 @@ def _get_default_view_requirements(self): SampleBatch.INFOS: ViewRequirement(), SampleBatch.EPS_ID: ViewRequirement(), SampleBatch.AGENT_INDEX: ViewRequirement(), + SampleBatch.UNROLL_ID: ViewRequirement(), "t": ViewRequirement(), } From 7113c3228251b2e887ff24a48fe451328a23293a Mon Sep 17 00:00:00 2001 From: sven1977 Date: Thu, 26 Nov 2020 19:34:55 +0100 Subject: [PATCH 6/7] WIP. --- rllib/evaluation/rollout_worker.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index 6d6d532d57c6..31d9d6506bb4 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -272,9 +272,11 @@ def __init__( output_creator (Callable[[IOContext], OutputWriter]): Function that returns an OutputWriter object for saving generated experiences. - remote_worker_envs (bool): If using num_envs > 1, whether to create - those new envs in remote processes instead of in the current - process. This adds overheads, but can make sense if your envs + remote_worker_envs (bool): If using num_envs_per_worker > 1, + whether to create those new envs in remote processes instead of + in the current process. This adds overheads, but can make sense + if your envs are expensive to step/reset (e.g., for StarCraft). + Use this cautiously, overheads are significant! remote_env_batch_wait_ms (float): Timeout that remote workers are waiting when polling environments. 0 (continue when at least one env is ready) is a reasonable default, but optimal From f79faf7fb61cbaa211913ddc9ba67650dfa723ed Mon Sep 17 00:00:00 2001 From: sven1977 Date: Thu, 26 Nov 2020 22:43:27 +0100 Subject: [PATCH 7/7] Fixes and LINT. --- rllib/evaluation/collectors/simple_list_collector.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index 84824e4a4f17..7db62c5329c9 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -52,7 +52,7 @@ def __init__(self, shift_before: int = 0): # each time a (non-initial!) observation is added. self.count = 0 - def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID, + def add_init_obs(self, episode_id: EpisodeID, agent_index: int, env_id: EnvID, t: int, init_obs: TensorType, view_requirements: Dict[str, ViewRequirement]) -> None: """Adds an initial observation (after reset) to the Agent's trajectory. @@ -60,8 +60,8 @@ def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID, Args: episode_id (EpisodeID): Unique ID for the episode we are adding the initial observation for. - agent_id (AgentID): Unique ID for the agent we are adding the - initial observation for. + agent_index (int): Unique int index (starting from 0) for the agent + within its episode. env_id (EnvID): The environment index (in a vectorized setup). t (int): The time step (episode length - 1). The initial obs has ts=-1(!), then an action/reward/next-obs at t=0, etc.. @@ -74,13 +74,13 @@ def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID, single_row={ SampleBatch.OBS: init_obs, SampleBatch.EPS_ID: episode_id, - SampleBatch.AGENT_INDEX: agent_id, + SampleBatch.AGENT_INDEX: agent_index, "env_id": env_id, "t": t, }) self.buffers[SampleBatch.OBS].append(init_obs) self.buffers[SampleBatch.EPS_ID].append(episode_id) - self.buffers[SampleBatch.AGENT_INDEX].append(agent_id) + self.buffers[SampleBatch.AGENT_INDEX].append(agent_index) self.buffers["env_id"].append(env_id) self.buffers["t"].append(t) @@ -388,7 +388,7 @@ def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID, self.agent_collectors[agent_key] = _AgentCollector() self.agent_collectors[agent_key].add_init_obs( episode_id=episode.episode_id, - agent_id=agent_id, + agent_index=episode._agent_index(agent_id), env_id=env_id, t=t, init_obs=init_obs,