Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support SchNet-like models as well #10

Closed
jchodera opened this issue Oct 9, 2020 · 37 comments
Closed

Support SchNet-like models as well #10

jchodera opened this issue Oct 9, 2020 · 37 comments

Comments

@jchodera
Copy link
Member

jchodera commented Oct 9, 2020

SchNet-like models have significant advantages in that they use atom environment embeddings instead of separate pairs of networks to compute energy contributions. Would it be possible to support these as well?

@peastman
Copy link
Member

peastman commented Oct 9, 2020

Definitely. What do you consider "SchNet-like models" to include? Are there other particular models you have in mind?

@jchodera
Copy link
Member Author

jchodera commented Oct 9, 2020

The key feature is the use of an atom embedding vector that is used in computing the interaction potential, allowing facile extension of these models to broader chemistries. Here's the SchNet version of this (first "embedding" stage from element identities Z_i on upper left):
image
Here's AIMNet:
image
Here's PhysNet:
image

@jchodera
Copy link
Member Author

jchodera commented Oct 9, 2020

Notably, these atom embeddings could potentially be computed externally if they are fixed throughout the entire simulation!

@peastman
Copy link
Member

peastman commented Oct 9, 2020

My question is what other models you had in mind. Does "SchNet-like models" just mean "SchNet", or some larger class of models? How general does the implementation need to be? If we also want to support models that are different than SchNet, in what ways are they different?

@jchodera
Copy link
Member Author

jchodera commented Oct 10, 2020

My question is what other models you had in mind.

I was thinking of the three models I listed above:

Clebsch-Gordan nets are also very exciting, but are significantly more challenging:

How general does the implementation need to be?

It's worth looking at the architectures above to see if some large class(es) easily emerge from

If we also want to support models that are different than SchNet, in what ways are they different?

Worth some study! I haven't looked deeply into this yet.

@peastman
Copy link
Member

@giadefa based on your work with SchNet, do you have a good sense of the performance bottlenecks? What are the critical operations to accelerate?

@peastman
Copy link
Member

I've been trying to figure out how SchNet actually works. Here are the sources I've been looking at.

The original paper
A later paper
A paper on SchNetPack
The source code for SchNetPack

The papers are often vague and confusing, and it appears that what's implemented in SchNetPack is substantially different from what any of them describes.

First question: does SchNet use a cutoff?

The original paper makes no mention of cutoffs at all. It also specifies that their radial basis consists of gaussians with centers all the way out to 30 A.

The second paper explicitly says that they don't use a cutoff for the tests on small molecules (QM9 and MD17), and that they use gaussians with centers out to 20 A. For the tests on crystals they say they use a cutoff of 5 A, but they're vague about what that means. They just say, "Given a filter W(r_jb-r_ia) over all atoms with |r_jb-r_ia| < r_cut..." That implies they're just imposing a sharp cutoff, not smoothing it in any way.

But looking at the SchNetPack source code tells a different story. The SchNetInteraction class accepts an optional cutoff_network argument to implement a smooth cutoff. The default value is the cosine based cutoff function used by ANI. How exactly does this cutoff function get applied? Here's the code in CFConv.

        # pass expanded interactomic distances through filter block
        W = self.filter_network(f_ij)
        # apply cutoff
        if self.cutoff_network is not None:
            C = self.cutoff_network(r_ij)
            W = W * C.unsqueeze(-1)

The output of filter_network is multiplied by the cutoff function. So what is that? It's defined in the SchNetInteraction class.

        self.filter_network = nn.Sequential(
            Dense(n_spatial_basis, n_filters, activation=shifted_softplus),
            Dense(n_filters, n_filters),
        )

So we take the basis functions, send them through a dense layer, apply a shifted softplus activation, and then send them through another dense layer without any activation. The output of that second dense layer gets multiplied by the cutoff function. Take a look at the diagram of the CFConv block shown above. Notice that it lists two dense layers both with shifted softplus activation. As far as I can tell, that second shifted softplus isn't actually used in SchNetPack.

Second question: what is the width of the gaussians?

