Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TF]Refine LSTMBlockCell to support dynamic rnn #5963

Merged
merged 1 commit into from
Jul 16, 2020

Conversation

lixiaoquan
Copy link
Contributor

@lixiaoquan lixiaoquan commented Jun 30, 2020

  1. Refine conversion of LSTMBlockCell

    1. Make its output follows definition in TensorFlow
    2. Avoid introducing variables which doesn't match any placeholder nodes in TensorFlow graph
  2. About change in test_forward_ptb
    States nodes of LSTMBlockCell in this PB file are actually Constant node.
    TF can feed data to those Constant nodes but relay can't do that, so current conversion of LSTMBockCell introduces extra variables to solve this issue.
    But this causes that relay IR doesn't match original TF graph. This PR solves this issue by converting those states node into placeholders.

cc @kevinthesun @srkreddy1238 @joyalbin Please help to review, thanks

@lixiaoquan lixiaoquan changed the title Refine LSTMBlockCell to support dynamic rnn [TF]Refine LSTMBlockCell to support dynamic rnn Jun 30, 2020
@lixiaoquan lixiaoquan force-pushed the lstm branch 3 times, most recently from 7460648 to 9acbe2c Compare June 30, 2020 23:22
@lixiaoquan
Copy link
Contributor Author

also cc @zhiics @yongwww Could you please help to reivew? Thanks

# Return dummy for those unused values
dummy = tvm.relay.const(0)
return tvm.relay.TupleWrapper(
tvm.relay.Tuple([dummy, next_c, dummy, dummy, dummy, dummy, next_h]), 7)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if the dummy node will be used somehow in some cases. It would be good to generate the real node as TF does for the dummy nodes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def lstm_block_cell(x, cs_prev, h_prev, w, wci, wcf, wco, b, forget_bias=1, cell_clip=3, use_peephole=False, name=None):
  r"""Computes the LSTM cell forward propagation for 1 time step.

  This implementation uses 1 weight matrix and 1 bias vector, and there's an
  optional peephole connection.

  This kernel op implements the following mathematical equations:

  xh = [x, h_prev]
  [i, f, ci, o] = xh * w + b
  f = f + forget_bias

  if not use_peephole:
    wci = wcf = wco = 0

  i = sigmoid(cs_prev * wci + i)
  f = sigmoid(cs_prev * wcf + f)
  ci = tanh(ci)

  cs = ci .* i + cs_prev .* f
  cs = clip(cs, cell_clip)

  o = sigmoid(cs * wco + o)
  co = tanh(cs)
  h = co .* o
  ...
  Returns:
    A tuple of `Tensor` objects (i, cs, f, o, ci, co, h).  
 ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added all return objects, but I just can't test them, because tensorflow.contrib.rnn.LSTMBlockCell.call() only return h and new_states

I think if users uses API to contruct graph, other values won't be used.

(cs_prev, h_prev) = state
(_, cs, _, _, _, _, h) = _lstm_block_cell(
    inputs,
    cs_prev,
    h_prev,
    self._kernel,
    self._bias,
    wci=wci,
    wcf=wcf,
    wco=wco,
    forget_bias=self._forget_bias,
    cell_clip=self._cell_clip,
    use_peephole=self._use_peephole)

new_state = rnn_cell_impl.LSTMStateTuple(cs, h)
return h, new_state

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @yongwww Could you take another look? Thanks

    1. Refine conversion of `LSTMBlockCell`
       1) Make its output follows definition in TensorFlow
       2) Avoid introducing variables which doesn't match any placeholder nodes in TensorFlow graph

    2. About change in test_forward_ptb
       States nodes of LSTMBlockCell in this PB file  are actually Constant node.
       TF can feed data to those Constant nodes but relay can't do that, so current conversion of LSTMBockCell introduces extra variables to solve this issue.
       But this causes that relay IR doesn't match original TF graph. This PR solves this issue by convert those states node into placeholders.
Copy link
Member

@yongwww yongwww left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lixiaoquan
Copy link
Contributor Author

ping @zhiics Could you help to review and merge? Thanks

Copy link
Contributor

@kevinthesun kevinthesun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@kevinthesun kevinthesun merged commit b85c239 into apache:master Jul 16, 2020
@kevinthesun
Copy link
Contributor

Thanks @lixiaoquan @yongwww

trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Aug 26, 2020
1. Refine conversion of `LSTMBlockCell`
       1) Make its output follows definition in TensorFlow
       2) Avoid introducing variables which doesn't match any placeholder nodes in TensorFlow graph

    2. About change in test_forward_ptb
       States nodes of LSTMBlockCell in this PB file  are actually Constant node.
       TF can feed data to those Constant nodes but relay can't do that, so current conversion of LSTMBockCell introduces extra variables to solve this issue.
       But this causes that relay IR doesn't match original TF graph. This PR solves this issue by convert those states node into placeholders.
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Aug 26, 2020
1. Refine conversion of `LSTMBlockCell`
       1) Make its output follows definition in TensorFlow
       2) Avoid introducing variables which doesn't match any placeholder nodes in TensorFlow graph

    2. About change in test_forward_ptb
       States nodes of LSTMBlockCell in this PB file  are actually Constant node.
       TF can feed data to those Constant nodes but relay can't do that, so current conversion of LSTMBockCell introduces extra variables to solve this issue.
       But this causes that relay IR doesn't match original TF graph. This PR solves this issue by convert those states node into placeholders.
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Sep 2, 2020
1. Refine conversion of `LSTMBlockCell`
       1) Make its output follows definition in TensorFlow
       2) Avoid introducing variables which doesn't match any placeholder nodes in TensorFlow graph

    2. About change in test_forward_ptb
       States nodes of LSTMBlockCell in this PB file  are actually Constant node.
       TF can feed data to those Constant nodes but relay can't do that, so current conversion of LSTMBockCell introduces extra variables to solve this issue.
       But this causes that relay IR doesn't match original TF graph. This PR solves this issue by convert those states node into placeholders.
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Sep 3, 2020
1. Refine conversion of `LSTMBlockCell`
       1) Make its output follows definition in TensorFlow
       2) Avoid introducing variables which doesn't match any placeholder nodes in TensorFlow graph

    2. About change in test_forward_ptb
       States nodes of LSTMBlockCell in this PB file  are actually Constant node.
       TF can feed data to those Constant nodes but relay can't do that, so current conversion of LSTMBockCell introduces extra variables to solve this issue.
       But this causes that relay IR doesn't match original TF graph. This PR solves this issue by convert those states node into placeholders.
@apivovarov
Copy link
Contributor

Tensorflow2 LSTM failed - the discussion is here https://discuss.tvm.apache.org/t/tensorflow2-lstm-failed/11174

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants