Skip to content

Commit

Permalink
mean_absolute_difference
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Apr 13, 2022
1 parent 6027862 commit 93a9ba5
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions nn/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,18 @@ def kl_div(*, target: nn.Tensor, target_type: str,
return kl


@nn.scoped
def mean_absolute_difference(a: nn.Tensor, b: nn.Tensor, *, axis: Optional[nn.Dim] = None) -> nn.Tensor:
"""
Mean absolute difference, mean absolute error (MAE), or L1 loss between two tensors,
i.e. mean_{axis}( abs(a - b) ), where axis is the feature dim by default.
"""
if not axis:
assert a.feature_dim
axis = a.feature_dim
return nn.reduce(nn.abs(a - b), mode="mean", axis=axis)


@nn.scoped
def mean_squared_difference(a: nn.Tensor, b: nn.Tensor, *, axis: Optional[nn.Dim] = None) -> nn.Tensor:
"""
Expand Down

0 comments on commit 93a9ba5

Please sign in to comment.