-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enzyme using fallback BLAS replacements in deep neural network #692
Comments
THe latter warning implies you aren't using the latest main. What happens when you use the latest Enzyme? Secondly, from |
No, it's using Enzyme in one of the adjoints. Though last I checked on main it'll still fail with Lux until #645 is fixed. The getfield missing derivatives are kind of the last piece of the puzzle. |
Yeah but I presume the Lux AD they are doing successfully isn't going through Enzyme, right? Thus the perf issue they see is Zygote? |
No it would be the Enzyme BLAS fallback. |
Hm okay, in that case -- so I can properly understand, can you make a version of this code that just is a call to Enzyme.autodiff ? |
It would just be the UDE part: function ude!(du,u,p,t,q)
knownPred = knownDynamics(u,p.predefined_params,q)
nnPred = Array(neuralNetwork(u,p.model_params,st)[1])
for i in 1:length(u)
du[i] = knownPred[i]+nnPred[i]
end
end w.r.t. u,p, |
Can you give an example that initializes the internals? Am not familiar with those packages. using Enzyme, Random, Lux, ComponentArrays
rng = Random.default_rng()
neuralNetwork = Lux.Chain(Lux.Dense(2,10),Lux.Dense(10,5),Lux.Dense(5,2))
ps, st = Lux.setup(rng, neuralNetwork)
ps64 = Float64.(ComponentArray(ps))
knownDynamics(x,p,q) = [-p[1].*x[1];
-p[2].*x[2]];
training_data = rand(2,50)
ps_dynamics = ComponentArray((predefined_params = rand(Float64, 2), model_params = ps64))
function ude!(du,u,p,t,q)
knownPred = knownDynamics(u,p.predefined_params,q)
nnPred = Array(neuralNetwork(u,p.model_params,st)[1])
for i in 1:length(u)
du[i] = knownPred[i]+nnPred[i]
end
end
du = ?
d_du = ?
u = ?
d_u = ?
p = ?
d_p = ?
t = ?
q = ?
Enzyme.autodiff(Reverse, ude!, Duplicated(du, d_du), Duplicated(u, d_u), Const(p), Const(d_p), Const(t), Const(q)) |
using Enzyme, Random, Lux, ComponentArrays
rng = Random.default_rng()
neuralNetwork = Lux.Chain(Lux.Dense(2,10),Lux.Dense(10,5),Lux.Dense(5,2))
ps, st = Lux.setup(rng, neuralNetwork)
ps64 = Float64.(ComponentArray(ps))
knownDynamics(x,p) = [-p[1].*x[1];
-p[2].*x[2]];
training_data = rand(2,50)
ps_dynamics = ComponentArray((predefined_params = rand(Float64, 2), model_params = ps64))
function ude!(du,u,p,t)
knownPred = knownDynamics(u,p.predefined_params)
nnPred = Array(neuralNetwork(u,p.model_params,st)[1])
for i in 1:length(u)
du[i] = knownPred[i]+nnPred[i]
end
nothing
end
du = training_data[:,1]
d_du = training_data[:,1]
u = training_data[:,1]
d_u = training_data[:,1]
p = ps_dynamics
d_p = copy(ps_dynamics)
t = 0.0
ude!(du,u,p,t)
Enzyme.autodiff(Reverse, ude!, Duplicated(du, d_du), Duplicated(u, d_u), Const(p), Const(t))
Enzyme.autodiff(Reverse, ude!, Duplicated(du, d_du), Const(u), DuplicatedNoNeed(p, d_p), Const(t)) |
Hope this is not tangent to the specifics of Lux above, but back to original question regarding slow performance associated with BLAS fallback, this is something that is experiened in the following simple setup: function mymatmul(x::AbstractMatrix, w::AbstractMatrix)
out = sum(w * x)
return out
end;
seed!(123)
bs = 4096
f = 256
h1 = 512
w = randn(h1, f) .* 0.01;
x = randn(f, bs) .* 0.01;
dw = zeros(h1, f);
# dx = zeros(f, bs); Forward pass (0.01 sec) julia> @time mymatmul(x, w)
0.012209 seconds (3 allocations: 16.000 MiB)
0.5842598323078428 Forward-backward (5 secs, over 100X slower) julia> @time _, y = Enzyme.autodiff(ReverseWithPrimal, mymatmul, Const(x), Duplicated(w, dw))
5.066479 seconds (4.11 k allocations: 32.125 MiB)
((nothing, nothing), 0.5842598323078416) This is from current main branch version:
|
@ChrisRackauckas yeah this does not differentiate with Enzyme atm so I'm presuming this is actually Zygote or a different fallback causing the issues. No augmented forward pass found for ijl_box_char
declare nonnull {} addrspace(10)* @ijl_box_char(i32 zeroext) local_unnamed_addr #3
Stacktrace:
[1] julia_error(cstr::Cstring, val::Ptr{LLVM.API.LLVMOpaqueValue}, errtype::Enzyme.API.ErrorType, data::Ptr{Nothing})
@ Enzyme.Compiler ~/git/Enzyme.jl/src/compiler.jl:5040
[2] macro expansion
@ ./logging.jl:362 [inlined]
[3] macro expansion
@ ~/.julia/packages/GPUCompiler/anMCs/src/utils.jl:58 [inlined]
[4] inlining_policy
@ ~/git/Enzyme.jl/src/compiler/interpreter.jl:181
[5] EnzymeCreateAugmentedPrimal(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{Enzyme.API.CDIFFE_TYPE}, TA::Enzyme.TypeAnalysis, returnUsed::Bool, shadowReturnUsed::Bool, typeInfo::Enzyme.FnTypeInfo, I'll go add the function now, regardless to get it closer to working on it -- but an FYI I don't think Enzyme is the cause. |
Summarizing for @jarroyoe, this code does not currently differentiate with Enzyme, so Enzyme is not the cause of the performance (the warnings come from Enzyme be asked to differentiate something, though presumably not using that since Enzyme would error). Thus your performance question is best asked at the issue forum for a different library. We'll work on making this differentiate so let's keep the issue open, regardless. |
@jeremiedb suboptimal performance on fallback blas is expected (hence the warning about it). Feel free to open a separate issue to track. Resolving it requires either someone adding the BLAS rules in EnzymeRules or finishing up @ZuseZ4's work on Enzyme-internal blas rules. |
@wsmoses thank you for this information. To summarize, is this issue entirely on the Enzyme.jl team hands now, or something can be done through the DiffEqFlux.jl part of the script? |
You could define an EnzymeRule for the unsupported part of the code. |
The current bottleneck is getfield: #645, #644. What's going on in the OP's case is that it will do a try/catch on Enzyme, which throws the warning, but then fails (errors), which is caught, and then it falls back to using ReverseDiff in scalar mode (to handle the mutation), and it should be doing tape compilation, but that's all a bit besides the point. The key is that what I showed is what it trys to do with Enzyme and fails. |
using Enzyme, Random, Lux, ComponentArrays
Enzyme.API.printall!(true)
rng = Random.default_rng()
neuralNetwork = Lux.Chain(Lux.Dense(2,1),Lux.Dense(1,2))
ps, st = Lux.setup(rng, neuralNetwork)
ps64 = Float64.(ComponentArray(ps))
# @show ps64
function ude!(u,p, neuralNetwork)
neuralNetwork = Base.inferencebarrier(neuralNetwork.layers.layer_1)
# y = neuralNetwork(u, p.layer_1, NamedTuple()) # st.layer_1)
(y, _) = neuralNetwork(u, p.layer_1, NamedTuple()) # st.layer_1)
y[1]::Float64
end
u = Float64[0,0]
d_u = Float64[0,0]
p = ps64
d_p = copy(ps64)
ude!(u,p, neuralNetwork)
Enzyme.autodiff(Reverse, ude!, Duplicated(u, d_u), Const(p), neuralNetwork) this hits a GC error atm, investigating. |
The GC part of the error should now be fixed, the type instability persists. |
Okay, still needs work: using OrdinaryDiffEq, DiffEqFlux
using Optimization, OptimizationOptimisers, OptimizationOptimJL
using ComponentArrays, Lux, Random
# using ComponentArrays, Lux, Plots, Random, StatsBase
# using DelimitedFiles, Serialization
rng = Random.default_rng()
const neuralNetwork = Lux.Chain(Lux.Dense(2,10),Lux.Dense(10,5),Lux.Dense(5,2))
const ps, st = Lux.setup(rng, neuralNetwork)
ps64 = Float64.(ComponentArray(ps))
knownDynamics(x,p,q) = [-p[1].*x[1];
-p[2].*x[2]];
training_data = rand(2,50)
ps_dynamics = ComponentArray((predefined_params = rand(Float64, 2), model_params = ps64))
function ude!(du,u,p,t,q)
knownPred = knownDynamics(u,p.predefined_params,q)
nnPred = Array(neuralNetwork(u,p.model_params,st)[1])
for i in 1:length(u)
du[i] = knownPred[i]+nnPred[i]
end
end
nn_dynamics!(du,u,p,t) = ude!(du,u,p,t,nothing)
prob_nn = ODEProblem(nn_dynamics!,training_data[:, 1], (Float64(1),Float64(size(training_data,2))), ps_dynamics)
function predict(p, X = training_data[:,1], T = 1:size(training_data,2))
_prob = remake(prob_nn, u0 = X, tspan = (Float64(T[1]), Float64(T[end])), p = p)
Array(solve(_prob, Rodas4P(), saveat = T,
abstol=1e-7, reltol=1e-7
))
end
function loss_function(p)
X̂ = predict(p)
sum(abs2, training_data .- X̂)
end
pinit = ComponentVector(ps_dynamics)
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_function(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pinit)
result_neuralode = Optimization.solve(optprob,
ADAM(),
maxiters = 10)
|
GC issue now fixed on main, and [once https://github.com//pull/911 lands], output is now: ┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/cy24l/src/utils.jl:56
┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/cy24l/src/utils.jl:56
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155 |
With the (currently off by default, but can be enabled with a flag) BLAS handling, there is no longer a performance warning from Enzyme (though there's still an unrelated diffcache one above). Full code below: using OrdinaryDiffEq, DiffEqFlux
using Optimization, OptimizationOptimisers, OptimizationOptimJL
using ComponentArrays, Lux, Plots, Random, StatsBase
using DelimitedFiles, Serialization
rng = Random.default_rng()
using Enzyme
Enzyme.API.runtimeActivity!(true)
Enzyme.Compiler.bitcode_replacement!(false)
neuralNetwork = Lux.Chain(Lux.Dense(2,10),Lux.Dense(10,5),Lux.Dense(5,2))
ps, st = Lux.setup(rng, neuralNetwork)
ps64 = Float64.(ComponentArray(ps))
knownDynamics(x,p,q) = [-p[1].*x[1];
-p[2].*x[2]];
training_data = rand(2,50)
ps_dynamics = ComponentArray((predefined_params = rand(Float64, 2), model_params = ps64))
function ude!(du,u,p,t,q)
knownPred = knownDynamics(u,p.predefined_params,q)
nnPred = Array(neuralNetwork(u,p.model_params,st)[1])
for i in 1:length(u)
du[i] = knownPred[i]+nnPred[i]
end
end
nn_dynamics!(du,u,p,t) = ude!(du,u,p,t,nothing)
prob_nn = ODEProblem(nn_dynamics!,training_data[:, 1], (Float64(1),Float64(size(training_data,2))), ps_dynamics)
function predict(p, X = training_data[:,1], T = 1:size(training_data,2))
_prob = remake(prob_nn, u0 = X, tspan = (Float64(T[1]), Float64(T[end])), p = p)
Array(solve(_prob, Rodas4P(), saveat = T,
abstol=1e-7, reltol=1e-7
))
end
function loss_function(p)
X̂ = predict(p)
sum(abs2, training_data .- X̂)
end
pinit = ComponentVector(ps_dynamics)
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_function(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pinit)
result_neuralode = Optimization.solve(optprob,
ADAM(),
maxiters = 10) That said your code itself was also type unstable in places that would be a performance bottleneck, consider changing it to something like: using OrdinaryDiffEq, DiffEqFlux
using Optimization, OptimizationOptimisers, OptimizationOptimJL
using ComponentArrays, Lux, Plots, Random, StatsBase
using DelimitedFiles, Serialization
rng = Random.default_rng()
using Enzyme
# Enzyme.API.runtimeActivity!(true)
Enzyme.Compiler.bitcode_replacement!(false)
const neuralNetwork = Lux.Chain(Lux.Dense(2,10),Lux.Dense(10,5),Lux.Dense(5,2))
ltup = Lux.setup(rng, neuralNetwork)
const ps = ltup[1]
const st = ltup[2]
ps64 = Float64.(ComponentArray(ps))
knownDynamics(x,p,q) = [-p[1].*x[1];
-p[2].*x[2]];
const training_data = rand(2,50)
const ps_dynamics = ComponentArray((predefined_params = rand(Float64, 2), model_params = ps64))
function ude!(du,u,p,t,q)
knownPred = knownDynamics(u,p.predefined_params,q)
nnPred = Array(neuralNetwork(u,p.model_params,st)[1])
for i in 1:length(u)
du[i] = knownPred[i]+nnPred[i]
end
end
nn_dynamics!(du,u,p,t) = ude!(du,u,p,t,nothing)
const prob_nn = ODEProblem(nn_dynamics!,training_data[:, 1], (Float64(1),Float64(size(training_data,2))), ps_dynamics)
function predict(p, X = training_data[:,1], T = 1:size(training_data,2))
_prob = remake(prob_nn, u0 = X, tspan = (Float64(T[1]), Float64(T[end])), p = p)
Array(solve(_prob, Rodas4P(), saveat = T,
abstol=1e-7, reltol=1e-7
))
end
function loss_function(p)
X̂ = predict(p)
sum(abs2, training_data .- X̂)
end
pinit = ComponentVector(ps_dynamics)
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_function(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pinit)
result_neuralode = Optimization.solve(optprob,
ADAM(),
maxiters = 10) |
When I run this code:
I get repetitions of the following warnings, and performance is significantly slowed down.
However, if I change my neural network architecture to include a single hidden layer:
neuralNetwork = Lux.Chain(Lux.Dense(2,10),Lux.Dense(10,2))
The warnings disappear. The warnings also disappear when my neural network has a single input:
I haven't tested having a single layer neural network on production scale, but the case with a single input has a significantly better performance than the case with two inputs (the single input completes the production in less than a day, the double input takes several days to go through a single iteration of the for loop and runs out of 250GB of RAM).
Besides changing my neural networks to a single hidden layer (not ideal), how can this issue be fixed?
The text was updated successfully, but these errors were encountered: