Skip to content

Commit

Permalink
docs: update SimpleRNN
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 6, 2025
1 parent 7908b2f commit 5d2a714
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 16 deletions.
7 changes: 1 addition & 6 deletions examples/SimpleRNN/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,15 @@
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"

[compat]
ADTypes = "1.10"
JLD2 = "0.5"
Lux = "1"
LuxCUDA = "0.3"
MLUtils = "0.4"
Optimisers = "0.4.1"
Statistics = "1"
Zygote = "0.6"
29 changes: 19 additions & 10 deletions examples/SimpleRNN/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

# ## Package Imports

using ADTypes, Lux, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random, Statistics
using ADTypes, Lux, JLD2, MLUtils, Optimisers, Printf, Reactant, Random

# ## Dataset

Expand All @@ -34,9 +34,11 @@ function get_dataloaders(; dataset_size=1000, sequence_length=50)
## Create DataLoaders
return (
## Use DataLoader to automatically minibatch and shuffle the data
DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
DataLoader(
collect.((x_train, y_train)); batchsize=128, shuffle=true, partial=false),
## Don't shuffle the validation data
DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false, partial=false)
)
end

# ## Creating a Classifier
Expand Down Expand Up @@ -128,31 +130,38 @@ accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
# ## Training the Model

function main(model_type)
dev = gpu_device()
dev = reactant_device()
cdev = cpu_device()

## Get the dataloaders
train_loader, val_loader = get_dataloaders() .|> dev
train_loader, val_loader = get_dataloaders() |> dev

## Create the model
model = model_type(2, 8, 1)
rng = Xoshiro(0)
ps, st = Lux.setup(rng, model) |> dev
ps, st = Lux.setup(Random.default_rng(), model) |> dev

train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
model_compiled = if dev isa ReactantDevice
@compile model(first(train_loader)[1], ps, Lux.testmode(st))
else
model
end
ad = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()

for epoch in 1:25
## Train the model
for (x, y) in train_loader
(_, loss, _, train_state) = Training.single_train_step!(
AutoZygote(), lossfn, (x, y), train_state)

ad, lossfn, (x, y), train_state
)
@printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
end

## Validate the model
st_ = Lux.testmode(train_state.states)
for (x, y) in val_loader
ŷ, st_ = model(x, train_state.parameters, st_)
ŷ, st_ = model_compiled(x, train_state.parameters, st_)
ŷ, y = cdev(ŷ), cdev(y)
loss = lossfn(ŷ, y)
acc = accuracy(ŷ, y)
@printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
Expand Down

0 comments on commit 5d2a714

Please sign in to comment.