Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs: migrate most examples to Reactant #1180

Merged
merged 17 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.4.4"
version = "1.5.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -88,7 +88,7 @@ Compat = "4.16"
ComponentArrays = "0.15.18"
ConcreteStructs = "0.2.3"
DispatchDoctor = "0.4.12"
Enzyme = "0.13.16"
Enzyme = "0.13.28"
EnzymeCore = "0.8.8"
FastClosures = "0.3.2"
Flux = "0.15, 0.16"
Expand Down
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,38 @@ gs, loss, stats, train_state = Training.single_train_step!(AutoZygote(), MSELoss
(x, dev(rand(rng, Float32, 10, 2))), train_state)
```

## 🤸 Quickstart with Reactant

```julia
using Lux, Random, Optimisers, Reactant, Enzyme

rng = Random.default_rng()
Random.seed!(rng, 0)

model = Chain(Dense(128, 256, tanh), Chain(Dense(256, 1, tanh), Dense(1, 10)))

dev = reactant_device()

ps, st = Lux.setup(rng, model) |> dev

x = rand(rng, Float32, 128, 2) |> dev

# We need to compile the model before we can use it.
model_forward = @compile model(x, ps, Lux.testmode(st))
model_forward(x, ps, Lux.testmode(st))

# Gradients can be computed using Enzyme
@jit Enzyme.gradient(Reverse, sum ∘ first ∘ Lux.apply, Const(model), x, ps, Const(st))

# All of this can be automated using the TrainState API
train_state = Training.TrainState(model, ps, st, Adam(0.001f0))

gs, loss, stats, train_state = Training.single_train_step!(
AutoEnzyme(), MSELoss(),
(x, dev(rand(rng, Float32, 10, 2))), train_state
)
```

## 📚 Examples

Look in the [examples](/examples/) directory for self-contained usage examples. The [documentation](https://lux.csail.mit.edu) has examples sorted into proper categories.
Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ julia = "1.10"
[sources]
Lux = { path = "../" }
LuxLib = { path = "../lib/LuxLib" }
LuxCUDA = { path = "../lib/LuxCUDA" }
LuxCore = { path = "../lib/LuxCore" }
MLDataDevices = { path = "../lib/MLDataDevices" }
LuxTestUtils = { path = "../lib/LuxTestUtils" }
Expand Down
10 changes: 6 additions & 4 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using Documenter, DocumenterVitepress, Pkg
using Lux, LuxCore, LuxLib, WeightInitializers, NNlib
using LuxTestUtils, MLDataDevices
using LuxCUDA

using Optimisers # for some docstrings

Expand Down Expand Up @@ -78,8 +77,10 @@ pages = [
#! format: on

deploy_config = Documenter.auto_detect_deploy_system()
deploy_decision = Documenter.deploy_folder(deploy_config; repo="github.com/LuxDL/Lux.jl",
devbranch="main", devurl="dev", push_preview=true)
deploy_decision = Documenter.deploy_folder(
deploy_config; repo="github.com/LuxDL/Lux.jl",
devbranch="main", devurl="dev", push_preview=true
)

makedocs(;
sitename="Lux.jl Docs",
Expand All @@ -96,7 +97,8 @@ makedocs(;
repo="https://github.com/LuxDL/Lux.jl/blob/{commit}{path}#{line}",
format=DocumenterVitepress.MarkdownVitepress(;
repo="github.com/LuxDL/Lux.jl", devbranch="main", devurl="dev",
deploy_url="https://lux.csail.mit.edu", deploy_decision),
deploy_url="https://lux.csail.mit.edu", deploy_decision
),
draft=false,
pages
)
Expand Down
4 changes: 2 additions & 2 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ hero:

features:
- icon: 🚀
title: Fast & Extendible
details: Lux.jl is written in Julia itself, making it extremely extendible. CUDA and AMDGPU are supported first-class, with experimental support for Metal and Intel GPUs.
title: Fast & Extendable
details: Lux.jl is written in Julia itself, making it extremely extendable. CUDA and AMDGPU are supported first-class, with experimental support for Metal and Intel GPUs.
link: /introduction

- icon: 🐎
Expand Down
50 changes: 33 additions & 17 deletions docs/src/introduction/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ Pkg.add("Lux")

```@example quickstart
using Lux, Random, Optimisers, Zygote
using LuxCUDA # For CUDA support
# using AMDGPU, Metal, oneAPI # Other pptional packages for GPU support
# using LuxCUDA, AMDGPU, Metal, oneAPI # Optional packages for GPU support
```

We take randomness very seriously
Expand Down Expand Up @@ -66,26 +65,33 @@ y, st = Lux.apply(model, x, ps, st)
train_state = Lux.Training.TrainState(model, ps, st, Adam(0.0001f0))

## We can compute the gradients using Training.compute_gradients
gs, loss, stats, train_state = Lux.Training.compute_gradients(AutoZygote(), MSELoss(),
(x, dev(rand(rng, Float32, 10, 2))), train_state)
gs, loss, stats, train_state = Lux.Training.compute_gradients(
AutoZygote(), MSELoss(),
(x, dev(rand(rng, Float32, 10, 2))), train_state
)

## Optimization
train_state = Training.apply_gradients!(train_state, gs) # or Training.apply_gradients (no `!` at the end)

# Both these steps can be combined into a single call
gs, loss, stats, train_state = Training.single_train_step!(AutoZygote(), MSELoss(),
(x, dev(rand(rng, Float32, 10, 2))), train_state)
gs, loss, stats, train_state = Training.single_train_step!(
AutoZygote(), MSELoss(),
(x, dev(rand(rng, Float32, 10, 2))), train_state
)
```

## Defining Custom Layers

We can train our model using the above code, but let's go ahead and see how to use Reactant.
Reactant is a julia frontend that generates MLIR and then compiles it using XLA (after
running fancy optimizations). It is the current recommended way to train large models in
Lux. For more details on using Reactant, see the [manual](@ref reactant-compilation).

```@example custom_compact
using Lux, Random, Optimisers, Zygote
using LuxCUDA # For CUDA support
# using AMDGPU, Metal, oneAPI # Other pptional packages for GPU support
using Lux, Random, Optimisers, Reactant, Enzyme
using Printf # For pretty printing

dev = gpu_device()
dev = reactant_device()
```

We will define a custom MLP using the `@compact` macro. The macro takes in a list of
Expand All @@ -97,10 +103,12 @@ n_in = 1
n_out = 1
nlayers = 3

model = @compact(w1=Dense(n_in => 32),
model = @compact(
w1=Dense(n_in => 32),
w2=[Dense(32 => 32) for i in 1:nlayers],
w3=Dense(32 => n_out),
act=relu) do x
act=relu
) do x
embed = act(w1(x))
for w in w2
embed = act(w(embed))
Expand All @@ -116,21 +124,24 @@ We can initialize the model and train it with the same code as before!
rng = Random.default_rng()
Random.seed!(rng, 0)

ps, st = Lux.setup(Xoshiro(0), model) |> dev
ps, st = Lux.setup(rng, model) |> dev

x = rand(rng, Float32, n_in, 32) |> dev

model(x, ps, st) # 1×32 Matrix and updated state as output.
@jit model(x, ps, st) # 1×32 Matrix and updated state as output.

x_data = reshape(collect(-2.0f0:0.1f0:2.0f0), 1, :) |> dev
x_data = reshape(collect(-2.0f0:0.1f0:2.0f0), 1, :)
y_data = 2 .* x_data .- x_data .^ 3
x_data, y_data = dev(x_data), dev(y_data)

function train_model!(model, ps, st, x_data, y_data)
train_state = Lux.Training.TrainState(model, ps, st, Adam(0.001f0))

for iter in 1:1000
_, loss, _, train_state = Lux.Training.single_train_step!(AutoZygote(), MSELoss(),
(x_data, y_data), train_state)
_, loss, _, train_state = Lux.Training.single_train_step!(
AutoEnzyme(), MSELoss(),
(x_data, y_data), train_state
)
if iter % 100 == 1 || iter == 1000
@printf "Iteration: %04d \t Loss: %10.9g\n" iter loss
end
Expand All @@ -155,6 +166,11 @@ packages mentioned in this documentation are available via the Julia General Reg

You can install all those packages via `import Pkg; Pkg.add(<package name>)`.

## XLA (CPU/GPU/TPU) Support

Lux.jl supports XLA compilation for CPU, GPU, and TPU using
[Reactant.jl](https://github.com/EnzymeAD/Reactant.jl).

## GPU Support

GPU Support for Lux.jl requires loading additional packages:
Expand Down
10 changes: 10 additions & 0 deletions docs/src/manual/compiling_lux_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,16 @@ fmap(Broadcast.BroadcastFunction(-), ∂ps_zyg, ∂ps_enzyme |> cpu_device())

## [Using the `TrainState` API](@id compile_lux_model_trainstate)

!!! tip "Debugging TrainState API Failures"

If the code fails to compile with Reactant, it is useful to dump the HLO. Starting the
Julia session with `LUX_DUMP_REACTANT_HLO_OPTIMIZE` environment variable set to
`no_enzyme`, `false`, or `true` will dump the HLO to a file (filename will be
displayed). This is an useful information to provide when opening an issue.

Alternatively, you can set theglobal reference `Lux.DUMP_REACTANT_HLO_OPT_MODE` to a
symbol corresponding to the `optimize` keyword argument to `@code_hlo`.

Now that we saw the low-level API let's see how to train the model without any of this
boilerplate. Simply follow the following steps:

Expand Down
13 changes: 6 additions & 7 deletions docs/src/manual/gpu_management.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@
# GPU Management

!!! info

Starting from `v0.5`, Lux has transitioned to a new GPU management system. The old
system using `cpu` and `gpu` functions is still in place but will be removed in `v1`.
Using the old functions might lead to performance regressions if used inside
performance critical code.

`Lux.jl` can handle multiple GPU backends. Currently, the following backends are supported:

```@example gpu_management
Expand All @@ -16,6 +9,12 @@ using Lux, LuxCUDA #, AMDGPU, Metal, oneAPI
supported_gpu_backends()
```

!!! tip "GPU Support via Reactant"

If you are using Reactant, you can use the [`reactant_device`](@ref) function to
automatically select Reactant backend if available. Additionally to force Reactant to
use `gpu`, you can run `Reactant.set_default_backend("gpu")` (this is automatic).

!!! danger "Metal Support"

Support for Metal GPUs should be considered extremely experimental at this point.
Expand Down
5 changes: 3 additions & 2 deletions docs/tutorials.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#! format: off
const BEGINNER_TUTORIALS = [
"Basics/main.jl" => "CUDA",
"Basics/main.jl" => "CPU",
"PolynomialFitting/main.jl" => "CUDA",
"SimpleRNN/main.jl" => "CUDA",
# Technically this is run on CPU but we need a better machine to run it
"SimpleChains/main.jl" => "CUDA",
"OptimizationIntegration/main.jl" => "CUDA",
"OptimizationIntegration/main.jl" => "CPU",
]
const INTERMEDIATE_TUTORIALS = [
"NeuralODE/main.jl" => "CUDA",
Expand Down
2 changes: 0 additions & 2 deletions examples/Basics/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -12,6 +11,5 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ComponentArrays = "0.15.18"
ForwardDiff = "0.10"
Lux = "1"
LuxCUDA = "0.3"
Optimisers = "0.4.1"
Zygote = "0.6"
14 changes: 8 additions & 6 deletions examples/Basics/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,14 @@ W * x
# the `cu` function (or the `gpu` function exported by `Lux``), and it supports all of the
# above operations with the same syntax.

using LuxCUDA

if LuxCUDA.functional()
x_cu = cu(rand(5, 3))
@show x_cu
end
# ```julia
# using LuxCUDA
#
# if LuxCUDA.functional()
# x_cu = cu(rand(5, 3))
# @show x_cu
# end
# ```

# ## (Im)mutability

Expand Down
11 changes: 5 additions & 6 deletions examples/ConditionalVAE/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ end
function create_image_grid(imgs::AbstractArray, grid_rows::Int, grid_cols::Int)
total_images = grid_rows * grid_cols
imgs = map(eachslice(imgs[:, :, :, 1:total_images]; dims=4)) do img
cimg = size(img, 3) == 1 ? colorview(Gray, view(img, :, :, 1)) : colorview(RGB, img)
cimg = size(img, 3) == 1 ? colorview(Gray, view(img, :, :, 1)) :
colorview(RGB, permutedims(img, (3, 1, 2)))
return cimg'
end
return create_image_grid(imgs, grid_rows, grid_cols)
Expand Down Expand Up @@ -239,23 +240,21 @@ function main(; batchsize=128, image_size=(64, 64), num_latent_dims=8, max_num_f
for epoch in 1:epochs
loss_total = 0.0f0
total_samples = 0
total_time = 0.0

start_time = time()
for (i, X) in enumerate(train_dataloader)
throughput_tic = time()
(_, loss, _, train_state) = Training.single_train_step!(
AutoEnzyme(), loss_function, X, train_state)
throughput_toc = time()

loss_total += loss
total_samples += size(X, ndims(X))
total_time += throughput_toc - throughput_tic

if i % 250 == 0 || i == length(train_dataloader)
throughput = total_samples / total_time
throughput = total_samples / (time() - start_time)
@printf "Epoch %d, Iter %d, Loss: %.7f, Throughput: %.6f im/s\n" epoch i loss throughput
end
end
total_time = time() - start_time

train_loss = loss_total / length(train_dataloader)
throughput = total_samples / total_time
Expand Down
6 changes: 0 additions & 6 deletions examples/HyperNet/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
Expand All @@ -9,19 +8,14 @@ OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ADTypes = "1.10"
ComponentArrays = "0.15.18"
Lux = "1"
LuxCUDA = "0.3"
MLDatasets = "0.7"
MLUtils = "0.4"
OneHotArrays = "0.2.5"
Optimisers = "0.4.1"
Setfield = "1"
Statistics = "1"
Zygote = "0.6"
4 changes: 2 additions & 2 deletions examples/HyperNet/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

# ## Package Imports

using Lux, ADTypes, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers,
Printf, Random, Setfield, Statistics, Zygote
using Lux, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers,
Printf, Random, Zygote

CUDA.allowscalar(false)

Expand Down
2 changes: 0 additions & 2 deletions examples/OptimizationIntegration/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
Expand All @@ -16,7 +15,6 @@ SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
CairoMakie = "0.12.10"
ComponentArrays = "0.15.18"
Lux = "1"
LuxCUDA = "0.3.3"
MLUtils = "0.4.4"
Optimization = "4"
OptimizationOptimJL = "0.4"
Expand Down
Loading
Loading