Skip to content

Commit

Permalink
fix: update to latest Reactant changes
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 4, 2024
1 parent 5c44c6b commit e3f350a
Show file tree
Hide file tree
Showing 7 changed files with 10 additions and 19 deletions.
6 changes: 1 addition & 5 deletions .buildkite/testing.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
steps:
- group: ":julia: (Lux) CUDA GPU"
steps:
- label: ":julia: Julia {{matrix.julia}} + CUDA GPU + {{matrix.group}}"
- label: ":julia: Julia {{matrix.julia}} + CUDA GPU
plugins:
- JuliaCI/julia#v1:
version: "{{matrix.julia}}"
Expand Down Expand Up @@ -45,17 +45,13 @@ steps:
include(joinpath(dir, "../test/runtests.jl"))'
env:
BACKEND_GROUP: "CUDA"
LUX_TEST_GROUP: "{{matrix.group}}"
if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/
timeout_in_minutes: 60
matrix:
setup:
julia:
- "1.10"
- "1"
group:
- "!reactant"
- "reactant"

- group: ":julia: (Lux) AMD GPU"
steps:
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ NNlib = "0.9.24"
Optimisers = "0.3.3"
Preferences = "1.4.3"
Random = "1.10"
Reactant = "0.2.3"
Reactant = "0.2.4"
Reexport = "1.2.2"
ReverseDiff = "1.15"
SIMDTypes = "0.1"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ Optimisers = "0.3.3"
Pkg = "1.10"
Printf = "1.10"
Random = "1.10"
Reactant = "0.2.1"
Reactant = "0.2.4"
StableRNGs = "1"
StaticArrays = "1"
WeightInitializers = "1"
Expand Down
13 changes: 5 additions & 8 deletions docs/src/manual/compiling_lux_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,15 +149,12 @@ function train_model(model, ps, st, dataloader)
train_state = Training.TrainState(model, ps, st, Adam(0.001f0))
for iteration in 1:1000
for (xᵢ, yᵢ) in dataloader
grads, loss, stats, train_state = Training.single_train_step!(
for (i, (xᵢ, yᵢ)) in enumerate(dataloader)
_, loss, _, train_state = Training.single_train_step!(
AutoEnzyme(), MSELoss(), (xᵢ, yᵢ), train_state)
end
if iteration % 100 == 0 || iteration == 1
# We need to do this since scalar outputs are currently expressed as a zero-dim
# array
loss = Array(loss)[]
@printf("Iter: [%4d/%4d]\tLoss: %.8f\n", iteration, 1000, loss)
if (iteration % 100 == 0 || iteration == 1) && i == 1
@printf("Iter: [%4d/%4d]\tLoss: %.8f\n", iteration, 1000, loss)
end
end
end
Expand Down
2 changes: 1 addition & 1 deletion lib/LuxCore/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ EnzymeCore = "0.8.5"
Functors = "0.4.12"
MLDataDevices = "1"
Random = "1.10"
Reactant = "0.2.3"
Reactant = "0.2.4"
ReverseDiff = "1.15"
Setfield = "1"
Tracker = "0.2.34"
Expand Down
2 changes: 1 addition & 1 deletion lib/MLDataDevices/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ MLUtils = "0.4.4"
Metal = "1"
Preferences = "1.4"
Random = "1.10"
Reactant = "0.2"
Reactant = "0.2.4"
RecursiveArrayTools = "3.8"
ReverseDiff = "1.15"
SparseArrays = "1.10"
Expand Down
2 changes: 0 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,6 @@ if ("all" in LUX_TEST_GROUP || "misc" in LUX_TEST_GROUP)
set_preferences!(Lux, "eltype_mismatch_handling" => "none"; force=true)
end

Lux.set_dispatch_doctor_preferences!(; luxcore="error", luxlib="error")

const RETESTITEMS_NWORKERS = parse(
Int, get(ENV, "RETESTITEMS_NWORKERS", string(min(Hwloc.num_physical_cores(), 4))))

Expand Down

0 comments on commit e3f350a

Please sign in to comment.