Skip to content

Commit

Permalink
* impl - linalg matrix_rank for cpu and gpu implemented (apache#18020)
Browse files Browse the repository at this point in the history
* fix - python interface

* impl - ffi for matrix_rank

* impl - ffi benchmark

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
2 people authored and sxjscience committed Jul 1, 2020
1 parent be6623e commit ff2dbab
Show file tree
Hide file tree
Showing 12 changed files with 956 additions and 15 deletions.
9 changes: 5 additions & 4 deletions benchmark/python/ffi/benchmark_ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def prepare_workloads():
OpArgMngr.add_workload("linalg.eigh", pool['3x3'])
OpArgMngr.add_workload("linalg.det", pool['3x3'])
OpArgMngr.add_workload("linalg.slogdet", pool['3x3'])
OpArgMngr.add_workload("linalg.matrix_rank", pool['3x3'], pool['1'], hermitian=False)
OpArgMngr.add_workload("linalg.svd", pool['3x3'])
OpArgMngr.add_workload("linalg.cholesky", pool['1x1'])
OpArgMngr.add_workload("linalg.qr", pool['3x3'])
Expand Down Expand Up @@ -123,10 +124,10 @@ def prepare_workloads():
out=dnp.array([False, False], dtype=bool), keepdims=False)
OpArgMngr.add_workload("roll", pool["2x2"], 1, axis=0)
OpArgMngr.add_workload("rot90", pool["2x2"], 2)
OpArgMngr.add_workload("array_split", pool['2X2'], 2, axis=1)
OpArgMngr.add_workload("vsplit", pool['2X2'], 2)
OpArgMngr.add_workload("hsplit", pool['2X2'], 2)
OpArgMngr.add_workload("dsplit", pool['2X2x2'], 2)
OpArgMngr.add_workload("array_split", pool['2x2'], 2, axis=1)
OpArgMngr.add_workload("vsplit", pool['2x2'], 2)
OpArgMngr.add_workload("hsplit", pool['2x2'], 2)
OpArgMngr.add_workload("dsplit", pool['2x2x2'], 2)
OpArgMngr.add_workload("arange", 10)
OpArgMngr.add_workload("concatenate", (pool['1x2'], pool['1x2'], pool['1x2']), axis=0)
OpArgMngr.add_workload("append", pool['2x2'], pool['1x2'], axis=0)
Expand Down
47 changes: 46 additions & 1 deletion python/mxnet/ndarray/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,52 @@
from . import _api_internal

__all__ = ['norm', 'svd', 'cholesky', 'qr', 'inv', 'det', 'slogdet', 'solve', 'tensorinv', 'tensorsolve',
'pinv', 'eigvals', 'eig', 'eigvalsh', 'eigh', 'lstsq']
'pinv', 'eigvals', 'eig', 'eigvalsh', 'eigh', 'lstsq', 'matrix_rank']


def matrix_rank(M, tol=None, hermitian=False):
"""
Return matrix rank of array using SVD method
Rank of the array is the number of singular values of the array that are
greater than `tol`.
Parameters
M : {(M,), (..., M, N)} ndarray
Input vector or stack of matrices.
tol : (...) ndarray, float, optional
Threshold below which SVD values are considered zero. If `tol` is
None, and ``S`` is an array with singular values for `M`, and
``eps`` is the epsilon value for datatype of ``S``, then `tol` is
set to ``S.max() * max(M.shape) * eps``.
hermitian : bool, optional
If True, `M` is assumed to be Hermitian (symmetric if real-valued),
enabling a more efficient method for finding singular values.
Defaults to False.
Returns
-------
rank : (...) ndarray
Rank of M.
Examples
--------
>>> from mxnet import np
>>> np.matrix_rank(np.eye(4)) # Full rank matrix
4
>>> I=np.eye(4); I[-1,-1] = 0. # rank deficient matrix
>>> np.matrix_rank(I)
3
>>> np.matrix_rank(np.ones((4,))) # 1 dimension - rank 1 unless all 0
1
>>> np.matrix_rank(np.zeros((4,)))
0
"""
finfo_eps_32 = _np.finfo(_np.float32).eps
finfo_eps_64 = _np.finfo(_np.float64).eps
if hermitian is True:
raise NotImplementedError("hermitian is not supported yet...")
return _api_internal.matrix_rank(M, tol, hermitian, finfo_eps_32, finfo_eps_64)


def lstsq(a, b, rcond='warn'):
Expand Down
2 changes: 0 additions & 2 deletions python/mxnet/numpy/fallback_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,9 @@
__all__ = [
'cond',
'matrix_power',
'matrix_rank',
'multi_dot'
]

cond = onp.linalg.cond
matrix_power = onp.linalg.matrix_power
matrix_rank = onp.linalg.matrix_rank
multi_dot = onp.linalg.multi_dot
43 changes: 42 additions & 1 deletion python/mxnet/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,51 @@
from . import fallback_linalg

