Skip to content

Commit

Permalink
auto move inputs to same device as model when invoking `forward_with_…
Browse files Browse the repository at this point in the history
…alphafold3_inputs`
  • Loading branch information
lucidrains committed Sep 16, 2024
1 parent 3e1795e commit ead35b7
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
14 changes: 11 additions & 3 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@

import torch
from torch import nn
from torch import Tensor, tensor
from torch import Tensor, tensor, is_tensor
from torch.amp import autocast
import torch.nn.functional as F
from torch.utils._pytree import tree_map

from torch.nn import (
Module,
Expand Down Expand Up @@ -227,6 +228,9 @@ def freeze_(m: Module):
def max_neg_value(t: Tensor):
return -torch.finfo(t.dtype).max

def dict_to_device(d, device):
return tree_map(lambda t: t.to(device) if is_tensor(t) else t, d)

def pack_one(t, pattern):
packed, ps = pack([t], pattern)

Expand Down Expand Up @@ -263,7 +267,7 @@ def should_checkpoint(
inputs: Tensor | Tuple[Tensor, ...],
check_instance_variable: str | None = 'checkpoint'
) -> bool:
if torch.is_tensor(inputs):
if is_tensor(inputs):
inputs = (inputs,)

return (
Expand Down Expand Up @@ -6344,7 +6348,11 @@ def forward_with_alphafold3_inputs(
alphafold3_inputs = [alphafold3_inputs]

batched_atom_inputs = alphafold3_inputs_to_batched_atom_input(alphafold3_inputs, atoms_per_window = self.w)
return self.forward(**batched_atom_inputs.model_forward_dict(), **kwargs)

atom_dict = batched_atom_inputs.model_forward_dict()
atom_dict = dict_to_device(atom_dict, device = self.device)

return self.forward(**atom_dict, **kwargs)

@typecheck
def forward(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "alphafold3-pytorch"
version = "0.5.18"
version = "0.5.19"
description = "Alphafold 3 - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" },
Expand Down

0 comments on commit ead35b7

Please sign in to comment.