Skip to content

Commit

Permalink
Fix HeteroLinear with segment_matmul (#5347)
Browse files Browse the repository at this point in the history
* update

* update
  • Loading branch information
rusty1s authored Sep 3, 2022
1 parent ded6fd1 commit e45f792
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.2.0] - 2022-MM-DD
### Added
- Added `pyg_lib.segment_matmul` integration within `HeteroLinear` ([#5330](https://github.com/pyg-team/pytorch_geometric/pull/5330))
- Added `pyg_lib.segment_matmul` integration within `HeteroLinear` ([#5330](https://github.com/pyg-team/pytorch_geometric/pull/5330), [#5347](https://github.com/pyg-team/pytorch_geometric/pull/5347)))
- Enabled `bf16` support in benchmark scripts ([#5293](https://github.com/pyg-team/pytorch_geometric/pull/5293), [#5341](https://github.com/pyg-team/pytorch_geometric/pull/5341))
- Added `Aggregation.set_validate_args` option to skip validation of `dim_size` ([#5290](https://github.com/pyg-team/pytorch_geometric/pull/5290))
- Added `SparseTensor` support to inference benchmark suite ([#5242](https://github.com/pyg-team/pytorch_geometric/pull/5242), [#5258](https://github.com/pyg-team/pytorch_geometric/pull/5258))
Expand Down
10 changes: 9 additions & 1 deletion torch_geometric/nn/dense/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
except ImportError:
_WITH_PYG_LIB = False

def segment_matmul(inputs: Tensor, ptr: Tensor, other: Tensor) -> Tensor:
raise NotImplementedError


def is_uninitialized_parameter(x: Any) -> bool:
if not hasattr(nn.parameter, 'UninitializedParameter'):
Expand Down Expand Up @@ -227,6 +230,8 @@ def __init__(self, in_channels: int, out_channels: int, num_types: int,
Linear(in_channels, out_channels, **kwargs)
for _ in range(num_types)
])
self.register_parameter('weight', None)
self.register_parameter('bias', None)

self.reset_parameters()

Expand All @@ -247,10 +252,13 @@ def forward(self, x: Tensor, type_vec: Tensor) -> Tensor:
type_vec (LongTensor): A vector that maps each entry to a type.
"""
if self._WITH_PYG_LIB:
assert self.weight is not None

if not self.is_sorted:
if (type_vec[1:] < type_vec[:-1]).any():
type_vec, perm = type_vec.sort()
x = x[:, perm]
x = x[perm]

type_vec_ptr = torch.ops.torch_sparse.ind2ptr(
type_vec, self.num_types)
out = segment_matmul(x, type_vec_ptr, self.weight)
Expand Down

0 comments on commit e45f792

Please sign in to comment.