From d3a0ada4ffac76c09f3198b4e35d60b08d6c6069 Mon Sep 17 00:00:00 2001 From: SW Yoo <71536965+swyo@users.noreply.github.com> Date: Fri, 21 Jun 2024 23:36:10 +0900 Subject: [PATCH] Fix dtype for MPS in reinforcement learning example (#19982) --- examples/fabric/reinforcement_learning/train_fabric.py | 2 +- .../fabric/reinforcement_learning/train_fabric_decoupled.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/fabric/reinforcement_learning/train_fabric.py b/examples/fabric/reinforcement_learning/train_fabric.py index 1f3f83f3f2025..74b9b378371d3 100644 --- a/examples/fabric/reinforcement_learning/train_fabric.py +++ b/examples/fabric/reinforcement_learning/train_fabric.py @@ -146,7 +146,7 @@ def main(args: argparse.Namespace): # Single environment step next_obs, reward, done, truncated, info = envs.step(action.cpu().numpy()) done = torch.logical_or(torch.tensor(done), torch.tensor(truncated)) - rewards[step] = torch.tensor(reward, device=device).view(-1) + rewards[step] = torch.tensor(reward, device=device, dtype=torch.float32).view(-1) next_obs, next_done = torch.tensor(next_obs, device=device), done.to(device) if "final_info" in info: diff --git a/examples/fabric/reinforcement_learning/train_fabric_decoupled.py b/examples/fabric/reinforcement_learning/train_fabric_decoupled.py index bbc09c977efcf..3849ae0f96a3c 100644 --- a/examples/fabric/reinforcement_learning/train_fabric_decoupled.py +++ b/examples/fabric/reinforcement_learning/train_fabric_decoupled.py @@ -135,7 +135,7 @@ def player(args, world_collective: TorchCollective, player_trainer_collective: T # Single environment step next_obs, reward, done, truncated, info = envs.step(action.cpu().numpy()) done = torch.logical_or(torch.tensor(done), torch.tensor(truncated)) - rewards[step] = torch.tensor(reward, device=device).view(-1) + rewards[step] = torch.tensor(reward, device=device, dtype=torch.float32).view(-1) next_obs, next_done = torch.tensor(next_obs, device=device), done.to(device) if "final_info" in info: