Skip to content

Commit

Permalink
[Unified Checkpoint] Fix fp32 dtype for using newest paddle(PaddlePad…
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay authored Nov 4, 2024
1 parent 1aa91be commit e2cc3d5
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 42 deletions.
7 changes: 1 addition & 6 deletions paddlenlp/trainer/unified_checkpoint/check_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,6 @@
from paddlenlp.utils.log import logger
from paddlenlp.utils.nested import flatten_list

try:
from paddle.base import core
except:
core = None

from .utils import (
get_expected_state_dict,
is_sharding_split_param_mode,
Expand Down Expand Up @@ -200,7 +195,7 @@ def check_dynamic_load(args, weight_map, existed_files, is_master_weights=False,
if args.use_expert_parallel and dp_rank > 0 and not getattr(state_dict[key], "no_sync", False):
continue

if is_master_weights and state_dict[key].dtype == core.VarDesc.VarType.FP32:
if is_master_weights and state_dict[key].dtype == paddle.float32:
continue

if not is_master_weights:
Expand Down
7 changes: 1 addition & 6 deletions paddlenlp/trainer/unified_checkpoint/load_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,6 @@
import paddle.distributed as dist
from paddle.distributed import fleet

try:
from paddle.base import core
except:
core = None

from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM
from paddlenlp.transformers.model_utils import _load_state_dict_into_model
from paddlenlp.transformers.utils import device_guard, is_safetensors_available
Expand Down Expand Up @@ -474,7 +469,7 @@ def check_optimizer_param(parameter):
key_name = key.split("/")
static_name = struct2static_name_mappings[key_name[0]]
if has_master_weights:
if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32:
if model_state_dict[key_name[0]].dtype != paddle.float32:
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])
else:
key_name = "_".join([static_name, key_name[1]])
Expand Down
8 changes: 2 additions & 6 deletions paddlenlp/trainer/unified_checkpoint/load_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,9 @@
import gc
import os

import paddle
from tqdm.auto import tqdm

try:
from paddle.base import core
except:
core = None

from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM
from paddlenlp.transformers.model_utils import (
_load_state_dict_into_model,
Expand Down Expand Up @@ -252,7 +248,7 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
key_name = key.split("/")
static_name = struct2static_name_mappings[key_name[0]]
if has_master_weights:
if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32:
if model_state_dict[key_name[0]].dtype != paddle.float32:
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])
else:
key_name = "_".join([static_name, key_name[1]])
Expand Down
9 changes: 2 additions & 7 deletions paddlenlp/trainer/unified_checkpoint/load_save_single_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@

import paddle

try:
from paddle.base import core
except:
core = None

from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM
from paddlenlp.transformers.model_utils import (
_load_state_dict_into_model,
Expand Down Expand Up @@ -120,7 +115,7 @@ def save_single_card_optimizer(model, optimizer, output_dir):
fp32_weight = {}
for k, v in state_dict.items():
static2struct_name_mappings[v.name] = k
if master_weights is not None and v.dtype == core.VarDesc.VarType.FP32:
if master_weights is not None and v.dtype == paddle.float32:
fp32_weight[k] = v

# rename optimizer param
Expand Down Expand Up @@ -226,7 +221,7 @@ def load_single_card_optimizer(model, optimizer, resume_from_checkpoint: str):
key_name = key.split("/")
static_name = struct2static_name_mappings[key_name[0]]
if has_master_weights:
if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32:
if model_state_dict[key_name[0]].dtype != paddle.float32:
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])
else:
key_name = "_".join([static_name, key_name[1]])
Expand Down
9 changes: 2 additions & 7 deletions paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@
import paddle
from paddle.distributed import fleet

try:
from paddle.base import core
except:
core = None

from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM
from paddlenlp.trainer.argparser import strtobool
from paddlenlp.trainer.utils.helper import distributed_isfile
Expand Down Expand Up @@ -281,7 +276,7 @@ def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint):
key_name = key.split("/")
static_name = struct2static_name_mappings[key_name[0]]
if has_master_weights:
if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32:
if model_state_dict[key_name[0]].dtype != paddle.float32:
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])
else:
key_name = "_".join([static_name, key_name[1]])
Expand Down Expand Up @@ -529,7 +524,7 @@ def unified_optimizer_into_shards(
fp32_weight = {}
for k, v in state_dict.items():
static2struct_name_mappings[v.name] = k
if master_weights is not None and v.dtype == core.VarDesc.VarType.FP32:
if master_weights is not None and v.dtype == paddle.float32:
if args.dataset_rank > 0: # deal with different dataset rank.
continue
fp32_weight[k] = v
Expand Down
11 changes: 1 addition & 10 deletions paddlenlp/trainer/unified_checkpoint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,6 @@
import paddle.distributed as dist
from paddle.distributed import fleet

try:
from paddle.base import core
except:
core = None

from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM
from paddlenlp.trainer.trainer_utils import ExplicitEnum, ShardingOption
from paddlenlp.trainer.utils.helper import distributed_isfile
Expand Down Expand Up @@ -231,11 +226,7 @@ def get_expected_keys(args, sharded_metadata, model, optimizer, is_master_weight
expected_keys = []
for key in list(sharded_metadata["all_optimizer_keys"]):
key_name = key.split("/")[0]
if (
is_master_weights
and key_name in model_state_dict
and model_state_dict[key_name].dtype == core.VarDesc.VarType.FP32
):
if is_master_weights and key_name in model_state_dict and model_state_dict[key_name].dtype == paddle.float32:
continue

if args.use_expert_parallel and args.data_parallel_rank > 0:
Expand Down

0 comments on commit e2cc3d5

Please sign in to comment.