Skip to content

Commit

Permalink
use existing is_local_dist_rank_0 fn to check devices
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Jul 29, 2024
1 parent aec7187 commit 64751e5
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,7 +933,7 @@ def _load_state_dict_into_meta_model(
)
):
if is_fsdp_enabled():
param_device = "cpu" if device_map[""] == 0 else "meta"
param_device = "cpu" if is_local_dist_rank_0() else "meta"
# For backward compatibility with older versions of `accelerate` and for non-quantized params
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
else:
Expand All @@ -945,7 +945,7 @@ def _load_state_dict_into_meta_model(
module, tensor_name = get_module_from_name(model, param_name)
value = getattr(module, tensor_name)
param_to = "cpu"
if is_fsdp_enabled() and value.data.device.index != 0:
if is_fsdp_enabled() and not is_local_dist_rank_0():
param_to = "meta"
value = type(value)(value.data.to(param_to), **value.__dict__)
setattr(module, tensor_name, value)
Expand Down

0 comments on commit 64751e5

Please sign in to comment.