Skip to content

Commit

Permalink
Add actual condition to run orchestrate_diagnostics
Browse files Browse the repository at this point in the history
  • Loading branch information
Sbozzolo committed May 19, 2024
1 parent 9eb9f8d commit c925edc
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 35 deletions.
11 changes: 10 additions & 1 deletion docs/src/internals.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,17 @@ because it creates a new integrator obtained by copying all the fields of the
old one and adding the diagnostics (with
[`Accessors`](https://github.com/JuliaObjects/Accessors.jl)).

The `DiagnosticsHandler` also contains three `BitVectors`: `active_compute`,
`active_output`, `active_sync`. These `BitVectors` have the same length as the
number of scheduled diagnostics and signal whether something should done at a
given step. The `BitVectors` are defined and preallocated trying to reduce the
inference allocations that result from operations like `filter` on lists of
`ScheduledDiagnostics`. They are updated by a callback that is run before
`orchestrate_diagnostics` and they can be used to determine if
`orchestrate_diagnostics` should be run at all.

## Orchestrate diagnostics

One of the design goals for `orchestrate_diagnostics` is to keep all the
broadcasted expression in the same function scope. This opens a path to optimize
the number of GPU kernel launches.
the number of GPU kernel launches.
131 changes: 101 additions & 30 deletions src/clima_diagnostics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ struct DiagnosticsHandler{
many times the given diagnostics was computed from the last time it was output to
disk."""
counters::COUNT

"""Bitvectors that identify which diagnostics are active at the given step. This is here
mostly to reduce inference allocations that would result from operations like filter."""
active_compute::BitVector
active_output::BitVector
active_sync::BitVector
end

"""
Expand Down Expand Up @@ -105,23 +111,22 @@ function DiagnosticsHandler(scheduled_diagnostics, Y, p, t; dt = nothing)
end
end

num_diagnostics = length(scheduled_diagnostics)
active_compute = BitVector(ntuple(_ -> false, num_diagnostics))
active_output = BitVector(ntuple(_ -> false, num_diagnostics))
active_sync = BitVector(ntuple(_ -> false, num_diagnostics))

return DiagnosticsHandler(
Tuple(scheduled_diagnostics),
storage,
accumulators,
counters,
active_compute,
active_output,
active_sync,
)
end

# Does the writer associated to `diag` need to be synced?
# It does only when it has a sync_schedule that is a callable and that
# callable returns true when called on the integrator
function _needs_sync(diag, integrator)
hasproperty(diag.output_writer, :sync_schedule) || return false
isnothing(diag.output_writer.sync_schedule) && return false
return diag.output_writer.sync_schedule(integrator)
end

"""
orchestrate_diagnostics(integrator, diagnostic_handler::DiagnosticsHandler)
Expand All @@ -133,15 +138,10 @@ function orchestrate_diagnostics(
diagnostic_handler::DiagnosticsHandler,
)
scheduled_diagnostics = diagnostic_handler.scheduled_diagnostics
active_compute = Bool[]
active_output = Bool[]
active_sync = Bool[]

for diag in scheduled_diagnostics
push!(active_compute, diag.compute_schedule_func(integrator))
push!(active_output, diag.output_schedule_func(integrator))
push!(active_sync, _needs_sync(diag, integrator))
end
active_compute = diagnostic_handler.active_compute
active_output = diagnostic_handler.active_output
active_sync = diagnostic_handler.active_sync

# Compute
for diag_index in 1:length(scheduled_diagnostics)
Expand Down Expand Up @@ -238,22 +238,92 @@ function orchestrate_diagnostics(
return nothing
end

# Does the writer associated to `diag` need to be synced?
# It does only when it has a sync_schedule that is a callable and that
# callable returns true when called on the integrator
function _needs_sync(diag, integrator)
hasproperty(diag.output_writer, :sync_schedule) || return false
isnothing(diag.output_writer.sync_schedule) && return false
return diag.output_writer.sync_schedule(integrator)
end

"""
update_diagnostic_handler_bitvectors(integrator, diagnostics_handler::DiagnosticsHandler)
Update the `active_{compute, update, sync}` bitvector in `diagnostics_handler`.
The `diagnostics_handler` contains three bitvectors that determine which actions should be
taken at the current iterations. They are preallocated mostly to avoid inference allocations
that result from operations with groups of `ScheduledDiagnostics`, but they can also be used
to determine if `orchestrate_diagnostics` should be called or not.
This function evaluates the various `schedule_func`s for all the `ScheduledDiagnostics` in
the `diagnostics_handler` and updates the bitvectors. This function should be called at
every step.
"""
function update_diagnostic_handler_bitvectors(integrator, diagnostics_handler)
scheduled_diagnostics = diagnostics_handler.scheduled_diagnostics

for index in 1:length(scheduled_diagnostics)
diag = scheduled_diagnostics[index]
diagnostics_handler.active_compute[index] =
diag.compute_schedule_func(integrator)
diagnostics_handler.active_output[index] =
diag.output_schedule_func(integrator)
diagnostics_handler.active_sync[index] = _needs_sync(diag, integrator)
end

return nothing
end

"""
check_callback_condition(integrator, diagnostics_handler::DiagnosticsHandler)
Return true when `orchestrate_diagnostics` should be called.
"""
function check_callback_condition(integrator, diagnostics_handler)
return any(diagnostics_handler.active_compute) ||
any(diagnostics_handler.active_output) ||
any(diagnostics_handler.active_sync)
end

"""
DiagnosticsCallback(diagnostics_handler::DiagnosticsHandler)
DiagnosticsCallbacks(diagnostics_handler::DiagnosticsHandler)
Translate a `DiagnosticsHandler` into two SciML callbacks ready to be used.
The first updates internal counters in `diagnostics_handler` that check if the diagnostics
have to be computed. The second actually computes and outputs the diagnostics.
Translate a `DiagnosticsHandler` into a SciML callback ready to be used.
"""
function DiagnosticsCallback(diagnostics_handler::DiagnosticsHandler)
sciml_callback(integrator) =
function DiagnosticsCallbacks(diagnostics_handler::DiagnosticsHandler)

# We use trivial condition to update the condition to run orchestrate_diagnostics
trivial_condition = (_, _, _) -> true

sciml_callback_update_diagnostic_handler_bitvectors(integrator) =
update_diagnostic_handler_bitvectors(integrator, diagnostics_handler)

orchestrate_condition =
(_, _, integrator) ->
check_callback_condition(integrator, diagnostics_handler)

sciml_callback_orchestrate_diagnostics(integrator) =
orchestrate_diagnostics(integrator, diagnostics_handler)

# SciMLBase.DiscreteCallback checks if the given condition is true at the end of each
# step. So, we set a condition that is always true, the callback is called at the end of
# every step. This callback runs `orchestrate_callbacks`, which manages which
# diagnostics functions to call
condition = (_, _, _) -> true
continuous_callbacks = ()
discrete_callbacks = (
SciMLBase.DiscreteCallback(
trivial_condition,
sciml_callback_update_diagnostic_handler_bitvectors,
),
SciMLBase.DiscreteCallback(
orchestrate_condition,
sciml_callback_orchestrate_diagnostics,
),
)

return SciMLBase.DiscreteCallback(condition, sciml_callback)
return SciMLBase.CallbackSet(continuous_callbacks, discrete_callbacks)
end

"""
Expand All @@ -263,7 +333,7 @@ end
Return a new `integrator` with diagnostics defined by `scheduled_diagnostics`.
`IntegratorWithDiagnostics` is conceptually similar to defining a `DiagnosticsHandler`,
constructing its associated `DiagnosticsCallback`, and adding such callback to a given
constructing its associated `DiagnosticsCallbacks`, and adding such callbacks to a given
integrator.
The new integrator is identical to the previous one with the only difference that it has a
Expand All @@ -284,11 +354,12 @@ function IntegratorWithDiagnostics(integrator, scheduled_diagnostics)
integrator.t;
integrator.dt,
)
diagnostics_callback = DiagnosticsCallback(diagnostics_handler)
diagnostics_callbacks =
DiagnosticsCallbacks(diagnostics_handler).discrete_callbacks

continuous_callbacks = integrator.callback.continuous_callbacks
discrete_callbacks =
(integrator.callback.discrete_callbacks..., diagnostics_callback)
(integrator.callback.discrete_callbacks..., diagnostics_callbacks...)
callback = SciMLBase.CallbackSet(continuous_callbacks, discrete_callbacks)

Accessors.@reset integrator.callback = callback
Expand Down
19 changes: 15 additions & 4 deletions test/diagnostics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,28 @@ include("TestTools.jl")
dt,
)

diag_cb = ClimaDiagnostics.DiagnosticsCallback(diagnostic_handler)
diag_cbs = ClimaDiagnostics.DiagnosticsCallbacks(diagnostic_handler)

prob = SciMLBase.ODEProblem(
ClimaTimeSteppers.ClimaODEFunction(T_exp! = exp_tendency!),
Y,
(t0, tf),
p,
)

@test ClimaDiagnostics.check_callback_condition(
prob,
diagnostic_handler,
) === false

algo = ClimaTimeSteppers.ExplicitAlgorithm(ClimaTimeSteppers.RK4())

SciMLBase.solve(prob, algo, dt = dt, callback = diag_cb)
SciMLBase.solve(prob, algo, dt = dt, callback = diag_cbs)

@test ClimaDiagnostics.check_callback_condition(
prob,
diagnostic_handler,
) === true

@test length(keys(dict_writer.dict[short_name])) ==
convert(Int, 1 + (tf - t0) / dt)
Expand All @@ -100,7 +111,7 @@ include("TestTools.jl")
dt,
)

diag_cb = ClimaDiagnostics.DiagnosticsCallback(diagnostic_handler)
diag_cbs = ClimaDiagnostics.DiagnosticsCallbacks(diagnostic_handler)

prob = SciMLBase.ODEProblem(
ClimaTimeSteppers.ClimaODEFunction(T_exp! = exp_tendency!),
Expand All @@ -110,7 +121,7 @@ include("TestTools.jl")
)
algo = ClimaTimeSteppers.ExplicitAlgorithm(ClimaTimeSteppers.RK4())

SciMLBase.solve(prob, algo, dt = dt, callback = diag_cb)
SciMLBase.solve(prob, algo, dt = dt, callback = diag_cbs)

@test length(keys(dict_writer.dict[short_name])) ==
convert(Int, (tf - t0) / 5dt)
Expand Down

0 comments on commit c925edc

Please sign in to comment.