diff --git a/Manifest.toml b/Manifest.toml index ca0bc6d1..87ebbe10 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.10.2" manifest_format = "2.0" -project_hash = "df8a9208b4276382055ff54a66a4252730918e13" +project_hash = "914538f40e552ac89a85de7921db9eaf76294f1a" [[deps.ADTypes]] git-tree-sha1 = "fcdb00b4d412b80ab08e39978e3bdef579e5e224" @@ -574,9 +574,11 @@ version = "0.1.20" [[deps.LuxLib]] deps = ["ArrayInterface", "ChainRulesCore", "FastBroadcast", "FastClosures", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "LuxCore", "Markdown", "NNlib", "PrecompileTools", "Random", "Reexport", "Statistics", "Strided"] -git-tree-sha1 = "edbf65f5ceb15ebbfad9d03c6a846d83b9a97baf" +git-tree-sha1 = "8143e3dbdcfff587e9595b58c4b637e74c090fbf" +repo-rev = "ap/more_frules" +repo-url = "https://github.com/LuxDL/LuxLib.jl.git" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" -version = "0.3.16" +version = "0.3.17" [deps.LuxLib.extensions] LuxLibAMDGPUExt = "AMDGPU" diff --git a/Project.toml b/Project.toml index d3bab847..37d5d09d 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" +LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" diff --git a/docs/src/tutorials/basic_mnist_deq.md b/docs/src/tutorials/basic_mnist_deq.md index 644f52fe..3684f4a7 100644 --- a/docs/src/tutorials/basic_mnist_deq.md +++ b/docs/src/tutorials/basic_mnist_deq.md @@ -138,7 +138,8 @@ function train_model( for (i, (x, y)) in enumerate(data_train) res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y) Optimisers.update!(opt_st, ps, res.grad[3]) - i % 50 == 1 && @printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(data_train) res.val + i % 50 == 1 && + @printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(data_train) res.val end acc = accuracy(model, data_test, ps, model_st.st) * 100 @@ -151,7 +152,8 @@ function train_model( for (i, (x, y)) in enumerate(data_train) res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y) Optimisers.update!(opt_st, ps, res.grad[3]) - i % 50 == 1 && @printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(data_train) res.val + i % 50 == 1 && + @printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(data_train) res.val end acc = accuracy(model, data_test, ps, model_st.st) * 100 diff --git a/docs/src/tutorials/reduced_dim_deq.md b/docs/src/tutorials/reduced_dim_deq.md index c91f5fcd..9f72ac69 100644 --- a/docs/src/tutorials/reduced_dim_deq.md +++ b/docs/src/tutorials/reduced_dim_deq.md @@ -132,7 +132,8 @@ function train_model( for (i, (x, y)) in enumerate(data_train) res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y) Optimisers.update!(opt_st, ps, res.grad[3]) - i % 50 == 1 && @printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(data_train) res.val + i % 50 == 1 && + @printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(data_train) res.val end acc = accuracy(model, data_test, ps, model_st.st) * 100 @@ -145,7 +146,8 @@ function train_model( for (i, (x, y)) in enumerate(data_train) res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y) Optimisers.update!(opt_st, ps, res.grad[3]) - i % 50 == 1 && @printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(data_train) res.val + i % 50 == 1 && + @printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(data_train) res.val end acc = accuracy(model, data_test, ps, model_st.st) * 100 diff --git a/ext/DeepEquilibriumNetworksZygoteExt.jl b/ext/DeepEquilibriumNetworksZygoteExt.jl index 688bd2ca..a04697e0 100644 --- a/ext/DeepEquilibriumNetworksZygoteExt.jl +++ b/ext/DeepEquilibriumNetworksZygoteExt.jl @@ -50,7 +50,6 @@ function CRC.rrule( end ## Don't remove `ad`. See https://github.com/ericphanson/ExplicitImports.jl/issues/33 -## FIXME: This will be broken in the new Lux release let's fix this function DEQs.__estimate_jacobian_trace(ad::AutoZygote, model, z, x, rng) return mean(abs2, __compute_vector_jacobian_product(model, model.ps, z, x, rng)) end diff --git a/src/layers.jl b/src/layers.jl index 466b9a64..995f94db 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -314,9 +314,8 @@ julia> model(x, ps, st); ``` """ -function MultiScaleDeepEquilibriumNetwork( - main_layers::Tuple, mapping_layers::Matrix, post_fuse_layer::Union{Nothing, Tuple}, - solver, scales; kwargs...) +function MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix, + post_fuse_layer::Union{Nothing, Tuple}, solver, scales; kwargs...) l1 = Parallel(nothing, main_layers...) l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...) diff --git a/src/utils.jl b/src/utils.jl index dfc13210..647636dc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -87,8 +87,8 @@ CRC.@non_differentiable __zeros_init(::Any, ::Any) ## Don't rely on SciMLSensitivity's choice @inline __default_sensealg(prob) = nothing -@inline function __gaussian_like(rng::AbstractRNG, x) - y = similar(x) +@inline function __gaussian_like(rng::AbstractRNG, x::AbstractArray) + y = similar(x)::typeof(x) randn!(rng, y) return y end diff --git a/test/layers_tests.jl b/test/layers_tests.jl index 75b6f68d..aa19ea45 100644 --- a/test/layers_tests.jl +++ b/test/layers_tests.jl @@ -34,7 +34,7 @@ end jacobian_regularizations = ongpu ? _jacobian_regularizations[1:(end - 1)] : _jacobian_regularizations - @testset "Solver: $(__nameof(solver))" for solver in SOLVERS, + @testset "Solver: $(__nameof(solver)) | Model Type: $(mtype) | Jac. Reg: $(jacobian_regularization)" for solver in SOLVERS, mtype in model_type, jacobian_regularization in jacobian_regularizations