Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[Numpy] SVD outputs tuple (#16530)
Browse files Browse the repository at this point in the history
* Convert output to tuple

* Fix hyper link in doc
  • Loading branch information
hzfan authored and reminisce committed Oct 20, 2019
1 parent cdfaf39 commit 1648f4c
Show file tree
Hide file tree
Showing 7 changed files with 250 additions and 93 deletions.
65 changes: 0 additions & 65 deletions python/mxnet/_numpy_op_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,71 +480,6 @@ def _np_reshape(a, newshape, order='C', out=None):
"""


def _np__linalg_svd(a):
r"""
Singular Value Decomposition.
When `a` is a 2D array, it is factorized as ``ut @ np.diag(s) @ v``,
where `ut` and `v` are 2D orthonormal arrays and `s` is a 1D
array of `a`'s singular values. When `a` is higher-dimensional, SVD is
applied in stacked mode as explained below.
Parameters
----------
a : (..., M, N) ndarray
A real or complex array with ``a.ndim >= 2`` and ``M <= N``.
Returns
-------
ut: (..., M, M) ndarray
Orthonormal array(s). The first ``a.ndim - 2`` dimensions have the same
size as those of the input `a`.
s : (..., M) ndarray
Vector(s) with the singular values, within each vector sorted in
descending order. The first ``a.ndim - 2`` dimensions have the same
size as those of the input `a`.
v : (..., M, N) ndarray
Orthonormal array(s). The first ``a.ndim - 2`` dimensions have the same
size as those of the input `a`.
Notes
-----
The decomposition is performed using LAPACK routine ``_gesvd``.
SVD is usually described for the factorization of a 2D matrix :math:`A`.
The higher-dimensional case will be discussed below. In the 2D case, SVD is
written as :math:`A = U^T S V`, where :math:`A = a`, :math:`U^T = ut`,
:math:`S= \mathtt{np.diag}(s)` and :math:`V = v`. The 1D array `s`
contains the singular values of `a` and `ut` and `v` are orthonormal. The rows
of `v` are the eigenvectors of :math:`A^T A` and the columns of `ut` are
the eigenvectors of :math:`A A^T`. In both cases the corresponding
(possibly non-zero) eigenvalues are given by ``s**2``.
If `a` has more than two dimensions, then broadcasting rules apply.
This means that SVD is working in "stacked" mode: it iterates over
all indices of the first ``a.ndim - 2`` dimensions and for each
combination SVD is applied to the last two indices. The matrix `a`
can be reconstructed from the decomposition with either
``(ut * s[..., None, :]) @ v`` or
``ut @ (s[..., None] * v)``. (The ``@`` operator denotes batch matrix multiplication)
Examples
--------
>>> a = np.arange(54).reshape(6, 9)
>>> ut, s, v = np.linalg.svd(a)
>>> ut.shape, s.shape, v.shape
((6, 6), (6,), (6, 9))
>>> s = s.reshape(6, 1)
>>> ret = np.dot(ut, s * v)
>>> (ret - a > 1e-3).sum()
array(0.)
>>> (ret - a < -1e-3).sum()
array(0.)
"""
pass


def _np_roll(a, shift, axis=None):
"""
Roll array elements along a given axis.
Expand Down
77 changes: 76 additions & 1 deletion python/mxnet/ndarray/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@

from __future__ import absolute_import
from . import _op as _mx_nd_np
from . import _internal as _npi

__all__ = ['norm']
__all__ = ['norm', 'svd']


def norm(x, ord=None, axis=None, keepdims=False):
Expand Down Expand Up @@ -66,3 +67,77 @@ def norm(x, ord=None, axis=None, keepdims=False):
if ord == 'fro' and x.ndim > 2 and axis is None:
raise ValueError('Improper number of dimensions to norm')
return _mx_nd_np.sqrt(_mx_nd_np.sum(x * x, axis=axis, keepdims=keepdims))


def svd(a):
r"""
Singular Value Decomposition.
When `a` is a 2D array, it is factorized as ``ut @ np.diag(s) @ v``,
where `ut` and `v` are 2D orthonormal arrays and `s` is a 1D
array of `a`'s singular values. When `a` is higher-dimensional, SVD is
applied in stacked mode as explained below.
Parameters
----------
a : (..., M, N) ndarray
A real array with ``a.ndim >= 2`` and ``M <= N``.
Returns
-------
ut: (..., M, M) ndarray
Orthonormal array(s). The first ``a.ndim - 2`` dimensions have the same
size as those of the input `a`.
s : (..., M) ndarray
Vector(s) with the singular values, within each vector sorted in
descending order. The first ``a.ndim - 2`` dimensions have the same
size as those of the input `a`.
v : (..., M, N) ndarray
Orthonormal array(s). The first ``a.ndim - 2`` dimensions have the same
size as those of the input `a`.
Notes
-----
The decomposition is performed using LAPACK routine ``_gesvd``.
SVD is usually described for the factorization of a 2D matrix :math:`A`.
The higher-dimensional case will be discussed below. In the 2D case, SVD is
written as :math:`A = U^T S V`, where :math:`A = a`, :math:`U^T = ut`,
:math:`S= \mathtt{np.diag}(s)` and :math:`V = v`. The 1D array `s`
contains the singular values of `a` and `ut` and `v` are orthonormal. The rows
of `v` are the eigenvectors of :math:`A^T A` and the columns of `ut` are
the eigenvectors of :math:`A A^T`. In both cases the corresponding
(possibly non-zero) eigenvalues are given by ``s**2``.
The sign of rows of `u` and `v` are determined as described in
`Auto-Differentiating Linear Algebra <https://arxiv.org/pdf/1710.08717.pdf>`_.
If `a` has more than two dimensions, then broadcasting rules apply.
This means that SVD is working in "stacked" mode: it iterates over
all indices of the first ``a.ndim - 2`` dimensions and for each
combination SVD is applied to the last two indices. The matrix `a`
can be reconstructed from the decomposition with either
``(ut * s[..., None, :]) @ v`` or
``ut @ (s[..., None] * v)``. (The ``@`` operator denotes batch matrix multiplication)
This function differs from the original `numpy.linalg.svd
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.linalg.svd.html>`_ in
the following way(s):
- The sign of rows of `u` and `v` may differ.
- Does not support complex input.
Examples
--------
>>> a = np.arange(54).reshape(6, 9)
>>> ut, s, v = np.linalg.svd(a)
>>> ut.shape, s.shape, v.shape
((6, 6), (6,), (6, 9))
>>> s = s.reshape(6, 1)
>>> ret = np.dot(ut, s * v)
>>> (ret - a > 1e-3).sum()
array(0.)
>>> (ret - a < -1e-3).sum()
array(0.)
"""
return tuple(_npi.svd(a))
76 changes: 75 additions & 1 deletion python/mxnet/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from __future__ import absolute_import
from ..ndarray import numpy as _mx_nd_np

__all__ = ['norm']
__all__ = ['norm', 'svd']


def norm(x, ord=None, axis=None, keepdims=False):
Expand Down Expand Up @@ -60,3 +60,77 @@ def norm(x, ord=None, axis=None, keepdims=False):
Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15
"""
return _mx_nd_np.linalg.norm(x, ord, axis, keepdims)


def svd(a):
r"""
Singular Value Decomposition.
When `a` is a 2D array, it is factorized as ``ut @ np.diag(s) @ v``,
where `ut` and `v` are 2D orthonormal arrays and `s` is a 1D
array of `a`'s singular values. When `a` is higher-dimensional, SVD is
applied in stacked mode as explained below.
Parameters
----------
a : (..., M, N) ndarray
A real array with ``a.ndim >= 2`` and ``M <= N``.
Returns
-------
ut: (..., M, M) ndarray
Orthonormal array(s). The first ``a.ndim - 2`` dimensions have the same
size as those of the input `a`.
s : (..., M) ndarray
Vector(s) with the singular values, within each vector sorted in
descending order. The first ``a.ndim - 2`` dimensions have the same
size as those of the input `a`.
v : (..., M, N) ndarray
Orthonormal array(s). The first ``a.ndim - 2`` dimensions have the same
size as those of the input `a`.
Notes
-----
The decomposition is performed using LAPACK routine ``_gesvd``.
SVD is usually described for the factorization of a 2D matrix :math:`A`.
The higher-dimensional case will be discussed below. In the 2D case, SVD is
written as :math:`A = U^T S V`, where :math:`A = a`, :math:`U^T = ut`,
:math:`S= \mathtt{np.diag}(s)` and :math:`V = v`. The 1D array `s`
contains the singular values of `a` and `ut` and `v` are orthonormal. The rows
of `v` are the eigenvectors of :math:`A^T A` and the columns of `ut` are
the eigenvectors of :math:`A A^T`. In both cases the corresponding
(possibly non-zero) eigenvalues are given by ``s**2``.
The sign of rows of `u` and `v` are determined as described in
`Auto-Differentiating Linear Algebra <https://arxiv.org/pdf/1710.08717.pdf>`_.
If `a` has more than two dimensions, then broadcasting rules apply.
This means that SVD is working in "stacked" mode: it iterates over
all indices of the first ``a.ndim - 2`` dimensions and for each
combination SVD is applied to the last two indices. The matrix `a`
can be reconstructed from the decomposition with either
``(ut * s[..., None, :]) @ v`` or
``ut @ (s[..., None] * v)``. (The ``@`` operator denotes batch matrix multiplication)
This function differs from the original `numpy.linalg.svd
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.linalg.svd.html>`_ in
the following way(s):
- The sign of rows of `u` and `v` may differ.
- Does not support complex input.
Examples
--------
>>> a = np.arange(54).reshape(6, 9)
>>> ut, s, v = np.linalg.svd(a)
>>> ut.shape, s.shape, v.shape
((6, 6), (6,), (6, 9))
>>> s = s.reshape(6, 1)
>>> ret = np.dot(ut, s * v)
>>> (ret - a > 1e-3).sum()
array(0.)
>>> (ret - a < -1e-3).sum()
array(0.)
"""
return _mx_nd_np.linalg.svd(a)
64 changes: 63 additions & 1 deletion python/mxnet/symbol/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
from __future__ import absolute_import
from . import _symbol
from . import _op as _mx_sym_np
from . import _internal as _npi

__all__ = ['norm']
__all__ = ['norm', 'svd']


def norm(x, ord=None, axis=None, keepdims=False):
Expand Down Expand Up @@ -66,3 +67,64 @@ def norm(x, ord=None, axis=None, keepdims=False):
raise ValueError('Improper number of dimensions to norm')
# TODO(junwu): When ord = 'fro', axis = None, and x.ndim > 2, raise exception
return _symbol.sqrt(_mx_sym_np.sum(x * x, axis=axis, keepdims=keepdims))


def svd(a):
r"""
Singular Value Decomposition.
When `a` is a 2D array, it is factorized as ``ut @ np.diag(s) @ v``,
where `ut` and `v` are 2D orthonormal arrays and `s` is a 1D
array of `a`'s singular values. When `a` is higher-dimensional, SVD is
applied in stacked mode as explained below.
Parameters
----------
a : (..., M, N) _Symbol
A real array with ``a.ndim >= 2`` and ``M <= N``.
Returns
-------
ut: (..., M, M) _Symbol
Orthonormal array(s). The first ``a.ndim - 2`` dimensions have the same
size as those of the input `a`.
s : (..., M) _Symbol
Vector(s) with the singular values, within each vector sorted in
descending order. The first ``a.ndim - 2`` dimensions have the same
size as those of the input `a`.
v : (..., M, N) _Symbol
Orthonormal array(s). The first ``a.ndim - 2`` dimensions have the same
size as those of the input `a`.
Notes
-----
The decomposition is performed using LAPACK routine ``_gesvd``.
SVD is usually described for the factorization of a 2D matrix :math:`A`.
The higher-dimensional case will be discussed below. In the 2D case, SVD is
written as :math:`A = U^T S V`, where :math:`A = a`, :math:`U^T = ut`,
:math:`S= \mathtt{np.diag}(s)` and :math:`V = v`. The 1D array `s`
contains the singular values of `a` and `ut` and `v` are orthonormal. The rows
of `v` are the eigenvectors of :math:`A^T A` and the columns of `ut` are
the eigenvectors of :math:`A A^T`. In both cases the corresponding
(possibly non-zero) eigenvalues are given by ``s**2``.
The sign of rows of `u` and `v` are determined as described in
`Auto-Differentiating Linear Algebra <https://arxiv.org/pdf/1710.08717.pdf>`_.
If `a` has more than two dimensions, then broadcasting rules apply.
This means that SVD is working in "stacked" mode: it iterates over
all indices of the first ``a.ndim - 2`` dimensions and for each
combination SVD is applied to the last two indices. The matrix `a`
can be reconstructed from the decomposition with either
``(ut * s[..., None, :]) @ v`` or
``ut @ (s[..., None] * v)``. (The ``@`` operator denotes batch matrix multiplication)
This function differs from the original `numpy.linalg.svd
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.linalg.svd.html>`_ in
the following way(s):
- The sign of rows of `u` and `v` may differ.
- Does not support complex input.
"""
return _npi.svd(a)
6 changes: 3 additions & 3 deletions src/operator/numpy/linalg/np_gesvd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ inline bool NumpyLaGesvdShape(const nnvm::NodeAttrs& attrs,
return false;
}

NNVM_REGISTER_OP(_np__linalg_svd)
NNVM_REGISTER_OP(_npi_svd)
.describe(R"code()code" ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(3)
Expand All @@ -102,10 +102,10 @@ NNVM_REGISTER_OP(_np__linalg_svd)
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; })
.set_attr<FCompute>("FCompute<cpu>", NumpyLaGesvdForward<cpu, gesvd>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_np_linalg_svd"})
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_npi_svd"})
.add_argument("A", "NDArray-or-Symbol", "Input matrices to be factorized");

NNVM_REGISTER_OP(_backward_np_linalg_svd)
NNVM_REGISTER_OP(_backward_npi_svd)
.set_num_inputs(6)
.set_num_outputs(1)
.set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs) {
Expand Down
4 changes: 2 additions & 2 deletions src/operator/numpy/linalg/np_gesvd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ namespace op {

#if MXNET_USE_CUSOLVER == 1

NNVM_REGISTER_OP(_np__linalg_svd)
NNVM_REGISTER_OP(_npi_svd)
.set_attr<FCompute>("FCompute<gpu>", NumpyLaGesvdForward<gpu, gesvd>);

NNVM_REGISTER_OP(_backward_np_linalg_svd)
NNVM_REGISTER_OP(_backward_npi_svd)
.set_attr<FCompute>("FCompute<gpu>", NumpyLaGesvdBackward<gpu, gesvd_backward>);

#endif
Expand Down
Loading

0 comments on commit 1648f4c

Please sign in to comment.