Skip to content

Commit

Permalink
[BugFix] Fix tictactoeenv.py
Browse files Browse the repository at this point in the history
ghstack-source-id: 99a368cf34cb7a3240ee85e85fb945d39292beb5
Pull Request resolved: #2417
  • Loading branch information
vmoens committed Sep 4, 2024
1 parent 60cd104 commit 6c66796
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions torchrl/envs/custom/tictactoeenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def _step(self, state: TensorDict) -> TensorDict:
turn = state["turn"].clone()
action = state["action"]
board.flatten(-2, -1).scatter_(index=action.unsqueeze(-1), dim=-1, value=1)
wins = self.win(state["board"], action)
wins = self.win(board, action)

mask = board.flatten(-2, -1) == -1
done = wins | ~mask.any(-1, keepdim=True)
Expand All @@ -234,7 +234,7 @@ def _step(self, state: TensorDict) -> TensorDict:
("player0", "reward"): reward_0.float(),
("player1", "reward"): reward_1.float(),
"board": torch.where(board == -1, board, 1 - board),
"turn": 1 - state["turn"],
"turn": 1 - turn,
"mask": mask,
},
batch_size=state.batch_size,
Expand All @@ -260,13 +260,15 @@ def _set_seed(self, seed: int | None):
def win(board: torch.Tensor, action: torch.Tensor):
row = action // 3 # type: ignore
col = action % 3 # type: ignore
return (
board[..., row, :].sum()
== 3 | board[..., col].sum()
== 3 | board.diagonal(0, -2, -1).sum()
== 3 | board.flip(-1).diagonal(0, -2, -1).sum()
== 3
)
if board[..., row, :].sum() == 3:
return True
if board[..., col].sum() == 3:
return True
if board.diagonal(0, -2, -1).sum() == 3:
return True
if board.flip(-1).diagonal(0, -2, -1).sum() == 3:
return True
return False

@staticmethod
def full(board: torch.Tensor) -> bool:
Expand Down

0 comments on commit 6c66796

Please sign in to comment.