Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump compats and update tutorials for Optimization v4 #950

Merged
merged 12 commits into from
Nov 5, 2024
8 changes: 4 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Boltz = "1"
ChainRulesCore = "1"
ComponentArrays = "0.15.17"
ConcreteStructs = "0.2"
DataInterpolations = "5, 6"
Vaibhavdixit02 marked this conversation as resolved.
Show resolved Hide resolved
DataInterpolations = "6.4"
DelayDiffEq = "5.47.3"
DiffEqCallbacks = "3.6.2"
Distances = "0.10.11"
Expand All @@ -54,9 +54,9 @@ LuxLib = "1.2"
NNlib = "0.9.22"
OneHotArrays = "0.2.5"
Optimisers = "0.3"
Optimization = "3.25.0"
OptimizationOptimJL = "0.3.0"
OptimizationOptimisers = "0.2.1"
Optimization = "4"
OptimizationOptimJL = "0.4"
OptimizationOptimisers = "0.3"
OrdinaryDiffEq = "6.76.0"
Printf = "1.10"
Random = "1.10"
Expand Down
8 changes: 4 additions & 4 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ MLUtils = "0.4"
NNlib = "0.9"
OneHotArrays = "0.2"
Optimisers = "0.3"
Optimization = "3.9"
OptimizationOptimJL = "0.2, 0.3"
OptimizationOptimisers = "0.2"
OptimizationPolyalgorithms = "0.2"
Optimization = "4"
OptimizationOptimJL = "0.4"
OptimizationOptimisers = "0.3"
OptimizationPolyalgorithms = "0.3"
OrdinaryDiffEq = "6.31"
Plots = "1.36"
Printf = "1"
Expand Down
34 changes: 17 additions & 17 deletions docs/src/examples/augmented_neural_ode.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ function plot_contour(model, ps, st, npoints = 300)
return contour(x, y, sol; fill = true, linewidth = 0.0)
end

loss_node(model, x, y, ps, st) = mean((first(model(x, ps, st)) .- y) .^ 2)
loss_node(model, data, ps, st) = mean((first(model(data[1], ps, st)) .- data[2]) .^ 2)

dataloader = concentric_sphere(
2, (0.0f0, 2.0f0), (3.0f0, 4.0f0), 2000, 2000; batch_size = 256)

iter = 0
cb = function (ps, l)
cb = function (state, l)
global iter
iter += 1
if iter % 10 == 0
Expand All @@ -87,15 +87,15 @@ end
model, ps, st = construct_model(1, 2, 64, 0)
opt = OptimizationOptimisers.Adam(0.005)

loss_node(model, dataloader.data[1], dataloader.data[2], ps, st)
loss_node(model, (dataloader.data[1], dataloader.data[2]), ps, st)

println("Training Neural ODE")

optfunc = OptimizationFunction(
(x, p, data, target) -> loss_node(model, data, target, x, st),
(x, data) -> loss_node(model, data, x, st),
Optimization.AutoZygote())
optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev)
res = solve(optprob, opt, IterTools.ncycle(dataloader, 5); callback = cb)
optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev, dataloader)
res = solve(optprob, opt; callback = cb, epochs = 1000)

plt_node = plot_contour(model, res.u, st)

Expand All @@ -106,10 +106,10 @@ println()
println("Training Augmented Neural ODE")

optfunc = OptimizationFunction(
(x, p, data, target) -> loss_node(model, data, target, x, st),
(x, data) -> loss_node(model, data, x, st),
Optimization.AutoZygote())
optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev)
res = solve(optprob, opt, IterTools.ncycle(dataloader, 5); callback = cb)
optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev, dataloader)
res = solve(optprob, opt; callback = cb, epochs = 1000)

plot_contour(model, res.u, st)
```
Expand Down Expand Up @@ -229,7 +229,7 @@ We use the L2 distance between the model prediction `model(x)` and the actual pr
optimization objective.

```@example augneuralode
loss_node(model, x, y, ps, st) = mean((first(model(x, ps, st)) .- y) .^ 2)
loss_node(model, data, ps, st) = mean((first(model(data[1], ps, st)) .- data[2]) .^ 2)
```

#### Dataset
Expand All @@ -248,7 +248,7 @@ Additionally, we define a callback function which displays the total loss at spe

```@example augneuralode
iter = 0
cb = function (ps, l)
cb = function (state, l)
global iter
iter += 1
if iter % 10 == 0
Expand Down Expand Up @@ -276,10 +276,10 @@ for `20` epochs.
model, ps, st = construct_model(1, 2, 64, 0)

optfunc = OptimizationFunction(
(x, p, data, target) -> loss_node(model, data, target, x, st),
(x, data) -> loss_node(model, data, x, st),
Optimization.AutoZygote())
optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev)
res = solve(optprob, opt, IterTools.ncycle(dataloader, 5); callback = cb)
optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev, dataloader)
res = solve(optprob, opt; callback = cb, epochs = 1000)

plot_contour(model, res.u, st)
```
Expand All @@ -297,10 +297,10 @@ a function which can be expressed by the neural ode. For more details and proofs
model, ps, st = construct_model(1, 2, 64, 1)

optfunc = OptimizationFunction(
(x, p, data, target) -> loss_node(model, data, target, x, st),
(x, data) -> loss_node(model, data, x, st),
Optimization.AutoZygote())
optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev)
res = solve(optprob, opt, IterTools.ncycle(dataloader, 5); callback = cb)
optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev, dataloader)
res = solve(optprob, opt; callback = cb, epochs = 1000)

plot_contour(model, res.u, st)
```
Expand Down
63 changes: 28 additions & 35 deletions docs/src/examples/hamiltonian_nn.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Before getting to the explanation, here's some code to start with. We will follo

```@example hamiltonian_cp
using Lux, DiffEqFlux, OrdinaryDiffEq, Statistics, Plots, Zygote, ForwardDiff, Random,
ComponentArrays, Optimization, OptimizationOptimisers, IterTools
ComponentArrays, Optimization, OptimizationOptimisers, MLUtils

t = range(0.0f0, 1.0f0; length = 1024)
π_32 = Float32(π)
Expand All @@ -23,37 +23,33 @@ p_t = reshape(cos.(2π_32 * t), 1, :)
dqdt = 2π_32 .* p_t
dpdt = -2π_32 .* q_t

data = vcat(q_t, p_t)
target = vcat(dqdt, dpdt)
data = cat(q_t, p_t; dims = 1)
target = cat(dqdt, dpdt; dims = 1)
B = 256
NEPOCHS = 100
dataloader = ncycle(
((selectdim(data, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2)))),
selectdim(target, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2)))))
for i in 1:(size(data, 2) ÷ B)),
NEPOCHS)

hnn = Layers.HamiltonianNN{true}(Layers.MLP(2, (64, 1)); autodiff = AutoZygote())
NEPOCHS = 500
dataloader = DataLoader((data, target); batchsize = B)

hnn = Layers.HamiltonianNN{true}(Layers.MLP(2, (1028, 1)); autodiff = AutoZygote())
ps, st = Lux.setup(Xoshiro(0), hnn)
ps_c = ps |> ComponentArray

opt = OptimizationOptimisers.Adam(0.01f0)

function loss_function(ps, data, target)
function loss_function(ps, databatch)
data, target = databatch
pred, st_ = hnn(data, ps, st)
return mean(abs2, pred .- target), pred
return mean(abs2, pred .- target)
end

function callback(ps, loss, pred)
function callback(state, loss)
println("[Hamiltonian NN] Loss: ", loss)
return false
end

opt_func = OptimizationFunction((ps, _, data, target) -> loss_function(ps, data, target),
Optimization.AutoForwardDiff())
opt_prob = OptimizationProblem(opt_func, ps_c)
opt_func = OptimizationFunction(loss_function, Optimization.AutoForwardDiff())
opt_prob = OptimizationProblem(opt_func, ps_c, dataloader)

res = Optimization.solve(opt_prob, opt, dataloader; callback)
res = Optimization.solve(opt_prob, opt; callback, epochs = NEPOCHS)

ps_trained = res.u

Expand All @@ -75,7 +71,7 @@ The HNN predicts the gradients ``(\dot q, \dot p)`` given ``(q, p)``. Hence, we

```@example hamiltonian
using Lux, DiffEqFlux, OrdinaryDiffEq, Statistics, Plots, Zygote, ForwardDiff, Random,
ComponentArrays, Optimization, OptimizationOptimisers, IterTools
ComponentArrays, Optimization, OptimizationOptimisers, MLUtils

t = range(0.0f0, 1.0f0; length = 1024)
π_32 = Float32(π)
Expand All @@ -87,40 +83,37 @@ dpdt = -2π_32 .* q_t
data = cat(q_t, p_t; dims = 1)
target = cat(dqdt, dpdt; dims = 1)
B = 256
NEPOCHS = 100
dataloader = ncycle(
((selectdim(data, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2)))),
selectdim(target, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2)))))
for i in 1:(size(data, 2) ÷ B)),
NEPOCHS)
NEPOCHS = 500
dataloader = DataLoader((data, target); batchsize = B)
```

### Training the HamiltonianNN

We parameterize the with a small MultiLayered Perceptron. HNNs are trained by optimizing the gradients of the Neural Network. Zygote currently doesn't support nesting itself, so we will be using ForwardDiff in the training loop to compute the gradients of the HNN Layer for Optimization.

```@example hamiltonian
hnn = Layers.HamiltonianNN{true}(Layers.MLP(2, (64, 1)); autodiff = AutoZygote())
hnn = Layers.HamiltonianNN{true}(Layers.MLP(2, (1028, 1)); autodiff = AutoZygote())
ps, st = Lux.setup(Xoshiro(0), hnn)
ps_c = ps |> ComponentArray
hnn_stateful = StatefulLuxLayer{true}(hnn, ps_c, st)

opt = OptimizationOptimisers.Adam(0.01f0)
opt = OptimizationOptimisers.Adam(0.005f0)

