-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added simple workarounds for gather_mm and segment_mm (#57)
* Added simple workarounds for gather_mm and segment_mm. See #56 * bumping python and pytorch version in CI * enabling black on notebooks in CI * updating github actions to avoid deprecation warning
- Loading branch information
Showing
6 changed files
with
237 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,13 @@ | ||
from .sparse_matmul import sparse_mm | ||
from .indexed_matmul import gather_mm, segment_mm | ||
from .sparse_solve import sparse_triangular_solve, sparse_generic_solve | ||
from .sparse_lstsq import sparse_generic_lstsq | ||
|
||
__all__ = ["sparse_mm", "sparse_triangular_solve", "sparse_generic_solve", "sparse_generic_lstsq"] | ||
__all__ = [ | ||
"sparse_mm", | ||
"gather_mm", | ||
"segment_mm", | ||
"sparse_triangular_solve", | ||
"sparse_generic_solve", | ||
"sparse_generic_lstsq", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
import torch | ||
|
||
try: | ||
import dgl.ops as dglops | ||
|
||
dgl_installed = True | ||
except ImportError: | ||
dgl_installed = False | ||
|
||
|
||
def segment_mm(a, b, seglen_a): | ||
""" | ||
Performs matrix multiplication according to segments. | ||
See https://docs.dgl.ai/generated/dgl.ops.segment_mm.html | ||
Suppose ``seglen_a == [10, 5, 0, 3]``, the operator will perform | ||
four matrix multiplications:: | ||
a[0:10] @ b[0], a[10:15] @ b[1], | ||
a[15:15] @ b[2], a[15:18] @ b[3] | ||
Args: | ||
a (torch.Tensor): The left operand, 2-D tensor of shape ``(N, D1)`` | ||
b (torch.Tensor): The right operand, 3-D tensor of shape ``(R, D1, D2)`` | ||
seglen_a (torch.Tensor): An integer tensor of shape ``(R,)``. Each element is the length of segments of input ``a``. The summation of all elements must be equal to ``N``. | ||
Returns: | ||
torch.Tensor: The output dense matrix of shape ``(N, D2)`` | ||
""" | ||
if torch.__version__ < (2, 4): | ||
raise NotImplementedError("PyTorch version is too old for nested tesors") | ||
|
||
if dgl_installed: | ||
# DGL is probably more computationally efficient | ||
# See https://github.com/pytorch/pytorch/issues/136747 | ||
return dglops.segment_mm(a, b, seglen_a) | ||
|
||
if not a.dim() == 2 or not b.dim() == 3 or not seglen_a.dim() == 1: | ||
raise ValueError("Input tensors have unexpected dimensions") | ||
|
||
N, _ = a.shape | ||
R, D1, D2 = b.shape | ||
|
||
# Sanity check sizes | ||
if not a.shape[1] == D1 or not seglen_a.shape[0] == R: | ||
raise ValueError("Incompatible size for inputs") | ||
|
||
segidx_a = torch.cumsum(seglen_a[:-1], dim=0) | ||
|
||
# Ideally the conversions below to nested tensor would be handled natively | ||
nested_a = torch.nested.as_nested_tensor(torch.tensor_split(a, segidx_a, dim=0)) | ||
nested_b = torch.nested.as_nested_tensor(list(map(torch.squeeze, torch.split(b, 1, dim=0)))) | ||
|
||
# The actual gather matmul computation | ||
nested_ab = torch.matmul(nested_a, nested_b) | ||
|
||
# Convert back to tensors, again ideally this would be handled natively | ||
ab = torch.cat(nested_ab.unbind(), dim=0) | ||
return ab | ||
|
||
|
||
def gather_mm(a, b, idx_b): | ||
""" | ||
Gather data according to the given indices and perform matrix multiplication. | ||
See https://docs.dgl.ai/generated/dgl.ops.gather_mm.html | ||
Let the result tensor be ``c``, the operator conducts the following computation: | ||
c[i] = a[i] @ b[idx_b[i]] | ||
, where len(c) == len(idx_b) | ||
Args: | ||
a (torch.Tensor): A 2-D tensor of shape ``(N, D1)`` | ||
b (torch.Tensor): A 3-D tensor of shape ``(R, D1, D2)`` | ||
idx_b (torch.Tensor): An 1-D integer tensor of shape ``(N,)``. | ||
Returns: | ||
torch.Tensor: The output dense matrix of shape ``(N, D2)`` | ||
""" | ||
if torch.__version__ < (2, 4): | ||
raise NotImplementedError("PyTorch version is too old for nested tesors") | ||
|
||
if dgl_installed: | ||
# DGL is more computationally efficient | ||
# See https://github.com/pytorch/pytorch/issues/136747 | ||
return dglops.gather_mm(a, b, idx_b) | ||
|
||
# Dependency free fallback | ||
if not isinstance(a, torch.Tensor) or not isinstance(b, torch.Tensor) or not isinstance(idx_b, torch.Tensor): | ||
raise ValueError("Inputs should be instances of torch.Tensor") | ||
|
||
if not a.dim() == 2 or not b.dim() == 3 or not idx_b.dim() == 1: | ||
raise ValueError("Input tensors have unexpected dimensions") | ||
|
||
N = idx_b.shape[0] | ||
R, D1, D2 = b.shape | ||
|
||
# Sanity check sizes | ||
if not a.shape[0] == N or not a.shape[1] == D1: | ||
raise ValueError("Incompatible size for inputs") | ||
|
||
torchdevice = a.device | ||
src_idx = torch.arange(N, device=torchdevice) | ||
|
||
# Ideally the conversions below to nested tensor would be handled without for looops and without copy | ||
nested_a = torch.nested.as_nested_tensor([a[idx_b == i, :] for i in range(R)]) | ||
src_idx_reshuffled = torch.cat([src_idx[idx_b == i] for i in range(R)]) | ||
nested_b = torch.nested.as_nested_tensor([b[i, :, :].squeeze() for i in range(R)]) | ||
|
||
# The actual gather matmul computation | ||
nested_ab = torch.matmul(nested_a, nested_b) | ||
|
||
# Convert back to tensors, again, ideally this would be handled natively with no copy | ||
ab_segmented = torch.cat(nested_ab.unbind(), dim=0) | ||
ab = torch.empty((N, D2), device=torchdevice) | ||
ab[src_idx_reshuffled] = ab_segmented | ||
return ab |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import torch | ||
import pytest | ||
|
||
if torch.__version__ < (2, 4): | ||
pytest.skip( | ||
"Skipping test based on nested tensors since an old version of pytorch is used", allow_module_level=True | ||
) | ||
|
||
from torchsparsegradutils import gather_mm, segment_mm | ||
|
||
# Identify Testing Parameters | ||
DEVICES = [torch.device("cpu")] | ||
if torch.cuda.is_available(): | ||
DEVICES.append(torch.device("cuda")) | ||
|
||
TEST_DATA = [ | ||
# name N, R, D1, D2 | ||
("small", 100, 32, 7, 10), | ||
] | ||
|
||
INDEX_DTYPES = [torch.int32, torch.int64] | ||
VALUE_DTYPES = [torch.float32, torch.float64] | ||
|
||
ATOL = 1e-6 # relaxed tolerance to allow for float32 | ||
RTOL = 1e-4 | ||
|
||
|
||
# Define Test Names: | ||
def data_id(shapes): | ||
return shapes[0] | ||
|
||
|
||
def device_id(device): | ||
return str(device) | ||
|
||
|
||
def dtype_id(dtype): | ||
return str(dtype).split(".")[-1] | ||
|
||
|
||
# Define Fixtures | ||
|
||
|
||
@pytest.fixture(params=TEST_DATA, ids=[data_id(d) for d in TEST_DATA]) | ||
def shapes(request): | ||
return request.param | ||
|
||
|
||
@pytest.fixture(params=VALUE_DTYPES, ids=[dtype_id(d) for d in VALUE_DTYPES]) | ||
def value_dtype(request): | ||
return request.param | ||
|
||
|
||
@pytest.fixture(params=INDEX_DTYPES, ids=[dtype_id(d) for d in INDEX_DTYPES]) | ||
def index_dtype(request): | ||
return request.param | ||
|
||
|
||
@pytest.fixture(params=DEVICES, ids=[device_id(d) for d in DEVICES]) | ||
def device(request): | ||
return request.param | ||
|
||
|
||
# Define Tests | ||
|
||
|
||
def test_segment_mm(device, value_dtype, index_dtype, shapes): | ||
_, N, R, D1, D2 = shapes | ||
|
||
a = torch.randn((N, D1), device=device) | ||
b = torch.randn((R, D1, D2), device=device) | ||
seglen_a = torch.randint(low=1, high=int(N / R), size=(R,), device=device) | ||
seglen_a[-1] = N - seglen_a[:-1].sum() | ||
|
||
ab = segment_mm(a, b, seglen_a) | ||
|
||
k = 0 | ||
for i in range(R): | ||
for j in range(seglen_a[i]): | ||
assert torch.allclose(ab[k, :].squeeze(), a[k, :].squeeze() @ b[i, :, :].squeeze(), atol=ATOL, rtol=RTOL) | ||
k += 1 | ||
|
||
|
||
def test_gather_mm(device, value_dtype, index_dtype, shapes): | ||
_, N, R, D1, D2 = shapes | ||
|
||
a = torch.randn((N, D1), device=device) | ||
b = torch.randn((R, D1, D2), device=device) | ||
idx_b = torch.randint(low=0, high=R, size=(N,), device=device) | ||
|
||
ab = gather_mm(a, b, idx_b) | ||
|
||
for i in range(N): | ||
assert torch.allclose(ab[i, :].squeeze(), a[i, :].squeeze() @ b[idx_b[i], :, :].squeeze(), atol=ATOL, rtol=RTOL) |