Skip to content

Commit

Permalink
Fix week04_approx_rl (#492)
Browse files Browse the repository at this point in the history
* fix errors

* Provide more info inside the comments (#1)
  • Loading branch information
alexeyhorkin authored Mar 13, 2022
1 parent 125f6cc commit b4761a4
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions week04_approx_rl/homework_pytorch_main.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,7 @@
"N_STEPS = 100\n",
"\n",
"exp_replay = ReplayBuffer(REPLAY_BUFFER_SIZE)\n",
"for i in range(REPLAY_BUFFER_SIZE // N_STEPS)):\n",
"for i in range(REPLAY_BUFFER_SIZE // N_STEPS):\n",
" if not utils.is_enough_ram(min_available_gb=0.1):\n",
" print(\"\"\"\n",
" Less than 100 Mb RAM available. \n",
Expand Down Expand Up @@ -991,7 +991,7 @@
"\n",
" if step % loss_freq == 0:\n",
" td_loss_history.append(loss.data.cpu().item())\n",
" grad_norm_history.append(grad_norm)\n",
" grad_norm_history.append(grad_norm.cpu())\n",
"\n",
" if step % refresh_target_network_freq == 0:\n",
" # Load agent weights into target_network\n",
Expand Down
12 changes: 6 additions & 6 deletions week04_approx_rl/seminar_pytorch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -176,18 +176,18 @@
"def compute_td_loss(states, actions, rewards, next_states, is_done, gamma=0.99, check_shapes=False):\n",
" \"\"\" Compute td loss using torch operations only. Use the formula above. \"\"\"\n",
" states = torch.tensor(\n",
" states, dtype=torch.float32) # shape: [batch_size, state_size]\n",
" actions = torch.tensor(actions, dtype=torch.long) # shape: [batch_size]\n",
" rewards = torch.tensor(rewards, dtype=torch.float32) # shape: [batch_size]\n",
" states, dtype=torch.float32) # shape: [batch_size, state_size]\n",
" actions = torch.tensor(actions, dtype=torch.long) # shape: [batch_size]\n",
" rewards = torch.tensor(rewards, dtype=torch.float32) # shape: [batch_size]\n",
" # shape: [batch_size, state_size]\n",
" next_states = torch.tensor(next_states, dtype=torch.float32)\n",
" is_done = torch.tensor(is_done, dtype=torch.uint8) # shape: [batch_size]\n",
" is_done = torch.tensor(is_done, dtype=torch.uint8) # shape: [batch_size]\n",
"\n",
" # get q-values for all actions in current states\n",
" predicted_qvalues = network(states)\n",
" predicted_qvalues = network(states) # shape: [batch_size, n_actions]\n",
"\n",
" # select q-values for chosen actions\n",
" predicted_qvalues_for_actions = predicted_qvalues[\n",
" predicted_qvalues_for_actions = predicted_qvalues[ # shape: [batch_size]\n",
" range(states.shape[0]), actions\n",
" ]\n",
"\n",
Expand Down

0 comments on commit b4761a4

Please sign in to comment.