-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for working on this op!
src/executor/graph_executor.cc
Outdated
it->second = InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype); | ||
return it->second; | ||
} else { | ||
// not shareable storage | ||
return InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype); | ||
} // arg_array.shape().Size() >= arg_shape.Size() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment here probably makes sense after the else if clause instead of here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
line 704 is still performing sharing by updating it->second
, it's just that its size is too small, so creating a bigger ndarray for sharing. Maybe the name size_shareable
is misleading?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the comment "// arg_array.shape().Size() >= arg_shape.Size()" probably makes sense after if instead of else. I don't think the name size_shareable
is misleading.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see you're referring to this comment. I'll move it
src/executor/graph_executor.cc
Outdated
in_arg_vec->emplace_back(ReshapeOrCreate(arg_name, inferred_shape, inferred_dtype, | ||
inferred_stype, in_arg_ctxes[arg_top], | ||
shared_buffer)); | ||
// gradient for model parameter | ||
shared_buffer, true)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using a bool variable and passing it as argument to ReshapeOrCreate here and below would improve readability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes sure
MSHADOW_XINLINE static void Map(int tid, | ||
DType* row_flg, | ||
const IType* row_idx) { | ||
nnvm::dim_t idx = static_cast<nnvm::dim_t>(row_idx[tid]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use IType here instead of nnvm::dim_t ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm afraid not. For embedding the data/row_idx could be float, which cannot be used for []
, that's why it's always casted
const nnvm::dim_t* row_flg_sum, | ||
const nnvm::dim_t num_rows) { | ||
if (tid < num_rows) { | ||
nnvm::dim_t prev = (tid == 0) ? 0 : row_flg_sum[tid-1]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can RType be used here instead of nnvm::dim_t
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This kernel was used for other operators (such as dot, elemwise_sum), too. I'm moving this kernel from it's original place to this file so that it can be reused for embedding. Changing that interface is kind of out of the scope for this PR, since row_flg_sum comes from temp storage which is always allocated as nnvm::dim_t type...
[ 5., 6., 7., 8., 9.], | ||
[ 10., 11., 12., 13., 14.], | ||
[ 15., 16., 17., 18., 19.]] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use a rsp weight in the example here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I'll add one more example with rsp weight..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm actually not sure about that. Showing weights with 0's is kind of confusing..
@@ -95,6 +95,77 @@ Examples:: | |||
.add_argument("weight", "NDArray-or-Symbol", "The embedding weight matrix.") | |||
.add_arguments(EmbeddingParam::__FIELDS__()); | |||
|
|||
NNVM_REGISTER_OP(_contrib_SparseEmbedding) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stupid question: Why are we registering a new operator here instead of extending the existing embedding op ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question! Because if we override the existing op, we don't know whether to infer rsp grad or dense grad for weight, because the weight ndarray is not visible for backward pass.
const dim_t idx_offset = first - weight_idx; | ||
const dim_t out_offset = i * row_length; | ||
const dim_t weight_offset = idx_offset * row_length; | ||
if (idx_offset >= nnr || *(weight_idx + idx_offset) > val) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not very obvious to me why element is not found in the case where *(weight_idx + idx_offset) > val. Maybe add a comment here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. It's possible that weight.idx = [5,10] and data = [3,7], so cannot find any matching indices in weight_idx.
using namespace rowsparse; | ||
using namespace mxnet_op; | ||
// zeros weight | ||
if (req == kWriteTo && !weight.storage_initialized()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens when input data storage is not initialized.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input data is always dense.
|
||
[[ 0., 1., 2., 3., 4.], | ||
[ 10., 11., 12., 13., 14.]]] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also a frontend python example , similar to one here: https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/symbol_doc.py#L128 would be a good to have.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought we stopped using SymbolDoc with doc tests. Is that still working?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see the symbol doc example here: https://mxnet.incubator.apache.org/versions/master/api/python/symbol/symbol.html#mxnet.symbol.Embedding. I am not sure if it is deprecated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool will add this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like that doesn't work for contrib ops.. :(
src/executor/graph_executor.cc
Outdated
it->second = InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype); | ||
return it->second; | ||
} else { | ||
// not shareable storage | ||
return InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype); | ||
} // arg_array.shape().Size() >= arg_shape.Size() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the comment "// arg_array.shape().Size() >= arg_shape.Size()" probably makes sense after if instead of else. I don't think the name size_shareable
is misleading.
* cpu embedding draft * clean up * fix omp thread call * add sparse embedding example * check bound with signel thread * add note * add comments * add operator note * support rsp weight sharing for bucketing * improve workload balance in take add grad rsp kernel * use MSHADOW_CINLINE for cpu kernel * review comments. add unit test for shared rsp weight * remove indexing op-inl.h * Trigger * Trigger
* cpu embedding draft * clean up * fix omp thread call * add sparse embedding example * check bound with signel thread * add note * add comments * add operator note * support rsp weight sharing for bucketing * improve workload balance in take add grad rsp kernel * use MSHADOW_CINLINE for cpu kernel * review comments. add unit test for shared rsp weight * remove indexing op-inl.h * Trigger * Trigger
Description
The
SparseEmbedding
op takes indices and (rowsparse) weight as input and produces dense result in the forward pass. In backward pass it outputs (rowsparse) gradient for the weight, which is useful for sparse gradient update.~ 8x faster on a
c4.8xlarge
machine (36 cores) compared to dense embedding:With sparse embedding:
OMP_NUM_THREADS=32 python matrix_factorization.py --print-every=1000
With embedding:
OMP_NUM_THREADS=32 python matrix_factorization.py --print-every=1000 --use-dense
Note: SparseEmbedding checks if any input is out-of-bound, and throws an exception if found one.
Checklist
Essentials
make lint
)Changes
Comments