This repository is deprecated and has been replaced by https://github.com/swansonk14/chemprop, which contains a stable, clean version of the code. This repo is no longer actively maintained and should be used for reference only.
This repository contains graph convolutional networks (or message passing network) for molecule property prediction.
Requirements:
- cuda >= 8.0 + cuDNN
- Python 3/conda: Please follow the installation guide on https://conda.io/miniconda.html
- Create a conda environment with
conda create -n <name> python=3.6
- Activate the environment with
conda activate <name>
- Create a conda environment with
- pytorch: Please follow the installation guide on https://pytorch.org/
- Typically it's
conda install pytorch torchvision -c pytorch
- Typically it's
- tensorflow: Needed for Tensorboard training visualization
- CPU-only:
pip install tensorflow
- GPU:
pip install tensorflow-gpu
- CPU-only:
- RDKit:
conda install -c rdkit rdkit
- Other packages:
pip install -r requirements.txt
- Note that if you get warning messages about kyotocabinet, it's safe to ignore them.
- This is a requirement from descriptastorus, but not necessary for the parts of descriptastorus that we use.
- See https://github.com/bp-kelley/descriptastorus for installation details if you want to install anyway.
The data file must be be a CSV file with a header row. For example:
smiles,NR-AR,NR-AR-LBD,NR-AhR,NR-Aromatase,NR-ER,NR-ER-LBD,NR-PPAR-gamma,SR-ARE,SR-ATAD5,SR-HSE,SR-MMP,SR-p53
CCOc1ccc2nc(S(N)(=O)=O)sc2c1,0,0,1,,,0,0,1,0,0,0,0
CCN1C(=O)NC(c2ccccc2)C1=O,0,0,0,0,0,0,0,,0,,0,0
...
Data sets from deepchem are available in the data
directory.
To train a model, run:
python train.py --data_path <path> --dataset_type <type> --save_dir <dir>
where <path>
is the path to a CSV file containing a dataset, <type>
is either "classification" or "regression" depending on the type of the dataset, and <dir>
is the directory where model checkpoints will be saved.
For example:
python train.py --data_path data/tox21.csv --dataset_type classification --save_dir tox21_checkpoints
Notes:
- Classification is assumed to be binary.
- Empty values in the CSV are ignored.
--save_dir
may be left out if you don't want to save model checkpoints.- The default metric for classification is AUC and the default metric for regression is RMSE. The qm8 and qm9 datasets use MAE instead of RMSE, so you need to specify
--metric mae
.
By default, the data in --data_path
will be split randomly into train, validation, and test sets using the seed specified by --seed
(default = 0). By default, the train set contains 80% of the data while the validation and test sets contain 10% of the data each. These sizes can be controlled with --split_sizes
(for example, the default would be --split_sizes 0.8 0.1 0.1
).
To use a different data set for testing, specify --separate_test_path
. In this case, the data in --data_path
will be split into only train and validation sets (80% and 20% of the data), and the test set will contain all the dta in --separate_test_path
.
k-fold cross-validation can be run by specifying the --num_folds
argument (which is 1 by default). For example:
python train.py --data_path data/tox21.csv --dataset_type classification --num_folds 5
To train an ensemble, specify the number of models in the ensemble with the --ensemble_size
argument (which is 1 by default). For example:
python train.py --data_path data/tox21.csv --dataset_type classification --ensemble_size 5
The base message passing architecture can be modified in a range of ways that can be controlled through command line arguments. The full range of options can be seen in parsing.py
. Suggested modifications are:
--hidden_size <int>
Control the hidden size of the neural network layers.--depth <int>
Control the number of message passing steps.--virtual_edges
Adds "virtual" edges connected non-bonded atoms to improve information flow. This works very well on some datasets (ex. QM9) but very poorly on others (ex. delaney).
To load a trained model and make predictions, run predict.py
and specify:
--test_path <path>
Path to the data to predict on.- A checkpoint by using either:
--checkpoint_dir <dir>
Directory where the model checkpoint(s) are saved (i.e.--save_dir
during training). This will walk the directory, load all.pt
files it finds, and treat the models as an ensemble.--checkpoint_path <path>
Path to a model checkpoint file (.pt
file).
--preds_path
Path where a CSV file containing the predictions will be saved.- (Optional)
--write_smiles
Add this flag if you would like to save the SMILES string for each molecule alongside the property prediction.
For example:
python predict.py --test_path data/tox21.csv --checkpoint_dir tox21_checkpoints --preds_path tox21_preds.csv
or
python predict.py --test_path data/tox21.csv --checkpoint_path tox21_checkpoints/fold_0/model_0/model.pt --preds_path tox21_preds.csv
During training, TensorBoard logs are automatically saved to the same directory as the model checkpoints. To view TensorBoard logs, run tensorboard --logdir=<dir>
where <dir>
is the path to the checkpoint directory. Then navigate to http://localhost:6006.
We tested our model on 14 deepchem benchmark datasets (http://moleculenet.ai/), ranging from physical chemistry to biophysics properties. To train our model on those datasets, run:
bash run.sh 1
where 1 is the random seed for randomly splitting the dataset into training, validation and testing (not applied to datasets with scaffold splitting).
We compared our model against the graph convolution in deepchem. Our results are averaged over 3 runs with different random seeds, namely different splits across datasets. Unless otherwise indicated, all models were trained using hidden size 1800, depth 6, and master node. We did a few hyperparameter experiments on qm9, but did no searching on the other datasets, so there may still be further room for improvement.
Results on classification datasets (AUC score, the higher the better)
Dataset | Size | Ours | GraphConv (deepchem) |
---|---|---|---|
Bace | 1,513 | 0.884 ± 0.034 | 0.783 ± 0.014 |
BBBP | 2,039 | 0.922 ± 0.012 | 0.690 ± 0.009 |
Tox21 | 7,831 | 0.851 ± 0.015 | 0.829 ± 0.006 |
Toxcast | 8,576 | 0.748 ± 0.014 | 0.716 ± 0.014 |
Sider | 1,427 | 0.643 ± 0.027 | 0.638 ± 0.012 |
clintox | 1,478 | 0.882 ± 0.022 | 0.807 ± 0.047 |
MUV | 93,087 | 0.067 ± 0.03* | 0.046 ± 0.031 |
HIV | 41,127 | 0.821 ± 0.034† | 0.763 ± 0.016 |
PCBA | 437,928 | 0.218 ± 0.001* | 0.136 ± 0.003 |
Results on regression datasets (score, the lower the better)
Dataset | Size | Ours | GraphConv/MPNN (deepchem) |
---|---|---|---|
delaney | 1,128 | 0.687 ± 0.037 | 0.58 ± 0.03 |
Freesolv | 642 | 0.915 ± 0.154 | 1.15 ± 0.12 |
Lipo | 4,200 | 0.565 ± 0.052 | 0.655 ± 0.036 |
qm8 | 21,786 | 0.008 ± 0.000 | 0.0143 ± 0.0011 |
qm9 | 133,884 | 2.47 ± 0.036 | 3.2 ± 1.5 |
†HIV was trained with hidden size 1800 and depth 6 but without the master node. *MUV and PCBA are using a much older version of the model.