This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add qr backward for wide matrices with m < n
- Loading branch information
Showing
2 changed files
with
261 additions
and
71 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -483,19 +483,53 @@ struct assign_helper { | |
} | ||
}; | ||
|
||
// backprop helper to get y, v | ||
struct QrBackHelper_G1 { | ||
template<typename DType> | ||
MSHADOW_XINLINE static void Map(const int k, const int m, const int n, const DType *in_data, | ||
const int ldin, DType *out_data, const int ldout) { | ||
const int offin(k * m * ldin); | ||
const int offout(k * m * ldout); | ||
for (index_t i = 0; i < m; ++i) { | ||
for (index_t j = 0; j < n - m; ++j) { | ||
out_data[offout + i * ldout + j] = in_data[offin + m + i * ldin + j]; | ||
} | ||
} | ||
} | ||
}; | ||
|
||
// backprop helper to get da from dx, dy | ||
struct QrBackHelper_G2 { | ||
template<typename DType> | ||
MSHADOW_XINLINE static void Map(const int k, const int m, const int n, const DType *in_data_x, | ||
const int ldinx, const DType *in_data_y, const int ldiny, | ||
DType *out_data, const int ldout) { | ||
const int offiny(k * m * ldiny); | ||
const int offinx(k * m * ldinx); | ||
const int offout(k * m * ldout); | ||
for (index_t i = 0; i < m; ++i) { | ||
for (index_t j = 0; j < n - m; ++j) { | ||
out_data[offout + m + i * ldout + j] = in_data_y[offiny + i * ldiny + j]; | ||
} | ||
for (index_t j = 0; j < m; ++j) { | ||
out_data[offout + i * ldout + j] = in_data_x[offinx + i * ldinx + j]; | ||
} | ||
} | ||
} | ||
}; | ||
|
||
// Reference https://journals.aps.org/prx/pdf/10.1103/PhysRevX.9.031041 | ||
struct qr_backward { | ||
template<typename xpu, typename DType> | ||
static void op(const Tensor<xpu, 3, DType>& dA, | ||
const Tensor<xpu, 3, DType>& dQ, | ||
const Tensor<xpu, 3, DType>& dR, | ||
const Tensor<xpu, 3, DType>& A, | ||
const Tensor<xpu, 3, DType>& Q, | ||
const Tensor<xpu, 3, DType>& R, | ||
const Tensor<xpu, 3, DType>& M, | ||
const OpContext& ctx, const nnvm::NodeAttrs& attrs) { | ||
// Implements case m >= n; da = [dq + q@copyltu(M))]@r**(-T) | ||
// Implements da = [dq + q@copyltu(M))]@r**(-T) | ||
// Where M = r@(dr**T) - (dq**T)@q | ||
// Reference: https://arxiv.org/abs/1710.08717 | ||
Stream<xpu> *s = ctx.get_stream<xpu>(); | ||
if (dQ.dptr_ != dA.dptr_) Copy(dA, dQ, s); | ||
// M = R@dR_T | ||
|
@@ -514,15 +548,30 @@ struct qr_backward { | |
|
||
template<typename xpu> | ||
size_t QrBackwardWorkspaceSize(const TBlob& a, | ||
const TBlob& q, | ||
const TBlob& r, | ||
const TBlob& grad_a) { | ||
const mxnet::TShape& a_shape = a.shape_; | ||
const int a_ndim = a_shape.ndim(); | ||
const int n = a.size(a_ndim - 1); | ||
const int m = a.size(a_ndim - 2); | ||
|
||
if (0U == a.Size()) { return 0U; } | ||
|
||
MSHADOW_SGL_DBL_TYPE_SWITCH(grad_a.type_flag_, DType, { | ||
size_t work_space_size = 0; | ||
// for grad a and M | ||
work_space_size += a.Size(); | ||
work_space_size += r.Size(); | ||
if (m >= n) { | ||
work_space_size += r.Size(); | ||
} else { | ||
const mxnet::TShape& q_shape = q.shape_; | ||
mxnet::TShape v_shape(q_shape); | ||
v_shape[a_ndim - 1] = n - m; | ||
// allocate space for: m, u, dq_prime, du, dx (shaped like Q) | ||
work_space_size += 5 * q.Size(); | ||
// allocate space for: y, dv (shaped like V, the partition of R) | ||
work_space_size += 2 * v_shape.Size(); | ||
} | ||
return work_space_size * sizeof(DType); | ||
}); | ||
LOG(FATAL) << "InternalError: cannot reach here"; | ||
|
@@ -542,36 +591,116 @@ void QrBackwardImpl(const TBlob& grad_a, | |
const nnvm::NodeAttrs& attrs) { | ||
Stream<xpu> *s = ctx.get_stream<xpu>(); | ||
const mxnet::TShape& a_shape = a.shape_; | ||
const mxnet::TShape& q_shape = q.shape_; | ||
const mxnet::TShape& r_shape = r.shape_; | ||
const int a_ndim = a_shape.ndim(); | ||
const int m = a.size(a_ndim - 2); | ||
const int n = a.size(a_ndim - 1); | ||
|
||
if (kNullOp == req[0]) { return; } | ||
|
||
if (0U == a_shape.Size()) { return; } | ||
|
||
MSHADOW_SGL_DBL_TYPE_SWITCH(grad_a.type_flag_, DType, { | ||
// case m >= n; Q of same shape with A and R is (n, n) | ||
DType *m_ptr = reinterpret_cast<DType*>(workspace.dptr_); | ||
DType *grad_a_ptr = m_ptr + r_shape.Size(); | ||
TBlob temp_m(m_ptr, r_shape, xpu::kDevMask); | ||
// common for all shapes (m, n) | ||
DType *grad_a_ptr = reinterpret_cast<DType*>(workspace.dptr_); | ||
TBlob grad_a_data(grad_a_ptr, a_shape, xpu::kDevMask); | ||
// dR_T | ||
mxnet_op::Kernel<QrTypeTransposeHelper, xpu>::Launch( | ||
s, r_shape.Size(), grad_r.dptr<DType>(), m_ptr, n, n, n * n); | ||
|
||
qr_backward::op(grad_a_data.FlatToKD<xpu, 3, DType>(s), | ||
grad_q.FlatToKD<xpu, 3, DType>(s), | ||
grad_r.FlatToKD<xpu, 3, DType>(s), | ||
a.FlatToKD<xpu, 3, DType>(s), | ||
q.FlatToKD<xpu, 3, DType>(s), | ||
r.FlatToKD<xpu, 3, DType>(s), | ||
temp_m.FlatToKD<xpu, 3, DType>(s), | ||
ctx, attrs); | ||
|
||
if (m >= n) { | ||
// Q of same shape with A (m, n) and R is (n, n) | ||
DType *m_ptr = grad_a_ptr + a_shape.Size(); | ||
TBlob temp_m(m_ptr, r_shape, xpu::kDevMask); | ||
// dR_T | ||
mxnet_op::Kernel<QrTypeTransposeHelper, xpu>::Launch( | ||
s, r_shape.Size(), grad_r.dptr<DType>(), m_ptr, n, n, n * n); | ||
qr_backward::op(grad_a_data.FlatToKD<xpu, 3, DType>(s), | ||
grad_q.FlatToKD<xpu, 3, DType>(s), | ||
grad_r.FlatToKD<xpu, 3, DType>(s), | ||
q.FlatToKD<xpu, 3, DType>(s), | ||
r.FlatToKD<xpu, 3, DType>(s), | ||
temp_m.FlatToKD<xpu, 3, DType>(s), | ||
ctx, attrs); | ||
} else { | ||
// R is same shape with A (m, n) and Q is (m, m) | ||
// Partition A = (X | Y); R = (U | V) | ||
// X and U are (m, m); Y and V are (m, n - m) | ||
mxnet::TShape v_shape(q_shape); | ||
v_shape[a_ndim - 1] = n - m; | ||
|
||
DType *m_ptr = grad_a_ptr + a_shape.Size(); | ||
DType *u_ptr = m_ptr + q_shape.Size(); | ||
DType *dq_prime_ptr = u_ptr + q_shape.Size(); | ||
DType *dv_ptr = dq_prime_ptr + q_shape.Size(); | ||
DType *y_ptr = dv_ptr + v_shape.Size(); | ||
DType *du_ptr = y_ptr + v_shape.Size(); | ||
DType *dx_ptr = du_ptr + q_shape.Size(); | ||
|
||
TBlob temp_m(m_ptr, q_shape, xpu::kDevMask); | ||
TBlob u_data(u_ptr, q_shape, xpu::kDevMask); | ||
TBlob dq_prime_data(dq_prime_ptr, q_shape, xpu::kDevMask); | ||
TBlob dv_data(dv_ptr, v_shape, xpu::kDevMask); | ||
TBlob y_data(y_ptr, v_shape, xpu::kDevMask); | ||
TBlob du_data(du_ptr, q_shape, xpu::kDevMask); | ||
TBlob dx_data(dx_ptr, q_shape, xpu::kDevMask); | ||
|
||
Tensor<xpu, 3, DType> R = r.FlatToKD<xpu, 3, DType>(s); | ||
Tensor<xpu, 3, DType> dR = grad_r.FlatToKD<xpu, 3, DType>(s); | ||
Tensor<xpu, 3, DType> Q = q.FlatToKD<xpu, 3, DType>(s); | ||
Tensor<xpu, 3, DType> dQ = grad_q.FlatToKD<xpu, 3, DType>(s); | ||
Tensor<xpu, 3, DType> dQ_prime = dq_prime_data.FlatToKD<xpu, 3, DType>(s); | ||
Tensor<xpu, 3, DType> A = a.FlatToKD<xpu, 3, DType>(s); | ||
Tensor<xpu, 3, DType> dA = grad_a_data.FlatToKD<xpu, 3, DType>(s); | ||
Tensor<xpu, 3, DType> U = u_data.FlatToKD<xpu, 3, DType>(s); | ||
Tensor<xpu, 3, DType> dU = du_data.FlatToKD<xpu, 3, DType>(s); | ||
Tensor<xpu, 3, DType> dV = dv_data.FlatToKD<xpu, 3, DType>(s); | ||
Tensor<xpu, 3, DType> Y = y_data.FlatToKD<xpu, 3, DType>(s); | ||
Tensor<xpu, 3, DType> dX = dx_data.FlatToKD<xpu, 3, DType>(s); | ||
Tensor<xpu, 3, DType> M = temp_m.FlatToKD<xpu, 3, DType>(s); | ||
|
||
// U | ||
for (index_t i = 0; i < R.size(0); ++i) { | ||
const Tensor<xpu, 2, DType>& Ri = R[i]; | ||
const Tensor<xpu, 2, DType>& Ui = U[i]; | ||
Tensor<xpu, 2, DType> Um(Ri.dptr_, Shape2(m, m), Ri.stride_, s); | ||
Copy(Ui, Um, s); | ||
} | ||
// dU | ||
for (index_t i = 0; i < dR.size(0); ++i) { | ||
const Tensor<xpu, 2, DType>& dRi = dR[i]; | ||
const Tensor<xpu, 2, DType>& dUi = dU[i]; | ||
Tensor<xpu, 2, DType> dUm(dRi.dptr_, Shape2(m, m), dRi.stride_, s); | ||
Copy(dUi, dUm, s); | ||
} | ||
// Y | ||
mxnet_op::Kernel<QrBackHelper_G1, xpu>::Launch( | ||
s, A.size(0), m, n, A.dptr_, A.stride_, Y.dptr_, Y.stride_); | ||
// dV | ||
mxnet_op::Kernel<QrBackHelper_G1, xpu>::Launch( | ||
s, dR.size(0), m, n, dR.dptr_, dR.stride_, dV.dptr_, dV.stride_); | ||
// store dU_T in M | ||
mxnet_op::Kernel<QrTypeTransposeHelper, xpu>::Launch( | ||
s, q_shape.Size(), dU.dptr_, m_ptr, m, m, m * m); | ||
// dq_prime = dQ | ||
Copy(dQ_prime, dQ, s); | ||
// dq_prime = [email protected] | ||
gemm::op(Y, dV, dQ_prime, DType(1.0), DType(1.0), false, true, s); | ||
// dX = op call | ||
qr_backward::op(dX, | ||
dQ_prime, | ||
dU, | ||
Q, | ||
U, | ||
M, | ||
ctx, attrs); | ||
// dY = Q@dV; reuse Y memory for dY | ||
gemm::op(Q, dV, Y, DType(1.0), DType(0.0), false, false, s); | ||
// copy dX and dY to dA | ||
mxnet_op::Kernel<QrBackHelper_G2, xpu>::Launch( | ||
s, dA.size(0), m, n, dX.dptr_, dX.stride_, Y.dptr_, Y.stride_, dA.dptr_, dA.stride_); | ||
} | ||
// common for all shapes | ||
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { | ||
mxnet_op::Kernel<assign_helper<req_type>, xpu>::Launch( | ||
s, a_shape.Size(), grad_a_data.dptr<DType>(), grad_a.dptr<DType>()); | ||
mxnet_op::Kernel<assign_helper<req_type>, xpu>::Launch( | ||
s, a_shape.Size(), grad_a_data.dptr<DType>(), grad_a.dptr<DType>()); | ||
}); | ||
}); | ||
} | ||
|
@@ -594,14 +723,8 @@ void NumpyLaQrBackward(const nnvm::NodeAttrs& attrs, | |
const TBlob& q = inputs[3]; | ||
const TBlob& r = inputs[4]; | ||
const TBlob& grad_a = outputs[0]; | ||
const int a_ndim = a.shape_.ndim(); | ||
const int n = a.size(a_ndim - 1); | ||
const int m = a.size(a_ndim - 2); | ||
|
||
CHECK_LE(n, m) | ||
<< "QrBackward not implemented when ncols > nrows"; | ||
|
||
size_t workspace_size = QrBackwardWorkspaceSize<xpu>(a, r, grad_a); | ||
size_t workspace_size = QrBackwardWorkspaceSize<xpu>(a, q, r, grad_a); | ||
Tensor<xpu, 1, char> workspace = ctx.requested[0] | ||
.get_space_typed<xpu, 1, char>(Shape1(workspace_size), ctx.get_stream<xpu>()); | ||
QrBackwardImpl<xpu>(grad_a, grad_q, grad_r, a, q, r, req, workspace, ctx, attrs); | ||
|
Oops, something went wrong.