Skip to content

Commit

Permalink
Fixed RL examples to work with new gym API (pytorch#1051)
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 authored Sep 14, 2022
1 parent dc51eb1 commit 8428996
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
14 changes: 7 additions & 7 deletions reinforcement_learning/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


env = gym.make('CartPole-v1')
env.seed(args.seed)
env.reset(seed=args.seed)
torch.manual_seed(args.seed)


Expand Down Expand Up @@ -56,7 +56,7 @@ def forward(self, x):
"""
x = F.relu(self.affine1(x))

# actor: choses action to take from state s_t
# actor: choses action to take from state s_t
# by returning probability of each action
action_prob = F.softmax(self.action_head(x), dim=-1)

Expand All @@ -65,7 +65,7 @@ def forward(self, x):

# return values for both actor and critic as a tuple of 2 values:
# 1. a list with the probability of each action over the action space
# 2. the value from state s_t
# 2. the value from state s_t
return action_prob, state_values


Expand Down Expand Up @@ -113,7 +113,7 @@ def finish_episode():
for (log_prob, value), R in zip(saved_actions, returns):
advantage = R - value.item()

# calculate actor (policy) loss
# calculate actor (policy) loss
policy_losses.append(-log_prob * advantage)

# calculate critic (value) loss using L1 smooth loss
Expand Down Expand Up @@ -141,18 +141,18 @@ def main():
for i_episode in count(1):

# reset environment and episode reward
state = env.reset()
state, _ = env.reset()
ep_reward = 0

# for each episode, only run 9999 steps so that we don't
# for each episode, only run 9999 steps so that we don't
# infinite loop while learning
for t in range(1, 10000):

# select action from policy
action = select_action(state)

# take the action
state, reward, done, _ = env.step(action)
state, reward, done, _, _ = env.step(action)

if args.render:
env.render()
Expand Down
5 changes: 3 additions & 2 deletions reinforcement_learning/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,11 @@ def finish_episode():
def main():
running_reward = 10
for i_episode in count(1):
state, ep_reward = env.reset(), 0
state, _ = env.reset()
ep_reward = 0
for t in range(1, 10000): # Don't infinite loop while learning
action = select_action(state)
state, reward, done, _ = env.step(action)
state, reward, done, _, _ = env.step(action)
if args.render:
env.render()
policy.rewards.append(reward)
Expand Down

0 comments on commit 8428996

Please sign in to comment.