Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training #22

Open
3 of 4 tasks
Tracked by #29
jafioti opened this issue Feb 6, 2024 · 13 comments
Open
3 of 4 tasks
Tracked by #29

Training #22

jafioti opened this issue Feb 6, 2024 · 13 comments
Labels
advanced Issues requiring good knowledge of the codebase

Comments

@jafioti
Copy link
Owner

jafioti commented Feb 6, 2024

It should be possible to write an autograd compiler that runs on a primgraph, and derives a backward graph and attaches it to the end of the main graph. With this, we can then run the rest of the optimizations on the full graph, and get good performance for free.

We will also need an api around it, for transferring weights between different sets (prev weights, new weights) and an optimizer library.

Compiler should implement reverse mode differentiation because it derives gradients of every node with respect to one output.
Compiler assumes the graph is a primgraph. Ensure this is enough to create gradient nodes for each model node. The compiler should take in a node set of model weights. Automatically derive the gradient graph by walking backward through the existing graph. Does this happen with the optimizer inside the forward graph? Or is optimizer totally seperate. Perhaps optimizer is a function, hopefully it can be implemented with primops.

The compiler should create not only the gradient nodes, but then add those gradients to the weight nodes to get new weights. New weight set should be returned from the compiler.

In a normal training loop, the weights aren't marked as kept, but the updated weights are. After the graph is ran, the previous weights are gone, and the new weights are present. We should then simply swap the node sets (similar to how KV caching works in mistral).

So training consists of something like this:

let gpus = [...];
let input = cx.tensor();
let target = cx.tensor();
let output = model.forward(input);
let loss = loss_fn(output, target);
let learning_rate = cx.constant(1e-4);
adam_w(loss, learning_rate);
let old_weights = state_dict(model);

// Distribute training over gpus
let new_weights = cx.compile(Autograd(old_weights, loss));
cx.compile((GenericCompiler, CudaCompiler<bf16>, DistributedDataParallel(gpus)));

for (input, target) in data {
    input_tensor.set(input);
    target_tensor.set(target);
    cx.execute();
    // Transfer new weights back to old weights (weight update)
    transfer_nodes_same_graph(new_weights, old_weights, &mut cx);
    println!("Loss: {}", loss.data());
    loss.drop();
    // Update lr
    learning_rate.set([1e-3]);
}

Semantics still need to be worked out for how gradients are accumulated before they're applied, etc.

Todo:

  • Sketch out concepts fully (what the optimizer is, derivatives of each node, etc)
  • Compiler to derive and attach backward graph to forward graph (https://colah.github.io/posts/2015-08-Backprop/)
  • Optimizers (SGD, Adam)
  • Train mnist feedforward
@jafioti jafioti added the advanced Issues requiring good knowledge of the codebase label Mar 1, 2024
@jafioti jafioti mentioned this issue Mar 2, 2024
11 tasks
@jafioti jafioti pinned this issue Mar 2, 2024
@jafioti
Copy link
Owner Author

jafioti commented Mar 16, 2024

Some furthur thinking on this.

Essentially we need to device a local derivative for each primop:

  • Log2 - 1 / (x * ln(2))
  • Exp2 - exp2(x) * ln(2)
  • Sin - cos(x)
  • Sqrt - 1/(2 * sqrt(x))
  • Recip - -1 / x^2
  • Add - 1
  • Mul - dy/da = b, dy/dx = a
  • Mod - Discontinuities
  • LessThan - Discontinuities
  • SumReduce - 1
  • MaxReduce - Discontinuities
  • Contiguous - 1

The ops with discontinuities don't really have derivatives, so I think if we just don't backprop through them it'll be fine. It seems like where they are used, we already have other paths that maintain the derivative like max(x, 0).

The autograd compiler should essentially just do a reverse dfs from the loss node through the forward graph and iteratively build up the backward graph starting at the loss node. Along the dfs if it encounters a 0 derivative, it should stop going down that branch (since chain rule specifies that every other derivative along that path will end up 0 because of the multiply).

@daniel-vainsencher
Copy link

Hi, just a comment: if I understand correctly, MaxReduce takes the maximum of multiple quantities. If that is the case, it does not have a discontinuity, merely a small set of non-differentiability points, just like the special case max(x,0) does.

For such functions the correct thing to do is to take the local gradient of the active (maximizing) quantity, and if there are multiple, choose any convex combination of the gradients of the active ones.

For Mod and LessThen, the path you propose is correct because at all points of continuity, they are piecewise constant, so their gradient is zero wherever it is defined.

@jafioti
Copy link
Owner Author

jafioti commented Mar 20, 2024

@daniel-vainsencher Great catch, yeah the local gradient I guess should just be 1 for the elements of the input to max reduce equal to the corresponding output (the inputs that actually are max), and 0 for each input that isn't max.

@jafioti
Copy link
Owner Author

jafioti commented Mar 21, 2024

Commit f65b0292495a66e941edbdb7374de5ddf3668a9c contains an initial implementation for the autograd.

Let's do a 2 graph solution: We have the gradient graph which is responsible for doing whatever needed to produce the gradients for an update, and the optimizer graph, which takes in model weights and gradients and produces new weights.

Both graphs can be stateful if needed (for gradient accumulation in the gradient graph and optimizer state in the optimizer graph for instance), and gradients + weights are transferred from the gradient graph to the optimizer graph when a weight update is done, and new weights are transferred back again. We can make nice APIs for this if needed.

@jafioti
Copy link
Owner Author

jafioti commented Apr 3, 2024

Now we have autograd fully implemented and tested for transformers! examples/train_math_net contains the first training example using sgd.

I'll implement adam soon and do an mnist training example, and then this issue can be closed.

@quietlychris
Copy link

Apologies if this is a little presumptuous, but I'd like to suggest maybe aiming for a training example using a dataset like CIFAR-10 (or other RGB dataset, although I think that might be the simplest RGB dataset) instead of MNIST. I've been playing around with various Rust deep learning libraries for a while, and it often seems like the methods used to get good classification results on MNIST don't really translate well to RGB/multi-dimensional data that require convolutional layers and more complex architectures to get good results. As a result, even though it's super common to use MNIST as sort of an MVP example, I tend to think it kind of falls short of actually being "minimum viable" for moving onto solving a lot of real world problems. Luminal looks like it's really starting to come together, and I'm excited to see where things go!

@jafioti
Copy link
Owner Author

jafioti commented Apr 8, 2024

@quietlychris Great point. Let's do CIFAR instead. Do you know how much more time cifar would take to train vs mnist? Like 2x longer or 10x longer? Ideally we want these training jobs to finish in a few minutes because I want them to run on every commit ideally.

@daniel-vainsencher
Copy link

https://github.com/tysam-code/hlb-CIFAR10 claims CIFAR10 @ 94% in <7 seconds on an A100.

They've apparently done a lot to get that fast, so are an ambitious comparison point. That said, I think having training working on anything is a great milestone, and reproducing world class training speed can totally be a separate issue :)

@jafioti
Copy link
Owner Author

jafioti commented Apr 9, 2024

Wow that's nuts. For luminal_training the training run needs to happen just with primops to test the autograd / training stuff in isolation so I'm thinking mnist for that. We can also have other training jobs use cuda or metal for platforms that support those (which we'll definitely need to test) so cifar should be a good fit

@quietlychris
Copy link

Apologies for the delay on this; things have been a bit crazy recently. Off-hand, CIFAR-10 usually does take a while longer than MNIST, I believe typically because the requisite convolutional layers are much more compute intensive. If you're running them on Github Actions, I would expect that to take at least a couple of minutes. That said, the PyTorch tutorial at https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html says the following when switching to CPU

Why don’t I notice MASSIVE speedup compared to CPU? Because your network is really small.

I don't have the actual benchmarks off-hand, though, although I seem to recall it being in the "pour yourself a cup of coffee, but don't start a new project" level time between training runs.

That said, I'm not sure that you'd actually need to train until reaching >90% on CIFAR? Just stacking fully-connected layers can get you to about 40-45%, and a little convolution will get you around 50%, but getting up into the >70% range seems to really be the inflection point where the library capability/network architecture is fine and now it's purely a train-for-more-epochs problem. All of which is to say, for a regularly-run CI job, it might be fine to target only training to a lower accuracy threshold and still get most of the same benefits without the increased time trade-off for the last few percent.

@janroden
Copy link
Contributor

janroden commented Jun 7, 2024

Sounds exciting :) I'm curious what your thoughts are on potentially using Enzyme for the autodiff? Its performance advantages look promising so far? Or are there no significant advantages to be expected in the case of luminal because the derivatives of the primops are quite simple and then the automatically derived gradient graph can be heavily optimized before it even comes to the LLVM stage?
Talk at DOE CSGF
Enzyme Rust
Thanks!

@jafioti
Copy link
Owner Author

jafioti commented Jun 8, 2024

@janroden Good idea, I don't think there would be a huge benifit since luminal primitive ops are quite bounded, so the problem of autodiff is very closed. You can see in luminal_train that the autograd compiler is ~150 lines, so we don't need to bring in a general-purpose autodiff.

@janroden
Copy link
Contributor

janroden commented Jun 9, 2024

@janroden Good idea, I don't think there would be a huge benifit since luminal primitive ops are quite bounded, so the problem of autodiff is very closed. You can see in luminal_train that the autograd compiler is ~150 lines, so we don't need to bring in a general-purpose autodiff.

Thanks for the explanation! Makes sense. The concept of the primops seems quite powerful :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
advanced Issues requiring good knowledge of the codebase
Projects
None yet
Development

No branches or pull requests

4 participants