-
-
Notifications
You must be signed in to change notification settings - Fork 200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
NNODE fails on simple logistic curve #634
Comments
I think the problem is not strictly speaking with NeuralPDE but that depending on the initialization, the NN sometimes gets stuck in local minima. With networks this small, I think this is a real concern. Your code sometimes actually trains successfully for me using Adam, BFGS, or both. You could try making the network wider or deeper to increase chances of training success. |
I've tried both wider and deeper (as well as both), and sometimes I get a sigmoidish curve, but still way off the true solution, and with many more parameters than a simple basis function approach would require. I was hoping that this example wouldn't require anything large or complex on the neural net side. The kind of problems I'm looking at will have the same problem as this simple logistic curve. |
I'm more confused after playing around with it more.. Sometimes the loss is really low at the end (like 1e-10) but the solution is totally wrong. For successful runs the loss is more like 1e-7. Sometimes I get NaN solutions / parameters without an error. I also tried using Lux instead of Flux as this is recommended in the docs to use double precision. But using Lux I was never able to train successfully, the loss got lower more consistently but the solutions just don't seem to work. Another thing I just realized is that the issue title references NNODE (https://docs.sciml.ai/NeuralPDE/stable/manual/ode/#NeuralPDE.NNODE) but the code doesn't actually use NNODE (or is it used internally?). Maybe the NNODE specialization would work better as this is an ODE here, but I could not get NNODE to work based on the example code there. I always get "Optimization algorithm not found. Either the chosen algorithm is not a valid solver |
Sorry, the subject comes from a discussion of this issue on the Julia Discourse. |
I got Here's my code for using the using Flux
using NeuralPDE
using OrdinaryDiffEq, Optimisers
using Plots
import Lux, OptimizationOptimisers, OptimizationOptimJL
eq = (u, p, t) -> u .* (1 .- u)
tspan = (0.0, 10.0)
u0 = 0.01
prob = ODEProblem(eq, u0, tspan)
ode_sol = solve(prob, Tsit5(), saveat=0.1)
function run()
chain = Flux.Chain(Flux.Dense(1, 20, σ), Flux.Dense(20, 1))
luxchain = Lux.Chain(Lux.Dense(1, 20, Lux.σ), Lux.Dense(20, 1))
opt = OptimizationOptimJL.BFGS()
# opt = OptimizationOptimisers.Adam(0.1)
nnode_sol = solve(prob, NeuralPDE.NNODE(luxchain, opt), dt=1 / 10.0, verbose=true, abstol=1.0e-10, maxiters=5000)
ts = tspan[1]:0.01:tspan[2]
xpred = [nnode_sol(t) for t in ts]
plot(ode_sol, label="ODE solution")
plot!(ts, xpred, label="NNODE solution")
end
run() |
I've been playing with this and applying a new training strategy to it, and I am getting the mean absolute difference between the nnode solution and ode solution is 0.02. With some other strategies, this difference is sometimes very large which confuses me for such a simple problem. Here is the code, I am going to try to make this better
|
Seems like Quadrature Training works pretty well for this, I am getting mean difference in solutions to be 0.01 or less fairly consistently. Here is the code and a graph
|
👍 . I think we should change the docs to use QuadratureTraining in the tutorials and then close this. GridTraining shouldn't be used (as the docs say in other places). |
@ChrisRackauckas we can close this? |
NeuralPDE.jl fails on this simple example of a logistic curve. I'm not sure whether this would be helped by changing the neural net (as the output is between 0 and 1), though I've tried lots of structures, as well as trying to refine the optimization. The ADAM+BFGS should work OK for this example.
The text was updated successfully, but these errors were encountered: