From dc838d055abbb07e05bea8d96b89cdf21a9c49c8 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 30 Mar 2023 11:09:17 +0200 Subject: [PATCH 1/3] overwrite progbar at each epoch --- examples/mlp_mnist.jl | 3 +- src/ProgressMeter/ProgressMeter.jl | 566 +---------------------------- src/logging.jl | 2 +- src/trainer.jl | 37 +- 4 files changed, 44 insertions(+), 564 deletions(-) diff --git a/examples/mlp_mnist.jl b/examples/mlp_mnist.jl index 12274e2..00ac562 100644 --- a/examples/mlp_mnist.jl +++ b/examples/mlp_mnist.jl @@ -58,11 +58,12 @@ Tsunami.fit!(model, trainer, train_loader, test_loader) # TRAIN FROM SCRATCH -trainer = Trainer(max_epochs = 3, +trainer = Trainer(max_epochs = 5, default_root_dir = @__DIR__, accelerator = :cpu, checkpointer = true, logger = true, + progress_bar = true, ) fit_state = Tsunami.fit!(model, trainer, train_loader, test_loader) diff --git a/src/ProgressMeter/ProgressMeter.jl b/src/ProgressMeter/ProgressMeter.jl index 8b3dc6c..7cddafc 100644 --- a/src/ProgressMeter/ProgressMeter.jl +++ b/src/ProgressMeter/ProgressMeter.jl @@ -3,15 +3,17 @@ # https://github.com/timholy/ProgressMeter.jl/pull/261 # The change with respect to the original package are: # - merged PR #261 +# - removed docstring so that Documenter does not complain they are not in the docs +# - removed ProgressThresh and ProgressUnknown since not needed +# - removed @showprogress module ProgressMeter using Printf: @sprintf using Distributed -export Progress, ProgressThresh, ProgressUnknown, BarGlyphs, next!, update!, cancel, finish!, @showprogress, progress_map, progress_pmap, ijulia_behavior +export Progress, BarGlyphs, next!, update!, cancel, finish!, ijulia_behavior -ProgressMeter abstract type AbstractProgress end @@ -57,6 +59,7 @@ mutable struct Progress <: AbstractProgress check_iterations::Int prev_update_count::Int threads_used::Vector{Int} + keep::Bool # whether to keep the progress bar after completion function Progress(n::Integer; dt::Real=0.1, @@ -69,13 +72,14 @@ mutable struct Progress <: AbstractProgress start::Integer=0, enabled::Bool = true, showspeed::Bool = false, + keep::Bool = (offset == 0), ) CLEAR_IJULIA[] = clear_ijulia() reentrantlocker = Threads.ReentrantLock() counter = start tinit = tsecond = tlast = time() printed = false - new(n, reentrantlocker, dt, counter, tinit, tsecond, tlast, printed, desc, barlen, barglyphs, color, output, offset, 0, start, enabled, showspeed, 1, 1, Int[]) + new(n, reentrantlocker, dt, counter, tinit, tsecond, tlast, printed, desc, barlen, barglyphs, color, output, offset, 0, start, enabled, showspeed, 1, 1, Int[], keep) end end @@ -87,89 +91,6 @@ Progress(n::Integer, dt::Real, desc::AbstractString="Progress: ", Progress(n::Integer, desc::AbstractString, offset::Integer=0) = Progress(n, desc=desc, offset=offset) -mutable struct ProgressThresh{T<:Real} <: AbstractProgress - thresh::T - reentrantlocker::Threads.ReentrantLock - dt::Float64 - val::T - counter::Int - triggered::Bool - tinit::Float64 - tlast::Float64 - printed::Bool # true if we have issued at least one status update - desc::String # prefix to the percentage, e.g. "Computing..." - color::Symbol # default to green - output::IO # output stream into which the progress is written - numprintedvalues::Int # num values printed below progress in last iteration - offset::Int # position offset of progress bar (default is 0) - enabled::Bool # is the output enabled - showspeed::Bool # should the output include average time per iteration - check_iterations::Int - prev_update_count::Int - threads_used::Vector{Int} - - function ProgressThresh{T}(thresh; - dt::Real=0.1, - desc::AbstractString="Progress: ", - color::Symbol=:green, - output::IO=stderr, - offset::Integer=0, - enabled = true, - showspeed::Bool = false) where T - CLEAR_IJULIA[] = clear_ijulia() - reentrantlocker = Threads.ReentrantLock() - tinit = tlast = time() - printed = false - new{T}(thresh, reentrantlocker, dt, typemax(T), 0, false, tinit, tlast, printed, desc, color, output, 0, offset, enabled, showspeed, 1, 1, Int[]) - end -end -ProgressThresh(thresh::Real; kwargs...) = ProgressThresh{typeof(thresh)}(thresh; kwargs...) - -# Legacy constructor calls -ProgressThresh(thresh::Real, dt::Real, desc::AbstractString="Progress: ", - color::Symbol=:green, output::IO=stderr; - offset::Integer=0) = - ProgressThresh(thresh; dt=dt, desc=desc, color=color, output=output, offset=offset) - -ProgressThresh(thresh::Real, desc::AbstractString, offset::Integer=0) = ProgressThresh(thresh; desc=desc, offset=offset) - - -mutable struct ProgressUnknown <: AbstractProgress - done::Bool - reentrantlocker::Threads.ReentrantLock - dt::Float64 - counter::Int - spincounter::Int - triggered::Bool - tinit::Float64 - tlast::Float64 - printed::Bool # true if we have issued at least one status update - desc::String # prefix to the percentage, e.g. "Computing..." - color::Symbol # default to green - spinner::Bool # show a spinner - output::IO # output stream into which the progress is written - numprintedvalues::Int # num values printed below progress in last iteration - enabled::Bool # is the output enabled - showspeed::Bool # should the output include average time per iteration - check_iterations::Int - prev_update_count::Int - threads_used::Vector{Int} -end - -function ProgressUnknown(;dt::Real=0.1, desc::AbstractString="Progress: ", color::Symbol=:green, spinner::Bool=false, output::IO=stderr, enabled::Bool = true, showspeed::Bool = false) - CLEAR_IJULIA[] = clear_ijulia() - reentrantlocker = Threads.ReentrantLock() - tinit = tlast = time() - printed = false - ProgressUnknown(false, reentrantlocker, dt, 0, 0, false, tinit, tlast, printed, desc, color, spinner, output, 0, enabled, showspeed, 1, 1, Int[]) -end - -ProgressUnknown(dt::Real, desc::AbstractString="Progress: ", - color::Symbol=:green, output::IO=stderr; kwargs...) = - ProgressUnknown(dt=dt, desc=desc, color=color, output=output; kwargs...) - -ProgressUnknown(desc::AbstractString; kwargs...) = ProgressUnknown(desc=desc; kwargs...) - #...length of percentage and ETA string with days is 29 characters, speed string is always 14 extra characters function tty_width(desc, output, showspeed::Bool) full_width = displaysize(output)[2] @@ -209,7 +130,7 @@ end # update progress display function updateProgress!(p::Progress; showvalues = (), truncate_lines = false, valuecolor = :blue, - offset::Integer = p.offset, keep = (offset == 0), desc::Union{Nothing,AbstractString} = nothing, + offset::Integer = p.offset, keep = p.keep, desc::Union{Nothing,AbstractString} = nothing, ignore_predictor = false) !p.enabled && return if p.counter == 2 # ignore the first loop given usually uncharacteristically slow @@ -222,6 +143,7 @@ function updateProgress!(p::Progress; showvalues = (), truncate_lines = false, v p.desc = desc end p.offset = offset + p.keep = keep if p.counter >= p.n if p.counter == p.n #&& p.printed @@ -241,7 +163,7 @@ function updateProgress!(p::Progress; showvalues = (), truncate_lines = false, v move_cursor_up_while_clearing_lines(p.output, p.numprintedvalues) printover(p.output, msg, p.color) printvalues!(p, showvalues; color = valuecolor, truncate = truncate_lines) - if keep + if p.keep println(p.output) else print(p.output, "\r\u1b[A" ^ (p.offset + p.numprintedvalues)) @@ -291,133 +213,6 @@ function updateProgress!(p::Progress; showvalues = (), truncate_lines = false, v return nothing end -function updateProgress!(p::ProgressThresh; showvalues = (), truncate_lines = false, valuecolor = :blue, - offset::Integer = p.offset, keep = (offset == 0), desc = p.desc, ignore_predictor = false) - !p.enabled && return - p.offset = offset - p.desc = desc - if p.val <= p.thresh && !p.triggered - p.triggered = true - if p.printed - t = time() - elapsed_time = t - p.tinit - p.triggered = true - dur = durationstring(elapsed_time) - msg = @sprintf "%s Time: %s (%d iterations)" p.desc dur p.counter - if p.showspeed - sec_per_iter = elapsed_time / p.counter - msg = @sprintf "%s (%s)" msg speedstring(sec_per_iter) - end - print(p.output, "\n" ^ (p.offset + p.numprintedvalues)) - move_cursor_up_while_clearing_lines(p.output, p.numprintedvalues) - printover(p.output, msg, p.color) - printvalues!(p, showvalues; color = valuecolor, truncate = truncate_lines) - if keep - println(p.output) - else - print(p.output, "\r\u1b[A" ^ (p.offset + p.numprintedvalues)) - end - flush(p.output) - end - return - end - - if ignore_predictor || predicted_updates_per_dt_have_passed(p) - t = time() - if p.counter > 2 - p.check_iterations = calc_check_iterations(p, t) - end - if t > p.tlast+p.dt && !p.triggered - msg = @sprintf "%s (thresh = %g, value = %g)" p.desc p.thresh p.val - if p.showspeed - elapsed_time = t - p.tinit - sec_per_iter = elapsed_time / p.counter - msg = @sprintf "%s (%s)" msg speedstring(sec_per_iter) - end - print(p.output, "\n" ^ (p.offset + p.numprintedvalues)) - move_cursor_up_while_clearing_lines(p.output, p.numprintedvalues) - printover(p.output, msg, p.color) - printvalues!(p, showvalues; color = valuecolor, truncate = truncate_lines) - print(p.output, "\r\u1b[A" ^ (p.offset + p.numprintedvalues)) - flush(p.output) - # Compensate for any overhead of printing. This can be - # especially important if you're running over a slow network - # connection. - p.tlast = t + 2*(time()-t) - p.printed = true - p.prev_update_count = p.counter - end - end -end - -const spinner_chars = ['◐','◓','◑','◒'] -const spinner_done = '✓' - -spinner_char(p::ProgressUnknown, spinner::AbstractChar) = spinner -spinner_char(p::ProgressUnknown, spinner::AbstractVector{<:AbstractChar}) = - p.done ? spinner_done : spinner[p.spincounter % length(spinner) + firstindex(spinner)] -spinner_char(p::ProgressUnknown, spinner::AbstractString) = - p.done ? spinner_done : spinner[nextind(spinner, 1, p.spincounter % length(spinner))] - -function updateProgress!(p::ProgressUnknown; showvalues = (), truncate_lines = false, valuecolor = :blue, desc = p.desc, - ignore_predictor = false, spinner::Union{AbstractChar,AbstractString,AbstractVector{<:AbstractChar}} = spinner_chars) - !p.enabled && return - p.desc = desc - if p.done - if p.printed - t = time() - elapsed_time = t - p.tinit - dur = durationstring(elapsed_time) - if p.spinner - msg = @sprintf "%c %s \t Time: %s" spinner_char(p, spinner) p.desc dur - p.spincounter += 1 - else - msg = @sprintf "%s %d \t Time: %s" p.desc p.counter dur - end - if p.showspeed - sec_per_iter = elapsed_time / p.counter - msg = @sprintf "%s (%s)" msg speedstring(sec_per_iter) - end - move_cursor_up_while_clearing_lines(p.output, p.numprintedvalues) - printover(p.output, msg, p.color) - printvalues!(p, showvalues; color = valuecolor, truncate = truncate_lines) - println(p.output) - flush(p.output) - end - return - end - if ignore_predictor || predicted_updates_per_dt_have_passed(p) - t = time() - if p.counter > 2 - p.check_iterations = calc_check_iterations(p, t) - end - if t > p.tlast+p.dt - dur = durationstring(t-p.tinit) - if p.spinner - msg = @sprintf "%c %s \t Time: %s" spinner_char(p, spinner) p.desc dur - p.spincounter += 1 - else - msg = @sprintf "%s %d \t Time: %s" p.desc p.counter dur - end - if p.showspeed - elapsed_time = t - p.tinit - sec_per_iter = elapsed_time / p.counter - msg = @sprintf "%s (%s)" msg speedstring(sec_per_iter) - end - move_cursor_up_while_clearing_lines(p.output, p.numprintedvalues) - printover(p.output, msg, p.color) - printvalues!(p, showvalues; color = valuecolor, truncate = truncate_lines) - flush(p.output) - # Compensate for any overhead of printing. This can be - # especially important if you're running over a slow network - # connection. - p.tlast = t + 2*(time()-t) - p.printed = true - p.prev_update_count = p.counter - return - end - end -end predicted_updates_per_dt_have_passed(p::AbstractProgress) = p.counter <= 2 || # otherwise the first 2 are never printed, independently of dt @@ -442,14 +237,14 @@ function lock_if_threading(f::Function, p::AbstractProgress) end end -function next!(p::Union{Progress, ProgressUnknown}; step::Int = 1, options...) +function next!(p::Union{Progress}; step::Int = 1, options...) lock_if_threading(p) do p.counter += step updateProgress!(p; ignore_predictor = step == 0, options...) end end -function next!(p::Union{Progress, ProgressUnknown}, color::Symbol; step::Int = 1, options...) +function next!(p::Union{Progress}, color::Symbol; step::Int = 1, options...) lock_if_threading(p) do p.color = color p.counter += step @@ -458,7 +253,7 @@ function next!(p::Union{Progress, ProgressUnknown}, color::Symbol; step::Int = 1 end -function update!(p::Union{Progress, ProgressUnknown}, counter::Int=p.counter, color::Symbol=p.color; options...) +function update!(p::Union{Progress}, counter::Int=p.counter, color::Symbol=p.color; options...) lock_if_threading(p) do counter_changed = p.counter != counter p.counter = counter @@ -468,26 +263,18 @@ function update!(p::Union{Progress, ProgressUnknown}, counter::Int=p.counter, co end -function update!(p::ProgressThresh, val=p.val, color::Symbol=p.color; increment::Bool = true, options...) - lock_if_threading(p) do - p.val = val - if increment - p.counter += 1 - end - p.color = color - updateProgress!(p; options...) - end -end - -function cancel(p::AbstractProgress, msg::AbstractString = "Aborted before all tasks were completed", color = :red; showvalues = (), truncate_lines = false, valuecolor = :blue, offset = p.offset, keep = (offset == 0)) +function cancel(p::AbstractProgress, msg::AbstractString = "Aborted before all tasks were completed", color = :red; + showvalues = (), truncate_lines = false, valuecolor = :blue, + offset = p.offset, keep = p.keep) lock_if_threading(p) do p.offset = offset + p.keep = keep if p.printed print(p.output, "\n" ^ (p.offset + p.numprintedvalues)) move_cursor_up_while_clearing_lines(p.output, p.numprintedvalues) printover(p.output, msg, color) printvalues!(p, showvalues; color = valuecolor, truncate = truncate_lines) - if keep + if p.keep println(p.output) else print(p.output, "\r\u1b[A" ^ (p.offset + p.numprintedvalues)) @@ -503,16 +290,6 @@ function finish!(p::Progress; options...) end end -function finish!(p::ProgressThresh; options...) - update!(p, p.thresh; options...) -end - -function finish!(p::ProgressUnknown; options...) - lock_if_threading(p) do - p.done = true - updateProgress!(p; options...) - end -end # Internal method to print additional values below progress bar function printvalues!(p::AbstractProgress, showvalues; color = :normal, truncate = false) @@ -632,313 +409,6 @@ function speedstring(sec_per_iter) return " >100 d/it" end -function showprogress_process_expr(node, metersym) - if !isa(node, Expr) - node - elseif node.head === :break || node.head === :return - # special handling for break and return statements - quote - ($finish!)($metersym) - $node - end - elseif node.head === :for || node.head === :while - # do not process inner loops - # - # FIXME: do not process break and return statements in inner functions - # either - node - else - # process each subexpression recursively - Expr(node.head, [showprogress_process_expr(a, metersym) for a in node.args]...) - end -end - -struct ProgressWrapper{T} - obj::T - meter::Progress -end - -Base.length(wrap::ProgressWrapper) = Base.length(wrap.obj) -function Base.iterate(wrap::ProgressWrapper, state...) - ir = iterate(wrap.obj, state...) - - if ir === nothing - finish!(wrap.meter) - elseif !isempty(state) - next!(wrap.meter) - end - - ir -end - -function showprogressdistributed(args...) - if length(args) < 1 - throw(ArgumentError("@showprogress @distributed requires at least 1 argument")) - end - progressargs = args[1:end-1] - expr = Base.remove_linenums!(args[end]) - - if expr.head != :macrocall || expr.args[1] != Symbol("@distributed") - throw(ArgumentError("malformed @showprogress @distributed expression")) - end - - distargs = filter(x -> !(x isa LineNumberNode), expr.args[2:end]) - na = length(distargs) - if na == 1 - loop = distargs[1] - elseif na == 2 - reducer = distargs[1] - loop = distargs[2] - else - println("$distargs $na") - throw(ArgumentError("wrong number of arguments to @distributed")) - end - if loop.head !== :for - throw(ArgumentError("malformed @distributed loop")) - end - var = loop.args[1].args[1] - r = loop.args[1].args[2] - body = loop.args[2] - - setup = quote - n = length($(esc(r))) - p = Progress(n, $([esc(arg) for arg in progressargs]...)) - ch = RemoteChannel(() -> Channel{Bool}(n)) - end - - if na == 1 - # would be nice to do this with @sync @distributed but @sync is broken - # https://github.com/JuliaLang/julia/issues/28979 - compute = quote - display = @async let i = 0 - while i < n - take!(ch) - next!(p) - i += 1 - end - end - @distributed for $(esc(var)) = $(esc(r)) - $(esc(body)) - put!(ch, true) - end - nothing - end - else - compute = quote - display = @async while take!(ch) next!(p) end - results = @distributed $(esc(reducer)) for $(esc(var)) = $(esc(r)) - x = $(esc(body)) - put!(ch, true) - x - end - put!(ch, false) - results - end - end - - quote - $setup - results = $compute - wait(display) - results - end -end - - -macro showprogress(args...) - showprogress(args...) -end - -function showprogress(args...) - if length(args) < 1 - throw(ArgumentError("@showprogress requires at least one argument.")) - end - progressargs = args[1:end-1] - expr = args[end] - if expr.head == :macrocall && expr.args[1] == Symbol("@distributed") - return showprogressdistributed(args...) - end - orig = expr = copy(expr) - if expr.args[1] == :|> # e.g. map(x->x^2) |> sum - expr.args[2] = showprogress(progressargs..., expr.args[2]) - return expr - end - metersym = gensym("meter") - mapfuns = (:map, :asyncmap, :reduce, :pmap) - kind = :invalid # :invalid, :loop, or :map - - if isa(expr, Expr) - if expr.head == :for - outerassignidx = 1 - loopbodyidx = lastindex(expr.args) - kind = :loop - elseif expr.head == :comprehension - outerassignidx = lastindex(expr.args) - loopbodyidx = 1 - kind = :loop - elseif expr.head == :typed_comprehension - outerassignidx = lastindex(expr.args) - loopbodyidx = 2 - kind = :loop - elseif expr.head == :call && expr.args[1] in mapfuns - kind = :map - elseif expr.head == :do - call = expr.args[1] - if call.head == :call && call.args[1] in mapfuns - kind = :map - end - end - end - - if kind == :invalid - throw(ArgumentError("Final argument to @showprogress must be a for loop, comprehension, map, reduce, or pmap; got $expr")) - elseif kind == :loop - # As of julia 0.5, a comprehension's "loop" is actually one level deeper in the syntax tree. - if expr.head !== :for - @assert length(expr.args) == loopbodyidx - expr = expr.args[outerassignidx] = copy(expr.args[outerassignidx]) - @assert expr.head === :generator - outerassignidx = lastindex(expr.args) - loopbodyidx = 1 - end - - # Transform the first loop assignment - loopassign = expr.args[outerassignidx] = copy(expr.args[outerassignidx]) - if loopassign.head === :block # this will happen in a for loop with multiple iteration variables - for i in 2:length(loopassign.args) - loopassign.args[i] = esc(loopassign.args[i]) - end - loopassign = loopassign.args[1] = copy(loopassign.args[1]) - end - @assert loopassign.head === :(=) - @assert length(loopassign.args) == 2 - obj = loopassign.args[2] - loopassign.args[1] = esc(loopassign.args[1]) - loopassign.args[2] = :(ProgressWrapper(iterable, $(esc(metersym)))) - - # Transform the loop body break and return statements - if expr.head === :for - expr.args[loopbodyidx] = showprogress_process_expr(expr.args[loopbodyidx], metersym) - end - - # Escape all args except the loop assignment, which was already appropriately escaped. - for i in 1:length(expr.args) - if i != outerassignidx - expr.args[i] = esc(expr.args[i]) - end - end - if orig !== expr - # We have additional escaping to do; this will occur for comprehensions with julia 0.5 or later. - for i in 1:length(orig.args)-1 - orig.args[i] = esc(orig.args[i]) - end - end - - setup = quote - iterable = $(esc(obj)) - $(esc(metersym)) = Progress(length(iterable), $([esc(arg) for arg in progressargs]...)) - end - - if expr.head === :for - return quote - $setup - $expr - end - else - # We're dealing with a comprehension - return quote - begin - $setup - rv = $orig - next!($(esc(metersym))) - rv - end - end - end - else # kind == :map - - # isolate call to map - if expr.head == :do - call = expr.args[1] - else - call = expr - end - - # get args to map to determine progress length - mapargs = collect(Any, filter(call.args[2:end]) do a - return isa(a, Symbol) || !(a.head in (:kw, :parameters)) - end) - if expr.head == :do - insert!(mapargs, 1, :nothing) # to make args for ncalls line up - end - - # change call to progress_map - mapfun = call.args[1] - call.args[1] = :progress_map - - # escape args as appropriate - for i in 2:length(call.args) - call.args[i] = esc(call.args[i]) - end - if expr.head == :do - expr.args[2] = esc(expr.args[2]) - end - - # create appropriate Progress expression - lenex = :(ncalls($(esc(mapfun)), ($([esc(a) for a in mapargs]...),))) - progex = :(Progress($lenex, $([esc(a) for a in progressargs]...))) - - # insert progress and mapfun kwargs - push!(call.args, Expr(:kw, :progress, progex)) - push!(call.args, Expr(:kw, :mapfun, esc(mapfun))) - - return expr - end -end - -function progress_map(args...; mapfun=map, - progress=Progress(ncalls(mapfun, args)), - channel_bufflen=min(1000, ncalls(mapfun, args)), - kwargs...) - f = first(args) - other_args = args[2:end] - channel = RemoteChannel(()->Channel{Bool}(channel_bufflen), 1) - local vals - @sync begin - # display task - @async while take!(channel) - next!(progress) - end - - # map task - @sync begin - vals = mapfun(other_args...; kwargs...) do x... - val = f(x...) - put!(channel, true) - yield() - return val - end - put!(channel, false) - end - end - return vals -end - - -progress_pmap(args...; kwargs...) = progress_map(args...; mapfun=pmap, kwargs...) - -function ncalls(mapfun::Function, map_args) - if mapfun == pmap && length(map_args) >= 2 && isa(map_args[2], AbstractWorkerPool) - relevant = map_args[3:end] - else - relevant = map_args[2:end] - end - if isempty(relevant) - error("Unable to determine number of calls in $mapfun. Too few arguments?") - else - return maximum(length(arg) for arg in relevant) - end -end end \ No newline at end of file diff --git a/src/logging.jl b/src/logging.jl index 1938090..415dcfa 100644 --- a/src/logging.jl +++ b/src/logging.jl @@ -153,7 +153,7 @@ function store_for_val_prog_bar!(metalogger::MetaLogger, name::AbstractString, v metalogger.values_for_val_progressbar[name] = value end -function values_for_train_progressbar(metalogger::MetaLogger) +function values_for_train_progbar(metalogger::MetaLogger) dict = metalogger.values_for_train_progressbar ks = sort(collect(keys(dict))) return [(k, roundval(dict[k])) for k in ks] diff --git a/src/trainer.jl b/src/trainer.jl index b4563da..2cc9b43 100644 --- a/src/trainer.jl +++ b/src/trainer.jl @@ -80,7 +80,9 @@ the fit state during the execution of `fit!`. - **progress\\_bar**: It `true`, shows a progress bar during training. Default: `true`. -- **val\\_every\\_n\\_epochs**: Perform a validation loop every after every N training epochs. +- **val\\_every\\_n\\_epochs**: Perform a validation loop every after every N training epochs. + The validation loop is in any case performed at the end of the last training epoch. + Set to 0 or negative to disable validation. Default: `1`. # Additional Fields @@ -121,14 +123,14 @@ Tsunami.fit!(model, trainer, train_dataloader, val_dataloader) optimisers = nothing end -function val_loop(model, trainer, val_dataloader; device) +function val_loop(model, trainer, val_dataloader; device, progbar_offset = 0, progbar_keep = true) val_dataloader === nothing && return fit_state = trainer.fit_state oldstage = fit_state.stage fit_state.stage = :validation - valprogressbar = Progress(length(val_dataloader); desc="Validation: ", - showspeed=true, enabled=true, color=:green) + valprogressbar = Progress(length(val_dataloader); desc="Val Epoch $(fit_state.epoch): ", + showspeed=true, enabled=trainer.progress_bar, color=:green, offset=progbar_offset, keep=progbar_keep) for (batch_idx, batch) in enumerate(val_dataloader) fit_state.batchsize = MLUtils.numobs(batch) batch = batch |> device @@ -154,14 +156,16 @@ function train_loop(model, trainer, train_dataloader, val_dataloader; device, ma oldstage = fit_state.stage fit_state.stage = :training + islastepoch = fit_state.epoch == trainer.max_epochs if trainer.lr_schedulers !== nothing lr = trainer.lr_schedulers(fit_state.epoch) Optimisers.adjust!(trainer.optimisers, lr) end - progressbar = Progress(length(train_dataloader); desc="Train Epoch $(fit_state.epoch): ", - showspeed=true, enabled = trainer.progress_bar, color=:yellow) + train_progbar = Progress(length(train_dataloader); desc="Train Epoch $(fit_state.epoch): ", + showspeed=true, enabled = trainer.progress_bar, color=:yellow, + keep = islastepoch) ## SINGLE EPOCH TRAINING LOOP for (batch_idx, batch) in enumerate(train_dataloader) @@ -177,13 +181,13 @@ function train_loop(model, trainer, train_dataloader, val_dataloader; device, ma Optimisers.update!(trainer.optimisers, model, grads[1]) - ProgressMeter.next!(progressbar, - showvalues = values_for_train_progressbar(trainer.metalogger), + ProgressMeter.next!(train_progbar, + showvalues = values_for_train_progbar(trainer.metalogger), valuecolor = :yellow) fit_state.step == max_steps && break end - ProgressMeter.finish!(progressbar) + ProgressMeter.finish!(train_progbar) ## EPOCH END fit_state.stage = :train_epoch_end @@ -195,8 +199,12 @@ function train_loop(model, trainer, train_dataloader, val_dataloader; device, ma fit_state.stage = :training ## VALIDATION - if (val_dataloader !== nothing && fit_state.epoch % trainer.val_every_n_epochs == 0) - val_loop(model, trainer, val_dataloader; device) + if val_dataloader !== nothing && trainer.val_every_n_epochs > 0 + if (fit_state.epoch % trainer.val_every_n_epochs == 0) || islastepoch + val_loop(model, trainer, val_dataloader; device, + progbar_offset = islastepoch ? 0 : train_progbar.numprintedvalues + 1, + progbar_keep = islastepoch) + end end fit_state.stage = oldstage @@ -254,8 +262,8 @@ function fit!( max_steps, max_epochs = compute_max_steps_and_epochs(trainer.max_steps, trainer.max_epochs) if trainer.fast_dev_run - max_epochs = 1 - max_steps = 1 + # max_steps = 1 + # max_epochs = 1 trainer.val_every_n_epochs = 1 empty!(trainer.loggers) @@ -265,6 +273,7 @@ function fit!( if val_dataloader !== nothing check_val_step(model, trainer, first(val_dataloader)) end + return fit_state end print_fit_initial_summary(model, trainer, device) @@ -284,7 +293,7 @@ function fit!( trainer.optimisers = optimisers |> device trainer.lr_schedulers = lr_schedulers - val_loop(model, trainer, val_dataloader; device) + val_loop(model, trainer, val_dataloader; device, progbar_keep=false) for epoch in start_epoch:max_epochs fit_state.epoch = epoch From c3c01057e019b61ce4f8750ddd4f2708dae782c1 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 30 Mar 2023 11:11:15 +0200 Subject: [PATCH 2/3] comment --- src/ProgressMeter/ProgressMeter.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/ProgressMeter/ProgressMeter.jl b/src/ProgressMeter/ProgressMeter.jl index 7cddafc..8af2064 100644 --- a/src/ProgressMeter/ProgressMeter.jl +++ b/src/ProgressMeter/ProgressMeter.jl @@ -1,11 +1,12 @@ # Had to copy and past the entire ProgressMeter package here # since the mantainer is not responsive. See: # https://github.com/timholy/ProgressMeter.jl/pull/261 -# The change with respect to the original package are: +# The changes with respect to the original package are: # - merged PR #261 # - removed docstring so that Documenter does not complain they are not in the docs # - removed ProgressThresh and ProgressUnknown since not needed # - removed @showprogress +# - added the `keep` keyword argument to Progress module ProgressMeter From 67f251646bb8e4e7cd2e01b82c144c76d609e644 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 30 Mar 2023 11:48:17 +0200 Subject: [PATCH 3/3] fix tests --- examples/mlp_mnist.jl | 2 +- test/trainer.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/mlp_mnist.jl b/examples/mlp_mnist.jl index 00ac562..50aaf9d 100644 --- a/examples/mlp_mnist.jl +++ b/examples/mlp_mnist.jl @@ -58,7 +58,7 @@ Tsunami.fit!(model, trainer, train_loader, test_loader) # TRAIN FROM SCRATCH -trainer = Trainer(max_epochs = 5, +trainer = Trainer(max_epochs = 3, default_root_dir = @__DIR__, accelerator = :cpu, checkpointer = true, diff --git a/test/trainer.jl b/test/trainer.jl index a4d155d..6936cdb 100644 --- a/test/trainer.jl +++ b/test/trainer.jl @@ -57,8 +57,8 @@ end train_dataloader = make_dataloader(nx, ny) trainer = SilentTrainer(max_epochs=2, fast_dev_run=true) fit_state = Tsunami.fit!(model, trainer, train_dataloader) - @test fit_state.epoch == 1 - @test fit_state.step == 1 + @test fit_state.epoch == 0 + @test fit_state.step == 0 end @testset "val_every_n_epochs" begin