Skip to content

Commit

Permalink
docs: run on a small subset of MNIST
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 12, 2024
1 parent 0362d59 commit 4fa64f0
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ DiffEqFluxDataInterpolationsExt = "DataInterpolations"
ADTypes = "1.5"
Aqua = "0.8.7"
BenchmarkTools = "1.5.0"
Boltz = "0.4.1"
Boltz = "0.4.2"
ChainRulesCore = "1"
ComponentArrays = "0.15.17"
ConcreteStructs = "0.2"
Expand Down
3 changes: 1 addition & 2 deletions docs/src/examples/mnist_conv_neural_ode.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ ENV["DATADEPS_ALWAYS_ACCEPT"] = true
function loadmnist(batchsize)
# Load MNIST
dataset = MNIST(; split = :train)
dataset = MNIST(; split = :train)[1:2000] # Partial load for demonstration
imgs = dataset.features
labels_raw = dataset.targets
Expand Down Expand Up @@ -114,6 +114,5 @@ end
# Train the NN-ODE and monitor the loss and weights.
res = Optimization.solve(opt_prob, opt, dataloader; maxiters = 5, callback)
acc = accuracy(m, dataloader, res.u, st)
@assert acc > 0.8 # hide
acc # hide
```
10 changes: 4 additions & 6 deletions docs/src/examples/mnist_neural_ode.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ logitcrossentropy = CrossEntropyLoss(; logits = Val(true))
function loadmnist(batchsize)
# Load MNIST
dataset = MNIST(; split = :train)
dataset = MNIST(; split = :train)[1:2000] # Partial load for demonstration
imgs = dataset.features
labels_raw = dataset.targets
Expand Down Expand Up @@ -104,7 +104,7 @@ end
# Train the NN-ODE and monitor the loss and weights.
res = Optimization.solve(opt_prob, opt, dataloader; callback, maxiters = 5)
@assert accuracy(m, dataloader, res.u, st) > 0.8
accuracy(m, dataloader, res.u, st)
```

## Step-by-Step Description
Expand Down Expand Up @@ -151,7 +151,7 @@ logitcrossentropy = CrossEntropyLoss(; logits = Val(true))
function loadmnist(batchsize)
# Load MNIST
dataset = MNIST(; split = :train)
dataset = MNIST(; split = :train)[1:2000] # Partial load for demonstration
imgs = dataset.features
labels_raw = dataset.targets
Expand Down Expand Up @@ -324,7 +324,5 @@ for Neural ODE is given by `nn_ode.p`:
```@example mnist
# Train the NN-ODE and monitor the loss and weights.
res = Optimization.solve(opt_prob, opt, dataloader; callback, maxiters = 5)
acc = accuracy(m, dataloader, res.u, st)
@assert acc > 0.8 # hide
acc # hide
accuracy(m, dataloader, res.u, st)
```

0 comments on commit 4fa64f0

Please sign in to comment.