The original paper states that gamma = 10 A. That doesn't make sense: it needs to have units of inverse distance squared. Did they mean 10 1/A^2? Or maybe 1/(10 A)^2? The second paper says, "We have set the grid spacing and scaling parameter gamma to be 0.1 1/A^2 for all models in this work." That's equally confusing, since the grid spacing and scaling parameter have different units. So turning to the source code, I see that it calculates the centers and widths of the gaussians as

        offset = torch.linspace(start, stop, n_gaussians)
        widths = torch.FloatTensor((offset[1] - offset[0]) * torch.ones_like(offset))

It does indeed hardcode that widths equals the spacing between centers. And here's how they get interpreted:

coeff = -0.5 / torch.pow(widths, 2)
...
gauss = torch.exp(coeff * torch.pow(diff, 2))

So assuming a spacing of 0.1 A like in the papers, the actual value of gamma is 50 1/A^2.

Here's another place where the code doesn't match the paper. The first atom-wise layer in the interaction block is defined as

self.in2f = Dense(n_in, n_filters, bias=False, activation=None)

Notice that they set bias=False. That contradicts the papers, which say that they include biases on all atom-wise layers (eq. 2 in the second paper).

Given all these differences, I'm not sure what "SchNet" even means. Is it the model implemented in SchNetPack? The one described in the papers? Add in that the papers don't provide enough detail to actually implement it, and that they have parameter values with the wrong units. Their model and results are impossible to reproduce.

@giadefa
Copy link
Member

giadefa commented Oct 20, 2020 via email

@peastman
Copy link
Member

Has SchNet ever been validated on small molecules with a cutoff? The papers above never use a cutoff with small molecules.

If not, the following would be a really useful paper.

  1. Train SchNet with a cutoff on the same dataset that was used for ANI-2.
  2. Compare the two models on both accuracy and speed.
  3. Make the trained model available online so people can use it directly without having to train their own model.

@giadefa
Copy link
Member

giadefa commented Oct 20, 2020 via email

@peastman
Copy link
Member

Good to know, thanks!

@peastman
Copy link
Member

Based on my testing of SchNetPack in #14, I think there's room for a custom op to have a significant speedup. It looks to me like the operation we want to implement is the cfconv layer, that is, equation 3 in https://aip.scitation.org/doi/10.1063/1.5019779. For every atom pair it computes the radial basis functions, passes them through two dense layers, multiplies the result by a cutoff function, multiplies by the input features, and sums over all neighbors of each atom.

Are the following assumptions reasonable?

  1. If using a cutoff, it will hardcode the cosine cutoff function. (SchNetPack lets you specify an arbitrary function.)
  2. The number of dense layers and the activation functions will be hardcoded to match what's in SchNetPack.
  3. Initially this will be for inference only. Allowing it to be used for training requires it to calculate gradients with respect to all the weights and biases of the dense layers. That's both more complicated and a lot more memory intensive. If necessary we could add that feature later, but I'd like to keep it simpler to start with.

@peastman
Copy link
Member

Also, what does equation 3 in that paper even mean? It claims to involve an element-wise product between x, which is a 2D tensor of shape (numAtoms, numFeatures), and W, a 3D tensor of shape (numAtoms, numAtoms, numFeatures)??? How did this paper even get published?

@jchodera
Copy link
Member Author

If using a cutoff, it will hardcode the cosine cutoff function. (SchNetPack lets you specify an arbitrary function.)

Could we use something like lepton to make this flexible?

A cosine cutoff may not be ideal since it only forces the function value and first derivative to zero, but not necessarily higher order derivatives, I think.

The number of dense layers and the activation functions will be hardcoded to match what's in SchNetPack.

Could this later be made flexible as well?

Instead of coalescing the RBF computation + the NN-based message passing, what about splitting the RBF computation into a separate kernel that would then use the ML framework for the NN message passing? Is this something that could still lead to increased performance but increased flexibility (by having the RBF computation Op as a GPU-accelerated custom op)?

@peastman
Copy link
Member

