From 0581f863edeb0e490fedc083af358f45d950b3d4 Mon Sep 17 00:00:00 2001 From: Dmitry Baranchuk Date: Thu, 28 Jul 2022 03:24:55 +0300 Subject: [PATCH] Update module_backend.py --- hivemind/moe/server/module_backend.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/hivemind/moe/server/module_backend.py b/hivemind/moe/server/module_backend.py index f6260371a..5688b8ecc 100644 --- a/hivemind/moe/server/module_backend.py +++ b/hivemind/moe/server/module_backend.py @@ -118,9 +118,7 @@ def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: with torch.enable_grad(): args = [ - tensor.detach().requires_grad_(True) - if tensor.dtype in (torch.half, torch.float, torch.double) - else tensor.detach() + tensor.detach().requires_grad_(True) if tensor.is_floating_point() else tensor.detach() for tensor in args ] kwargs = {