From 8892e6de6af26b9e7c802a526fa2b98016961836 Mon Sep 17 00:00:00 2001 From: Haibin Lin Date: Tue, 7 Nov 2017 13:52:15 -0800 Subject: [PATCH] cpu sparse embedding op (#8460) * cpu embedding draft * clean up * fix omp thread call * add sparse embedding example * check bound with signel thread * add note * add comments * add operator note * support rsp weight sharing for bucketing * improve workload balance in take add grad rsp kernel * use MSHADOW_CINLINE for cpu kernel * review comments. add unit test for shared rsp weight * remove indexing op-inl.h * Trigger * Trigger --- example/sparse/get_data.py | 60 ++- example/sparse/matrix_fact_model.py | 51 +++ example/sparse/matrix_factorization.py | 111 +++++ example/sparse/readme.md | 11 +- src/executor/graph_executor.cc | 76 ++-- src/ndarray/ndarray.cc | 8 +- src/ndarray/ndarray_function.cu | 1 + src/operator/tensor/cast_storage-inl.cuh | 2 + src/operator/tensor/dot-inl.cuh | 1 + src/operator/tensor/indexing_op.cc | 82 ++++ src/operator/tensor/indexing_op.h | 388 +++++++++++++++++- src/operator/tensor/util/tensor_util-inl.cuh | 26 -- src/operator/tensor/util/tensor_util-inl.h | 82 ++++ tests/python/unittest/test_module.py | 101 +++-- tests/python/unittest/test_sparse_operator.py | 40 ++ 15 files changed, 926 insertions(+), 114 deletions(-) create mode 100644 example/sparse/matrix_fact_model.py create mode 100644 example/sparse/matrix_factorization.py create mode 100644 src/operator/tensor/util/tensor_util-inl.h diff --git a/example/sparse/get_data.py b/example/sparse/get_data.py index 21db06d8e746..e96fd804b683 100644 --- a/example/sparse/get_data.py +++ b/example/sparse/get_data.py @@ -15,9 +15,28 @@ # specific language governing permissions and limitations # under the License. -# pylint: skip-file import os, gzip import sys +import mxnet as mx + +class DummyIter(mx.io.DataIter): + "A dummy iterator that always return the same batch, used for speed testing" + def __init__(self, real_iter): + super(DummyIter, self).__init__() + self.real_iter = real_iter + self.provide_data = real_iter.provide_data + self.provide_label = real_iter.provide_label + self.batch_size = real_iter.batch_size + + for batch in real_iter: + self.the_batch = batch + break + + def __iter__(self): + return self + + def next(self): + return self.the_batch def get_libsvm_data(data_dir, data_name, url): if not os.path.isdir(data_dir): @@ -31,3 +50,42 @@ def get_libsvm_data(data_dir, data_name, url): os.system("bzip2 -d %r" % data_name + ".bz2") print("Dataset " + data_name + " is now present.") os.chdir("..") + +def get_movielens_data(prefix): + if not os.path.exists("%s.zip" % prefix): + print("Dataset MovieLens 10M not present. Downloading now ...") + os.system("wget http://files.grouplens.org/datasets/movielens/%s.zip" % prefix) + os.system("unzip %s.zip" % prefix) + os.system("cd ml-10M100K; sh split_ratings.sh; cd -;") + +def get_movielens_iter(filename, batch_size, dummy_iter): + """Not particularly fast code to parse the text file and load into NDArrays. + return two data iters, one for train, the other for validation. + """ + print("Preparing data iterators for " + filename + " ... ") + user = [] + item = [] + score = [] + with file(filename) as f: + num_samples = 0 + for line in f: + tks = line.strip().split('::') + if len(tks) != 4: + continue + num_samples += 1 + user.append((tks[0])) + item.append((tks[1])) + score.append((tks[2])) + if dummy_iter and num_samples > batch_size * 10: + break + # convert to ndarrays + user = mx.nd.array(user, dtype='int32') + item = mx.nd.array(item) + score = mx.nd.array(score) + # prepare data iters + data_train = {'user':user, 'item':item} + label_train = {'score':score} + iter_train = mx.io.NDArrayIter(data=data_train,label=label_train, + batch_size=batch_size, shuffle=True) + iter_train = DummyIter(iter_train) if dummy_iter else iter_train + return mx.io.PrefetchingIter(iter_train) diff --git a/example/sparse/matrix_fact_model.py b/example/sparse/matrix_fact_model.py new file mode 100644 index 000000000000..d2d8de5dd33c --- /dev/null +++ b/example/sparse/matrix_fact_model.py @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import mxnet as mx + +def matrix_fact_net(factor_size, num_hidden, max_user, max_item, sparse_embed=True): + # input + user = mx.symbol.Variable('user') + item = mx.symbol.Variable('item') + score = mx.symbol.Variable('score') + if sparse_embed: + # user feature lookup + user_weight = mx.symbol.Variable('user_weight', stype='row_sparse') + user = mx.symbol.contrib.SparseEmbedding(data=user, weight=user_weight, + input_dim=max_user, output_dim=factor_size) + # item feature lookup + item_weight = mx.symbol.Variable('item_weight', stype='row_sparse') + item = mx.symbol.contrib.SparseEmbedding(data=item, weight=item_weight, + input_dim=max_item, output_dim=factor_size) + else: + # user feature lookup + user = mx.symbol.Embedding(data=user, input_dim=max_user, output_dim=factor_size) + # item feature lookup + item = mx.symbol.Embedding(data=item, input_dim=max_item, output_dim=factor_size) + # non-linear transformation of user features + user = mx.symbol.Activation(data=user, act_type='relu') + user = mx.symbol.FullyConnected(data=user, num_hidden=num_hidden) + # non-linear transformation of item features + item = mx.symbol.Activation(data=item, act_type='relu') + item = mx.symbol.FullyConnected(data=item, num_hidden=num_hidden) + # predict by the inner product, which is elementwise product and then sum + pred = user * item + pred = mx.symbol.sum(data=pred, axis = 1) + pred = mx.symbol.Flatten(data=pred) + # loss layer + pred = mx.symbol.LinearRegressionOutput(data=pred, label=score) + return pred diff --git a/example/sparse/matrix_factorization.py b/example/sparse/matrix_factorization.py new file mode 100644 index 000000000000..cdb61643d3a4 --- /dev/null +++ b/example/sparse/matrix_factorization.py @@ -0,0 +1,111 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import argparse +import logging +import time +import mxnet as mx +import numpy as np +from get_data import get_movielens_iter, get_movielens_data +from matrix_fact_model import matrix_fact_net +logging.basicConfig(level=logging.DEBUG) + +parser = argparse.ArgumentParser(description="Run matrix factorization with sparse embedding", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('--num-epoch', type=int, default=3, + help='number of epochs to train') +parser.add_argument('--batch-size', type=int, default=128, + help='number of examples per batch') +parser.add_argument('--print-every', type=int, default=100, + help='logging frequency') +parser.add_argument('--factor-size', type=int, default=128, + help="the factor size of the embedding operation") +parser.add_argument('--use-dense', action='store_true', + help="use the dense embedding operator") +parser.add_argument('--dummy-iter', action='store_true', + help="use the dummy data iterator for speed test") + +MOVIELENS = { + 'dataset': 'ml-10m', + 'train': './ml-10M100K/r1.train', + 'val': './ml-10M100K/r1.test', + 'max_user': 71569, + 'max_movie': 65135, +} + +if __name__ == '__main__': + head = '%(asctime)-15s %(message)s' + logging.basicConfig(level=logging.INFO, format=head) + + # arg parser + args = parser.parse_args() + logging.info(args) + num_epoch = args.num_epoch + batch_size = args.batch_size + optimizer = 'sgd' + use_sparse = not args.use_dense + factor_size = args.factor_size + dummy_iter = args.dummy_iter + print_every = args.print_every + + momentum = 0.9 + ctx = mx.cpu(0) + learning_rate = 0.1 + + # prepare dataset and iterators + max_user = MOVIELENS['max_user'] + max_movies = MOVIELENS['max_movie'] + get_movielens_data(MOVIELENS['dataset']) + train_iter = get_movielens_iter(MOVIELENS['train'], batch_size, dummy_iter) + val_iter = get_movielens_iter(MOVIELENS['val'], batch_size, dummy_iter) + + # construct the model + net = matrix_fact_net(factor_size, factor_size, max_user, max_movies, sparse_embed=use_sparse) + a = time.time() + + # initialize the module + mod = mx.module.Module(symbol=net, context=ctx, data_names=['user', 'item'], + label_names=['score']) + mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label) + mod.init_params(initializer=mx.init.Xavier(factor_type="in", magnitude=2.34)) + optim = mx.optimizer.create(optimizer, learning_rate=learning_rate, momentum=momentum, + wd=1e-4, rescale_grad=1.0/batch_size) + mod.init_optimizer(optimizer=optim) + # use MSE as the metric + metric = mx.metric.create(['MSE']) + speedometer = mx.callback.Speedometer(batch_size, print_every) + logging.info('Training started ...') + for epoch in range(num_epoch): + nbatch = 0 + metric.reset() + for batch in train_iter: + nbatch += 1 + mod.forward_backward(batch) + # update all parameters + mod.update() + # update training metric + mod.update_metric(metric, batch.label) + speedometer_param = mx.model.BatchEndParam(epoch=epoch, nbatch=nbatch, + eval_metric=metric, locals=locals()) + speedometer(speedometer_param) + # evaluate metric on validation dataset + score = mod.score(val_iter, ['MSE']) + logging.info('epoch %d, eval MSE = %s ' % (epoch, score[0][1])) + # reset the iterator for next pass of data + train_iter.reset() + val_iter.reset() + logging.info('Training completed.') diff --git a/example/sparse/readme.md b/example/sparse/readme.md index 73f6747115dd..e443bfa2d5f9 100644 --- a/example/sparse/readme.md +++ b/example/sparse/readme.md @@ -1,8 +1,8 @@ Example =========== -This folder contains examples using the sparse feature in MXNet. +This folder contains examples using the sparse feature in MXNet. They are for demonstration purpose only. -## Linear Classification +## Linear Classification Using Sparse Matrix Multiplication The example demonstrates the basic usage of the sparse feature in MXNet to speedup computation. It utilizes the sparse data loader, sparse operators and a sparse gradient updater to train a linear model on the [Avazu](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary.html#avazu) click-through-prediction dataset. @@ -12,3 +12,10 @@ Notes on Distributed Training: - For distributed training, please use the `../../tools/launch.py` script to launch a cluster. - For example, to run two workers and two servers with one machine, run `../../tools/launch.py -n 2 --launcher=local python linear_classification.py --kvstore=dist_async` + +## Matrix Factorization Using Sparse Embedding + +The example demonstrates the basic usage of the SparseEmbedding operator in MXNet, adapted based on @leopd's recommender examples. + +- `python matrix_factorization.py` +- To compare the train speed with (dense) Embedding, run `python matrix_factorization.py --use-dense` diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 90b2b269a5a2..dd4867559d5a 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -676,38 +676,50 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, /*! * \brief If the requested ndarray's shape size is less than * the corresponding shared_data_array's shape size and the - * storage type is default storage, reuse the memory allocation + * storage type is shareable, reuse the memory allocation * in shared_buffer; otherwise, create a zero ndarray. + * Shareable storages include both default storage and row_sparse storage + * if enable_row_sparse_sharing is `True`, otherwise default storage only. */ NDArray ReshapeOrCreate(const std::string& name, const TShape& dest_arg_shape, const int dest_arg_dtype, const NDArrayStorageType dest_arg_stype, const Context& ctx, - std::unordered_map* shared_buffer) { - if (dest_arg_dtype != kDefaultStorage) { - return InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype); + std::unordered_map* shared_buffer, + bool enable_row_sparse_sharing) { + bool stype_shareable = dest_arg_stype == kDefaultStorage; + if (enable_row_sparse_sharing) { + stype_shareable = stype_shareable || dest_arg_stype == kRowSparseStorage; } auto it = shared_buffer->find(name); if (it != shared_buffer->end()) { - if (it->second.shape().Size() >= dest_arg_shape.Size()) { // memory can be reused + // check if size is large enough for sharing + bool size_shareable = it->second.shape().Size() >= dest_arg_shape.Size(); + if (size_shareable && stype_shareable) { // memory can be reused CHECK_EQ(it->second.dtype(), dest_arg_dtype) - << "Requested arg array's dtype does not match the reusable ndarray"; - CHECK_EQ(it->second.storage_type(), kDefaultStorage) - << "shared_buffer should only contain NDArrays with default storage type."; + << "Requested arg array's dtype does not match that of the reusable ndarray"; + CHECK_EQ(it->second.storage_type(), dest_arg_stype) + << "Requested arg array's stype does not match that of the reusable ndarray"; return it->second.Reshape(dest_arg_shape); - } else { + } else if (stype_shareable) { LOG(WARNING) << "Bucketing: data " << name << " has a shape " << dest_arg_shape << ", which is larger than already allocated shape " << it->second.shape() << ". Need to re-allocate. Consider putting default bucket key to be " << "the bucket taking the largest input for better memory sharing."; - // the NDArrays in shared_buffer are guaranteed to be of default storage + // size is not large enough, creating a larger one for sharing + // the NDArrays in shared_buffer are guaranteed to be of shareable storages it->second = InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype); return it->second; - } // arg_array.shape().Size() >= arg_shape.Size() + } else { + // not shareable storage + return InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype); + } } else { auto ret = InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype); - shared_buffer->emplace(name, ret); + if (stype_shareable) { + shared_buffer->emplace(name, ret); + } return ret; } // if (it != shared_buffer->end()) } @@ -745,18 +757,21 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, const std::string& arg_name = idx[nid].source->attrs.name; // aux_states if (mutable_nodes.count(nid)) { - if (nullptr != shared_exec && inferred_stype == kDefaultStorage && - shared_exec->aux_state_map().at(arg_name).storage_type() == kDefaultStorage) { + if (nullptr != shared_exec) { const NDArray& aux_nd = shared_exec->aux_state_map().at(arg_name); + CHECK(inferred_stype == kDefaultStorage && aux_nd.storage_type() == kDefaultStorage) + << "Non-default storage type detected when creating auxilliary NDArray. The allocated " + << "memory of shared_exec.aux_array cannot be resued for argument: " + << arg_name << " for the current executor"; CHECK_EQ(inferred_shape, aux_nd.shape()) << "Inferred shape does not match shared_exec.aux_array's shape." " Therefore, the allocated memory for shared_exec.aux_array cannot" - " be resued for creating auxilliary NDArray of the argument" + " be resued for creating auxilliary NDArray of the argument: " << arg_name << " for the current executor"; CHECK_EQ(inferred_dtype, aux_nd.dtype()) << "Inferred dtype does not match shared_exec.aux_array's dtype." " Therefore, the allocated memory for shared_exec.aux_array cannot" - " be resued for creating auxilliary NDArray of the argument" + " be resued for creating auxilliary NDArray of the argument: " << arg_name << " for the current executor"; aux_state_vec->emplace_back(aux_nd); } else { @@ -769,10 +784,21 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, } else { // in_args and grad for in_args if (shared_arg_names.count(arg_name)) { // model parameter // model parameter - if (nullptr != shared_exec && inferred_stype == kDefaultStorage && - shared_exec->in_arg_map().at(arg_name).storage_type() == kDefaultStorage) { - // try to reuse memory from shared_exec + if (nullptr != shared_exec) { const NDArray& in_arg_nd = shared_exec->in_arg_map().at(arg_name); + auto arg_nd_stype = in_arg_nd.storage_type(); + // for model parameter, both default storage and row_sparse storage can be shared + bool shareable_arg_stype = inferred_stype == kDefaultStorage || + inferred_stype == kRowSparseStorage; + // try to reuse memory from shared_exec + CHECK(shareable_arg_stype) << "Inferred storage type " + << common::stype_string(inferred_stype) + << " does not support memory sharing with shared_exec.arg_array"; + CHECK_EQ(inferred_stype, arg_nd_stype) + << "Inferred stype does not match shared_exec.arg_array's stype" + " Therefore, the allocated memory for shared_exec.arg_array cannot" + " be resued for creating NDArray of the argument" + << arg_name << " for the current executor"; CHECK_EQ(inferred_shape, in_arg_nd.shape()) << "Inferred shape does not match shared_exec.arg_array's shape" " Therefore, the allocated memory for shared_exec.arg_array cannot" @@ -801,26 +827,30 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, // try to reuse memory from shared_exec arg_grad_vec->emplace_back(shared_exec->arg_grad_map().at(arg_name)); } else { + // no need to reuse memory from shared_exec for gradient of non-default storage EmplaceBackZeros(grad_stype, inferred_shape, arg_grad_ctxes[arg_top], inferred_dtype, arg_grad_vec); } grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back()); } } else { // !shared_arg_names.count(arg_name) - // model parameter + // model parameter, row_sparse ndarray sharing enabled + bool enable_row_sparse_sharing = true; in_arg_vec->emplace_back(ReshapeOrCreate(arg_name, inferred_shape, inferred_dtype, inferred_stype, in_arg_ctxes[arg_top], - shared_buffer)); - // gradient for model parameter + shared_buffer, enable_row_sparse_sharing)); + // gradient for model parameter, row_sparse ndarray sharing disabled if (kNullOp == grad_req_types[arg_top]) { arg_grad_vec->emplace_back(); } else { auto grad_oid = grad_store_.size() + num_forward_outputs_; auto grad_eid = idx.entry_id(idx.outputs()[grad_oid]); auto grad_stype = (NDArrayStorageType) inferred_stypes[grad_eid]; + bool enable_row_sparse_sharing = false; arg_grad_vec->emplace_back(ReshapeOrCreate("grad of " + arg_name, inferred_shape, inferred_dtype, grad_stype, - arg_grad_ctxes[arg_top], shared_buffer)); + arg_grad_ctxes[arg_top], shared_buffer, + enable_row_sparse_sharing)); grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back()); } // if (kNullOp == grad_req_types[arg_top]) } // if (shared_arg_names.count(arg_name)) diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index dd43338f6dfd..275cf4038071 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -65,10 +65,10 @@ nnvm::Symbol NDArray::get_autograd_symbol() const { NDArray NDArray::Reshape(const TShape &shape) const { CHECK(!is_none()) << "NDArray is not initialized"; - CHECK(storage_type() == kDefaultStorage) << "Reshape for storage type " << - storage_type() << " is not implemented yet"; - CHECK(storage_type() == kDefaultStorage) << "Reshape for storage type " << - storage_type() << " is not implemented yet"; + auto stype = storage_type(); + // reshape is not supported for non-default ndarray with dismatching shapes + CHECK((shape_ == shape) || stype == kDefaultStorage) + << "Reshape for storage type " << stype << " is not implemented yet"; CHECK_GE(shape_.Size(), shape.Size()) << "NDArray.Reshape: target shape size is larger current shape"; NDArray ret = this->Detach(); diff --git a/src/ndarray/ndarray_function.cu b/src/ndarray/ndarray_function.cu index 8accc2b41cfd..445f8459aef2 100644 --- a/src/ndarray/ndarray_function.cu +++ b/src/ndarray/ndarray_function.cu @@ -27,6 +27,7 @@ #include #include "../operator/mxnet_op.h" #include "../operator/tensor/init_op.h" +#include "../operator/tensor/util/tensor_util-inl.h" #include "../operator/tensor/util/tensor_util-inl.cuh" #include "../common/cuda_utils.h" #include "./ndarray_function.h" diff --git a/src/operator/tensor/cast_storage-inl.cuh b/src/operator/tensor/cast_storage-inl.cuh index fb75438250db..c441341eafd0 100644 --- a/src/operator/tensor/cast_storage-inl.cuh +++ b/src/operator/tensor/cast_storage-inl.cuh @@ -28,6 +28,8 @@ #include #include #include +#include +#include "./util/tensor_util-inl.h" #include "../mxnet_op.h" #include "./util/tensor_util-inl.cuh" diff --git a/src/operator/tensor/dot-inl.cuh b/src/operator/tensor/dot-inl.cuh index fd668b38eb69..2b346bfaf299 100644 --- a/src/operator/tensor/dot-inl.cuh +++ b/src/operator/tensor/dot-inl.cuh @@ -27,6 +27,7 @@ #include #include +#include "./util/tensor_util-inl.h" #include "./util/tensor_util-inl.cuh" namespace mxnet { diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index cd44eb8c26c5..bbcad70d695d 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -95,6 +95,77 @@ Examples:: .add_argument("weight", "NDArray-or-Symbol", "The embedding weight matrix.") .add_arguments(EmbeddingParam::__FIELDS__()); +NNVM_REGISTER_OP(_contrib_SparseEmbedding) +.describe(R"code(Maps integer indices to vector representations (embeddings). + +This operator maps words to real-valued vectors in a high-dimensional space, +called word embeddings. These embeddings can capture semantic and syntactic properties of the words. +For example, it has been noted that in the learned embedding spaces, similar words tend +to be close to each other and dissimilar words far apart. + +For an input array of shape (d1, ..., dK), +the shape of an output array is (d1, ..., dK, output_dim). +All the input values should be integers in the range [0, input_dim). + +If the input_dim is ip0 and output_dim is op0, then shape of the embedding weight matrix must be +(ip0, op0). + +The storage type of weight must be `row_sparse`, and the gradient of the weight will be of +`row_sparse` storage type, too. + +.. Note:: + + `SparseEmbedding` is designed for the use case where `input_dim` is very large (e.g. 100k). + The `row_sparse` weight cannot be used in a `BucketingModule`. + The operator is only available on CPU. + +Examples:: + + input_dim = 4 + output_dim = 5 + + // Each row in weight matrix y represents a word. So, y = (w0,w1,w2,w3) + y = [[ 0., 1., 2., 3., 4.], + [ 5., 6., 7., 8., 9.], + [ 10., 11., 12., 13., 14.], + [ 15., 16., 17., 18., 19.]] + + // Input array x represents n-grams(2-gram). So, x = [(w1,w3), (w0,w2)] + x = [[ 1., 3.], + [ 0., 2.]] + + // Mapped input x to its vector representation y. + SparseEmbedding(x, y, 4, 5) = [[[ 5., 6., 7., 8., 9.], + [ 15., 16., 17., 18., 19.]], + + [[ 0., 1., 2., 3., 4.], + [ 10., 11., 12., 13., 14.]]] + +)code" ADD_FILELINE) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data", "weight"}; + }) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FInferShape", EmbeddingOpShape) +.set_attr("FInferType", EmbeddingOpType) +.set_attr("FInferStorageType", SparseEmbeddingOpForwardStorageType) +.set_attr("FComputeEx", SparseEmbeddingOpForwardEx) +.set_attr("FGradient", + [](const nnvm::NodePtr& n, const std::vector& ograds) { + return MakeNonlossGradNode("_backward_SparseEmbedding", n, ograds, + {n->inputs[0]}, n->attrs.dict); + }) +.add_argument("data", "NDArray-or-Symbol", "The input array to the embedding operator.") +.add_argument("weight", "NDArray-or-Symbol", "The embedding weight matrix.") +.add_arguments(EmbeddingParam::__FIELDS__()); + NNVM_REGISTER_OP(_backward_Embedding) .set_num_inputs(2) .set_num_outputs(2) @@ -105,6 +176,17 @@ NNVM_REGISTER_OP(_backward_Embedding) .set_attr("TIsBackward", true) .set_attr("FCompute", EmbeddingOpBackward); +NNVM_REGISTER_OP(_backward_SparseEmbedding) +.set_num_inputs(2) +.set_num_outputs(2) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FInferStorageType", SparseEmbeddingOpBackwardStorageType) +.set_attr("TIsBackward", true) +.set_attr("FComputeEx", SparseEmbeddingOpBackwardEx); + NNVM_REGISTER_OP(take) .describe(R"code(Takes elements from an input array along the given axis. diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 4eb0e46908f0..262a43100762 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -38,6 +38,7 @@ #include "../operator_common.h" #include "../mshadow_op.h" #include "../elemwise_op_common.h" +#include "./util/tensor_util-inl.h" #include "../mxnet_op.h" #include "./sort_op.h" #include "./dot-inl.h" @@ -53,6 +54,7 @@ enum EmbeddingOpOutputs {kOut}; enum EmbeddingOpResource {kTempSpace}; } // namespace embedding + struct EmbeddingParam: public dmlc::Parameter { int input_dim; int output_dim; @@ -173,6 +175,62 @@ inline bool EmbeddingOpType(const nnvm::NodeAttrs& attrs, return true; } +inline bool SparseEmbeddingOpForwardStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + const int& data_stype = in_attrs->at(embedding::kData); + const int& weight_stype = in_attrs->at(embedding::kWeight); + int& out_stype = out_attrs->at(embedding::kOut); + bool dispatched = false; + const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask; + if (!dispatched && data_stype == kDefaultStorage && + weight_stype == kRowSparseStorage && !invalid_ctx) { + // dns, rsp -> dns + dispatched = storage_type_assign(&out_stype, kDefaultStorage, + dispatch_mode, DispatchMode::kFComputeEx); + } + if (!dispatched) { + // nothing to fallback on + LOG(FATAL) << "Not implemented: " + << operator_stype_string(attrs, dev_mask, *in_attrs, *out_attrs); + } + return true; +} + + +inline bool SparseEmbeddingOpBackwardStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 2U); + const int ograd_stype = in_attrs->at(0); + const int data_stype = in_attrs->at(1); + int& data_grad_stype = out_attrs->at(0); + int& weight_grad_stype = out_attrs->at(1); + bool dispatched = false; + const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask; + if (!dispatched && ograd_stype == kDefaultStorage && + data_stype == kDefaultStorage && !invalid_ctx) { + // dns, dns -> dns, rsp + if (type_assign(&data_grad_stype, kDefaultStorage) && + type_assign(&weight_grad_stype, kRowSparseStorage) && + dispatch_mode_assign(dispatch_mode, DispatchMode::kFComputeEx)) { + dispatched = true; + } + } + if (!dispatched) { + // nothing to fallback on + LOG(FATAL) << "Not implemented: " + << operator_stype_string(attrs, dev_mask, *in_attrs, *out_attrs); + } + return true; +} /*! \brief name the struct Take instead of take * to avoid conflict with the take function in mshadow */ @@ -192,13 +250,164 @@ struct Take { } }; +// Embedding forward implementation with dense weight +template +void EmbeddingOpForwardDnsImpl(mshadow::Stream* s, + const TBlob& data, + const TBlob& weight, + const OpReqType req, + const TBlob& output) { + using namespace mxnet_op; + const TShape& ishape = data.shape_; + const TShape& oshape = output.shape_; + + MSHADOW_TYPE_SWITCH(output.type_flag_, DType, { + MSHADOW_TYPE_SWITCH(data.type_flag_, IType, { + Tensor idx = data.get_with_shape( + Shape1(ishape.ProdShape(0, ishape.ndim())), s); + Tensor wmat = weight.get(s); + Tensor out = output.get_with_shape( + Shape2(oshape.ProdShape(0, oshape.ndim()-1), oshape[oshape.ndim()-1]), s); + Kernel::Launch(s, oshape.Size(), out.dptr_, wmat.dptr_, + idx.dptr_, wmat.shape_[1], wmat.shape_[0]); + }); + }); +} + + +template +struct TakeRspKernel { + /*! + * \brief + * \param i thread id + * \param data input data + * \param out output + * \param weight_idx indices of rsp weight + * \param weight_data data of rsp weight + * \param row_length number of elements per row + * \param nnr number of non-zero rows + */ + template + MSHADOW_XINLINE static void Map(int i, + const IType* data, + DType* out, + const RType* weight_idx, + const DType* weight_data, + const nnvm::dim_t row_length, + const nnvm::dim_t nnr) { + using nnvm::dim_t; + const dim_t val = static_cast(data[i]); + const DType zero = 0; + // Use binary search to find the lower_bound of val in weight_idx array + // (adapted based on the binary search in dot kernel) + const RType* first = weight_idx; + const RType* last = weight_idx + nnr; + const RType* it; + dim_t count = last - first, step; + while (count > 0) { + it = first; + step = count / 2; + it += step; + if (*it < val) { + first = ++it; + count -= step + 1; + } else { + count = step; + } + } + // end of binary search + const dim_t idx_offset = first - weight_idx; + const dim_t out_offset = i * row_length; + const dim_t weight_offset = idx_offset * row_length; + // target idx might be missing in weight.idx. For example, + // weight.idx = [5,10] and data = [3,7], so binary search fails to + // find any matching indices in weight_idx. + if (idx_offset >= nnr || *(weight_idx + idx_offset) > val) { + // val not found, fill zeros + for (int j = 0; j < row_length; j++) { + KERNEL_ASSIGN(out[out_offset + j], req, zero); + } + } else { + for (int j = 0; j < row_length; j++) { + KERNEL_ASSIGN(out[out_offset + j], req, weight_data[weight_offset + j]); + } + } + } +}; + +inline void EmbeddingOpForwardRspImpl(mshadow::Stream* s, + const cpu& cpu_dev, + const TBlob& data, + const NDArray& weight, + const OpReqType req, + const TBlob& output) { + using namespace mxnet_op; + using namespace rowsparse; + MSHADOW_TYPE_SWITCH(output.type_flag_, DType, { + MSHADOW_TYPE_SWITCH(data.type_flag_, IType, { + MSHADOW_TYPE_SWITCH(weight.aux_type(kIdx), RType, { + MXNET_ASSIGN_REQ_SWITCH(req, req_t, { + size_t data_size = data.shape_.Size(); + // only using the second dim since weight.ndim() == 2 + const nnvm::dim_t row_length = weight.shape()[1]; + Kernel, cpu>::Launch(s, data_size, data.dptr(), + output.dptr(), + weight.aux_data(kIdx).dptr(), + weight.data().dptr(), + row_length, weight.aux_shape(kIdx)[0]); + }); + }); + }); + }); +} + + +// Embedding forward implementation with row_sparse weight +template +void SparseEmbeddingOpForwardRspImpl(mshadow::Stream* s, + const TBlob& data, + const NDArray& weight, + const OpReqType req, + const TBlob& output) { + if (req == kNullOp) return; + CHECK((std::is_same::value)) << "SparseEmbedding is only implemented for CPU"; + using namespace rowsparse; + using namespace mxnet_op; + // zeros weight + if (req == kWriteTo && !weight.storage_initialized()) { + size_t out_size = output.shape_.Size(); + MSHADOW_TYPE_SWITCH(output.type_flag_, DType, { + Kernel::Launch(s, out_size, output.dptr()); + }) + return; + } + // check out-of-bound indices + bool is_valid = true; + MSHADOW_TYPE_SWITCH(data.type_flag_, DType, { + DType min = 0; + DType max = static_cast(weight.shape()[0] - 1); + // check with single thread is faster since data is small + DType* data_ptr = data.dptr(); + size_t data_size = data.shape_.Size(); + for (size_t i = 0; i < data_size; i++) { + if (data_ptr[i] > max || data_ptr[i] < min) is_valid = false; + } + }) + CHECK(is_valid) << "SparseEmbedding input contains data out of bound"; + // the weight is actually dense + if (weight.aux_shape(kIdx)[0] == weight.shape()[0]) { + EmbeddingOpForwardDnsImpl(s, data, weight.data(), req, output); + } else { + EmbeddingOpForwardRspImpl(s, xpu(), data, weight, req, output); + } +} + template void EmbeddingOpForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - using namespace mxnet_op; CHECK_EQ(req[embedding::kOut], kWriteTo); CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 1U); @@ -206,22 +415,36 @@ void EmbeddingOpForward(const nnvm::NodeAttrs& attrs, << "Embedding layer expects its weight to be two-dimensional. " << inputs[embedding::kWeight].ndim() << " dimensional input is given instead"; + mshadow::Stream *s = ctx.get_stream(); + EmbeddingOpForwardDnsImpl(s, inputs[embedding::kData], inputs[embedding::kWeight], + req[embedding::kOut], outputs[embedding::kOut]); +} - const TShape& ishape = inputs[embedding::kData].shape_; - const TShape& oshape = outputs[embedding::kOut].shape_; - - Stream *s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { - MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { - Tensor data = inputs[embedding::kData].get_with_shape( - Shape1(ishape.ProdShape(0, ishape.ndim())), s); - Tensor wmat = inputs[embedding::kWeight].get(s); - Tensor out = outputs[embedding::kOut].get_with_shape( - Shape2(oshape.ProdShape(0, oshape.ndim()-1), oshape[oshape.ndim()-1]), s); - Kernel::Launch(s, oshape.Size(), out.dptr_, wmat.dptr_, - data.dptr_, wmat.shape_[1], wmat.shape_[0]); - }); - }); +template +void SparseEmbeddingOpForwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(req[embedding::kOut], kWriteTo); + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + const NDArray& data = inputs[embedding::kData]; + const NDArray& weight = inputs[embedding::kWeight]; + const NDArray& out = outputs[embedding::kOut]; + CHECK_EQ(weight.shape().ndim(), 2U) + << "Embedding layer expects its weight to be two-dimensional. " + << weight.shape().ndim() << " dimensional input is given instead"; + const auto data_stype = data.storage_type(); + const auto weight_stype = weight.storage_type(); + const auto out_stype = out.storage_type(); + mshadow::Stream *s = ctx.get_stream(); + if (data_stype == kDefaultStorage && weight_stype == kRowSparseStorage && + out_stype == kDefaultStorage) { + SparseEmbeddingOpForwardRspImpl(s, data.data(), weight, req[0], out.data()); + } else { + LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs); + } } // Returns integer log2(a) rounded up @@ -337,6 +560,139 @@ void EmbeddingOpBackward(const nnvm::NodeAttrs& attrs, }); } +struct AddTakeGradRspKernel { + /*! + * \brief Each thread i is responsible for row slices in [segment_start, segment_end) + of the result gradient + * \param tid global thread id + * \param grad the gradient to calculate + * \param prefix_sum the inclusive prefix sum of row ids of the gradient + * \param ograd output gradient + * \param row_length the length of the row slices of the gradient + * \param data_val the values of input data + * \param data_size number of values of input data + * \param segment_length the length of row segment to process for each thread + * \param nnr total number of non-zero rows of result gradient + */ + template + MSHADOW_CINLINE static void Map(int tid, + DType* grad, + const nnvm::dim_t* prefix_sum, + const DType* ograd, + const nnvm::dim_t row_length, + const IType* data_val, + const nnvm::dim_t data_size, + const nnvm::dim_t segment_length, + const nnvm::dim_t nnr) { + using nnvm::dim_t; + dim_t segment_start = tid * segment_length; + dim_t segment_end = std::min(nnr, segment_start + segment_length); + // scan all data + for (dim_t data_i = 0; data_i < data_size; data_i++) { + dim_t data = static_cast(data_val[data_i]); + dim_t grad_row_id = prefix_sum[data] - 1; + if (grad_row_id < segment_start || grad_row_id >= segment_end) continue; + // no projection is performed + dim_t ograd_i = data_i * row_length; + dim_t grad_i = grad_row_id * row_length; + for (dim_t offset = 0; offset < row_length; offset++) { + grad[grad_i + offset] += ograd[ograd_i + offset]; + } + } + } +}; + +inline void SparseEmbeddingOpBackwardRspImpl(const OpContext& ctx, + const cpu& cpu_dev, + const TBlob& ograd, + const TBlob& data, + const OpReqType req, + const NDArray& output) { + using namespace mshadow; + using namespace mxnet_op; + using namespace mshadow::expr; + using namespace rowsparse; + using nnvm::dim_t; + if (req == kNullOp) return; + CHECK_EQ(req, kWriteTo) << "SparseEmbedding layer doesn't support " + << "weight gradient calculation with req != write"; + + // Request temporary storage for marking non-zero rows and prefix sum + Stream *s = ctx.get_stream(); + dim_t num_rows = output.shape()[0]; + dim_t row_length = output.shape()[1]; + // TODO(haibin) request less storage to save space in the future + size_t workspace_size = 2 * (num_rows * sizeof(dim_t)); + Tensor workspace = + ctx.requested[embedding::kTempSpace].get_space_typed( + Shape1(workspace_size), s); + dim_t* row_flg = reinterpret_cast(workspace.dptr_); + dim_t* prefix_sum = row_flg + num_rows; + dim_t data_size = static_cast(data.shape_.Size()); + + MSHADOW_TYPE_SWITCH(data.type_flag_, IType, { + MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, { + MSHADOW_TYPE_SWITCH(output.aux_type(kIdx), RType, { + // mark row flags + Fill(s, TBlob(row_flg, mshadow::Shape1(num_rows), cpu::kDevMask), kWriteTo, 0); + Kernel::Launch(s, data_size, row_flg, data.dptr()); + // calculate inclusive prefix sum + // TODO(haibin) ideally this is should be done in parallel + prefix_sum[0] = row_flg[0]; + for (dim_t i = 1; i < num_rows; i++) { + prefix_sum[i] = prefix_sum[i - 1] + row_flg[i]; + } + // total number of non-zero rows + dim_t nnr = prefix_sum[num_rows - 1]; + if (nnr == 0) { + FillZerosRspImpl(s, output); + return; + } + output.CheckAndAlloc({Shape1(nnr)}); + RType* grad_row_idx = output.aux_data(kIdx).dptr(); + // fill row_idx array of output matrix, using the row_flg values + Kernel::Launch(s, num_rows, + grad_row_idx, prefix_sum, num_rows); + // prefill with zeros + DType* grad_data = output.data().dptr(); + Kernel::Launch(s, nnr * row_length, grad_data); + // add the final gradients + int num_threads = Engine::Get()->num_omp_threads_per_worker(); + dim_t segment_len = (nnr + num_threads - 1) / num_threads; + Kernel::Launch(s, num_threads, grad_data, prefix_sum, + ograd.dptr(), row_length, + data.dptr(), data_size, segment_len, + num_rows); + }); + }); + }); +} + +template +void SparseEmbeddingOpBackwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 2U); + const NDArray& weight_grad = outputs[1]; + const NDArray& ograd = inputs[0]; + const NDArray& data = inputs[1]; + // check dtype + CHECK_EQ(weight_grad.dtype(), ograd.dtype()); + // check req + CHECK_EQ(req[embedding::kData], kNullOp) + << "SparseEmbedding layer doesn't support calculate data gradient"; + if (data.storage_type() == kDefaultStorage && ograd.storage_type() == kDefaultStorage && + weight_grad.storage_type() == kRowSparseStorage) { + SparseEmbeddingOpBackwardRspImpl(ctx, xpu(), ograd.data(), data.data(), + req[embedding::kWeight], weight_grad); + } else { + LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs); + } +} + namespace take_ { // to avoid name conflict enum TakeOpInputs {kArr, kIdx}; enum TakeOpOutputs {kOut}; diff --git a/src/operator/tensor/util/tensor_util-inl.cuh b/src/operator/tensor/util/tensor_util-inl.cuh index 8bbee2522c76..f38e8e117c94 100644 --- a/src/operator/tensor/util/tensor_util-inl.cuh +++ b/src/operator/tensor/util/tensor_util-inl.cuh @@ -194,32 +194,6 @@ struct IndexRspRowFlgKernel { } }; -/*! - * \brief GPU kernel for filling the row index array of an rsp tensor. - * Parallelized by tensor rows: 1 thread/row - */ -struct FillRspRowIdxKernel { - /*! - * \brief - * \param tid global thread id - * \param row_idx row index array to store indices of non-zero rows - * \param row_flg_sum inclusive prefix sum array over 0/1 marked row flag array - * \param num_rows rsp tensor number of rows (shape) - */ - template - __device__ __forceinline__ static void Map(int tid, - RType* row_idx, - const nnvm::dim_t* row_flg_sum, - const nnvm::dim_t num_rows) { - if (tid < num_rows) { - nnvm::dim_t prev = (tid == 0)? 0 : row_flg_sum[tid-1]; - if (row_flg_sum[tid] > prev) { - row_idx[prev] = static_cast(tid); - } - } - } -}; - /*! * \brief GPU kernel for marking non-zero columns of a csr matrix. * Parallelized by matrix rows: 1 warp/row diff --git a/src/operator/tensor/util/tensor_util-inl.h b/src/operator/tensor/util/tensor_util-inl.h new file mode 100644 index 000000000000..45b12730318a --- /dev/null +++ b/src/operator/tensor/util/tensor_util-inl.h @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2017 by Contributors + * \file tensor_util-inl.h + * \brief commonly utilized tensor operator CPU kernels + */ +#ifndef MXNET_OPERATOR_TENSOR_UTIL_TENSOR_UTIL_INL_H_ +#define MXNET_OPERATOR_TENSOR_UTIL_TENSOR_UTIL_INL_H_ + +#include +#include + +namespace mxnet { +namespace op { + +/*! + * \brief kernel to flag indices that appear in row_idx array with 1. + */ +struct MarkRowFlgKernel { + /*! + * \brief + * \param tid global thread id + * \param row_flg flag array for indices + * \param row_idx row index array storing indices of rows + */ + template + MSHADOW_XINLINE static void Map(int tid, + DType* row_flg, + const IType* row_idx) { + nnvm::dim_t idx = static_cast(row_idx[tid]); + row_flg[idx] = 1; + } +}; + +/*! + * \brief kernel for filling the row index array of an rsp tensor. + * Parallelized by tensor rows: 1 thread/row + */ +struct FillRspRowIdxKernel { + /*! + * \brief + * \param tid global thread id + * \param row_idx row index array to store indices of non-zero rows + * \param row_flg_sum inclusive prefix sum array over 0/1 marked row flag array + * \param num_rows rsp tensor number of rows (shape) + */ + template + MSHADOW_XINLINE static void Map(int tid, + RType* row_idx, + const nnvm::dim_t* row_flg_sum, + const nnvm::dim_t num_rows) { + if (tid < num_rows) { + nnvm::dim_t prev = (tid == 0) ? 0 : row_flg_sum[tid-1]; + if (row_flg_sum[tid] > prev) { + row_idx[prev] = static_cast(tid); + } + } + } +}; + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_TENSOR_UTIL_TENSOR_UTIL_INL_H_ diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py index d79657a85af5..180d2ee05242 100644 --- a/tests/python/unittest/test_module.py +++ b/tests/python/unittest/test_module.py @@ -349,14 +349,20 @@ def mean_abs(x): assert(mon_result_counts == [2, 2, 1, 6, 6, 4]) def test_executor_group(): - def get_rnn_sym(num_layers, num_words, num_hidden, num_embed, seq_len): + def get_rnn_sym(num_layers, num_words, num_hidden, num_embed, seq_len, sparse_embedding): stack = mx.rnn.SequentialRNNCell() for i in range(num_layers): stack.add(mx.rnn.LSTMCell(num_hidden=num_hidden, prefix='lstm_l%d_' % i)) data = mx.sym.Variable('data') label = mx.sym.Variable('softmax_label') - embed = mx.sym.Embedding(data=data, input_dim=num_words, - output_dim=num_embed, name='embed') + if sparse_embedding: + embed_weight = mx.sym.Variable('embed_weight', stype='row_sparse') + embed = mx.sym.contrib.SparseEmbedding(data=data, input_dim=num_words, + weight=embed_weight, output_dim=num_embed, + name='embed') + else: + embed = mx.sym.Embedding(data=data, input_dim=num_words, + output_dim=num_embed, name='embed') stack.reset() outputs, states = stack.unroll(seq_len, inputs=embed, merge_outputs=True) @@ -368,7 +374,8 @@ def get_rnn_sym(num_layers, num_words, num_hidden, num_embed, seq_len): pred = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax') return pred - def test_shared_exec_group(exec_grp_shared, exec_grp_created, shared_arg_names=None, extra_args=None): + def test_shared_exec_group(exec_grp_shared, exec_grp_created, shared_arg_names=None, + extra_args=None, check_shared_grad=True): # Test shared data arrays for i in range(len(exec_grp_shared.execs)): # test same shared_data_arrays for two exec groups @@ -404,12 +411,14 @@ def test_shared_exec_group(exec_grp_shared, exec_grp_created, shared_arg_names=N assert mx.test_utils.same_array(exec_shared.arg_dict[arg_name], exec_created.arg_dict[arg_name]), \ "Shared argument '%s' does not share memory." % (arg_name) # test shared argument gradients - for arg_name in shared_arg_names: - assert arg_name in exec_created.grad_dict, \ - "Shared argument gradient '%s' is not in " \ - "grad_dict of created executor group." % (arg_name) - assert mx.test_utils.same_array(exec_shared.grad_dict[arg_name], exec_created.grad_dict[arg_name]), \ - "Shared argument gradient '%s' does not sharing memory." % (arg_name) + if check_shared_grad: + for arg_name in shared_arg_names: + assert arg_name in exec_created.grad_dict, \ + "Shared argument gradient '%s' is not in " \ + "grad_dict of created executor group." % (arg_name) + assert mx.test_utils.same_array(exec_shared.grad_dict[arg_name], \ + exec_created.grad_dict[arg_name]), \ + "Shared argument gradient '%s' does not share memory." % (arg_name) for arg_name, grad in exec_grp_shared.grad_req.items(): assert grad == exec_grp_created.grad_req[arg_name], \ @@ -417,6 +426,43 @@ def test_shared_exec_group(exec_grp_shared, exec_grp_created, shared_arg_names=N "Shared executor group requires '%s' while created executor group requires '%s'" \ %(arg_name, grad, exec_grp_created.grad_req[arg_name]) + def check_shared_exec_group(sparse_embedding): + # generate an rnn sym with #layers=5 + sym = get_rnn_sym(num_layers=3, num_words=num_words, num_hidden=num_hidden, + num_embed=num_embed, seq_len=max_bucket_size, + sparse_embedding=sparse_embedding) + arg_names1 = sym.list_arguments() + input_names = [name[0] for name in data_shapes] + [name[0] for name in label_shapes] + shared_arg_names = [name for name in arg_names1 if name not in input_names] + exec_group1 = DataParallelExecutorGroup(symbol=sym, contexts=contexts, + workload=workload, data_shapes=data_shapes, + label_shapes=label_shapes, param_names=shared_arg_names, + for_training=True, inputs_need_grad=False) + + # shared_data_arrays should only have input "data" and "softmax_label" arrays + for i in range(len(contexts)): + assert len(exec_group1.shared_data_arrays[i]) == len(input_names),\ + "exec_group1.shared_data_arrays[%d] should have the same number of names as in input_names" % i + for name in input_names: + assert name in exec_group1.shared_data_arrays[i],\ + "arg %s should be in exec_group1.shared_data_arrays[%d]" % (name, i) + + # generate an rnn sym with #layers=5 + sym = get_rnn_sym(num_layers=5, num_words=num_words, num_hidden=num_hidden, + num_embed=num_embed, seq_len=max_bucket_size, + sparse_embedding=sparse_embedding) + arg_names2 = sym.list_arguments() + exec_group2 = DataParallelExecutorGroup(symbol=sym, contexts=contexts, + workload=workload, data_shapes=data_shapes, + label_shapes=label_shapes, param_names=shared_arg_names, + for_training=True, inputs_need_grad=False, + shared_group=exec_group1) + extra_args = [name for name in arg_names2 if name not in shared_arg_names] + check_shared_grad = not sparse_embedding + test_shared_exec_group(exec_grp_shared=exec_group1, exec_grp_created=exec_group2, + shared_arg_names=shared_arg_names, extra_args=extra_args, + check_shared_grad=check_shared_grad) + contexts = [mx.cpu(0), mx.cpu(1)] workload = [1] * len(contexts) batch_size = 32 @@ -426,38 +472,9 @@ def test_shared_exec_group(exec_grp_shared, exec_grp_created, shared_arg_names=N num_embed = 200 data_shapes = [('data', (batch_size, max_bucket_size))] label_shapes = [('softmax_label', (batch_size, max_bucket_size))] - - # generate an rnn sym with #layers=5 - sym = get_rnn_sym(num_layers=3, num_words=num_words, num_hidden=num_hidden, - num_embed=num_embed, seq_len=max_bucket_size) - arg_names1 = sym.list_arguments() - input_names = [name[0] for name in data_shapes] + [name[0] for name in label_shapes] - shared_arg_names = [name for name in arg_names1 if name not in input_names] - exec_group1 = DataParallelExecutorGroup(symbol=sym, contexts=contexts, - workload=workload, data_shapes=data_shapes, - label_shapes=label_shapes, param_names=shared_arg_names, - for_training=True, inputs_need_grad=False) - - # shared_data_arrays should only have input "data" and "softmax_label" arrays - for i in range(len(contexts)): - assert len(exec_group1.shared_data_arrays[i]) == len(input_names),\ - "exec_group1.shared_data_arrays[%d] should have the same number of names as in input_names" % i - for name in input_names: - assert name in exec_group1.shared_data_arrays[i],\ - "arg %s should be in exec_group1.shared_data_arrays[%d]" % (name, i) - - # generate an rnn sym with #layers=5 - sym = get_rnn_sym(num_layers=5, num_words=num_words, num_hidden=num_hidden, - num_embed=num_embed, seq_len=max_bucket_size) - arg_names2 = sym.list_arguments() - exec_group2 = DataParallelExecutorGroup(symbol=sym, contexts=contexts, - workload=workload, data_shapes=data_shapes, - label_shapes=label_shapes, param_names=shared_arg_names, - for_training=True, inputs_need_grad=False, - shared_group=exec_group1) - extra_args = [name for name in arg_names2 if name not in shared_arg_names] - test_shared_exec_group(exec_grp_shared=exec_group1, exec_grp_created=exec_group2, - shared_arg_names=shared_arg_names, extra_args=extra_args) + sparse_embedding_opt = [True, False] + for opt in sparse_embedding_opt: + check_shared_exec_group(opt) def test_factorization_machine_module(verbose=False): diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index 2f0ebd812a1d..4c9ce9c9e848 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -1543,6 +1543,46 @@ def check_sparse_elementwise_sum_with_shape(stype, shape, n): shape = tuple(np.random.randint(5, 10, size=dim)) check_sparse_elementwise_sum_with_shape('row_sparse', shape, np.random.randint(1, 9)) +def test_sparse_embedding(): + ''' test sparse embedding op on cpu ''' + def check_sparse_embedding(executor, weight_ref, data_onehot, grad, density): + # update weight based on density + weight[:] = rand_ndarray(weight.shape, 'row_sparse', density=density) + # check forward + exe_test.forward(is_train=True) + assert_almost_equal(exe_test.outputs[0].asnumpy(), np.dot(data_onehot, weight.asnumpy())) + # check backward + executor.backward([grad]) + assert_almost_equal(grad_map["embed_weight"].asnumpy(), np.dot(data_onehot.T, grad.asnumpy())) + + if default_context().device_type == 'cpu': + densities = [0, 0.5, 1] + in_dim = 50 + out_dim = 3 + batch = 8 + # init executor + data = mx.sym.Variable("data") + weight = mx.sym.Variable("embed_weight", stype='row_sparse') + embed = mx.sym.contrib.SparseEmbedding(data=data, weight=weight, input_dim=in_dim, + output_dim=out_dim, name="embed") + grad_req = {'data': 'null', 'embed_weight': 'write'} + exe_test = embed.simple_bind(default_context(), grad_req=grad_req, data=(batch,)) + arg_map = dict(zip(embed.list_arguments(), exe_test.arg_arrays)) + grad_map = dict(zip(embed.list_arguments(), exe_test.grad_arrays)) + # init data + np_data = np.random.randint(low=0, high=in_dim, size=batch) + np_onehot = np.zeros((batch, in_dim)) + np_onehot[np.arange(batch), np_data] = 1.0 + arg_map["data"][:] = np_data + # init grad + np_grad = np.random.uniform(-1, 1, exe_test.outputs[0].shape) + grad = mx.nd.sparse.zeros('row_sparse', np_grad.shape) + grad[:] = np_grad + # weight + weight = arg_map["embed_weight"] + for density in densities: + check_sparse_embedding(exe_test, weight, np_onehot, grad, density) + def test_scatter_ops(): def csr_get_seen_points(name, csr_array, verbose=False): """Get a unique list of points int he CSR array as well as a