Skip to content

Commit

Permalink
Replaced ifs with dispatch.
Browse files Browse the repository at this point in the history
Replaced ifs with dispatch.

Forgot to add loss to function arguments.

Added loss as input argument.

Fix typo.
  • Loading branch information
benedict-96 committed Nov 13, 2024
1 parent 4bc5fc5 commit ece1417
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions src/data_loader/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,17 @@ function optimize_for_one_epoch!(opt::Optimizer, model, ps::Union{NeuralNetworkP
count += 1
# these `copy`s should not be necessary! coming from a Zygote problem!
input_nt_output_nt = convert_input_and_batch_indices_to_array(dl, batch, batch_indices) |> _copy
loss_value, pullback = if typeof(input_nt_output_nt) <: Tuple
Zygote.pullback(ps -> loss(model, ps, input_nt_output_nt...), ps)
else
Zygote.pullback(ps -> loss(model, ps, input_nt_output_nt), ps)
end
loss_value, pullback = _pullback(ps, model, input_nt_output_nt, loss)
total_error += loss_value
dp = return_correct_named_tuple(pullback(one(loss_value))[1])
optimization_step!(opt, λY, ps, dp)
end
total_error / count
end

_pullback(ps, model, input_nt_output_nt, loss) = Zygote.pullback(ps -> loss(model, ps, input_nt_output_nt), ps)
_pullback(ps, model, input_nt_output_nt::Tuple, loss) = Zygote.pullback(ps -> loss(model, ps, input_nt_output_nt...), ps)

# this is needed because of the specific way in which we store nn parameters
return_correct_named_tuple(dx::NamedTuple{(:params, )}) = dx.params
return_correct_named_tuple(dx) = dx
Expand Down

0 comments on commit ece1417

Please sign in to comment.