diff --git a/Project.toml b/Project.toml index 080b343..22fb6ef 100644 --- a/Project.toml +++ b/Project.toml @@ -14,6 +14,7 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" AbstractNeuralNetworks = "0.3, 0.4" Documenter = "1.8.0" ForwardDiff = "0.10.38" +GeometricMachineLearning = "0.3.7" Latexify = "0.16.5" RuntimeGeneratedFunctions = "0.5" Symbolics = "5, 6" @@ -28,6 +29,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +GeometricMachineLearning = "194d25b2-d3f5-49f0-af24-c124f4aa80cc" [targets] -test = ["Test", "ForwardDiff", "Random", "Documenter", "Latexify", "SafeTestsets", "Zygote"] +test = ["Test", "ForwardDiff", "Random", "Documenter", "Latexify", "SafeTestsets", "Zygote", "GeometricMachineLearning"] diff --git a/src/derivatives/pullback.jl b/src/derivatives/pullback.jl index dba9c53..bd48d4d 100644 --- a/src/derivatives/pullback.jl +++ b/src/derivatives/pullback.jl @@ -47,13 +47,14 @@ import Random Random.seed!(123) c = Chain(Dense(2, 1, tanh)) -nn = SymbolicNeuralNetwork(c) +nn = NeuralNetwork(c) +snn = SymbolicNeuralNetwork(nn) loss = FeedForwardLoss() -pb = SymbolicPullback(nn, loss) -ps = initialparameters(c) |> NeuralNetworkParameters +pb = SymbolicPullback(snn, loss) input_output = (rand(2), rand(1)) -loss_and_pullback = pb(ps, nn.model, input_output) -pv_values = loss_and_pullback[2](1) +loss_and_pullback = pb(nn.params, nn.model, input_output) +# note that we apply the second argument to another input `1` +pb_values = loss_and_pullback[2](1) @variables soutput[1:SymbolicNeuralNetworks.output_dimension(nn.model)] symbolic_pullbacks = SymbolicNeuralNetworks.symbolic_pullback(loss(nn.model, nn.params, nn.input, soutput), nn) diff --git a/test/derivatives/pullback.jl b/test/derivatives/pullback.jl new file mode 100644 index 0000000..9eee5c7 --- /dev/null +++ b/test/derivatives/pullback.jl @@ -0,0 +1,38 @@ +using SymbolicNeuralNetworks +using SymbolicNeuralNetworks: _get_params, _get_contents +using AbstractNeuralNetworks +using Symbolics +using GeometricMachineLearning: ZygotePullback +using Test +import Random +Random.seed!(123) + +compare_values(arr1::Array, arr2::Array) = @test arr1 ≈ arr2 +function compare_values(nt1::NamedTuple, nt2::NamedTuple) + @assert keys(nt1) == keys(nt2) + NamedTuple{keys(nt1)}((compare_values(arr1, arr2) for (arr1, arr2) in zip(values(nt1), values(nt2)))) +end + +function compare_symbolic_pullback_to_zygote_pullback(input_dim::Integer, output_dim::Integer, second_dim::Integer=1) + c = Chain(Dense(input_dim, output_dim, tanh)) + nn = NeuralNetwork(c) + snn = SymbolicNeuralNetwork(nn) + loss = FeedForwardLoss() + spb = SymbolicPullback(snn, loss) + input_output = (rand(input_dim, second_dim), rand(output_dim, second_dim)) + loss_and_pullback = spb(nn.params, nn.model, input_output) + # note that we apply the second argument to another input `1` + pb_values = loss_and_pullback[2](1) + + zpb = ZygotePullback(loss) + loss_and_pullback_zygote = zpb(nn.params, nn.model, input_output) + pb_values_zygote = loss_and_pullback_zygote[2](1) |> _get_contents |> _get_params + + compare_values(pb_values, pb_values_zygote) +end + +for input_dim ∈ (2, 3) + for output_dim ∈ (1, 2) + compare_symbolic_pullback_to_zygote_pullback(input_dim, output_dim) + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 8a79390..edb44a1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,4 +7,5 @@ using SafeTestsets @safetestset "Symbolic Params " begin include("symbolic_neuralnet/symbolize.jl") end @safetestset "Tests associated with 'build_function.jl' " begin include("build_function/build_function.jl") end @safetestset "Tests associated with 'build_function_double_input.jl' " begin include("build_function/build_function_double_input.jl") end -@safetestset "Tests associated with 'build_function_array.jl " begin include("build_function/build_function_arrays.jl") end \ No newline at end of file +@safetestset "Tests associated with 'build_function_array.jl " begin include("build_function/build_function_arrays.jl") end +@safetestset "Compare Zygote Pullback with Symbolic Pullback " begin include("derivatives/pullback.jl") end \ No newline at end of file