Skip to content

Commit

Permalink
[Feature] Deterministic sample for Masked one-hot
Browse files Browse the repository at this point in the history
ghstack-source-id: 27787eab47324c5af152f706d81687e71b5b9803
Pull Request resolved: #2440
  • Loading branch information
vmoens committed Sep 17, 2024
1 parent 0a410ff commit a6d7545
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions torchrl/modules/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,17 @@ def sample(
) -> torch.Tensor:
...

@property
def deterministic_sample(self):
return self.mode

@property
def mode(self) -> torch.Tensor:
if hasattr(self, "logits"):
return (self.logits == self.logits.max(-1, True)[0]).to(torch.long)
else:
return (self.probs == self.probs.max(-1, True)[0]).to(torch.long)

def log_prob(self, value: torch.Tensor) -> torch.Tensor:
return super().log_prob(value.argmax(dim=-1))

Expand Down

0 comments on commit a6d7545

Please sign in to comment.