-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Higher Level API for RNN #3930
Comments
Great! Thanks a lot! We should long have this built in the standard package instead of bare-bone Also if I understand correctly, the class |
@pluskid Yes, the RNN class here is a symbol constructor. I'm also thinking about the naming issue. May be call it "RNNFactory" in order to distinguish from the cudnn version? |
@leezu @jennyzhang0215 Let's PR by |
Great! I remember CUDNN RNN cell need to transpose data before input, so what the input shape of this operator? And I find current NON-CUDNN version will also use 30%~50% GPU-Util, and its not easy to use 100% GPU. You mentioned you have done speed test against CUDNN version, is there any reports? |
@xlvector I find that cudnn will be 3/6 times faster than the original implementation. The input shape is chosen to be the same as cudnn, i.e, |
The iterators defined in https://github.com/dmlc/mxnet/blob/master/example/rnn-time-major/bucket_io.py are helpful to more easily get the correct (time-major) input shape. One could also add (a more general version) to https://github.com/dmlc/mxnet/blob/master/python/mxnet/io.py . What do you think? |
My understanding is that RNN input and output shape is already defined in infershape API in RNN-inl.h. So I am using them in my RNN implementation. |
@sxjscience: you should be explicit by what dropout you are referring to. Is it the version from " |
@sbodenstein Currently we've just coded the old version. The dropout method in the NIPS paper should be added later. Also, the performance boost in the paper is partially due to the dropout method for the embedding layer and we can also support it. |
@ZhenlinGuo The problem may be that by doing this we cannot support stacked RNNs with different number of hidden state sizes. The API will be easier to use if we separate the weight, bias and states. |
Hi all, I have one question, if I run RNN layer and return is combination of kOut, kStateOut, kStateCellOut, then in python, how I can get the topright of kStateOut and kStateCellout? What structure or API I can use? In rnn_cell_demo, this sentence is used after call RNN. But I don't quite understand it. For seq2seq, h, y, c will be output, but how I can just get topright of h and c from stream? |
@sxjscience Hi, What do you mean by have "3/6 times faster" speed up with CUDNN? 3-6 times speed up? |
@magic282 Yes, cudnn is faster. |
@sxjscience I strongly suggest that we have a benchmark for RNN related models, such as RNN LM, s2s, and attention, comparing with popular Theano and TF. Actually I have implemented s2s+attention NMT model with mxnet and got state-of-the-art result on IWSLT data. However, the training speed is not very fast. According to this benchmark https://github.com/guolinke/deep-learning-benchmarks/blob/master/result.md , mxnet is faster than other tools for the FCN models but slower for LSTM. I am really curious about the reason. But this is just a reference since I don't really believe that Theano is much faster than Torch. And about the CUDNN speed up. Our group has built a DL framework from scratch and we are planning to leverage CUDNN to speed up training. We can only get 2 times speed up at most since our tool is already very fast for RNN models (because of some optimization for RNN). So I am thinking that if mxnet can have some optimization for RNN models without the support of CUDNN? |
@magic282 Yes, we need to include such a benchmark. Would you mind sharing your implementation? In fact, we haven't got an s2s + attention example yet. It will be easier for us to investigate whether the speed is not that satisfactory if we have the example code. |
@sxjscience Sure, I will refactor the code and share it and hope that we can speed it up. |
@magic282 , do you have early code which call RNN layer rather than lstm_unroll python? I already implemented mkl based RNN for CPU and going to release later. But I am lack of s2s model to test the perf. |
@zhenlinluo Nope. |
@magic282 since you have s2s model based on cudnn RNN layer, could you pls help to answer my question about what need to be returned after call mx.sym.RNN(xxxx, state_output=True) in encoder and decoder func? Since output Tblob will include kOut, kStateOut and kCell, but I just need to return 2D array kStateOut and kCell on top-right cell as input to decoder, how to do that? |
@zhenlinluo I don't have s2s base on cudnn RNN symbol since I implemented LSTM or GRU using the basic OP. |
@mli @piiswrong Do you know how to get specific data when outputs are multiple in python? |
@sxjscience Hi, I have uploaded the s2s + attention code. It might have some bugs since this time I can only get 44 BLEU score. Link, https://github.com/magic282/MXNMT |
@magic282 Is it possible to revise the code to run the WMT14 or WMT15 dataset? We'd better do pair-test with TensorFlow. |
Actually I am not very familiar with MT task. Could you please provide the data format? Is it the reference number issue? |
@magic282 We may need to refer to their seq2seq example. https://www.tensorflow.org/versions/r0.12/tutorials/seq2seq/index.html |
Regarding this discussion (re: Yarin Gal's RNN implementation):
According to Yarin, his implementation is already in Keras, TensorFlow, and Torch. I think it basically uses the same dropout mask across timesteps on RNNs. The implementations might include examples to handle RNN layers of different sizes, and I also speculate the potential of allowing for uncertainty in RNN predictions as that research continues. |
@predict-r Thanks very much! I'm busy with my PQE exam and need to work on this after I finish the exam. I find that the "dropout" used in the paper is actually a type of "DropConnect"(http://www.jmlr.org/proceedings/papers/v28/wan13.pdf) that directly masks the weight matrix. I'm thinking to keep its original name. |
@sxjscience Sorry for late reply. I think the code will run if we have parallel corpus. Did it fail when you try it? |
@sxjscience It might be called DropConnect, but apparently its even older. Here is Yann LeCun's writeup on the history: https://www.facebook.com/yann.lecun/posts/10154058859142143 . |
@sxjscience Hey, what is the progress on the API for RNN? Is it complete or still in progress? |
@karishmamalkan Still in progress, decide to add also the example of batchnorm. Will finish it after I finish the PQE. You can view some codes here. https://github.com/ML-HK/mxnet/blob/master/python/mxnet/recurrent.py |
Thanks @sxjscience I wanted to know is this a working version of the code. When i try to import recurrent.py, i get an error about importing "utils".. Is something missing? |
We've created a higher level API for recurrent neural networks and have completed gradient tests, forward test and speed comparison against CuDNN. The class definition and key methods look like this:
We decide to Pull Request this feature @leezu .
Should we create a new directory under "python/mxnet" like "operators" to store these kind of composed symbols? What do you think? @pluskid @piiswrong @tqchen @xlvector @sbodenstein
The text was updated successfully, but these errors were encountered: