-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support random_uniform/normal/gamma with row_sparse output #155
Changes from all commits
8da42c2
8aef7a5
3b969ac
90f4bd2
0dc5611
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -232,29 +232,75 @@ struct SampleGenNegBinomialParam : public dmlc::Parameter<SampleGenNegBinomialPa | |
} | ||
}; | ||
|
||
using FSampleCompute = std::function<void (const nnvm::NodeAttrs& attrs, | ||
const OpContext& ctx, | ||
const OpReqType& req, | ||
TBlob* outputs)>; | ||
|
||
template<typename xpu> | ||
void SampleUniform_(const nnvm::NodeAttrs& attrs, | ||
const OpContext& ctx, | ||
const std::vector<TBlob>& inputs, | ||
const std::vector<OpReqType>& req, | ||
const std::vector<TBlob>& outputs) { | ||
void SampleComputeEx_(const nnvm::NodeAttrs& attrs, | ||
const OpContext& ctx, | ||
const std::vector<NDArray>& inputs, | ||
const std::vector<OpReqType>& req, | ||
const std::vector<NDArray>& outputs, | ||
FSampleCompute fcomp) { | ||
NDArray output = outputs[0]; | ||
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>(); | ||
if (output.storage_type() == kRowSparseStorage) { | ||
// indices | ||
nnvm::dim_t nnr = output.shape()[0]; | ||
output.CheckAndAlloc({mshadow::Shape1(nnr)}); | ||
PopulateFullIdxRspImpl(s, &output); | ||
// data | ||
TBlob out_blob = output.data(); | ||
fcomp(attrs, ctx, req[0], &out_blob); | ||
} else { | ||
LOG(FATAL) << "Unexpected storage type for SampleComputeEx_: " | ||
<< output.storage_type(); | ||
} | ||
} | ||
|
||
template<typename xpu> | ||
void SampleUniformDnsImpl(const nnvm::NodeAttrs& attrs, | ||
const OpContext& ctx, | ||
const OpReqType& req, | ||
TBlob* output) { | ||
using namespace mxnet::op; | ||
using namespace mshadow::expr; | ||
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>(); | ||
const SampleUniformParam& param = nnvm::get<SampleUniformParam>(attrs.parsed); | ||
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { | ||
MSHADOW_REAL_TYPE_SWITCH(output->type_flag_, DType, { | ||
mshadow::Random<xpu, DType> *prnd = ctx.requested[0].get_random<xpu, DType>(s); | ||
mshadow::Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s); | ||
mshadow::Tensor<xpu, 2, DType> out = output->FlatTo2D<xpu, DType>(s); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. did we determine that FlatTo2D was bad? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks the the original dense operator uses this. The sampler op is defined in mshadow, and since it's performing sampling per element, IMO it won't matter what the shape is flatten to, for these random operators. |
||
prnd->SampleUniform(&out, param.low, param.high); | ||
}); | ||
} | ||
|
||
template<typename xpu> | ||
void SampleNormal_(const nnvm::NodeAttrs& attrs, | ||
const OpContext& ctx, | ||
const std::vector<TBlob>& inputs, | ||
const std::vector<OpReqType>& req, | ||
const std::vector<TBlob>& outputs) { | ||
void SampleUniform_(const nnvm::NodeAttrs& attrs, | ||
const OpContext& ctx, | ||
const std::vector<TBlob>& inputs, | ||
const std::vector<OpReqType>& req, | ||
const std::vector<TBlob>& outputs) { | ||
TBlob out = outputs[0]; | ||
SampleUniformDnsImpl<xpu>(attrs, ctx, req[0], &out); | ||
} | ||
|
||
|
||
template<typename xpu> | ||
void SampleUniformEx_(const nnvm::NodeAttrs& attrs, | ||
const OpContext& ctx, | ||
const std::vector<NDArray>& inputs, | ||
const std::vector<OpReqType>& req, | ||
const std::vector<NDArray>& outputs) { | ||
SampleComputeEx_<xpu>(attrs, ctx, inputs, req, outputs, SampleUniformDnsImpl<xpu>); | ||
} | ||
|
||
template<typename xpu> | ||
void SampleNormalDnsImpl(const nnvm::NodeAttrs& attrs, | ||
const OpContext& ctx, | ||
const OpReqType& req, | ||
TBlob* outputs) { | ||
using namespace mxnet::op; | ||
using namespace mshadow::expr; | ||
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>(); | ||
|
@@ -268,11 +314,29 @@ void SampleNormal_(const nnvm::NodeAttrs& attrs, | |
} | ||
|
||
template<typename xpu> | ||
void SampleGamma_(const nnvm::NodeAttrs& attrs, | ||
void SampleNormal_(const nnvm::NodeAttrs& attrs, | ||
const OpContext& ctx, | ||
const std::vector<TBlob>& inputs, | ||
const std::vector<OpReqType>& req, | ||
const std::vector<TBlob>& outputs) { | ||
TBlob out = outputs[0]; | ||
SampleNormalDnsImpl<xpu>(attrs, ctx, req[0], &out); | ||
} | ||
|
||
template<typename xpu> | ||
void SampleNormalEx_(const nnvm::NodeAttrs& attrs, | ||
const OpContext& ctx, | ||
const std::vector<NDArray>& inputs, | ||
const std::vector<OpReqType>& req, | ||
const std::vector<NDArray>& outputs) { | ||
SampleComputeEx_<xpu>(attrs, ctx, inputs, req, outputs, SampleNormalDnsImpl<xpu>); | ||
} | ||
|
||
template<typename xpu> | ||
void SampleGammaDnsImpl(const nnvm::NodeAttrs& attrs, | ||
const OpContext& ctx, | ||
const OpReqType& req, | ||
TBlob* outputs) { | ||
using namespace mxnet::op; | ||
using namespace mshadow::expr; | ||
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>(); | ||
|
@@ -286,6 +350,25 @@ void SampleGamma_(const nnvm::NodeAttrs& attrs, | |
}); | ||
} | ||
|
||
template<typename xpu> | ||
void SampleGamma_(const nnvm::NodeAttrs& attrs, | ||
const OpContext& ctx, | ||
const std::vector<TBlob>& inputs, | ||
const std::vector<OpReqType>& req, | ||
const std::vector<TBlob>& outputs) { | ||
TBlob out = outputs[0]; | ||
SampleGammaDnsImpl<xpu>(attrs, ctx, req[0], &out); | ||
} | ||
|
||
template<typename xpu> | ||
void SampleGammaEx_(const nnvm::NodeAttrs& attrs, | ||
const OpContext& ctx, | ||
const std::vector<NDArray>& inputs, | ||
const std::vector<OpReqType>& req, | ||
const std::vector<NDArray>& outputs) { | ||
SampleComputeEx_<xpu>(attrs, ctx, inputs, req, outputs, SampleGammaDnsImpl<xpu>); | ||
} | ||
|
||
template<typename xpu> | ||
void SampleExponential_(const nnvm::NodeAttrs& attrs, | ||
const OpContext& ctx, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -167,6 +167,26 @@ inline void FillDnsZerosRspImpl(mshadow::Stream<xpu> *s, NDArray *dst) { | |
}); | ||
} | ||
|
||
struct PopulateFullIdxRspKernel { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can the normal identity op be used instead? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. which identity op? what's the input for identity op? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nvm, you on't have it in your branch. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh I think there's already
|
||
template<typename IType> | ||
MSHADOW_XINLINE static void Map(int i, IType* out) { | ||
KERNEL_ASSIGN(out[i], kWriteTo, i); | ||
} | ||
}; | ||
|
||
// Fill full indices NDArray with zeros by updating the aux shape. | ||
template<typename xpu> | ||
void PopulateFullIdxRspImpl(mshadow::Stream<xpu> *s, NDArray *dst) { | ||
using namespace rowsparse; | ||
CHECK_EQ(dst->storage_type(), kRowSparseStorage); | ||
nnvm::dim_t nnr = dst->shape()[0]; | ||
dst->CheckAndAllocAuxData(kIdx, mshadow::Shape1(nnr)); | ||
MSHADOW_IDX_TYPE_SWITCH(dst->aux_type(kIdx), IType, { | ||
IType* idx = dst->aux_data(kIdx).dptr<IType>(); | ||
mxnet_op::Kernel<PopulateFullIdxRspKernel, xpu>::Launch(s, nnr, idx); | ||
}); | ||
} | ||
|
||
// Fill a rsp NDArray with zeros by updating the aux shape. | ||
template<typename xpu> | ||
void FillZerosRspImpl(mshadow::Stream<xpu> *s, NDArray *dst) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why ending in underscore for a publicly-accessibly function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the underscore for
SampleGaussian_
is because theSampleGaussian
is defined already somewhere else. But yeah, I can remove the underscore forSampleComputeEx_
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's consistent with a preexisting function, then I am fine with it