Skip to content

Commit

Permalink
Switch to update! only
Browse files Browse the repository at this point in the history
  • Loading branch information
darsnack committed Oct 15, 2022
1 parent 6a284ba commit 8b8eebc
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/optimise/Optimise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import Zygote
import Zygote: Params, gradient
using AbstractDifferentiation
import Optimisers
import Optimisers: update, update!
import Optimisers: update!
using LinearAlgebra
import ArrayInterface
using ProgressLogging: @progress, @withprogress, @logprogress
Expand Down
3 changes: 1 addition & 2 deletions src/optimise/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ function Optimisers.update!(opt::AbstractOptimiser, xs::Params, gs)

return opt, xs
end
Optimisers.update(opt::AbstractOptimiser, xs::Params, gs) = update!(opt, xs, gs)

# Callback niceties
call(f, xs...) = f(xs...)
Expand Down Expand Up @@ -139,7 +138,7 @@ function train!(loss, ad::AD.AbstractBackend, model, data, optstate; cb = () ->
try
_loss = _build_loss(ad, loss, batchmemaybe(d))
gs = _gradient_only(AD.gradient(ad, _loss, model))
optstate, model = update(optstate, model, gs)
optstate, model = update!(optstate, model, gs)
cb()
catch ex
if ex isa StopException
Expand Down

0 comments on commit 8b8eebc

Please sign in to comment.