Skip to content

Commit

Permalink
binary_cross_entropy
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Apr 13, 2022
1 parent 7d14813 commit 6027862
Showing 1 changed file with 84 additions and 1 deletion.
85 changes: 84 additions & 1 deletion nn/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"""

from typing import Optional
from typing import Optional, Union
from .. import nn


Expand Down Expand Up @@ -57,6 +57,89 @@ def cross_entropy(*, target: nn.Tensor, estimated: nn.Tensor, estimated_type: st
return -nn.dot(target, log_prob, reduce=axis)


@nn.scoped
def binary_cross_entropy(*,
target: nn.Tensor,
pos_estimated: nn.Tensor, pos_estimated_type: str,
pos_weight: Optional[Union[float, nn.Tensor]] = None):
"""
Binary cross entropy, or also called sigmoid cross entropy.
:param target: (sparse) target labels, 0 (positive) or 1 (negative), i.e. binary.
:param pos_estimated: positive class logits. probs = sigmoid(logits).
:param pos_estimated_type: "logits" only supported currently
:param pos_weight: weight for positive class.
Code and documentation partly borrowed from TensorFlow.
A value `pos_weight > 1` decreases the false negative count, hence increasing
the recall.
Conversely setting `pos_weight < 1` decreases the false positive count and
increases the precision.
This can be seen from the fact that `pos_weight` is introduced as a
multiplicative coefficient for the positive labels term
in the loss expression:
labels * -log(sigmoid(logits)) * pos_weight +
(1 - labels) * -log(1 - sigmoid(logits))
For brevity, let `x = logits`, `z = labels`. The logistic loss is
z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
= z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
= z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
= z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
= (1 - z) * x + log(1 + exp(-x))
= x - x * z + log(1 + exp(-x))
For x < 0, to avoid overflow in exp(-x), we reformulate the above
x - x * z + log(1 + exp(-x))
= log(exp(x)) - x * z + log(1 + exp(-x))
= - x * z + log(1 + exp(x))
Hence, to ensure stability and avoid overflow, the implementation uses this
equivalent formulation
max(x, 0) - x * z + log(1 + exp(-abs(x)))
"""
if pos_estimated_type != "logits":
raise NotImplementedError(
f"binary_cross_entropy, pos_estimated_type {pos_estimated_type!r}, only 'logits' supported")
logits = pos_estimated

if pos_weight is not None:
# Code adapted from tf.nn.weighted_cross_entropy_with_logits.
# The logistic loss formula from above is
# (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))
# For x < 0, a more numerically stable formula is
# (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(x)) - l * x
# To avoid branching, we use the combined version
# (1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))
log_weight = 1 + (pos_weight - 1) * target
return (
(1 - target) * logits +
log_weight * (nn.log1p(nn.exp(-nn.abs(logits))) + nn.relu(-logits))
)

# Code adapted from tf.nn.sigmoid_cross_entropy_with_logits.
# The logistic loss formula from above is
# x - x * z + log(1 + exp(-x))
# For x < 0, a more numerically stable formula is
# -x * z + log(1 + exp(x))
# Note that these two expressions can be combined into the following:
# max(x, 0) - x * z + log(1 + exp(-abs(x)))
# To allow computing gradients at zero, we define custom versions of max and
# abs functions.
cond = (logits >= 0)
relu_logits = nn.where(cond, logits, 0)
neg_abs_logits = nn.where(cond, -logits, logits) # pylint: disable=invalid-unary-operand-type
return (
relu_logits - logits * target +
nn.log1p(nn.exp(neg_abs_logits))
)


@nn.scoped
def kl_div(*, target: nn.Tensor, target_type: str,
estimated: nn.Tensor, estimated_type: str,
Expand Down

0 comments on commit 6027862

Please sign in to comment.