Skip to content

Commit

Permalink
[Unified Checkpoint] update non-merge checkpoint loading, move async_…
Browse files Browse the repository at this point in the history
…save_info.json location (PaddlePaddle#9321)

* [Unified checkpoint] update optimizer async save signal

* update async_save_info.json file place
  • Loading branch information
DesmonDay authored Oct 28, 2024
1 parent 394aada commit ce3a1ce
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 27 deletions.
12 changes: 7 additions & 5 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2308,7 +2308,7 @@ def save_model(
if output_dir is None:
output_dir = self.args.output_dir

if PREFIX_CHECKPOINT_DIR in output_dir:
if PREFIX_CHECKPOINT_DIR in os.path.split(output_dir)[-1]:
signal_dir = os.path.join(self.args.output_signal_dir, os.path.split(output_dir)[-1])
else:
signal_dir = self.args.output_signal_dir
Expand Down Expand Up @@ -2606,7 +2606,7 @@ def _save(
# signal_dir is used for asynchronous saving situations.
signal_dir = self.args.output_signal_dir
if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config:
if PREFIX_CHECKPOINT_DIR in output_dir:
if PREFIX_CHECKPOINT_DIR in os.path.split(output_dir)[-1]:
signal_dir = os.path.join(signal_dir, os.path.split(output_dir)[-1])
os.makedirs(signal_dir, exist_ok=True)
logger.info(f"Saving model checkpoint finish signal to {signal_dir}")
Expand All @@ -2626,9 +2626,11 @@ def _save(
"ignore_save_lr_and_optim": self.args.ignore_save_lr_and_optim,
"skip_save_model_weight": "skip_save_model_weight" in self.args.unified_checkpoint_config,
}
if os.path.exists(os.path.join(signal_dir, "async_save_info.json")): # afs cannot overwrite
os.remove(os.path.join(signal_dir, "async_save_info.json"))
with open(os.path.join(signal_dir, "async_save_info.json"), "w") as f:
if os.path.exists(
os.path.join(self.args.output_signal_dir, "async_save_info.json")
): # afs cannot overwrite
os.remove(os.path.join(self.args.output_signal_dir, "async_save_info.json"))
with open(os.path.join(self.args.output_signal_dir, "async_save_info.json"), "w") as f:
json.dump(save_info, f)

if self.args.should_save:
Expand Down
28 changes: 6 additions & 22 deletions paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import copy
import json
import os
import sys

import paddle
from paddle.distributed import fleet
Expand All @@ -31,13 +30,10 @@
from paddlenlp.transformers.model_utils import (
PretrainedModel,
_add_variant,
load_state_dict,
unwrap_model,
)
from paddlenlp.transformers.utils import (
device_guard,
dtype_byte_size,
is_safetensors_available,
)
from paddlenlp.transformers.utils import dtype_byte_size
from paddlenlp.utils.env import (
LORA_WEIGHTS_NAME,
PADDLE_MASTER_WEIGHTS_NAME,
Expand All @@ -56,12 +52,6 @@
from paddlenlp.utils.log import logger
from paddlenlp.utils.nested import nested_copy

if is_safetensors_available():
if sys.platform.startswith("win"):
from safetensors.numpy import load_file
else:
from paddlenlp.utils.safetensors import fast_load_file as load_file

from .async_handler import AsyncCheckpointHandler
from .check_completion import check_unified_checkpoint, check_unified_optimizer
from .load_dynamic import (
Expand Down Expand Up @@ -282,9 +272,9 @@ def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint):

model_state_dict = get_expected_state_dict(model)
struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} # get optimizer param mappings
optimizer_state_dict = load_file(optimizer_path)
optimizer_state_dict = load_state_dict(optimizer_path, None, None, device="expected")
if has_master_weights:
master_weights = load_file(master_weights_path)
master_weights = load_state_dict(master_weights_path, None, None, device="expected")

# rename and move to paddle.Tensor
for key in list(optimizer_state_dict.keys()):
Expand All @@ -297,20 +287,14 @@ def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint):
key_name = "_".join([static_name, key_name[1]])
else:
key_name = "_".join([static_name, key_name[1]])
with device_guard():
weight = paddle.Tensor(optimizer_state_dict.pop(key), zero_copy=True)
weight = weight._copy_to(paddle.framework._current_expected_place(), False)
returned_optim_state_dict[key_name] = weight
returned_optim_state_dict[key_name] = optimizer_state_dict.pop(key)
returned_optim_state_dict[key_name].name = key_name

if has_master_weights:
returned_optim_state_dict["master_weights"] = {}
for key in list(master_weights.keys()):
static_name = struct2static_name_mappings[key]
with device_guard():
weight = paddle.Tensor(master_weights.pop(key), zero_copy=True)
weight = weight._copy_to(paddle.framework._current_expected_place(), False)
returned_optim_state_dict["master_weights"][static_name] = weight
returned_optim_state_dict["master_weights"][static_name] = master_weights.pop(key)
returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER])

return returned_optim_state_dict
Expand Down

0 comments on commit ce3a1ce

Please sign in to comment.