Skip to content

Commit

Permalink
Add NaN checker to simulations by default (#1198)
Browse files Browse the repository at this point in the history
  • Loading branch information
ali-ramadhan authored Dec 4, 2020
2 parents 5ddace1 + e3afd05 commit c3b688f
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 27 deletions.
1 change: 1 addition & 0 deletions src/Diagnostics/Diagnostics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ export
run_diagnostic!,
TimeInterval, IterationInterval, WallTimeInterval

using CUDA
using Oceananigans
using Oceananigans.Operators
using Oceananigans.Utils: TimeInterval, IterationInterval, WallTimeInterval
Expand Down
19 changes: 8 additions & 11 deletions src/Diagnostics/nan_checker.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
"""
NaNChecker{F} <: AbstractDiagnostic
A diagnostic that checks for `NaN` values and aborts the simulation if any are found.
"""
struct NaNChecker{T, F} <: AbstractDiagnostic
schedule :: T
fields :: F
Expand All @@ -11,16 +6,18 @@ end
"""
NaNChecker(; schedule, fields)
Returns a `NaNChecker` that checks for `NaN` anywhere within `fields`
when `schedule` actuates.
Returns a `NaNChecker` that checks for a `NaN` anywhere in `fields` when `schedule` actuates.
`fields` should be a named tuple. The simulation is aborted if a `NaN` is found.
"""
NaNChecker(model=nothing; schedule, fields) = NaNChecker(schedule, fields)

function run_diagnostic!(nc::NaNChecker, model)
for (name, field) in nc.fields
if any(isnan, field.data.parent)
t, i = model.clock.time, model.clock.iteration
error("time = $t, iteration = $i: NaN found in $name. Aborting simulation.")
for (name, field) in pairs(nc.fields)
CUDA.@allowscalar begin
if any(isnan, field.data.parent)
t, i = model.clock.time, model.clock.iteration
error("time = $t, iteration = $i: NaN found in $name. Aborting simulation.")
end
end
end
end
3 changes: 3 additions & 0 deletions src/Simulations/simulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ function Simulation(model; Δt,
"recalculate the time step every iteration which can be slow."
end

diagnostics[:nan_checker] = NaNChecker(fields=(u=model.velocities.u,),
schedule=IterationInterval(iteration_interval))

run_time = 0.0

return Simulation(model, Δt, stop_criteria, stop_iteration, stop_time, wall_time_limit,
Expand Down
31 changes: 15 additions & 16 deletions test/test_diagnostics.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
using Oceananigans.Diagnostics

function nan_checker_aborts_simulation(arch, FT)
function nan_checker_aborts_simulation(arch)
grid = RegularCartesianGrid(size=(4, 2, 1), extent=(1, 1, 1))
model = IncompressibleModel(grid=grid, architecture=arch, float_type=FT)
model = IncompressibleModel(grid=grid, architecture=arch)
simulation = Simulation(model, Δt=1, stop_iteration=1)

model.velocities.u[1, 1, 1] = NaN

# It checks for NaNs in w by default.
nc = NaNChecker(model; schedule=IterationInterval(1), fields=Dict(:w => model.velocities.w.data.parent))
push!(model.diagnostics, nc)
run!(simulation)

model.velocities.w[3, 2, 1] = NaN

time_step!(model, 1, 1)
return nothing
end

TestModel(::GPU, FT, ν=1.0, Δx=0.5) =
Expand Down Expand Up @@ -68,18 +67,20 @@ get_time(model) = model.clock.time
function diagnostics_getindex(arch, FT)
model = TestModel(arch, FT)
simulation = Simulation(model, Δt=0, stop_iteration=0)
nc = NaNChecker(model; schedule=IterationInterval(1), fields=Dict(:w => model.velocities.w.data.parent))
nc = NaNChecker(model, schedule=IterationInterval(1), fields=model.velocities)
simulation.diagnostics[:nc] = nc
return simulation.diagnostics[1] == nc

# The first diagnostic is the NaN checker.
return simulation.diagnostics[2] == nc
end

function diagnostics_setindex(arch, FT)
model = TestModel(arch, FT)
simulation = Simulation(model, Δt=0, stop_iteration=0)

nc1 = NaNChecker(model; schedule=IterationInterval(1), fields=Dict(:w => model.velocities.w.data.parent))
nc2 = NaNChecker(model; schedule=IterationInterval(2), fields=Dict(:u => model.velocities.u.data.parent))
nc3 = NaNChecker(model; schedule=IterationInterval(3), fields=Dict(:v => model.velocities.v.data.parent))
nc1 = NaNChecker(model, schedule=IterationInterval(1), fields=model.velocities)
nc2 = NaNChecker(model, schedule=IterationInterval(2), fields=model.velocities)
nc3 = NaNChecker(model, schedule=IterationInterval(3), fields=model.velocities)

push!(simulation.diagnostics, nc1, nc2)
simulation.diagnostics[2] = nc3
Expand All @@ -93,9 +94,7 @@ end
for arch in archs
@testset "NaN Checker [$(typeof(arch))]" begin
@info " Testing NaN Checker [$(typeof(arch))]"
for FT in float_types
@test_throws ErrorException nan_checker_aborts_simulation(arch, FT)
end
@test_throws ErrorException nan_checker_aborts_simulation(arch)
end
end

Expand Down

0 comments on commit c3b688f

Please sign in to comment.