Skip to content

Commit

Permalink
rewrite code for peptides, including residue readout
Browse files Browse the repository at this point in the history
  • Loading branch information
akensert committed Apr 15, 2024
1 parent 8b21faf commit 155bed1
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 43 deletions.
72 changes: 52 additions & 20 deletions molexpress/datasets/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from molexpress.ops import chem_ops


class MolecularGraphEncoder:
class PeptideGraphEncoder:
def __init__(
self,
atom_featurizers: list[featurizers.Featurizer],
Expand All @@ -17,10 +17,21 @@ def __init__(
self.node_encoder = MolecularNodeEncoder(atom_featurizers)
self.edge_encoder = MolecularEdgeEncoder(bond_featurizers, self_loops=self_loops)

def __call__(self, molecule: types.Molecule | types.SMILES | types.InChI) -> np.ndarray:
molecule = chem_ops.get_molecule(molecule)
return {**self.node_encoder(molecule), **self.edge_encoder(molecule)}

def __call__(self, molecules: list[types.Molecule | types.SMILES | types.InChI]) -> np.ndarray:
molecular_graphs = []
residue_sizes = []
for molecule in molecules:
molecule = chem_ops.get_molecule(molecule)
molecular_graph = {
**self.node_encoder(molecule),
**self.edge_encoder(molecule)
}
molecular_graphs.append(molecular_graph)
residue_sizes.append(molecule.GetNumAtoms())
graph = self._merge_molecular_graphs(molecular_graphs)
graph["residue_size"] = np.array(residue_sizes)
return graph

@staticmethod
def _collate_fn(
data: list[tuple[types.MolecularGraph, np.ndarray]],
Expand All @@ -34,27 +45,46 @@ def _collate_fn(

x, y = list(zip(*data))

num_nodes = np.array([graph["node_state"].shape[0] for graph in x])
disjoint_graph = PeptideGraphEncoder._merge_molecular_graphs(x)
disjoint_graph["peptide_size"] = np.concatenate([
graph["residue_size"].shape[:1] for graph in x
]).astype("int32")
disjoint_graph["residue_size"] = np.concatenate([
graph["residue_size"] for graph in x
]).astype("int32")
return disjoint_graph, np.stack(y)

@staticmethod
def _merge_molecular_graphs(
molecular_graphs: list[types.MolecularGraph],
) -> types.MolecularGraph:

disjoint_graph = {}
num_nodes = np.array([
g["node_state"].shape[0] for g in molecular_graphs
])

disjoint_graph["node_state"] = np.concatenate([graph["node_state"] for graph in x])
disjoint_molecular_graph = {}

if "edge_state" in x[0]:
disjoint_graph["edge_state"] = np.concatenate([graph["edge_state"] for graph in x])
disjoint_molecular_graph["node_state"] = np.concatenate([
g["node_state"] for g in molecular_graphs
])

edge_src = np.concatenate([graph["edge_src"] for graph in x])
edge_dst = np.concatenate([graph["edge_dst"] for graph in x])
num_edges = np.array([graph["edge_src"].shape[0] for graph in x])
indices = np.repeat(range(len(x)), num_edges)
if "edge_state" in molecular_graphs[0]:
disjoint_molecular_graph["edge_state"] = np.concatenate([
g["edge_state"] for g in molecular_graphs
])

edge_src = np.concatenate([graph["edge_src"] for graph in molecular_graphs])
edge_dst = np.concatenate([graph["edge_dst"] for graph in molecular_graphs])
num_edges = np.array([graph["edge_src"].shape[0] for graph in molecular_graphs])
indices = np.repeat(range(len(molecular_graphs)), num_edges)
edge_incr = np.concatenate([[0], num_nodes[:-1]])
edge_incr = np.take_along_axis(edge_incr, indices, axis=0)

disjoint_graph["edge_src"] = edge_src + edge_incr
disjoint_graph["edge_dst"] = edge_dst + edge_incr
disjoint_graph["graph_indicator"] = np.repeat(range(len(x)), num_nodes)
disjoint_molecular_graph["edge_src"] = edge_src + edge_incr
disjoint_molecular_graph["edge_dst"] = edge_dst + edge_incr

return disjoint_graph, np.stack(y)
return disjoint_molecular_graph


class Composer:
Expand Down Expand Up @@ -103,7 +133,7 @@ def __call__(self, molecule: types.Molecule) -> np.ndarray:

if molecule.GetNumBonds() == 0:
edge_state = np.zeros(
shape=(0, self.output_dim + int(self.self_loops)),
shape=(int(self.self_loops), self.output_dim + int(self.self_loops)),
dtype=self.output_dtype
)
return {
Expand Down Expand Up @@ -144,4 +174,6 @@ def __init__(

def __call__(self, molecule: types.Molecule) -> np.ndarray:
node_encodings = np.stack([self.featurizer(atom) for atom in molecule.GetAtoms()], axis=0)
return {"node_state": np.stack(node_encodings)}
return {
"node_state": np.stack(node_encodings),
}
3 changes: 2 additions & 1 deletion molexpress/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from molexpress.layers.base_layer import BaseLayer as BaseLayer
from molexpress.layers.gcn_conv import GCNConv as GCNConv
from molexpress.layers.gin_conv import GINConv as GINConv
from molexpress.layers.readout import Readout as Readout
from molexpress.layers.peptide_readout import PeptideReadout as PeptideReadout
from molexpress.layers.residue_readout import ResidueReadout as ResidueReadout
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from molexpress.ops import gnn_ops


class Readout(keras.layers.Layer):
class PeptideReadout(keras.layers.Layer):
def __init__(self, mode: str = "mean", **kwargs) -> None:
super().__init__(**kwargs)
self.mode = mode
Expand All @@ -18,14 +18,21 @@ def __init__(self, mode: str = "mean", **kwargs) -> None:
self._readout_fn = gnn_ops.segment_mean

def build(self, input_shape: dict[str, tuple[int, ...]]) -> None:
if "graph_indicator" not in input_shape:
raise ValueError("Cannot perform readout: 'graph_indicator' not found.")
if "peptide_size" not in input_shape:
raise ValueError("Cannot perform readout: 'peptide_size' not found.")

def call(self, inputs: types.MolecularGraph) -> types.Array:
graph_indicator = keras.ops.cast(inputs["graph_indicator"], "int32")
peptide_size = keras.ops.cast(inputs['peptide_size'], 'int32')
residue_size = keras.ops.cast(inputs['residue_size'], 'int32')
n_peptides = keras.ops.shape(peptide_size)[0]
repeats = keras.ops.segment_sum(
residue_size,
keras.ops.repeat(range(n_peptides), peptide_size)
)
segment_ids = keras.ops.repeat(range(n_peptides), repeats)
return self._readout_fn(
data=inputs["node_state"],
segment_ids=graph_indicator,
segment_ids=segment_ids,
num_segments=None,
sorted=False,
)
58 changes: 58 additions & 0 deletions molexpress/layers/residue_readout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from __future__ import annotations

import keras

from molexpress import types
from molexpress.ops import gnn_ops


class ResidueReadout(keras.layers.Layer):
def __init__(self, mode: str = "mean", **kwargs) -> None:
super().__init__(**kwargs)
self.mode = mode
if self.mode == "max":
self._readout_fn = keras.ops.segment_max
elif self.mode == "sum":
self._readout_fn = keras.ops.segment_sum
else:
self._readout_fn = gnn_ops.segment_mean

def build(self, input_shape: dict[str, tuple[int, ...]]) -> None:
if "residue_size" not in input_shape:
raise ValueError("Cannot perform readout: 'residue_size' not found.")

def call(self, inputs: types.MolecularGraph) -> types.Array:
peptide_size = keras.ops.cast(inputs['peptide_size'], 'int32')
residue_size = keras.ops.cast(inputs['residue_size'], 'int32')
n_residues = keras.ops.shape(residue_size)[0]
segment_ids = keras.ops.repeat(range(n_residues), residue_size)
residue_state = self._readout_fn(
data=inputs["node_state"],
segment_ids=segment_ids,
num_segments=None,
sorted=False,
)
# Make shape known
residue_state = keras.ops.reshape(
residue_state,
(
keras.ops.shape(residue_size)[0],
keras.ops.shape(inputs['node_state'])[-1]
)
)

if keras.ops.shape(peptide_size)[0] == 1:
# Single peptide in batch
return residue_state[None]

# Split and stack (with padding in the second dim)
# Resulting shape: (n_peptides, n_residues, n_features)
residues = keras.ops.split(residue_state, peptide_size[:-1])
max_residue_size = keras.ops.max([len(r) for r in residues])
return keras.ops.stack([
keras.ops.pad(r, [(0, max_residue_size-keras.ops.shape(r)[0]), (0, 0)])
for r in residues
])



4 changes: 4 additions & 0 deletions molexpress/ops/gnn_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def aggregate(
"""
num_nodes = keras.ops.shape(node_state)[0]

# Instead of casting to int, throw an error if not int?
edge_src = keras.ops.cast(edge_src, "int32")
edge_dst = keras.ops.cast(edge_dst, "int32")

expected_rank = 2
current_rank = len(keras.ops.shape(edge_src))
for _ in range(expected_rank - current_rank):
Expand Down
39 changes: 22 additions & 17 deletions notebooks/examples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
"from molexpress import layers\n",
"from molexpress.datasets import featurizers\n",
"from molexpress.datasets import encoders\n",
"\n",
"from rdkit import Chem\n",
"from molexpress.ops.chem_ops import get_molecule\n",
"\n",
"import torch"
]
Expand All @@ -34,7 +33,7 @@
"metadata": {},
"outputs": [],
"source": [
"mol = Chem.MolFromSmiles('CCO')\n",
"mol = get_molecule('C(C(=O)O)N')\n",
"\n",
"print(featurizers.AtomType(vocab={'O'}, oov=False)(mol.GetAtoms()[0]))\n",
"print(featurizers.AtomType(vocab={'O'}, oov=True)(mol.GetAtoms()[0]))\n",
Expand Down Expand Up @@ -67,13 +66,15 @@
" featurizers.BondType()\n",
"]\n",
"\n",
"encoder = encoders.MolecularGraphEncoder(\n",
"peptide_graph_encoder = encoders.PeptideGraphEncoder(\n",
" atom_featurizers=atom_featurizers, \n",
" bond_featurizers=bond_featurizers,\n",
" self_loops=True # adds one dim to edge state\n",
")\n",
"\n",
"encoder(mol)"
"mol2 = get_molecule('CC(C(=O)O)N')\n",
"\n",
"peptide_graph_encoder([mol, mol2])"
]
},
{
Expand All @@ -91,8 +92,12 @@
"metadata": {},
"outputs": [],
"source": [
"x_dummy = ['CC', 'CC', 'CCO', 'CCCN']\n",
"y_dummy = [1., 2., 3., 4.]\n",
"x_dummy = [\n",
" ['CC(C)C(C(=O)O)N', 'C(C(=O)O)N'], \n",
" ['C(C(=O)O)N', 'CC(C(=O)O)N', 'C(C(=O)O)N'], \n",
" ['CC(C(=O)O)N']\n",
"]\n",
"y_dummy = [1., 2., 3.]\n",
"\n",
"\n",
"class TinyDataset(torch.utils.data.Dataset):\n",
Expand All @@ -107,13 +112,13 @@
" def __getitem__(self, index):\n",
" x = self.x[index]\n",
" y = self.y[index]\n",
" x = encoder(x)\n",
" return x, y\n",
" x = peptide_graph_encoder(x)\n",
" return x, [y]\n",
"\n",
"torch_dataset = TinyDataset(x_dummy, y_dummy)\n",
"\n",
"dataset = torch.utils.data.DataLoader(\n",
" torch_dataset, batch_size=2, collate_fn=encoder._collate_fn)\n",
" torch_dataset, batch_size=2, collate_fn=peptide_graph_encoder._collate_fn)\n",
"\n",
"for x, y in dataset:\n",
" print(f'x = {x}\\ny = {y}', end='\\n' + '---' * 30 + '\\n')"
Expand Down Expand Up @@ -141,14 +146,16 @@
"\n",
" self.gcn1 = layers.GINConv(32)\n",
" self.gcn2 = layers.GINConv(32)\n",
" self.readout = layers.Readout()\n",
" self.readout = layers.ResidueReadout()\n",
" self.lstm = torch.nn.LSTM(32, 32, 1, batch_first=True)\n",
" self.linear = torch.nn.Linear(32, 1)\n",
"\n",
" def forward(self, x):\n",
" x = self.gcn1(x)\n",
" x = self.gcn2(x)\n",
" x = self.readout(x)\n",
" x = self.linear(x)\n",
" x, (_, _) = self.lstm(x)\n",
" x = self.linear(x[:, -1, :])\n",
" return x\n",
"\n",
"model = TinyGCNModel().to('cuda')"
Expand All @@ -169,18 +176,16 @@
"metadata": {},
"outputs": [],
"source": [
"optimizer = torch.optim.SGD(model.parameters(), lr=0.00001, momentum=0.9)\n",
"optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)\n",
"loss_fn = torch.nn.MSELoss()\n",
"\n",
"for _ in range(30):\n",
" loss_sum = 0.\n",
" for x, y in dataset:\n",
" optimizer.zero_grad()\n",
" \n",
" outputs = model(x)\n",
" \n",
" y = torch.tensor(y, dtype=torch.float32).to('cuda')\n",
" loss = loss_fn(outputs, y[:, None])\n",
" loss = loss_fn(outputs, y)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
Expand All @@ -192,7 +197,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "ad047588-9926-4838-a264-193476897b4b",
"id": "9fe0fe29-34d1-445a-9ea7-81e2e3aa0046",
"metadata": {},
"outputs": [],
"source": []
Expand Down

0 comments on commit 155bed1

Please sign in to comment.