Skip to content

Commit

Permalink
bug fixed for batch size argument
Browse files Browse the repository at this point in the history
  • Loading branch information
seujung committed Mar 11, 2021
1 parent 190c23a commit 3110b70
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ def add_model_specific_args(parent_parser):

parser.add_argument('--batch_size',
type=int,
default=14,
default=28,
help='')

parser.add_argument('--max_len',
type=int,
default=512,
Expand Down Expand Up @@ -175,7 +176,6 @@ def __init__(self, hparams, **kwargs):
self.tokenizer = get_kobart_tokenizer()

def forward(self, inputs):

attention_mask = inputs['input_ids'].ne(self.pad_token_id).float()
decoder_attention_mask = inputs['decoder_input_ids'].ne(self.pad_token_id).float()

Expand Down Expand Up @@ -217,6 +217,7 @@ def validation_epoch_end(self, outputs):
args.test_file,
None,
max_len=args.max_len,
batch_size=args.batch_size,
num_workers=args.num_workers)

checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor='val_loss',
Expand Down

0 comments on commit 3110b70

Please sign in to comment.