diff --git a/deepchem/models/tests/assets/graphconvlayer_biases.npy b/deepchem/models/tests/assets/graphconvlayer_biases.npy new file mode 100644 index 0000000000..9fd0ad8830 Binary files /dev/null and b/deepchem/models/tests/assets/graphconvlayer_biases.npy differ diff --git a/deepchem/models/tests/assets/graphconvlayer_result.npy b/deepchem/models/tests/assets/graphconvlayer_result.npy new file mode 100644 index 0000000000..f2e2fdc0f3 Binary files /dev/null and b/deepchem/models/tests/assets/graphconvlayer_result.npy differ diff --git a/deepchem/models/tests/assets/graphconvlayer_weights.npy b/deepchem/models/tests/assets/graphconvlayer_weights.npy new file mode 100644 index 0000000000..749b6c57e3 Binary files /dev/null and b/deepchem/models/tests/assets/graphconvlayer_weights.npy differ diff --git a/deepchem/models/tests/test_layers.py b/deepchem/models/tests/test_layers.py index e46ef54999..502497bbf2 100644 --- a/deepchem/models/tests/test_layers.py +++ b/deepchem/models/tests/test_layers.py @@ -13,6 +13,7 @@ try: import torch + import torch.nn as nn import deepchem.models.torch_models.layers as torch_layers has_torch = True except ModuleNotFoundError: @@ -1450,3 +1451,39 @@ def test_torch_highway_layer(): output_tensor = highway_layer(input_tensor) assert output_tensor.shape == (batch_size, feat_dim) + + +@pytest.mark.torch +def test_torch_graph_conv(): + """Test invoking GraphConv.""" + out_channels = 2 + n_atoms = 4 # In CCC and C, there are 4 atoms + raw_smiles = ['CCC', 'C'] + from rdkit import Chem + mols = [Chem.MolFromSmiles(s) for s in raw_smiles] + featurizer = dc.feat.graph_features.ConvMolFeaturizer() + mols = featurizer.featurize(mols) + multi_mol = dc.feat.mol_graphs.ConvMol.agglomerate_mols(mols) + atom_features = multi_mol.get_atom_features().astype(np.float32) + degree_slice = multi_mol.deg_slice + membership = multi_mol.membership + deg_adjs = multi_mol.get_deg_adjacency_lists()[1:] + args = [atom_features, degree_slice, membership] + deg_adjs + layer = torch_layers.GraphConv(out_channels) + torch.set_printoptions(precision=8) + W_list = np.load("deepchem/models/tests/assets/graphconvlayer_weights.npy", + allow_pickle=True).tolist() + layer.W_list = nn.ParameterList( + [nn.Parameter(torch.tensor(k)) for k in W_list]) + b_list = np.load("deepchem/models/tests/assets/graphconvlayer_biases.npy", + allow_pickle=True).tolist() + layer.b_list = nn.ParameterList( + [nn.Parameter(torch.tensor(k)) for k in b_list]) + result = layer(args) + assert np.allclose( + result.detach().numpy(), + np.load("deepchem/models/tests/assets/graphconvlayer_result.npy"), + atol=1e-4) + assert result.shape == (n_atoms, out_channels) + num_deg = 2 * layer.max_degree + (1 - layer.min_degree) + assert len(list(layer.parameters())) == 2 * num_deg diff --git a/deepchem/models/torch_models/__init__.py b/deepchem/models/torch_models/__init__.py index 26966d594e..eab97387e0 100644 --- a/deepchem/models/torch_models/__init__.py +++ b/deepchem/models/torch_models/__init__.py @@ -18,7 +18,7 @@ from deepchem.models.torch_models.mat import MAT, MATModel from deepchem.models.torch_models.megnet import MEGNetModel from deepchem.models.torch_models.normalizing_flows_pytorch import NormalizingFlow -from deepchem.models.torch_models.layers import MultilayerPerceptron, CNNModule, CombineMeanStd, WeightedLinearCombo, AtomicConvolution, NeighborList, SetGather, EdgeNetwork, WeaveLayer, WeaveGather, MolGANConvolutionLayer, MolGANAggregationLayer, MolGANMultiConvolutionLayer, MolGANEncoderLayer, VariationalRandomizer, EncoderRNN, DecoderRNN, AtomicConv +from deepchem.models.torch_models.layers import MultilayerPerceptron, CNNModule, CombineMeanStd, WeightedLinearCombo, AtomicConvolution, NeighborList, SetGather, EdgeNetwork, WeaveLayer, WeaveGather, MolGANConvolutionLayer, MolGANAggregationLayer, MolGANMultiConvolutionLayer, MolGANEncoderLayer, VariationalRandomizer, EncoderRNN, DecoderRNN, AtomicConv, GraphConv from deepchem.models.torch_models.cnn import CNN from deepchem.models.torch_models.scscore import ScScore, ScScoreModel from deepchem.models.torch_models.weavemodel_pytorch import Weave, WeaveModel diff --git a/deepchem/models/torch_models/layers.py b/deepchem/models/torch_models/layers.py index 9ffe03ef3e..d086df5036 100644 --- a/deepchem/models/torch_models/layers.py +++ b/deepchem/models/torch_models/layers.py @@ -5441,7 +5441,7 @@ def forward(self, one_electron: torch.Tensor, two_electron: torch.Tensor): one_electron: torch.Tensor The one electron feature after passing through the layer which has the shape (batch_size, number of electrons, n_one shape). two_electron: torch.Tensor - The two electron feature after passing through the layer which has the shape (batch_size, number of electrons, number of electron , n_two shape). + The two electron feature after passing through the layer which has the shape (batch_size, number of electrons, number of electron , n_two shape). """ for l in range(self.layer_size): # Calculating one-electron feature's average @@ -6055,3 +6055,185 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: output = H_out * T_out + x * (1 - T_out) return output + + +class GraphConv(nn.Module): + """Graph Convolutional Layers + + This layer implements the graph convolution introduced in [1]_. The graph + convolution combines per-node feature vectures in a nonlinear fashion with + the feature vectors for neighboring nodes. This "blends" information in + local neighborhoods of a graph. + + Example + -------- + >>> import deepchem as dc + >>> import numpy as np + >>> import deepchem.models.torch_models.layers as torch_layers + >>> out_channels = 2 + >>> n_atoms = 4 # In CCC and C, there are 4 atoms + >>> raw_smiles = ['CCC', 'C'] + >>> from rdkit import Chem + >>> mols = [Chem.MolFromSmiles(s) for s in raw_smiles] + >>> featurizer = dc.feat.graph_features.ConvMolFeaturizer() + >>> mols = featurizer.featurize(mols) + >>> multi_mol = dc.feat.mol_graphs.ConvMol.agglomerate_mols(mols) + >>> atom_features = multi_mol.get_atom_features().astype(np.float32) + >>> degree_slice = multi_mol.deg_slice + >>> membership = multi_mol.membership + >>> deg_adjs = multi_mol.get_deg_adjacency_lists()[1:] + >>> args = [atom_features, degree_slice, membership] + deg_adjs + >>> layer = torch_layers.GraphConv(out_channels) + >>> result = layer(args) + >>> type(result) + + >>> result.shape + torch.Size([4, 2]) + >>> num_deg = 2 * layer.max_degree + (1 - layer.min_degree) + >>> num_deg + 21 + + References + ---------- + .. [1] Duvenaud, David K., et al. "Convolutional networks on graphs for learning molecular fingerprints." + Advances in neural information processing systems. 2015. https://arxiv.org/abs/1509.09292 + + """ + + def __init__(self, + out_channel: int, + min_deg: int = 0, + max_deg: int = 10, + activation_fn: Optional[Callable] = None, + **kwargs): + """Initialize a graph convolutional layer. + + Parameters + ---------- + out_channel: int + The number of output channels per graph node. + min_deg: int, optional (default 0) + The minimum allowed degree for each graph node. + max_deg: int, optional (default 10) + The maximum allowed degree for each graph node. Note that this + is set to 10 to handle complex molecules (some organometallic + compounds have strange structures). If you're using this for + non-molecular applications, you may need to set this much higher + depending on your dataset. + activation_fn: function + A nonlinear activation function to apply. If you're not sure, + `torch.nn.ReLU` is probably a good default for your application. + """ + super(GraphConv, self).__init__(**kwargs) + self.out_channel: int = out_channel + self.min_degree: int = min_deg + self.max_degree: int = max_deg + self.activation_fn: Optional[Callable] = activation_fn + + # Generate the nb_affine weights and biases + num_deg: int = 2 * self.max_degree + (1 - self.min_degree) + self.W_list: nn.ParameterList = nn.ParameterList([ + nn.Parameter( + getattr(initializers, + 'xavier_uniform_')(torch.empty(75, self.out_channel))) + for k in range(num_deg) + ]) + self.b_list: nn.ParameterList = nn.ParameterList([ + nn.Parameter( + getattr(initializers, 'zeros_')(torch.empty(self.out_channel,))) + for k in range(num_deg) + ]) + self.built = True + + def __repr__(self) -> str: + """ + Returns a string representation of the object. + + Returns: + ------- + str: A string that contains the class name followed by the values of its instance variable. + """ + # flake8: noqa + return ( + f'{self.__class__.__name__}(out_channel:{self.out_channel},min_deg:{self.min_deg},max_deg:{self.max_deg},activation_fn:{self.activation_fn})' + ) + + def forward(self, inputs: List[np.ndarray]) -> torch.Tensor: + """ + The forward pass combines per-node feature vectors in a nonlinear fashion with + the feature vectors for neighboring nodes. + Parameters + ---------- + inputs: List[np.ndarray] + Should contain atom features and arrays describing graph topology + Returns: + ------- + torch.Tensor + Combined atom features + """ + + # Extract atom_features + atom_features: torch.Tensor = torch.tensor(inputs[0]) + + # Extract graph topology + deg_slice: np.ndarray = inputs[1] + deg_adj_lists: List[np.ndarray] = inputs[3:] + + W = iter(self.W_list) + b = iter(self.b_list) + + # Sum all neighbors using adjacency matrix + deg_summed: List[np.ndarray] = self.sum_neigh(atom_features, + deg_adj_lists) + + # Get collection of modified atom features + new_rel_atoms_collection = [] + + split_features: Tuple[torch.Tensor, + ...] = torch.split(atom_features, + (deg_slice[:, 1]).tolist()) + for deg in range(1, self.max_degree + 1): + # Obtain relevant atoms for this degree + rel_atoms: torch.Tensor = torch.from_numpy(deg_summed[deg - 1]) + + # Get self atoms + self_atoms: torch.Tensor = split_features[deg - self.min_degree] + + # Apply hidden affine to relevant atoms and append + rel_out: torch.Tensor = torch.matmul(rel_atoms.type(torch.float32), + next(W)) + next(b) + self_out: torch.Tensor = torch.matmul( + self_atoms.type(torch.float32), next(W)) + next(b) + out: torch.Tensor = rel_out + self_out + new_rel_atoms_collection.append( + torch.from_numpy(out.detach().numpy())) + + # Determine the min_deg=0 case + if self.min_degree == 0: + self_atoms = split_features[0] + + # Only use the self layer + out = torch.matmul(self_atoms.type(torch.float32), + next(W)) + next(b) + new_rel_atoms_collection.insert( + 0, torch.from_numpy(out.detach().numpy())) + + # Combine all atoms back into the list + atom_features = torch.concat(new_rel_atoms_collection, 0) + + if self.activation_fn is not None: + atom_features = self.activation_fn(atom_features) + + return atom_features + + def sum_neigh(self, atoms: torch.Tensor, deg_adj_lists) -> List[np.ndarray]: + """Store the summed atoms by degree""" + deg_summed = [] + + for deg in range(1, self.max_degree + 1): + gathered_atoms: torch.Tensor = atoms[deg_adj_lists[deg - 1]] + # Sum along neighbors as well as self, and store + summed_atoms: torch.Tensor = torch.sum(gathered_atoms, 1) + deg_summed.append(summed_atoms.detach().numpy()) + + return deg_summed diff --git a/docs/source/api_reference/layers.rst b/docs/source/api_reference/layers.rst index ba569fa050..7e25b867b5 100644 --- a/docs/source/api_reference/layers.rst +++ b/docs/source/api_reference/layers.rst @@ -224,7 +224,7 @@ Torch Layers .. autoclass:: deepchem.models.torch_models.layers.MolGANAggregationLayer :members: - + .. autoclass:: deepchem.models.torch_models.layers.MolGANMultiConvolutionLayer :members: @@ -239,7 +239,7 @@ Torch Layers .. autoclass:: deepchem.models.torch_models.layers.WeaveGather :members: - + .. autoclass:: deepchem.models.torch_models.layers.MXMNetGlobalMessagePassing :members: @@ -276,6 +276,9 @@ Torch Layers .. autoclass:: deepchem.models.torch_models.layers.HighwayLayer :members: +.. autoclass:: deepchem.models.torch_models.layers.GraphConv + :members: + .. autoclass:: deepchem.models.torch_models.flows.ClampExp :members: diff --git a/docs/source/api_reference/torch_layers.csv b/docs/source/api_reference/torch_layers.csv index fefa181f98..8aed9ea816 100644 --- a/docs/source/api_reference/torch_layers.csv +++ b/docs/source/api_reference/torch_layers.csv @@ -52,4 +52,5 @@ FerminetElectronFeature, `ref `_, Ferminet FerminetEnvelope, `ref `_, FerminetModel MXMNetLocalMessagePassing, `ref `_, MXMNetModel MXMNetModelMXMNetSphericalBasisLayer, ref``_, MXMNetModel -HighwayLayer, `ref `_, \ No newline at end of file +HighwayLayer, `ref `_, +GraphConv,`ref `_,