Skip to content

Commit

Permalink
[Relay][Frontend] Fix tensorflow frontend lstm forget bias adding ord…
Browse files Browse the repository at this point in the history
…er (#3410)
  • Loading branch information
ttyang1018 authored and tqchen committed Jun 27, 2019
1 parent 6c43019 commit 7db5779
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
5 changes: 2 additions & 3 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1437,9 +1437,8 @@ def _impl(inputs, in_state_c, in_state_h, attr, params):
gate_list = _op.split(gates_bias, indices_or_sections=4, axis=1)
in_gate = _op.sigmoid(gate_list[0])
in_transform = _op.tanh(gate_list[1])
forget_gate = _op.sigmoid(gate_list[2])
forget_gate = _op.add(forget_gate,
tvm.relay.const(forget_bias, attr['T'].name))
forget_gate = _op.add(gate_list[2], tvm.relay.const(forget_bias, attr['T'].name))
forget_gate = _op.sigmoid(forget_gate)
out_gate = _op.sigmoid(gate_list[3])
next_c = _op.add(_op.multiply(forget_gate, in_state_c),
_op.multiply(in_gate, in_transform))
Expand Down
2 changes: 1 addition & 1 deletion tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,7 +1183,7 @@ def _get_tensorflow_output():

def test_forward_lstm():
'''test LSTM block cell'''
_test_lstm_cell(1, 2, 1, 0.0, 'float32')
_test_lstm_cell(1, 2, 1, 0.5, 'float32')



Expand Down

0 comments on commit 7db5779

Please sign in to comment.