Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Deepspeed] [performance] inefficient load with from_pretrained w/ zero3 #12273

Open
stas00 opened this issue Jun 20, 2021 · 0 comments
Open
Assignees
Labels
DeepSpeed WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress

Comments

@stas00
Copy link
Contributor

stas00 commented Jun 20, 2021

🚀 Feature request

Currently under Deepspeed stage3 with from_pretrained we:

a. loop over each sub-module in zero.Init

  1. init the sub-module
  2. shard and scatter the shards

b. then to load pre-trained weights we loop over each sub-module:

  1. gather the shards
  2. load_state_dict for the one layer layer
  3. shard and scatter the shards

c. any sub-module params that weren't in the pretrained state_dict

  1. run the postponed module_init as it was done in Pytorch - Lazy initialization of models #11471
  2. shard and scatter the shards XXX: I actually don't think deepspeed.zero.GatheredParameters was handled here. so these params don't get ZeRO'ed - need to fix that [Deepspeed zero3] lazy weights init  #12272

Because we unnecessarily do scatter/gather/scatter, this takes much longer than just:

a. init the modules w/o allocating any storage as it has been implemented in pt-1.9.0/1.9.1 https://pytorch.org/tutorials/prototype/skip_param_init.html#implementation-details

b. for each sub-module with pretrained weights

  1. load_state_dict
  2. shard and scatter the shards

c. any sub-module params that weren't in the pretrained state_dict

  1. materialize and module_init
  2. shard and scatter the shards

Solving this will most likely require support from Deepspeed, microsoft/DeepSpeed#1142 or perhaps we can just try to remove zero.Init if the weights aren't materialized during model creation. So the very first sharding will get postponed to the load_state_dict stage (and module_init for the sub-modules that don't have pre-trained weights).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
DeepSpeed WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

No branches or pull requests

1 participant