Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Physics informed neural operator ode #806

Merged
merged 160 commits into from
Oct 30, 2024
Merged

Physics informed neural operator ode #806

merged 160 commits into from
Oct 30, 2024

Conversation

KirillZubov
Copy link
Member

@KirillZubov KirillZubov commented Feb 12, 2024

Implementation Physics-informed neural operator method for solve parametric Ordinary Differential Equations (ODE) use DeepOnet.

#575

Checklist

  • pino ode
  • family ode by parameter
  • physics informed DeepOnet
  • tests
  • addition loss test
  • doc
  • multiple parameters
  • test with vector outputs and multiple parameters
  • imigrate to LuxNeuralOperators
  • interpretation output with another mesh
  • vector output Vector output for PINO ODE #871
  • update docs

https://arxiv.org/abs/2103.10974
https://arxiv.org/abs/2111.03794

@KirillZubov
Copy link
Member Author

@ChrisRackauckas I need help with packages version. Adding dependency of NeuralOperator.jl to project, fail CI. I've tried a little to line up suitable versions but not success.

@ChrisRackauckas
Copy link
Member

What's left here before review?

@KirillZubov
Copy link
Member Author

What's left here before review?

@ChrisRackauckas all is done here, but it is necessary for the tests to pass, but this is related to Lux updates and not to the code in this PR

@ChrisRackauckas
Copy link
Member

These two tests don't pass and that doesn't seem lux related? https://github.com/SciML/NeuralPDE.jl/actions/runs/11405946017/job/31738592363?pr=806#step:6:1096

It just needs to match https://github.com/SciML/NeuralPDE.jl/actions/runs/11395770870/job/31708526179 and make sure the new tests pass, but the new tests you added don't pass.

Comment on lines +100 to +117
function physics_loss(
phi::PINOPhi{C, T}, prob::ODEProblem, x::Tuple, θ) where {C <: DeepONet, T}
p, t = x
f = prob.f
out = phi(x, θ)
if size(p, 1) == 1
f_vec = reduce(hcat,
[reduce(vcat, [f(out[j, i], p[1, i], t[j]) for j in axes(t, 2)])
for i in axes(p, 2)])
else
f_vec = reduce(hcat,
[reduce(vcat, [f(out[j, i], p[:, i], t[j]) for j in axes(t, 2)])
for i in axes(p, 2)])
end
du = dfdx(phi, x, θ)
norm = prod(size(du))
sum(abs2, du .- f_vec) / norm
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why can't this use the ODE code?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here it has parameters and coordinates on input, unlike the usual ODE with just cord, also ODE does not process DeepONet and NeuralOperator in general. It can't be reduced to instance ODE task. If only on the contrary, ODE is a special case of PinoODE as a parametric ODE with one parameter

Comment on lines +142 to +153
function initial_condition_loss(
phi::PINOPhi{C, T}, prob::ODEProblem, x, θ) where {
C <: DeepONet, T}
p, t = x
t0 = reshape([prob.tspan[1]], (1, 1, 1))
x0 = (p, t0)
u = phi(x0, θ)
u0 = size(prob.u0, 1) == 1 ? fill(prob.u0, size(u)) :
reduce(vcat, [fill(u0, size(u)) for u0 in prob.u0])
norm = prod(size(u0))
sum(abs2, u .- u0) / norm
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why can't this use the ODE code?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the same here

@ChrisRackauckas
Copy link
Member

Seems a bit odd to take over sol.interp like that, what does sol(t) do in this case?

@KirillZubov
Copy link
Member Author

KirillZubov commented Oct 28, 2024

Seems a bit odd to take over sol.interp like that, what does sol(t) do in this case?

it wasn't identified sol(t). I added sol(t) and replace interp with sol(t)

@KirillZubov
Copy link
Member Author

it is reason why CI fail LuxDL/LuxLib.jl#179

@test ground_solution≈predict_sol rtol=0.05
p, t = get_trainset(chain, bounds, 100, tspan, 0.01)
ground_solution = ground_analytic.(u0, p, t)
predict_sol = sol(reduce(vcat, (p, t)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it doesn't match the documented interface though, so it's really odd. sol.original(...) would be allowed to do this, but what I was saying is that we shouldn't have an interface break here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you think it would be better ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sol(t) should do the normal interoplation w.r.t. p in the ODEProblem. You can overload sol.interp to do extra things and document that as well, but sol(t) is described as having very specific behavior which should be kept the same with all other ODE solvers.

SciMLBase.allowscomplex(::PINOODE) = true

function (sol::SciMLBase.AbstractODESolution)(t::AbstractArray)
p, _ = sol.t
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand how .t would have p?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how can we place t and parameters in ODEsolution separately without override t is how all input data(t and p), or is it better identify PINOODEsolution here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would've guessed you'd do:

function (sol::SciMLBase.AbstractODESolution)(t::Union{Number,AbstractArray})
  sol.interp(sol.prob.p, t)
end

? If that works then this PR is complete.

Copy link
Member Author

@KirillZubov KirillZubov Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, but with Chain it need that size(t) == size(p). I guess, It need add the warning "t should be same size as p"

or

function (sol::SciMLBase.AbstractODESolution)(t::Union{Number,AbstractArray})
   p = gen(sol.prob.p, size(t)) # generate p same size as t 
   sol.interp(p, t)
end

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've implemented it with the warning

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alright, I think we're good to go.

@ChrisRackauckas ChrisRackauckas merged commit 783b220 into master Oct 30, 2024
35 of 56 checks passed
@ChrisRackauckas ChrisRackauckas deleted the pino_ode branch October 30, 2024 15:53
@ChrisRackauckas
Copy link
Member

🎉 🎉 🎉 🎉 🎉 🎉

@KirillZubov
Copy link
Member Author

🥳🥳🥳🥳🥳

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants