Skip to content

Commit

Permalink
smooth_l1_loss (ivy-llc#22378)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel4078 <[email protected]>
  • Loading branch information
2 people authored and arshPratap committed Sep 11, 2023
1 parent 685a484 commit a29ce99
Show file tree
Hide file tree
Showing 9 changed files with 540 additions and 0 deletions.
49 changes: 49 additions & 0 deletions ivy/data_classes/array/experimental/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,52 @@ def l1_loss(
ivy.array(0.20000000000000004)
"""
return ivy.l1_loss(self._data, target, reduction=reduction, out=out)

def smooth_l1_loss(
self: ivy.Array,
target: Union[ivy.Array, ivy.NativeArray],
/,
*,
beta: Optional[float] = 1.0,
reduction: Optional[str] = "mean",
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
ivy.Array instance method variant of ivy. smooth_l1_loss. This method simply
wraps the function, and so the docstring for ivy.smooth_l1_loss also applies to
this method with minimal changes.
Parameters
----------
self
input array containing true labels.
target
input array containing targeted labels.
beta
A float specifying the beta value for
the smooth L1 loss. Default: 1.0.
reduction
Reduction method for the loss.
Options are 'none', 'mean', or 'sum'.
Default: 'mean'.
out
Optional output array, for writing the result to.
It must have a shape
that the inputs broadcast to.
Returns
-------
ret
The smooth L1 loss between the given labels.
Examples
--------
>>> x = ivy.array([1, 2, 3, 4])
>>> y = ivy.array([2, 2, 2, 2])
>>> z = x.smooth_l1_loss(y, beta=0.5)
>>> print(z)
ivy.array(0.8125)
"""
return ivy.smooth_l1_loss(
self._data, target, beta=beta, reduction=reduction, out=out
)
169 changes: 169 additions & 0 deletions ivy/data_classes/container/experimental/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,172 @@ def l1_loss(
map_sequences=map_sequences,
out=out,
)

@staticmethod
def _static_smooth_l1_loss(
input: Union[ivy.Container, ivy.Array, ivy.NativeArray],
target: Union[ivy.Container, ivy.Array, ivy.NativeArray],
/,
*,
beta: Optional[Union[float, ivy.Container]] = 1.0,
reduction: Optional[Union[str, ivy.Container]] = "mean",
key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None,
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
ivy.Container static method variant of ivy.smooth_l1_loss. This method simply
wraps the function, and so the docstring for ivy. smooth_l1_loss also applies to
this method with minimal changes.
Parameters
----------
input
input array or container containing input labels.
target
input array or container containing the targeticted labels.
beta
a positive float value that sets the smoothness threshold.
Default: ``1.0``.
reduction
``'none'``: No reduction will be applied to the output.
``'mean'``: The output will be averaged.
``'sum'``: The output will be summed. Default: ``'mean'``.
key_chains
The key-chains to apply or not apply the method to. Default is ``None``.
to_apply
If input, the method will be applied to key_chains, otherwise key_chains
will be skipped. Default is ``input``.
prune_unapplied
Whether to prune key_chains for which the function was not applied.
Default is ``False``.
map_sequences
Whether to also map method to sequences (lists, tuples).
Default is ``False``.
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
Returns
-------
ret
The smooth L1 loss between the input array and the targeticted labels.
Examples
--------
With :class:`ivy.Container` inputs:
>>> x = ivy.Container(a=ivy.array([1, 0, 2]), b=ivy.array([3, 2, 1]))
>>> y = ivy.Container(a=ivy.array([0.6, 0.2, 0.3]),
b=ivy.array([0.8, 0.2, 0.2]))
>>> z = ivy.Container.static_smooth_l1_loss(x, y)
>>> print(z)
{
a: ivy.array(0.9),
b: ivy.array(0.25)
}
With a mix of :class:`ivy.Array` and :class:`ivy.Container` inputs:
>>> x = ivy.array([1 , 0, 2])
>>> y = ivy.Container(a=ivy.array([0.6, 0.2, 0.3]),
b=ivy.array([0.8, 0.2, 0.2]))
>>> z = ivy.Container.static_smooth_l1_loss(x, y)
>>> print(z)
{
a: ivy.array(0.9),
b: ivy.array(0.25)
}
"""
return ContainerBase.cont_multi_map_in_function(
"smooth_l1_loss",
input,
target,
beta=beta,
reduction=reduction,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
)

def smooth_l1_loss(
self: ivy.Container,
target: Union[ivy.Container, ivy.Array, ivy.NativeArray],
/,
*,
beta: Optional[Union[float, ivy.Container]] = 1.0,
reduction: Optional[Union[str, ivy.Container]] = "mean",
key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None,
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
ivy.Container instance method variant of ivy.smooth_l1_loss. This method simply
wraps the function, and so the docstring for ivy. smooth_l1_loss also applies to
this method with minimal changes.
Parameters
----------
self
input container containing input labels.
target
input array or container containing the targeticted labels.
beta
a positive float value that sets the smoothness threshold.
Default: ``1.0``.
reduction
``'none'``: No reduction will be applied to the output.
``'mean'``: The output will be averaged.
``'sum'``: The output will be summed. Default: ``'mean'``.
key_chains
The key-chains to apply or not apply the method to. Default is
``None``.
to_apply
If input, the method will be applied to key_chains, otherwise
key_chains
will be skipped. Default is ``input``.
prune_unapplied
Whether to prune key_chains for which the function was not applied.
Default is ``False``.
map_sequences
Whether to also map method to sequences (lists, tuples).
Default is ``False``.
out
optional output container, for writing the result to.
It must have a shape
that the inputs broadcast to.
Returns
-------
ret
The smooth L1 loss between the input array and the targeticted labels.
Examples
--------
>>> x = ivy.Container(a=ivy.array([1, 0, 2]), b=ivy.array([3, 2, 1]))
>>> y = ivy.Container(a=ivy.array([0.6, 0.2, 0.3]),
b=ivy.array([0.8, 0.2, 0.2]))
>>> z = x.smooth_l1_loss(y)
>>> print(z)
{
a: ivy.array(0.9),
b: ivy.array(0.25)
}
"""
return self._static_smooth_l1_loss(
self,
target,
beta=beta,
reduction=reduction,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
)
25 changes: 25 additions & 0 deletions ivy/functional/backends/jax/experimental/losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import jax.numpy as jnp
from typing import Optional
from ivy.functional.backends.jax import JaxArray


def smooth_l1_loss(
input: JaxArray,
target: JaxArray,
/,
*,
beta: Optional[float] = 1.0,
reduction: Optional[str] = "mean",
) -> JaxArray:
if beta < 1e-5:
loss = jnp.abs(input - target)
else:
diff = jnp.abs(input - target)
loss = jnp.where(diff < beta, 0.5 * diff**2 / beta, diff - 0.5 * beta)

if reduction == "mean":
return jnp.mean(loss)
elif reduction == "sum":
return jnp.sum(loss)
else:
return loss
30 changes: 30 additions & 0 deletions ivy/functional/backends/numpy/experimental/losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import numpy as np
from typing import Optional
from ivy.functional.backends.numpy.helpers import _scalar_output_to_0d_array
from ivy.func_wrapper import with_unsupported_dtypes
from . import backend_version


# Implementation of smooth_l1_loss in the given format
@with_unsupported_dtypes({"1.25.2 and below": ("bool",)}, backend_version)
@_scalar_output_to_0d_array
def smooth_l1_loss(
input: np.ndarray,
target: np.ndarray,
/,
*,
beta: Optional[float] = 1.0,
reduction: Optional[str] = "mean",
) -> np.ndarray:
if beta < 1e-5:
loss = np.abs(input - target)
else:
diff = np.abs(input - target)
loss = np.where(diff < beta, 0.5 * diff**2 / beta, diff - 0.5 * beta)

if reduction == "mean":
return np.mean(loss)
elif reduction == "sum":
return np.sum(loss)
else:
return loss
30 changes: 30 additions & 0 deletions ivy/functional/backends/paddle/experimental/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,33 @@ def l1_loss(
reduction: Optional[str] = "mean",
) -> paddle.Tensor:
return F.l1_loss(input, target, reduction=reduction)


@with_unsupported_device_and_dtypes(
{
"2.5.1 and below": {
"cpu": (
"int8",
"int16",
"int32",
"int64",
"uint8",
"complex64",
"complex128",
"bool",
)
}
},
backend_version,
)
def smooth_l1_loss(
input: paddle.Tensor,
target: paddle.Tensor,
/,
*,
beta: Optional[float] = 1.0,
reduction: Optional[str] = "mean",
) -> paddle.Tensor:
return paddle.nn.functional.smooth_l1_loss(
input, target, reduction=reduction, beta=beta
)
24 changes: 24 additions & 0 deletions ivy/functional/backends/tensorflow/experimental/losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import tensorflow as tf
from typing import Optional
from ivy.func_wrapper import with_unsupported_dtypes
from . import backend_version


@with_unsupported_dtypes({"2.13.0 and below": "bool"}, backend_version)
def smooth_l1_loss(
input: tf.Tensor,
target: tf.Tensor,
/,
*,
beta: Optional[float] = 1.0,
reduction: Optional[str] = "mean",
) -> tf.Tensor:
diff = tf.abs(input - target)
loss = tf.where(diff < beta, 0.5 * diff**2 / beta, diff - 0.5 * beta)

if reduction == "mean":
return tf.reduce_mean(loss)
elif reduction == "sum":
return tf.reduce_sum(loss)
else:
return loss
30 changes: 30 additions & 0 deletions ivy/functional/backends/torch/experimental/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,33 @@ def l1_loss(
target,
reduction=reduction,
)


@with_unsupported_dtypes(
{
"2.0.1 and below": (
"complex",
"uint8",
"int8",
"int16",
"int32",
"int64",
"bool",
)
},
backend_version,
)
def smooth_l1_loss(
input: torch.Tensor,
target: torch.Tensor,
/,
*,
beta: Optional[float] = 1.0,
reduction: Optional[str] = "mean",
) -> torch.Tensor:
return torch.nn.functional.smooth_l1_loss(
input,
target,
beta=beta,
reduction=reduction,
)
Loading

0 comments on commit a29ce99

Please sign in to comment.