diff --git a/alphafold3_pytorch/alphafold3.py b/alphafold3_pytorch/alphafold3.py index 5f6438a0..2a91f2fa 100644 --- a/alphafold3_pytorch/alphafold3.py +++ b/alphafold3_pytorch/alphafold3.py @@ -10,9 +10,10 @@ import torch from torch import nn -from torch import Tensor, tensor +from torch import Tensor, tensor, is_tensor from torch.amp import autocast import torch.nn.functional as F +from torch.utils._pytree import tree_map from torch.nn import ( Module, @@ -227,6 +228,9 @@ def freeze_(m: Module): def max_neg_value(t: Tensor): return -torch.finfo(t.dtype).max +def dict_to_device(d, device): + return tree_map(lambda t: t.to(device) if is_tensor(t) else t, d) + def pack_one(t, pattern): packed, ps = pack([t], pattern) @@ -263,7 +267,7 @@ def should_checkpoint( inputs: Tensor | Tuple[Tensor, ...], check_instance_variable: str | None = 'checkpoint' ) -> bool: - if torch.is_tensor(inputs): + if is_tensor(inputs): inputs = (inputs,) return ( @@ -6344,7 +6348,11 @@ def forward_with_alphafold3_inputs( alphafold3_inputs = [alphafold3_inputs] batched_atom_inputs = alphafold3_inputs_to_batched_atom_input(alphafold3_inputs, atoms_per_window = self.w) - return self.forward(**batched_atom_inputs.model_forward_dict(), **kwargs) + + atom_dict = batched_atom_inputs.model_forward_dict() + atom_dict = dict_to_device(atom_dict, device = self.device) + + return self.forward(**atom_dict, **kwargs) @typecheck def forward( diff --git a/pyproject.toml b/pyproject.toml index 2b129180..b47df14d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "alphafold3-pytorch" -version = "0.5.18" +version = "0.5.19" description = "Alphafold 3 - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" },