Skip to content

Commit

Permalink
Add tests for AD backends
Browse files Browse the repository at this point in the history
  • Loading branch information
darsnack committed Nov 1, 2022
1 parent fbde477 commit 37c9759
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 2 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[targets]
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"]
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "Tracker"]
32 changes: 31 additions & 1 deletion test/optimise.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Flux.Optimise
using Flux.Optimise: runall
using Flux.Optimise: runall, ZygoteImplicitBackend, ZygoteExplicitBackend
using Flux: Params, gradient
import FillArrays, ComponentArrays
using Test
Expand Down Expand Up @@ -45,6 +45,36 @@ end
end
end

@testset "AD backends" begin
# this is hack to make Tracker work
AD.gradient(::AD.TrackerBackend, f, xs...) = Tracker.withgradient(f, xs...).grad
AD.value_and_gradient(::AD.TrackerBackend, f, xs...) = Tracker.withgradient(f, xs...)

function _loss_and_model(::ZygoteImplicitBackend, loss, model)
return () -> loss(model), Flux.params(model)
end
_loss_and_model(ad, loss, model) = loss, model

function _check_gradient(::ZygoteImplicitBackend, model, grad)
return grad[model[1].weight] == 2 .* Flux.ones32(5, 10) &&
grad[model[2].weight] == 10 .* Flux.ones32(2, 5)
end
function _check_gradient(ad, model, grad)
return grad[1].layers[1].weight == 2 .* Flux.ones32(5, 10) &&
grad[1].layers[2].weight == 10 .* Flux.ones32(2, 5)
end

@testset for ad in [ZygoteImplicitBackend(), ZygoteExplicitBackend(), AD.TrackerBackend()]
model = Chain(Dense(Flux.ones32(5, 10), false), Dense(Flux.ones32(2, 5), false))
x = Flux.ones32(10)
_loss, _model = _loss_and_model(ad, m -> sum(m(x)), model)
val, grad = AD.value_and_gradient(ad, _loss, _model)
@test val == sum(model(x))
@test _check_gradient(ad, model, grad)
@test _check_gradient(ad, model, AD.gradient(ad, _loss, _model))
end
end

@testset "Training Loop" begin
i = 0
l = 1
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ using Flux: params
using Test
using Random, Statistics, LinearAlgebra
using IterTools: ncycle
import Tracker
using Zygote
using AbstractDifferentiation
using CUDA

Random.seed!(0)
Expand Down

0 comments on commit 37c9759

Please sign in to comment.