Skip to content

Commit

Permalink
ivy.fill_diagonal (#19072)
Browse files Browse the repository at this point in the history
Co-authored-by: sherry30 <[email protected]>
  • Loading branch information
mohame54 and sherry30 authored Jul 11, 2023
1 parent aaba3f3 commit e1015bc
Show file tree
Hide file tree
Showing 9 changed files with 251 additions and 0 deletions.
15 changes: 15 additions & 0 deletions ivy/data_classes/array/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,3 +1065,18 @@ def unique_consecutive(
changes.
"""
return ivy.unique_consecutive(self._data, axis=axis)

def fill_diagonal(
self: ivy.Array,
v: Union[int, float],
/,
*,
wrap: bool = False,
) -> ivy.Array:
"""
ivy.Array instance method variant of ivy.fill_diag.
This method simply wraps the function, and so the docstring for
ivy.fill_diag also applies to this method with minimal changes.
"""
return ivy.fill_diagonal(self._data, v, wrap=wrap)
42 changes: 42 additions & 0 deletions ivy/data_classes/container/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2940,3 +2940,45 @@ def unique_consecutive(
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
)

@staticmethod
def _static_fill_diagonal(
a: Union[ivy.Array, ivy.NativeArray, ivy.Container],
v: Union[ivy.Array, ivy.NativeArray],
/,
*,
wrap: bool = False,
) -> ivy.Container:
"""
ivy.Container static method variant of ivy.fill_diagonal.
This method simply wraps the function, and so the docstring for
ivy.fill_diagonal also applies to this method with minimal
changes.
"""
return ContainerBase.cont_multi_map_in_function(
"fill_diagonal",
a,
v,
wrap=wrap,
)

def fill_diagonal(
self: ivy.Container,
v: Union[int, float],
/,
*,
wrap: bool = False,
) -> ivy.Container:
"""
ivy.Container instance method variant of ivy.fill_diagonal.
This method simply wraps the function, and so the docstring for
ivy.fill_diagonal also applies to this method with minimal
changes.
"""
return self._static_fill_diagonal(
self,
v,
wrap=wrap,
)
21 changes: 21 additions & 0 deletions ivy/functional/backends/jax/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,3 +395,24 @@ def unique_consecutive(
inverse_indices,
counts,
)


def fill_diagonal(
a: JaxArray,
v: Union[int, float],
/,
*,
wrap: bool = False,
) -> jnp.DeviceArray:
shape = jnp.array(a.shape)
end = None
if len(shape) == 2:
step = shape[1] + 1
if not wrap:
end = shape[1] * shape[1]
else:
step = 1 + (jnp.cumprod(shape[:-1])).sum()
a = jnp.reshape(a, (-1,))
a = a.at[:end:step].set(jnp.array(v).astype(a.dtype))
a = jnp.reshape(a, shape)
return a
11 changes: 11 additions & 0 deletions ivy/functional/backends/numpy/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,3 +494,14 @@ def unique_consecutive(
inverse_indices,
counts,
)


def fill_diagonal(
a: np.ndarray,
v: Union[int, float],
/,
*,
wrap: bool = False,
) -> np.ndarray:
np.fill_diagonal(a, v, wrap=wrap)
return a
36 changes: 36 additions & 0 deletions ivy/functional/backends/paddle/experimental/manipulation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from collections import namedtuple
from typing import Optional, Union, Sequence, Tuple, NamedTuple, List
from numbers import Number


from .. import backend_version
from ivy.func_wrapper import with_unsupported_device_and_dtypes
import paddle
Expand Down Expand Up @@ -592,3 +594,37 @@ def unique_consecutive(
inverse_indices,
counts,
)


@with_unsupported_device_and_dtypes(
{"2.5.0 and below": {"cpu": ("int8", "int16", "uint8", "float16")}}, backend_version
)
def fill_diagonal(
a: paddle.Tensor,
v: Union[int, float],
/,
*,
wrap: bool = False,
) -> paddle.Tensor:
shape = a.shape
max_end = paddle.prod(paddle.to_tensor(shape))
end = max_end
if len(shape) == 2:
step = shape[1] + 1
if not wrap:
end = shape[1] * shape[1]
else:
step = 1 + (paddle.cumprod(paddle.to_tensor(shape[:-1]), dim=0)).sum()
end = max_end if end > max_end else end
a = paddle.reshape(a, (-1,))
w = paddle.zeros(a.shape, dtype=bool)
ins = paddle.arange(0, max_end)
steps = paddle.arange(0, end, step)

for i in steps:
i = ins == i
w = paddle.logical_or(w, i)
v = paddle.to_tensor(v, dtype=a.dtype)
a = paddle.where(w, v, a)
a = paddle.reshape(a, shape)
return a
25 changes: 25 additions & 0 deletions ivy/functional/backends/tensorflow/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,28 @@ def unique_consecutive(
tf.cast(inverse_indices, tf.int64),
tf.cast(counts, tf.int64),
)


def fill_diagonal(
a: tf.Tensor,
v: Union[int, float],
/,
*,
wrap: bool = False,
):
shape = tf.shape(a)
max_end = tf.math.reduce_prod(shape)
end = max_end
if len(shape) == 2:
step = shape[1] + 1
if not wrap:
end = shape[1] * shape[1]
else:
step = 1 + tf.reduce_sum(tf.math.cumprod(shape[:-1]))
a = tf.reshape(a, (-1,))
end = min(end, max_end)
indices = [[i] for i in range(0, end, step)]
ups = tf.convert_to_tensor([v] * len(indices), dtype=a.dtype)
a = tf.tensor_scatter_nd_update(a, indices, ups)
a = tf.reshape(a, shape)
return a
31 changes: 31 additions & 0 deletions ivy/functional/backends/torch/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,3 +364,34 @@ def unique_consecutive(
inverse_indices,
counts,
)


def fill_diagonal(
a: torch.Tensor,
v: Union[int, float],
/,
*,
wrap: bool = False,
) -> torch.Tensor:
shape = a.shape
max_end = torch.prod(torch.tensor(shape))
end = max_end
if len(shape) == 2:
step = shape[1] + 1
if not wrap:
end = shape[1] * shape[1]
else:
step = 1 + (torch.cumprod(torch.tensor(shape[:-1]), 0)).sum()

end = max_end if end > max_end else end
a = torch.reshape(a, (-1,))
w = torch.zeros(a.shape, dtype=bool).to(a.device)
ins = torch.arange(0, max_end).to(a.device)
steps = torch.arange(0, end, step).to(a.device)

for i in steps:
i = ins == i
w = torch.logical_or(w, i)
a = torch.where(w, v, a)
a = torch.reshape(a, shape)
return a
31 changes: 31 additions & 0 deletions ivy/functional/ivy/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2095,3 +2095,34 @@ def unique_consecutive(
counts=ivy.array([2, 2, 1, 2, 1]))
"""
return ivy.current_backend(x).unique_consecutive(x, axis=axis)


@handle_exceptions
@handle_nestable
@handle_array_like_without_promotion
@to_native_arrays_and_back
@handle_array_function
def fill_diagonal(
a: Union[ivy.Array, ivy.NativeArray],
v: Union[int, float],
/,
*,
wrap: bool = False,
) -> Union[ivy.Array, ivy.NativeArray]:
"""
Fill the main diagonal of the given array of any dimensionality..
Parameters
----------
a
Array at least 2D.
v
The value to write on the diagonal.
wrap
The diagonal ‘wrapped’ after N columns for tall matrices.
Returns
-------
ret
Array with the diagonal filled.
"""
return ivy.current_backend(a).fill_diag(a, v, wrap=wrap)
Original file line number Diff line number Diff line change
Expand Up @@ -970,3 +970,42 @@ def test_unique_consecutive(
x=x[0],
axis=axis,
)


# fill_diag
@handle_test(
fn_tree="fill_diagonal",
dt_a=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float"),
min_num_dims=2,
max_num_dims=4,
min_dim_size=3,
max_dim_size=3,
),
v=st.sampled_from([1, 2, 3, 10]),
wrap=st.booleans(),
test_with_out=st.just(False),
)
def test_fill_diagonal(
*,
dt_a,
v,
wrap,
test_flags,
backend_fw,
fn_name,
on_device,
ground_truth_backend,
):
dt, a = dt_a
helpers.test_function(
ground_truth_backend=ground_truth_backend,
input_dtypes=dt,
test_flags=test_flags,
on_device=on_device,
fw=backend_fw,
fn_name=fn_name,
a=a[0],
v=v,
wrap=wrap,
)

0 comments on commit e1015bc

Please sign in to comment.