Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Derivative in loss function error #1464

Open
stelmo opened this issue Jan 13, 2021 · 8 comments
Open

Derivative in loss function error #1464

stelmo opened this issue Jan 13, 2021 · 8 comments

Comments

@stelmo
Copy link

stelmo commented Jan 13, 2021

Hi, I am trying to implement a PINN as described here using Flux. Essentially, I am trying to train a neural network that includes the time derivative of it in the loss function (time is one of its inputs). Below is a very minimal example:

using Flux

m = Chain(Dense(3, 10, relu), Dense(10, 10, relu), Dense(10, 1)) # [u0, k, t] -> u(t)
ps = Flux.params(m)

function loss(x, y)
    fitloss = Flux.Losses.mse(m(x), y) # typical loss function

    derivativeloss = abs2(gradient(a -> m(a)[1], x)[1][3]) #  problem source (3rd input is time)
    
    return fitloss + derivativeloss
end

xt = rand(3)
yt = rand(1)

gs = gradient(ps) do
    loss(xt, yt)
end # this generates a foreigncall exception

This issue seems to be pervasive, see here and #1338 and #1257 and here (the last one is me on the discourse channel). I have tried all the suggestions in the aforementioned links, but nothing seems to work. Do you have a work around or is this some built in limitation of Flux/Zygote?

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Jan 13, 2021

Thanks for brining this up! If you remove the need to use params, you get a more meaningful error.

function loss(m, x, y)
    fitloss = Flux.Losses.mse(m(x), y) # typical loss function

    derivativeloss = abs2(gradient(a -> m(a)[1], x)[1][3]) #  problem source (3rd input is time)
    
    return fitloss + derivativeloss
end

gs = gradient(m, xt, yt) do m, x, y
    loss(m, x, y)
end 

julia> gs = gradient(m, xt, yt) do m, x, y
           loss(m, x, y)
       end # this generates a foreigncall exception
ERROR: Mutating arrays is not supported
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] (::Zygote.var"#380#381")(::Nothing) at /Users/dhairyagandhi/Downloads/new_clones/Zygote.jl/src/lib/array.jl:58
 [3] (::Zygote.var"#2288#back#382"{Zygote.var"#380#381"})(::Nothing) at /Users/dhairyagandhi/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
 [4] (::Zygote.var"#152#153"{Zygote.var"#2288#back#382"{Zygote.var"#380#381"},Tuple{Tuple{Nothing,Nothing},Tuple{Nothing}}})(::Nothing) at /Users/dhairyagandhi/Downloads/new_clones/Zygote.jl/src/lib/lib.jl:191
 [5] (::Zygote.var"#1699#back#154"{Zygote.var"#152#153"{Zygote.var"#2288#back#382"{Zygote.var"#380#381"},Tuple{Tuple{Nothing,Nothing},Tuple{Nothing}}}})(::Nothing) at /Users/dhairyagandhi/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
 [6] #372 at /Users/dhairyagandhi/Downloads/new_clones/Zygote.jl/src/lib/array.jl:38 [inlined]

The same can be seen using FluxML/Zygote.jl#823 directly as well.

The Fix

This comes from picking out the first element of the model output, instead if you use something like

julia> function loss(m, x, y)
           fitloss = Flux.Losses.mse(m(x), y) # typical loss function

           derivativeloss = abs2(gradient(a -> sum(m(a)), x)[1][3]) #  problem source (3rd input is time)

           return fitloss + derivativeloss
       end

(notice the call to sum instead of the first element), then things work fine.

In cases where we do need the element wise grads, sum doesn't quite cut it, and we need to write the getindex adjoint to not do array mutation (which it does for perf). We could always write a fallback rule that handles it as well.

@stelmo
Copy link
Author

stelmo commented Jan 13, 2021

Thank you very much! This fixes the one dimensional case :) I will add this thread to my discourse question (as well as the other question on there).

How would one write the adjoint for the getindex function? I will need it to extend the PINN system to multiple dimensions...

I looked at the source code for Base.getindex and it is not very clear to me what's going on - hopefully I won't have to change anything there? The documentation of Zygote mentions that one should use ChainRules to define custom adjoint rules. Should I use that instead to define a fallback rule?