__all__ = ['norm', 'svd', 'cholesky', 'qr', 'inv', 'det', 'slogdet', 'solve', 'tensorinv', 'tensorsolve',
'pinv', 'eigvals', 'eig', 'eigvalsh', 'eigh', 'lstsq']
'pinv', 'eigvals', 'eig', 'eigvalsh', 'eigh', 'lstsq', 'matrix_rank']
__all__ += fallback_linalg.__all__


def matrix_rank(M, tol=None, hermitian=False):
"""
Return matrix rank of array using SVD method
Rank of the array is the number of singular values of the array that are
greater than `tol`.
Parameters
M : {(M,), (..., M, N)} ndarray
Input vector or stack of matrices.
tol : (...) ndarray, float, optional
Threshold below which SVD values are considered zero. If `tol` is
None, and ``S`` is an array with singular values for `M`, and
``eps`` is the epsilon value for datatype of ``S``, then `tol` is
set to ``S.max() * max(M.shape) * eps``.
hermitian : bool, optional
If True, `M` is assumed to be Hermitian (symmetric if real-valued),
enabling a more efficient method for finding singular values.
Defaults to False.
Returns
-------
rank : (...) ndarray
Rank of M.
Examples
--------
>>> from mxnet import np
>>> np.matrix_rank(np.eye(4)) # Full rank matrix
4
>>> I=np.eye(4); I[-1,-1] = 0. # rank deficient matrix
>>> np.matrix_rank(I)
3
>>> np.matrix_rank(np.ones((4,))) # 1 dimension - rank 1 unless all 0
1
>>> np.matrix_rank(np.zeros((4,)))
0
"""
return _mx_nd_np.linalg.matrix_rank(M, tol, hermitian)


def lstsq(a, b, rcond='warn'):
r"""
Return the least-squares solution to a linear matrix equation.
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'linalg.eigvalsh',
'linalg.eigh',
'linalg.qr',
'linalg.matrix_rank',
'shape',
'trace',
'tril',
Expand Down
35 changes: 34 additions & 1 deletion python/mxnet/symbol/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,40 @@
from . import _internal as _npi

__all__ = ['norm', 'svd', 'cholesky', 'qr', 'inv', 'det', 'slogdet', 'solve', 'tensorinv', 'tensorsolve',
'pinv', 'eigvals', 'eig', 'eigvalsh', 'eigh', 'lstsq']
'pinv', 'eigvals', 'eig', 'eigvalsh', 'eigh', 'lstsq', 'matrix_rank']


def matrix_rank(M, tol=None, hermitian=False):
"""
Return matrix rank of array using SVD method
Rank of the array is the number of singular values of the array that are
greater than `tol`.
Parameters
M : {(M,), (..., M, N)} _Symbol
Input vector or stack of matrices.
tol : (...) _Symbol, float, optional
Threshold below which SVD values are considered zero. If `tol` is
None, and ``S`` is an array with singular values for `M`, and
``eps`` is the epsilon value for datatype of ``S``, then `tol` is
set to ``S.max() * max(M.shape) * eps``.
hermitian : bool, optional
If True, `M` is assumed to be Hermitian (symmetric if real-valued),
enabling a more efficient method for finding singular values.
Defaults to False.
Returns
-------
rank : (...) _Symbol
Rank of M.
"""
finfo_eps_32 = _np.finfo(_np.float32).eps
finfo_eps_64 = _np.finfo(_np.float64).eps
if tol is None:
return _npi.matrix_rank_none_tol(M, finfo_eps_32, finfo_eps_64, hermitian)
else:
return _npi.matrix_rank(M, tol, hermitian)


def lstsq(a, b, rcond='warn'):
Expand Down
76 changes: 76 additions & 0 deletions src/api/operator/numpy/linalg/np_matrix_rank.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_pinv.cc
* \brief Implementation of the API of functions in src/operator/numpy/linalg/np_matrix_rank.cc
*/
#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include "../../utils.h"
#include "../../../../operator/numpy/linalg/np_matrix_rank-inl.h"

namespace mxnet {

inline static void _npi_matrix_rank_none_tol(runtime::MXNetArgs args,
runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_matrix_rank_none_tol");
op::MatrixRankNoneTolParam param;
nnvm::NodeAttrs attrs;
param.hermitian = args[2].operator bool();
param.finfoEps32 = args[3].operator double();
param.finfoEps64 = args[4].operator double();
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::MatrixRankNoneTolParam>(&attrs);
int num_inputs = 1;
int num_outputs = 0;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
}

inline static void _npi_matrix_rank(runtime::MXNetArgs args,
runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_matrix_rank");
op::MatrixRankParam param;
nnvm::NodeAttrs attrs;
param.hermitian = args[2].operator bool();
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::MatrixRankParam>(&attrs);
int num_inputs = 2;
int num_outputs = 0;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()};
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
}

MXNET_REGISTER_API("_npi.matrix_rank")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
if (args[1].type_code() == kNull) {
_npi_matrix_rank_none_tol(args, ret);
} else {
_npi_matrix_rank(args, ret);
}
});

} // namespace mxnet
Loading

0 comments on commit ff2dbab

Please sign in to comment.