Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A better PyTorch wrapper #19

Merged
merged 14 commits into from
Apr 26, 2021
4 changes: 1 addition & 3 deletions pytorch/BenchmarkTorchANISymmetryFunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
sum_aev.backward()
grad = positions.grad.clone()

N = 40000
N = 100000
start = time.time()
for _ in range(N):
aev = symmFunc(speciesPositions).aevs
Expand All @@ -55,7 +55,5 @@

aev_error = torch.max(torch.abs(aev - aev_ref))
grad_error = torch.max(torch.abs(grad - grad_ref))
print(aev_error)
print(grad_error)
assert aev_error < 0.0002
assert grad_error < 0.007
33 changes: 13 additions & 20 deletions pytorch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,39 +44,32 @@ print(energy, forces)

### Build & install

- Crate a *Conda* environment
- Get the source code
```bash
$ conda create -n nnpops \
-c pytorch \
-c conda-forge \
cmake \
git \
gxx_linux-64 \
make \
mdtraj \
pytest \
python=3.8 \
pytorch=1.6 \
torchani=2.2
$ conda activate nnpops
$ git clone https://github.com/openmm/NNPOps.git
```
- Get the source code

- Crate a *Conda* environment
```bash
$ git clone https://github.com/peastman/NNPOps.git
$ cd NNPOps
$ conda create -f pytorch/environment.yml
$ conda activate nnpops
```

- Configure, build, and install
```bash
$ mkdir build
$ cd build
$ cmake ../NNPOps/pytorch \
$ cmake ../pytorch \
-DCMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc \
-DCMAKE_CUDA_HOST_COMPILER=$CXX \
-DTorch_DIR=$CONDA_PREFIX/lib/python3.8/site-packages/torch/share/cmake/Torch \
-DTorch_DIR=$CONDA_PREFIX/lib/python3.9/site-packages/torch/share/cmake/Torch \
-DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX
$ make install
```
- Optional: run tests
- Optional: run tests and benchmarks
```bash
$ cd ../NNPOps/pytorch
$ cd ../pytorch
$ pytest TestSymmetryFunctions.py
$ python BenchmarkTorchANISymmetryFunctions.py
```
198 changes: 103 additions & 95 deletions pytorch/SymmetryFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,31 +32,52 @@
throw std::runtime_error(std::string("Encountered error ")+cudaGetErrorName(result)+" at "+__FILE__+":"+std::to_string(__LINE__));\
}

