Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

cpu sparse embedding op #8460

Merged
merged 20 commits into from
Nov 7, 2017
Merged

cpu sparse embedding op #8460

merged 20 commits into from
Nov 7, 2017

Conversation

eric-haibin-lin
Copy link
Member

@eric-haibin-lin eric-haibin-lin commented Oct 28, 2017

Description

The SparseEmbedding op takes indices and (rowsparse) weight as input and produces dense result in the forward pass. In backward pass it outputs (rowsparse) gradient for the weight, which is useful for sparse gradient update.

~ 8x faster on a c4.8xlarge machine (36 cores) compared to dense embedding:
With sparse embedding: OMP_NUM_THREADS=32 python matrix_factorization.py --print-every=1000

INFO:root:Epoch[0] Batch [1000] Speed: 81742.24 samples/sec     mse=1.143560
INFO:root:Epoch[0] Batch [2000] Speed: 84291.38 samples/sec     mse=0.980251
INFO:root:Epoch[0] Batch [3000] Speed: 82812.25 samples/sec     mse=0.950991
INFO:root:Epoch[0] Batch [4000] Speed: 82077.48 samples/sec     mse=0.911922
INFO:root:Epoch[0] Batch [5000] Speed: 80918.93 samples/sec     mse=0.898004
INFO:root:Epoch[0] Batch [6000] Speed: 73614.71 samples/sec     mse=0.872740
INFO:root:Epoch[0] Batch [7000] Speed: 83037.03 samples/sec     mse=0.863862
INFO:root:Epoch[0] Batch [8000] Speed: 83170.19 samples/sec     mse=0.852437
INFO:root:Epoch[0] Batch [9000] Speed: 83669.75 samples/sec     mse=0.858010

With embedding: OMP_NUM_THREADS=32 python matrix_factorization.py --print-every=1000 --use-dense

INFO:root:Epoch[0] Batch [1000] Speed: 9179.29 samples/sec      mse=1.101776
INFO:root:Epoch[0] Batch [2000] Speed: 10797.47 samples/sec     mse=0.928593
INFO:root:Epoch[0] Batch [3000] Speed: 11022.51 samples/sec     mse=0.903449
INFO:root:Epoch[0] Batch [4000] Speed: 9529.76 samples/sec      mse=0.885801
INFO:root:Epoch[0] Batch [5000] Speed: 10027.75 samples/sec     mse=0.876870
INFO:root:Epoch[0] Batch [6000] Speed: 9447.15 samples/sec      mse=0.857204
INFO:root:Epoch[0] Batch [7000] Speed: 9982.74 samples/sec      mse=0.847336
INFO:root:Epoch[0] Batch [8000] Speed: 10335.85 samples/sec     mse=0.840450
INFO:root:Epoch[0] Batch [9000] Speed: 9890.59 samples/sec      mse=0.840442

Note: SparseEmbedding checks if any input is out-of-bound, and throws an exception if found one.

Checklist

Essentials

  • Passed code style checking (make lint)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • For user-facing API changes, API doc string has been updated.
  • To my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Feature1, tests, (and when applicable, API doc)
  • Feature2, tests, (and when applicable, API doc)

Comments

  • If this change is a backward incompatible change, why must this change be made.
  • Intersting edge cases to note here

@eric-haibin-lin eric-haibin-lin changed the title [WIP] Cpu sparse embedding op cpu sparse embedding op Oct 31, 2017
@formath formath mentioned this pull request Oct 31, 2017
@eric-haibin-lin
Copy link
Member Author

Copy link
Member

@anirudh2290 anirudh2290 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for working on this op!

it->second = InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype);
return it->second;
} else {
// not shareable storage
return InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype);
} // arg_array.shape().Size() >= arg_shape.Size()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment here probably makes sense after the else if clause instead of here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line 704 is still performing sharing by updating it->second , it's just that its size is too small, so creating a bigger ndarray for sharing. Maybe the name size_shareable is misleading?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the comment "// arg_array.shape().Size() >= arg_shape.Size()" probably makes sense after if instead of else. I don't think the name size_shareable is misleading.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see you're referring to this comment. I'll move it

in_arg_vec->emplace_back(ReshapeOrCreate(arg_name, inferred_shape, inferred_dtype,
inferred_stype, in_arg_ctxes[arg_top],
shared_buffer));
// gradient for model parameter
shared_buffer, true));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using a bool variable and passing it as argument to ReshapeOrCreate here and below would improve readability.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes sure

MSHADOW_XINLINE static void Map(int tid,
DType* row_flg,
const IType* row_idx) {
nnvm::dim_t idx = static_cast<nnvm::dim_t>(row_idx[tid]);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use IType here instead of nnvm::dim_t ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm afraid not. For embedding the data/row_idx could be float, which cannot be used for [], that's why it's always casted

const nnvm::dim_t* row_flg_sum,
const nnvm::dim_t num_rows) {
if (tid < num_rows) {
nnvm::dim_t prev = (tid == 0) ? 0 : row_flg_sum[tid-1];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can RType be used here instead of nnvm::dim_t

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This kernel was used for other operators (such as dot, elemwise_sum), too. I'm moving this kernel from it's original place to this file so that it can be reused for embedding. Changing that interface is kind of out of the scope for this PR, since row_flg_sum comes from temp storage which is always allocated as nnvm::dim_t type...

[ 5., 6., 7., 8., 9.],
[ 10., 11., 12., 13., 14.],
[ 15., 16., 17., 18., 19.]]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use a rsp weight in the example here ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'll add one more example with rsp weight..

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm actually not sure about that. Showing weights with 0's is kind of confusing..

@@ -95,6 +95,77 @@ Examples::
.add_argument("weight", "NDArray-or-Symbol", "The embedding weight matrix.")
.add_arguments(EmbeddingParam::__FIELDS__());

NNVM_REGISTER_OP(_contrib_SparseEmbedding)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stupid question: Why are we registering a new operator here instead of extending the existing embedding op ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question! Because if we override the existing op, we don't know whether to infer rsp grad or dense grad for weight, because the weight ndarray is not visible for backward pass.

const dim_t idx_offset = first - weight_idx;
const dim_t out_offset = i * row_length;
const dim_t weight_offset = idx_offset * row_length;
if (idx_offset >= nnr || *(weight_idx + idx_offset) > val) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not very obvious to me why element is not found in the case where *(weight_idx + idx_offset) > val. Maybe add a comment here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. It's possible that weight.idx = [5,10] and data = [3,7], so cannot find any matching indices in weight_idx.

using namespace rowsparse;
using namespace mxnet_op;
// zeros weight
if (req == kWriteTo && !weight.storage_initialized()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens when input data storage is not initialized.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input data is always dense.


[[ 0., 1., 2., 3., 4.],
[ 10., 11., 12., 13., 14.]]]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also a frontend python example , similar to one here: https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/symbol_doc.py#L128 would be a good to have.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we stopped using SymbolDoc with doc tests. Is that still working?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the symbol doc example here: https://mxnet.incubator.apache.org/versions/master/api/python/symbol/symbol.html#mxnet.symbol.Embedding. I am not sure if it is deprecated.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool will add this

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like that doesn't work for contrib ops.. :(

it->second = InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype);
return it->second;
} else {
// not shareable storage
return InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype);
} // arg_array.shape().Size() >= arg_shape.Size()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the comment "// arg_array.shape().Size() >= arg_shape.Size()" probably makes sense after if instead of else. I don't think the name size_shareable is misleading.

Ubuntu and others added 2 commits November 7, 2017 04:26
@piiswrong piiswrong merged commit 4862c41 into apache:master Nov 7, 2017
cjolivier01 pushed a commit to cjolivier01/mxnet that referenced this pull request Nov 9, 2017
* cpu embedding draft

* clean up

* fix omp thread call

* add sparse embedding example

* check bound with signel thread

* add note

* add comments

* add operator note

* support rsp weight sharing for bucketing

* improve workload balance in take add grad rsp kernel

* use MSHADOW_CINLINE for cpu kernel

* review comments. add unit test for shared rsp weight

* remove indexing op-inl.h

* Trigger

* Trigger
@eric-haibin-lin eric-haibin-lin deleted the cpu-embed branch November 14, 2017 05:57
rahul003 pushed a commit to rahul003/mxnet that referenced this pull request Jun 4, 2018
* cpu embedding draft

* clean up

* fix omp thread call

* add sparse embedding example

* check bound with signel thread

* add note

* add comments

* add operator note

* support rsp weight sharing for bucketing

* improve workload balance in take add grad rsp kernel

* use MSHADOW_CINLINE for cpu kernel

* review comments. add unit test for shared rsp weight

* remove indexing op-inl.h

* Trigger

* Trigger
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants