From 856b48e0522b54f75109c1c4659d62c1c3871ed4 Mon Sep 17 00:00:00 2001 From: Parth Raut Date: Thu, 12 Dec 2024 22:14:39 -0500 Subject: [PATCH] forgot to move to cuda --- zeus/utils/framework.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zeus/utils/framework.py b/zeus/utils/framework.py index c745b9d1..86b9bd8a 100644 --- a/zeus/utils/framework.py +++ b/zeus/utils/framework.py @@ -122,7 +122,7 @@ def all_reduce( return object # wrap object in a tensor - tensor = torch.Tensor(object) + tensor = torch.Tensor(object, device="cuda") # determine operation if operation == "sum":