function loss_function(ps, data, target)
pred, st_ = hnn(data, ps, st)
return mean(abs2, pred .- target), pred
function loss_function(ps, databatch)
(data, target) = databatch
pred = hnn_stateful(data, ps)
return mean(abs2, pred .- target)
end

function callback(ps, loss, pred)
function callback(state, loss)
println("[Hamiltonian NN] Loss: ", loss)
return false
end

opt_func = OptimizationFunction(
(ps, _, data, target) -> loss_function(ps, data, target), Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps_c)
opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps_c, dataloader)

res = solve(opt_prob, opt, dataloader; callback)
res = Optimization.solve(opt_prob, opt; callback, epochs = NEPOCHS)

ps_trained = res.u
```
Expand Down
18 changes: 9 additions & 9 deletions docs/src/examples/mnist_conv_neural_ode.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,30 +89,30 @@ end
# burn in accuracy
accuracy(m, ((img, lab),), ps, st)

function loss_function(ps, x, y)
function loss_function(ps, data)
(x, y) = data
pred, _ = m(x, ps, st)
return logitcrossentropy(pred, y), pred
return logitcrossentropy(pred, y)
end

# burn in loss
loss_function(ps, img, lab)
loss_function(ps, (img, lab))

opt = OptimizationOptimisers.Adam(0.005)
iter = 0

opt_func = OptimizationFunction(
(ps, _, x, y) -> loss_function(ps, x, y), Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps);
opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps, dataloader);

function callback(ps, l, pred)
function callback(state, l)
global iter += 1
iter % 10 == 0 &&
@info "[MNIST Conv GPU] Accuracy: $(accuracy(m, dataloader, ps.u, st))"
@info "[MNIST Conv GPU] Accuracy: $(accuracy(m, dataloader, state.u, st))"
return false
end

# Train the NN-ODE and monitor the loss and weights.
res = Optimization.solve(opt_prob, opt, dataloader; maxiters = 5, callback)
res = Optimization.solve(opt_prob, opt; epochs = 5, callback)
acc = accuracy(m, dataloader, res.u, st)
acc # hide
```
36 changes: 18 additions & 18 deletions docs/src/examples/mnist_neural_ode.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,29 +81,29 @@ end

accuracy(m, ((x_train1, y_train1),), ps, st) # burn in accuracy

function loss_function(ps, x, y)
function loss_function(ps, data)
(x, y) = data
pred, st_ = m(x, ps, st)
return logitcrossentropy(pred, y), pred
return logitcrossentropy(pred, y)
end

loss_function(ps, x_train1, y_train1) # burn in loss
loss_function(ps, (x_train1, y_train1)) # burn in loss

opt = OptimizationOptimisers.Adam(0.05)
iter = 0

opt_func = OptimizationFunction(
(ps, _, x, y) -> loss_function(ps, x, y), Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps)
opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps, dataloader)

function callback(ps, l, pred)
function callback(state, l)
global iter += 1
iter % 10 == 0 &&
@info "[MNIST GPU] Accuracy: $(accuracy(m, dataloader, ps.u, st))"
@info "[MNIST GPU] Accuracy: $(accuracy(m, dataloader, state.u, st))"
return false
end

# Train the NN-ODE and monitor the loss and weights.
res = Optimization.solve(opt_prob, opt, dataloader; callback, maxiters = 5)
res = Optimization.solve(opt_prob, opt; callback, epochs = 5)
accuracy(m, dataloader, res.u, st)
```

Expand Down Expand Up @@ -285,12 +285,13 @@ final output of our model. `logitcrossentropy` takes in the prediction from our
model `model(x)` and compares it to actual output `y`:

```@example mnist
function loss_function(ps, x, y)
function loss_function(ps, data)
(x, y) = data
pred, st_ = m(x, ps, st)
return logitcrossentropy(pred, y), pred
return logitcrossentropy(pred, y)
end

loss_function(ps, x_train1, y_train1) # burn in loss
loss_function(ps, (x_train1, y_train1)) # burn in loss
```

#### Optimizer
Expand All @@ -309,14 +310,13 @@ This callback function is used to print both the training and testing accuracy a
```@example mnist
iter = 0

opt_func = OptimizationFunction(
(ps, _, x, y) -> loss_function(ps, x, y), Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps)
opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps, dataloader)

function callback(ps, l, pred)
function callback(state, l)
global iter += 1
iter % 10 == 0 &&
@info "[MNIST GPU] Accuracy: $(accuracy(m, dataloader, ps.u, st))"
@info "[MNIST GPU] Accuracy: $(accuracy(m, dataloader, state.u, st))"
return false
end
```
Expand All @@ -329,6 +329,6 @@ for Neural ODE is given by `nn_ode.p`:

```@example mnist
# Train the NN-ODE and monitor the loss and weights.
res = Optimization.solve(opt_prob, opt, dataloader; callback, maxiters = 5)
res = Optimization.solve(opt_prob, opt; callback, epochs = 5)
accuracy(m, dataloader, res.u, st)
```
Loading
Loading