diff --git a/modeling/models/rnn/torch_rnn_net.py b/modeling/models/rnn/torch_rnn_net.py index 50a9b4f..2d8db46 100644 --- a/modeling/models/rnn/torch_rnn_net.py +++ b/modeling/models/rnn/torch_rnn_net.py @@ -306,5 +306,5 @@ def zip_collate(batch): def get_test_loader(self, rnn_dataset): return DataLoader( - rnn_dataset, batch_size=3200, shuffle=False, num_workers=2, collate_fn=self.zip_collate, pin_memory=True + rnn_dataset, batch_size=1024, shuffle=False, num_workers=2, collate_fn=self.zip_collate, pin_memory=True )