From e9ab7f74f6462460d08ced4f7a8076336c726049 Mon Sep 17 00:00:00 2001 From: Matthias Seeger Date: Wed, 16 Aug 2017 17:36:58 +0200 Subject: [PATCH] New operators linalg_syrk, linalg_gelqf. Numerical unit tests can run in float64 now --- docs/api/python/symbol.md | 2 + python/mxnet/test_utils.py | 76 ++-- src/operator/c_lapack_api.h | 66 ++++ src/operator/linalg.h | 41 +++ src/operator/linalg_impl.h | 133 +++++++ src/operator/tensor/la_op.cc | 324 +++++++++++------ src/operator/tensor/la_op.cu | 2 +- src/operator/tensor/la_op.h | 159 +++++++-- src/operator/tensor/la_op_inline.h | 284 +++++++++++++-- tests/python/unittest/test_operator.py | 460 ++++++++++++++++--------- 10 files changed, 1217 insertions(+), 330 deletions(-) diff --git a/docs/api/python/symbol.md b/docs/api/python/symbol.md index d7b735932703..05d946889375 100644 --- a/docs/api/python/symbol.md +++ b/docs/api/python/symbol.md @@ -526,6 +526,8 @@ Composite multiple symbols into a new one by an operator. linalg_trmm linalg_trsm linalg_sumlogdiag + linalg_syrk + linalg_gelqf ``` ### Miscellaneous diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 439417200692..946314f59856 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -447,7 +447,7 @@ def simple_forward(sym, ctx=None, is_train=False, **inputs): return outputs -def _parse_location(sym, location, ctx): +def _parse_location(sym, location, ctx, dtype=default_dtype()): """Parses the given location to a dictionary. Arguments of the provided op `sym` are used as dictionary keys @@ -468,6 +468,8 @@ def _parse_location(sym, location, ctx): *In either case, value of all the arguments must be provided.* ctx : Context Device context. + dtype: np.float32 or np.float64 + Datatype for mx.nd.array. Returns ------- @@ -489,6 +491,7 @@ def _parse_location(sym, location, ctx): ValueError: Symbol arguments and keys of the given location do not match. """ assert isinstance(location, (dict, list, tuple)) + assert dtype == np.float32 or dtype == np.float64 if isinstance(location, dict): if set(location.keys()) != set(sym.list_arguments()): raise ValueError("Symbol arguments and keys of the given location do not match." @@ -496,12 +499,12 @@ def _parse_location(sym, location, ctx): % (str(set(sym.list_arguments())), str(set(location.keys())))) else: location = {k: v for k, v in zip(sym.list_arguments(), location)} - location = {k: mx.nd.array(v, ctx=ctx) if isinstance(v, np.ndarray) \ + location = {k: mx.nd.array(v, ctx=ctx, dtype=dtype) if isinstance(v, np.ndarray) \ else v for k, v in location.items()} return location -def _parse_aux_states(sym, aux_states, ctx): +def _parse_aux_states(sym, aux_states, ctx, dtype=default_dtype()): """Parses the given auxiliary states to a dictionary. Auxiliary states of the provided op `sym` are used as dictionary @@ -520,6 +523,10 @@ def _parse_aux_states(sym, aux_states, ctx): - if type is dict of str -> `np.ndarray` maps the name of arguments to the corresponding `np.ndarray`. *In either case, all aux states of `sym` must be provided.* + ctx : Context + Device context. + dtype: np.float32 or np.float64 + Datatype for mx.nd.array. Returns ------- @@ -543,6 +550,7 @@ def _parse_aux_states(sym, aux_states, ctx): >>> _parse_aux_states(fc2, {'batchnorm0_moving_var': mean_states}, None) ValueError: Symbol aux_states names and given aux_states do not match. """ + assert dtype == np.float32 or dtype == np.float64 if aux_states is not None: if isinstance(aux_states, dict): if set(aux_states.keys()) != set(sym.list_auxiliary_states()): @@ -553,11 +561,12 @@ def _parse_aux_states(sym, aux_states, ctx): elif isinstance(aux_states, (list, tuple)): aux_names = sym.list_auxiliary_states() aux_states = {k:v for k, v in zip(aux_names, aux_states)} - aux_states = {k: mx.nd.array(v, ctx=ctx) for k, v in aux_states.items()} + aux_states = {k: mx.nd.array(v, ctx=ctx, dtype=dtype) for k, v in aux_states.items()} return aux_states -def numeric_grad(executor, location, aux_states=None, eps=1e-4, use_forward_train=True): +def numeric_grad(executor, location, aux_states=None, eps=1e-4, + use_forward_train=True, dtype=default_dtype()): """Calculates a numeric gradient via finite difference method. Class based on Theano's `theano.gradient.numeric_grad` [1] @@ -578,11 +587,15 @@ def numeric_grad(executor, location, aux_states=None, eps=1e-4, use_forward_trai Epsilon for the finite-difference method. use_forward_train : bool, optional Whether to use `is_train=True` in testing. + dtype: np.float32 or np.float64 + Datatype for mx.nd.array. + References --------- ..[1] https://github.com/Theano/Theano/blob/master/theano/gradient.py """ - approx_grads = {k: np.zeros(v.shape, dtype=np.float32) + assert dtype == np.float32 or dtype == np.float64 + approx_grads = {k: np.zeros(v.shape, dtype=dtype) for k, v in location.items()} for k, v in location.items(): executor.arg_dict[k][:] = v @@ -619,7 +632,7 @@ def numeric_grad(executor, location, aux_states=None, eps=1e-4, use_forward_trai def check_numeric_gradient(sym, location, aux_states=None, numeric_eps=1e-3, rtol=1e-2, atol=None, grad_nodes=None, use_forward_train=True, ctx=None, - grad_stype_dict=None): + grad_stype_dict=None, dtype=default_dtype()): """Verify an operation by checking backward pass via finite difference method. Based on Theano's `theano.gradient.verify_grad` [1] @@ -650,10 +663,14 @@ def check_numeric_gradient(sym, location, aux_states=None, numeric_eps=1e-3, rto Check the gradient computation on the specified device. grad_stype_dict : dict of str->str, optional Storage type dictionary for gradient ndarrays. + dtype: np.float32 or np.float64 + Datatype for mx.nd.array. + References --------- ..[1] https://github.com/Theano/Theano/blob/master/theano/gradient.py """ + assert dtype == np.float32 or dtype == np.float64 if ctx is None: ctx = default_context() @@ -669,9 +686,10 @@ def random_projection(shape): plain = _rng.rand(*shape) + 0.1 return plain - location = _parse_location(sym=sym, location=location, ctx=ctx) + location = _parse_location(sym=sym, location=location, ctx=ctx, dtype=dtype) location_npy = {k:v.asnumpy() for k, v in location.items()} - aux_states = _parse_aux_states(sym=sym, aux_states=aux_states, ctx=ctx) + aux_states = _parse_aux_states(sym=sym, aux_states=aux_states, ctx=ctx, + dtype=dtype) if aux_states is not None: aux_states_npy = {k: v.asnumpy() for k, v in aux_states.items()} else: @@ -695,11 +713,12 @@ def random_projection(shape): out = mx.sym.MakeLoss(out) location = dict(list(location.items()) + - [("__random_proj", mx.nd.array(random_projection(out_shape[0]), ctx=ctx))]) + [("__random_proj", mx.nd.array(random_projection(out_shape[0]), + ctx=ctx, dtype=dtype))]) args_grad_npy = dict([(k, _rng.normal(0, 0.01, size=location[k].shape)) for k in grad_nodes] + [("__random_proj", _rng.normal(0, 0.01, size=out_shape[0]))]) - args_grad = {k: mx.nd.array(v, ctx=ctx) for k, v in args_grad_npy.items()} + args_grad = {k: mx.nd.array(v, ctx=ctx, dtype=dtype) for k, v in args_grad_npy.items()} if grad_stype_dict is not None: assert isinstance(grad_stype_dict, dict), "grad_stype_dict must be a dict" for k, v in grad_stype_dict.items(): @@ -722,8 +741,9 @@ def random_projection(shape): executor.backward() symbolic_grads = {k:executor.grad_dict[k].asnumpy() for k in grad_nodes} - numeric_gradients = numeric_grad(executor, location_npy, aux_states_npy, - eps=numeric_eps, use_forward_train=use_forward_train) + numeric_gradients = numeric_grad( + executor, location_npy, aux_states_npy, eps=numeric_eps, + use_forward_train=use_forward_train, dtype=dtype) for name in grad_nodes: fd_grad = numeric_gradients[name] orig_grad = args_grad_npy[name] @@ -742,7 +762,7 @@ def random_projection(shape): def check_symbolic_forward(sym, location, expected, rtol=1E-4, atol=None, - aux_states=None, ctx=None): + aux_states=None, ctx=None, dtype=default_dtype()): """Compares a symbol's forward results with the expected ones. Prints error messages if the forward results are not the same as the expected ones. @@ -773,6 +793,8 @@ def check_symbolic_forward(sym, location, expected, rtol=1E-4, atol=None, Contains the mapping between names of auxiliary states and their values. ctx : Context, optional running context + dtype: np.float32 or np.float64 + Datatype for mx.nd.array. Example ------- @@ -785,14 +807,16 @@ def check_symbolic_forward(sym, location, expected, rtol=1E-4, atol=None, >>> ret_expected = np.array([[19, 22], [43, 50]]) >>> check_symbolic_forward(sym_dot, [mat1, mat2], [ret_expected]) """ + assert dtype == np.float32 or dtype == np.float64 if ctx is None: ctx = default_context() - location = _parse_location(sym=sym, location=location, ctx=ctx) - aux_states = _parse_aux_states(sym=sym, aux_states=aux_states, ctx=ctx) + location = _parse_location(sym=sym, location=location, ctx=ctx, dtype=dtype) + aux_states = _parse_aux_states(sym=sym, aux_states=aux_states, ctx=ctx, + dtype=dtype) if isinstance(expected, dict): expected = [expected[k] for k in sym.list_outputs()] - args_grad_data = {k:mx.nd.empty(v.shape, ctx=ctx) for k, v in location.items()} + args_grad_data = {k:mx.nd.empty(v.shape, ctx=ctx, dtype=dtype) for k, v in location.items()} executor = sym.bind(ctx=ctx, args=location, args_grad=args_grad_data, aux_states=aux_states) for g in executor.grad_arrays: @@ -807,7 +831,8 @@ def check_symbolic_forward(sym, location, expected, rtol=1E-4, atol=None, def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol=None, - aux_states=None, grad_req='write', ctx=None, grad_stypes=None): + aux_states=None, grad_req='write', ctx=None, grad_stypes=None, + dtype=default_dtype()): """Compares a symbol's backward results with the expected ones. Prints error messages if the backward results are not the same as the expected results. @@ -845,6 +870,8 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol= Running context. grad_stypes: dict of str->str dictionary of mapping argument name to stype for the gradient + dtype: np.float32 or np.float64 + Datatype for mx.nd.array. Example ------- @@ -862,17 +889,19 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol= >>> grad_expected = ograd.copy().asnumpy() >>> check_symbolic_backward(sym_add, [mat1, mat2], [ograd], [grad_expected, grad_expected]) """ + assert dtype == np.float32 or dtype == np.float64 if ctx is None: ctx = default_context() - location = _parse_location(sym=sym, location=location, ctx=ctx) - aux_states = _parse_aux_states(sym=sym, aux_states=aux_states, ctx=ctx) + location = _parse_location(sym=sym, location=location, ctx=ctx, dtype=dtype) + aux_states = _parse_aux_states(sym=sym, aux_states=aux_states, ctx=ctx, + dtype=dtype) if isinstance(expected, (list, tuple)): expected = {k:v for k, v in zip(sym.list_arguments(), expected)} args_grad_npy = {k:_rng.normal(size=v.shape) for k, v in expected.items()} args_grad_data = {} for k, v in args_grad_npy.items(): - nd = mx.nd.array(v, ctx=ctx) + nd = mx.nd.array(v, ctx=ctx, dtype=dtype) if grad_stypes is not None and k in grad_stypes: args_grad_data[k] = nd.tostype(grad_stypes[k]) else: @@ -888,9 +917,10 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol= executor.forward(is_train=True) if isinstance(out_grads, (tuple, list)): - out_grads = [mx.nd.array(v, ctx=ctx) for v in out_grads] + out_grads = [mx.nd.array(v, ctx=ctx, dtype=dtype) for v in out_grads] elif isinstance(out_grads, (dict)): - out_grads = {k:mx.nd.array(v, ctx=ctx) for k, v in out_grads.items()} + out_grads = {k:mx.nd.array(v, ctx=ctx, dtype=dtype) + for k, v in out_grads.items()} else: assert out_grads is None executor.backward(out_grads) diff --git a/src/operator/c_lapack_api.h b/src/operator/c_lapack_api.h index 96a9b3a23709..d80915158e80 100644 --- a/src/operator/c_lapack_api.h +++ b/src/operator/c_lapack_api.h @@ -86,6 +86,24 @@ extern "C" { void sposv_(char *uplo, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, int *info); + + // Note: GELQF in row-major (MXNet) becomes GEQRF in column-major (LAPACK). + // Also, m and n are flipped, compared to the row-major version + #define MXNET_LAPACK_FSIG_GEQRF(func, dtype) \ + void func##_(int *m, int *n, dtype *a, int *lda, dtype *tau, dtype *work, \ + int *lwork, int *info); + + MXNET_LAPACK_FSIG_GEQRF(sgeqrf, float) + MXNET_LAPACK_FSIG_GEQRF(dgeqrf, double) + + // Note: ORGLQ in row-major (MXNet) becomes ORGQR in column-major (LAPACK) + // Also, m and n are flipped, compared to the row-major version + #define MXNET_LAPACK_FSIG_ORGQR(func, dtype) \ + void func##_(int *m, int *n, int *k, dtype *a, int *lda, dtype *tau, \ + dtype *work, int *lwork, int *info); + + MXNET_LAPACK_FSIG_ORGQR(sorgqr, float) + MXNET_LAPACK_FSIG_ORGQR(dorgqr, double) } #define MXNET_LAPACK_ROW_MAJOR 101 @@ -178,6 +196,42 @@ inline void flip(int m, int n, return info; } + // Note: Both MXNET_LAPACK_*gelqf, MXNET_LAPACK_*orglq can only be called with + // row-major format (MXNet). Internally, the QR variants are done in column-major. + // In particular, the matrix dimensions m and n are flipped. + #define MXNET_LAPACK_CWRAP_GELQF(prefix, dtype) \ + inline int MXNET_LAPACK_##prefix##gelqf(int matrix_layout, int m, int n, \ + dtype *a, int lda, dtype* tau, \ + dtype* work, int lwork) { \ + if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \ + int info(0); \ + prefix##geqrf_(&n, &m, a, &lda, tau, work, &lwork, &info); \ + return info; \ + } else { \ + CHECK(false) << "MXNET_LAPACK_" << #prefix << "gelqf implemented for row-major layout only"; \ + return 1; \ + } \ + } + MXNET_LAPACK_CWRAP_GELQF(s, float) + MXNET_LAPACK_CWRAP_GELQF(d, double) + + // Note: The k argument (rank) is equal to m as well + #define MXNET_LAPACK_CWRAP_ORGLQ(prefix, dtype) \ + inline int MXNET_LAPACK_##prefix##orglq(int matrix_layout, int m, int n, \ + dtype *a, int lda, dtype* tau, \ + dtype* work, int lwork) { \ + if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \ + int info(0); \ + prefix##orgqr_(&n, &m, &m, a, &lda, tau, work, &lwork, &info); \ + return info; \ + } else { \ + CHECK(false) << "MXNET_LAPACK_" << #prefix << "orglq implemented for row-major layout only"; \ + return 1; \ + } \ + } + MXNET_LAPACK_CWRAP_ORGLQ(s, float) + MXNET_LAPACK_CWRAP_ORGLQ(d, double) + #else // use pragma message instead of warning @@ -192,6 +246,13 @@ inline void flip(int m, int n, return 1; \ } + #define MXNET_LAPACK_CWRAPPER2(func, dtype) \ + inline int MXNET_LAPACK_##func(int matrix_layout, int m, int n, dtype* a, \ + int lda, dtype* tau, dtype* work, int lwork) { \ + LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \ + return 1; \ + } + #define MXNET_LAPACK_UNAVAILABLE(func) \ inline int mxnet_lapack_##func(...) { \ LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \ @@ -206,6 +267,11 @@ inline void flip(int m, int n, MXNET_LAPACK_UNAVAILABLE(sposv) MXNET_LAPACK_UNAVAILABLE(dposv) + MXNET_LAPACK_CWRAPPER2(sgelqf, float) + MXNET_LAPACK_CWRAPPER2(dgelqf, double) + MXNET_LAPACK_CWRAPPER2(sorglq, float) + MXNET_LAPACK_CWRAPPER2(dorglq, double) + #endif template diff --git a/src/operator/linalg.h b/src/operator/linalg.h index 76acf7b98f41..1a3cfe29e89d 100644 --- a/src/operator/linalg.h +++ b/src/operator/linalg.h @@ -123,6 +123,47 @@ void linalg_potri(const Tensor& A, bool lower, Stream *s = 0 template void linalg_batch_potri(const Tensor& A, bool lower, Stream *s = 0); +//////////////////////////////// SYRK //////////////////////////////////////////// + +// CPU/GPU-versions of BLAS3 function "syrk". Please refer to the BLAS3-documentation +// for further information about the function and its parameters. +// Note that this is B = syrk(A, B), so that B is input and output parameter. + +template +void linalg_syrk(const Tensor& A, const Tensor& B, + DType alpha, DType beta, bool tA, Stream *s = 0); + +template +void linalg_batch_syrk(const Tensor& A, + const Tensor& B, DType alpha, DType beta, + bool tA, Stream *s = 0); + +//////////////////////////////// GELQF //////////////////////////////////////////// + +// CPU/GPU-versions of LAPACK functions "gelqf", "orglq". Please refer to the +// LAPACK documentation for further details. +// Note: +// - The current implementation works for CPU only +// - Both functions have A as input and output parameter +// - Both functions require extra workspace, passed as 1D tensor +// - We call orglq after gelqf. Apart from A, they also communicate via the +// first part of the workspace. + +template +void linalg_gelqf(const Tensor& A, + const Tensor& work, Stream *s = 0); + +template +void linalg_orglq(const Tensor& A, + const Tensor& work, Stream *s = 0); + +// This function determines the amount of workspace needed for linalg_gelqf, +// linalg_orglq. The workspace can be used for both. The first m entries are +// used to communicate information from gelqf to orglq. +template +int linalg_gelqf_workspace_query(const Tensor& A, + Stream *s = 0); + #include "linalg_impl.h" #endif // MXNET_OPERATOR_LINALG_H_ diff --git a/src/operator/linalg_impl.h b/src/operator/linalg_impl.h index 28c9496798ef..e8f64e15302e 100644 --- a/src/operator/linalg_impl.h +++ b/src/operator/linalg_impl.h @@ -719,4 +719,137 @@ LINALG_GPU_BATCH_POTRI(double) #endif +//////////////////////////////// SYRK //////////////////////////////////////////// + +// CPU/GPU-versions of BLAS3 function "syrk". Please refer to the BLAS3-documentation +// for further information about the function and its parameters. +// Note that this is B = syrk(A, B), so B is input and output parameter. + +#if MSHADOW_USE_CBLAS == 1 + +template inline +void check_syrk(const Tensor& A, const Tensor& B, + DType alpha, DType beta, bool tA) { + // Any checking that helps user debug potential problems. + CHECK_EQ(B.size(0), B.size(1)) + << "B must be square symmetric matrix for syrk"; + CHECK_EQ((tA ? A.size(1) : A.size(0)), B.size(0)) + << "Non compatible matrix dimensions between inputs A and B for syrk"; +} + +#define LINALG_CPU_SYRK(fname, DType) \ +template<> inline \ +void linalg_syrk(const Tensor& A, \ + const Tensor& B, DType alpha, \ + DType beta, bool tA, Stream *s) { \ + check_syrk(A, B, alpha, beta, tA); \ + cblas_##fname(CblasRowMajor, CblasLower, (tA ? CblasTrans : CblasNoTrans), \ + B.size(0), (tA ? A.size(0) : A.size(1)), alpha, \ + A.dptr_, A.stride_, beta, B.dptr_, B.stride_); \ +} + +#define LINALG_CPU_BATCH_SYRK(DType) \ +template<> inline \ +void linalg_batch_syrk(const Tensor& A, \ + const Tensor& B, DType alpha, DType beta, \ + bool tA, Stream *s) { \ + linalg_check_batch_size(A.size(0), B.size(0), B.size(0)); \ + for (index_t i = 0; i < A.size(0); ++i) { \ + linalg_syrk(A[i], B[i], alpha, beta, tA); \ + } \ +} + +#else + +#define LINALG_CPU_SYRK(fname, DType) \ +template<> inline \ +void linalg_syrk(const Tensor& A, \ + const Tensor& B, DType alpha, \ + DType beta, bool tA, Stream *s) { \ + LOG(FATAL) << "linalg_syrk not implemented by mxnet for cpu, needs cblas!"; \ +} + +#define LINALG_CPU_BATCH_SYRK(DType) \ +template<> inline \ +void linalg_batch_syrk(const Tensor& A, \ + const Tensor& B, DType alpha, DType beta, \ + bool tA, Stream *s) { \ + LOG(FATAL) << "linalg_batch_syrk not implemented by mxnet for cpu, needs cblas!"; \ +} + +#endif // MSHADOW_USE_CBLAS == 1 + +LINALG_CPU_SYRK(ssyrk, float) +LINALG_CPU_SYRK(dsyrk, double) +LINALG_CPU_BATCH_SYRK(float) +LINALG_CPU_BATCH_SYRK(double) + +//////////////////////////////// GELQF //////////////////////////////////////////// + +// CPU/GPU-versions of LAPACK functions "gelqf", "orglq". + +template inline +void check_gelqf(const Tensor& A, + const Tensor& work) { + // Any checking that helps user debug potential problems. + CHECK_LE(A.size(0), A.size(1)) + << "A must have num(rows) <= num(columns)"; + CHECK_LT(A.size(0), work.size(0)) + << "Size of work is too small"; +} + +#define LINALG_CPU_GELQF(fname, DType) \ +template<> inline \ +void linalg_gelqf(const Tensor& A, \ + const Tensor& work, \ + Stream *s) { \ + check_gelqf(A, work); \ + int m(A.size(0)); \ + int lwork(work.size(0) - m); \ + int ret(MXNET_LAPACK_##fname(MXNET_LAPACK_ROW_MAJOR, m, A.size(1), \ + A.dptr_ , A.stride_, work.dptr_, \ + work.dptr_ + m, lwork)); \ + CHECK_EQ(ret, 0) << #fname << " failed in lapack on cpu."; \ +} +LINALG_CPU_GELQF(sgelqf, float) +LINALG_CPU_GELQF(dgelqf, double) + +#define LINALG_CPU_ORGLQ(fname, DType) \ +template<> inline \ +void linalg_orglq(const Tensor& A, \ + const Tensor& work, \ + Stream *s) { \ + check_gelqf(A, work); \ + int m(A.size(0)); \ + int lwork(work.size(0) - m); \ + int ret(MXNET_LAPACK_##fname(MXNET_LAPACK_ROW_MAJOR, m, A.size(1), \ + A.dptr_ , A.stride_, work.dptr_, \ + work.dptr_ + m, lwork)); \ + CHECK_EQ(ret, 0) << #fname << " failed in lapack on cpu."; \ +} +LINALG_CPU_ORGLQ(sorglq, float) +LINALG_CPU_ORGLQ(dorglq, double) + +#define LINALG_CPU_GELQF_WORKSPACE_QUERY(prefix, DType) \ +template<> inline \ +int linalg_gelqf_workspace_query(const Tensor& A, \ + Stream *s) { \ + int m(A.size(0)); \ + DType work; \ + int ret(MXNET_LAPACK_##prefix##gelqf(MXNET_LAPACK_ROW_MAJOR, m, \ + A.size(1), A.dptr_ , A.stride_, &work, \ + &work, -1)); \ + CHECK_EQ(ret, 0) << #prefix << "gelqf: Workspace query failed on CPU."; \ + int ws_size(static_cast(work)); \ + ret = MXNET_LAPACK_##prefix##orglq(MXNET_LAPACK_ROW_MAJOR, m, \ + A.size(1), A.dptr_ , \ + A.stride_, &work, &work, -1); \ + CHECK_EQ(ret, 0) << #prefix << "orglq: Workspace query failed on CPU."; \ + int wsz2(static_cast(work)); \ + if (wsz2 > ws_size) ws_size = wsz2; \ + return ws_size + m; \ +} +LINALG_CPU_GELQF_WORKSPACE_QUERY(s, float) +LINALG_CPU_GELQF_WORKSPACE_QUERY(d, double) + #endif // MXNET_OPERATOR_LINALG_IMPL_H_ diff --git a/src/operator/tensor/la_op.cc b/src/operator/tensor/la_op.cc index 9b94603b5fdc..63fcb02da739 100644 --- a/src/operator/tensor/la_op.cc +++ b/src/operator/tensor/la_op.cc @@ -30,24 +30,25 @@ namespace op { DMLC_REGISTER_PARAMETER(LaMatrixMacParam); DMLC_REGISTER_PARAMETER(LaMatrixMultParam); DMLC_REGISTER_PARAMETER(LaTriangMatrixMultParam); +DMLC_REGISTER_PARAMETER(LaSyrkParam); NNVM_REGISTER_OP(_linalg_gemm) .add_alias("linalg_gemm") .describe(R"code(Performs general matrix multiplication and accumulation. -Input are three tensors *A*, *B*, *C* each of dimension *n >= 2* and each -having the same shape on the leading *n-2* dimensions. For every *n-2* dimensional index *i* let -*A*\ :sub:`i`\ , *B*\ :sub:`i`\ , *C*\ :sub:`i` be the matrices given by the last *2* dimensions. -The operator performs the BLAS3 function *gemm* +Input are tensors *A*, *B*, *C*, each of dimension *n >= 2* and having the same shape +on the leading *n-2* dimensions. - *out*\ :sub:`i` = *alpha* \* *op*\ (*A*\ :sub:`i`\ ) \* *op*\ (*B*\ :sub:`i`\ ) + *beta* \* *C*\ :sub:`i` +If *n=2*, the BLAS3 function *gemm* is performed: -on all such triples of matrices. Here *alpha* and *beta* are scalar operator parameters and *op()* -is either the identity or the matrix transposition. + *out* = *alpha* \* *op*\ (*A*) \* *op*\ (*B*) + *beta* \* *C* -In case of *n=2*, a single *gemm* function is performed on the matrices *A*, *B*, *C*. +Here, *alpha* and *beta* are scalar parameters, and *op()* is either the identity or +matrix transposition (depending on *transpose_a*, *transpose_b*). -.. note:: The operator does only support float32 and float64 data types and provides - proper backward gradients. +If *n>2*, *gemm* is performed separately on the trailing two dimensions for all inputs +(batch mode). + +.. note:: The operator supports float32 and float64 data types only. Examples:: @@ -55,14 +56,14 @@ Examples:: A = [[1.0, 1.0], [1.0, 1.0]] B = [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]] C = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]] - gemm(A, B, C, transpose_b = 1, alpha = 2.0 , beta = 10.0) + gemm(A, B, C, transpose_b=True, alpha=2.0, beta=10.0) = [[14.0, 14.0, 14.0], [14.0, 14.0, 14.0]] // Batch matrix multiply-add A = [[[1.0, 1.0]], [[0.1, 0.1]]] B = [[[1.0, 1.0]], [[0.1, 0.1]]] C = [[[10.0]], [[0.01]]] - gemm(A, B, C, transpose_b = 1, alpha = 2.0 , beta = 10.0) + gemm(A, B, C, transpose_b=True, alpha=2.0 , beta=10.0) = [[[104.0]], [[0.14]]] )code" ADD_FILELINE) .set_num_inputs(3) @@ -95,33 +96,33 @@ NNVM_REGISTER_OP(_backward_linalg_gemm) NNVM_REGISTER_OP(_linalg_gemm2) .add_alias("linalg_gemm2") .describe(R"code(Performs general matrix multiplication. -Input are two tensors *A*, *B* each of dimension *n >= 2* and each -having the same shape on the leading *n-2* dimensions. For every *n-2* dimensional index *i* let -*A*\ :sub:`i`\ , *B*\ :sub:`i`\ be the matrices given by the last *2* dimensions. -The operator performs the BLAS3 function *gemm* (restricted to two arguments) +Input are tensors *A*, *B*, each of dimension *n >= 2* and having the same shape +on the leading *n-2* dimensions. + +If *n=2*, the BLAS3 function *gemm* is performed: - *out*\ :sub:`i` = *alpha* \* *op*\ (*A*\ :sub:`i`\ ) \* *op*\ (*B*\ :sub:`i`\ ) + *out* = *alpha* \* *op*\ (*A*) \* *op*\ (*B*) -on all such pairs of matrices. Here *alpha* is a scalar operator parameter and *op()* is either -the identity or the matrix transposition. +Here *alpha* is a scalar parameter and *op()* is either the identity or the matrix +transposition (depending on *transpose_a*, *transpose_b*). -In case of *n=2*, a single *gemm* function is performed on the matrices *A*, *B*. +If *n>2*, *gemm* is performed separately on the trailing two dimensions for all inputs +(batch mode). -.. note:: The operator does only support float32 and float64 data types and provides - proper backward gradients. +.. note:: The operator supports float32 and float64 data types only. Examples:: // Single matrix multiply A = [[1.0, 1.0], [1.0, 1.0]] B = [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]] - gemm2(A, B, transpose_b = 1, alpha = 2.0) + gemm2(A, B, transpose_b=True, alpha=2.0) = [[4.0, 4.0, 4.0], [4.0, 4.0, 4.0]] // Batch matrix multiply A = [[[1.0, 1.0]], [[0.1, 0.1]]] B = [[[1.0, 1.0]], [[0.1, 0.1]]] - gemm2(A, B, transpose_b = 1, alpha = 2.0 ) + gemm2(A, B, transpose_b=True, alpha=2.0) = [[[4.0]], [[0.04 ]]] )code" ADD_FILELINE) .set_num_inputs(2) @@ -151,22 +152,18 @@ NNVM_REGISTER_OP(_backward_linalg_gemm2) NNVM_REGISTER_OP(_linalg_potrf) .add_alias("linalg_potrf") .describe(R"code(Performs Cholesky factorization of a symmetric positive-definite matrix. -Input is a tensor *A* of dimension *n >= 2*. For every *n-2* dimensional index *i* let -*A*\ :sub:`i`\ be the matrix given by the last *2* dimensions. -The operator performs the Cholesky factorization (LAPACK function *potrf*) -on each *A*\ :sub:`i`\ , -i.e. it computes a lower triangular matrix *U*\ :sub:`i` such that +Input is a tensor *A* of dimension *n >= 2*. - *A*\ :sub:`i`\ = *U*\ :sub:`i`\ \* *U*\ :sub:`i`\ \ :sup:`T` +If *n=2*, the Cholesky factor *L* of the symmetric, positive definite matrix *A* is +computed. *L* is lower triangular (entries of upper triangle are all zero), has +positive diagonal entries, and: -for all such matrices. The matrices *A*\ :sub:`i` must be all symmetric and positive-definite. -The resulting matrices *U*\ :sub:`i` will contain zeros in the upper triangle -apart from the diagonal. + *A* = *L* \* *L*\ :sup:`T` -In case of *n=2*, a single Cholesky factorization is performed on the matrix *A*. +If *n>2*, *potrf* is performed separately on the trailing two dimensions for all inputs +(batch mode). -.. note:: The operator does only support float32 and float64 data types and provides - proper backward gradients. +.. note:: The operator supports float32 and float64 data types only. Examples:: @@ -204,21 +201,26 @@ NNVM_REGISTER_OP(_backward_linalg_potrf) NNVM_REGISTER_OP(_linalg_potri) .add_alias("linalg_potri") .describe(R"code(Performs matrix inversion from a Cholesky factorization. -Input is a tensor *A* of dimension *n >= 2*. For every *n-2* dimensional index *i* let -*A*\ :sub:`i`\ be the matrix given by the last *2* dimensions. -The operator assumes that each *A*\ :sub:`i` is the Cholesky factorization of some symmetric -positive-definite matrix *B*\ :sub:`i` given as a lower triangular matrix -(so *A* is the output of a prior call to operator *linalg_potrf*). The operator computes the -inverse of each *B*\ :sub:`i` from this decomposition, i.e +Input is a tensor *A* of dimension *n >= 2*. + +If *n=2*, *A* is a lower triangular matrix (entries of upper triangle are all zero) +with positive diagonal. We compute: + + *out* = *A*\ :sup:`-T` \* *A*\ :sup:`-1` + +In other words, if *A* is the Cholesky factor of a symmetric positive definite matrix +*B*, then - *out*\ :sub:`i` = *B*\ :sub:`i`\ \ :sup:`-1` + *out* = *B*\ :sup:`-1` -for all such matrices. +If *n>2*, *potri* is performed separately on the trailing two dimensions for all inputs +(batch mode). -In case of *n=2*, the operation is performed on the matrix *A* itself. +.. note:: The operator supports float32 and float64 data types only. -.. note:: The operator does only support float32 and float64 data types and provides - proper backward gradients. +.. note:: Use this operator only if you are certain you need the inverse of *B*, and + cannot use the Cholesky factor alone. The latter is more numerically + stable and cheaper. Examples:: @@ -229,7 +231,7 @@ Examples:: // Batch matrix inverse A = [[[2.0, 0], [0.5, 2.0]], [[4.0, 0], [1.0, 4.0]]] potri(A) = [[[0.26563, -0.0625], [-0.0625, 0.25]], - [[0.06641, -0.01562], [-0.01562, 0,0625]]] + [[0.06641, -0.01562], [-0.01562, 0,0625]]] )code" ADD_FILELINE) .set_num_inputs(1) .set_num_outputs(1) @@ -253,41 +255,40 @@ NNVM_REGISTER_OP(_backward_linalg_potri) NNVM_REGISTER_OP(_linalg_trmm) .add_alias("linalg_trmm") -.describe(R"code(Performs multiplication with a triangular matrix. -Input are two tensors *A*, *B* each of dimension *n >= 2* and each -having the same shape on the leading *n-2* dimensions. For every *n-2* dimensional index *i* let -*A*\ :sub:`i`\ , *B*\ :sub:`i`\ be the matrices given by the last *2* dimensions. -The operator performs the BLAS3 function *trmm* +.describe(R"code(Performs multiplication with a lower triangular matrix. +Input are tensors *A*, *B*, each of dimension *n >= 2* and having the same shape +on the leading *n-2* dimensions. - *out*\ :sub:`i` = *alpha* \* *op*\ (*A*\ :sub:`i`\ ) \* *B*\ :sub:`i` +If *n=2*, *A* must be lower triangular. The operator performs the BLAS3 function +*trmm*: -or + *out* = *alpha* \* *op*\ (*A*) \* *B* - *out*\ :sub:`i` = *alpha* \* *B*\ :sub:`i` \* *op*\ (*A*\ :sub:`i`\ ) +if *rightside=False*, or -on all such pairs of matrices. Here *alpha* is a scalar operator parameter, *op()* is either -the identity or the matrix transposition (depending on the parameter *transpose*) and the -order of matrix multiplication depends on the parameter *rightside*. -All matrices *A*\ :sub:`i` must be lower triangular. + *out* = *alpha* \* *B* \* *op*\ (*A*) -In case of *n=2*, a single *trmm* function is performed on the matrices *A*, *B*. +if *rightside=True*. Here, *alpha* is a scalar parameter, and *op()* is either the +identity or the matrix transposition (depending on *transpose*). + +If *n>2*, *trmm* is performed separately on the trailing two dimensions for all inputs +(batch mode). + +.. note:: The operator supports float32 and float64 data types only. -.. note:: The operator does only support float32 and float64 data types and provides - proper backward gradients. Examples:: - // Single matrix multiply + // Single triangular matrix multiply A = [[1.0, 0], [1.0, 1.0]] B = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]] - trmm(A, B, alpha = 2.0) = [[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]] + trmm(A, B, alpha=2.0) = [[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]] - // Batch matrix multiply + // Batch triangular matrix multiply A = [[[1.0, 0], [1.0, 1.0]], [[1.0, 0], [1.0, 1.0]]] B = [[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], [[0.5, 0.5, 0.5], [0.5, 0.5, 0.5]]] - trmm(A, B, alpha = 2.0 ) = [[[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]], - [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]] - + trmm(A, B, alpha=2.0) = [[[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]], + [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]] )code" ADD_FILELINE) .set_num_inputs(2) .set_num_outputs(1) @@ -299,13 +300,13 @@ Examples:: .set_attr("FInplaceOption", [](const NodeAttrs& attrs) { return std::vector>{{1, 0}}; }) .set_attr("FCompute", LaOpForward) -.set_attr("FGradient", ElemwiseGradUseInOut{"_backward_linalg_trmm"}) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_linalg_trmm"}) .add_argument("A", "NDArray-or-Symbol", "Tensor of lower triangular matrices") .add_argument("B", "NDArray-or-Symbol", "Tensor of matrices") .add_arguments(LaTriangMatrixMultParam::__FIELDS__()); NNVM_REGISTER_OP(_backward_linalg_trmm) -.set_num_inputs(4) +.set_num_inputs(3) .set_num_outputs(2) .set_attr_parser(ParamParser) .set_attr("FInplaceOption", [](const NodeAttrs& attrs) @@ -313,45 +314,44 @@ NNVM_REGISTER_OP(_backward_linalg_trmm) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) .set_attr("TIsBackward", true) -.set_attr("FCompute", LaOpBackward); +.set_attr("FCompute", LaOpBackward); NNVM_REGISTER_OP(_linalg_trsm) .add_alias("linalg_trsm") -.describe(R"code(Solves matrix equations involving a triangular matrix. -Input are two tensors *A*, *B* each of dimension *n >= 2* and each -having the same shape on the leading *n-2* dimensions. For every *n-2* dimensional index *i* let -*A*\ :sub:`i`\ , *B*\ :sub:`i`\ be the matrices given by the last *2* dimensions. -The operator performs the BLAS3 function *trsm*, i.e. it solves the equation +.describe(R"code(Solves matrix equation involving a lower triangular matrix. +Input are tensors *A*, *B*, each of dimension *n >= 2* and having the same shape +on the leading *n-2* dimensions. - *op*\ (*A*\ :sub:`i`\ ) \* *X*\ :sub:`i` = *alpha* \* *B*\ :sub:`i` +If *n=2*, *A* must be lower triangular. The operator performs the BLAS3 function +*trsm*, solving for *out* in: -or + *op*\ (*A*) \* *out* = *alpha* \* *B* - *X*\ :sub:`i` \* *op*\ (*A*\ :sub:`i`\ ) = *alpha* \* *B*\ :sub:`i` +if *rightside=False*, or -on all such pairs of matrices. Here *alpha* is a scalar operator parameter, *op()* is either -the identity or the matrix transposition (depending on the parameter *transpose*) and the -order of multiplication on the left depends on the parameter *rightside*. -All matrices *A*\ :sub:`i` must be lower triangular. + *out* \* *op*\ (*A*) = *alpha* \* *B* -In case of *n=2*, a single *trsm* function is performed on the matrices *A*, *B*. +if *rightside=True*. Here, *alpha* is a scalar parameter, and *op()* is either the +identity or the matrix transposition (depending on *transpose*). -.. note:: The operator does only support float32 and float64 data types and provides - proper backward gradients. +If *n>2*, *trsm* is performed separately on the trailing two dimensions for all inputs +(batch mode). + +.. note:: The operator supports float32 and float64 data types only. Examples:: // Single matrix solve A = [[1.0, 0], [1.0, 1.0]] B = [[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]] - trsm(A, B, alpha = 0.5) = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]] + trsm(A, B, alpha=0.5) = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]] // Batch matrix solve A = [[[1.0, 0], [1.0, 1.0]], [[1.0, 0], [1.0, 1.0]]] B = [[[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]], [[4.0, 4.0, 4.0], [8.0, 8.0, 8.0]]] - trsm(A, B, alpha = 0.5 ) = [[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], - [[2.0, 2.0, 2.0 ], [2.0, 2.0, 2.0]]] + trsm(A, B, alpha=0.5) = [[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + [[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]] )code" ADD_FILELINE) .set_num_inputs(2) .set_num_outputs(1) @@ -381,16 +381,16 @@ NNVM_REGISTER_OP(_backward_linalg_trsm) NNVM_REGISTER_OP(_linalg_sumlogdiag) .add_alias("linalg_sumlogdiag") -.describe(R"code(Computes the sum of the logarithms of all diagonal elements in a matrix. -Input is a tensor *A* of dimension *n >= 2*. For every *n-2* dimensional index *i* let -*A*\ :sub:`i`\ be the matrix given by the last *2* dimensions. -The operator performs a reduction of each such matrix to a scalar by summing up the logarithms -of all diagonal elements. All matrices must be square and all diagonal elements must be positive. +.describe(R"code(Computes the sum of the logarithms of the diagonal elements of a square matrix. +Input is a tensor *A* of dimension *n >= 2*. + +If *n=2*, *A* must be square with positive diagonal entries. We sum the natural +logarithms of the diagonal elements, the result has shape (1,). -In case of *n=2*, *A* represents a single matrix on which the reduction will be performed. +If *n>2*, *sumlogdiag* is performed separately on the trailing two dimensions for all +inputs (batch mode). -.. note:: The operator does only support float32 and float64 data types and provides - proper backward gradients. +.. note:: The operator supports float32 and float64 data types only. Examples:: @@ -420,5 +420,133 @@ NNVM_REGISTER_OP(_backward_linalg_sumlogdiag) .set_attr("TIsBackward", true) .set_attr("FCompute", LaOpBackward); +NNVM_REGISTER_OP(_linalg_syrk) +.add_alias("linalg_syrk") +.describe(R"code(Multiplication of matrix with its transpose. +Input is a tensor *A* of dimension *n >= 2*. + +If *n=2*, the operator performs the BLAS3 function *syrk*: + + *out* = *alpha* \* *A* \* *A*\ :sup:`T` + +if *transpose=False*, or + + *out* = *alpha* \* *A*\ :sup:`T` \ \* *A* + +if *transpose=True*. + +If *n>2*, *syrk* is performed separately on the trailing two dimensions for all +inputs (batch mode). + +.. note:: The operator supports float32 and float64 data types only. + +Examples:: + + // Single matrix multiply + A = [[1., 2., 3.], [4., 5., 6.]] + syrk(A, alpha=1., transpose=False) + = [[14., 32.], + [32., 77.]] + syrk(A, alpha=1., transpose=True) + = [[17., 22., 27.], + [22., 29., 36.], + [27., 36., 45.]] + + // Batch matrix multiply + A = [[[1., 1.]], [[0.1, 0.1]]] + syrk(A, alpha=2., transpose=False) = [[[4.]], [[0.04]]] +)code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", [](const NodeAttrs& attrs) + { return std::vector{"A"}; } ) +.set_attr("FInferShape", LaSyrkShape) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FCompute", LaOpForward) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_linalg_syrk"}) +.add_argument("A", "NDArray-or-Symbol", "Tensor of input matrices") +.add_arguments(LaSyrkParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_linalg_syrk) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FResourceRequest", [](const NodeAttrs& attrs) + { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("TIsBackward", true) +.set_attr("FCompute", LaOpBackward); + +NNVM_REGISTER_OP(_linalg_gelqf) +.add_alias("linalg_gelqf") +.describe(R"code(LQ factorization for general matrix. +Input is a tensor *A* of dimension *n >= 2*. + +If *n=2*, we compute the LQ factorization (LAPACK *gelqf*, followed by *orglq*). *A* +must have shape *(x, y)* with *x <= y*, and must have full rank *=x*. The LQ +factorization consists of *L* with shape *(x, x)* and *Q* with shape *(x, y)*, so +that: + + *A* = *L* \* *Q* + +Here, *L* is lower triangular (upper triangle equal to zero) with nonzero diagonal, +and *Q* is row-orthonormal, meaning that + + *Q* \* *Q*\ :sup:`T` + +is equal to the identity matrix of shape *(x, x)*. + +If *n>2*, *gelqf* is performed separately on the trailing two dimensions for all +inputs (batch mode). + +.. note:: The operator supports float32 and float64 data types only. + +Examples:: + + // Single LQ factorization + A = [[1., 2., 3.], [4., 5., 6.]] + Q, L = gelqf(A) + Q = [[-0.26726124, -0.53452248, -0.80178373], + [0.87287156, 0.21821789, -0.43643578]] + L = [[-3.74165739, 0.], + [-8.55235974, 1.96396101]] + + // Batch LQ factorization + A = [[[1., 2., 3.], [4., 5., 6.]], + [[7., 8., 9.], [10., 11., 12.]]] + Q, L = gelqf(A) + Q = [[[-0.26726124, -0.53452248, -0.80178373], + [0.87287156, 0.21821789, -0.43643578]], + [[-0.50257071, -0.57436653, -0.64616234], + [0.7620735, 0.05862104, -0.64483142]]] + L = [[[-3.74165739, 0.], + [-8.55235974, 1.96396101]], + [[-13.92838828, 0.], + [-19.09768702, 0.52758934]]] +)code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(2) +.set_attr("FListInputNames", [](const NodeAttrs& attrs) + { return std::vector{"A"}; } ) +.set_attr("FInferShape", LaLQFactShape) +.set_attr("FInferType", ElemwiseType<1, 2>) +.set_attr("FInplaceOption", [](const NodeAttrs& attrs) + { return std::vector>{{0, 0}}; }) +.set_attr("FResourceRequest", [](const NodeAttrs& attrs) + { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("FCompute", LaOpForward) +.set_attr("FGradient", ElemwiseGradUseOut{"_backward_linalg_gelqf"}) +.add_argument("A", "NDArray-or-Symbol", "Tensor of input matrices to be factorized"); + +NNVM_REGISTER_OP(_backward_linalg_gelqf) +.set_num_inputs(4) +.set_num_outputs(1) +.set_attr("FInplaceOption", [](const NodeAttrs& attrs) + { return std::vector >{{0, 0}}; }) +.set_attr("FResourceRequest", [](const NodeAttrs& attrs) + { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("TIsBackward", true) +.set_attr("FCompute", LaOpBackward); + } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/la_op.cu b/src/operator/tensor/la_op.cu index e5d5b272c08a..963471cef31e 100644 --- a/src/operator/tensor/la_op.cu +++ b/src/operator/tensor/la_op.cu @@ -43,7 +43,7 @@ NNVM_REGISTER_OP(_linalg_trmm) .set_attr("FCompute", LaOpForward); NNVM_REGISTER_OP(_backward_linalg_trmm) -.set_attr("FCompute", LaOpBackward); +.set_attr("FCompute", LaOpBackward); NNVM_REGISTER_OP(_linalg_trsm) .set_attr("FCompute", LaOpForward); diff --git a/src/operator/tensor/la_op.h b/src/operator/tensor/la_op.h index dd5fab985e3c..b4093f6c2636 100644 --- a/src/operator/tensor/la_op.h +++ b/src/operator/tensor/la_op.h @@ -90,6 +90,20 @@ struct LaTriangMatrixMultParam : public dmlc::Parameter } }; +// Parameters for syrk +struct LaSyrkParam : public dmlc::Parameter { + bool transpose; + double alpha; + DMLC_DECLARE_PARAMETER(LaSyrkParam) { + DMLC_DECLARE_FIELD(transpose) + .set_default(false) + .describe("Use transpose of input matrix."); + DMLC_DECLARE_FIELD(alpha) + .set_default(1.0) + .describe("Scalar factor to be applied to the result."); + } +}; + // Common function for shape inference for matrix mult and matrix mac. inline bool LaMatrixMultMacOpShape(const nnvm::NodeAttrs& attrs, std::vector* in_attrs, @@ -112,7 +126,8 @@ inline bool LaMatrixMultMacOpShape(const nnvm::NodeAttrs& attrs, std::vector oshape(ndim); for ( int i = 0; i < ndim-2; ++i ) { // Both inputs must have same shape except for last two dimensions. - if ( (*in_attrs)[0][i] != (*in_attrs)[1][i] ) return false; + CHECK_EQ((*in_attrs)[0][i], (*in_attrs)[1][i]) + << "Shapes of inputs 0, 1 must be the same, except on last two dimensions"; oshape[i] = (*in_attrs)[0][i]; } CHECK_EQ((transpose_a ? (*in_attrs)[0][ndim-2] : (*in_attrs)[0][ndim-1]), @@ -146,7 +161,8 @@ inline bool LaTriangMatrixMultOpShape(const nnvm::NodeAttrs& attrs, std::vector oshape(ndim); for ( int i = 0; i < ndim-2; ++i ) { // Must have same shape except for last two dimensions. - if ( (*in_attrs)[0][i] != (*in_attrs)[1][i] ) return false; + CHECK_EQ((*in_attrs)[0][i], (*in_attrs)[1][i]) + << "Shapes of inputs 0, 1 must be the same, except on last two dimensions"; oshape[i] = (*in_attrs)[0][i]; } if ( param.rightside ) { @@ -200,8 +216,8 @@ inline bool LaReduceShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 1); CHECK_EQ(out_attrs->size(), 1); const int ndim((*in_attrs)[0].ndim()); - if ( ndim < dim ) { - return false; + if (ndim < dim) { + return false; } std::vector oshape(std::max(1, ndim-dim)); oshape[0] = 1; @@ -214,13 +230,81 @@ inline bool LaReduceShape(const nnvm::NodeAttrs& attrs, return true; } +// Shape inference function for linalg_syrk +inline bool LaSyrkShape(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1); + CHECK_EQ(out_attrs->size(), 1); + const TShape& in_attr = (*in_attrs)[0]; + bool transpose = nnvm::get(attrs.parsed).transpose; + const int ndim = in_attr.ndim(); + if ( ndim >= 2 ) { + // Forward shape inference. + std::vector oshape(ndim); + for ( int i = 0; i < ndim-2; ++i ) { + oshape[i] = in_attr[i]; + } + oshape[ndim-2] = (transpose ? in_attr[ndim-1] : in_attr[ndim-2]); + oshape[ndim-1] = oshape[ndim-2]; + TShape tshape(oshape.begin(), oshape.end()); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, tshape); + return true; + } + // Can't do backward inference of shapes for this operator. + return false; +} + +// Shape inference function for linalg_gelqf +// Inputs: A. Outputs: Q, L +inline bool LaLQFactShape(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1); + CHECK_EQ(out_attrs->size(), 2); + const TShape& in_a = (*in_attrs)[0]; + const TShape& out_q = (*out_attrs)[0]; + const TShape& out_l = (*out_attrs)[1]; + if ( in_a.ndim() >= 2 ) { + // Forward shape inference. + const int ndim(in_a.ndim()); + CHECK_LE(in_a[ndim-2], in_a[ndim-1]) + << "Input A shape wrong: Last dimension must be >= than second to last"; + // Q must have same shape as A + SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_a); + std::vector oshape_l(ndim); + for ( int i = 0; i < ndim-1; ++i ) { + oshape_l[i] = in_a[i]; + } + oshape_l[ndim-1] = in_a[ndim-2]; + TShape tshape_l(oshape_l.begin(), oshape_l.end()); + SHAPE_ASSIGN_CHECK(*out_attrs, 1, tshape_l); + return true; + } + if ( out_q.ndim() >= 2 && out_q.ndim() == out_l.ndim() ) { + // Backward shape inference. + const int ndim(out_q.ndim()); + for ( int i = 0; i < ndim-1; ++i ) { + CHECK_EQ(out_q[i], out_l[i]) + << "Outputs Q, L must have same dimensions except for last"; + } + CHECK_LE(out_q[ndim-2], out_q[ndim-1]) + << "Output Q shape wrong: Last dimension must be >= than second to last"; + CHECK_EQ(out_l[ndim-2], out_l[ndim-1]) + << "Output L shape wrong: Last two dimensions must be equal"; + SHAPE_ASSIGN_CHECK(*in_attrs, 0, out_q); + return true; + } + return false; +} + // Adapters for calling the various operators with appropriate signatures. template struct LaOpCaller { static void op(const std::vector& inputs, const std::vector& outputs, const nnvm::NodeAttrs& attrs, - mshadow::Stream *s) { + const OpContext& ctx) { CHECK(false) << "no specialized LaOpCaller defined for template parameters"; } }; @@ -229,9 +313,22 @@ struct LaOpCaller { static void op(const std::vector& inputs, const std::vector& outputs, const nnvm::NodeAttrs& attrs, - mshadow::Stream *s) { + const OpContext& ctx) { + mshadow::Stream *s = ctx.get_stream(); + laop::op(inputs[0].FlatToKD(s), + outputs[0].FlatToKD(s), ctx, attrs); + } +}; +template +struct LaOpCaller { + static void op(const std::vector& inputs, + const std::vector& outputs, + const nnvm::NodeAttrs& attrs, + const OpContext& ctx) { + mshadow::Stream *s = ctx.get_stream(); laop::op(inputs[0].FlatToKD(s), - outputs[0].FlatToKD(s), s, attrs); + outputs[0].FlatToKD(s), + outputs[1].FlatToKD(s), ctx, attrs); } }; template @@ -239,10 +336,11 @@ struct LaOpCaller { static void op(const std::vector& inputs, const std::vector& outputs, const nnvm::NodeAttrs& attrs, - mshadow::Stream *s) { + const OpContext& ctx) { + mshadow::Stream *s = ctx.get_stream(); laop::op(inputs[0].FlatToKD(s), inputs[1].FlatToKD(s), - outputs[0].FlatToKD(s), s, attrs); + outputs[0].FlatToKD(s), ctx, attrs); } }; template @@ -250,11 +348,12 @@ struct LaOpCaller { static void op(const std::vector& inputs, const std::vector& outputs, const nnvm::NodeAttrs& attrs, - mshadow::Stream *s) { + const OpContext& ctx) { + mshadow::Stream *s = ctx.get_stream(); laop::op(inputs[0].FlatToKD(s), inputs[1].FlatToKD(s), inputs[2].FlatToKD(s), - outputs[0].FlatToKD(s), s, attrs); + outputs[0].FlatToKD(s), ctx, attrs); } }; template @@ -262,12 +361,27 @@ struct LaOpCaller { static void op(const std::vector& inputs, const std::vector& outputs, const nnvm::NodeAttrs& attrs, - mshadow::Stream *s) { + const OpContext& ctx) { + mshadow::Stream *s = ctx.get_stream(); laop::op(inputs[0].FlatToKD(s), inputs[1].FlatToKD(s), inputs[2].FlatToKD(s), outputs[0].FlatToKD(s), - outputs[1].FlatToKD(s), s, attrs); + outputs[1].FlatToKD(s), ctx, attrs); + } +}; +template +struct LaOpCaller { + static void op(const std::vector& inputs, + const std::vector& outputs, + const nnvm::NodeAttrs& attrs, + const OpContext& ctx) { + mshadow::Stream *s = ctx.get_stream(); + laop::op(inputs[0].FlatToKD(s), + inputs[1].FlatToKD(s), + inputs[2].FlatToKD(s), + inputs[3].FlatToKD(s), + outputs[0].FlatToKD(s), ctx, attrs); } }; template @@ -275,13 +389,14 @@ struct LaOpCaller { static void op(const std::vector& inputs, const std::vector& outputs, const nnvm::NodeAttrs& attrs, - mshadow::Stream *s) { + const OpContext& ctx) { + mshadow::Stream *s = ctx.get_stream(); laop::op(inputs[0].FlatToKD(s), inputs[1].FlatToKD(s), inputs[2].FlatToKD(s), inputs[3].FlatToKD(s), outputs[0].FlatToKD(s), - outputs[1].FlatToKD(s), s, attrs); + outputs[1].FlatToKD(s), ctx, attrs); } }; template @@ -289,14 +404,15 @@ struct LaOpCaller { static void op(const std::vector& inputs, const std::vector& outputs, const nnvm::NodeAttrs& attrs, - mshadow::Stream *s) { + const OpContext& ctx) { + mshadow::Stream *s = ctx.get_stream(); laop::op(inputs[0].FlatToKD(s), inputs[1].FlatToKD(s), inputs[2].FlatToKD(s), inputs[3].FlatToKD(s), outputs[0].FlatToKD(s), outputs[1].FlatToKD(s), - outputs[2].FlatToKD(s), s, attrs); + outputs[2].FlatToKD(s), ctx, attrs); } }; @@ -308,11 +424,11 @@ void LaOpForward(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { using namespace mshadow; - Stream *s = ctx.get_stream(); CHECK_EQ(inputs.size(), inum); CHECK_EQ(outputs.size(), onum); MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, { - LaOpCaller::op(inputs, outputs, attrs, s); + LaOpCaller::op(inputs, outputs, + attrs, ctx); }); } @@ -331,11 +447,12 @@ void LaOpBackward(const nnvm::NodeAttrs& attrs, std::vector tspace(outputs); for ( int i = 0; i < onum; ++i ) { if ( req[i] == kAddTo ) { - tspace[i].dptr_ = ctx.requested[ResourceRequest::kTempSpace] + tspace[i].dptr_ = ctx.requested[0] .get_space_typed(Shape1(outputs[i].Size()), s).dptr_; } } - LaOpCaller::op(inputs, tspace, attrs, s); + LaOpCaller::op(inputs, tspace, + attrs, ctx); for ( int i = 0; i < onum; ++i ) { if ( req[i] == kAddTo ) { Tensor out = outputs[i].FlatTo1D(s); diff --git a/src/operator/tensor/la_op_inline.h b/src/operator/tensor/la_op_inline.h index 34fb441f53f7..aa7b0a736694 100644 --- a/src/operator/tensor/la_op_inline.h +++ b/src/operator/tensor/la_op_inline.h @@ -55,12 +55,14 @@ struct Scale { }; // Forward computations (always using batched processing) +// CHANGE: Added xyz::op(..., ctx, attrs), which calls xyz::op(..., s, attrs) // D = gemm(A,B,C) struct gemm { template static void op(const Tensor& A, const Tensor& B, - const Tensor& C, DType alpha, DType beta, bool tA, bool tB, Stream *s) { + const Tensor& C, DType alpha, DType beta, + bool tA, bool tB, Stream *s) { linalg_batch_gemm(A, B, C, alpha, beta, tA, tB, s); } template @@ -69,8 +71,15 @@ struct gemm { Stream *s, const nnvm::NodeAttrs& attrs) { if ( C.dptr_ != D.dptr_ ) Copy(D, C, s); const LaMatrixMacParam& param = nnvm::get(attrs.parsed); - gemm::op(A, B, D, DType(param.alpha), DType(param.beta), - param.transpose_a, param.transpose_b, s); + op(A, B, D, DType(param.alpha), DType(param.beta), param.transpose_a, + param.transpose_b, s); + } + template + static void op(const Tensor& A, const Tensor& B, + const Tensor& C, const Tensor& D, + const OpContext& ctx, const nnvm::NodeAttrs& attrs) { + Stream *s = ctx.get_stream(); + op(A, B, C, D, s, attrs); } }; @@ -78,9 +87,18 @@ struct gemm { struct gemm2 { template static void op(const Tensor& A, const Tensor& B, - const Tensor& C, Stream *s, const nnvm::NodeAttrs& attrs) { + const Tensor& C, Stream *s, + const nnvm::NodeAttrs& attrs) { const LaMatrixMultParam& param = nnvm::get(attrs.parsed); - gemm::op(A, B, C, DType(param.alpha), DType(0), param.transpose_a, param.transpose_b, s); + gemm::op(A, B, C, DType(param.alpha), DType(0), param.transpose_a, + param.transpose_b, s); + } + template + static void op(const Tensor& A, const Tensor& B, + const Tensor& C, const OpContext& ctx, + const nnvm::NodeAttrs& attrs) { + Stream *s = ctx.get_stream(); + op(A, B, C, s, attrs); } }; @@ -94,6 +112,12 @@ struct potrf { using namespace mxnet_op; Kernel::Launch(s, L.MSize(), L.size(1)*L.stride_, L.stride_, L.dptr_); } + template + static void op(const Tensor& A, const Tensor& L, + const OpContext& ctx, const nnvm::NodeAttrs& attrs) { + Stream *s = ctx.get_stream(); + op(A, L, s, attrs); + } }; // A = potri(L). @@ -106,6 +130,12 @@ struct potri { using namespace mxnet_op; Kernel::Launch(s, A.MSize(), A.size(1)*A.stride_, A.stride_, A.dptr_); } + template + static void op(const Tensor& L, const Tensor& A, + const OpContext& ctx, const nnvm::NodeAttrs& attrs) { + Stream *s = ctx.get_stream(); + op(L, A, s, attrs); + } }; // B = trsm(L,A) @@ -123,6 +153,13 @@ struct trsm { const LaTriangMatrixMultParam& param = nnvm::get(attrs.parsed); op(L, B, DType(param.alpha), param.rightside, param.transpose, s); } + template + static void op(const Tensor& L, const Tensor& A, + const Tensor& B, + const OpContext& ctx, const nnvm::NodeAttrs& attrs) { + Stream *s = ctx.get_stream(); + op(L, A, B, s, attrs); + } }; // B = trmm(L,A) @@ -134,11 +171,19 @@ struct trmm { } template static void op(const Tensor& L, const Tensor& A, - const Tensor& B, Stream *s, const nnvm::NodeAttrs& attrs) { + const Tensor& B, Stream *s, + const nnvm::NodeAttrs& attrs) { if ( A.dptr_ != B.dptr_ ) Copy(B, A, s); const LaTriangMatrixMultParam& param = nnvm::get(attrs.parsed); op(L, B, DType(param.alpha), param.rightside, param.transpose, s); } + template + static void op(const Tensor& L, const Tensor& A, + const Tensor& B, const OpContext& ctx, + const nnvm::NodeAttrs& attrs) { + Stream *s = ctx.get_stream(); + op(L, A, B, s, attrs); + } }; // Useful operator that is not part of BLAS/LAPACK. @@ -161,6 +206,82 @@ struct sumlogdiag { using namespace mxnet_op; Kernel::Launch(s, A.size(0), A.size(1), A.stride_, A.dptr_, B.dptr_); } + template + static void op(const Tensor& A, const Tensor& B, + const OpContext& ctx, const nnvm::NodeAttrs& attrs) { + Stream *s = ctx.get_stream(); + op(A, B, s, attrs); + } +}; + +// B = syrk(A) +struct syrk { + template + static void op(const Tensor& A, const Tensor& B, + DType alpha, DType beta, bool tA, Stream *s) { + linalg_batch_syrk(A, B, alpha, beta, tA, s); + // Symmetric B is in lower triangle: Copy to upper + using namespace mxnet_op; + Kernel::Launch(s, B.MSize(), B.size(1)*B.stride_, + B.stride_, B.dptr_); + } + template + static void op(const Tensor& A, const Tensor& B, + Stream *s, const nnvm::NodeAttrs& attrs) { + const LaSyrkParam& param = nnvm::get(attrs.parsed); + op(A, B, DType(param.alpha), DType(0), param.transpose, s); + } + template + static void op(const Tensor& A, const Tensor& B, + const OpContext& ctx, const nnvm::NodeAttrs& attrs) { + Stream *s = ctx.get_stream(); + op(A, B, s, attrs); + } +}; + +// (Q, L) = gelqf(A) [LQ factorization] +// More complex than the other cases: +// - Has to reserve workspace, whose size can only be determined by workspace +// queries. This is done once, and then the workspace is used for all items +// of the batch +// - Two different LAPACK functions are called (the first, gelqf, returns an +// internal representation, which has to be converted into Q, L) +struct gelqf { + template + static void op(const Tensor& A, const Tensor& Q, + const Tensor& L, const OpContext& ctx, + const nnvm::NodeAttrs& attrs) { + Stream *s = ctx.get_stream(); + if (A.dptr_ != Q.dptr_) Copy(Q, A, s); + // From here on, we work on Q only + // Reserve workspace + // The size is determined by workspace queries, done on the first items + // of the batch + int ws_size(linalg_gelqf_workspace_query(Q[0], s)); + Tensor work = ctx.requested[0] + .get_space_typed(Shape1(ws_size), s); + // Loop over items in batch + linalg_check_batch_size(A.size(0), Q.size(0), L.size(0)); + int m = Q.size(1); // Q[i] has shape (m, n) + for (index_t i = 0; i < A.size(0); ++i) { + const Tensor& Qi = Q[i]; + const Tensor& Li = L[i]; + // Call gelqf: Overwrites Qi and part of work. Afterwards, L matrix is + // in lower triangle of Qi + linalg_gelqf(Qi, work, s); + // Copy lower triangle & diagonal of Qi ==> Li. + // Also, zero the upper triangle. + // QLeft: First m columns of Qi + Tensor QLeft(Qi.dptr_, Shape2(m, m), Qi.stride_, s); + Copy(Li, QLeft, s); + using namespace mxnet_op; + Kernel::Launch(s, Li.MSize(), m*Li.stride_, Li.stride_, + Li.dptr_); + // Call orglq: Input is Qi and part of work. Overwrites Qi by final Q + // matrix (conversion from internal representation) + linalg_orglq(Qi, work, s); + } + } }; // Backward operators (always using batch processing) @@ -182,6 +303,15 @@ struct gemm_backward { using namespace mxnet_op; Kernel::Launch(s, dC.MSize(), DType(param.beta), dC.dptr_); } + template + static void op(const Tensor& dD, const Tensor& A, + const Tensor& B, const Tensor& C, + const Tensor& dA, const Tensor& dB, + const Tensor& dC, + const OpContext& ctx, const nnvm::NodeAttrs& attrs) { + Stream *s = ctx.get_stream(); + op(dD, A, B, C, dA, dB, dC, s, attrs); + } }; struct gemm2_backward { @@ -197,6 +327,14 @@ struct gemm2_backward { (tB ? gemm::op(dC, A, dB, DType(param.alpha), DType(0), true, tA, s) : gemm::op(A, dC, dB, DType(param.alpha), DType(0), !tA, false, s)); } + template + static void op(const Tensor& dC, const Tensor& A, + const Tensor& B, const Tensor& dA, + const Tensor& dB, + const OpContext& ctx, const nnvm::NodeAttrs& attrs) { + Stream *s = ctx.get_stream(); + op(dC, A, B, dA, dB, s, attrs); + } }; struct potrf_backward { @@ -205,11 +343,10 @@ struct potrf_backward { const Tensor& dA, Stream* s, const nnvm::NodeAttrs& attrs) { // Backward of L = potrf(A). - // dA = 0.5 * L**T * symm(L**T * dL # E) * L**(-1) where - // '#' denotes Hadamard product - // E is the matrix having 1 on diagonal, 0 on upper and 2 on lower triagle - // symm(X) = 0.5 * (X + X**T) - // Hadamard product and symm can be realized by a single copy from lower to upper triangle. + // dA = 0.5 * L**T * copyLTU(L**T * dL) * L**(-1) + // Here, copyLTU(M) creates a symmetric matrix from the square matrix M + // by setting the upper triangle to be equal to the lower triangle, leaving + // lower triangle and diagonal unchanged. if ( dL.dptr_ != dA.dptr_ ) { Copy(dA, dL, s); } @@ -220,6 +357,13 @@ struct potrf_backward { trsm::op(L, dA, DType(1.0), false, true, s); trsm::op(L, dA, DType(0.5), true, false, s); } + template + static void op(const Tensor& dL, const Tensor& L, + const Tensor& dA, + const OpContext& ctx, const nnvm::NodeAttrs& attrs) { + Stream *s = ctx.get_stream(); + op(dL, L, dA, s, attrs); + } }; struct potri_backward { @@ -228,11 +372,23 @@ struct potri_backward { const Tensor& A, const Tensor& dL, Stream* s, const nnvm::NodeAttrs& attrs) { // Backward of A = potri(L). - // dL = -2 * tril(A * dA * L**(-T)), where tril() extracts lower triangle and diagonal. - gemm::op(A, dA, dL, DType(1.0), DType(0), false, false, s); - trsm::op(L, dL, DType(-2.0), true, true, s); + // dL = -tril( A * (dA + dA**T) * L**(-T)), where tril() extracts lower triangle + // and diagonal. We must not assume that dA is symmetric. + // Note: Calling gemm twice here is a bit wasteful, but otherwise the symmetrization + // of dA would require temporary memory. + gemm::op(A, dA, dL, DType(1.), DType(0.), false, false, s); + gemm::op(A, dA, dL, DType(1.), DType(1.), false, true, s); + trsm::op(L, dL, DType(-1.), true, true, s); using namespace mxnet_op; - Kernel::Launch(s, dL.MSize(), dL.size(1)*dL.stride_, dL.stride_, dL.dptr_); + Kernel::Launch(s, dL.MSize(), dL.size(1)*dL.stride_, dL.stride_, + dL.dptr_); + } + template + static void op(const Tensor& dA, const Tensor& L, + const Tensor& A, const Tensor& dL, + const OpContext& ctx, const nnvm::NodeAttrs& attrs) { + Stream *s = ctx.get_stream(); + op(dA, L, A, dL, s, attrs); } }; @@ -255,27 +411,46 @@ struct trsm_backward { using namespace mxnet_op; Kernel::Launch(s, dL.MSize(), dL.size(1)*dL.stride_, dL.stride_, dL.dptr_); } + template + static void op(const Tensor& dB, const Tensor& L, + const Tensor& A, const Tensor& B, + const Tensor& dL, const Tensor& dA, + const OpContext& ctx, const nnvm::NodeAttrs& attrs) { + Stream *s = ctx.get_stream(); + op(dB, L, A, B, dL, dA, s, attrs); + } }; struct trmm_backward { template static void op(const Tensor& dB, const Tensor& L, - const Tensor& A, const Tensor& B, - const Tensor& dL, const Tensor& dA, - Stream* s, const nnvm::NodeAttrs& attrs) { + const Tensor& A, const Tensor& dL, + const Tensor& dA, Stream* s, + const nnvm::NodeAttrs& attrs) { // Backward of B = trmm(L,A). const LaTriangMatrixMultParam& param = nnvm::get(attrs.parsed); // Compute dL - const bool db_left(param.rightside == param.transpose); DType scale(param.alpha); - (db_left ? gemm::op(dB, A, dL, scale, DType(0), param.transpose, !param.transpose, s) - : gemm::op(A, dB, dL, scale, DType(0), !param.transpose, param.transpose, s)); + if (param.rightside == param.transpose) { + gemm::op(dB, A, dL, scale, DType(0.), param.transpose, !param.transpose, s); + } else { + gemm::op(A, dB, dL, scale, DType(0.), !param.transpose, param.transpose, s); + } using namespace mxnet_op; - Kernel::Launch(s, dL.MSize(), dL.size(1)*dL.stride_, dL.stride_, dL.dptr_); + Kernel::Launch(s, dL.MSize(), dL.size(1)*dL.stride_, dL.stride_, + dL.dptr_); // Compute dA - if ( dA.dptr_ != dB.dptr_ ) Copy(dA, dB, s); + if (dA.dptr_ != dB.dptr_) Copy(dA, dB, s); trmm::op(L, dA, scale, param.rightside, !param.transpose, s); } + template + static void op(const Tensor& dB, const Tensor& L, + const Tensor& A, const Tensor& dL, + const Tensor& dA, const OpContext& ctx, + const nnvm::NodeAttrs& attrs) { + Stream *s = ctx.get_stream(); + op(dB, L, A, dL, dA, s, attrs); + } }; struct BackwardSumLogDiag { @@ -302,6 +477,69 @@ struct sumlogdiag_backward { Kernel::Launch (s, A.size(0), A.size(1), A.stride_, dB.dptr_, A.dptr_, dA.dptr_); } + template + static void op(const Tensor& dB, const Tensor& A, + const Tensor& dA, + const OpContext& ctx, const nnvm::NodeAttrs& attrs) { + Stream *s = ctx.get_stream(); + op(dB, A, dA, s, attrs); + } +}; + +struct syrk_backward { + template + static void op(const Tensor& dB, const Tensor& A, + const Tensor& dA, Stream* s, + const nnvm::NodeAttrs& attrs) { + const LaSyrkParam& param = nnvm::get(attrs.parsed); + // Note: Calling gemm twice is a bit wasteful, but the symmetrization of dB + // would otherwise need temporary memory + if (param.transpose) { + gemm::op(A, dB, dA, DType(param.alpha), DType(0.), false, false, s); + gemm::op(A, dB, dA, DType(param.alpha), DType(1.), false, true, s); + } else { + gemm::op(dB, A, dA, DType(param.alpha), DType(0.), false, false, s); + gemm::op(dB, A, dA, DType(param.alpha), DType(1.), true, false, s); + } + } + template + static void op(const Tensor& dB, const Tensor& A, + const Tensor& dA, const OpContext& ctx, + const nnvm::NodeAttrs& attrs) { + Stream *s = ctx.get_stream(); + op(dB, A, dA, s, attrs); + } +}; + +// Have to reserve temporary storage tempM, same shape as dL +struct gelqf_backward { + template + static void op(const Tensor& dQ, + const Tensor& dL, + const Tensor& Q, + const Tensor& L, + const Tensor& dA, + const OpContext& ctx, const nnvm::NodeAttrs& attrs) { + // Backward of (Q, L) = gelqf(A): + // dA = L**(-T) * (dQ + copyLTU(M) * Q), M = L**T * dL - dQ * Q**T + // Here, copyLTU(M) creates a symmetric matrix from the square matrix M + // by setting the upper triangle to be equal to the lower triangle, leaving + // lower triangle and diagonal unchanged. + using namespace mxnet_op; + Stream *s = ctx.get_stream(); + if (dQ.dptr_ != dA.dptr_) Copy(dA, dQ, s); + // Need temporal space, same shape as dL + Tensor tempM = ctx.requested[0] + .get_space_typed(dL.shape_, s); + Copy(tempM, dL, s); + trmm::op(L, tempM, DType(1.0), false, true, s); + gemm::op(dA, Q, tempM, DType(-1.0), DType(1.0), false, true, s); + Kernel::Launch + (s, tempM.MSize(), tempM.size(1)*tempM.stride_, tempM.stride_, + tempM.dptr_); + gemm::op(tempM, Q, dA, DType(1.0), DType(1.0), false, false, s); + trsm::op(L, dA, DType(1.0), false, true, s); + } }; } // namespace op diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 41e9fbd52224..efc01862ffdc 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -3578,15 +3578,53 @@ def test_deformable_psroipooling(): grad_nodes=grad_nodes, ctx=mx.gpu(0)) -def test_laop(): +# Helper functions for test_laop + +def _make_symm_symbol(a, ndims): + assert ndims >= 2 + tr_shape = list(range(ndims)) + tr_shape[-1] = ndims-2 + tr_shape[-2] = ndims-1 + tr_shape = tuple(tr_shape) + return 0.5 * (a + mx.sym.transpose(a, axes=tr_shape)) + +def _make_lower_triangle_symm(a, ndims, m, dtype=np.float32): + assert ndims >= 2 + # The last two dimensions must both be m + # Create mask for lower triangle and diagonal + index = mx.sym.arange(start=0, stop=m, step=1, dtype=np.int32) + lt_mask = mx.sym.one_hot(index, depth=m, dtype=dtype) + for j in range(1, m): + part1 = mx.sym.zeros(shape=(j, m), dtype=dtype) + index = mx.sym.arange(start=0, stop=m-j, step=1, dtype=np.int32) + part2 = mx.sym.one_hot(index, depth=m, dtype=dtype) + lt_mask = lt_mask + mx.sym.concat(*[part1, part2], dim=0) + shp = tuple([1]*(ndims-2) + [m, m]) + lt_mask = mx.sym.reshape(lt_mask, shape=shp) + return mx.sym.broadcast_mul(a, lt_mask) +def test_laop(): + dtype = np.float64 + rtol_fw = 1e-7 + atol_fw = 1e-9 + num_eps = 1e-6 + rtol_bw = 1e-5 + atol_bw = 1e-6 # enable numerical checking of gradients grad_check = 1 data1 = mx.symbol.Variable('data1') data2 = mx.symbol.Variable('data2') data3 = mx.symbol.Variable('data3') - data4 = mx.symbol.Variable('data4') + + check_fw = lambda sym, location, expected :\ + check_symbolic_forward(sym, location, expected, rtol=rtol_fw, + atol=atol_fw, dtype=dtype) + check_grad = lambda sym, location:\ + check_numeric_gradient(sym, location, numeric_eps=num_eps, rtol=rtol_bw, + atol=atol_bw, dtype=dtype) + rep_3x = lambda a, m, n :\ + np.reshape(np.tile(np.array(a).flatten(), 3), (3, 1, m, n)) # Test gemm separately from other la-operators. shape1 = (2, 3) @@ -3602,222 +3640,316 @@ def test_laop(): # Check all transpositions of gemm operator. data_in1_t = np.transpose(data_in1) data_in2_t = np.transpose(data_in2) - res_gemm = 4*np.dot(data_in1,data_in2)+7*data_in4 - test_gemm = mx.sym.linalg.gemm(data1, data2, data3, alpha = 4, beta = 7) - check_symbolic_forward(test_gemm, [data_in1, data_in2, data_in4], [res_gemm]) + res_gemm = 4. * np.dot(data_in1, data_in2) + 7. * data_in4 + test_gemm = mx.sym.linalg_gemm(data1, data2, data3, alpha=4., beta=7.) + check_fw(test_gemm, [data_in1, data_in2, data_in4], [res_gemm]) if grad_check == 1: - check_numeric_gradient(test_gemm, [data_in1, data_in2, data_in4], numeric_eps=1e-3, rtol=1e-1, atol=1e-1) - res_gemm = 4*np.dot(data_in1_t,data_in2_t)+7*data_in3 - test_gemm = mx.sym.linalg.gemm(data1, data2, data3, alpha = 4, beta = 7, transpose_a = 1, transpose_b = 1) - check_symbolic_forward(test_gemm, [data_in1, data_in2, data_in3], [res_gemm]) + check_grad(test_gemm, [data_in1, data_in2, data_in4]) + res_gemm = 4. * np.dot(data_in1_t, data_in2_t) + 7. * data_in3 + test_gemm = mx.sym.linalg_gemm(data1, data2, data3, alpha=4., beta=7., + transpose_a=True, transpose_b=True) + check_fw(test_gemm, [data_in1, data_in2, data_in3], [res_gemm]) if grad_check == 1: - check_numeric_gradient(test_gemm, [data_in1, data_in2, data_in3], numeric_eps=1e-3, rtol=1e-1, atol=1e-1) - res_gemm = 4*np.dot(data_in1_t,data_in1)+7*data_in3 - test_gemm = mx.sym.linalg.gemm(data1, data2, data3, alpha = 4, beta = 7, transpose_a = 1) - check_symbolic_forward(test_gemm, [data_in1, data_in1, data_in3], [res_gemm]) + check_grad(test_gemm, [data_in1, data_in2, data_in3]) + res_gemm = 4. * np.dot(data_in1_t, data_in1) + 7. * data_in3 + test_gemm = mx.sym.linalg_gemm(data1, data2, data3, alpha=4., beta=7., + transpose_a=True) + check_fw(test_gemm, [data_in1, data_in1, data_in3], [res_gemm]) if grad_check == 1: - check_numeric_gradient(test_gemm, [data_in1, data_in1, data_in3], numeric_eps=1e-3, rtol=1e-1, atol=1e-1) - res_gemm = 4*np.dot(data_in1,data_in1_t)+7*data_in4 - test_gemm = mx.sym.linalg.gemm(data1, data2, data3, alpha = 4, beta = 7, transpose_b = 1) - check_symbolic_forward(test_gemm, [data_in1, data_in1, data_in4], [res_gemm]) + check_grad(test_gemm, [data_in1, data_in1, data_in3]) + res_gemm = 4. * np.dot(data_in1, data_in1_t) + 7. * data_in4 + test_gemm = mx.sym.linalg_gemm(data1, data2, data3, alpha=4., beta=7., + transpose_b=True) + check_fw(test_gemm, [data_in1, data_in1, data_in4], [res_gemm]) if grad_check == 1: - check_numeric_gradient(test_gemm, [data_in1, data_in1, data_in4], numeric_eps=1e-3, rtol=1e-1, atol=1e-1) + check_grad(test_gemm, [data_in1, data_in1, data_in4]) # Check batch of gemm. - a = np.tile(np.array(data_in1).flatten(),3) - a = np.reshape(a,(3,1,2,3)) - b = np.tile(np.array(data_in2).flatten(),3) - b = np.reshape(b,(3,1,3,2)) - c = np.tile(np.array(data_in4).flatten(),3) - c = np.reshape(c,(3,1,2,2)) - r = 4*np.dot(data_in1,data_in2)+7*data_in4 - r = np.tile(r.flatten(),3) - r = np.reshape(r,(3,1,2,2)) - test_gemm = mx.sym.linalg.gemm(data1, data2, data3, alpha = 4, beta = 7) - check_symbolic_forward(test_gemm, [a, b, c], [r]) + a = rep_3x(data_in1, 2, 3) + b = rep_3x(data_in2, 3, 2) + c = rep_3x(data_in4, 2, 2) + r = 4. * np.dot(data_in1, data_in2) + 7. * data_in4 + r = rep_3x(r, 2, 2) + test_gemm = mx.sym.linalg_gemm(data1, data2, data3, alpha=4., beta=7.) + check_fw(test_gemm, [a, b, c], [r]) if grad_check == 1: - check_numeric_gradient(test_gemm, [a, b, c], numeric_eps=1e-3, rtol=1e-1, atol=1e-1) + check_grad(test_gemm, [a, b, c]) # Check gemm2 operator same way as gemm. - res_gemm = 4*np.dot(data_in1,data_in2) - test_gemm = mx.sym.linalg.gemm2(data1, data2, alpha = 4) - check_symbolic_forward(test_gemm, [data_in1, data_in2], [res_gemm]) + res_gemm = 4. * np.dot(data_in1, data_in2) + test_gemm = mx.sym.linalg_gemm2(data1, data2, alpha=4.) + check_fw(test_gemm, [data_in1, data_in2], [res_gemm]) if grad_check == 1: - check_numeric_gradient(test_gemm, [data_in1, data_in2], numeric_eps=1e-3, rtol=1e-1, atol=1e-1) - res_gemm = 4*np.dot(data_in1_t, data_in2_t) - test_gemm = mx.sym.linalg.gemm2(data1, data2, alpha = 4, transpose_a = 1, transpose_b = 1) - check_symbolic_forward(test_gemm, [data_in1, data_in2], [res_gemm]) + check_grad(test_gemm, [data_in1, data_in2]) + res_gemm = 4. * np.dot(data_in1_t, data_in2_t) + test_gemm = mx.sym.linalg_gemm2(data1, data2, alpha=4., transpose_a=True, + transpose_b=True) + check_fw(test_gemm, [data_in1, data_in2], [res_gemm]) if grad_check == 1: - check_numeric_gradient(test_gemm, [data_in1, data_in2], numeric_eps=1e-3, rtol=1e-1, atol=1e-1) - res_gemm = 4*np.dot(data_in1_t,data_in1) - test_gemm = mx.sym.linalg.gemm2(data1, data2, alpha = 4, transpose_a = 1) - check_symbolic_forward(test_gemm, [data_in1, data_in1], [res_gemm]) + check_grad(test_gemm, [data_in1, data_in2]) + res_gemm = 4. * np.dot(data_in1_t, data_in1) + test_gemm = mx.sym.linalg_gemm2(data1, data2, alpha=4., transpose_a=True) + check_fw(test_gemm, [data_in1, data_in1], [res_gemm]) if grad_check == 1: - check_numeric_gradient(test_gemm, [data_in1, data_in1], numeric_eps=1e-3, rtol=1e-1, atol=1e-1) - res_gemm = 4*np.dot(data_in1,data_in1_t) - test_gemm = mx.sym.linalg.gemm2(data1, data2, alpha = 4, transpose_b = 1) - check_symbolic_forward(test_gemm, [data_in1, data_in1], [res_gemm]) + check_grad(test_gemm, [data_in1, data_in1]) + res_gemm = 4. * np.dot(data_in1, data_in1_t) + test_gemm = mx.sym.linalg_gemm2(data1, data2, alpha=4., transpose_b=True) + check_fw(test_gemm, [data_in1, data_in1], [res_gemm]) if grad_check == 1: - check_numeric_gradient(test_gemm, [data_in1, data_in1], numeric_eps=1e-3, rtol=1e-1, atol=1e-1) + check_grad(test_gemm, [data_in1, data_in1]) # Check batch of gemm2. - a = np.tile(np.array(data_in1).flatten(),3) - a = np.reshape(a,(3,1,2,3)) - b = np.tile(np.array(data_in2).flatten(),3) - b = np.reshape(b,(3,1,3,2)) - r = 4*np.dot(data_in1,data_in2) - r = np.tile(r.flatten(),3) - r = np.reshape(r,(3,1,2,2)) - test_gemm = mx.sym.linalg.gemm2(data1, data2, alpha = 4) - check_symbolic_forward(test_gemm, [a, b], [r]) + a = rep_3x(data_in1, 2, 3) + b = rep_3x(data_in2, 3, 2) + r = rep_3x(4. * np.dot(data_in1, data_in2), 2, 2) + test_gemm = mx.sym.linalg_gemm2(data1, data2, alpha=4.) + check_fw(test_gemm, [a, b], [r]) if grad_check == 1: - check_numeric_gradient(test_gemm, [a, b], numeric_eps=1e-3, rtol=1e-1, atol=1e-1) + check_grad(test_gemm, [a, b]) # Now test all the other operators. # Tests with trivial 1x1 matrices. - shape = (4, 4, 1, 1 ) + shape = (4, 4, 1, 1) data_in = np.random.uniform(1, 10, shape) # test potrf + # Note: Have to symmetrize input, for gradient test to work res_potrf = np.sqrt(data_in) - test_potrf = mx.sym.linalg.potrf(data1) - check_symbolic_forward(test_potrf, [data_in], [res_potrf]) + test_potrf = mx.sym.linalg_potrf(data1) + check_fw(test_potrf, [data_in], [res_potrf]) if grad_check == 1: - check_numeric_gradient(test_potrf, [data_in]) + check_grad(test_potrf, [data_in]) # test potri ones = mx.nd.ones(shape).asnumpy() - res_potri = np.divide(ones,data_in*data_in) - test_potri = mx.sym.linalg.potri(data1) - check_symbolic_forward(test_potri, [data_in], [res_potri]) + res_potri = np.divide(ones, data_in * data_in) + test_potri = mx.sym.linalg_potri(data1) + check_fw(test_potri, [data_in], [res_potri]) if grad_check == 1: - check_numeric_gradient(test_potri, [data_in], atol = 0.01, rtol = 1.5) + check_grad(test_potri, [data_in]) # test trsm - trian_in = data_in *7 - test_trsm = mx.sym.linalg.trsm(data1,data2,alpha = 7) - check_symbolic_forward(test_trsm, [trian_in,data_in], [ones]) + trian_in = data_in * 7. + test_trsm = mx.sym.linalg_trsm(data1, data2, alpha=7.) + check_fw(test_trsm, [trian_in, data_in], [ones]) if grad_check == 1: - check_numeric_gradient(test_trsm, [trian_in,data_in], atol = 0.02, rtol = 2.0) + check_grad(test_trsm, [trian_in,data_in]) # test trmm - trian_in = np.divide(ones,trian_in) - test_trmm = mx.sym.linalg.trmm(data1,data2,alpha = 7, transpose = 1, rightside = 1) - check_symbolic_forward(test_trmm, [trian_in,data_in], [ones]) + trian_in = np.divide(ones, trian_in) + test_trmm = mx.sym.linalg_trmm(data1, data2, alpha=7., transpose=True, + rightside=True) + check_fw(test_trmm, [trian_in, data_in], [ones]) if grad_check == 1: - check_numeric_gradient(test_trmm, [trian_in,data_in], atol = 0.02, rtol = 2.0) + check_grad(test_trmm, [trian_in, data_in]) # test sumlogdiag - res_sumlogdiag = np.reshape(np.log(data_in),(4,4)) - test_sumlogdiag = mx.sym.linalg.sumlogdiag(data1) - check_symbolic_forward(test_sumlogdiag, [data_in], [res_sumlogdiag]) + res_sumlogdiag = np.reshape(np.log(data_in), (4, 4)) + test_sumlogdiag = mx.sym.linalg_sumlogdiag(data1) + check_fw(test_sumlogdiag, [data_in], [res_sumlogdiag]) if grad_check == 1: - check_numeric_gradient(test_sumlogdiag, [data_in], atol = 0.01, rtol = 2.0) - - # more elaborate example of cholesky factorization - matrix = [ 9, 3, -6, 12, 3, 26, -7, -11, -6, -7, 9, 7, 12, -11, 7, 65 ] - trian = [ 3, 0, 0, 0, 1, 5, 0, 0, -2, -1, 2, 0, 4, -3, 6, 2 ] - pow = [ 2, 1, 1, 1, 1, 4, 1, 1, 1, 1, 8, 1, 1, 1, 1, 16 ] - inv = [ 2.98333, 0.01667, 2.65, -0.83333, 0.01667, 0.05, 0.05, 0, 2.65, 0.05, 2.5, -0.75, -0.83333, 0, -0.75, 0.25 ] - ident = [ 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1 ] - - # Tests for numeric gradients for potrf/potri/trmm/trsm are suppressed by default - # as they are very volatile and may often report false negatives which - # have to be excluded by manual inspection. - grad_check = 0 + check_grad(test_sumlogdiag, [data_in]) + + # more elaborate example of Cholesky factorization + matrix = np.array([[9., 3., -6., 12.], + [3., 26., -7., -11.], + [-6., -7., 9., 7.], + [12., -11., 7., 65.]]) + trian = np.array([[3., 0., 0., 0.], + [1., 5., 0., 0.], + [-2., -1., 2., 0.], + [4., -3., 6., 2.]]) + pow = np.array([[2., 1., 1., 1.], + [1., 4., 1., 1.], + [1., 1., 8., 1.], + [1., 1., 1., 16.]]) + inv = np.array([[8.95/3., 0.05/3., 2.65, -2.5/3.], + [0.05/3., 0.05, 0.05, 0.], + [2.65, 0.05, 2.5, -0.75], + [-2.5/3., 0., -0.75, 0.25]]) + ident = np.eye(4) # test potrf - a = np.tile(np.array(matrix),3) - a = np.reshape(a,(3,1,4,4)) - r = np.tile(np.array(trian),3) - r = np.reshape(r,(3,1,4,4)) - check_symbolic_forward(test_potrf, [a], [r]) + test_potrf = mx.sym.linalg_potrf(_make_symm_symbol(data1, ndims=4)) + a = rep_3x(matrix, 4, 4) + r = rep_3x(trian, 4, 4) + check_fw(test_potrf, [a], [r]) if grad_check == 1: - check_numeric_gradient(test_potrf, [a], numeric_eps=1e-3, rtol=1e-2, atol=1e-1) + check_grad(test_potrf, [a]) #test potri - a = np.tile(np.array(trian),3) - a = np.reshape(a,(3,1,4,4)) - r = np.tile(np.array(inv),3) - r = np.reshape(r,(3,1,4,4)) - check_symbolic_forward(test_potri, [a], [r], atol=0.01) + data1_ltri = _make_lower_triangle_symm( + data1, ndims=4, m=4, dtype=dtype) + test_potri = mx.sym.linalg_potri(data1_ltri) + a = rep_3x(trian, 4, 4) + r = rep_3x(inv, 4, 4) + check_fw(test_potri, [a], [r]) if grad_check == 1: - check_numeric_gradient(test_potri, [a], numeric_eps=1e-3, rtol=1e-2, atol=1e-1) - - #test trsm - a = np.tile(np.array(trian),3) - a = np.reshape(a,(3,1,4,4)) - b = np.tile(np.array(matrix),3) - b = np.reshape(b,(3,1,4,4)) - r = 7*np.transpose(np.reshape(np.array(trian),(4,4))) - r = np.reshape(np.tile(np.reshape(r,(16)),3),(3,1,4,4)) - check_symbolic_forward(test_trsm, [a,b], [r]) + check_grad(test_potri, [a]) + + # test trsm + test_trsm = mx.sym.linalg_trsm(data1_ltri, data2, alpha=7.) + a = rep_3x(trian, 4, 4) + b = rep_3x(matrix, 4, 4) + r = rep_3x(7. * np.transpose(trian), 4, 4) + check_fw(test_trsm, [a, b], [r]) if grad_check == 1: - check_numeric_gradient(test_trsm, [a,b], numeric_eps=1e-3, rtol=1e-2, atol=1e-1) + check_grad(test_trsm, [a, b]) - test_trsm2 = mx.sym.linalg.trsm(data1,data2,alpha = -2, rightside = 1, transpose = 1) - r = -2*np.reshape(np.array(trian),(4,4)) - r = np.reshape(np.tile(np.reshape(r,(16)),3),(3,1,4,4)) - check_symbolic_forward(test_trsm2, [a,b], [r]) + test_trsm2 = mx.sym.linalg_trsm( + data1_ltri, data2, alpha=-2., rightside=True, transpose=True) + r = rep_3x(-2. * trian, 4, 4) + check_fw(test_trsm2, [a, b], [r]) if grad_check == 1: - check_numeric_gradient(test_trsm2, [a,b], numeric_eps=1e-3, rtol=1e-2, atol=1e-1) - - test_trsm3 = mx.sym.linalg.trsm(data1,data2,alpha = 0.50, transpose = 1) - b = np.transpose(np.reshape(np.array(trian),(4,4))) - b = np.reshape(np.tile(np.reshape(b,(16)),3),(3,1,4,4)) - r = 0.5*np.reshape(np.array(ident),(4,4)) - r = np.reshape(np.tile(np.reshape(r,(16)),3),(3,1,4,4)) - check_symbolic_forward(test_trsm3, [a,b], [r]) + check_grad(test_trsm2, [a, b]) + + test_trsm3 = mx.sym.linalg_trsm( + data1_ltri, data2, alpha=0.5, transpose=True) + b = rep_3x(np.transpose(trian), 4, 4) + r = rep_3x(0.5 * ident, 4, 4) + check_fw(test_trsm3, [a, b], [r]) if grad_check == 1: - check_numeric_gradient(test_trsm3, [a,b], numeric_eps=1e-3, rtol=1e-2, atol=1e-1) - - test_trsm4 = mx.sym.linalg.trsm(data1,data2,alpha = -0.5, rightside = 1) - b = np.tile(np.array(trian),3) - b = np.reshape(b,(3,1,4,4)) - r = -0.5*np.reshape(np.array(ident),(4,4)) - r = np.reshape(np.tile(np.reshape(r,(16)),3),(3,1,4,4)) - check_symbolic_forward(test_trsm4, [a,b], [r]) + check_grad(test_trsm3, [a, b]) + + test_trsm4 = mx.sym.linalg_trsm( + data1_ltri, data2, alpha=-0.5, rightside=True) + b = rep_3x(trian, 4, 4) + r = rep_3x(-0.5 * ident, 4, 4) + check_fw(test_trsm4, [a, b], [r]) if grad_check == 1: - check_numeric_gradient(test_trsm4, [a,b], numeric_eps=1e-3, rtol=1e-2, atol=1e-1) - - #test trmm - a = np.tile(np.array(trian),3) - a = np.reshape(a,(3,1,4,4)) - b = np.tile(np.array(matrix),3) - b = np.reshape(b,(3,1,4,4)) - r = 7*np.dot(np.reshape(np.array(matrix),(4,4)),np.transpose(np.reshape(np.array(trian),(4,4)))) - r = np.reshape(np.tile(np.reshape(r,(16)),3),(3,1,4,4)) - check_symbolic_forward(test_trmm, [a,b], [r]) + check_grad(test_trsm4, [a, b]) + + # test trmm + test_trmm = mx.sym.linalg_trmm( + data1_ltri, data2, alpha=7., transpose=True, rightside=True) + a = rep_3x(trian, 4, 4) + b = rep_3x(matrix, 4, 4) + r = rep_3x(7. * np.dot(matrix, trian.T), 4, 4) + check_fw(test_trmm, [a, b], [r]) if grad_check == 1: - check_numeric_gradient(test_trmm, [a,b], numeric_eps=1e-3, rtol=1e-2, atol=1e-1) + check_grad(test_trmm, [a, b]) - test_trmm2 = mx.sym.linalg.trmm(data1,data2,alpha = -2) - r = -2*np.dot(np.reshape(np.array(trian),(4,4)),np.reshape(np.array(matrix),(4,4))) - r = np.reshape(np.tile(np.reshape(r,(16)),3),(3,1,4,4)) - check_symbolic_forward(test_trmm2, [a,b], [r]) + test_trmm2 = mx.sym.linalg_trmm(data1_ltri, data2, alpha=-2.) + r = rep_3x(-2. * np.dot(trian, matrix), 4, 4) + check_fw(test_trmm2, [a, b], [r]) if grad_check == 1: - check_numeric_gradient(test_trmm2, [a,b], numeric_eps=1e-3, rtol=1e-2, atol=1e-1) + check_grad(test_trmm2, [a, b]) - test_trmm3 = mx.sym.linalg.trmm(data1,data2,rightside = 1) - r = np.dot(np.reshape(np.array(matrix),(4,4)),np.reshape(np.array(trian),(4,4))) - r = np.reshape(np.tile(np.reshape(r,(16)),3),(3,1,4,4)) - check_symbolic_forward(test_trmm3, [a,b], [r]) + test_trmm3 = mx.sym.linalg_trmm(data1_ltri, data2, rightside=True) + r = rep_3x(np.dot(matrix, trian), 4, 4) + check_fw(test_trmm3, [a, b], [r]) if grad_check == 1: - check_numeric_gradient(test_trmm3, [a,b], numeric_eps=1e-3, rtol=1e-2, atol=1e-1) + check_grad(test_trmm3, [a, b]) - test_trmm4 = mx.sym.linalg.trmm(data1,data2,alpha = 1.2,transpose = 1) - r = 1.2*np.dot(np.transpose(np.reshape(np.array(trian),(4,4))),np.reshape(np.array(matrix),(4,4))) - r = np.reshape(np.tile(np.reshape(r,(16)),3),(3,1,4,4)) - check_symbolic_forward(test_trmm4, [a,b], [r]) + test_trmm4 = mx.sym.linalg_trmm( + data1_ltri, data2, alpha=1.2, transpose=True) + r = rep_3x(1.2 * np.dot(trian.T, matrix), 4, 4) + check_fw(test_trmm4, [a, b], [r]) if grad_check == 1: - check_numeric_gradient(test_trmm4, [a,b], numeric_eps=1e-3, rtol=1e-2, atol=1e-1) + check_grad(test_trmm4, [a, b]) # test sumlogdiag - a = np.array(pow) - a = np.tile(a,3) - a = np.reshape(a,(3,1,4,4)) - r = 10*np.log(np.array([2])) - r = np.tile(r,3) - r = np.reshape(r,(3)) - check_symbolic_forward(test_sumlogdiag, [a], [r]) + a = rep_3x(pow, 4, 4) + r = np.reshape(np.tile(10. * np.log(np.array([2.])), 3), (3,)) + check_fw(test_sumlogdiag, [a], [r]) if grad_check == 1: - check_numeric_gradient(test_sumlogdiag, [a]) + check_grad(test_sumlogdiag, [a]) + + +# Tests for new operators linalg_syrk, linalg_gelqf + +def _gelqf_combined_symbol(a): + q, l = mx.sym.linalg_gelqf(a) + q_qt = mx.sym.linalg_syrk(q, transpose=False, alpha=1., name='Q_times_Qt') + l_q = mx.sym.linalg_trmm(l, q, alpha=1., name='L_times_Q') + return mx.sym.Group([q_qt, l_q]) + +# NOTE: If we leave the unused output dangling, things break if dtype=np.float64. Namely, the +# backward gradient for the unused output is of dtype np.float32 then. +# ==> Very annoying! +def _gelqf_first_output(a): + q, l = mx.sym.linalg_gelqf(a) + bogus_scal = mx.sym.sum(mx.sym.BlockGrad(l), axis=(), keepdims=True) * 0.0 + return mx.sym.broadcast_add(q, bogus_scal) + +def _gelqf_second_output(a): + q, l = mx.sym.linalg_gelqf(a) + bogus_scal = mx.sym.sum(mx.sym.BlockGrad(q), axis=(), keepdims=True) * 0.0 + return mx.sym.broadcast_add(l, bogus_scal) + +def test_laop_2(): + # Operators implemented for CPU only currently + if default_context() != mx.cpu(): + return + np.random.seed(1896893923) + dtype = np.float64 + rtol_fw = 1e-7 + atol_fw = 1e-9 + num_eps = 1e-6 + rtol_bw = 1e-5 + atol_bw = 1e-6 + # enable numerical checking of gradients + grad_check = 1 + + data1 = mx.symbol.Variable('data1') + + check_fw = lambda sym, location, expected :\ + check_symbolic_forward(sym, location, expected, rtol=rtol_fw, + atol=atol_fw, dtype=dtype) + check_grad = lambda sym, location:\ + check_numeric_gradient(sym, location, numeric_eps=num_eps, rtol=rtol_bw, + atol=atol_bw, dtype=dtype) + rep_3x = lambda a, m, n :\ + np.reshape(np.tile(np.array(a).flatten(), 3), (3, 1, m, n)) + + # Tests for linalg_syrk + mnalpha_lst = [(2, 3, 1.), (5, 3, -2.), (1, 6, 5.), (3, 3, 0.5), (4, 1, 10.), (1, 1, 1.)] + for m, n, alpha in mnalpha_lst: + #print('m={}, n={}, alpha={}'.format(m, n, alpha)) + data_in1 = np.random.uniform(1, 10, (m, n)) + res_syrk1 = alpha * np.dot(data_in1, data_in1.T) + test_syrk1 = mx.sym.linalg_syrk(data1, transpose=False, alpha=alpha) + check_fw(test_syrk1, [data_in1], [res_syrk1]) + if grad_check == 1: + check_grad(test_syrk1, [data_in1]) + res_syrk2 = alpha * np.dot(data_in1.T, data_in1) + test_syrk2 = mx.sym.linalg_syrk(data1, transpose=True, alpha=alpha) + check_fw(test_syrk2, [data_in1], [res_syrk2]) + if grad_check == 1: + check_grad(test_syrk2, [data_in1]) + # Batch mode (3x the same thing) + a_batch = rep_3x(data_in1, m, n) + r1_batch = rep_3x(res_syrk1, m, m) + check_fw(test_syrk1, [a_batch], [r1_batch]) + if grad_check == 1: + check_grad(test_syrk1, [a_batch]) + r2_batch = rep_3x(res_syrk2, n, n) + check_fw(test_syrk2, [a_batch], [r2_batch]) + if grad_check == 1: + check_grad(test_syrk2, [a_batch]) + + # Tests for linalg_gelqf + test_gelqf2 = _gelqf_combined_symbol(data1) # Outputs (dot(Q, Q.T), dot(L, Q)) + test_gelqf_q = _gelqf_first_output(data1) # Output Q (L is not dangling) + test_gelqf_l = _gelqf_second_output(data1) # Output L (Q is not dangling) + mn_lst = [(4, 4), (1, 1), (5, 20), (1, 10), (15, 50)] + for m, n in mn_lst: + #print('m={}, n={}'.format(m, n)) + data_in1 = np.random.normal(0., 10., (m, n)) + res_eye = np.eye(m) + res_a = data_in1 + check_fw(test_gelqf2, [data_in1], [res_eye, res_a]) + if grad_check == 1: + # A => Q + check_grad(test_gelqf_q, [data_in1]) + # A => L + check_grad(test_gelqf_l, [data_in1]) + # Batch mode (3x the same thing) + a_batch = rep_3x(data_in1, m, n) + reye_batch = rep_3x(res_eye, m, m) + ra_batch = a_batch + check_fw(test_gelqf2, [a_batch], [reye_batch, ra_batch]) + if grad_check == 1: + # A => Q + check_grad(test_gelqf_q, [a_batch]) + # A => L + check_grad(test_gelqf_l, [a_batch]) def test_stack():