Skip to content

Commit

Permalink
improve code
Browse files Browse the repository at this point in the history
* add device param
* all in torch

Co-authored-by: Albert Zeyer <[email protected]>
  • Loading branch information
christophmluscher and albertz authored Dec 19, 2024
1 parent 3825aeb commit 0d8a9e3
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions i6_models/samplers/log_uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,19 @@


class LogUniformSampler(nn.Module):
def __init__(self, num_classes):
def __init__(self, num_classes: int, *, device: Optional[torch.device] = None):
super().__init__()

# assumes count-sorted vocabulary, descending
self.num_classes = num_classes

# approximately zipf distribution
self._distribution = [
(math.log1p(w + 1) - math.log1p(w)) / math.log1p(self.num_classes) for w in range(self.num_classes)
]
self._distribution = torch.tensor(self._distribution).clamp(min=1e-10)
ws = torch.arange(self.num_classes, dtype=torch.get_default_dtype(), device=device)
self._distribution = (torch.log1p(ws + 1) - torch.log1p(ws)) / torch.log1p(torch.tensor(self.num_classes))
self._distribution.clamp_(min=1e-10)
self._distribution /= self._distribution.sum()

self._cat_sampler = torch.distributions.categorical.Categorical(probs=self._distribution.cuda())
self._cat_sampler = torch.distributions.categorical.Categorical(probs=self._distribution)

def sample(self, num_samples):
return self._cat_sampler.sample(torch.Size([num_samples]))
Expand Down

0 comments on commit 0d8a9e3

Please sign in to comment.