diff --git a/src/optimise/train.jl b/src/optimise/train.jl index d487032ddf..dd5ed2c4a8 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -1,5 +1,5 @@ using Juno -import Zygote: Params, gradient +import Zygote: Params, gradient, pullback """ update!(x, x̄) @@ -94,16 +94,18 @@ The callback can call [`Flux.stop`](@ref) to interrupt the training loop. Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. """ -function train!(loss, ps, data, opt; cb = () -> ()) +function train!(loss, ps, data, opt; cb = (x...) -> (), prehooks = (x...) -> true) ps = Params(ps) cb = runall(cb) + train_prehooks = runall(prehooks) @progress for d in data try - gs = gradient(ps) do + l, back = pullback(ps) do loss(batchmemaybe(d)...) end - update!(opt, ps, gs) - cb() + gs = back(l) + all(train_prehooks(l, gs, d)) && update!(opt, ps, gs) + cb(l, gs, d) catch ex if ex isa StopException break