Could we use something like lepton to make this flexible?

Only if we're going to do runtime compilation of kernels, which makes the whole thing a lot more complicated.

Instead of coalescing the RBF computation + the NN-based message passing, what about splitting the RBF computation into a separate kernel that would then use the ML framework for the NN message passing?

I don't think so. We need to keep everything that relates to pairs of atoms inside a single kernel and only ever write out quantities that relate to single atoms. Otherwise there's little room for speedup.

@jchodera
Copy link
Member Author

I think this is turning out to be a good example that will help guide some difficult architectural decisions we have to make.

Ideally, we'd like to support several use cases

  1. Accelerated execution of published (and experimental) models in OpenMM
  2. Accelerated training of QML models in ML frameworks like pytorch
  3. Accelerated research of QML models in ML frameworks like pytorch

Use case 1 (accelerated execution) is the narrowest, just requiring we accelerate production models for simulation, with published models being the most important. Naïvely, it might seem like we should take the approach of freezing all choices of hyperparameters (cutoffs, number of basis functions, number of message-passing rounds) to the published model values, but as we're finding out, there doesn't seem to be consensus among the published papers and production software about what these hyperparameters should be, and it seems like several sets are in regular use.

In case 2, it's likely these same "standard" sets of hyperparameters would be used.

For case 3, it's extremely common practice to do large-scale hyperparameter searches as part of the architecture search for building models that work well. There will be particularly high interest in model distillation as well---identifying less complex models that are highly performant on available hardware at the cost of an acceptable amount of error.

It seems like we're going to need a way to generate custom ops that have hyperparameters either as arguments or synthesize the code on demand given the hyperparameters.

It looks like torch supports JIT compilation of C++ extensions, as well as JIT compilation of C++/CUDA extensions, suggesting there may be a way to manage this. Not sure about TensorFlow and JAX (which both use XLA).

@peastman
Copy link
Member

Naïvely, it might seem like we should take the approach of freezing all choices of hyperparameters (cutoffs, number of basis functions, number of message-passing rounds) to the published model values

I definitely wouldn't do that. There's no reason not to make hyperparameters user configurable. It doesn't require JIT compilation. They're just variables the user can specify.

@peastman
Copy link
Member

I'm starting to develop an API for this. Tell me if the following sounds reasonable. Especially @raimis if you can tell me whether it will be possible to wrap this for PyTorch in a reasonable way.

There will be two classes, CFConv and CFConvNeighbors. Ideally you'll create all these objects in advance (one CFConv for each convolutional layer and a single CFConvNeighbors to hold the neighbor list, which can be shared between all the layers). I understand that might be difficult with PyTorch though.

The constructors will be

CFConv(int numAtoms, int inputWidth, int outputWidth, int numGaussians, float cutoff)

and

CFConvNeighbors(int numAtoms, float cutoff)

giving the number of atoms, the widths (number of features) of the input and output tensors, the number of Gaussians to use for the radial basis function (evenly spaced between 0 and cutoff) and the cutoff distance.

To perform an evaluation, you first build the neighbor list by calling

neighbors.build(positions, periodicBoxVectors)

This only needs to be done once for any set of coordinates. Each layer is then evaluated by calling

cfconv.compute(neighbors, positions, periodicBoxVectors, input, output, w1, b1, w2, b2)

where the last six arguments are each a float* pointing to memory on either the host or device. The w's and b's are the weights and biases of the two dense layers contained within the cfconv. To compute gradients you call

cfconv.backprop(neighbors, positions, periodicBoxVectors, outputDeriv, inputDeriv, positionDeriv, w1, b1, w2, b2)

You pass it outputDeriv, containing the derivatives of the model's output (e.g. energy) with respect to this layer's outputs. It backpropagates to compute the derivatives with respect to this layer's inputs and the positions. (It's a little confusing, because the layer's output depends directly on the positions, and also on the inputs which indirectly depend on the positions.)

The API would be cleaner if the weights and biases were provided to the constructor rather than to the compute() and backprop() methods, but possibly that would make it harder to wrap efficiently for PyTorch?

