You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
shard and scatter the shards XXX: I actually don't think deepspeed.zero.GatheredParameters was handled here. so these params don't get ZeRO'ed - need to fix that [Deepspeed zero3] lazy weights init #12272
Because we unnecessarily do scatter/gather/scatter, this takes much longer than just:
c. any sub-module params that weren't in the pretrained state_dict
materialize and module_init
shard and scatter the shards
Solving this will most likely require support from Deepspeed, microsoft/DeepSpeed#1142 or perhaps we can just try to remove zero.Init if the weights aren't materialized during model creation. So the very first sharding will get postponed to the load_state_dict stage (and module_init for the sub-modules that don't have pre-trained weights).
The text was updated successfully, but these errors were encountered:
🚀 Feature request
Currently under Deepspeed stage3 with
from_pretrained
we:a. loop over each sub-module in zero.Init
b. then to load pre-trained weights we loop over each sub-module:
load_state_dict
for the one layer layerc. any sub-module params that weren't in the pretrained state_dict
module_init
as it was done in Pytorch - Lazy initialization of models #11471deepspeed.zero.GatheredParameters
was handled here. so these params don't get ZeRO'ed - need to fix that [Deepspeed zero3] lazy weights init #12272Because we unnecessarily do scatter/gather/scatter, this takes much longer than just:
a. init the modules w/o allocating any storage as it has been implemented in pt-1.9.0/1.9.1 https://pytorch.org/tutorials/prototype/skip_param_init.html#implementation-details
b. for each sub-module with pretrained weights
c. any sub-module params that weren't in the pretrained state_dict
Solving this will most likely require support from Deepspeed, microsoft/DeepSpeed#1142 or perhaps we can just try to remove
zero.Init
if the weights aren't materialized during model creation. So the very first sharding will get postponed to theload_state_dict
stage (andmodule_init
for the sub-modules that don't have pre-trained weights).The text was updated successfully, but these errors were encountered: