-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Training #22
Comments
Some furthur thinking on this. Essentially we need to device a local derivative for each primop:
The ops with discontinuities don't really have derivatives, so I think if we just don't backprop through them it'll be fine. It seems like where they are used, we already have other paths that maintain the derivative like max(x, 0). The autograd compiler should essentially just do a reverse dfs from the loss node through the forward graph and iteratively build up the backward graph starting at the loss node. Along the dfs if it encounters a 0 derivative, it should stop going down that branch (since chain rule specifies that every other derivative along that path will end up 0 because of the multiply). |
Hi, just a comment: if I understand correctly, MaxReduce takes the maximum of multiple quantities. If that is the case, it does not have a discontinuity, merely a small set of non-differentiability points, just like the special case max(x,0) does. For such functions the correct thing to do is to take the local gradient of the active (maximizing) quantity, and if there are multiple, choose any convex combination of the gradients of the active ones. For Mod and LessThen, the path you propose is correct because at all points of continuity, they are piecewise constant, so their gradient is zero wherever it is defined. |
@daniel-vainsencher Great catch, yeah the local gradient I guess should just be 1 for the elements of the input to max reduce equal to the corresponding output (the inputs that actually are max), and 0 for each input that isn't max. |
Commit Let's do a 2 graph solution: We have the gradient graph which is responsible for doing whatever needed to produce the gradients for an update, and the optimizer graph, which takes in model weights and gradients and produces new weights. Both graphs can be stateful if needed (for gradient accumulation in the gradient graph and optimizer state in the optimizer graph for instance), and gradients + weights are transferred from the gradient graph to the optimizer graph when a weight update is done, and new weights are transferred back again. We can make nice APIs for this if needed. |
Now we have autograd fully implemented and tested for transformers! examples/train_math_net contains the first training example using sgd. I'll implement adam soon and do an mnist training example, and then this issue can be closed. |
Apologies if this is a little presumptuous, but I'd like to suggest maybe aiming for a training example using a dataset like CIFAR-10 (or other RGB dataset, although I think that might be the simplest RGB dataset) instead of MNIST. I've been playing around with various Rust deep learning libraries for a while, and it often seems like the methods used to get good classification results on MNIST don't really translate well to RGB/multi-dimensional data that require convolutional layers and more complex architectures to get good results. As a result, even though it's super common to use MNIST as sort of an MVP example, I tend to think it kind of falls short of actually being "minimum viable" for moving onto solving a lot of real world problems. Luminal looks like it's really starting to come together, and I'm excited to see where things go! |
@quietlychris Great point. Let's do CIFAR instead. Do you know how much more time cifar would take to train vs mnist? Like 2x longer or 10x longer? Ideally we want these training jobs to finish in a few minutes because I want them to run on every commit ideally. |
https://github.com/tysam-code/hlb-CIFAR10 claims CIFAR10 @ 94% in <7 seconds on an A100. They've apparently done a lot to get that fast, so are an ambitious comparison point. That said, I think having training working on anything is a great milestone, and reproducing world class training speed can totally be a separate issue :) |
Wow that's nuts. For luminal_training the training run needs to happen just with primops to test the autograd / training stuff in isolation so I'm thinking mnist for that. We can also have other training jobs use cuda or metal for platforms that support those (which we'll definitely need to test) so cifar should be a good fit |
Apologies for the delay on this; things have been a bit crazy recently. Off-hand, CIFAR-10 usually does take a while longer than MNIST, I believe typically because the requisite convolutional layers are much more compute intensive. If you're running them on Github Actions, I would expect that to take at least a couple of minutes. That said, the PyTorch tutorial at https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html says the following when switching to CPU
I don't have the actual benchmarks off-hand, though, although I seem to recall it being in the "pour yourself a cup of coffee, but don't start a new project" level time between training runs. That said, I'm not sure that you'd actually need to train until reaching >90% on CIFAR? Just stacking fully-connected layers can get you to about 40-45%, and a little convolution will get you around 50%, but getting up into the >70% range seems to really be the inflection point where the library capability/network architecture is fine and now it's purely a train-for-more-epochs problem. All of which is to say, for a regularly-run CI job, it might be fine to target only training to a lower accuracy threshold and still get most of the same benefits without the increased time trade-off for the last few percent. |
Sounds exciting :) I'm curious what your thoughts are on potentially using Enzyme for the autodiff? Its performance advantages look promising so far? Or are there no significant advantages to be expected in the case of luminal because the derivatives of the primops are quite simple and then the automatically derived gradient graph can be heavily optimized before it even comes to the LLVM stage? |
@janroden Good idea, I don't think there would be a huge benifit since luminal primitive ops are quite bounded, so the problem of autodiff is very closed. You can see in luminal_train that the autograd compiler is ~150 lines, so we don't need to bring in a general-purpose autodiff. |
Thanks for the explanation! Makes sense. The concept of the primops seems quite powerful :) |
It should be possible to write an autograd compiler that runs on a primgraph, and derives a backward graph and attaches it to the end of the main graph. With this, we can then run the rest of the optimizations on the full graph, and get good performance for free.
We will also need an api around it, for transferring weights between different sets (prev weights, new weights) and an optimizer library.
Compiler should implement reverse mode differentiation because it derives gradients of every node with respect to one output.
Compiler assumes the graph is a primgraph. Ensure this is enough to create gradient nodes for each model node. The compiler should take in a node set of model weights. Automatically derive the gradient graph by walking backward through the existing graph. Does this happen with the optimizer inside the forward graph? Or is optimizer totally seperate. Perhaps optimizer is a function, hopefully it can be implemented with primops.
The compiler should create not only the gradient nodes, but then add those gradients to the weight nodes to get new weights. New weight set should be returned from the compiler.
In a normal training loop, the weights aren't marked as kept, but the updated weights are. After the graph is ran, the previous weights are gone, and the new weights are present. We should then simply swap the node sets (similar to how KV caching works in mistral).
So training consists of something like this:
Semantics still need to be worked out for how gradients are accumulated before they're applied, etc.
Todo:
The text was updated successfully, but these errors were encountered: