You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
If you are submitting a bug report, please fill in the following details and use the tag [bug].
Describe the bug
If you run n_devices=3 for llama2-7b from a HookedTransformer, you will receive an issue that the device index specified is out of the ordinal range.
From what I can tell this could be related to how layers are being assigned GPU, and then fetched.
def get_device_for_block_index(
index: int,
cfg: "transformer_lens.HookedTransformerConfig",
device: Optional[Union[torch.device, str]] = None,
):
"""
Determine the device for a given layer index based on the model configuration.
This function assists in distributing model layers across multiple devices. The distribution
is based on the configuration's number of layers (cfg.n_layers) and devices (cfg.n_devices).
Args:
index (int): Model layer index.
cfg (HookedTransformerConfig): Model and device configuration.
device (Optional[Union[torch.device, str]], optional): Initial device used for determining the target device.
If not provided, the function uses the device specified in the configuration (cfg.device).
Returns:
torch.device: The device for the specified layer index.
"""
assert cfg.device is not None
layers_per_device = cfg.n_layers // cfg.n_devices
if device is None:
device = cfg.device
device = torch.device(device)
if device.type == "cpu":
return device
## If n_devices =3, then for the latter layers of any model, we will assign to the index 3
## , which is out of range when n_device = 3. Assuming devices have a starting index of 0.
device_index = (device.index or 0) + (index // layers_per_device)
### As a fix, could we just cap the device index?
## The layer assignment would still be the same.
# if device_index > (cfg_n_device-1):
# device_index = (cfg_n_device-1)
return torch.device(device.type, device_index)
If you are submitting a bug report, please fill in the following details and use the tag [bug].
Describe the bug
If you run n_devices=3 for llama2-7b from a HookedTransformer, you will receive an issue that the device index specified is out of the ordinal range.
From what I can tell this could be related to how layers are being assigned GPU, and then fetched.
Code example
System Info
Describe the characteristic of your environment:
Additional context
Checklist
The text was updated successfully, but these errors were encountered: