From 92910b8361a6ca90b10215b047823aedf9a9b500 Mon Sep 17 00:00:00 2001 From: haohongxiang Date: Wed, 28 Sep 2022 04:54:37 +0000 Subject: [PATCH] code style check --- examples/language_model/gpt-3/dygraph/run_pretrain.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/language_model/gpt-3/dygraph/run_pretrain.py b/examples/language_model/gpt-3/dygraph/run_pretrain.py index bab42ce039a5..bf8a29342a55 100644 --- a/examples/language_model/gpt-3/dygraph/run_pretrain.py +++ b/examples/language_model/gpt-3/dygraph/run_pretrain.py @@ -153,7 +153,7 @@ def do_train(args): dp_rank = hcg.get_data_parallel_rank() sharding_rank = hcg.get_sharding_parallel_rank() - # sharding stage2/3 not support hybrid parallel + # sharding stage2/3 not support hybrid parallel now if args.sharding_stage in [2, 3]: assert args.mp_degree == args.pp_degree == 1, "sharding stage2/3 will support tensor/pipeline parallel later" dp_group = hcg.get_data_parallel_group() @@ -279,8 +279,9 @@ def do_train(args): # TODO(Baibaifan): combine ShardingStage1/2/3 and fleet.distributed_model in feature if args.sharding_stage in [2, 3]: if args.dp_degree > 1: - sync_params_buffers( - model, comm_group=dp_group, src_rank=dp_group.ranks[0]) + sync_params_buffers(model, + comm_group=dp_group, + src_rank=dp_group.ranks[0]) scaler = scaler if args.use_pure_fp16 else None model, optimizer, scaler = wrap_sharding_2_3(model, optimizer, scaler,