From 58350b276196220bb9c215d604830eab87233501 Mon Sep 17 00:00:00 2001 From: theo-barfoot Date: Mon, 21 Oct 2024 18:12:26 +0000 Subject: [PATCH] tensor_split expects tensor_indices_or_sections to be on cpu --- torchsparsegradutils/indexed_matmul.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchsparsegradutils/indexed_matmul.py b/torchsparsegradutils/indexed_matmul.py index ef0c8f7..b41535c 100644 --- a/torchsparsegradutils/indexed_matmul.py +++ b/torchsparsegradutils/indexed_matmul.py @@ -45,7 +45,7 @@ def segment_mm(a, b, seglen_a): if not a.shape[1] == D1 or not seglen_a.shape[0] == R: raise ValueError("Incompatible size for inputs") - segidx_a = torch.cumsum(seglen_a[:-1], dim=0) + segidx_a = torch.cumsum(seglen_a[:-1], dim=0).cpu() # Ideally the conversions below to nested tensor would be handled natively nested_a = torch.nested.as_nested_tensor(torch.tensor_split(a, segidx_a, dim=0))