class CustomANISymmetryFunctions : public torch::CustomClassHolder {
namespace NNPOps {
namespace ANISymmetryFunctions {

class Holder;
using std::vector;
using HolderPtr = torch::intrusive_ptr<Holder>;
using torch::Tensor;
using torch::optional;
using Context = torch::autograd::AutogradContext;
using torch::autograd::tensor_list;

class Holder : public torch::CustomClassHolder {
public:
CustomANISymmetryFunctions(int64_t numSpecies_,
double Rcr,
double Rca,
const std::vector<double>& EtaR,
const std::vector<double>& ShfR,
const std::vector<double>& EtaA,
const std::vector<double>& Zeta,
const std::vector<double>& ShfA,
const std::vector<double>& ShfZ,
const std::vector<int64_t>& atomSpecies_,
const torch::Tensor& positions) : torch::CustomClassHolder() {

// Constructor for an uninitialized object
// Note: this is need for serialization
Holder() : torch::CustomClassHolder() {};

Holder(int64_t numSpecies_,
double Rcr,
double Rca,
const vector<double>& EtaR,
const vector<double>& ShfR,
const vector<double>& EtaA,
const vector<double>& Zeta,
const vector<double>& ShfA,
const vector<double>& ShfZ,
const vector<int64_t>& atomSpecies_,
const Tensor& positions) : torch::CustomClassHolder() {

// Construct an uninitialized object
// Note: this is needed for Python bindings
if (numSpecies_ == 0)
return;

tensorOptions = torch::TensorOptions().device(positions.device()); // Data type of float by default
int numAtoms = atomSpecies_.size();
int numSpecies = numSpecies_;
const std::vector<int> atomSpecies(atomSpecies_.begin(), atomSpecies_.end());
const vector<int> atomSpecies(atomSpecies_.begin(), atomSpecies_.end());

std::vector<RadialFunction> radialFunctions;
vector<RadialFunction> radialFunctions;
for (const float eta: EtaR)
for (const float rs: ShfR)
radialFunctions.push_back({eta, rs});

std::vector<AngularFunction> angularFunctions;
vector<AngularFunction> angularFunctions;
for (const float eta: EtaA)
for (const float zeta: Zeta)
for (const float rs: ShfA)
Expand All @@ -77,11 +98,11 @@ class CustomANISymmetryFunctions : public torch::CustomClassHolder {
positionsGrad = torch::empty({numAtoms, 3}, tensorOptions);
};

torch::autograd::tensor_list forward(const torch::Tensor& positions_, const torch::optional<torch::Tensor>& periodicBoxVectors_) {
tensor_list forward(const Tensor& positions_, const optional<Tensor>& periodicBoxVectors_) {

const torch::Tensor positions = positions_.to(tensorOptions);
const Tensor positions = positions_.to(tensorOptions);

torch::Tensor periodicBoxVectors;
Tensor periodicBoxVectors;
float* periodicBoxVectorsPtr = nullptr;
if (periodicBoxVectors_) {
periodicBoxVectors = periodicBoxVectors_->to(tensorOptions);
Expand All @@ -93,99 +114,86 @@ class CustomANISymmetryFunctions : public torch::CustomClassHolder {
return {radial, angular};
};

torch::Tensor backward(const torch::autograd::tensor_list& grads) {
Tensor backward(const tensor_list& grads) {

const torch::Tensor radialGrad = grads[0].clone();
const torch::Tensor angularGrad = grads[1].clone();
const Tensor radialGrad = grads[0].clone();
const Tensor angularGrad = grads[1].clone();

symFunc->backprop(radialGrad.data_ptr<float>(), angularGrad.data_ptr<float>(), positionsGrad.data_ptr<float>());

return positionsGrad;
}
};

bool is_initialized() {
return bool(symFunc);
};

private:
torch::TensorOptions tensorOptions;
std::shared_ptr<ANISymmetryFunctions> symFunc;
torch::Tensor radial;
torch::Tensor angular;
torch::Tensor positionsGrad;
std::shared_ptr<::ANISymmetryFunctions> symFunc;
Tensor radial;
Tensor angular;
Tensor positionsGrad;
};

class GradANISymmetryFunction : public torch::autograd::Function<GradANISymmetryFunction> {
class AutogradFunctions : public torch::autograd::Function<AutogradFunctions> {

public:
static torch::autograd::tensor_list forward(torch::autograd::AutogradContext *ctx,
int64_t numSpecies,
double Rcr,
double Rca,
const std::vector<double>& EtaR,
const std::vector<double>& ShfR,
const std::vector<double>& EtaA,
const std::vector<double>& Zeta,
const std::vector<double>& ShfA,
const std::vector<double>& ShfZ,
const std::vector<int64_t>& atomSpecies,
const torch::Tensor& positions,
const torch::optional<torch::Tensor>& periodicBoxVectors) {

const auto symFunc = torch::intrusive_ptr<CustomANISymmetryFunctions>::make(
numSpecies, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, atomSpecies, positions);
ctx->saved_data["symFunc"] = symFunc;

return symFunc->forward(positions, periodicBoxVectors);
static tensor_list forward(Context *ctx,
const HolderPtr& holder,
const Tensor& positions,
const optional<Tensor>& periodicBoxVectors) {

ctx->saved_data["holder"] = holder;

return holder->forward(positions, periodicBoxVectors);
};

static torch::autograd::tensor_list backward(torch::autograd::AutogradContext *ctx, const torch::autograd::tensor_list& grads) {

const auto symFunc = ctx->saved_data["symFunc"].toCustomClass<CustomANISymmetryFunctions>();
torch::Tensor positionsGrad = symFunc->backward(grads);
ctx->saved_data.erase("symFunc");

return { torch::Tensor(), // numSpecies
torch::Tensor(), // Rcr
torch::Tensor(), // Rca
torch::Tensor(), // EtaR
torch::Tensor(), // ShfR
torch::Tensor(), // EtaA
torch::Tensor(), // Zeta
torch::Tensor(), // ShfA
torch::Tensor(), // ShfZ
torch::Tensor(), // atomSpecies
positionsGrad, // positions
torch::Tensor()}; // periodicBoxVectors
static tensor_list backward(Context *ctx, const tensor_list& grads) {

const auto holder = ctx->saved_data["holder"].toCustomClass<Holder>();
Tensor positionsGrad = holder->backward(grads);
ctx->saved_data.erase("holder");

return { Tensor(), // holder
positionsGrad, // positions
Tensor() }; // periodicBoxVectors
};
};

static torch::autograd::tensor_list ANISymmetryFunctionsOp(int64_t numSpecies,
double Rcr,
double Rca,
const std::vector<double>& EtaR,
const std::vector<double>& ShfR,
const std::vector<double>& EtaA,
const std::vector<double>& Zeta,
const std::vector<double>& ShfA,
const std::vector<double>& ShfZ,
const std::vector<int64_t>& atomSpecies,
const torch::Tensor& positions,
const torch::optional<torch::Tensor>& periodicBoxVectors) {

return GradANISymmetryFunction::apply(numSpecies, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, atomSpecies, positions, periodicBoxVectors);
tensor_list operation(const optional<HolderPtr>& holder,
const Tensor& positions,
const optional<Tensor>& periodicBoxVectors) {

return AutogradFunctions::apply(*holder, positions, periodicBoxVectors);
}

TORCH_LIBRARY(NNPOpsANISymmetryFunctions, m) {
m.class_<Holder>("Holder")
.def(torch::init<int64_t, // numSpecies
double, // Rcr
double, // Rca
const vector<double>&, // EtaR
const vector<double>&, // ShfR
const vector<double>&, // EtaA
const vector<double>&, // Zeta
const vector<double>&, // ShfA
const vector<double>&, // ShfZ
const vector<int64_t>&, // atomSpecies
const Tensor&>()) // positions
.def("forward", &Holder::forward)
.def("backward", &Holder::backward)
.def("is_initialized", &Holder::is_initialized)
.def_pickle(
// __getstate__
// Note: nothing is during serialization
[](const HolderPtr& self) -> int64_t { return 0; },
// __setstate__
// Note: a new uninitialized object is create during deserialization
[](int64_t state) -> HolderPtr { return HolderPtr::make(); }
);
m.def("operation", operation);
}

TORCH_LIBRARY(NNPOps, m) {
m.class_<CustomANISymmetryFunctions>("CustomANISymmetryFunctions")
.def(torch::init<int64_t, // numSpecies
double, // Rcr
double, // Rca
const std::vector<double>&, // EtaR
const std::vector<double>&, // ShfR
const std::vector<double>&, // EtaA
const std::vector<double>&, // Zeta
const std::vector<double>&, // ShfA
const std::vector<double>&, // ShfZ
const std::vector<int64_t>&, // atomSpecies
const torch::Tensor&>()) // positions
.def("forward", &CustomANISymmetryFunctions::forward)
.def("backward", &CustomANISymmetryFunctions::backward);
m.def("ANISymmetryFunctions", ANISymmetryFunctionsOp);
}
} // namespace ANISymmetryFunctions
} // namespace NNPOps
23 changes: 17 additions & 6 deletions pytorch/SymmetryFunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
from torchani.aev import SpeciesAEV

torch.ops.load_library(os.path.join(os.path.dirname(__file__), 'libNNPOpsPyTorch.so'))
torch.classes.load_library(os.path.join(os.path.dirname(__file__), 'libNNPOpsPyTorch.so'))

Holder = torch.classes.NNPOpsANISymmetryFunctions.Holder
operation = torch.ops.NNPOpsANISymmetryFunctions.operation

class TorchANISymmetryFunctions(torch.nn.Module):
"""Optimized TorchANI symmetry functions
Expand Down Expand Up @@ -66,7 +70,6 @@ def __init__(self, symmFunc: torchani.AEVComputer):
Arguments:
symmFunc: the instance of torchani.AEVComputer (https://aiqm.github.io/torchani/api.html#torchani.AEVComputer)
"""

super().__init__()

self.numSpecies = symmFunc.num_species
Expand All @@ -79,6 +82,10 @@ def __init__(self, symmFunc: torchani.AEVComputer):
self.ShfA = symmFunc.ShfA[0, 0, :, 0].tolist()
self.ShfZ = symmFunc.ShfZ[0, 0, 0, :].tolist()

# Create an uninitialized holder
self.holder = Holder(0, 0, 0, [], [] , [] , [], [] , [], [], Tensor())
assert not self.holder.is_initialized()

self.triu_index = torch.tensor([0]) # A dummy variable to make TorchScript happy ;)

def forward(self, speciesAndPositions: Tuple[Tensor, Tensor],
Expand All @@ -100,7 +107,6 @@ def forward(self, speciesAndPositions: Tuple[Tensor, Tensor],
species, positions = speciesAndPositions
if species.shape[0] != 1:
raise ValueError('Batched molecule computation is not supported')
species_: List[int] = species[0].tolist() # Explicit type casting for TorchScript
if species.shape + (3,) != positions.shape:
raise ValueError('Inconsistent shapes of "species" and "positions"')
if cell is not None:
Expand All @@ -113,10 +119,15 @@ def forward(self, speciesAndPositions: Tuple[Tensor, Tensor],
if pbc_ != [True, True, True]:
raise ValueError('Only fully periodic systems are supported, i.e. pbc = [True, True, True]')

symFunc = torch.ops.NNPOps.ANISymmetryFunctions
radial, angular = symFunc(self.numSpecies, self.Rcr, self.Rca, self.EtaR, self.ShfR,
self.EtaA, self.Zeta, self.ShfA, self.ShfZ,
species_, positions[0], cell)
if not self.holder.is_initialized():
species_: List[int] = species[0].tolist() # Explicit type casting for TorchScript
self.holder = Holder(self.numSpecies, self.Rcr, self.Rca,
self.EtaR, self.ShfR,
self.EtaA, self.Zeta, self.ShfA, self.ShfZ,
species_, positions)
assert self.holder.is_initialized()

radial, angular = operation(self.holder, positions[0], cell)
features = torch.cat((radial, angular), dim=1).unsqueeze(0)

return SpeciesAEV(species, features)
12 changes: 12 additions & 0 deletions pytorch/environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
name: nnpops
channels:
- conda-forge
dependencies:
- cmake
- gxx_linux-64
- make
- mdtraj
- torchani 2.2
- pytest
- python 3.9
- pytorch 1.8.0