From 9e4e7fcac5b1f3200beff6c686111b17a99e5c48 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Wed, 14 Nov 2018 05:53:57 -0800 Subject: [PATCH] Fix dummy batch when --max-tokens is small (fixes #347) --- fairseq/data/language_pair_dataset.py | 2 +- fairseq/data/monolingual_dataset.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fairseq/data/language_pair_dataset.py b/fairseq/data/language_pair_dataset.py index 96929a3dc3..3361dec26a 100644 --- a/fairseq/data/language_pair_dataset.py +++ b/fairseq/data/language_pair_dataset.py @@ -192,7 +192,7 @@ def get_dummy_batch(self, num_tokens, max_positions, src_len=128, tgt_len=128): max_positions, (self.max_source_positions, self.max_target_positions), ) - bsz = num_tokens // max(src_len, tgt_len) + bsz = max(num_tokens // max(src_len, tgt_len), 1) return self.collater([ { 'id': i, diff --git a/fairseq/data/monolingual_dataset.py b/fairseq/data/monolingual_dataset.py index 0635eee741..3c915853c1 100644 --- a/fairseq/data/monolingual_dataset.py +++ b/fairseq/data/monolingual_dataset.py @@ -153,7 +153,7 @@ def get_dummy_batch(self, num_tokens, max_positions, tgt_len=128): """Return a dummy batch with a given number of tokens.""" if isinstance(max_positions, float) or isinstance(max_positions, int): tgt_len = min(tgt_len, max_positions) - bsz = num_tokens // tgt_len + bsz = max(num_tokens // tgt_len, 1) target = self.vocab.dummy_sentence(tgt_len + 2) source, past_target, future_target = target[1:-1], target[2:], target[:-2] source, target = self._make_source_target(source, past_target, future_target)