Skip to content

Commit

Permalink
Test with the new frules
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 24, 2024
1 parent 908c224 commit 031deac
Show file tree
Hide file tree
Showing 8 changed files with 19 additions and 14 deletions.
8 changes: 5 additions & 3 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.10.2"
manifest_format = "2.0"
project_hash = "df8a9208b4276382055ff54a66a4252730918e13"
project_hash = "914538f40e552ac89a85de7921db9eaf76294f1a"

[[deps.ADTypes]]
git-tree-sha1 = "fcdb00b4d412b80ab08e39978e3bdef579e5e224"
Expand Down Expand Up @@ -574,9 +574,11 @@ version = "0.1.20"

[[deps.LuxLib]]
deps = ["ArrayInterface", "ChainRulesCore", "FastBroadcast", "FastClosures", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "LuxCore", "Markdown", "NNlib", "PrecompileTools", "Random", "Reexport", "Statistics", "Strided"]
git-tree-sha1 = "edbf65f5ceb15ebbfad9d03c6a846d83b9a97baf"
git-tree-sha1 = "8143e3dbdcfff587e9595b58c4b637e74c090fbf"
repo-rev = "ap/more_frules"
repo-url = "https://github.com/LuxDL/LuxLib.jl.git"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
version = "0.3.16"
version = "0.3.17"

[deps.LuxLib.extensions]
LuxLibAMDGPUExt = "AMDGPU"
Expand Down
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Expand Down
6 changes: 4 additions & 2 deletions docs/src/tutorials/basic_mnist_deq.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ function train_model(
for (i, (x, y)) in enumerate(data_train)
res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
Optimisers.update!(opt_st, ps, res.grad[3])
i % 50 == 1 && @printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(data_train) res.val
i % 50 == 1 &&
@printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(data_train) res.val
end
acc = accuracy(model, data_test, ps, model_st.st) * 100
Expand All @@ -151,7 +152,8 @@ function train_model(
for (i, (x, y)) in enumerate(data_train)
res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
Optimisers.update!(opt_st, ps, res.grad[3])
i % 50 == 1 && @printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(data_train) res.val
i % 50 == 1 &&
@printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(data_train) res.val
end
acc = accuracy(model, data_test, ps, model_st.st) * 100
Expand Down
6 changes: 4 additions & 2 deletions docs/src/tutorials/reduced_dim_deq.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ function train_model(
for (i, (x, y)) in enumerate(data_train)
res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
Optimisers.update!(opt_st, ps, res.grad[3])
i % 50 == 1 && @printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(data_train) res.val
i % 50 == 1 &&
@printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(data_train) res.val
end
acc = accuracy(model, data_test, ps, model_st.st) * 100
Expand All @@ -145,7 +146,8 @@ function train_model(
for (i, (x, y)) in enumerate(data_train)
res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
Optimisers.update!(opt_st, ps, res.grad[3])
i % 50 == 1 && @printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(data_train) res.val
i % 50 == 1 &&
@printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(data_train) res.val
end
acc = accuracy(model, data_test, ps, model_st.st) * 100
Expand Down
1 change: 0 additions & 1 deletion ext/DeepEquilibriumNetworksZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ function CRC.rrule(
end

## Don't remove `ad`. See https://github.com/ericphanson/ExplicitImports.jl/issues/33
## FIXME: This will be broken in the new Lux release let's fix this
function DEQs.__estimate_jacobian_trace(ad::AutoZygote, model, z, x, rng)
return mean(abs2, __compute_vector_jacobian_product(model, model.ps, z, x, rng))
end
Expand Down
5 changes: 2 additions & 3 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -314,9 +314,8 @@ julia> model(x, ps, st);
```
"""
function MultiScaleDeepEquilibriumNetwork(
main_layers::Tuple, mapping_layers::Matrix, post_fuse_layer::Union{Nothing, Tuple},
solver, scales; kwargs...)
function MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix,
post_fuse_layer::Union{Nothing, Tuple}, solver, scales; kwargs...)
l1 = Parallel(nothing, main_layers...)
l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...)

Expand Down
4 changes: 2 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ CRC.@non_differentiable __zeros_init(::Any, ::Any)
## Don't rely on SciMLSensitivity's choice
@inline __default_sensealg(prob) = nothing

@inline function __gaussian_like(rng::AbstractRNG, x)
y = similar(x)
@inline function __gaussian_like(rng::AbstractRNG, x::AbstractArray)
y = similar(x)::typeof(x)
randn!(rng, y)
return y
end
Expand Down
2 changes: 1 addition & 1 deletion test/layers_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ end
jacobian_regularizations = ongpu ? _jacobian_regularizations[1:(end - 1)] :
_jacobian_regularizations

@testset "Solver: $(__nameof(solver))" for solver in SOLVERS,
@testset "Solver: $(__nameof(solver)) | Model Type: $(mtype) | Jac. Reg: $(jacobian_regularization)" for solver in SOLVERS,
mtype in model_type,
jacobian_regularization in jacobian_regularizations

Expand Down

0 comments on commit 031deac

Please sign in to comment.