Skip to content

Commit

Permalink
Fix dtype for MPS in reinforcement learning example (#19982)
Browse files Browse the repository at this point in the history
  • Loading branch information
swyo authored Jun 21, 2024
1 parent cec6ae1 commit d3a0ada
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion examples/fabric/reinforcement_learning/train_fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def main(args: argparse.Namespace):
# Single environment step
next_obs, reward, done, truncated, info = envs.step(action.cpu().numpy())
done = torch.logical_or(torch.tensor(done), torch.tensor(truncated))
rewards[step] = torch.tensor(reward, device=device).view(-1)
rewards[step] = torch.tensor(reward, device=device, dtype=torch.float32).view(-1)
next_obs, next_done = torch.tensor(next_obs, device=device), done.to(device)

if "final_info" in info:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def player(args, world_collective: TorchCollective, player_trainer_collective: T
# Single environment step
next_obs, reward, done, truncated, info = envs.step(action.cpu().numpy())
done = torch.logical_or(torch.tensor(done), torch.tensor(truncated))
rewards[step] = torch.tensor(reward, device=device).view(-1)
rewards[step] = torch.tensor(reward, device=device, dtype=torch.float32).view(-1)
next_obs, next_done = torch.tensor(next_obs, device=device), done.to(device)

if "final_info" in info:
Expand Down

0 comments on commit d3a0ada

Please sign in to comment.