Skip to content

Commit

Permalink
Fix NaN loss for sigmoid(x) == 1
Browse files Browse the repository at this point in the history
In the current implementation of binary_cross_entropy_with_logit the
loss will actually be NaN due to taking the log(0) which occurs for high
logits passing through a sigmoid and an affine transformation:

inp.affine(-1., 1.)?.log()?
^      ^              ^
|      |              |
1.0    |              |
       0.0            |
                      NaN

The proposed implementation is actually taken more or less directly from
pytorch
https://github.com/pytorch/pytorch/blob/41977a05314bbf537e1c5d6cf5916a368d1907d9/aten/src/ATen/native/Loss.cpp#L362
  • Loading branch information
BeneSim committed Oct 14, 2024
1 parent 712adde commit fa07a09
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 5 deletions.
7 changes: 2 additions & 5 deletions candle-nn/src/loss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,10 @@ pub fn mse(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
///
/// The resulting tensor is a scalar containing the average value over the batch.
pub fn binary_cross_entropy_with_logit(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
let inp = crate::ops::sigmoid(inp)?;
let log_sigmoid_input = crate::ops::sigmoid(inp)?.log()?;

let left_side = target * inp.log()?;
let right_side = (target.affine(-1., 1.))? * inp.affine(-1., 1.)?.log()?;
let loss = (1.0 - target)?.mul(inp)?.sub(&log_sigmoid_input)?.mean_all()?;

let loss = left_side? + right_side?;
let loss = loss?.neg()?.mean_all()?;

Ok(loss)
}
50 changes: 50 additions & 0 deletions candle-nn/tests/loss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,53 @@ fn binary_cross_entropy_with_logit() -> Result<()> {
assert_eq!(to_vec0_round(&loss, 4)?, 0.8224);
Ok(())
}

/*
Test high logit
Equivalent python code:
import torch
import torch.nn.functional as F
inp = torch.Tensor([[ 2.3611, -0.8813, -0.5006, -0.2178],
[ 0.0419, 0.0763, -1.0457, -1.6692],
[-1.0494, 0.8111, 1.5723, 1.2315],
[ 1.3081, 0.6641, 1.1802, -0.2547],
[ 0.5292, 0.7636, 0.3692, 28.8318]])
target = torch.Tensor([[0., 1., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 1.],
[1., 0., 0., 0.],
[0., 0., 1., 0.]])
print(F.binary_cross_entropy_with_logits(inp, target))
*/
#[test]
fn binary_cross_entropy_with_high_logit() -> Result<()> {
let cpu = Device::Cpu;

let inp = [
[2.3611f32, -0.8813, -0.5006, -0.2178],
[0.0419, 0.0763, -1.0457, -1.6692],
[-1.0494, 0.8111, 1.5723, 1.2315],
[1.3081, 0.6641, 1.1802, -0.2547],
[0.5292, 0.7636, 0.3692, 28.8318],
];

let target = [
[0.0f32, 1., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 1.],
[1., 0., 0., 0.],
[0., 0., 1., 0.],
];

let inp = Tensor::new(&inp, &cpu)?;
let target = Tensor::new(&target, &cpu)?;

let loss = candle_nn::loss::binary_cross_entropy_with_logit(&inp, &target)?;

assert_eq!(to_vec0_round(&loss, 4)?, 2.246);
Ok(())
}

0 comments on commit fa07a09

Please sign in to comment.