Skip to content

Commit

Permalink
Fix some recursive functions (e.g., reorder_incremental_state) to onl…
Browse files Browse the repository at this point in the history
…y touch each module once (facebookresearch#379)

Summary:
This can happen if a module is registered in more than one place in the network.
Pull Request resolved: facebookresearch#379

Differential Revision: D13154498

Pulled By: myleott

fbshipit-source-id: a35575d1956a46cd35ac8b16a719ad20ac3e380a
  • Loading branch information
myleott authored and facebook-github-bot committed Nov 26, 2018
1 parent 3c19878 commit 14506a8
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
19 changes: 13 additions & 6 deletions fairseq/models/fairseq_incremental_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,26 @@ def reorder_incremental_state(self, incremental_state, new_order):
previous time step. A typical use case is beam search, where the input
order changes between time steps based on the selection of beams.
"""
seen = set()

def apply_reorder_incremental_state(module):
if module != self and hasattr(module, 'reorder_incremental_state'):
module.reorder_incremental_state(
incremental_state,
new_order,
)
if module != self and hasattr(module, 'reorder_incremental_state') \
and module not in seen:
seen.add(module)
module.reorder_incremental_state(incremental_state, new_order)

self.apply(apply_reorder_incremental_state)

def set_beam_size(self, beam_size):
"""Sets the beam size in the decoder and all children."""
if getattr(self, '_beam_size', -1) != beam_size:
seen = set()

def apply_set_beam_size(module):
if module != self and hasattr(module, 'set_beam_size'):
if module != self and hasattr(module, 'set_beam_size') \
and module not in seen:
seen.add(module)
module.set_beam_size(beam_size)

self.apply(apply_set_beam_size)
self._beam_size = beam_size
12 changes: 10 additions & 2 deletions fairseq/models/fairseq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,12 @@ def apply_remove_weight_norm(module):

self.apply(apply_remove_weight_norm)

seen = set()

def apply_make_generation_fast_(module):
if module != self and hasattr(module, 'make_generation_fast_'):
if module != self and hasattr(module, 'make_generation_fast_') \
and module not in seen:
seen.add(module)
module.make_generation_fast_(**kwargs)

self.apply(apply_make_generation_fast_)
Expand All @@ -115,8 +119,12 @@ def train(mode):

def prepare_for_onnx_export_(self, **kwargs):
"""Make model exportable via ONNX trace."""
seen = set()

def apply_prepare_for_onnx_export_(module):
if module != self and hasattr(module, 'prepare_for_onnx_export_'):
if module != self and hasattr(module, 'prepare_for_onnx_export_') \
and module not in seen:
seen.add(module)
module.prepare_for_onnx_export_(**kwargs)

self.apply(apply_prepare_for_onnx_export_)
Expand Down

0 comments on commit 14506a8

Please sign in to comment.