@jchodera
Copy link
Member Author

jchodera commented Oct 26, 2020

In case it's helpful, this documentation thoroughly covers development of C++ and CUDA extensions for Torch, and this documentation for adding Python extensions to PyTorch autograd.

@raimis
Copy link
Contributor

raimis commented Oct 26, 2020

Regarding the potential problems with the PyTorch wrapper:

There will be two classes, CFConv and CFConvNeighbors. Ideally you'll create all these objects in advance (one CFConv for each convolutional layer and a single CFConvNeighbors to hold the neighbor list, which can be shared between all the layers). I understand that might be difficult with PyTorch though.

As discussed in #5, I haven't find a way to create the objects in advance with PyTorch. So the objects have to be constructed and destroyed for each execution.

To perform an evaluation, you first build the neighbor list by calling

neighbors.build(positions, periodicBoxVectors)

This only needs to be done once for any set of coordinates. Each layer is then evaluated by calling

cfconv.compute(neighbors, positions, periodicBoxVectors, input, output, w1, b1, w2, b2)

PyTorch operations can only return one or more torch::Tensor. So CFConv and CFConvNeighbors cannot be wrapped as separate operations, otherwise there is no way to pass the neighbors object.

@raimis
Copy link
Contributor

raimis commented Oct 26, 2020

Meanwhile, I have several ideas how to make the ANISymmetryFunctions wrapper (#5) more efficient. I should be able to share a working code within a few days.

@peastman
Copy link
Member

Do you have a suggestion for an alternate way of handling the neighbor list? Evaluating the model for a single set of coordinates involves six convolution layers, with forward and backward passes on each. If we have to rebuild the same neighbor list 12 times, that will kill the efficiency.

@peastman
Copy link
Member

The documentation at https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html describes how to wrap C++ classes for use in Python and TorchScript. It looks to me like it ought to be straightforward to expose the CFConvNeighbors class so you create an instance of it, call build() on it, and then pass it as an argument to the CFConv op. Am I misunderstanding something? The same would work with ANI: first create an ANISymmetryFunctions object, then pass it as an argument to the op.

@raimis
Copy link
Contributor

raimis commented Oct 27, 2020

The documentation at https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html describes how to wrap C++ classes for use in Python and TorchScript. It looks to me like it ought to be straightforward to expose the CFConvNeighbors class so you create an instance of it, call build() on it, and then pass it as an argument to the CFConv op. Am I misunderstanding something?

If you want autograd to work, you need to implement like this: https://pytorch.org/tutorials/advanced/cpp_autograd.html#using-custom-autograd-function-in-c

@peastman
Copy link
Member

The forward() function can take whatever arguments you want. One of them can be a CFConvNeighbors object. On the page I linked above, see the section titled "Moving Custom Classes To/From IValues" for how to pass custom classes as arguments.

@raimis
Copy link
Contributor

raimis commented Oct 28, 2020

This discussion starts to repeat the one in #5.

PyTorch internals have limitations. Some of them are neither very intuitive, nor explicitly mentioned in the tutorials. While making #5, I had to find them the hard way. @peastman, you are welcome to learn the same way too.

@peastman
Copy link
Member

Can you explain what the limitation is? They give detailed instructions and examples of how to wrap an arbitrary C++ class so you can instantiate it and invoke methods on it. They give detailed instructions on how to allow that class to be passed to arbitrary PyTorch functions. They explicitly say that the functions you define for an operator can accept an arbitrary number of arguments of arbitrary types. So what is the limitation you're referring to?

@raimis
Copy link
Contributor

raimis commented Oct 28, 2020

Can you explain what the limitation is? They give detailed instructions and examples of how to wrap an arbitrary C++ class so you can instantiate it and invoke methods on it. They give detailed instructions on how to allow that class to be passed to arbitrary PyTorch functions. They explicitly say that the functions you define for an operator can accept an arbitrary number of arguments of arbitrary types.

None of these examples shows that it works with autograd! Why?

So what is the limitation you're referring to?

As mentioned #10 (comment), it you want autograd to work, you have to implement like this: https://pytorch.org/tutorials/advanced/cpp_autograd.html#using-custom-autograd-function-in-c.

@peastman
Copy link
Member

I created a minimal test following the documentation at https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html and https://pytorch.org/tutorials/advanced/cpp_extension.html. It works perfectly. First I created a C++ class that stores a floating point value, and defines forward and backward passes for scaling a tensor by that value.

#include <torch/script.h>

class ValueHolder : public torch::CustomClassHolder {
public:
    ValueHolder(double v) : value(v) {
    }
    double getValue() const {
        return value;
    }
    void setValue(double v) {
        value = v;
    }
    at::Tensor scaleForward(torch::Tensor input) {
        return input*value;
    }
    at::Tensor scaleBackward(torch::Tensor input) {
        return input/value;
    }
private:
    double value;
};

I wrapped it as they described.

TORCH_LIBRARY(optest, m) {
    m.class_<ValueHolder>("ValueHolder")
        .def(torch::init<double>())
        .def("getValue", &ValueHolder::getValue)
        .def("setValue", &ValueHolder::setValue)
        .def("scaleForward", &ValueHolder::scaleForward)
        .def("scaleBackward", &ValueHolder::scaleBackward);
}

Creating instances of the object and calling methods on it works exactly as you'd expect.

import torch

torch.classes.load_library("build/liboptest.dylib")
h = torch.classes.optest.ValueHolder(2)
print(h.getValue())
h.setValue(5)
print(h.getValue())
2.0
5.0

I then defined an autograd function implemented with this class. It works just as you'd expect.

class ScaleOp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, holder):
        ctx.holder = holder 
        return holder.scaleForward(x)

    @staticmethod
    def backward(ctx, x):
        return ctx.holder.scaleBackward(x), None

x = torch.tensor([3.0], requires_grad=True)
y = ScaleOp.apply(x, h)
print(y)
y.backward()
print(x.grad)
tensor([15.], grad_fn=<ScaleOpBackward>)
tensor([0.2000])

@peastman peastman mentioned this issue Oct 28, 2020
14 tasks
@raimis
Copy link
Contributor

raimis commented Oct 29, 2020

It seems that torch.autograd.Function in Python API behaves differently than torch::autograd::Function in C++ API. OK, one more idea to try.

@raimis
Copy link
Contributor

raimis commented Nov 4, 2020

@peastman I have tried to adapt your idea, but it doesn't work with JIT. The problem is demonstrated bellow.

Let's write a module:

class Scale(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.h = torch.classes.optest.ValueHolder(5)

    def forward(self, x):
        return ScaleOp.apply(x, self.h)

It works as expected:

scale = Scale()
z = scale(x)
print(z)
x.grad.zero_()
z.backward()
print(x.grad)
tensor([15.], grad_fn=<ScaleOpBackward>)
tensor([0.2000])

Let's compile the module to TorchScript:

torch.jit.script(scale)

It fails:

RuntimeError: 
Tried to access nonexistent attribute or method 'holder' of type '__torch__.ScaleOp'. Did you forget to initialize an attribute in __init__()?:
  File "test.py", line 18
    @staticmethod
    def backward(ctx, x):
        return ctx.holder.scaleBackward(x), None
               ~~~~~~~~~~ <--- HERE
'ScaleOp.backward' is being compiled since it was called from '__torch__.ScaleOp'
  File "test.py", line 32
    def forward(self, x):
        return ScaleOp.apply(x, self.h)
               ~~~~~~~ <--- HERE
'__torch__.ScaleOp' is being compiled since it was called from 'Scale.forward'
  File "test.py", line 32
    def forward(self, x):
        return ScaleOp.apply(x, self.h)
               ~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE

It seems, JIT doesn't pass all the attributes of ctx.

@peastman
Copy link
Member

peastman commented Nov 4, 2020

Apparently this isn't even specific to C++ extensions. PyTorch doesn't support JIT even for autodiff functions written in Python: pytorch/pytorch#22329

Let's see if we can come up with a workaround...

@peastman
Copy link
Member

peastman commented Nov 4, 2020

I figured it out. It only took one hack!

I implemented the function in C++ instead of Python. That does support JIT. Mostly it's a straightforward port of the Python code above. forward() takes an input tensor and a ValueHolder and invokes its scaleForward() method. It also needs to store a reference to the ValueHolder into the context so it will be available to backward(). That's where we run into a problem. Unlike Python, which lets you save arbitrary values, the C++ API only lets you save tensors.

Hence the hack: I get a pointer to the ValueHolder, cast it to a 64 bit int, and store that in a tensor. Then I cast it back again in backward(). Here's the code.

class ScaleFunction : public torch::autograd::Function<ScaleFunction> {
public:
    static torch::Tensor forward(torch::autograd::AutogradContext *ctx, torch::Tensor x, const c10::intrusive_ptr<ValueHolder>& holder) {
        long long pointer = (long long) holder.get();
        auto options = torch::TensorOptions().dtype(torch::kInt64);
        ctx->save_for_backward({torch::ones(1, options)*pointer});
        return holder->scaleForward(x);
    }
    static torch::autograd::tensor_list backward(torch::autograd::AutogradContext *ctx, torch::autograd::tensor_list grad_outputs) {
        long long pointer = ctx->get_saved_variables()[0].accessor<long long, 1>()[0];
        const ValueHolder* holder = reinterpret_cast<const ValueHolder*>(pointer);
        return {holder->scaleBackward(grad_outputs[0]), torch::Tensor()};
    }
};

static torch::autograd::tensor_list ScaleOp(const torch::Tensor x, const c10::intrusive_ptr<ValueHolder>& holder) {
    return {ScaleFunction::apply(x, holder)};
}

@peastman
Copy link
Member

peastman commented Nov 5, 2020

Here's what the Python code looks like.

class ScaleModule(torch.nn.Module):
    def __init__(self, scale):
        super(ScaleModule, self).__init__()
        self.holder = torch.classes.optest.ValueHolder(scale)

    def forward(self, x):
        return torch.ops.optest.ScaleOp(x, self.holder)

scale = torch.jit.script(ScaleModule(3))
x = torch.tensor([3.0], requires_grad=True)
y = scale(x)
print(y)
y[0].backward()
print(x.grad)

@raimis
Copy link
Contributor

raimis commented Nov 6, 2020

Hence the hack: I get a pointer to the ValueHolder, cast it to a 64 bit int, and store that in a tensor. Then I cast it back again in backward().

In #5, I found a way to pass a class between forward and backward without hacks.

Anyway, the fact that forward accepts c10::intrusive_ptr<T> could be a breakthrough. I haven't seen that documented. How have you found out that?

Also, could you share your TORCH_LIBRARY for ScaleOp with c10::intrusive_ptr<ValueHolder>? I guess, it should be something according to this example (https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html#defining-custom-operators-that-take-or-return-bound-c-classes)

@peastman
Copy link
Member

peastman commented Nov 6, 2020

In #5, I found a way to pass a class between forward and backward without hacks.

Cool! I hadn't realized that was possible.

Anyway, the fact that forward accepts c10::intrusive_ptr could be a breakthrough. I haven't seen that documented. How have you found out that?

It's what they do in https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html.

Also, could you share your TORCH_LIBRARY for ScaleOp with c10::intrusive_ptr?

TORCH_LIBRARY(optest, m) {
    m.class_<ValueHolder>("ValueHolder")
        .def(torch::init<double>())
        .def("getValue", &ValueHolder::getValue)
        .def("setValue", &ValueHolder::setValue)
        .def("scaleForward", &ValueHolder::scaleForward)
        .def("scaleBackward", &ValueHolder::scaleBackward);
    m.def("ScaleOp", &ScaleOp);
}

@peastman
Copy link
Member

peastman commented Dec 3, 2020

Closing this since we now have SchNet support.

@peastman peastman closed this as completed Dec 3, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants