From e1015bc1c0967599185d7d8d5fa792fa780c2639 Mon Sep 17 00:00:00 2001 From: Mohammed Ayman Date: Tue, 11 Jul 2023 12:35:23 +0300 Subject: [PATCH] ivy.fill_diagonal (#19072) Co-authored-by: sherry30 --- .../array/experimental/manipulation.py | 15 +++++++ .../container/experimental/manipulation.py | 42 +++++++++++++++++++ .../backends/jax/experimental/manipulation.py | 21 ++++++++++ .../numpy/experimental/manipulation.py | 11 +++++ .../paddle/experimental/manipulation.py | 36 ++++++++++++++++ .../tensorflow/experimental/manipulation.py | 25 +++++++++++ .../torch/experimental/manipulation.py | 31 ++++++++++++++ .../ivy/experimental/manipulation.py | 31 ++++++++++++++ .../test_core/test_manipulation.py | 39 +++++++++++++++++ 9 files changed, 251 insertions(+) diff --git a/ivy/data_classes/array/experimental/manipulation.py b/ivy/data_classes/array/experimental/manipulation.py index e59be37c29c95..bc6d61c2c67cc 100644 --- a/ivy/data_classes/array/experimental/manipulation.py +++ b/ivy/data_classes/array/experimental/manipulation.py @@ -1065,3 +1065,18 @@ def unique_consecutive( changes. """ return ivy.unique_consecutive(self._data, axis=axis) + + def fill_diagonal( + self: ivy.Array, + v: Union[int, float], + /, + *, + wrap: bool = False, + ) -> ivy.Array: + """ + ivy.Array instance method variant of ivy.fill_diag. + + This method simply wraps the function, and so the docstring for + ivy.fill_diag also applies to this method with minimal changes. + """ + return ivy.fill_diagonal(self._data, v, wrap=wrap) diff --git a/ivy/data_classes/container/experimental/manipulation.py b/ivy/data_classes/container/experimental/manipulation.py index 573e55963bdaa..a95433bb1bf2b 100644 --- a/ivy/data_classes/container/experimental/manipulation.py +++ b/ivy/data_classes/container/experimental/manipulation.py @@ -2940,3 +2940,45 @@ def unique_consecutive( prune_unapplied=prune_unapplied, map_sequences=map_sequences, ) + + @staticmethod + def _static_fill_diagonal( + a: Union[ivy.Array, ivy.NativeArray, ivy.Container], + v: Union[ivy.Array, ivy.NativeArray], + /, + *, + wrap: bool = False, + ) -> ivy.Container: + """ + ivy.Container static method variant of ivy.fill_diagonal. + + This method simply wraps the function, and so the docstring for + ivy.fill_diagonal also applies to this method with minimal + changes. + """ + return ContainerBase.cont_multi_map_in_function( + "fill_diagonal", + a, + v, + wrap=wrap, + ) + + def fill_diagonal( + self: ivy.Container, + v: Union[int, float], + /, + *, + wrap: bool = False, + ) -> ivy.Container: + """ + ivy.Container instance method variant of ivy.fill_diagonal. + + This method simply wraps the function, and so the docstring for + ivy.fill_diagonal also applies to this method with minimal + changes. + """ + return self._static_fill_diagonal( + self, + v, + wrap=wrap, + ) diff --git a/ivy/functional/backends/jax/experimental/manipulation.py b/ivy/functional/backends/jax/experimental/manipulation.py index 5de33b1d11dca..df0cd6b628f2c 100644 --- a/ivy/functional/backends/jax/experimental/manipulation.py +++ b/ivy/functional/backends/jax/experimental/manipulation.py @@ -395,3 +395,24 @@ def unique_consecutive( inverse_indices, counts, ) + + +def fill_diagonal( + a: JaxArray, + v: Union[int, float], + /, + *, + wrap: bool = False, +) -> jnp.DeviceArray: + shape = jnp.array(a.shape) + end = None + if len(shape) == 2: + step = shape[1] + 1 + if not wrap: + end = shape[1] * shape[1] + else: + step = 1 + (jnp.cumprod(shape[:-1])).sum() + a = jnp.reshape(a, (-1,)) + a = a.at[:end:step].set(jnp.array(v).astype(a.dtype)) + a = jnp.reshape(a, shape) + return a diff --git a/ivy/functional/backends/numpy/experimental/manipulation.py b/ivy/functional/backends/numpy/experimental/manipulation.py index 09049b65d80b7..3cbd7a3d471d6 100644 --- a/ivy/functional/backends/numpy/experimental/manipulation.py +++ b/ivy/functional/backends/numpy/experimental/manipulation.py @@ -494,3 +494,14 @@ def unique_consecutive( inverse_indices, counts, ) + + +def fill_diagonal( + a: np.ndarray, + v: Union[int, float], + /, + *, + wrap: bool = False, +) -> np.ndarray: + np.fill_diagonal(a, v, wrap=wrap) + return a diff --git a/ivy/functional/backends/paddle/experimental/manipulation.py b/ivy/functional/backends/paddle/experimental/manipulation.py index f92e3ef36082c..15d78666a8de7 100644 --- a/ivy/functional/backends/paddle/experimental/manipulation.py +++ b/ivy/functional/backends/paddle/experimental/manipulation.py @@ -1,6 +1,8 @@ from collections import namedtuple from typing import Optional, Union, Sequence, Tuple, NamedTuple, List from numbers import Number + + from .. import backend_version from ivy.func_wrapper import with_unsupported_device_and_dtypes import paddle @@ -592,3 +594,37 @@ def unique_consecutive( inverse_indices, counts, ) + + +@with_unsupported_device_and_dtypes( + {"2.5.0 and below": {"cpu": ("int8", "int16", "uint8", "float16")}}, backend_version +) +def fill_diagonal( + a: paddle.Tensor, + v: Union[int, float], + /, + *, + wrap: bool = False, +) -> paddle.Tensor: + shape = a.shape + max_end = paddle.prod(paddle.to_tensor(shape)) + end = max_end + if len(shape) == 2: + step = shape[1] + 1 + if not wrap: + end = shape[1] * shape[1] + else: + step = 1 + (paddle.cumprod(paddle.to_tensor(shape[:-1]), dim=0)).sum() + end = max_end if end > max_end else end + a = paddle.reshape(a, (-1,)) + w = paddle.zeros(a.shape, dtype=bool) + ins = paddle.arange(0, max_end) + steps = paddle.arange(0, end, step) + + for i in steps: + i = ins == i + w = paddle.logical_or(w, i) + v = paddle.to_tensor(v, dtype=a.dtype) + a = paddle.where(w, v, a) + a = paddle.reshape(a, shape) + return a diff --git a/ivy/functional/backends/tensorflow/experimental/manipulation.py b/ivy/functional/backends/tensorflow/experimental/manipulation.py index 0f5a03553ba8f..ec8b4bd367ee3 100644 --- a/ivy/functional/backends/tensorflow/experimental/manipulation.py +++ b/ivy/functional/backends/tensorflow/experimental/manipulation.py @@ -342,3 +342,28 @@ def unique_consecutive( tf.cast(inverse_indices, tf.int64), tf.cast(counts, tf.int64), ) + + +def fill_diagonal( + a: tf.Tensor, + v: Union[int, float], + /, + *, + wrap: bool = False, +): + shape = tf.shape(a) + max_end = tf.math.reduce_prod(shape) + end = max_end + if len(shape) == 2: + step = shape[1] + 1 + if not wrap: + end = shape[1] * shape[1] + else: + step = 1 + tf.reduce_sum(tf.math.cumprod(shape[:-1])) + a = tf.reshape(a, (-1,)) + end = min(end, max_end) + indices = [[i] for i in range(0, end, step)] + ups = tf.convert_to_tensor([v] * len(indices), dtype=a.dtype) + a = tf.tensor_scatter_nd_update(a, indices, ups) + a = tf.reshape(a, shape) + return a diff --git a/ivy/functional/backends/torch/experimental/manipulation.py b/ivy/functional/backends/torch/experimental/manipulation.py index 26b862846ade3..e0d07dc57020c 100644 --- a/ivy/functional/backends/torch/experimental/manipulation.py +++ b/ivy/functional/backends/torch/experimental/manipulation.py @@ -364,3 +364,34 @@ def unique_consecutive( inverse_indices, counts, ) + + +def fill_diagonal( + a: torch.Tensor, + v: Union[int, float], + /, + *, + wrap: bool = False, +) -> torch.Tensor: + shape = a.shape + max_end = torch.prod(torch.tensor(shape)) + end = max_end + if len(shape) == 2: + step = shape[1] + 1 + if not wrap: + end = shape[1] * shape[1] + else: + step = 1 + (torch.cumprod(torch.tensor(shape[:-1]), 0)).sum() + + end = max_end if end > max_end else end + a = torch.reshape(a, (-1,)) + w = torch.zeros(a.shape, dtype=bool).to(a.device) + ins = torch.arange(0, max_end).to(a.device) + steps = torch.arange(0, end, step).to(a.device) + + for i in steps: + i = ins == i + w = torch.logical_or(w, i) + a = torch.where(w, v, a) + a = torch.reshape(a, shape) + return a diff --git a/ivy/functional/ivy/experimental/manipulation.py b/ivy/functional/ivy/experimental/manipulation.py index 3ee3a6293d069..7f8879487c909 100644 --- a/ivy/functional/ivy/experimental/manipulation.py +++ b/ivy/functional/ivy/experimental/manipulation.py @@ -2095,3 +2095,34 @@ def unique_consecutive( counts=ivy.array([2, 2, 1, 2, 1])) """ return ivy.current_backend(x).unique_consecutive(x, axis=axis) + + +@handle_exceptions +@handle_nestable +@handle_array_like_without_promotion +@to_native_arrays_and_back +@handle_array_function +def fill_diagonal( + a: Union[ivy.Array, ivy.NativeArray], + v: Union[int, float], + /, + *, + wrap: bool = False, +) -> Union[ivy.Array, ivy.NativeArray]: + """ + Fill the main diagonal of the given array of any dimensionality.. + + Parameters + ---------- + a + Array at least 2D. + v + The value to write on the diagonal. + wrap + The diagonal ‘wrapped’ after N columns for tall matrices. + Returns + ------- + ret + Array with the diagonal filled. + """ + return ivy.current_backend(a).fill_diag(a, v, wrap=wrap) diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_manipulation.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_manipulation.py index bb91f34f070fc..2f55e27b39d57 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_manipulation.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_manipulation.py @@ -970,3 +970,42 @@ def test_unique_consecutive( x=x[0], axis=axis, ) + + +# fill_diag +@handle_test( + fn_tree="fill_diagonal", + dt_a=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=2, + max_num_dims=4, + min_dim_size=3, + max_dim_size=3, + ), + v=st.sampled_from([1, 2, 3, 10]), + wrap=st.booleans(), + test_with_out=st.just(False), +) +def test_fill_diagonal( + *, + dt_a, + v, + wrap, + test_flags, + backend_fw, + fn_name, + on_device, + ground_truth_backend, +): + dt, a = dt_a + helpers.test_function( + ground_truth_backend=ground_truth_backend, + input_dtypes=dt, + test_flags=test_flags, + on_device=on_device, + fw=backend_fw, + fn_name=fn_name, + a=a[0], + v=v, + wrap=wrap, + )