From 7c9cb4d3c3d7b9a8aaea17264f78275a3a7ae57b Mon Sep 17 00:00:00 2001 From: Nassim Berrada <112006029+nassimberrada@users.noreply.github.com> Date: Thu, 10 Nov 2022 12:31:19 +0100 Subject: [PATCH] Ivy fix (#6838) * ivy_fix * fixes Co-authored-by: nassimberrada --- ivy/array/experimental/elementwise.py | 31 +++++++ ivy/container/experimental/elementwise.py | 84 +++++++++++++++++++ .../backends/jax/experimental/elementwise.py | 9 ++ .../numpy/experimental/elementwise.py | 12 +++ .../tensorflow/experimental/elementwise.py | 12 +++ .../torch/experimental/elementwise.py | 12 +++ ivy/functional/experimental/elementwise.py | 34 ++++++++ .../test_core/test_elementwise.py | 37 ++++++++ 8 files changed, 231 insertions(+) diff --git a/ivy/array/experimental/elementwise.py b/ivy/array/experimental/elementwise.py index fbcffe42351c9..7f51118005df0 100644 --- a/ivy/array/experimental/elementwise.py +++ b/ivy/array/experimental/elementwise.py @@ -671,3 +671,34 @@ def allclose( return ivy.allclose( self._data, x2, rtol=rtol, atol=atol, equal_nan=equal_nan, out=out ) + + def fix( + self: ivy.Array, + /, + *, + out: Optional[ivy.Array] = None, + ) -> ivy.Array: + """ivy.Array instance method variant of ivy.fix. This method + simply wraps the function, and so the docstring for ivy.fix also + applies to this method with minimal changes. + + Parameters + ---------- + self + Array input. + out + optional output array, for writing the result to. + + Returns + ------- + ret + Array of floats with elements corresponding to input elements + rounded to nearest integer towards zero, element-wise. + + Examples + -------- + >>> x = ivy.array([2.1, 2.9, -2.1]) + >>> x.fix() + ivy.array([ 2., 2., -2.]) + """ + return ivy.fix(self._data, out=out) diff --git a/ivy/container/experimental/elementwise.py b/ivy/container/experimental/elementwise.py index c7c6a1f93e84a..1d80e78f8f12f 100644 --- a/ivy/container/experimental/elementwise.py +++ b/ivy/container/experimental/elementwise.py @@ -1686,3 +1686,87 @@ def allclose( map_sequences=map_sequences, out=out, ) + + @staticmethod + def static_fix( + x: Union[ivy.Array, ivy.NativeArray, ivy.Container], + /, + *, + key_chains: Optional[Union[List[str], Dict[str, str]]] = None, + to_apply: bool = True, + prune_unapplied: bool = False, + map_sequences: bool = False, + out: Optional[ivy.Container] = None, + ) -> ivy.Container: + """ + ivy.Container static method variant of ivy.fix. This method simply wraps + the function, and so the docstring for ivy.fix also applies to this + method with minimal changes. + + Parameters + ---------- + x + input container with array items. + out + optional output container, for writing the result to. + + Returns + ------- + ret + Container including arrays with element-wise rounding of + input arrays elements. + + Examples + -------- + >>> x = ivy.Container(a=ivy.array([2.1, 2.9, -2.1]),\ + b=ivy.array([3.14])) + >>> ivy.Container.static_fix(x) + { + a: ivy.array([ 2., 2., -2.]) + b: ivy.array([ 3.0 ]) + } + """ + return ContainerBase.multi_map_in_static_method( + "fix", + x, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + out=out, + ) + + def fix( + self: ivy.Container, + /, + *, + out: Optional[ivy.Container] = None, + ) -> ivy.Container: + """ivy.Container instance method variant of ivy.fix. This method simply + wraps the function, and so the docstring for ivy.fix also applies to + this method with minimal changes. + + Parameters + ---------- + self + input container with array items. + out + optional output container, for writing the result to. + + Returns + ------- + ret + Container including arrays with element-wise rounding of + input arrays elements. + + Examples + -------- + >>> x = ivy.Container(a=ivy.array([2.1, 2.9, -2.1]),\ + b=ivy.array([3.14])) + >>> x.fix() + { + a: ivy.array([ 2., 2., -2.]) + b: ivy.array([ 3.0 ]) + } + """ + return self.static_fix(self, out=out) diff --git a/ivy/functional/backends/jax/experimental/elementwise.py b/ivy/functional/backends/jax/experimental/elementwise.py index 2f8c274e1f729..9091a1c346ddf 100644 --- a/ivy/functional/backends/jax/experimental/elementwise.py +++ b/ivy/functional/backends/jax/experimental/elementwise.py @@ -158,3 +158,12 @@ def allclose( out: Optional[JaxArray] = None, ) -> bool: return jnp.allclose(x1, x2, rtol=rtol, atol=atol, equal_nan=equal_nan) + + +def fix( + x: JaxArray, + /, + *, + out: Optional[JaxArray] = None, +) -> JaxArray: + return jnp.fix(x, out=out) diff --git a/ivy/functional/backends/numpy/experimental/elementwise.py b/ivy/functional/backends/numpy/experimental/elementwise.py index d9c3ad6ea1b9d..8a47ca35496eb 100644 --- a/ivy/functional/backends/numpy/experimental/elementwise.py +++ b/ivy/functional/backends/numpy/experimental/elementwise.py @@ -232,3 +232,15 @@ def allclose( isclose.support_native_out = False + + +def fix( + x: np.ndarray, + /, + *, + out: Optional[np.ndarray] = None, +) -> np.ndarray: + return np.fix(x, out=out) + + +fix.support_native_out = True diff --git a/ivy/functional/backends/tensorflow/experimental/elementwise.py b/ivy/functional/backends/tensorflow/experimental/elementwise.py index 7d138cf6c6812..f611c53d33858 100644 --- a/ivy/functional/backends/tensorflow/experimental/elementwise.py +++ b/ivy/functional/backends/tensorflow/experimental/elementwise.py @@ -225,3 +225,15 @@ def allclose( return tf.experimental.numpy.allclose( x1, x2, rtol=rtol, atol=atol, equal_nan=equal_nan ) + + +@with_unsupported_dtypes( + {"2.9.1 and below": ("bfloat16,")}, backend_version +) +def fix( + x: Union[tf.Tensor, tf.Variable], + /, + *, + out: Optional[Union[tf.Tensor, tf.Variable]] = None, +) -> Union[tf.Tensor, tf.Variable]: + return tf.experimental.numpy.fix(x) diff --git a/ivy/functional/backends/torch/experimental/elementwise.py b/ivy/functional/backends/torch/experimental/elementwise.py index ce3a6543a0c40..c871010b1a018 100644 --- a/ivy/functional/backends/torch/experimental/elementwise.py +++ b/ivy/functional/backends/torch/experimental/elementwise.py @@ -235,3 +235,15 @@ def allclose( out: Optional[torch.Tensor] = None, ) -> bool: return torch.allclose(x1, x2, rtol=rtol, atol=atol, equal_nan=equal_nan) + + +def fix( + x: torch.Tensor, + /, + *, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + return torch.fix(x, out=out) + + +fix.support_native_out = True diff --git a/ivy/functional/experimental/elementwise.py b/ivy/functional/experimental/elementwise.py index 0f592f4b4333a..3be6f9de3ee6a 100644 --- a/ivy/functional/experimental/elementwise.py +++ b/ivy/functional/experimental/elementwise.py @@ -776,3 +776,37 @@ def allclose( return ivy.current_backend().allclose( a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, out=out ) + + +@to_native_arrays_and_back +@handle_out_argument +@handle_nestable +def fix( + x: Union[ivy.Array, ivy.NativeArray, float, int, list, tuple], + /, + *, + out: Optional[ivy.Array] = None, +) -> ivy.Array: + """Round an array of floats element-wise to nearest integer towards zero. + The rounded values are returned as floats. + + Parameters + ---------- + x + Array input. + out + optional output array, for writing the result to. + + Returns + ------- + ret + Array of floats with elements corresponding to input elements + rounded to nearest integer towards zero, element-wise. + + Examples + -------- + >>> x = ivy.array([2.1, 2.9, -2.1]) + >>> ivy.fix(x) + ivy.array([ 2., 2., -2.]) + """ + return ivy.current_backend(x).fix(x, out=out) diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_elementwise.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_elementwise.py index b12bc95f2fdf0..5077ef601fb3c 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_elementwise.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_elementwise.py @@ -729,3 +729,40 @@ def test_allclose( atol=atol, equal_nan=equal_nan, ) + + +# fix +@handle_cmd_line_args +@given( + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float", index=2), + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + ), + num_positional_args=helpers.num_positional_args(fn_name="fix"), +) +def test_fix( + dtype_and_x, + with_out, + as_variable, + num_positional_args, + native_array, + container, + instance_method, + fw, +): + input_dtype, x = dtype_and_x + helpers.test_function( + input_dtypes=input_dtype, + as_variable_flags=as_variable, + with_out=with_out, + num_positional_args=num_positional_args, + native_array_flags=native_array, + container_flags=container, + instance_method=instance_method, + fw=fw, + fn_name="fix", + x=np.asarray(x[0], dtype=input_dtype[0]), + )