From 3707c035acccf7ba19723de9e9e8f73641bc6305 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Wed, 14 Nov 2018 07:41:22 -0800 Subject: [PATCH] Fix dummy batch when --max-tokens is small (fixes #347) Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/366 Differential Revision: D13058513 Pulled By: myleott fbshipit-source-id: a146d2cfb345d404775ed8d6b8e4a4ad4e7a33b4 --- fairseq/data/language_pair_dataset.py | 2 +- fairseq/data/monolingual_dataset.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fairseq/data/language_pair_dataset.py b/fairseq/data/language_pair_dataset.py index 96929a3dc3..3361dec26a 100644 --- a/fairseq/data/language_pair_dataset.py +++ b/fairseq/data/language_pair_dataset.py @@ -192,7 +192,7 @@ def get_dummy_batch(self, num_tokens, max_positions, src_len=128, tgt_len=128): max_positions, (self.max_source_positions, self.max_target_positions), ) - bsz = num_tokens // max(src_len, tgt_len) + bsz = max(num_tokens // max(src_len, tgt_len), 1) return self.collater([ { 'id': i, diff --git a/fairseq/data/monolingual_dataset.py b/fairseq/data/monolingual_dataset.py index 0635eee741..3c915853c1 100644 --- a/fairseq/data/monolingual_dataset.py +++ b/fairseq/data/monolingual_dataset.py @@ -153,7 +153,7 @@ def get_dummy_batch(self, num_tokens, max_positions, tgt_len=128): """Return a dummy batch with a given number of tokens.""" if isinstance(max_positions, float) or isinstance(max_positions, int): tgt_len = min(tgt_len, max_positions) - bsz = num_tokens // tgt_len + bsz = max(num_tokens // tgt_len, 1) target = self.vocab.dummy_sentence(tgt_len + 2) source, past_target, future_target = target[1:-1], target[2:], target[:-2] source, target = self._make_source_target(source, past_target, future_target)