Skip to content

Commit

Permalink
fix save steps arg bug (PaddlePaddle#44)
Browse files Browse the repository at this point in the history
  • Loading branch information
LiuChiachi authored Mar 1, 2021
1 parent 1e955cb commit dd98f7f
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 11 deletions.
8 changes: 4 additions & 4 deletions examples/model_compression/distill_lstm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,13 @@ python -u ./run_bert_finetune.py \
CUDA_VISIBLE_DEVICES=0 python small.py \
--task_name senta \
--max_epoch 20 \
--vocab_size 29496 \
--vocab_size 1256608 \
--batch_size 64 \
--model_name bert-wwm-ext-chinese \
--optimizer adam \
--lr 3e-4 \
--dropout_prob 0.2 \
--vocab_path senta_word_dict_subset.txt \
--vocab_path senta_word_dict.txt \
--output_dir small_models/senta/

```
Expand Down Expand Up @@ -131,14 +131,14 @@ CUDA_VISIBLE_DEVICES=0 python small.py \
```shell
CUDA_VISIBLE_DEVICES=0 python bert_distill.py \
--task_name senta \
--vocab_size 29496 \
--vocab_size 1256608 \
--max_epoch 6 \
--lr 1.0 \
--dropout_prob 0.1 \
--batch_size 64 \
--model_name bert-wwm-ext-chinese \
--teacher_path pretrained_models/senta/best_bert_wwm_ext_model_880/model_state.pdparams \
--vocab_path senta_word_dict_subset.txt \
--vocab_path senta_word_dict.txt \
--output_dir distilled_models/senta

```
Expand Down
7 changes: 3 additions & 4 deletions examples/model_compression/distill_lstm/bert_distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def do_train(agrs):
for epoch in range(args.max_epoch):
model.train()
for i, batch in enumerate(train_data_loader):
global_step += 1
if args.task_name == 'qqp':
bert_input_ids, bert_segment_ids, student_input_ids_1, seq_len_1, student_input_ids_2, seq_len_2, labels = batch
else:
Expand All @@ -139,7 +140,7 @@ def do_train(agrs):
optimizer.step()
optimizer.clear_grad()

if i % args.log_freq == 0:
if global_step % args.log_freq == 0:
print(
"global step %d, epoch: %d, batch: %d, loss: %f, speed: %.4f step/s"
% (global_step, epoch, i, loss,
Expand All @@ -149,7 +150,7 @@ def do_train(agrs):
print("eval done total : %s s" % (time.time() - tic_eval))
tic_train = time.time()

if i % args.save_steps == 0:
if global_step % args.save_steps == 0:
paddle.save(
model.state_dict(),
os.path.join(args.output_dir,
Expand All @@ -158,8 +159,6 @@ def do_train(agrs):
os.path.join(args.output_dir,
"step_" + str(global_step) + ".pdopt"))

global_step += 1


if __name__ == '__main__':
args = parse_args()
Expand Down
6 changes: 3 additions & 3 deletions examples/model_compression/distill_lstm/small.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def do_train(args):
tic_train = time.time()
for epoch in range(args.max_epoch):
for i, batch in enumerate(train_data_loader):
global_step += 1
if args.task_name == 'qqp':
input_ids_1, seq_len_1, input_ids_2, seq_len_2, labels = batch
logits = model(input_ids_1, seq_len_1, input_ids_2, seq_len_2)
Expand All @@ -188,7 +189,7 @@ def do_train(args):
optimizer.step()
optimizer.clear_grad()

if i % args.log_freq == 0:
if global_step % args.log_freq == 0:
with paddle.no_grad():
print(
"global step %d, epoch: %d, batch: %d, loss: %f, speed: %.4f step/s"
Expand All @@ -201,15 +202,14 @@ def do_train(args):
print("eval done total : %s s" % (time.time() - tic_eval))
tic_train = time.time()

if i % args.save_steps == 0:
if global_step % args.save_steps == 0:
paddle.save(
model.state_dict(),
os.path.join(args.output_dir,
"step_" + str(global_step) + ".pdparams"))
paddle.save(optimizer.state_dict(),
os.path.join(args.output_dir,
"step_" + str(global_step) + ".pdopt"))
global_step += 1


if __name__ == '__main__':
Expand Down

0 comments on commit dd98f7f

Please sign in to comment.