Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added function eigvals #8945

Merged
merged 10 commits into from
Jan 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions ivy/array/experimental/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,20 @@ def eig(
)
"""
return ivy.eig(self._data)

def eigvals(
self: ivy.Array,
/,
) -> ivy.Array:
"""
ivy.Array instance method variant of ivy.eigvals.
This method simply wraps the function, and so the docstring for
ivy.eigvals also applies to this method with minimal changes.

Examples
--------
>>> x = ivy.array([[1,2], [3,4]])
>>> x.eigvals()
ivy.array([-0.37228132+0.j, 5.37228132+0.j])
"""
return ivy.eigvals(self._data)
94 changes: 94 additions & 0 deletions ivy/container/experimental/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,3 +325,97 @@ def eig(
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
)

@staticmethod
def static_eigvals(
x: Union[ivy.Array, ivy.NativeArray, ivy.Container],
/,
*,
key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
to_apply: bool = True,
prune_unapplied: bool = False,
map_sequences: bool = False,
) -> ivy.Container:
"""
ivy.Container static method variant of ivy.eigvals.
This method simply wraps the function, and so the docstring for
ivy.eigvals also applies to this method with minimal changes.

Parameters
----------
x
container with input arrays.

Returns
-------
ret
container including array corresponding
to eigenvalues of input array

Examples
--------
>>> x = ivy.array([[1,2], [3,4]])
>>> c = ivy.Container({'x':{'xx':x}})
>>> ivy.Container.eigvals(c)
{
x: {
xx: ivy.array([-0.37228132+0.j, 5.37228132+0.j])
}
}
>>> ivy.Container.eigvals(c)['x']['xx']
ivy.array([-0.37228132+0.j, 5.37228132+0.j])
"""
return ContainerBase.cont_multi_map_in_function(
"eigvals",
x,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
)

def eigvals(
self: ivy.Container,
/,
*,
key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
to_apply: bool = True,
prune_unapplied: bool = False,
map_sequences: bool = False,
) -> ivy.Container:
"""
ivy.Container instance method variant of ivy.eigvals.
This method simply wraps the function, and so the docstring for
ivy.eigvals also applies to this method with minimal changes.

Parameters
----------
x
container with input arrays.

Returns
-------
ret
container including array corresponding
to eigenvalues of input array

Examples
--------
>>> x = ivy.array([[1,2], [3,4]])
>>> c = ivy.Container({'x':{'xx':x}})
>>> c.eigvals()
{
x: {
xx: ivy.array([-0.37228132+0.j, 5.37228132+0.j])
}
}
>>> c.eigvals()['x']['xx']
ivy.array([-0.37228132+0.j, 5.37228132+0.j])
"""
return self.static_eigvals(
self,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
)
6 changes: 6 additions & 0 deletions ivy/functional/backends/jax/experimental/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,9 @@ def matrix_exp(

def eig(x: JaxArray, /) -> Tuple[JaxArray]:
return jnp.linalg.eig(x)


def eigvals(x: JaxArray, /) -> JaxArray:
if not ivy.dtype(x) in (ivy.float32, ivy.float64, ivy.complex64, ivy.complex128):
x = x.astype(jnp.float64)
return jnp.linalg.eigvals(x)
10 changes: 10 additions & 0 deletions ivy/functional/backends/numpy/experimental/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,13 @@ def eig(x: np.ndarray, /) -> Tuple[np.ndarray]:


eig.support_native_out = False


def eigvals(x: np.ndarray, /) -> np.ndarray:
if ivy.dtype(x) == ivy.float16:
x = x.astype(np.float32)
e = np.linalg.eigvals(x)
return e.astype(complex)


eigvals.support_native_out = False
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,12 @@ def eig(
if not ivy.dtype(x) in (ivy.float32, ivy.float64, ivy.complex64, ivy.complex128):
return tf.linalg.eig(tf.cast(x, tf.float64))
return tf.linalg.eig(x)


def eigvals(
x: Union[tf.Tensor],
/,
) -> Union[tf.Tensor, tf.Variable]:
if not ivy.dtype(x) in (ivy.float32, ivy.float64, ivy.complex64, ivy.complex128):
return tf.linalg.eigvals(tf.cast(x, tf.float64))
return tf.linalg.eigvals(x)
9 changes: 9 additions & 0 deletions ivy/functional/backends/torch/experimental/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,12 @@ def eig(x: torch.Tensor, /) -> Tuple[torch.Tensor]:


eig.support_native_out = False


def eigvals(x: torch.Tensor, /) -> torch.Tensor:
if not torch.is_complex(x):
x = x.to(torch.complex128)
return torch.linalg.eigvals(x)


eigvals.support_native_out = False
40 changes: 40 additions & 0 deletions ivy/functional/ivy/experimental/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,43 @@ def eig(
])
"""
return current_backend(x).eig(x)


@to_native_arrays_and_back
@handle_nestable
@handle_exceptions
def eigvals(
x: Union[ivy.Array, ivy.NativeArray],
/,
) -> ivy.Array:
"""Computes eigenvalues of x. Returns a set of eigenvalues.

Parameters
----------
x
An array of shape (..., N, N).

Returns
-------
w
Not necessarily ordered array(..., N) of eigenvalues in complex type.

Functional Examples
------------------
With :class:`ivy.Array` inputs:
>>> x = ivy.array([[1,2], [3,4]])
>>> w = ivy.eigvals(x)
>>> w
ivy.array([-0.37228132+0.j, 5.37228132+0.j])

>>> x = ivy.array([[[1,2], [3,4]], [[5,6], [5,6]]])
>>> w = ivy.eigvals(x)
>>> w
ivy.array(
[
[-0.37228132+0.j, 5.37228132+0.j],
[ 0. +0.j, 11. +0.j]
]
)
"""
return current_backend(x).eigvals(x)
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,44 @@ def test_eig(
test_values=False,
x=x[0],
)


@handle_test(
fn_tree="functional.ivy.experimental.eigvals",
dtype_x=helpers.dtype_and_values(
available_dtypes=(
ivy.float32,
ivy.float64,
ivy.int32,
ivy.int64,
ivy.complex64,
ivy.complex128,
),
min_num_dims=2,
max_num_dims=3,
min_dim_size=10,
max_dim_size=10,
min_value=1.0,
max_value=1.0e5,
shared_dtype=True,
),
test_with_out=st.just(False),
test_gradients=st.just(False),
)
def test_eigvals(
dtype_x,
test_flags,
backend_fw,
fn_name,
ground_truth_backend,
):
dtype, x = dtype_x
helpers.test_function(
ground_truth_backend=ground_truth_backend,
input_dtypes=dtype,
test_flags=test_flags,
fw=backend_fw,
fn_name=fn_name,
test_values=False,
x=x[0],
)