-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* PyTorch wrapper for the forward pass on CPU * CMake file for the PyTorch wrapper * Pytorch wrapper for the backward pass (not yet working) * Wrap CpuANISymmetryFunctions as a custom Pytorch class * Pytorch wrapper of the backward pass * Simplify Pytorch wrapper * Pytorch wrapper for the CUDA implementation * Fix a typo * Simplfy the Pytorch wrapper * Fix the memory leak in the PyTorch wrapper * Pass the box vector to the PyTorch wrapper * Unify the names of PyTorch wrapper * Implement integration with TorchANI via the PyTorch wrapper * Simplify and add check to the TorchANI integration * Rename the PyTorch wrapper component * Add a test for TorchANISymmetryFunctions * Fix the serialization of TorchANISymmetryFunctions * Add a test for the serialization of TorchANISymmetryFunctions * Add more molecules for TorchANISymmetryFunctions tests * Update TorchANISymmetryFunctions tests to use all the molecules * Improve CMake file for NNPOpsPyTorch * Add installation instructions for NNPOpsPyTorch * Fix the import of NNPOps in Python * Add an usage example for NNPOpsPyTorch * Fix the import in the example * Add docstrings for TorchANISymmetryFunctions * Add more general text about the wrapper * Fix typo * Add a benchmark script for TorchANISymmetryFunctions * Make PyTorch and NNPOps to run on the same GPU
- Loading branch information
Raimondas Galvelis
authored
Oct 28, 2020
1 parent
fadbc78
commit 667a282
Showing
13 changed files
with
1,695 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import mdtraj | ||
import time | ||
import torch | ||
import torchani | ||
|
||
from NNPOps.SymmetryFunctions import TorchANISymmetryFunctions | ||
|
||
device = torch.device('cuda') | ||
|
||
mol = mdtraj.load('molecules/2iuz_ligand.mol2') | ||
species = torch.tensor([[atom.element.atomic_number for atom in mol.top.atoms]], device=device) | ||
positions = torch.tensor(mol.xyz, dtype=torch.float32, requires_grad=True, device=device) | ||
|
||
nnp = torchani.models.ANI2x(periodic_table_index=True, model_index=None).to(device) | ||
speciesPositions = nnp.species_converter((species, positions)) | ||
symmFuncRef = nnp.aev_computer | ||
symmFunc = TorchANISymmetryFunctions(nnp.aev_computer).to(device) | ||
|
||
aev_ref = symmFuncRef(speciesPositions).aevs | ||
sum_aev_ref = torch.sum(aev_ref) | ||
sum_aev_ref.backward() | ||
grad_ref = positions.grad.clone() | ||
|
||
N = 10000 | ||
start = time.time() | ||
for _ in range(N): | ||
aev_ref = symmFuncRef(speciesPositions).aevs | ||
sum_aev_ref = torch.sum(aev_ref) | ||
positions.grad.zero_() | ||
sum_aev_ref.backward() | ||
delta = time.time() - start | ||
grad_ref = positions.grad.clone() | ||
print('Original TorchANI symmetry functions') | ||
print(f' Duration: {delta} s') | ||
print(f' Speed: {delta/N*1000} ms/it') | ||
|
||
aev = symmFunc(speciesPositions).aevs | ||
sum_aev = torch.sum(aev) | ||
positions.grad.zero_() | ||
sum_aev.backward() | ||
grad = positions.grad.clone() | ||
|
||
N = 40000 | ||
start = time.time() | ||
for _ in range(N): | ||
aev = symmFunc(speciesPositions).aevs | ||
sum_aev = torch.sum(aev) | ||
positions.grad.zero_() | ||
sum_aev.backward() | ||
delta = time.time() - start | ||
grad = positions.grad.clone() | ||
print('Optimized TorchANI symmetry functions') | ||
print(f' Duration: {delta} s') | ||
print(f' Speed: {delta/N*1000} ms/it') | ||
|
||
aev_error = torch.max(torch.abs(aev - aev_ref)) | ||
grad_error = torch.max(torch.abs(grad - grad_ref)) | ||
print(aev_error) | ||
print(grad_error) | ||
assert aev_error < 0.0002 | ||
assert grad_error < 0.007 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
cmake_minimum_required(VERSION 3.1 FATAL_ERROR) | ||
|
||
set(NAME NNPOps) | ||
set(LIBRARY ${NAME}PyTorch) | ||
project(${NAME} LANGUAGES CXX CUDA) | ||
|
||
find_package(Python REQUIRED) | ||
find_package(PythonLibs REQUIRED) | ||
find_package(Torch REQUIRED) | ||
|
||
set(CMAKE_INSTALL_RPATH_USE_LINK_PATH true) | ||
|
||
add_library(${LIBRARY} SHARED SymmetryFunctions.cpp | ||
../ani/CpuANISymmetryFunctions.cpp | ||
../ani/CudaANISymmetryFunctions.cu) | ||
target_compile_features(${LIBRARY} PRIVATE cxx_std_14) | ||
target_include_directories(${LIBRARY} PRIVATE ${PYTHON_INCLUDE_DIRS}) | ||
target_include_directories(${LIBRARY} PRIVATE ../ani) | ||
target_link_libraries(${LIBRARY} ${TORCH_LIBRARIES} ${PYTHON_LIBRARIES}) | ||
|
||
install(TARGETS ${LIBRARY} DESTINATION ${Python_SITEARCH}/${NAME}) | ||
install(FILES SymmetryFunctions.py DESTINATION ${Python_SITEARCH}/${NAME}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# PyTorch wrapper for NNPOps | ||
|
||
*NNPOps* functionalities are available in *PyTorch* (https://pytorch.org/). | ||
|
||
## Optimized TorchANI symmetry functions | ||
|
||
Optimized drop-in replacement for `torchani.AEVComputer` (https://aiqm.github.io/torchani/api.html?highlight=speciesaev#torchani.AEVComputer) | ||
|
||
### Example | ||
|
||
```python | ||
import mdtraj | ||
import torch | ||
import torchani | ||
|
||
from NNPOps.SymmetryFunctions import TorchANISymmetryFunctions | ||
|
||
device = torch.device('cuda') | ||
|
||
# Load a molecule | ||
molecule = mdtraj.load('molecule.mol2') | ||
species = torch.tensor([[atom.element.atomic_number for atom in molecule.top.atoms]], device=device) | ||
positions = torch.tensor(molecule.xyz, dtype=torch.float32, requires_grad=True, device=device) | ||
|
||
# Construct ANI-2x and replace its native featurizer with NNPOps implementation | ||
nnp = torchani.models.ANI2x(periodic_table_index=True).to(device) | ||
nnp.aev_computer = TorchANISymmetryFunctions(nnp.aev_computer) | ||
|
||
# Compute energy | ||
energy = nnp((species, positions)).energies | ||
energy.backward() | ||
forces = -positions.grad.clone() | ||
|
||
print(energy, forces) | ||
``` | ||
|
||
## Installation | ||
|
||
### Prerequisites | ||
|
||
- *Linux* | ||
- Complete *CUDA Toolkit* (https://developer.nvidia.com/cuda-downloads) | ||
- *Miniconda* (https://docs.conda.io/en/latest/miniconda.html#linux-installers) | ||
|
||
### Build & install | ||
|
||
- Crate a *Conda* environment | ||
```bash | ||
$ conda create -n nnpops \ | ||
-c pytorch \ | ||
-c conda-forge \ | ||
cmake \ | ||
git \ | ||
gxx_linux-64 \ | ||
make \ | ||
mdtraj \ | ||
pytest \ | ||
python=3.8 \ | ||
pytorch=1.6 \ | ||
torchani=2.2 | ||
$ conda activate nnpops | ||
``` | ||
- Get the source code | ||
```bash | ||
$ git clone https://github.com/peastman/NNPOps.git | ||
``` | ||
- Configure, build, and install | ||
```bash | ||
$ mkdir build | ||
$ cd build | ||
$ cmake ../NNPOps/pytorch \ | ||
-DCMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc \ | ||
-DCMAKE_CUDA_HOST_COMPILER=$CXX \ | ||
-DTorch_DIR=$CONDA_PREFIX/lib/python3.8/site-packages/torch/share/cmake/Torch \ | ||
-DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX | ||
$ make install | ||
``` | ||
- Optional: run tests | ||
```bash | ||
$ cd ../NNPOps/pytorch | ||
$ pytest TestSymmetryFunctions.py | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
/** | ||
* Copyright (c) 2020 Acellera | ||
* Authors: Raimondas Galvelis | ||
* | ||
* Permission is hereby granted, free of charge, to any person obtaining a copy | ||
* of this software and associated documentation files (the "Software"), to deal | ||
* in the Software without restriction, including without limitation the rights | ||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
* copies of the Software, and to permit persons to whom the Software is | ||
* furnished to do so, subject to the following conditions: | ||
* | ||
* The above copyright notice and this permission notice shall be included in all | ||
* copies or substantial portions of the Software. | ||
* | ||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
* SOFTWARE. | ||
*/ | ||
|
||
#include <stdexcept> | ||
#include <cuda_runtime.h> | ||
#include <torch/script.h> | ||
#include "CpuANISymmetryFunctions.h" | ||
#include "CudaANISymmetryFunctions.h" | ||
|
||
#define CHECK_CUDA_RESULT(result) \ | ||
if (result != cudaSuccess) { \ | ||
throw std::runtime_error(std::string("Encountered error ")+cudaGetErrorName(result)+" at "+__FILE__+":"+std::to_string(__LINE__));\ | ||
} | ||
|
||
class CustomANISymmetryFunctions : public torch::CustomClassHolder { | ||
public: | ||
CustomANISymmetryFunctions(int64_t numSpecies_, | ||
double Rcr, | ||
double Rca, | ||
const std::vector<double>& EtaR, | ||
const std::vector<double>& ShfR, | ||
const std::vector<double>& EtaA, | ||
const std::vector<double>& Zeta, | ||
const std::vector<double>& ShfA, | ||
const std::vector<double>& ShfZ, | ||
const std::vector<int64_t>& atomSpecies_, | ||
const torch::Tensor& positions) : torch::CustomClassHolder() { | ||
|
||
tensorOptions = torch::TensorOptions().device(positions.device()); // Data type of float by default | ||
int numAtoms = atomSpecies_.size(); | ||
int numSpecies = numSpecies_; | ||
const std::vector<int> atomSpecies(atomSpecies_.begin(), atomSpecies_.end()); | ||
|
||
std::vector<RadialFunction> radialFunctions; | ||
for (const float eta: EtaR) | ||
for (const float rs: ShfR) | ||
radialFunctions.push_back({eta, rs}); | ||
|
||
std::vector<AngularFunction> angularFunctions; | ||
for (const float eta: EtaA) | ||
for (const float zeta: Zeta) | ||
for (const float rs: ShfA) | ||
for (const float thetas: ShfZ) | ||
angularFunctions.push_back({eta, rs, zeta, thetas}); | ||
|
||
const torch::Device& device = tensorOptions.device(); | ||
if (device.is_cpu()) | ||
symFunc = std::make_shared<CpuANISymmetryFunctions>(numAtoms, numSpecies, Rcr, Rca, false, atomSpecies, radialFunctions, angularFunctions, true); | ||
if (device.is_cuda()) { | ||
// PyTorch allow to chose GPU with "torch.device", but it doesn't set as the default one. | ||
CHECK_CUDA_RESULT(cudaSetDevice(device.index())); | ||
symFunc = std::make_shared<CudaANISymmetryFunctions>(numAtoms, numSpecies, Rcr, Rca, false, atomSpecies, radialFunctions, angularFunctions, true); | ||
} | ||
|
||
radial = torch::empty({numAtoms, numSpecies * (int)radialFunctions.size()}, tensorOptions); | ||
angular = torch::empty({numAtoms, numSpecies * (numSpecies + 1) / 2 * (int)angularFunctions.size()}, tensorOptions); | ||
positionsGrad = torch::empty({numAtoms, 3}, tensorOptions); | ||
}; | ||
|
||
torch::autograd::tensor_list forward(const torch::Tensor& positions_, const torch::optional<torch::Tensor>& periodicBoxVectors_) { | ||
|
||
const torch::Tensor positions = positions_.to(tensorOptions); | ||
|
||
torch::Tensor periodicBoxVectors; | ||
float* periodicBoxVectorsPtr = nullptr; | ||
if (periodicBoxVectors_) { | ||
periodicBoxVectors = periodicBoxVectors_->to(tensorOptions); | ||
float* periodicBoxVectorsPtr = periodicBoxVectors.data_ptr<float>(); | ||
} | ||
|
||
symFunc->computeSymmetryFunctions(positions.data_ptr<float>(), periodicBoxVectorsPtr, radial.data_ptr<float>(), angular.data_ptr<float>()); | ||
|
||
return {radial, angular}; | ||
}; | ||
|
||
torch::Tensor backward(const torch::autograd::tensor_list& grads) { | ||
|
||
const torch::Tensor radialGrad = grads[0].clone(); | ||
const torch::Tensor angularGrad = grads[1].clone(); | ||
|
||
symFunc->backprop(radialGrad.data_ptr<float>(), angularGrad.data_ptr<float>(), positionsGrad.data_ptr<float>()); | ||
|
||
return positionsGrad; | ||
} | ||
|
||
private: | ||
torch::TensorOptions tensorOptions; | ||
std::shared_ptr<ANISymmetryFunctions> symFunc; | ||
torch::Tensor radial; | ||
torch::Tensor angular; | ||
torch::Tensor positionsGrad; | ||
}; | ||
|
||
class GradANISymmetryFunction : public torch::autograd::Function<GradANISymmetryFunction> { | ||
|
||
public: | ||
static torch::autograd::tensor_list forward(torch::autograd::AutogradContext *ctx, | ||
int64_t numSpecies, | ||
double Rcr, | ||
double Rca, | ||
const std::vector<double>& EtaR, | ||
const std::vector<double>& ShfR, | ||
const std::vector<double>& EtaA, | ||
const std::vector<double>& Zeta, | ||
const std::vector<double>& ShfA, | ||
const std::vector<double>& ShfZ, | ||
const std::vector<int64_t>& atomSpecies, | ||
const torch::Tensor& positions, | ||
const torch::optional<torch::Tensor>& periodicBoxVectors) { | ||
|
||
const auto symFunc = torch::intrusive_ptr<CustomANISymmetryFunctions>::make( | ||
numSpecies, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, atomSpecies, positions); | ||
ctx->saved_data["symFunc"] = symFunc; | ||
|
||
return symFunc->forward(positions, periodicBoxVectors); | ||
}; | ||
|
||
static torch::autograd::tensor_list backward(torch::autograd::AutogradContext *ctx, const torch::autograd::tensor_list& grads) { | ||
|
||
const auto symFunc = ctx->saved_data["symFunc"].toCustomClass<CustomANISymmetryFunctions>(); | ||
torch::Tensor positionsGrad = symFunc->backward(grads); | ||
ctx->saved_data.erase("symFunc"); | ||
|
||
return { torch::Tensor(), // numSpecies | ||
torch::Tensor(), // Rcr | ||
torch::Tensor(), // Rca | ||
torch::Tensor(), // EtaR | ||
torch::Tensor(), // ShfR | ||
torch::Tensor(), // EtaA | ||
torch::Tensor(), // Zeta | ||
torch::Tensor(), // ShfA | ||
torch::Tensor(), // ShfZ | ||
torch::Tensor(), // atomSpecies | ||
positionsGrad, // positions | ||
torch::Tensor()}; // periodicBoxVectors | ||
}; | ||
}; | ||
|
||
static torch::autograd::tensor_list ANISymmetryFunctionsOp(int64_t numSpecies, | ||
double Rcr, | ||
double Rca, | ||
const std::vector<double>& EtaR, | ||
const std::vector<double>& ShfR, | ||
const std::vector<double>& EtaA, | ||
const std::vector<double>& Zeta, | ||
const std::vector<double>& ShfA, | ||
const std::vector<double>& ShfZ, | ||
const std::vector<int64_t>& atomSpecies, | ||
const torch::Tensor& positions, | ||
const torch::optional<torch::Tensor>& periodicBoxVectors) { | ||
|
||
return GradANISymmetryFunction::apply(numSpecies, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, atomSpecies, positions, periodicBoxVectors); | ||
} | ||
|
||
TORCH_LIBRARY(NNPOps, m) { | ||
m.class_<CustomANISymmetryFunctions>("CustomANISymmetryFunctions") | ||
.def(torch::init<int64_t, // numSpecies | ||
double, // Rcr | ||
double, // Rca | ||
const std::vector<double>&, // EtaR | ||
const std::vector<double>&, // ShfR | ||
const std::vector<double>&, // EtaA | ||
const std::vector<double>&, // Zeta | ||
const std::vector<double>&, // ShfA | ||
const std::vector<double>&, // ShfZ | ||
const std::vector<int64_t>&, // atomSpecies | ||
const torch::Tensor&>()) // positions | ||
.def("forward", &CustomANISymmetryFunctions::forward) | ||
.def("backward", &CustomANISymmetryFunctions::backward); | ||
m.def("ANISymmetryFunctions", ANISymmetryFunctionsOp); | ||
} |
Oops, something went wrong.