Skip to content

Commit

Permalink
Merge pull request #172 from JuliaGNI/replace-ifs-with-dispatch
Browse files Browse the repository at this point in the history
Replaced ifs with dispatch.
  • Loading branch information
michakraus authored Nov 14, 2024
2 parents bf31168 + ece1417 commit 8366ad1
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions src/data_loader/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,17 @@ function optimize_for_one_epoch!(opt::Optimizer, model, ps::Union{NeuralNetworkP
count += 1
# these `copy`s should not be necessary! coming from a Zygote problem!
input_nt_output_nt = convert_input_and_batch_indices_to_array(dl, batch, batch_indices) |> _copy
loss_value, pullback = if typeof(input_nt_output_nt) <: Tuple
Zygote.pullback(ps -> loss(model, ps, input_nt_output_nt...), ps)
else
Zygote.pullback(ps -> loss(model, ps, input_nt_output_nt), ps)
end
loss_value, pullback = _pullback(ps, model, input_nt_output_nt, loss)
total_error += loss_value
dp = return_correct_named_tuple(pullback(one(loss_value))[1])
optimization_step!(opt, λY, ps, dp)
end
total_error / count
end

_pullback(ps, model, input_nt_output_nt, loss) = Zygote.pullback(ps -> loss(model, ps, input_nt_output_nt), ps)
_pullback(ps, model, input_nt_output_nt::Tuple, loss) = Zygote.pullback(ps -> loss(model, ps, input_nt_output_nt...), ps)

# this is needed because of the specific way in which we store nn parameters
return_correct_named_tuple(dx::NamedTuple{(:params, )}) = dx.params
return_correct_named_tuple(dx) = dx
Expand Down

0 comments on commit 8366ad1

Please sign in to comment.