Skip to content

Commit

Permalink
Porting GraphConv layer to PyTorch (deepchem#3960)
Browse files Browse the repository at this point in the history
* Porting GraphConv layer to PyTorch

* Adding doctrings
  • Loading branch information
NimishaDey authored May 17, 2024
1 parent 67e662b commit 996bde3
Show file tree
Hide file tree
Showing 8 changed files with 228 additions and 5 deletions.
Binary file not shown.
Binary file not shown.
Binary file not shown.
37 changes: 37 additions & 0 deletions deepchem/models/tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

try:
import torch
import torch.nn as nn
import deepchem.models.torch_models.layers as torch_layers
has_torch = True
except ModuleNotFoundError:
Expand Down Expand Up @@ -1450,3 +1451,39 @@ def test_torch_highway_layer():
output_tensor = highway_layer(input_tensor)

assert output_tensor.shape == (batch_size, feat_dim)


@pytest.mark.torch
def test_torch_graph_conv():
"""Test invoking GraphConv."""
out_channels = 2
n_atoms = 4 # In CCC and C, there are 4 atoms
raw_smiles = ['CCC', 'C']
from rdkit import Chem
mols = [Chem.MolFromSmiles(s) for s in raw_smiles]
featurizer = dc.feat.graph_features.ConvMolFeaturizer()
mols = featurizer.featurize(mols)
multi_mol = dc.feat.mol_graphs.ConvMol.agglomerate_mols(mols)
atom_features = multi_mol.get_atom_features().astype(np.float32)
degree_slice = multi_mol.deg_slice
membership = multi_mol.membership
deg_adjs = multi_mol.get_deg_adjacency_lists()[1:]
args = [atom_features, degree_slice, membership] + deg_adjs
layer = torch_layers.GraphConv(out_channels)
torch.set_printoptions(precision=8)
W_list = np.load("deepchem/models/tests/assets/graphconvlayer_weights.npy",
allow_pickle=True).tolist()
layer.W_list = nn.ParameterList(
[nn.Parameter(torch.tensor(k)) for k in W_list])
b_list = np.load("deepchem/models/tests/assets/graphconvlayer_biases.npy",
allow_pickle=True).tolist()
layer.b_list = nn.ParameterList(
[nn.Parameter(torch.tensor(k)) for k in b_list])
result = layer(args)
assert np.allclose(
result.detach().numpy(),
np.load("deepchem/models/tests/assets/graphconvlayer_result.npy"),
atol=1e-4)
assert result.shape == (n_atoms, out_channels)
num_deg = 2 * layer.max_degree + (1 - layer.min_degree)
assert len(list(layer.parameters())) == 2 * num_deg
2 changes: 1 addition & 1 deletion deepchem/models/torch_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from deepchem.models.torch_models.mat import MAT, MATModel
from deepchem.models.torch_models.megnet import MEGNetModel
from deepchem.models.torch_models.normalizing_flows_pytorch import NormalizingFlow
from deepchem.models.torch_models.layers import MultilayerPerceptron, CNNModule, CombineMeanStd, WeightedLinearCombo, AtomicConvolution, NeighborList, SetGather, EdgeNetwork, WeaveLayer, WeaveGather, MolGANConvolutionLayer, MolGANAggregationLayer, MolGANMultiConvolutionLayer, MolGANEncoderLayer, VariationalRandomizer, EncoderRNN, DecoderRNN, AtomicConv
from deepchem.models.torch_models.layers import MultilayerPerceptron, CNNModule, CombineMeanStd, WeightedLinearCombo, AtomicConvolution, NeighborList, SetGather, EdgeNetwork, WeaveLayer, WeaveGather, MolGANConvolutionLayer, MolGANAggregationLayer, MolGANMultiConvolutionLayer, MolGANEncoderLayer, VariationalRandomizer, EncoderRNN, DecoderRNN, AtomicConv, GraphConv
from deepchem.models.torch_models.cnn import CNN
from deepchem.models.torch_models.scscore import ScScore, ScScoreModel
from deepchem.models.torch_models.weavemodel_pytorch import Weave, WeaveModel
Expand Down
184 changes: 183 additions & 1 deletion deepchem/models/torch_models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5441,7 +5441,7 @@ def forward(self, one_electron: torch.Tensor, two_electron: torch.Tensor):
one_electron: torch.Tensor
The one electron feature after passing through the layer which has the shape (batch_size, number of electrons, n_one shape).
two_electron: torch.Tensor
The two electron feature after passing through the layer which has the shape (batch_size, number of electrons, number of electron , n_two shape).
The two electron feature after passing through the layer which has the shape (batch_size, number of electrons, number of electron , n_two shape).
"""
for l in range(self.layer_size):
# Calculating one-electron feature's average
Expand Down Expand Up @@ -6055,3 +6055,185 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
output = H_out * T_out + x * (1 - T_out)

return output


class GraphConv(nn.Module):
"""Graph Convolutional Layers
This layer implements the graph convolution introduced in [1]_. The graph
convolution combines per-node feature vectures in a nonlinear fashion with
the feature vectors for neighboring nodes. This "blends" information in
local neighborhoods of a graph.
Example
--------
>>> import deepchem as dc
>>> import numpy as np
>>> import deepchem.models.torch_models.layers as torch_layers
>>> out_channels = 2
>>> n_atoms = 4 # In CCC and C, there are 4 atoms
>>> raw_smiles = ['CCC', 'C']
>>> from rdkit import Chem
>>> mols = [Chem.MolFromSmiles(s) for s in raw_smiles]
>>> featurizer = dc.feat.graph_features.ConvMolFeaturizer()
>>> mols = featurizer.featurize(mols)
>>> multi_mol = dc.feat.mol_graphs.ConvMol.agglomerate_mols(mols)
>>> atom_features = multi_mol.get_atom_features().astype(np.float32)
>>> degree_slice = multi_mol.deg_slice
>>> membership = multi_mol.membership
>>> deg_adjs = multi_mol.get_deg_adjacency_lists()[1:]
>>> args = [atom_features, degree_slice, membership] + deg_adjs
>>> layer = torch_layers.GraphConv(out_channels)
>>> result = layer(args)
>>> type(result)
<class 'torch.Tensor'>
>>> result.shape
torch.Size([4, 2])
>>> num_deg = 2 * layer.max_degree + (1 - layer.min_degree)
>>> num_deg
21
References
----------
.. [1] Duvenaud, David K., et al. "Convolutional networks on graphs for learning molecular fingerprints."
Advances in neural information processing systems. 2015. https://arxiv.org/abs/1509.09292
"""

def __init__(self,
out_channel: int,
min_deg: int = 0,
max_deg: int = 10,
activation_fn: Optional[Callable] = None,
**kwargs):
"""Initialize a graph convolutional layer.
Parameters
----------
out_channel: int
The number of output channels per graph node.
min_deg: int, optional (default 0)
The minimum allowed degree for each graph node.
max_deg: int, optional (default 10)
The maximum allowed degree for each graph node. Note that this
is set to 10 to handle complex molecules (some organometallic
compounds have strange structures). If you're using this for
non-molecular applications, you may need to set this much higher
depending on your dataset.
activation_fn: function
A nonlinear activation function to apply. If you're not sure,
`torch.nn.ReLU` is probably a good default for your application.
"""
super(GraphConv, self).__init__(**kwargs)
self.out_channel: int = out_channel
self.min_degree: int = min_deg
self.max_degree: int = max_deg
self.activation_fn: Optional[Callable] = activation_fn

# Generate the nb_affine weights and biases
num_deg: int = 2 * self.max_degree + (1 - self.min_degree)
self.W_list: nn.ParameterList = nn.ParameterList([
nn.Parameter(
getattr(initializers,
'xavier_uniform_')(torch.empty(75, self.out_channel)))
for k in range(num_deg)
])
self.b_list: nn.ParameterList = nn.ParameterList([
nn.Parameter(
getattr(initializers, 'zeros_')(torch.empty(self.out_channel,)))
for k in range(num_deg)
])
self.built = True

def __repr__(self) -> str:
"""
Returns a string representation of the object.
Returns:
-------
str: A string that contains the class name followed by the values of its instance variable.
"""
# flake8: noqa
return (
f'{self.__class__.__name__}(out_channel:{self.out_channel},min_deg:{self.min_deg},max_deg:{self.max_deg},activation_fn:{self.activation_fn})'
)

def forward(self, inputs: List[np.ndarray]) -> torch.Tensor:
"""
The forward pass combines per-node feature vectors in a nonlinear fashion with
the feature vectors for neighboring nodes.
Parameters
----------
inputs: List[np.ndarray]
Should contain atom features and arrays describing graph topology
Returns:
-------
torch.Tensor
Combined atom features
"""

# Extract atom_features
atom_features: torch.Tensor = torch.tensor(inputs[0])

# Extract graph topology
deg_slice: np.ndarray = inputs[1]
deg_adj_lists: List[np.ndarray] = inputs[3:]

W = iter(self.W_list)
b = iter(self.b_list)

# Sum all neighbors using adjacency matrix
deg_summed: List[np.ndarray] = self.sum_neigh(atom_features,
deg_adj_lists)

# Get collection of modified atom features
new_rel_atoms_collection = []

split_features: Tuple[torch.Tensor,
...] = torch.split(atom_features,
(deg_slice[:, 1]).tolist())
for deg in range(1, self.max_degree + 1):
# Obtain relevant atoms for this degree
rel_atoms: torch.Tensor = torch.from_numpy(deg_summed[deg - 1])

# Get self atoms
self_atoms: torch.Tensor = split_features[deg - self.min_degree]

# Apply hidden affine to relevant atoms and append
rel_out: torch.Tensor = torch.matmul(rel_atoms.type(torch.float32),
next(W)) + next(b)
self_out: torch.Tensor = torch.matmul(
self_atoms.type(torch.float32), next(W)) + next(b)
out: torch.Tensor = rel_out + self_out
new_rel_atoms_collection.append(
torch.from_numpy(out.detach().numpy()))

# Determine the min_deg=0 case
if self.min_degree == 0:
self_atoms = split_features[0]

# Only use the self layer
out = torch.matmul(self_atoms.type(torch.float32),
next(W)) + next(b)
new_rel_atoms_collection.insert(
0, torch.from_numpy(out.detach().numpy()))

# Combine all atoms back into the list
atom_features = torch.concat(new_rel_atoms_collection, 0)

if self.activation_fn is not None:
atom_features = self.activation_fn(atom_features)

return atom_features

def sum_neigh(self, atoms: torch.Tensor, deg_adj_lists) -> List[np.ndarray]:
"""Store the summed atoms by degree"""
deg_summed = []

for deg in range(1, self.max_degree + 1):
gathered_atoms: torch.Tensor = atoms[deg_adj_lists[deg - 1]]
# Sum along neighbors as well as self, and store
summed_atoms: torch.Tensor = torch.sum(gathered_atoms, 1)
deg_summed.append(summed_atoms.detach().numpy())

return deg_summed
7 changes: 5 additions & 2 deletions docs/source/api_reference/layers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ Torch Layers

.. autoclass:: deepchem.models.torch_models.layers.MolGANAggregationLayer
:members:

.. autoclass:: deepchem.models.torch_models.layers.MolGANMultiConvolutionLayer
:members:

Expand All @@ -239,7 +239,7 @@ Torch Layers

.. autoclass:: deepchem.models.torch_models.layers.WeaveGather
:members:

.. autoclass:: deepchem.models.torch_models.layers.MXMNetGlobalMessagePassing
:members:

Expand Down Expand Up @@ -276,6 +276,9 @@ Torch Layers
.. autoclass:: deepchem.models.torch_models.layers.HighwayLayer
:members:

.. autoclass:: deepchem.models.torch_models.layers.GraphConv
:members:

.. autoclass:: deepchem.models.torch_models.flows.ClampExp
:members:

Expand Down
3 changes: 2 additions & 1 deletion docs/source/api_reference/torch_layers.csv
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,5 @@ FerminetElectronFeature, `ref <https://arxiv.org/pdf/1909.02487.pdf>`_, Ferminet
FerminetEnvelope, `ref <https://arxiv.org/pdf/1909.02487.pdf>`_, FerminetModel
MXMNetLocalMessagePassing, `ref <https://arxiv.org/pdf/2011.07457>`_, MXMNetModel
MXMNetModelMXMNetSphericalBasisLayer, ref`<https://arxiv.org/pdf/2011.07457>`_, MXMNetModel
HighwayLayer, `ref <https://arxiv.org/abs/1507.06228>`_,
HighwayLayer, `ref <https://arxiv.org/abs/1507.06228>`_,
GraphConv,`ref <https://arxiv.org/abs/1509.09292>`_,

0 comments on commit 996bde3

Please sign in to comment.