From 41220d85351522a5b69610b1fbe258af658cea60 Mon Sep 17 00:00:00 2001 From: ali-ramadhan Date: Mon, 23 Nov 2020 21:38:59 -0500 Subject: [PATCH 1/6] Faster NaN checker that accepts named tuples. --- src/Diagnostics/Diagnostics.jl | 1 + src/Diagnostics/nan_checker.jl | 21 ++++++++++----------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/Diagnostics/Diagnostics.jl b/src/Diagnostics/Diagnostics.jl index cde101d8bf..26ed43f156 100644 --- a/src/Diagnostics/Diagnostics.jl +++ b/src/Diagnostics/Diagnostics.jl @@ -7,6 +7,7 @@ export run_diagnostic!, TimeInterval, IterationInterval, WallTimeInterval +using CUDA using Oceananigans using Oceananigans.Operators using Oceananigans.Utils: TimeInterval, IterationInterval, WallTimeInterval diff --git a/src/Diagnostics/nan_checker.jl b/src/Diagnostics/nan_checker.jl index 5dc75e98cf..70255687ad 100644 --- a/src/Diagnostics/nan_checker.jl +++ b/src/Diagnostics/nan_checker.jl @@ -1,8 +1,3 @@ -""" - NaNChecker{F} <: AbstractDiagnostic - -A diagnostic that checks for `NaN` values and aborts the simulation if any are found. -""" struct NaNChecker{T, F} <: AbstractDiagnostic schedule :: T fields :: F @@ -11,16 +6,20 @@ end """ NaNChecker(; schedule, fields) -Returns a `NaNChecker` that checks for `NaN` anywhere within `fields` -when `schedule` actuates. +Returns a `NaNChecker` that checks for a `NaN` in the first grid points of `fields` +when `schedule` actuates. `fields` should be a named tuple. + +The simulation is aborted if a `NaN` is found. """ NaNChecker(model=nothing; schedule, fields) = NaNChecker(schedule, fields) function run_diagnostic!(nc::NaNChecker, model) - for (name, field) in nc.fields - if any(isnan, field.data.parent) - t, i = model.clock.time, model.clock.iteration - error("time = $t, iteration = $i: NaN found in $name. Aborting simulation.") + for (name, field) in pairs(nc.fields) + CUDA.@allowscalar begin + if isnan(field.data[1, 1, 1]) + t, i = model.clock.time, model.clock.iteration + error("time = $t, iteration = $i: NaN found in $name. Aborting simulation.") + end end end end From 4ad555cc99be41b8a19d3f183092093d8b482973 Mon Sep 17 00:00:00 2001 From: ali-ramadhan Date: Mon, 23 Nov 2020 21:44:40 -0500 Subject: [PATCH 2/6] Add a NaN checker to simulations by default. --- src/Simulations/simulation.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/Simulations/simulation.jl b/src/Simulations/simulation.jl index 1219b5a703..9a988f9b29 100644 --- a/src/Simulations/simulation.jl +++ b/src/Simulations/simulation.jl @@ -69,6 +69,8 @@ function Simulation(model; Δt, "recalculate the time step every iteration which can be slow." end + diagnostics[:nan_checker] = NaNChecker(fields=model.velocities, schedule=IterationInterval(1)) + run_time = 0.0 return Simulation(model, Δt, stop_criteria, stop_iteration, stop_time, wall_time_limit, From 3889687e97ba6e920fdc4a7964492d5d0a670f4c Mon Sep 17 00:00:00 2001 From: ali-ramadhan Date: Mon, 23 Nov 2020 21:48:06 -0500 Subject: [PATCH 3/6] Update tests --- test/test_diagnostics.jl | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/test/test_diagnostics.jl b/test/test_diagnostics.jl index c71f888572..49d3253e6c 100644 --- a/test/test_diagnostics.jl +++ b/test/test_diagnostics.jl @@ -1,16 +1,15 @@ using Oceananigans.Diagnostics -function nan_checker_aborts_simulation(arch, FT) +function nan_checker_aborts_simulation(arch) grid = RegularCartesianGrid(size=(4, 2, 1), extent=(1, 1, 1)) - model = IncompressibleModel(grid=grid, architecture=arch, float_type=FT) + model = IncompressibleModel(grid=grid, architecture=arch) + simulation = Simulation(model, Δt=1, stop_iteration=1) + + model.velocities.u[1, 1, 1] = NaN - # It checks for NaNs in w by default. - nc = NaNChecker(model; schedule=IterationInterval(1), fields=Dict(:w => model.velocities.w.data.parent)) - push!(model.diagnostics, nc) + run!(simulation) - model.velocities.w[3, 2, 1] = NaN - - time_step!(model, 1, 1) + return nothing end TestModel(::GPU, FT, ν=1.0, Δx=0.5) = @@ -68,7 +67,7 @@ get_time(model) = model.clock.time function diagnostics_getindex(arch, FT) model = TestModel(arch, FT) simulation = Simulation(model, Δt=0, stop_iteration=0) - nc = NaNChecker(model; schedule=IterationInterval(1), fields=Dict(:w => model.velocities.w.data.parent)) + nc = NaNChecker(model, schedule=IterationInterval(1), fields=model.velocities) simulation.diagnostics[:nc] = nc return simulation.diagnostics[1] == nc end @@ -77,9 +76,9 @@ function diagnostics_setindex(arch, FT) model = TestModel(arch, FT) simulation = Simulation(model, Δt=0, stop_iteration=0) - nc1 = NaNChecker(model; schedule=IterationInterval(1), fields=Dict(:w => model.velocities.w.data.parent)) - nc2 = NaNChecker(model; schedule=IterationInterval(2), fields=Dict(:u => model.velocities.u.data.parent)) - nc3 = NaNChecker(model; schedule=IterationInterval(3), fields=Dict(:v => model.velocities.v.data.parent)) + nc1 = NaNChecker(model, schedule=IterationInterval(1), fields=model.velocities) + nc2 = NaNChecker(model, schedule=IterationInterval(2), fields=model.velocities) + nc3 = NaNChecker(model, schedule=IterationInterval(3), fields=model.velocities) push!(simulation.diagnostics, nc1, nc2) simulation.diagnostics[2] = nc3 @@ -93,9 +92,7 @@ end for arch in archs @testset "NaN Checker [$(typeof(arch))]" begin @info " Testing NaN Checker [$(typeof(arch))]" - for FT in float_types - @test_throws ErrorException nan_checker_aborts_simulation(arch, FT) - end + @test_throws ErrorException nan_checker_aborts_simulation(arch) end end From b79fb7527a008c95cdf6706a3b40b3563671c4a4 Mon Sep 17 00:00:00 2001 From: ali-ramadhan Date: Tue, 24 Nov 2020 08:03:07 -0500 Subject: [PATCH 4/6] Compromise and check all u for NaN every 100 iterations --- src/Diagnostics/nan_checker.jl | 8 +++----- src/Simulations/simulation.jl | 3 ++- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/Diagnostics/nan_checker.jl b/src/Diagnostics/nan_checker.jl index 70255687ad..7cf7f813e1 100644 --- a/src/Diagnostics/nan_checker.jl +++ b/src/Diagnostics/nan_checker.jl @@ -6,17 +6,15 @@ end """ NaNChecker(; schedule, fields) -Returns a `NaNChecker` that checks for a `NaN` in the first grid points of `fields` -when `schedule` actuates. `fields` should be a named tuple. - -The simulation is aborted if a `NaN` is found. +Returns a `NaNChecker` that checks for a `NaN` anywhere in `fields` when `schedule` actuates. +`fields` should be a named tuple. The simulation is aborted if a `NaN` is found. """ NaNChecker(model=nothing; schedule, fields) = NaNChecker(schedule, fields) function run_diagnostic!(nc::NaNChecker, model) for (name, field) in pairs(nc.fields) CUDA.@allowscalar begin - if isnan(field.data[1, 1, 1]) + if any(isnan, field.data.parent) t, i = model.clock.time, model.clock.iteration error("time = $t, iteration = $i: NaN found in $name. Aborting simulation.") end diff --git a/src/Simulations/simulation.jl b/src/Simulations/simulation.jl index 9a988f9b29..c9ede12294 100644 --- a/src/Simulations/simulation.jl +++ b/src/Simulations/simulation.jl @@ -69,7 +69,8 @@ function Simulation(model; Δt, "recalculate the time step every iteration which can be slow." end - diagnostics[:nan_checker] = NaNChecker(fields=model.velocities, schedule=IterationInterval(1)) + diagnostics[:nan_checker] = NaNChecker(fields=(u=model.velocities.u,), + schedule=IterationInterval(100)) run_time = 0.0 From a03f3bbabd530ec400ebb0601c4024cfb80ff660 Mon Sep 17 00:00:00 2001 From: ali-ramadhan Date: Tue, 24 Nov 2020 08:36:37 -0500 Subject: [PATCH 5/6] Fix diagnostics_getindex test --- test/test_diagnostics.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/test_diagnostics.jl b/test/test_diagnostics.jl index 49d3253e6c..9e32704c72 100644 --- a/test/test_diagnostics.jl +++ b/test/test_diagnostics.jl @@ -69,7 +69,9 @@ function diagnostics_getindex(arch, FT) simulation = Simulation(model, Δt=0, stop_iteration=0) nc = NaNChecker(model, schedule=IterationInterval(1), fields=model.velocities) simulation.diagnostics[:nc] = nc - return simulation.diagnostics[1] == nc + + # The first diagnostic is the NaN checker. + return simulation.diagnostics[2] == nc end function diagnostics_setindex(arch, FT) From e3afd055f6349b731f8a4146bc08b57e1fd54edc Mon Sep 17 00:00:00 2001 From: ali-ramadhan Date: Tue, 24 Nov 2020 16:43:35 -0500 Subject: [PATCH 6/6] Check for NaN every iteration_interval --- src/Simulations/simulation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Simulations/simulation.jl b/src/Simulations/simulation.jl index c9ede12294..27b4430854 100644 --- a/src/Simulations/simulation.jl +++ b/src/Simulations/simulation.jl @@ -70,7 +70,7 @@ function Simulation(model; Δt, end diagnostics[:nan_checker] = NaNChecker(fields=(u=model.velocities.u,), - schedule=IterationInterval(100)) + schedule=IterationInterval(iteration_interval)) run_time = 0.0