Skip to content

Commit

Permalink
Allocate temp data on the fly for some casting operations (apache#149)
Browse files Browse the repository at this point in the history
  • Loading branch information
cjolivier01 authored and eric-haibin-lin committed Aug 9, 2017
1 parent 253ae57 commit b2ad302
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions src/operator/tensor/cast_storage-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,15 @@ struct FillRspValsKernel {
}
};

template<typename xpu, int ndim, typename DType>
inline mshadow::Tensor<xpu, ndim, DType> AllocateTempDataForCast(const OpContext& op_ctx,
const mshadow::Shape<ndim>& shape) {
Resource rsc = ResourceManager::Get()->Request(op_ctx.run_ctx.ctx,
ResourceRequest(ResourceRequest::kTempSpace));
mshadow::Stream<xpu> *stream = op_ctx.run_ctx.get_stream<xpu>();
return rsc.get_space_typed<xpu, ndim, DType>(shape, stream);
};

/*!
* \brief GPU implementation of casting a dns tensor to rsp type.
*/
Expand Down Expand Up @@ -245,8 +254,8 @@ inline void CastStorageDnsRspImpl(const OpContext& ctx,
mshadow::Stream<gpu>::GetStream(s));

// Allocate temp storage for marking non-zero rows and for cub's prefix sum
mshadow::Tensor<gpu, 1, char> workspace = ctx.requested[0]
.get_space_typed<gpu, 1, char>(Shape1(num_rows*sizeof(RType)+temp_storage_bytes), s);
auto workspace = AllocateTempDataForCast<gpu, 1, char>(ctx, Shape1(num_rows*sizeof(RType)
+ temp_storage_bytes));
row_flg = reinterpret_cast<RType*>(workspace.dptr_);
d_temp_storage = workspace.dptr_ + num_rows*sizeof(RType);

Expand Down Expand Up @@ -652,8 +661,8 @@ inline void CastStorageDnsCsrImpl(const OpContext& ctx,
mshadow::Stream<gpu>::GetStream(s));

// Allocate temporary storage
mshadow::Tensor<gpu, 1, char> workspace = ctx.requested[0]
.get_space_typed<gpu, 1, char>(Shape1(temp_storage_bytes), s);
auto workspace = AllocateTempDataForCast<gpu, 1, char>(ctx, Shape1(temp_storage_bytes));

d_temp_storage = workspace.dptr_;

// Compute indptr through inclusive prefix sum
Expand Down

0 comments on commit b2ad302

Please sign in to comment.