Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various fixes #3

Merged
merged 2 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions molexpress/datasets/encoders.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import Union

import numpy as np

from molexpress import types
Expand All @@ -23,27 +25,32 @@ def __call__(self, residues: list[types.Molecule | types.SMILES | types.InChI])
for residue in residues:
residue = chem_ops.get_molecule(residue)
residue_graph = {
**self.node_encoder(residue),
**self.node_encoder(residue),
**self.edge_encoder(residue)
}
residue_graphs.append(residue_graph)
residue_sizes.append(residue.GetNumAtoms())
disjoint_peptide_graph = self._merge_molecular_graphs(residue_graphs)
disjoint_peptide_graph["residue_size"] = np.array(residue_sizes)
return disjoint_peptide_graph

disjoint_peptide_graph["peptide_size"] = np.array([len(residues)], dtype="int32")
return disjoint_peptide_graph

@staticmethod
def _collate_fn(
data: list[tuple[types.MolecularGraph, np.ndarray]],
def collate_fn(
data: list[Union[types.MolecularGraph, tuple[types.MolecularGraph, np.ndarray]]],
) -> tuple[types.MolecularGraph, np.ndarray]:
"""TODO: Not sure where to implement this collate function.
Temporarily putting it here.

Procedure:
Merges list of graphs into a single disjoint graph.
"""
Merge list of graphs into a single disjoint graph.

disjoint_peptide_graphs, y = list(zip(*data))
Data can be a list of MolecularGraphs or a list of tuples where the first element is a
MolecularGraph and the second element is a label.

"""
if isinstance(data[0], tuple):
disjoint_peptide_graphs, y = list(zip(*data))
else:
disjoint_peptide_graphs = data
y = None

disjoint_peptide_batch_graph = PeptideGraphEncoder._merge_molecular_graphs(
disjoint_peptide_graphs
Expand All @@ -54,7 +61,11 @@ def _collate_fn(
disjoint_peptide_batch_graph["residue_size"] = np.concatenate([
g["residue_size"] for g in disjoint_peptide_graphs
]).astype("int32")
return disjoint_peptide_batch_graph, np.stack(y)

if y is None:
return disjoint_peptide_batch_graph
else:
return disjoint_peptide_batch_graph, np.stack(y)

@staticmethod
def _merge_molecular_graphs(
Expand Down
30 changes: 13 additions & 17 deletions molexpress/layers/residue_readout.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def build(self, input_shape: dict[str, tuple[int, ...]]) -> None:
raise ValueError("Cannot perform readout: 'residue_size' not found.")

def call(self, inputs: types.MolecularGraph) -> types.Array:
peptide_size = keras.ops.cast(inputs['peptide_size'], 'int32')
residue_size = keras.ops.cast(inputs['residue_size'], 'int32')
peptide_size = keras.ops.cast(inputs["peptide_size"], "int32")
residue_size = keras.ops.cast(inputs["residue_size"], "int32")
n_residues = keras.ops.shape(residue_size)[0]
segment_ids = keras.ops.repeat(range(n_residues), residue_size)
residue_state = self._readout_fn(
Expand All @@ -34,25 +34,21 @@ def call(self, inputs: types.MolecularGraph) -> types.Array:
)
# Make shape known
residue_state = keras.ops.reshape(
residue_state,
(
keras.ops.shape(residue_size)[0],
keras.ops.shape(inputs['node_state'])[-1]
)
residue_state,
(keras.ops.shape(residue_size)[0], keras.ops.shape(inputs["node_state"])[-1]),
)

if keras.ops.shape(peptide_size)[0] == 1:
# Single peptide in batch
return residue_state[None]

# Split and stack (with padding in the second dim)
# Resulting shape: (n_peptides, n_residues, n_features)
residues = keras.ops.split(residue_state, peptide_size[:-1])
residues = keras.ops.split(residue_state, keras.ops.cumsum(peptide_size)[:-1])
max_residue_size = keras.ops.max([len(r) for r in residues])
return keras.ops.stack([
keras.ops.pad(r, [(0, max_residue_size-keras.ops.shape(r)[0]), (0, 0)])
for r in residues
])



return keras.ops.stack(
[
keras.ops.pad(r, [(0, max_residue_size - keras.ops.shape(r)[0]), (0, 0)])
for r in residues
]
)