Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed SaveLoad implementation for semi-auto strategy #59659

Merged
merged 45 commits into from
Dec 7, 2023

Conversation

pangengzheng
Copy link
Contributor

@pangengzheng pangengzheng commented Dec 4, 2023

PR types

Others

PR changes

Others

Description

card-78318
Design the save_state_dict and load_state_dict api to support save and load checkpoint of dynamic and static graph semi-auto distributed training.

Copy link

paddle-bot bot commented Dec 4, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

paddle.distributed.get_world_size() > 1 or coordinator_rank != 0
):
raise ValueError(
f"use_dist is False, please set coordinator_rank to 0 and paddle.distributed.get_world_size() to 1, world_size:{paddle.distributed.get_world_size()}, coordinator_rank:{coordinator_rank}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not allow use_dist=false and world_size > 1?

Copy link
Contributor Author

@pangengzheng pangengzheng Dec 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_dist是针对单卡的情况的,但貌似不需要用户指定,在内部通过use_dist=True if world_size>1 else False来确定就行。save_state_dict的设计是导出当前训练时候分布式策略下的模型,如果当前是分布式的就导出分布式的,如果是单卡的就导出单卡的,不支持直接在分布式的情况下导出单卡模型,如果需要导出单卡模型,需要先定义单卡模型,用load_state_dict加载再用save_state_dict导出即可

return tuple(local_shape), tuple(global_offset)


def flatten_state_dict(state_dict):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WHY return directly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是个TODO,为了支持state_dict={"model":model.state_dict(), "optimizer":optimizer.state_dict()}这种情况,但目前还未实现,先不对传入的state_dict进行操作

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

python/paddle/distributed/checkpoint/save_state_dict.py Outdated Show resolved Hide resolved
Comment on lines +181 to +183
if coordinator_rank == paddle.distributed.get_rank():
logger.debug(f"metadata:{metadata}")
paddle.save(metadata, os.path.join(path, f"{unique_id}.metadata"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not save meta on all ranks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

meta是global的,每个rank上是一样的,只需要保存一份

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我明白,每个rank都save是不是方便调试,不必都找rank 0?meta 也不占很多空间。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可能不行,因为每个机器都有多个卡,多个卡同时写一个文件可能会出问题,导致写入的内容不符合预期

The identifier of a local tensor.
"""

tensor_id: str
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensor_name or tensor_key ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensor_name貌似不太合适,这个是个标识,在动半中是structure_name,在静半中是tensor的名字。叫tensor_key与tensor_id的意思类似,也是可以的,如果觉得tensor_key更合适,可更改

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯嗯,在state_dict中就是key吧

local_tensor_index not in tensor_id_list
), f"Duplicate tensor_id:{local_tensor_index} found. Check whether the metadata_file:{metadata_file} contains the same tensor metadata."
tensor_id_list.append(local_tensor_index.tensor_id)
if local_tensor_index.tensor_id in state_dict:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The state_dict is local_state_dict?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个state_dict是每个rank自己维护的那个,是local的

for rank, local_files in enumerate(global_data_files):
if len(local_files) > 0:
local_files = [
f for f in local_files if f in necessary_data_files_set
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When does local_files differ from necessary_data_files_set?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

necessary_data_files_set是指当前state_dict的key命中的所有需要的文件,这些文件可能分布在其他rank上,local_files这里是个list,确实包含了所有rank可以读到的文件总和,但是不排除这些可以读到的文件总和是大于state_dict所需要读到的数据文件的,所以这里做了一个过滤的逻辑,只处理需要用到的文件

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果大于,是需要报warning吗?还是本来就合理。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

大于的话没有关系,不需要warning,因为不影响当前参数的加载

@@ -0,0 +1,21 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2019->2023

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -0,0 +1,497 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2022->2023

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if f not in file_to_ranks:
file_to_ranks[f] = []
file_to_ranks[f].append(r)
logger.info(f"file_to_ranks:{file_to_ranks}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logger系列调试信息后续会清理吗?如果不清理建议规范化一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

会打算在最后合入前统一清理,如果规范化的话,是有指定格式吗

python/paddle/distributed/checkpoint/load_state_dict.py Outdated Show resolved Hide resolved
@@ -0,0 +1,42 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

整体check一下吧,年份都不对

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

v._local_value().add_(paddle.ones_like(v._local_value()))
paddle.distributed.load_state_dict(state_dict, ckpt_path())
for k, v in state_dict.items():
assert k in local_state_dict, k
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the last k used for

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最后那个k是打印内容,assert用法是assert condition, error_message

assert k in local_state_dict, k
if v._is_initialized():
self.check_tensor_eq(v._local_value(), local_state_dict[k])
os.system(f"rm -rf {ckpt_path()}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use tempfile.TemporaryDirectory(), you can find examples in other ut.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@pangengzheng
Copy link
Contributor Author

中文api文档PR: PaddlePaddle/docs#6355

Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@risemeup1 risemeup1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@XieYunshen XieYunshen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
单测超时时间设置

@pangengzheng pangengzheng changed the title Dist save load Distributed SaveLoad implementation for semi-auto strategy Dec 7, 2023
Comment on lines +18 to +21
__all__ = [
"save_state_dict",
"load_state_dict",
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only add API in list of __ all__ at recommended user path, as we recommend using paddle.distributed.save_state_dict and paddle.distributed.load_state_dict, there is no need to add them to this list. import above can be retained.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

Comment on lines +349 to +354
def load_state_dict(
state_dict,
path,
process_group=None,
coordinator_rank=0,
) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw in the design document that there is parameter of use_dist. Shall we need to implement use_dist which is not implemented here? If not, please explain the reason and modify the design document.

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@sunzhongkai588 sunzhongkai588 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

API 文档请参考 英文模板,务必注意空行和缩进

coordinator_rank(int): The rank used to save non distributed values. Rank0 is used by default.

Examples:
.. code-block:: python
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
.. code-block:: python
.. code-block:: python

code-block下方得加空行,否则官网渲染会出错
image


Examples:
.. code-block:: python
>>> # doctest: +SKIP('Save state dict.')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
>>> # doctest: +SKIP('Save state dict.')
>>> # doctest: +SKIP('state dict not exist'')

跳过检查的原因写清晰一点叭,保证可读性

) -> None:
"""
Load the state_dict inplace from a checkpoint path.
Args:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Args:
Args:

声明、参数..等各部分之间加空行,否则可能会导致官网渲染出错

Comment on lines +362 to +363
Example:
.. code-block:: python
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Example:
.. code-block:: python
Example:
.. code-block:: python

同理

coordinator_rank(int): The rank used to coordinate the checkpoint. Rank0 is used by default.
Example:
.. code-block:: python
>>> # doctest: +SKIP('Load state dict.')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
>>> # doctest: +SKIP('Load state dict.')
>>> # doctest: +SKIP('state dict not exist')

理由写清晰一点,保证可读性

Copy link
Contributor

@sunzhongkai588 sunzhongkai588 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM,先合入,后续进行相关修改

@zhiqiu zhiqiu merged commit a2c8c9a into PaddlePaddle:develop Dec 7, 2023
27 of 29 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants