From ece14175ef4a0289b9c3828008106095eb7f0f38 Mon Sep 17 00:00:00 2001 From: benedict-96 Date: Wed, 13 Nov 2024 15:48:24 +0100 Subject: [PATCH] Replaced ifs with dispatch. Replaced ifs with dispatch. Forgot to add loss to function arguments. Added loss as input argument. Fix typo. --- src/data_loader/optimize.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/data_loader/optimize.jl b/src/data_loader/optimize.jl index ffb1a54a2..70c1c9689 100644 --- a/src/data_loader/optimize.jl +++ b/src/data_loader/optimize.jl @@ -55,11 +55,7 @@ function optimize_for_one_epoch!(opt::Optimizer, model, ps::Union{NeuralNetworkP count += 1 # these `copy`s should not be necessary! coming from a Zygote problem! input_nt_output_nt = convert_input_and_batch_indices_to_array(dl, batch, batch_indices) |> _copy - loss_value, pullback = if typeof(input_nt_output_nt) <: Tuple - Zygote.pullback(ps -> loss(model, ps, input_nt_output_nt...), ps) - else - Zygote.pullback(ps -> loss(model, ps, input_nt_output_nt), ps) - end + loss_value, pullback = _pullback(ps, model, input_nt_output_nt, loss) total_error += loss_value dp = return_correct_named_tuple(pullback(one(loss_value))[1]) optimization_step!(opt, λY, ps, dp) @@ -67,6 +63,9 @@ function optimize_for_one_epoch!(opt::Optimizer, model, ps::Union{NeuralNetworkP total_error / count end +_pullback(ps, model, input_nt_output_nt, loss) = Zygote.pullback(ps -> loss(model, ps, input_nt_output_nt), ps) +_pullback(ps, model, input_nt_output_nt::Tuple, loss) = Zygote.pullback(ps -> loss(model, ps, input_nt_output_nt...), ps) + # this is needed because of the specific way in which we store nn parameters return_correct_named_tuple(dx::NamedTuple{(:params, )}) = dx.params return_correct_named_tuple(dx) = dx