Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

New operators linalg_syrk, linalg_gelqf #7741

Merged
merged 1 commit into from
Sep 6, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/python/symbol.md
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,8 @@ Composite multiple symbols into a new one by an operator.
linalg_trmm
linalg_trsm
linalg_sumlogdiag
linalg_syrk
linalg_gelqf
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume ndarray.md should be updated, too?

```

### Miscellaneous
Expand Down
76 changes: 53 additions & 23 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def simple_forward(sym, ctx=None, is_train=False, **inputs):
return outputs


def _parse_location(sym, location, ctx):
def _parse_location(sym, location, ctx, dtype=default_dtype()):
"""Parses the given location to a dictionary.

Arguments of the provided op `sym` are used as dictionary keys
Expand All @@ -468,6 +468,8 @@ def _parse_location(sym, location, ctx):
*In either case, value of all the arguments must be provided.*
ctx : Context
Device context.
dtype: np.float32 or np.float64
Datatype for mx.nd.array.

Returns
-------
Expand All @@ -489,19 +491,20 @@ def _parse_location(sym, location, ctx):
ValueError: Symbol arguments and keys of the given location do not match.
"""
assert isinstance(location, (dict, list, tuple))
assert dtype == np.float32 or dtype == np.float64
if isinstance(location, dict):
if set(location.keys()) != set(sym.list_arguments()):
raise ValueError("Symbol arguments and keys of the given location do not match."
"symbol args:%s, location.keys():%s"
% (str(set(sym.list_arguments())), str(set(location.keys()))))
else:
location = {k: v for k, v in zip(sym.list_arguments(), location)}
location = {k: mx.nd.array(v, ctx=ctx) if isinstance(v, np.ndarray) \
location = {k: mx.nd.array(v, ctx=ctx, dtype=dtype) if isinstance(v, np.ndarray) \
else v for k, v in location.items()}
return location


def _parse_aux_states(sym, aux_states, ctx):
def _parse_aux_states(sym, aux_states, ctx, dtype=default_dtype()):
"""Parses the given auxiliary states to a dictionary.

Auxiliary states of the provided op `sym` are used as dictionary
Expand All @@ -520,6 +523,10 @@ def _parse_aux_states(sym, aux_states, ctx):
- if type is dict of str -> `np.ndarray`
maps the name of arguments to the corresponding `np.ndarray`.
*In either case, all aux states of `sym` must be provided.*
ctx : Context
Device context.
dtype: np.float32 or np.float64
Datatype for mx.nd.array.

Returns
-------
Expand All @@ -543,6 +550,7 @@ def _parse_aux_states(sym, aux_states, ctx):
>>> _parse_aux_states(fc2, {'batchnorm0_moving_var': mean_states}, None)
ValueError: Symbol aux_states names and given aux_states do not match.
"""
assert dtype == np.float32 or dtype == np.float64
if aux_states is not None:
if isinstance(aux_states, dict):
if set(aux_states.keys()) != set(sym.list_auxiliary_states()):
Expand All @@ -553,11 +561,12 @@ def _parse_aux_states(sym, aux_states, ctx):
elif isinstance(aux_states, (list, tuple)):
aux_names = sym.list_auxiliary_states()
aux_states = {k:v for k, v in zip(aux_names, aux_states)}
aux_states = {k: mx.nd.array(v, ctx=ctx) for k, v in aux_states.items()}
aux_states = {k: mx.nd.array(v, ctx=ctx, dtype=dtype) for k, v in aux_states.items()}
return aux_states


def numeric_grad(executor, location, aux_states=None, eps=1e-4, use_forward_train=True):
def numeric_grad(executor, location, aux_states=None, eps=1e-4,
use_forward_train=True, dtype=default_dtype()):
"""Calculates a numeric gradient via finite difference method.

Class based on Theano's `theano.gradient.numeric_grad` [1]
Expand All @@ -578,11 +587,15 @@ def numeric_grad(executor, location, aux_states=None, eps=1e-4, use_forward_trai
Epsilon for the finite-difference method.
use_forward_train : bool, optional
Whether to use `is_train=True` in testing.
dtype: np.float32 or np.float64
Datatype for mx.nd.array.

References
---------
..[1] https://github.com/Theano/Theano/blob/master/theano/gradient.py
"""
approx_grads = {k: np.zeros(v.shape, dtype=np.float32)
assert dtype == np.float32 or dtype == np.float64
approx_grads = {k: np.zeros(v.shape, dtype=dtype)
for k, v in location.items()}
for k, v in location.items():
executor.arg_dict[k][:] = v
Expand Down Expand Up @@ -619,7 +632,7 @@ def numeric_grad(executor, location, aux_states=None, eps=1e-4, use_forward_trai

def check_numeric_gradient(sym, location, aux_states=None, numeric_eps=1e-3, rtol=1e-2,
atol=None, grad_nodes=None, use_forward_train=True, ctx=None,
grad_stype_dict=None):
grad_stype_dict=None, dtype=default_dtype()):
"""Verify an operation by checking backward pass via finite difference method.

Based on Theano's `theano.gradient.verify_grad` [1]
Expand Down Expand Up @@ -650,10 +663,14 @@ def check_numeric_gradient(sym, location, aux_states=None, numeric_eps=1e-3, rto
Check the gradient computation on the specified device.
grad_stype_dict : dict of str->str, optional
Storage type dictionary for gradient ndarrays.
dtype: np.float32 or np.float64
Datatype for mx.nd.array.

References
---------
..[1] https://github.com/Theano/Theano/blob/master/theano/gradient.py
"""
assert dtype == np.float32 or dtype == np.float64
if ctx is None:
ctx = default_context()

Expand All @@ -669,9 +686,10 @@ def random_projection(shape):
plain = _rng.rand(*shape) + 0.1
return plain

location = _parse_location(sym=sym, location=location, ctx=ctx)
location = _parse_location(sym=sym, location=location, ctx=ctx, dtype=dtype)
location_npy = {k:v.asnumpy() for k, v in location.items()}
aux_states = _parse_aux_states(sym=sym, aux_states=aux_states, ctx=ctx)
aux_states = _parse_aux_states(sym=sym, aux_states=aux_states, ctx=ctx,
dtype=dtype)
if aux_states is not None:
aux_states_npy = {k: v.asnumpy() for k, v in aux_states.items()}
else:
Expand All @@ -695,11 +713,12 @@ def random_projection(shape):
out = mx.sym.MakeLoss(out)

location = dict(list(location.items()) +
[("__random_proj", mx.nd.array(random_projection(out_shape[0]), ctx=ctx))])
[("__random_proj", mx.nd.array(random_projection(out_shape[0]),
ctx=ctx, dtype=dtype))])
args_grad_npy = dict([(k, _rng.normal(0, 0.01, size=location[k].shape)) for k in grad_nodes]
+ [("__random_proj", _rng.normal(0, 0.01, size=out_shape[0]))])

args_grad = {k: mx.nd.array(v, ctx=ctx) for k, v in args_grad_npy.items()}
args_grad = {k: mx.nd.array(v, ctx=ctx, dtype=dtype) for k, v in args_grad_npy.items()}
if grad_stype_dict is not None:
assert isinstance(grad_stype_dict, dict), "grad_stype_dict must be a dict"
for k, v in grad_stype_dict.items():
Expand All @@ -722,8 +741,9 @@ def random_projection(shape):
executor.backward()
symbolic_grads = {k:executor.grad_dict[k].asnumpy() for k in grad_nodes}

numeric_gradients = numeric_grad(executor, location_npy, aux_states_npy,
eps=numeric_eps, use_forward_train=use_forward_train)
numeric_gradients = numeric_grad(
executor, location_npy, aux_states_npy, eps=numeric_eps,
use_forward_train=use_forward_train, dtype=dtype)
for name in grad_nodes:
fd_grad = numeric_gradients[name]
orig_grad = args_grad_npy[name]
Expand All @@ -742,7 +762,7 @@ def random_projection(shape):


def check_symbolic_forward(sym, location, expected, rtol=1E-4, atol=None,
aux_states=None, ctx=None):
aux_states=None, ctx=None, dtype=default_dtype()):
"""Compares a symbol's forward results with the expected ones.
Prints error messages if the forward results are not the same as the expected ones.

Expand Down Expand Up @@ -773,6 +793,8 @@ def check_symbolic_forward(sym, location, expected, rtol=1E-4, atol=None,
Contains the mapping between names of auxiliary states and their values.
ctx : Context, optional
running context
dtype: np.float32 or np.float64
Datatype for mx.nd.array.

Example
-------
Expand All @@ -785,14 +807,16 @@ def check_symbolic_forward(sym, location, expected, rtol=1E-4, atol=None,
>>> ret_expected = np.array([[19, 22], [43, 50]])
>>> check_symbolic_forward(sym_dot, [mat1, mat2], [ret_expected])
"""
assert dtype == np.float32 or dtype == np.float64
if ctx is None:
ctx = default_context()

location = _parse_location(sym=sym, location=location, ctx=ctx)
aux_states = _parse_aux_states(sym=sym, aux_states=aux_states, ctx=ctx)
location = _parse_location(sym=sym, location=location, ctx=ctx, dtype=dtype)
aux_states = _parse_aux_states(sym=sym, aux_states=aux_states, ctx=ctx,
dtype=dtype)
if isinstance(expected, dict):
expected = [expected[k] for k in sym.list_outputs()]
args_grad_data = {k:mx.nd.empty(v.shape, ctx=ctx) for k, v in location.items()}
args_grad_data = {k:mx.nd.empty(v.shape, ctx=ctx, dtype=dtype) for k, v in location.items()}

executor = sym.bind(ctx=ctx, args=location, args_grad=args_grad_data, aux_states=aux_states)
for g in executor.grad_arrays:
Expand All @@ -807,7 +831,8 @@ def check_symbolic_forward(sym, location, expected, rtol=1E-4, atol=None,


def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol=None,
aux_states=None, grad_req='write', ctx=None, grad_stypes=None):
aux_states=None, grad_req='write', ctx=None, grad_stypes=None,
dtype=default_dtype()):
"""Compares a symbol's backward results with the expected ones.
Prints error messages if the backward results are not the same as the expected results.

Expand Down Expand Up @@ -845,6 +870,8 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol=
Running context.
grad_stypes: dict of str->str
dictionary of mapping argument name to stype for the gradient
dtype: np.float32 or np.float64
Datatype for mx.nd.array.

Example
-------
Expand All @@ -862,17 +889,19 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol=
>>> grad_expected = ograd.copy().asnumpy()
>>> check_symbolic_backward(sym_add, [mat1, mat2], [ograd], [grad_expected, grad_expected])
"""
assert dtype == np.float32 or dtype == np.float64
if ctx is None:
ctx = default_context()

location = _parse_location(sym=sym, location=location, ctx=ctx)
aux_states = _parse_aux_states(sym=sym, aux_states=aux_states, ctx=ctx)
location = _parse_location(sym=sym, location=location, ctx=ctx, dtype=dtype)
aux_states = _parse_aux_states(sym=sym, aux_states=aux_states, ctx=ctx,
dtype=dtype)
if isinstance(expected, (list, tuple)):
expected = {k:v for k, v in zip(sym.list_arguments(), expected)}
args_grad_npy = {k:_rng.normal(size=v.shape) for k, v in expected.items()}
args_grad_data = {}
for k, v in args_grad_npy.items():
nd = mx.nd.array(v, ctx=ctx)
nd = mx.nd.array(v, ctx=ctx, dtype=dtype)
if grad_stypes is not None and k in grad_stypes:
args_grad_data[k] = nd.tostype(grad_stypes[k])
else:
Expand All @@ -888,9 +917,10 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol=
executor.forward(is_train=True)

if isinstance(out_grads, (tuple, list)):
out_grads = [mx.nd.array(v, ctx=ctx) for v in out_grads]
out_grads = [mx.nd.array(v, ctx=ctx, dtype=dtype) for v in out_grads]
elif isinstance(out_grads, (dict)):
out_grads = {k:mx.nd.array(v, ctx=ctx) for k, v in out_grads.items()}
out_grads = {k:mx.nd.array(v, ctx=ctx, dtype=dtype)
for k, v in out_grads.items()}
else:
assert out_grads is None
executor.backward(out_grads)
Expand Down
66 changes: 66 additions & 0 deletions src/operator/c_lapack_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,24 @@ extern "C" {

void sposv_(char *uplo, int *n, int *nrhs,
float *a, int *lda, float *b, int *ldb, int *info);

// Note: GELQF in row-major (MXNet) becomes GEQRF in column-major (LAPACK).
// Also, m and n are flipped, compared to the row-major version
#define MXNET_LAPACK_FSIG_GEQRF(func, dtype) \
void func##_(int *m, int *n, dtype *a, int *lda, dtype *tau, dtype *work, \
int *lwork, int *info);

MXNET_LAPACK_FSIG_GEQRF(sgeqrf, float)
MXNET_LAPACK_FSIG_GEQRF(dgeqrf, double)

// Note: ORGLQ in row-major (MXNet) becomes ORGQR in column-major (LAPACK)
// Also, m and n are flipped, compared to the row-major version
#define MXNET_LAPACK_FSIG_ORGQR(func, dtype) \
void func##_(int *m, int *n, int *k, dtype *a, int *lda, dtype *tau, \
dtype *work, int *lwork, int *info);

MXNET_LAPACK_FSIG_ORGQR(sorgqr, float)
MXNET_LAPACK_FSIG_ORGQR(dorgqr, double)
}

#define MXNET_LAPACK_ROW_MAJOR 101
Expand Down Expand Up @@ -178,6 +196,42 @@ inline void flip<cpu, double>(int m, int n,
return info;
}

// Note: Both MXNET_LAPACK_*gelqf, MXNET_LAPACK_*orglq can only be called with
// row-major format (MXNet). Internally, the QR variants are done in column-major.
// In particular, the matrix dimensions m and n are flipped.
#define MXNET_LAPACK_CWRAP_GELQF(prefix, dtype) \
inline int MXNET_LAPACK_##prefix##gelqf(int matrix_layout, int m, int n, \
dtype *a, int lda, dtype* tau, \
dtype* work, int lwork) { \
if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \
int info(0); \
prefix##geqrf_(&n, &m, a, &lda, tau, work, &lwork, &info); \
return info; \
} else { \
CHECK(false) << "MXNET_LAPACK_" << #prefix << "gelqf implemented for row-major layout only"; \
return 1; \
} \
}
MXNET_LAPACK_CWRAP_GELQF(s, float)
MXNET_LAPACK_CWRAP_GELQF(d, double)

// Note: The k argument (rank) is equal to m as well
#define MXNET_LAPACK_CWRAP_ORGLQ(prefix, dtype) \
inline int MXNET_LAPACK_##prefix##orglq(int matrix_layout, int m, int n, \
dtype *a, int lda, dtype* tau, \
dtype* work, int lwork) { \
if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \
int info(0); \
prefix##orgqr_(&n, &m, &m, a, &lda, tau, work, &lwork, &info); \
return info; \
} else { \
CHECK(false) << "MXNET_LAPACK_" << #prefix << "orglq implemented for row-major layout only"; \
return 1; \
} \
}
MXNET_LAPACK_CWRAP_ORGLQ(s, float)
MXNET_LAPACK_CWRAP_ORGLQ(d, double)

#else

// use pragma message instead of warning
Expand All @@ -192,6 +246,13 @@ inline void flip<cpu, double>(int m, int n,
return 1; \
}

#define MXNET_LAPACK_CWRAPPER2(func, dtype) \
inline int MXNET_LAPACK_##func(int matrix_layout, int m, int n, dtype* a, \
int lda, dtype* tau, dtype* work, int lwork) { \
LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \
return 1; \
}

#define MXNET_LAPACK_UNAVAILABLE(func) \
inline int mxnet_lapack_##func(...) { \
LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \
Expand All @@ -206,6 +267,11 @@ inline void flip<cpu, double>(int m, int n,
MXNET_LAPACK_UNAVAILABLE(sposv)
MXNET_LAPACK_UNAVAILABLE(dposv)

MXNET_LAPACK_CWRAPPER2(sgelqf, float)
MXNET_LAPACK_CWRAPPER2(dgelqf, double)
MXNET_LAPACK_CWRAPPER2(sorglq, float)
MXNET_LAPACK_CWRAPPER2(dorglq, double)

#endif

template <typename DType>
Expand Down
41 changes: 41 additions & 0 deletions src/operator/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,47 @@ void linalg_potri(const Tensor<xpu, 2, DType>& A, bool lower, Stream<xpu> *s = 0
template<typename xpu, typename DType>
void linalg_batch_potri(const Tensor<xpu, 3, DType>& A, bool lower, Stream<xpu> *s = 0);

//////////////////////////////// SYRK ////////////////////////////////////////////

// CPU/GPU-versions of BLAS3 function "syrk". Please refer to the BLAS3-documentation
// for further information about the function and its parameters.
// Note that this is B = syrk(A, B), so that B is input and output parameter.

template<typename xpu, typename DType>
void linalg_syrk(const Tensor<xpu, 2, DType>& A, const Tensor<xpu, 2, DType>& B,
DType alpha, DType beta, bool tA, Stream<xpu> *s = 0);

template<typename xpu, typename DType>
void linalg_batch_syrk(const Tensor<xpu, 3, DType>& A,
const Tensor<xpu, 3, DType>& B, DType alpha, DType beta,
bool tA, Stream<xpu> *s = 0);

//////////////////////////////// GELQF ////////////////////////////////////////////

// CPU/GPU-versions of LAPACK functions "gelqf", "orglq". Please refer to the
// LAPACK documentation for further details.
// Note:
// - The current implementation works for CPU only
// - Both functions have A as input and output parameter
// - Both functions require extra workspace, passed as 1D tensor
// - We call orglq after gelqf. Apart from A, they also communicate via the
// first part of the workspace.

template<typename xpu, typename DType>
void linalg_gelqf(const Tensor<xpu, 2, DType>& A,
const Tensor<xpu, 1, DType>& work, Stream<xpu> *s = 0);

template<typename xpu, typename DType>
void linalg_orglq(const Tensor<xpu, 2, DType>& A,
const Tensor<xpu, 1, DType>& work, Stream<xpu> *s = 0);

// This function determines the amount of workspace needed for linalg_gelqf,
// linalg_orglq. The workspace can be used for both. The first m entries are
// used to communicate information from gelqf to orglq.
template<typename xpu, typename DType>
int linalg_gelqf_workspace_query(const Tensor<xpu, 2, DType>& A,
Stream<xpu> *s = 0);

#include "linalg_impl.h"

#endif // MXNET_OPERATOR_LINALG_H_
Loading