Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LSTM #147

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open

LSTM #147

wants to merge 6 commits into from

Conversation

akbir
Copy link
Member

@akbir akbir commented Nov 23, 2022

Adding an LSTMAgent.

This adds an LSTM option to the PPO Agents.

  • Achieves parity on IPD
  1. Does your submission pass tests?
  2. Have you lint your code locally prior to submission?

Changes to Core Features:

  • Have you added an explanation of what your changes do and why you'd like us to include them?
  • Have you written new tests for your core changes, as applicable?
  • Have you successfully ran tests with your changes locally?

@github-actions github-actions bot added the core label Nov 23, 2022
@akbir akbir changed the title wip: LSTM LSTM Nov 24, 2022
@akbir akbir requested a review from newtonkwan November 29, 2022 16:54
Copy link
Contributor

@newtonkwan newtonkwan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice that you're adding LSTM. However, I'm not confident that it works because I don't know if cell is being used to calculate the new hidden. If you could point me to where that's happening in ppo_lstm.py, it might help me understand. Also, does this run on the parity tests? When I first developed the memory component, LSTM didn't automatically batch the first dimension, so you had to know beforehand how big the batch size was going to be when passing in some inputs. I'm not sure if you were able to find a work around for that. My misunderstandings are probably also due to me not seeing the code base for a while.

If you could show that this works on a few parity tests and that the cell part is being used for updating the state, as well as addressing those few minor comments, then I think we're good to go.

@@ -358,15 +374,27 @@ def forward_fn(
inputs: jnp.ndarray, state: jnp.ndarray
) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
"""forward function"""
torso = hk.nets.MLP(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this being removed?

behavior_values: jnp.ndarray
behavior_log_probs: jnp.ndarray

# GRU specific
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change to LSTM specific or Recurrent specific. Also, wouldn't this need cell as well?

seed=seed,
player_id=player_id,
)
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i would change this to an elif args.ppo.rnn_type == "gru", then an else that raises an error. We wouldn't want any string other than lstm to set the rnn_type to gru.

agent1: 'PPO'

# Environment
env_id: MountainCar-v0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the file is called pendulum.yaml but the env_id: MountainCar-v0. Am I missing something?

@@ -2,7 +2,7 @@

# Agents
agent1: 'PPO_memory'
agent2: 'TitForTat'
agent2: 'PPO_memory'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

file is called ppo_mem_v_tft.yaml but agent2: PPO_memory. Why was this changed?

key = jax.random.split(
agent2._state.random_key, args.popsize * args.num_opps
).reshape(args.popsize, args.num_opps, -1)
if args.ppo.rnn_type == "lstm" and args.agent2 == "PPO_memory":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need some help understanding this. If we want to use an LSTM, the initial hidden state is a Haiku LSTMState object that holds the hidden and cell states. And if we want a GRU, then the hidden state is in jnp.tile(agent2._mem.hidden, (args.popsize, args.num_opps, 1, 1)). Are these both NamedTuples and is agent2.batch_init() able to handle both of them?

Maybe I'm missing something that changed in how the agents and the agent methods are initialized, but I don't see any diffs for that file here.

hiddens: jnp.ndarray,
):
"""Surrogate loss using clipped probability ratios."""
(distribution, values), _ = network.apply(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are now using an LSTM, is it now the case that network.apply() now requires both hidden and cell?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no it needs the LSTMHIdden.

initial_hidden_state=initial_hidden_state,
optimizer=optimizer,
random_key=random_key,
gru_dim=args.ppo.hidden_size,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the ppo file, you could change this input to something more general now such as recurrent_dim, rather than gru_dim.

@akbir
Copy link
Member Author

akbir commented Dec 7, 2022

@newtonkwan - can you pick this PR up and get into main before the release?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants