Skip to content

Commit

Permalink
feat: jaqpot-403/Docstrings_and_remove_torchscript (#102)
Browse files Browse the repository at this point in the history
* docstrings_and_only_onnx_support

* 2_onnx_examples
  • Loading branch information
johnsaveus authored Nov 4, 2024
1 parent 0e90114 commit faa57ff
Show file tree
Hide file tree
Showing 10 changed files with 644 additions and 240 deletions.
78 changes: 78 additions & 0 deletions examples/graph_onnx_classif.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import pandas as pd
from jaqpotpy.descriptors.graph import SmilesGraphFeaturizer
from rdkit import Chem
from jaqpotpy.datasets import SmilesGraphDataset
from jaqpotpy.models.torch_geometric_models.graph_neural_network import (
GraphSageNetwork,
pyg_to_onnx,
)
from jaqpotpy.models.trainers.graph_trainers import BinaryGraphModelTrainer
from torch.optim import Adam
from torch.nn import BCEWithLogitsLoss

from torch_geometric.loader import DataLoader
from jaqpotpy import Jaqpot

df = pd.read_csv("./jaqpotpy/test_data/test_data_smiles_classification.csv")

train_smiles = list(df["SMILES"].iloc[:100])
train_y = list(df["ACTIVITY"].iloc[:100])

val_smiles = list(df["SMILES"].iloc[100:200])
val_y = list(df["ACTIVITY"].iloc[100:200])

featurizer = SmilesGraphFeaturizer()
featurizer.add_atom_feature("symbol", ["C", "O", "N", "F", "Cl", "Br", "I"])

featurizer.add_bond_feature(
"bond_type",
[
Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE,
Chem.rdchem.BondType.AROMATIC,
],
)

train_dataset = SmilesGraphDataset(
smiles=train_smiles, y=train_y, featurizer=featurizer
)
val_dataset = SmilesGraphDataset(smiles=val_smiles, y=val_y, featurizer=featurizer)

train_dataset.precompute_featurization()
val_dataset.precompute_featurization()

input_dim = featurizer.get_num_node_features()
edge_dim = featurizer.get_num_edge_features()

model = GraphSageNetwork(
input_dim=input_dim,
hidden_layers=2,
hidden_dim=16,
output_dim=1,
dropout_proba=0.5,
seed=42,
)

optimizer = Adam(model.parameters(), lr=0.001)
loss = BCEWithLogitsLoss()
epochs = 5
trainer = BinaryGraphModelTrainer(model, epochs, optimizer, loss)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
trainer.train(train_loader, val_loader)

onnx_model = pyg_to_onnx(model, featurizer)

jaqpot = Jaqpot()
jaqpot.login()

jaqpot.deploy_torch_model(
onnx_model,
featurizer=featurizer,
name="Graph Sage Network",
description="Graph Sage Network for binary classification",
target_name="ACTIVITY",
visibility="PRIVATE",
task="binary_classification",
)
78 changes: 78 additions & 0 deletions examples/graph_onnx_reg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import pandas as pd
from jaqpotpy.descriptors.graph import SmilesGraphFeaturizer
from rdkit import Chem
from jaqpotpy.datasets import SmilesGraphDataset
from jaqpotpy.models.torch_geometric_models.graph_neural_network import (
GraphSageNetwork,
pyg_to_onnx,
)
from jaqpotpy.models.trainers.graph_trainers import RegressionGraphModelTrainer
from torch.optim import Adam
from torch.nn import MSELoss

from torch_geometric.loader import DataLoader
from jaqpotpy import Jaqpot

df = pd.read_csv("./jaqpotpy/test_data/test_data_smiles_regression.csv")

train_smiles = list(df["SMILES"].iloc[:100])
train_y = list(df["ACTIVITY"].iloc[:100])

val_smiles = list(df["SMILES"].iloc[100:200])
val_y = list(df["ACTIVITY"].iloc[100:200])

featurizer = SmilesGraphFeaturizer()
featurizer.add_atom_feature("symbol", ["C", "O", "N", "F", "Cl", "Br", "I"])

featurizer.add_bond_feature(
"bond_type",
[
Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE,
Chem.rdchem.BondType.AROMATIC,
],
)

train_dataset = SmilesGraphDataset(
smiles=train_smiles, y=train_y, featurizer=featurizer
)
val_dataset = SmilesGraphDataset(smiles=val_smiles, y=val_y, featurizer=featurizer)

train_dataset.precompute_featurization()
val_dataset.precompute_featurization()

input_dim = featurizer.get_num_node_features()
edge_dim = featurizer.get_num_edge_features()

model = GraphSageNetwork(
input_dim=input_dim,
hidden_layers=2,
hidden_dim=16,
output_dim=1,
dropout_proba=0.5,
seed=42,
)

optimizer = Adam(model.parameters(), lr=0.001)
loss = MSELoss()
epochs = 5
trainer = RegressionGraphModelTrainer(model, epochs, optimizer, loss)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
trainer.train(train_loader, val_loader)

onnx_model = pyg_to_onnx(model, featurizer)

jaqpot = Jaqpot()
jaqpot.login()

jaqpot.deploy_torch_model(
onnx_model,
featurizer=featurizer,
name="Graph Sage Network",
description="Graph Sage Network for regression",
target_name="ACTIVITY",
visibility="PRIVATE",
task="regression",
)
74 changes: 56 additions & 18 deletions jaqpotpy/datasets/graph_pyg_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@

class SmilesGraphDataset(Dataset):
"""
A PyTorch Dataset class for handling SMILES strings as graphs.
This class overrides `__getitem__` and `__len__` (check source code for methods' docstrings).
A PyTorch Dataset class for handling SMILES strings as graph data suitable for training
graph neural networks. The class transforms SMILES strings into graph representations using
a specified featurizer and optionally supports target values for supervised learning tasks.
Attributes:
smiles (list): A list of SMILES strings.
y (list, optional): A list of target values.
featurizer (SmilesGraphFeaturizer): The object to transform SMILES strings into graph representations.
precomputed_features (list, optional): A list of precomputed features. If precompute_featurization() is not called, this attribute remains None.
smiles (list): A list of SMILES strings to be converted into graph data.
y (list, optional): A list of target values associated with each SMILES string.
featurizer (SmilesGraphFeaturizer): A featurizer to transform SMILES into graph representations.
precomputed_features (list, optional): Precomputed graph features; remains None until `precompute_featurization` is called.
"""

def __init__(
Expand All @@ -20,7 +22,15 @@ def __init__(
y: Optional[list] = None,
featurizer: Optional[SmilesGraphFeaturizer] = None,
):
"""The SmilesGraphDataset constructor."""
"""
Initializes the SmilesGraphDataset with SMILES strings, target values, and an optional featurizer.
Args:
smiles (list): List of SMILES strings to be transformed into graphs.
y (list, optional): List of target values for supervised learning tasks.
featurizer (SmilesGraphFeaturizer, optional): A featurizer for converting SMILES into graphs;
if not provided, a default featurizer is used.
"""
super().__init__()
self.smiles = smiles
self.y = y
Expand All @@ -36,7 +46,16 @@ def __init__(
self.precomputed_features = None

def precompute_featurization(self):
"""Precomputes the featurization of the dataset before being accessed by __getitem__"""
"""
Precomputes the featurized graph representations of the SMILES strings in the dataset.
This method prepares the graph data in advance, which can improve efficiency when
accessing individual data samples. Each SMILES string is transformed into a graph
representation (and paired with its target value if available).
Sets:
self.precomputed_features (list): A list of graph features precomputed for each SMILES.
"""
if self.y:
self.precomputed_features = [
self.featurizer(sm, y) for sm, y in zip(self.smiles, self.y)
Expand All @@ -45,25 +64,37 @@ def precompute_featurization(self):
self.precomputed_features = [self.featurizer(sm) for sm in self.smiles]

def get_num_node_features(self):
"""Returns the number of node features."""
"""
Returns the number of node features (atom-level features) in each graph representation.
Returns:
int: The number of features associated with each node (atom).
"""
return len(self.get_atom_feature_labels())

def get_num_edge_features(self):
"""Returns the number of edge features."""
"""
Returns the number of edge features (bond-level features) in each graph representation.
Returns:
int: The number of features associated with each edge (bond).
"""
return len(self.get_bond_feature_labels())

def __getitem__(self, idx):
"""
Retrieves the featurized graph Data object and target value for a given index.
Retrieves a featurized graph representation and target value (if available) for a specific index.
Args:
idx (int): Index of the data to retrieve.
idx (int): Index of the sample to retrieve.
Returns:
torch_geometric.data.Data: A torch_geometric.data.Data object for this single sample containing:
torch_geometric.data.Data: A graph data object containing:
- x (torch.Tensor): Node feature matrix with shape [num_nodes, num_node_features].
- edge_index (LongTensor): Graph connectivity in COO format with shape [2, num_edges].
- edge_attr (torch.Tensor, optional): Edge feature matrix with shape [num_edges, num_edge_features].
- y (float): Graph-level ground-truth label.
- smiles (str): The SMILES string corresponding to the particular sample.
- y (float, optional): Target value associated with the graph, if provided.
- smiles (str): The SMILES string for the specific sample.
"""
if self.precomputed_features:
return self.precomputed_features[idx]
Expand All @@ -73,11 +104,18 @@ def __getitem__(self, idx):

def __len__(self):
"""
__len__ functionality is important for the DataLoader to determine batching,
shuffling and iterating over the dataset.
Returns the total number of SMILES strings in the dataset, necessary for data loading operations.
Returns:
int: The length of the dataset (number of SMILES strings).
"""
return len(self.smiles)

def __repr__(self) -> str:
"""Official string representation of the Dataset Object"""
"""
Returns a formal string representation of the SmilesGraphDataset object, indicating its class name.
Returns:
str: Class name as the representation of the dataset object.
"""
return self.__class__.__name__
Loading

0 comments on commit faa57ff

Please sign in to comment.