diff --git a/Project.toml b/Project.toml index d003fdcb22..5aa1a1f203 100644 --- a/Project.toml +++ b/Project.toml @@ -53,6 +53,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [targets] -test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"] +test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "Tracker"] diff --git a/test/optimise.jl b/test/optimise.jl index e922d3c0b8..7442e4aed4 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -1,5 +1,5 @@ using Flux.Optimise -using Flux.Optimise: runall +using Flux.Optimise: runall, ZygoteImplicitBackend, ZygoteExplicitBackend using Flux: Params, gradient import FillArrays, ComponentArrays using Test @@ -45,6 +45,36 @@ end end end +@testset "AD backends" begin + # this is hack to make Tracker work + AD.gradient(::AD.TrackerBackend, f, xs...) = Tracker.withgradient(f, xs...).grad + AD.value_and_gradient(::AD.TrackerBackend, f, xs...) = Tracker.withgradient(f, xs...) + + function _loss_and_model(::ZygoteImplicitBackend, loss, model) + return () -> loss(model), Flux.params(model) + end + _loss_and_model(ad, loss, model) = loss, model + + function _check_gradient(::ZygoteImplicitBackend, model, grad) + return grad[model[1].weight] == 2 .* Flux.ones32(5, 10) && + grad[model[2].weight] == 10 .* Flux.ones32(2, 5) + end + function _check_gradient(ad, model, grad) + return grad[1].layers[1].weight == 2 .* Flux.ones32(5, 10) && + grad[1].layers[2].weight == 10 .* Flux.ones32(2, 5) + end + + @testset for ad in [ZygoteImplicitBackend(), ZygoteExplicitBackend(), AD.TrackerBackend()] + model = Chain(Dense(Flux.ones32(5, 10), false), Dense(Flux.ones32(2, 5), false)) + x = Flux.ones32(10) + _loss, _model = _loss_and_model(ad, m -> sum(m(x)), model) + val, grad = AD.value_and_gradient(ad, _loss, _model) + @test val == sum(model(x)) + @test _check_gradient(ad, model, grad) + @test _check_gradient(ad, model, AD.gradient(ad, _loss, _model)) + end +end + @testset "Training Loop" begin i = 0 l = 1 diff --git a/test/runtests.jl b/test/runtests.jl index 9027b114fc..ed04582b32 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,7 +5,9 @@ using Flux: params using Test using Random, Statistics, LinearAlgebra using IterTools: ncycle +import Tracker using Zygote +using AbstractDifferentiation using CUDA Random.seed!(0)