diff --git a/pymochi/models.py b/pymochi/models.py index 5ca9a5b..7ef426d 100644 --- a/pymochi/models.py +++ b/pymochi/models.py @@ -35,7 +35,7 @@ def forward(self, input: Tensor, target: Tensor, weight: Tensor) -> Tensor: class MochiGaussianNLLLoss(torch.nn.GaussianNLLLoss): """ - Mochi version of GaussianNLLLoss with no reduction that accepts weights (1/sigma) rather than var = sigma^2. + Mochi version of GaussianNLLLoss with no reduction that accepts fitness weights (1/sigma) rather than var = sigma^2. """ def forward(self, input: Tensor, target: Tensor, weight: Tensor) -> Tensor: return F.gaussian_nll_loss(input, target, torch.pow(weight, -2), full = True, eps = 0, reduction = "none")