Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Integration of torch models in main #34

Closed
wants to merge 104 commits into from

Conversation

jpitoskas
Copy link
Contributor

@jpitoskas jpitoskas commented Jun 14, 2024

Incorporate Torch Model training into jaqpotpy

The jaqpotpy_torch supports the following types of modelling tasks:

  • Binary Classification of SMILES using Graph NNs
  • Multiclass Classification of SMILES using Graph NNs
  • Regression of SMILES using Graph NNs
  • Binary Classification of SMILES combined with external features using using Graph NNs along with a Fully Connected NN
  • Multiclass Classification of SMILES combined with external features using using Graph NNs along with a Fully Connected NN
  • Regression from SMILES combined with external features using using Graph NNs along with a Fully Connected NN
  • Binary Classification of Tabular Data using a Fully Connected NN
  • Multiclass Classification of Tabular Data using a Fully Connected NN
  • Regression of Tabular Data using a Fully Connected NN

Notes for future implementation

Featurizers

Featurizers that inherit from our base featurizer class Featurizer must implement the featurize method.

Datasets

Datasets must inherit from Torch Dataset and implement/override the __len__ and __getitem__ methods.

Trainers

Trainers that inherit from our base class TorchModelTrainer must implement get_model_type, train, evaluate, prepare_for_deployment.

Models

Models must inherit form nn.Module and implement __init__ and forward.

jpitoskas and others added 30 commits May 15, 2024 15:30
This commit adds the following implementations to the models_torch subpackage:
- Added __init__.py for the subpackage
- Implemented GraphAttentionNetwork in graph_attention_network.py
- Implemented GraphConvolutionalNetwork in graph_convolutional_network.py
- Implemented GraphSAGENetwork in graph_sage_network.py
- Implemented GraphTransformerNetwork in graph_transformer_network.py

These implementations provide configurable nn architectural support for training graph-based models using PyTorch.
Create a new directory named jaqpotpy_torch/ to organize all torch-related code.
We'll decide later whether to keep this code here or move it to an entirely new standalone torch-specific package.
This commit initializes the featurizers_torch subpackage, adding the following implementations:
- Added __init__.py for the subpackage
- Implemented SmilesGraphFeaturizer in smiles_graph_featurizer.py.

The SmilesGraphFeaturizer class is designed to create custom graph featurizations from SMILES strings.
It offers highly configurable options, allowing users to choose from a wide range of both atom and bond characteristics to be included.
This commit initializes the datasets_torch subpackage, adding the following implementations:
- Added __init__.py for the subpackage
- Implemented SmilesGraphDataset in smiles_graph_dataset.py.

The SmilesGraphDataset class is designed to create a custom torch Dataset for graph-featurized SMILES.
Its __getitem__ method is overridden to return a torch_geometric Data object enacpsulating the following information:
- Node attributes (x)
- Edge indices (edge_index)
- Edge attributes (edge_attr)
- Target labels (y)
- The original SMILES representation (smiles)

SmilesGraphDataset enables straightforward integration into torch-based ML pipelines, facilitating the development of graph-based predictive models.
This commit removes the _torch suffix directory names within the jaqpotpy_torch module:

- Renamed datasets_torch directory to datasets
- Renamed featurizers_torch directory to featurizers
- Renamed models_torch directory to models
This commit initializes the trainers subpackage, adding the following implementations:
- Added __init__.py for the subpackage
- Implemented an initial version of TorchModelTrainer abstract class in torch_model_trainer.py.
This commit initializes the trainers subpackage, adding the following implementations:
- Added BinaryGraphModelTrainer, RegressionGraphModelTrainer in __init__.py
- Extended TorchModelTrainer class with additional attributes.
This commit adds the following implementations to the trainers subpackage:
- BinaryGraphModelTrainer subclass
- RegressionGraphModelTrainer subclass
This commit adds the SmilesGraphDatasetWithExternal class implementation to the datasets subpackage.
This class inherits from SmilesGraphDataset, and adds the functionality of providing an external
feature vector along with the smiles representation.
This commit adds the implementation of the Featurizer abstract class.
Also the abstract method featurize() is defined.
This commit adds the FullyConnectedNetwork class implementation to the models.
This commit adds the following implementations to the models subpackage:
- GraphAttentionNetworkWithExternal
- GraphConvolutionalNetworkWithExternal
- GraphSAGENetworkWithExternal
- GraphTransformerNetworkWithExternal

In these models the corresponding graph neural network is employed to produce
global level representations from smiles. Then these are concatenated with the
external feature vectors and the concatenated vector is passed through a fully
connected network to produce the final output.
- Fixed a circular import error of the FullyConnectedNetwork class
- Added super().__init__() to all the models supporting external features
This commit adds the implementation of the deploy_model for both
RegressionGraphModelTrainer and BinaryGraphModelTrainer.

deploy_model() is an abstract method of the TorchModelTrainer base class,
and must be implemented in every class that inherits from TorchModelTrainer,
to support model deployment on Jaqpot.
In this commit we:
- Implement the deployment logic for models and trainers that use external features
- Set deploy_model function to be on the TorchModelTrainer class
- Define the abstract method prepare_for_deployment() with a dynamic set of arguments per trainer subclass which transforms the data into the appropriate JSON
- Add 'SMILES' in a protected namespace so that external features can't be named like this
- Fix bugs regarding model input arguments
This commit provides:
- Ready for deployment torch models are implemented
- Everything up to date with the current structure of the API
This commit implements:
- TabularDataset class inheriting from torch.utils.data.Dataset
- BinaryFCModelTrainer & RegressionFCModelTrainer
- The required changed in the BinaryModelTrainer and RegressionModelTrainer abstract classes to support data from torch.utils.data.DataLoader as well
@jpitoskas jpitoskas changed the title feat: Getting up to date with main feat: Integration of torch models in main Jun 30, 2024
@alarv
Copy link
Member

alarv commented Sep 11, 2024

Closing this as it has been integrated into the main branch as part of JAQPOT-254 and is partly done on #47

@alarv alarv closed this Sep 11, 2024
@alarv alarv deleted the feat/JAQPOT-62/torch-graph-training branch September 11, 2024 12:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants