In a perfect world we probably would not need SymbolicNeuralNetworks
. Its motivation mainly comes from Zygote
's inability to handle second-order derivatives in a decent way1. We also note that if Enzyme
matures further, there may be no need for SymoblicNeuralNetworks
anymore in the future. For now (December 2024) SymbolicNeuralNetworks
offer a good way to incorporate derivatives into the loss function.
SymbolicNeuralNetworks
was created to take advantage of Symbolics
for training neural networks by accelerating their evaluation and by simplifying the computation of arbitrary derivatives of the neural network. This package is based on AbstractNeuralNetwork
and can be applied to GeometricMachineLearning
.
SymbolicNeuralNetworks
creates a symbolic expression of the neural network, computes arbitrary combinations of derivatives and uses RuntimeGeneratedFunctions
to compile a Julia
function.
To create a symbolic neural network, we first design a model
with AbstractNeuralNetwork
:
using AbstractNeuralNetworks
c = Chain(Dense(2, 2, tanh), Linear(2, 1))
We now call SymbolicNeuralNetwork
:
using SymbolicNeuralNetworks
nn = SymbolicNeuralNetwork(c)
We now train the neural network by using SymbolicPullback
2:
pb = SymbolicPullback(nn)
using GeometricMachineLearning
# we generate the data and process them with `GeometricMachineLearning.DataLoader`
x_vec = -1.:.1:1.
y_vec = -1.:.1:1.
xy_data = hcat([[x, y] for x in x_vec, y in y_vec]...)
f(x::Vector) = exp.(-sum(x.^2))
z_data = mapreduce(i -> f(xy_data[:, i]), hcat, axes(xy_data, 2))
dl = DataLoader(xy_data, z_data)
nn_cpu = NeuralNetwork(c, CPU())
o = Optimizer(AdamOptimizer(), nn_cpu)
n_epochs = 1000
batch = Batch(10)
o(nn_cpu, dl, batch, n_epochs, pb.loss, pb)
We can also train the neural network with Zygote
-based3 automatic differentiation (AD):
pb_zygote = GeometricMachineLearning.ZygotePullback(FeedForwardLoss())
o(nn_cpu, dl, batch, n_epochs, pb_zygote.loss, pb_zygote)
We are using git hooks, e.g., to enforce that all tests pass before pushing. In order to activate these hooks, the following command must be executed once:
git config core.hooksPath .githooks
Footnotes
-
In some cases it is possible to perform second-order differentiation with
Zygote
, but when this is possible and when it is not is not entirely clear. ↩ -
This example is discussed in detail in the docs. ↩
-
Note that here we can actually use
Zygote
without problems as it does not involve any complicated derivatives. ↩