You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This might not be a bug, but at the least it seems to be a mysterious error.
I'm using Flux and in the loss function (which is processed by Zygote and so uses ChainRules.jl), I'm using prod. The surprising thing was that it ran for several iterations before failing with a scalar index error.
I'm guessing the error happens when it goes into this branch because it finds a zero somewhere. That would explain why some iterations occur before hitting this error. I don't know if the code can be changed to avoid scalar indexing, or maybe a more informative error, or maybe it's just something I need to better understand.
I forgot how this works, but as you observe, the code which is careful about zeros is behind a test any(iszero, x), precisely to let all-nonzero x work without error on a GPU.
It wouldn't be crazy to add an explicit message there, like x isa AbstractGPUArray && @error "rule for prod found zeros..." maxlog=1. The reason for the error is otherwise a bit mysterious, as e.g. it may work for the first iteration but later fail. Usually scalar indexing errors depend only on the types not the values.
The alternatives are just to use the broadcasting branch on GPU arrays (and accept that getting NaN is your signal that something is wrong), or to write a GPU kernel to do this correctly (perhaps using KernelAbstractions to be device-agnostic).
That's for ChainRules. Can you share more about what your actual use is? Inserting something like clamp.(x, 0.001, 0.99) is one way you might avoid problems.
This might not be a bug, but at the least it seems to be a mysterious error.
I'm using Flux and in the loss function (which is processed by Zygote and so uses ChainRules.jl), I'm using
prod
. The surprising thing was that it ran for several iterations before failing with a scalar index error.I'm guessing the error happens when it goes into this branch because it finds a zero somewhere. That would explain why some iterations occur before hitting this error. I don't know if the code can be changed to avoid scalar indexing, or maybe a more informative error, or maybe it's just something I need to better understand.
The text was updated successfully, but these errors were encountered: