Skip to content

Commit

Permalink
Added test that compares symbolic pullback to zygote pullback.
Browse files Browse the repository at this point in the history
Flipped order of functions.
  • Loading branch information
benedict-96 committed Dec 16, 2024
1 parent f495dc8 commit 5ebf5f1
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 7 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
AbstractNeuralNetworks = "0.3, 0.4"
Documenter = "1.8.0"
ForwardDiff = "0.10.38"
GeometricMachineLearning = "0.3.7"
Latexify = "0.16.5"
RuntimeGeneratedFunctions = "0.5"
Symbolics = "5, 6"
Expand All @@ -28,6 +29,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
GeometricMachineLearning = "194d25b2-d3f5-49f0-af24-c124f4aa80cc"

[targets]
test = ["Test", "ForwardDiff", "Random", "Documenter", "Latexify", "SafeTestsets", "Zygote"]
test = ["Test", "ForwardDiff", "Random", "Documenter", "Latexify", "SafeTestsets", "Zygote", "GeometricMachineLearning"]
11 changes: 6 additions & 5 deletions src/derivatives/pullback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,14 @@ import Random
Random.seed!(123)
c = Chain(Dense(2, 1, tanh))
nn = SymbolicNeuralNetwork(c)
nn = NeuralNetwork(c)
snn = SymbolicNeuralNetwork(nn)
loss = FeedForwardLoss()
pb = SymbolicPullback(nn, loss)
ps = initialparameters(c) |> NeuralNetworkParameters
pb = SymbolicPullback(snn, loss)
input_output = (rand(2), rand(1))
loss_and_pullback = pb(ps, nn.model, input_output)
pv_values = loss_and_pullback[2](1)
loss_and_pullback = pb(nn.params, nn.model, input_output)
# note that we apply the second argument to another input `1`
pb_values = loss_and_pullback[2](1)
@variables soutput[1:SymbolicNeuralNetworks.output_dimension(nn.model)]
symbolic_pullbacks = SymbolicNeuralNetworks.symbolic_pullback(loss(nn.model, nn.params, nn.input, soutput), nn)
Expand Down
38 changes: 38 additions & 0 deletions test/derivatives/pullback.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
using SymbolicNeuralNetworks
using SymbolicNeuralNetworks: _get_params, _get_contents
using AbstractNeuralNetworks
using Symbolics
using GeometricMachineLearning: ZygotePullback
using Test
import Random
Random.seed!(123)

compare_values(arr1::Array, arr2::Array) = @test arr1 arr2
function compare_values(nt1::NamedTuple, nt2::NamedTuple)
@assert keys(nt1) == keys(nt2)
NamedTuple{keys(nt1)}((compare_values(arr1, arr2) for (arr1, arr2) in zip(values(nt1), values(nt2))))
end

function compare_symbolic_pullback_to_zygote_pullback(input_dim::Integer, output_dim::Integer, second_dim::Integer=1)
c = Chain(Dense(input_dim, output_dim, tanh))
nn = NeuralNetwork(c)
snn = SymbolicNeuralNetwork(nn)
loss = FeedForwardLoss()
spb = SymbolicPullback(snn, loss)
input_output = (rand(input_dim, second_dim), rand(output_dim, second_dim))
loss_and_pullback = spb(nn.params, nn.model, input_output)
# note that we apply the second argument to another input `1`
pb_values = loss_and_pullback[2](1)

zpb = ZygotePullback(loss)
loss_and_pullback_zygote = zpb(nn.params, nn.model, input_output)
pb_values_zygote = loss_and_pullback_zygote[2](1) |> _get_contents |> _get_params

compare_values(pb_values, pb_values_zygote)
end

for input_dim (2, 3)
for output_dim (1, 2)
compare_symbolic_pullback_to_zygote_pullback(input_dim, output_dim)
end
end
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ using SafeTestsets
@safetestset "Symbolic Params " begin include("symbolic_neuralnet/symbolize.jl") end
@safetestset "Tests associated with 'build_function.jl' " begin include("build_function/build_function.jl") end
@safetestset "Tests associated with 'build_function_double_input.jl' " begin include("build_function/build_function_double_input.jl") end
@safetestset "Tests associated with 'build_function_array.jl " begin include("build_function/build_function_arrays.jl") end
@safetestset "Tests associated with 'build_function_array.jl " begin include("build_function/build_function_arrays.jl") end
@safetestset "Compare Zygote Pullback with Symbolic Pullback " begin include("derivatives/pullback.jl") end

0 comments on commit 5ebf5f1

Please sign in to comment.