diff --git a/ivy/array/activations.py b/ivy/array/activations.py index e8d7d7520d00c..31442dd2757d3 100644 --- a/ivy/array/activations.py +++ b/ivy/array/activations.py @@ -135,3 +135,29 @@ def softplus( """ return ivy.softplus(self._data, beta=beta, threshold=threshold, out=out) + + def log_softmax( + self: ivy.Array, + /, + *, + axis: Optional[int] = None, + out: Optional[ivy.Array] = None, + ) -> ivy.Array: + """ + ivy.Array instance method variant of ivy.log_softmax. + This method simply wraps the function, + and so the docstring for ivy.log_softmax also applies to this method + with minimal changes. + + Examples + -------- + >>> x = ivy.array([-1.0, -0.98, 2.3]) + >>> y = x.log_softmax() + >>> print(y) + ivy.array([-3.37, -3.35, -0.0719]) + + >>> x = ivy.array([2.0, 3.4, -4.2]) + >>> y = x.log_softmax(x) + ivy.array([-1.62, -0.221, -7.82 ]) + """ + return ivy.log_softmax(self._data, axis=axis, out=out) diff --git a/ivy/container/activations.py b/ivy/container/activations.py index da62c25beeaba..9a0fd103c4b8e 100644 --- a/ivy/container/activations.py +++ b/ivy/container/activations.py @@ -758,3 +758,141 @@ def softplus( map_sequences=map_sequences, out=out, ) + + @staticmethod + def static_log_softmax( + x: Union[ivy.Array, ivy.NativeArray, ivy.Container], + /, + *, + axis: Optional[ivy.Container] = None, + key_chains: Optional[Union[List[str], Dict[str, str]]] = None, + to_apply: bool = True, + prune_unapplied: bool = False, + map_sequences: bool = False, + out: Optional[ivy.Container] = None, + ) -> ivy.Container: + """ + ivy.Container static method variant of ivy.log_softmax. + This method simply wraps the function, and so the docstring + for ivy.log_softmax also applies to this method with minimal changes. + + Parameters + ---------- + x + input container. + axis + the axis or axes along which the log_softmax should be computed + key_chains + The key-chains to apply or not apply the method to. Default is None. + to_apply + If True, the method will be applied to key_chains, otherwise key_chains + will be skipped. Default is True. + prune_unapplied + Whether to prune key_chains for which the function was not applied. + Default is False. + map_sequences + Whether to also map method to sequences (lists, tuples). Default is False. + out + optional output container, for writing the result to. It must have a shape + that the inputs broadcast to. + + Returns + ------- + ret + a container with the log_softmax unit function applied element-wise. + + Examples + -------- + >>> x = ivy.Container(a=ivy.array([-1.0, -0.98, 2.3])) + >>> y = ivy.Container.static_log_softmax(x) + >>> print(y) + { + a: ivy.array([-3.37, -3.35, -0.0719]) + } + + >>> x = ivy.Container(a=ivy.array([1.0, 2.4]), b=ivy.array([-0.2, -1.0])) + >>> y = ivy.Container.static_log_softmax(x) + >>> print(y) + { + a: ivy.array([-1.62, -0.22]), + b: ivy.array([-0.371, -1.17]) + } + """ + return ContainerBase.multi_map_in_static_method( + "log_softmax", + x, + axis=axis, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + out=out, + ) + + def log_softmax( + self: ivy.Container, + /, + *, + axis: Optional[ivy.Container] = None, + key_chains: Optional[Union[List[str], Dict[str, str]]] = None, + to_apply: bool = True, + prune_unapplied: bool = False, + map_sequences: bool = False, + out: Optional[ivy.Container] = None, + ): + """ + ivy.Container instance method variant of ivy.log_softmax. + This method simply wraps the function, and so the docstring + for ivy.log_softmax also applies to this method with minimal changes. + + Parameters + ---------- + self + input container. + axis + the axis or axes along which the log_softmax should be computed + key_chains + The key-chains to apply or not apply the method to. Default is None. + to_apply + If True, the method will be applied to key_chains, otherwise key_chains + will be skipped. Default is True. + prune_unapplied + Whether to prune key_chains for which the function was not applied. + Default is False. + map_sequences + Whether to also map method to sequences (lists, tuples). Default is False. + out + optional output container, for writing the result to. It must have a shape + that the inputs broadcast to. + + Returns + ------- + ret + a container with the log_softmax unit function applied element-wise. + + Examples + -------- + >>> x = ivy.Container(a=ivy.array([-1.0, -0.98, 2.3])) + >>> y = x.log_softmax() + >>> print(y) + { + a: ivy.array([-3.37, -3.35, -0.0719]) + } + + >>> x = ivy.Container(a=ivy.array([1.0, 2.4]), b=ivy.array([-0.2, -1.0])) + >>> y = x.log_softmax() + >>> print(y) + { + a: ivy.array([-1.62, -0.22]), + b: ivy.array([-0.371, -1.17]) + } + """ + return self.static_log_softmax( + self, + axis=axis, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + out=out, + ) diff --git a/ivy/functional/backends/jax/activations.py b/ivy/functional/backends/jax/activations.py index 311d625ecde58..c48acb8fd734e 100644 --- a/ivy/functional/backends/jax/activations.py +++ b/ivy/functional/backends/jax/activations.py @@ -66,3 +66,11 @@ def softplus( if threshold is not None: return jnp.where(x_beta > threshold, x, res) return res + + +def log_softmax( + x: JaxArray, /, *, axis: Optional[int] = None, out: Optional[JaxArray] = None +): + if axis is None: + axis = -1 + return jax.nn.log_softmax(x, axis) diff --git a/ivy/functional/backends/numpy/activations.py b/ivy/functional/backends/numpy/activations.py index 70cd1a200b14b..223e599a7ad79 100644 --- a/ivy/functional/backends/numpy/activations.py +++ b/ivy/functional/backends/numpy/activations.py @@ -90,3 +90,25 @@ def softplus( softplus.support_native_out = True + + +@_handle_0_dim_output +def log_softmax( + x: np.ndarray, /, *, axis: Optional[int] = None, out: Optional[np.ndarray] = None +) -> np.ndarray: + x_max = np.max(x, axis=axis, keepdims=True) + if x_max.ndim > 0: + x_max[~np.isfinite(x_max)] = 0 + elif not np.isfinite(x_max): + x_max = 0 + exp_tmp = np.exp(x - x_max) + + with np.errstate(divide="ignore"): + s = np.sum(exp_tmp, axis=axis, keepdims=True) + ret = np.log(s) + + ret = x - x_max - ret + return ret + + +log_softmax.support_native_out = True diff --git a/ivy/functional/backends/tensorflow/activations.py b/ivy/functional/backends/tensorflow/activations.py index 4f8a9dc3e9e81..a66a0cc0a3dcc 100644 --- a/ivy/functional/backends/tensorflow/activations.py +++ b/ivy/functional/backends/tensorflow/activations.py @@ -58,3 +58,9 @@ def softplus( if threshold is not None: return tf.where(x_beta > threshold, x, res) return res + + +def log_softmax( + x: Tensor, /, *, axis: Optional[int] = None, out: Optional[Tensor] = None +): + return tf.nn.log_softmax(x, axis) diff --git a/ivy/functional/backends/torch/activations.py b/ivy/functional/backends/torch/activations.py index cb1c3b013cdd2..a64ee2fe89dc8 100644 --- a/ivy/functional/backends/torch/activations.py +++ b/ivy/functional/backends/torch/activations.py @@ -85,3 +85,16 @@ def softplus( softplus.unsupported_dtypes = ("float16", "bfloat16") + + +def log_softmax( + x: torch.Tensor, + /, + *, + axis: Optional[int] = None, + out: Optional[torch.Tensor] = None, +): + return torch.nn.functional.log_softmax(x, axis) + + +log_softmax.unsupported_dtypes = ("float16", "bfloat16") diff --git a/ivy/functional/ivy/activations.py b/ivy/functional/ivy/activations.py index eb048f08c967e..0baf1c8d3ca18 100644 --- a/ivy/functional/ivy/activations.py +++ b/ivy/functional/ivy/activations.py @@ -385,3 +385,73 @@ def softplus( """ return current_backend(x).softplus(x, beta=beta, threshold=threshold, out=out) + + +@to_native_arrays_and_back +@handle_out_argument +@handle_nestable +@handle_exceptions +def log_softmax( + x: Union[ivy.Array, ivy.NativeArray], + /, + *, + axis: Optional[int] = -1, + out: Optional[ivy.Array] = None, +) -> ivy.Array: + """Applies the log_softmax function element-wise. + + Parameters + ---------- + x + Input array. + axis + The dimension log_softmax would be performed on. The default is -1 + which indicates the last dimension. + out + optional output array, for writing the result to. It must have a shape that the + inputs broadcast to. + + Returns + ------- + ret + The output array with log_softmax applied element-wise to input. + + Examples + -------- + With :code: `ivy.Array` input: + + >>> x = ivy.array([-1.0, -0.98]) + >>> y = ivy.log_softmax(x) + >>> print(y) + ivy.array([-0.703, -0.683]) + + >>> x = ivy.array([1.0, 2.0, 3.0]) + >>> y = ivy.log_softmax(x) + >>> print(y) + ivy.array([-2.41, -1.41, -0.408]) + + With :code: `ivy.NativeArray` input: + + >>> x = ivy.native_array([1.5, 0.5, 1.0]) + >>> y = ivy.log_softmax(x) + >>> print(y) + ivy.array([-0.68, -1.68, -1.18]) + + With :code: `ivy.Container` input: + + >>> x = ivy.Container(a=ivy.array([1.5, 0.5, 1.0])) + >>> y = ivy.log_softmax(x) + >>> print(y) + { + a: ivy.array([-0.68, -1.68, -1.18]) + } + + >>> x = ivy.Container(a=ivy.array([1.0, 2.0]), b=ivy.array([0.4, -0.2])) + >>> y = ivy.log_softmax(x) + >>> print(y) + { + a: ivy.array([-1.31, -0.313]), + b: ivy.array([-0.437, -1.04]) + } + """ + return current_backend(x).log_softmax(x, axis=axis, out=out) diff --git a/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py b/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py index 87c7bde4997a0..817610c671a47 100644 --- a/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py +++ b/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py @@ -264,3 +264,46 @@ def test_softplus( beta=beta, threshold=threshold, ) + + +# log_softmax +@handle_cmd_line_args +@given( + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=1, + large_abs_safety_factor=8, + small_abs_safety_factor=8, + safety_factor_scale="log", + ), + axis=helpers.ints(min_value=-1, max_value=0), + num_positional_args=helpers.num_positional_args(fn_name="log_softmax"), +) +def test_log_softmax( + *, + dtype_and_x, + as_variable, + axis, + with_out, + num_positional_args, + container, + instance_method, + native_array, + fw, +): + dtype, x = dtype_and_x + helpers.test_function( + input_dtypes=dtype, + as_variable_flags=as_variable, + with_out=with_out, + native_array_flags=native_array, + fw=fw, + num_positional_args=num_positional_args, + container_flags=container, + instance_method=instance_method, + fn_name="log_softmax", + rtol_=1e-02, + atol_=1e-02, + x=np.asarray(x, dtype=dtype), + axis=axis, + )