Skip to content

Commit

Permalink
tensor_split expects tensor_indices_or_sections to be on cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
theo-barfoot committed Oct 21, 2024
1 parent b4f2e71 commit 58350b2
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torchsparsegradutils/indexed_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def segment_mm(a, b, seglen_a):
if not a.shape[1] == D1 or not seglen_a.shape[0] == R:
raise ValueError("Incompatible size for inputs")

segidx_a = torch.cumsum(seglen_a[:-1], dim=0)
segidx_a = torch.cumsum(seglen_a[:-1], dim=0).cpu()

# Ideally the conversions below to nested tensor would be handled natively
nested_a = torch.nested.as_nested_tensor(torch.tensor_split(a, segidx_a, dim=0))
Expand Down

0 comments on commit 58350b2

Please sign in to comment.