diff --git a/src/operator/numpy/linalg/np_qr-inl.h b/src/operator/numpy/linalg/np_qr-inl.h index 0f332e4a661a..c2045205038f 100644 --- a/src/operator/numpy/linalg/np_qr-inl.h +++ b/src/operator/numpy/linalg/np_qr-inl.h @@ -483,19 +483,53 @@ struct assign_helper { } }; +// backprop helper to get y, v +struct QrBackHelper_G1 { + template + MSHADOW_XINLINE static void Map(const int k, const int m, const int n, const DType *in_data, + const int ldin, DType *out_data, const int ldout) { + const int offin(k * m * ldin); + const int offout(k * m * ldout); + for (index_t i = 0; i < m; ++i) { + for (index_t j = 0; j < n - m; ++j) { + out_data[offout + i * ldout + j] = in_data[offin + m + i * ldin + j]; + } + } + } +}; + +// backprop helper to get da from dx, dy +struct QrBackHelper_G2 { + template + MSHADOW_XINLINE static void Map(const int k, const int m, const int n, const DType *in_data_x, + const int ldinx, const DType *in_data_y, const int ldiny, + DType *out_data, const int ldout) { + const int offiny(k * m * ldiny); + const int offinx(k * m * ldinx); + const int offout(k * m * ldout); + for (index_t i = 0; i < m; ++i) { + for (index_t j = 0; j < n - m; ++j) { + out_data[offout + m + i * ldout + j] = in_data_y[offiny + i * ldiny + j]; + } + for (index_t j = 0; j < m; ++j) { + out_data[offout + i * ldout + j] = in_data_x[offinx + i * ldinx + j]; + } + } + } +}; + +// Reference https://journals.aps.org/prx/pdf/10.1103/PhysRevX.9.031041 struct qr_backward { template static void op(const Tensor& dA, const Tensor& dQ, const Tensor& dR, - const Tensor& A, const Tensor& Q, const Tensor& R, const Tensor& M, const OpContext& ctx, const nnvm::NodeAttrs& attrs) { - // Implements case m >= n; da = [dq + q@copyltu(M))]@r**(-T) + // Implements da = [dq + q@copyltu(M))]@r**(-T) // Where M = r@(dr**T) - (dq**T)@q - // Reference: https://arxiv.org/abs/1710.08717 Stream *s = ctx.get_stream(); if (dQ.dptr_ != dA.dptr_) Copy(dA, dQ, s); // M = R@dR_T @@ -514,15 +548,30 @@ struct qr_backward { template size_t QrBackwardWorkspaceSize(const TBlob& a, + const TBlob& q, const TBlob& r, const TBlob& grad_a) { + const mxnet::TShape& a_shape = a.shape_; + const int a_ndim = a_shape.ndim(); + const int n = a.size(a_ndim - 1); + const int m = a.size(a_ndim - 2); + if (0U == a.Size()) { return 0U; } MSHADOW_SGL_DBL_TYPE_SWITCH(grad_a.type_flag_, DType, { size_t work_space_size = 0; - // for grad a and M work_space_size += a.Size(); - work_space_size += r.Size(); + if (m >= n) { + work_space_size += r.Size(); + } else { + const mxnet::TShape& q_shape = q.shape_; + mxnet::TShape v_shape(q_shape); + v_shape[a_ndim - 1] = n - m; + // allocate space for: m, u, dq_prime, du, dx (shaped like Q) + work_space_size += 5 * q.Size(); + // allocate space for: y, dv (shaped like V, the partition of R) + work_space_size += 2 * v_shape.Size(); + } return work_space_size * sizeof(DType); }); LOG(FATAL) << "InternalError: cannot reach here"; @@ -542,8 +591,10 @@ void QrBackwardImpl(const TBlob& grad_a, const nnvm::NodeAttrs& attrs) { Stream *s = ctx.get_stream(); const mxnet::TShape& a_shape = a.shape_; + const mxnet::TShape& q_shape = q.shape_; const mxnet::TShape& r_shape = r.shape_; const int a_ndim = a_shape.ndim(); + const int m = a.size(a_ndim - 2); const int n = a.size(a_ndim - 1); if (kNullOp == req[0]) { return; } @@ -551,27 +602,105 @@ void QrBackwardImpl(const TBlob& grad_a, if (0U == a_shape.Size()) { return; } MSHADOW_SGL_DBL_TYPE_SWITCH(grad_a.type_flag_, DType, { - // case m >= n; Q of same shape with A and R is (n, n) - DType *m_ptr = reinterpret_cast(workspace.dptr_); - DType *grad_a_ptr = m_ptr + r_shape.Size(); - TBlob temp_m(m_ptr, r_shape, xpu::kDevMask); + // common for all shapes (m, n) + DType *grad_a_ptr = reinterpret_cast(workspace.dptr_); TBlob grad_a_data(grad_a_ptr, a_shape, xpu::kDevMask); - // dR_T - mxnet_op::Kernel::Launch( - s, r_shape.Size(), grad_r.dptr(), m_ptr, n, n, n * n); - - qr_backward::op(grad_a_data.FlatToKD(s), - grad_q.FlatToKD(s), - grad_r.FlatToKD(s), - a.FlatToKD(s), - q.FlatToKD(s), - r.FlatToKD(s), - temp_m.FlatToKD(s), - ctx, attrs); - + if (m >= n) { + // Q of same shape with A (m, n) and R is (n, n) + DType *m_ptr = grad_a_ptr + a_shape.Size(); + TBlob temp_m(m_ptr, r_shape, xpu::kDevMask); + // dR_T + mxnet_op::Kernel::Launch( + s, r_shape.Size(), grad_r.dptr(), m_ptr, n, n, n * n); + qr_backward::op(grad_a_data.FlatToKD(s), + grad_q.FlatToKD(s), + grad_r.FlatToKD(s), + q.FlatToKD(s), + r.FlatToKD(s), + temp_m.FlatToKD(s), + ctx, attrs); + } else { + // R is same shape with A (m, n) and Q is (m, m) + // Partition A = (X | Y); R = (U | V) + // X and U are (m, m); Y and V are (m, n - m) + mxnet::TShape v_shape(q_shape); + v_shape[a_ndim - 1] = n - m; + + DType *m_ptr = grad_a_ptr + a_shape.Size(); + DType *u_ptr = m_ptr + q_shape.Size(); + DType *dq_prime_ptr = u_ptr + q_shape.Size(); + DType *dv_ptr = dq_prime_ptr + q_shape.Size(); + DType *y_ptr = dv_ptr + v_shape.Size(); + DType *du_ptr = y_ptr + v_shape.Size(); + DType *dx_ptr = du_ptr + q_shape.Size(); + + TBlob temp_m(m_ptr, q_shape, xpu::kDevMask); + TBlob u_data(u_ptr, q_shape, xpu::kDevMask); + TBlob dq_prime_data(dq_prime_ptr, q_shape, xpu::kDevMask); + TBlob dv_data(dv_ptr, v_shape, xpu::kDevMask); + TBlob y_data(y_ptr, v_shape, xpu::kDevMask); + TBlob du_data(du_ptr, q_shape, xpu::kDevMask); + TBlob dx_data(dx_ptr, q_shape, xpu::kDevMask); + + Tensor R = r.FlatToKD(s); + Tensor dR = grad_r.FlatToKD(s); + Tensor Q = q.FlatToKD(s); + Tensor dQ = grad_q.FlatToKD(s); + Tensor dQ_prime = dq_prime_data.FlatToKD(s); + Tensor A = a.FlatToKD(s); + Tensor dA = grad_a_data.FlatToKD(s); + Tensor U = u_data.FlatToKD(s); + Tensor dU = du_data.FlatToKD(s); + Tensor dV = dv_data.FlatToKD(s); + Tensor Y = y_data.FlatToKD(s); + Tensor dX = dx_data.FlatToKD(s); + Tensor M = temp_m.FlatToKD(s); + + // U + for (index_t i = 0; i < R.size(0); ++i) { + const Tensor& Ri = R[i]; + const Tensor& Ui = U[i]; + Tensor Um(Ri.dptr_, Shape2(m, m), Ri.stride_, s); + Copy(Ui, Um, s); + } + // dU + for (index_t i = 0; i < dR.size(0); ++i) { + const Tensor& dRi = dR[i]; + const Tensor& dUi = dU[i]; + Tensor dUm(dRi.dptr_, Shape2(m, m), dRi.stride_, s); + Copy(dUi, dUm, s); + } + // Y + mxnet_op::Kernel::Launch( + s, A.size(0), m, n, A.dptr_, A.stride_, Y.dptr_, Y.stride_); + // dV + mxnet_op::Kernel::Launch( + s, dR.size(0), m, n, dR.dptr_, dR.stride_, dV.dptr_, dV.stride_); + // store dU_T in M + mxnet_op::Kernel::Launch( + s, q_shape.Size(), dU.dptr_, m_ptr, m, m, m * m); + // dq_prime = dQ + Copy(dQ_prime, dQ, s); + // dq_prime = dQ+Y@dV.T + gemm::op(Y, dV, dQ_prime, DType(1.0), DType(1.0), false, true, s); + // dX = op call + qr_backward::op(dX, + dQ_prime, + dU, + Q, + U, + M, + ctx, attrs); + // dY = Q@dV; reuse Y memory for dY + gemm::op(Q, dV, Y, DType(1.0), DType(0.0), false, false, s); + // copy dX and dY to dA + mxnet_op::Kernel::Launch( + s, dA.size(0), m, n, dX.dptr_, dX.stride_, Y.dptr_, Y.stride_, dA.dptr_, dA.stride_); + } + // common for all shapes MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { - mxnet_op::Kernel, xpu>::Launch( - s, a_shape.Size(), grad_a_data.dptr(), grad_a.dptr()); + mxnet_op::Kernel, xpu>::Launch( + s, a_shape.Size(), grad_a_data.dptr(), grad_a.dptr()); }); }); } @@ -594,14 +723,8 @@ void NumpyLaQrBackward(const nnvm::NodeAttrs& attrs, const TBlob& q = inputs[3]; const TBlob& r = inputs[4]; const TBlob& grad_a = outputs[0]; - const int a_ndim = a.shape_.ndim(); - const int n = a.size(a_ndim - 1); - const int m = a.size(a_ndim - 2); - - CHECK_LE(n, m) - << "QrBackward not implemented when ncols > nrows"; - size_t workspace_size = QrBackwardWorkspaceSize(a, r, grad_a); + size_t workspace_size = QrBackwardWorkspaceSize(a, q, r, grad_a); Tensor workspace = ctx.requested[0] .get_space_typed(Shape1(workspace_size), ctx.get_stream()); QrBackwardImpl(grad_a, grad_q, grad_r, a, q, r, req, workspace, ctx, attrs); diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 165b0f23ed2e..5fb5c687881b 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -5134,42 +5134,102 @@ def __init__(self): def hybrid_forward(self, F, data): return F.np.linalg.qr(data) - def get_expected_grad(a, q, r): + def get_expected_grad(a, q, r, dq, dr): + # all shapes (..., m, n) + # allow feeding different dq and dr values if 0 in r.shape: return r - def copyltu(M): - # shape of M is [batch, m, m] - eye = _np.array([_np.eye(M.shape[-1]) for i in range(M.shape[0])]) - lower = _np.tril(M) - eye * M - lower_mask = _np.tril(_np.ones_like(M)) - ret = lower_mask * M + lower.swapaxes(-1, -2) - return ret - shape_r = r.shape - shape_q = q.shape - shape_a = a.shape - r = r.reshape(-1, shape_r[-2], shape_r[-1]) - q = q.reshape(-1, shape_q[-2], shape_q[-1]) - dq = _np.ones_like(q) - dr = _np.ones_like(r) - dq_t = dq.swapaxes(-1, -2) - dr_t = dr.swapaxes(-1, -2) - r_inv = _np.linalg.inv(r) - r_inv_t = r_inv.swapaxes(-1, -2) - r_t = r.swapaxes(-1, -2) - # Get M - M = _np.matmul(r, dr_t) - _np.matmul(dq_t, q) - da = _np.matmul(dq + _np.matmul(q, copyltu(M)), r_inv_t) - return da.reshape(a.shape) + def _copyltu(M): + eye = _np.array([_np.eye(M.shape[-1]) for i in range(M.shape[0])]) + lower = _np.tril(M) - eye * M + lower_mask = _np.tril(_np.ones_like(M)) + ret = lower_mask * M + lower.swapaxes(-1, -2) + return ret + def _case_m_ge_n(a, q, r, dq, dr): + dq_t = dq.swapaxes(-1, -2) + dr_t = dr.swapaxes(-1, -2) + r_inv = _np.linalg.inv(r) + r_inv_t = r_inv.swapaxes(-1, -2) + r_t = r.swapaxes(-1, -2) + # Get M + M = _np.matmul(r, dr_t) - _np.matmul(dq_t, q) + da = _np.matmul(dq + _np.matmul(q, _copyltu(M)), r_inv_t) + return da + m, n = a.shape[-2], a.shape[-1] + x = a[..., :, :m] + x_shape = x.shape + y = a[..., :, m:] + y_shape = y.shape + u = r[..., :, :m] + v = r[..., :, m:] + dv = dr[..., :, m:] + du = dr[..., :, :m] + q = q.reshape(-1, q.shape[-2], q.shape[-1]) + u = u.reshape(-1, u.shape[-2], u.shape[-1]) + dq = dq.reshape(-1, q.shape[-2], q.shape[-1]) + du = du.reshape(-1, du.shape[-2], du.shape[-1]) + if m >= n: + dx = _case_m_ge_n(x, q, u, dq, du).reshape(x_shape) + return dx + else: + dv = dv.reshape(-1, dv.shape[-2], dv.shape[-1]) + y = y.reshape(-1, y.shape[-2], y.shape[-1]) + dy = _np.matmul(q, dv).reshape(y_shape) + dq_prime = dq + _np.matmul(y, dv.swapaxes(-1, -2)) + dx = _case_m_ge_n(x, q, u, dq_prime, du).reshape(x_shape) + da = _np.concatenate([dx, dy], axis=-1) + return da + + def _analytical_jacobian(x, dy, Q, R, Q_, R_, k): + x_size = _np.prod(x.shape) + dy_size = _np.prod(dy.shape) + # jacobian has data_np size number of rows and dQ or dR size number of columns + jacobian = _np.zeros((x_size, dy_size)) + # dQ and dR have all elements equal to zero to begin with + dy_data = _np.zeros(dy.shape) + dy_data_flat = dy_data.ravel() + for col in range(dy_size): + # we only feed dQ or dR with 1 element changed to 1 at a time + dy_data_flat[col] = 1 + ret_ = dy_data_flat.reshape(dy.shape) + if k == 0: + # k is 0 when dy is dQ + jacobian[:, col] = get_expected_grad(x, dy, R, ret_, R_).ravel() + else: + # k is 1 when dy is dR + jacobian[:, col] = get_expected_grad(x, Q, dy, Q_, ret_).ravel() + dy_data_flat[col] = 0 + return jacobian + + def _numerical_jacobian(x, y, delta, k, dtype): + # compute central differences + x_size = _np.prod(x.shape) + y_size = _np.prod(y.shape) + scale = _np.asarray(2 * delta)[()] + # jacobian has data_np size number of rows and Q or R size number of columns + jacobian_num = _np.zeros((x_size, y_size)) + for row in range(x_size): + x_pos = x.copy() + x_neg = x.copy() + # add delta to one element of data_np at a time + x_pos.ravel().view(dtype)[row] += delta # one element in x is added delta + # get qr decomposition of new input with one changed element + ret_pos = np.linalg.qr(np.array(x_pos))[k] + # subtract delta from input data_np one element at a time + x_neg.ravel().view(dtype)[row] -= delta + # get qr decomposition of new input with one changed element + ret_neg = np.linalg.qr(np.array(x_neg))[k] + # get central differences + diff = (ret_pos - ret_neg) / scale + jacobian_num[row, :] = diff.asnumpy().ravel().view(dtype) + return jacobian_num def well_conditioned_rectang_matrix_2D(shape, max_cond=4): m, n = shape[-2], shape[-1] while 1: - M1 = _np.random.uniform(-10, 10, (m, n)) - Q1, R1 = _np.linalg.qr(M1) - s = _np.ones(n) - D = _np.diag(s) - M2 =_np.random.uniform(-10, 10, (n, n)) - Q2, R2 = _np.linalg.qr(M2) + Q1, R1 = _np.linalg.qr(_np.random.uniform(-10, 10, (m, m))) + D = _np.eye(m, n) + Q2, R2 = _np.linalg.qr(_np.random.uniform(-10, 10, (n, n))) a = _np.matmul(_np.matmul(Q1, D), _np.swapaxes(Q2, -1, -2)) if (_np.linalg.cond(a, 2) < max_cond): return a @@ -5209,7 +5269,6 @@ def check_qr(q, r, a_np): (3, 3), (5, 5), (8, 8), - (4, 5), (4, 6), (5, 4), (6, 5), @@ -5223,19 +5282,16 @@ def check_qr(q, r, a_np): (4, 2, 2, 1), (2, 3, 4, 3) ] - dtypes = ['float32', 'float64'] + dtypes = ['float64', 'float32'] for hybridize, shape, dtype in itertools.product([False, True], shapes, dtypes): rtol = atol = 0.01 test_qr = TestQR() if hybridize: test_qr.hybridize() - if 0 in shape: data_np = _np.ones(shape) - elif shape[-2] >= shape[-1]: - data_np = well_conditioned_rectang_matrix_nD(shape, max_cond=4) else: - data_np = _np.random.uniform(-10.0, 10.0, shape) + data_np = well_conditioned_rectang_matrix_nD(shape, max_cond=4) data_np = _np.array(data_np, dtype=dtype) data = np.array(data_np, dtype=dtype) @@ -5245,13 +5301,24 @@ def check_qr(q, r, a_np): Q, R = ret[0], ret[1] check_qr(Q, R, data_np) - # Only shapes m >= n have gradient - if 0 not in R.shape and shape[-2] >= shape[-1]: + if 0 not in R.shape: assert data.grad.shape == data_np.shape - backward_expected = get_expected_grad(data_np, Q.asnumpy(), R.asnumpy()) + backward_expected = get_expected_grad(data_np, Q.asnumpy(), R.asnumpy(), + _np.ones(Q.shape), _np.ones(R.shape)) mx.autograd.backward(ret) assert_almost_equal(data.grad.asnumpy(), backward_expected, rtol=rtol, atol=atol) - + # for a few cases, check that the analytical jacobian is equal to + # numerical jacobian computed via central differences + # restrict this check to float64 for numerical precision + if dtype == 'float64' and len(shape) == 2: + epsilon = _np.finfo(dtype).eps + delta = 0.1 * epsilon**(1.0 / 3.0) # Optimal delta for central differences + for k, b in enumerate(ret): + qr_num = _numerical_jacobian(data_np, b.asnumpy(), delta, k, dtype) + qr_a = _analytical_jacobian(x=data_np, dy=b.asnumpy(), Q=Q.asnumpy(), + R=R.asnumpy(), Q_=_np.zeros(Q.shape), + R_=_np.zeros(R.shape), k=k) + assert_almost_equal(qr_num, qr_a, rtol=rtol, atol=atol) # check imperative once more; mode='reduced' is default # behavior and optional parameter in original numpy ret = np.linalg.qr(data, mode='reduced')