@DhairyaLGandhi
Copy link
Member

What would the multidimensional case entail?

You're better off checking out the code for the get index adjoint in Zygote in src/lib/array.jl and write an adjoint to that

@stelmo
Copy link
Author

stelmo commented Jan 13, 2021

Essentially I mean element wise gradients e.g.

m = Chain(Dense(5, 10, relu), Dense(10, 10, relu), Dense(10, 2)) # [u0_1, u0_2, k1, k2, t] -> [u1(t), u2(t)]

function loss(m, x, y)
    fitloss = Flux.Losses.mse(m(x), y) # typical loss function
    
    derivativeloss = 0.0f0
    
    for i=1:size(x, 2)
        for j=1:2 # dimension to take gradient of
            derivativeloss = abs2( gradient(a -> m(a)[j], x[:, i])[1][5] ) # ||du_j/dt|| for j=1,2 this mutates again :(
        end
    end
    
    return fitloss + derivativeloss
end

xt = rand(5, 10)
yt = rand(2, 10)

gs = gradient(m, xt, yt) do m, x, y
    loss(m, x, y)
end

Okay, I've looked in Zygote/src/lib/array.jl and found this:

@adjoint getindex(x::AbstractArray, inds...) = x[inds...], ∇getindex(x, inds)

∇getindex(x::AbstractArray, inds) = dy -> begin
  if inds isa  NTuple{<:Any, Integer}
    dx = _zero(x, typeof(dy))
    dx[inds...] = dy
  else
    dx = _zero(x, eltype(dy))
    dxv = view(dx, inds...)
    dxv .= accum.(dxv, _droplike(dy, dxv))
  end
  return (dx, map(_->nothing, inds)...)
end

Is the mutation happening here:

dxv = view(dx, inds...)
dxv .= accum.(dxv, _droplike(dy, dxv))

Would I have to change that somehow (I am not sure at all what to do)? I really appreciate your help in this!

@DhairyaLGandhi
Copy link
Member

Right, anywhere setindex gets in the way.

@mkalia94
Copy link

Bump.
Zygote.jl now has a jacobian function but results in the same mutating array error. I think this is in general very critical for physics informed machine learning, would be great to have some clarity on this! @DhairyaLGandhi I'm not quite sure I get your last comment. Thanks!

@ToucheSir
Copy link
Member

This is totally not my wheelhouse, but there are a few threads on discourse about representing PINNs that you could look into.

@NagaChaitanya96
Copy link

Essentially I mean element wise gradients e.g.

m = Chain(Dense(5, 10, relu), Dense(10, 10, relu), Dense(10, 2)) # [u0_1, u0_2, k1, k2, t] -> [u1(t), u2(t)]

function loss(m, x, y)
    fitloss = Flux.Losses.mse(m(x), y) # typical loss function
    
    derivativeloss = 0.0f0
    
    for i=1:size(x, 2)
        for j=1:2 # dimension to take gradient of
            derivativeloss = abs2( gradient(a -> m(a)[j], x[:, i])[1][5] ) # ||du_j/dt|| for j=1,2 this mutates again :(
        end
    end
    
    return fitloss + derivativeloss
end

xt = rand(5, 10)
yt = rand(2, 10)

gs = gradient(m, xt, yt) do m, x, y
    loss(m, x, y)
end

Okay, I've looked in Zygote/src/lib/array.jl and found this:

@adjoint getindex(x::AbstractArray, inds...) = x[inds...], ∇getindex(x, inds)

∇getindex(x::AbstractArray, inds) = dy -> begin
  if inds isa  NTuple{<:Any, Integer}
    dx = _zero(x, typeof(dy))
    dx[inds...] = dy
  else
    dx = _zero(x, eltype(dy))
    dxv = view(dx, inds...)
    dxv .= accum.(dxv, _droplike(dy, dxv))
  end
  return (dx, map(_->nothing, inds)...)
end

Is the mutation happening here:

dxv = view(dx, inds...)
dxv .= accum.(dxv, _droplike(dy, dxv))

Would I have to change that somehow (I am not sure at all what to do)? I really appreciate your help in this!

This code isnt correct right? I'm getting error, at the gradient part.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants