-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TF]Refine LSTMBlockCell to support dynamic rnn #5963
Conversation
7460648
to
9acbe2c
Compare
# Return dummy for those unused values | ||
dummy = tvm.relay.const(0) | ||
return tvm.relay.TupleWrapper( | ||
tvm.relay.Tuple([dummy, next_c, dummy, dummy, dummy, dummy, next_h]), 7) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if the dummy node will be used somehow in some cases. It would be good to generate the real node as TF does for the dummy nodes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def lstm_block_cell(x, cs_prev, h_prev, w, wci, wcf, wco, b, forget_bias=1, cell_clip=3, use_peephole=False, name=None):
r"""Computes the LSTM cell forward propagation for 1 time step.
This implementation uses 1 weight matrix and 1 bias vector, and there's an
optional peephole connection.
This kernel op implements the following mathematical equations:
xh = [x, h_prev]
[i, f, ci, o] = xh * w + b
f = f + forget_bias
if not use_peephole:
wci = wcf = wco = 0
i = sigmoid(cs_prev * wci + i)
f = sigmoid(cs_prev * wcf + f)
ci = tanh(ci)
cs = ci .* i + cs_prev .* f
cs = clip(cs, cell_clip)
o = sigmoid(cs * wco + o)
co = tanh(cs)
h = co .* o
...
Returns:
A tuple of `Tensor` objects (i, cs, f, o, ci, co, h).
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added all return objects, but I just can't test them, because tensorflow.contrib.rnn.LSTMBlockCell.call() only return h and new_states
I think if users uses API to contruct graph, other values won't be used.
(cs_prev, h_prev) = state
(_, cs, _, _, _, _, h) = _lstm_block_cell(
inputs,
cs_prev,
h_prev,
self._kernel,
self._bias,
wci=wci,
wcf=wcf,
wco=wco,
forget_bias=self._forget_bias,
cell_clip=self._cell_clip,
use_peephole=self._use_peephole)
new_state = rnn_cell_impl.LSTMStateTuple(cs, h)
return h, new_state
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @yongwww Could you take another look? Thanks
1. Refine conversion of `LSTMBlockCell` 1) Make its output follows definition in TensorFlow 2) Avoid introducing variables which doesn't match any placeholder nodes in TensorFlow graph 2. About change in test_forward_ptb States nodes of LSTMBlockCell in this PB file are actually Constant node. TF can feed data to those Constant nodes but relay can't do that, so current conversion of LSTMBockCell introduces extra variables to solve this issue. But this causes that relay IR doesn't match original TF graph. This PR solves this issue by convert those states node into placeholders.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
ping @zhiics Could you help to review and merge? Thanks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Thanks @lixiaoquan @yongwww |
1. Refine conversion of `LSTMBlockCell` 1) Make its output follows definition in TensorFlow 2) Avoid introducing variables which doesn't match any placeholder nodes in TensorFlow graph 2. About change in test_forward_ptb States nodes of LSTMBlockCell in this PB file are actually Constant node. TF can feed data to those Constant nodes but relay can't do that, so current conversion of LSTMBockCell introduces extra variables to solve this issue. But this causes that relay IR doesn't match original TF graph. This PR solves this issue by convert those states node into placeholders.
1. Refine conversion of `LSTMBlockCell` 1) Make its output follows definition in TensorFlow 2) Avoid introducing variables which doesn't match any placeholder nodes in TensorFlow graph 2. About change in test_forward_ptb States nodes of LSTMBlockCell in this PB file are actually Constant node. TF can feed data to those Constant nodes but relay can't do that, so current conversion of LSTMBockCell introduces extra variables to solve this issue. But this causes that relay IR doesn't match original TF graph. This PR solves this issue by convert those states node into placeholders.
1. Refine conversion of `LSTMBlockCell` 1) Make its output follows definition in TensorFlow 2) Avoid introducing variables which doesn't match any placeholder nodes in TensorFlow graph 2. About change in test_forward_ptb States nodes of LSTMBlockCell in this PB file are actually Constant node. TF can feed data to those Constant nodes but relay can't do that, so current conversion of LSTMBockCell introduces extra variables to solve this issue. But this causes that relay IR doesn't match original TF graph. This PR solves this issue by convert those states node into placeholders.
1. Refine conversion of `LSTMBlockCell` 1) Make its output follows definition in TensorFlow 2) Avoid introducing variables which doesn't match any placeholder nodes in TensorFlow graph 2. About change in test_forward_ptb States nodes of LSTMBlockCell in this PB file are actually Constant node. TF can feed data to those Constant nodes but relay can't do that, so current conversion of LSTMBockCell introduces extra variables to solve this issue. But this causes that relay IR doesn't match original TF graph. This PR solves this issue by convert those states node into placeholders.
Tensorflow2 LSTM failed - the discussion is here https://discuss.tvm.apache.org/t/tensorflow2-lstm-failed/11174 |
Refine conversion of
LSTMBlockCell
About change in test_forward_ptb
States nodes of LSTMBlockCell in this PB file are actually Constant node.
TF can feed data to those Constant nodes but relay can't do that, so current conversion of LSTMBockCell introduces extra variables to solve this issue.
But this causes that relay IR doesn't match original TF graph. This PR solves this issue by converting those states node into placeholders.
cc @kevinthesun @srkreddy1238 @joyalbin Please help to review, thanks