Skip to content

Commit

Permalink
Add move_to_device kwarg to the optimizer's load_state_dict (#1344)
Browse files Browse the repository at this point in the history
This makes it possible to load an optimizer checkpoint without
automatically moving the optimizer's state to the GPU.
  • Loading branch information
koute authored Sep 19, 2024
1 parent abb0c32 commit 8fc7892
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions bitsandbytes/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,14 @@ def fill_qmap(self):
def __setstate__(self, state):
super().__setstate__(state)

def load_state_dict(self, state_dict):
def load_state_dict(self, state_dict, move_to_device=True):
"""Load an optimizer state.
Arguments:
state_dict (`dict`):
An optimizer state (should be returned from a call to `state_dict`) to load.
move_to_device (`bool`, defaults to `True`):
Whether to move the optimizer's state to the device.
"""
# deepcopy, to be consistent with module API
state_dict = deepcopy(state_dict)
Expand Down Expand Up @@ -195,7 +197,8 @@ def cast(param, value):
elif isinstance(value, dict):
for k, v in value.items():
if k in self.non_castable_tensor_keys:
value[k] = v.to(param.device)
if move_to_device:
value[k] = v.to(param.device)
else:
value[k] = cast(param, v)

Expand Down

0 comments on commit 8fc7892

Please sign in to comment.