Skip to content

Commit

Permalink
add pretraining modules
Browse files Browse the repository at this point in the history
  • Loading branch information
akensert committed Jun 21, 2024
1 parent a92e4d2 commit 7b010c3
Show file tree
Hide file tree
Showing 7 changed files with 380 additions and 122 deletions.
66 changes: 59 additions & 7 deletions molexpress/datasets/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@ def __init__(
atom_featurizers: list[featurizers.Featurizer],
bond_featurizers: list[featurizers.Featurizer] = None,
self_loops: bool = False,
supports_masking: bool = False,
) -> None:
self.node_encoder = MolecularNodeEncoder(atom_featurizers)
self.edge_encoder = MolecularEdgeEncoder(bond_featurizers, self_loops=self_loops)
self.node_encoder = MolecularNodeEncoder(
atom_featurizers, supports_masking=supports_masking)
self.edge_encoder = MolecularEdgeEncoder(
bond_featurizers, self_loops=self_loops, supports_masking=supports_masking)

def __call__(self, residues: list[types.Molecule | types.SMILES | types.InChI]) -> np.ndarray:
residue_graphs = []
Expand Down Expand Up @@ -78,6 +81,45 @@ def collate_fn(
else:
return disjoint_peptide_batch_graph, np.stack(y)

@staticmethod
def masked_collate_fn(
data: list[types.MolecularGraph],
node_masking_rate: float = 0.25,
edge_masking_rate: float = 0.25,
) -> tuple[types.MolecularGraph, np.ndarray]:
"""
Merge list of graphs into a single disjoint graph.
Data can be a list of MolecularGraphs or a list of tuples where the first element is a
MolecularGraph and the second element is a label.
"""
disjoint_peptide_graphs = data

disjoint_peptide_batch_graph = PeptideGraphEncoder._merge_molecular_graphs(
disjoint_peptide_graphs
)

node_state = disjoint_peptide_batch_graph['node_state']
node_mask = np.random.uniform(size=node_state.shape[0]) < node_masking_rate
disjoint_peptide_batch_graph['node_loss_weight'] = np.copy(node_mask.astype(node_state.dtype))
disjoint_peptide_batch_graph['node_label'] = np.copy(disjoint_peptide_batch_graph['node_state'])
mask_state = np.zeros_like(node_state)
mask_state[:, -1] = 1.
disjoint_peptide_batch_graph['node_state'] = np.where(
node_mask[:, None], mask_state, node_state)

edge_state = disjoint_peptide_batch_graph['edge_state']
edge_mask = np.random.uniform(size=edge_state.shape[0]) < edge_masking_rate
disjoint_peptide_batch_graph['edge_loss_weight'] = np.copy(edge_mask.astype(edge_state.dtype))
disjoint_peptide_batch_graph['edge_label'] = np.copy(disjoint_peptide_batch_graph['edge_state'])
mask_state = np.zeros_like(edge_state)
mask_state[:, -1] = 1.
disjoint_peptide_batch_graph['edge_state'] = np.where(
edge_mask[:, None], mask_state, edge_state)

return disjoint_peptide_batch_graph

@staticmethod
def _merge_molecular_graphs(
molecular_graphs: list[types.MolecularGraph],
Expand Down Expand Up @@ -139,10 +181,14 @@ def output_dtype(self):

class MolecularEdgeEncoder:
def __init__(
self, featurizers: list[featurizers.Featurizer], self_loops: bool = False
self,
featurizers: list[featurizers.Featurizer],
self_loops: bool = False,
supports_masking: bool = False,
) -> None:
self.featurizer = Composer(featurizers)
self.self_loops = self_loops
self.supports_masking = supports_masking
self.output_dim = self.featurizer.output_dim
self.output_dtype = self.featurizer.output_dtype

Expand Down Expand Up @@ -170,12 +216,14 @@ def __call__(self, molecule: types.Molecule) -> np.ndarray:

if bond is None:
assert self.self_loops, "Found a bond to be None."
bond_encoding = np.zeros(self.output_dim + 1, dtype=self.output_dtype)
bond_encoding[-1] = 1
bond_encoding = np.zeros(
self.output_dim + int(self.self_loops) + int(self.supports_masking),
dtype=self.output_dtype)
bond_encoding[-(int(self.self_loops) + int(self.supports_masking))] = 1
else:
bond_encoding = self.featurizer(bond)
if self.self_loops:
bond_encoding = np.pad(bond_encoding, (0, 1))
bond_encoding = np.pad(
bond_encoding, (0, int(self.self_loops) + int(self.supports_masking)))

bond_encodings.append(bond_encoding)

Expand All @@ -190,11 +238,15 @@ class MolecularNodeEncoder:
def __init__(
self,
featurizers: list[featurizers.Featurizer],
supports_masking: bool = False,
) -> None:
self.featurizer = Composer(featurizers)
self.supports_masking = supports_masking

def __call__(self, molecule: types.Molecule) -> np.ndarray:
node_encodings = np.stack([self.featurizer(atom) for atom in molecule.GetAtoms()], axis=0)
if self.supports_masking:
node_encodings = np.pad(node_encodings, [(0, 0), (0, 1)])
return {
"node_state": np.stack(node_encodings),
}
125 changes: 13 additions & 112 deletions molexpress/datasets/featurizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,120 +8,21 @@

from molexpress import types


DEFAULT_VOCABULARY = {
"AtomType": {
"H",
"He",
"Li",
"Be",
"B",
"C",
"N",
"O",
"F",
"Ne",
"Na",
"Mg",
"Al",
"Si",
"P",
"S",
"Cl",
"Ar",
"K",
"Ca",
"Sc",
"Ti",
"V",
"Cr",
"Mn",
"Fe",
"Co",
"Ni",
"Cu",
"Zn",
"Ga",
"Ge",
"As",
"Se",
"Br",
"Kr",
"Rb",
"Sr",
"Y",
"Zr",
"Nb",
"Mo",
"Tc",
"Ru",
"Rh",
"Pd",
"Ag",
"Cd",
"In",
"Sn",
"Sb",
"Te",
"I",
"Xe",
"Cs",
"Ba",
"La",
"Ce",
"Pr",
"Nd",
"Pm",
"Sm",
"Eu",
"Gd",
"Tb",
"Dy",
"Ho",
"Er",
"Tm",
"Yb",
"Lu",
"Hf",
"Ta",
"W",
"Re",
"Os",
"Ir",
"Pt",
"Au",
"Hg",
"Tl",
"Pb",
"Bi",
"Po",
"At",
"Rn",
"Fr",
"Ra",
"Ac",
"Th",
"Pa",
"U",
"Np",
"Pu",
"Am",
"Cm",
"Bk",
"Cf",
"Es",
"Fm",
"Md",
"No",
"Lr",
"Rf",
"Db",
"Sg",
"Bh",
"Hs",
"Mt",
"Ds",
"Rg",
"Cn",
'H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne',
'Na', 'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar', 'K', 'Ca',
'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn',
'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr', 'Rb', 'Sr', 'Y', 'Zr',
'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn',
'Sb', 'Te', 'I', 'Xe', 'Cs', 'Ba', 'La', 'Ce', 'Pr', 'Nd',
'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er', 'Tm', 'Yb',
'Lu', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg',
'Tl', 'Pb', 'Bi', 'Po', 'At', 'Rn', 'Fr', 'Ra', 'Ac', 'Th',
'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk', 'Cf', 'Es', 'Fm',
'Md', 'No', 'Lr', 'Rf', 'Db', 'Sg', 'Bh', 'Hs', 'Mt', 'Ds',
'Rg', 'Cn'
},
"Hybridization": {"s", "sp", "sp2", "sp3", "sp3d", "sp3d2", "unspecified"},
"CIPCode": {"R", "S", "None"},
Expand Down
3 changes: 2 additions & 1 deletion molexpress/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
from molexpress.layers.gcn_conv import GCNConv as GCNConv
from molexpress.layers.gin_conv import GINConv as GINConv
from molexpress.layers.peptide_readout import PeptideReadout as PeptideReadout
from molexpress.layers.residue_readout import ResidueReadout as ResidueReadout
from molexpress.layers.residue_readout import ResidueReadout as ResidueReadout
from molexpress.layers.gather_incident import GatherIncident as GatherIncident
19 changes: 19 additions & 0 deletions molexpress/layers/gather_incident.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import keras

from molexpress import types
from molexpress.ops import gnn_ops


class GatherIncident(keras.layers.Layer):

def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)

def call(self, inputs: types.MolecularGraph):
node_state_src = gnn_ops.gather(
inputs['node_state'], inputs['edge_src']
)
node_state_dst = gnn_ops.gather(
inputs['node_state'], inputs['edge_dst']
)
return keras.ops.concatenate([node_state_src, node_state_dst], axis=1)
12 changes: 11 additions & 1 deletion molexpress/ops/gnn_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def transform(
state_transformed += bias
return state_transformed


def aggregate(
node_state: types.Array,
edge_src: types.Array,
Expand Down Expand Up @@ -83,6 +82,17 @@ def aggregate(
)
return node_state_updated

def gather(
node_state: types.Array,
edge: types.Array,
) -> types.Array:
edge = keras.ops.cast(edge, "int32")
expected_rank = 2
current_rank = len(keras.ops.shape(edge))
for _ in range(expected_rank - current_rank):
edge = keras.ops.expand_dims(edge, axis=-1)
node_state_edge = keras.ops.take_along_axis(node_state, edge, axis=0)
return node_state_edge

def segment_mean(
data: types.Array,
Expand Down
Loading

0 comments on commit 7b010c3

Please sign in to comment.