Skip to content

Commit

Permalink
Ivy fix (#6838)
Browse files Browse the repository at this point in the history
* ivy_fix

* fixes

Co-authored-by: nassimberrada <Nassim>
  • Loading branch information
nassimberrada authored Nov 10, 2022
1 parent a19d53e commit 7c9cb4d
Show file tree
Hide file tree
Showing 8 changed files with 231 additions and 0 deletions.
31 changes: 31 additions & 0 deletions ivy/array/experimental/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,3 +671,34 @@ def allclose(
return ivy.allclose(
self._data, x2, rtol=rtol, atol=atol, equal_nan=equal_nan, out=out
)

def fix(
self: ivy.Array,
/,
*,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""ivy.Array instance method variant of ivy.fix. This method
simply wraps the function, and so the docstring for ivy.fix also
applies to this method with minimal changes.
Parameters
----------
self
Array input.
out
optional output array, for writing the result to.
Returns
-------
ret
Array of floats with elements corresponding to input elements
rounded to nearest integer towards zero, element-wise.
Examples
--------
>>> x = ivy.array([2.1, 2.9, -2.1])
>>> x.fix()
ivy.array([ 2., 2., -2.])
"""
return ivy.fix(self._data, out=out)
84 changes: 84 additions & 0 deletions ivy/container/experimental/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1686,3 +1686,87 @@ def allclose(
map_sequences=map_sequences,
out=out,
)

@staticmethod
def static_fix(
x: Union[ivy.Array, ivy.NativeArray, ivy.Container],
/,
*,
key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
to_apply: bool = True,
prune_unapplied: bool = False,
map_sequences: bool = False,
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
ivy.Container static method variant of ivy.fix. This method simply wraps
the function, and so the docstring for ivy.fix also applies to this
method with minimal changes.
Parameters
----------
x
input container with array items.
out
optional output container, for writing the result to.
Returns
-------
ret
Container including arrays with element-wise rounding of
input arrays elements.
Examples
--------
>>> x = ivy.Container(a=ivy.array([2.1, 2.9, -2.1]),\
b=ivy.array([3.14]))
>>> ivy.Container.static_fix(x)
{
a: ivy.array([ 2., 2., -2.])
b: ivy.array([ 3.0 ])
}
"""
return ContainerBase.multi_map_in_static_method(
"fix",
x,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
)

def fix(
self: ivy.Container,
/,
*,
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""ivy.Container instance method variant of ivy.fix. This method simply
wraps the function, and so the docstring for ivy.fix also applies to
this method with minimal changes.
Parameters
----------
self
input container with array items.
out
optional output container, for writing the result to.
Returns
-------
ret
Container including arrays with element-wise rounding of
input arrays elements.
Examples
--------
>>> x = ivy.Container(a=ivy.array([2.1, 2.9, -2.1]),\
b=ivy.array([3.14]))
>>> x.fix()
{
a: ivy.array([ 2., 2., -2.])
b: ivy.array([ 3.0 ])
}
"""
return self.static_fix(self, out=out)
9 changes: 9 additions & 0 deletions ivy/functional/backends/jax/experimental/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,12 @@ def allclose(
out: Optional[JaxArray] = None,
) -> bool:
return jnp.allclose(x1, x2, rtol=rtol, atol=atol, equal_nan=equal_nan)


def fix(
x: JaxArray,
/,
*,
out: Optional[JaxArray] = None,
) -> JaxArray:
return jnp.fix(x, out=out)
12 changes: 12 additions & 0 deletions ivy/functional/backends/numpy/experimental/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,15 @@ def allclose(


isclose.support_native_out = False


def fix(
x: np.ndarray,
/,
*,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
return np.fix(x, out=out)


fix.support_native_out = True
12 changes: 12 additions & 0 deletions ivy/functional/backends/tensorflow/experimental/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,15 @@ def allclose(
return tf.experimental.numpy.allclose(
x1, x2, rtol=rtol, atol=atol, equal_nan=equal_nan
)


@with_unsupported_dtypes(
{"2.9.1 and below": ("bfloat16,")}, backend_version
)
def fix(
x: Union[tf.Tensor, tf.Variable],
/,
*,
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> Union[tf.Tensor, tf.Variable]:
return tf.experimental.numpy.fix(x)
12 changes: 12 additions & 0 deletions ivy/functional/backends/torch/experimental/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,15 @@ def allclose(
out: Optional[torch.Tensor] = None,
) -> bool:
return torch.allclose(x1, x2, rtol=rtol, atol=atol, equal_nan=equal_nan)


def fix(
x: torch.Tensor,
/,
*,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.fix(x, out=out)


fix.support_native_out = True
34 changes: 34 additions & 0 deletions ivy/functional/experimental/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,3 +776,37 @@ def allclose(
return ivy.current_backend().allclose(
a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, out=out
)


@to_native_arrays_and_back
@handle_out_argument
@handle_nestable
def fix(
x: Union[ivy.Array, ivy.NativeArray, float, int, list, tuple],
/,
*,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""Round an array of floats element-wise to nearest integer towards zero.
The rounded values are returned as floats.
Parameters
----------
x
Array input.
out
optional output array, for writing the result to.
Returns
-------
ret
Array of floats with elements corresponding to input elements
rounded to nearest integer towards zero, element-wise.
Examples
--------
>>> x = ivy.array([2.1, 2.9, -2.1])
>>> ivy.fix(x)
ivy.array([ 2., 2., -2.])
"""
return ivy.current_backend(x).fix(x, out=out)
Original file line number Diff line number Diff line change
Expand Up @@ -729,3 +729,40 @@ def test_allclose(
atol=atol,
equal_nan=equal_nan,
)


# fix
@handle_cmd_line_args
@given(
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float", index=2),
min_num_dims=1,
max_num_dims=3,
min_dim_size=1,
max_dim_size=3,
),
num_positional_args=helpers.num_positional_args(fn_name="fix"),
)
def test_fix(
dtype_and_x,
with_out,
as_variable,
num_positional_args,
native_array,
container,
instance_method,
fw,
):
input_dtype, x = dtype_and_x
helpers.test_function(
input_dtypes=input_dtype,
as_variable_flags=as_variable,
with_out=with_out,
num_positional_args=num_positional_args,
native_array_flags=native_array,
container_flags=container,
instance_method=instance_method,
fw=fw,
fn_name="fix",
x=np.asarray(x[0], dtype=input_dtype[0]),
)

0 comments on commit 7c9cb4d

Please sign in to comment.