Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convenience kwargs in JLD2OutputWriter constructor for averaging output #887

Merged
merged 5 commits into from
Sep 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 84 additions & 19 deletions src/OutputWriters/jld2_output_writer.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using Printf
using JLD2
using Oceananigans.Utils
using Oceananigans.Diagnostics: WindowedTimeAverage
import Oceananigans.Diagnostics: get_kernel

"""
JLD2OutputWriter{I, T, O, IF, IN, KW} <: AbstractOutputWriter
Expand All @@ -26,41 +28,102 @@ end
noinit(args...) = nothing

"""
JLD2OutputWriter(model, outputs; iteration_interval=nothing, time_interval=nothing, dir=".",
prefix="", init=noinit, including=[:grid, :coriolis, :buoyancy, :closure],
part=1, max_filesize=Inf, force=false, async=false, verbose=false, jld2_kw=Dict{Symbol, Any}())
JLD2OutputWriter(model, outputs; prefix,
dir = ".",
iteration_interval = nothing,
time_interval = nothing,
time_averaging_window = nothing,
time_averaging_stride = 1,
max_filesize = Inf,
force = false,
init = noinit,
async = false,
verbose = false,
including = [:grid, :coriolis, :buoyancy, :closure],
part = 1,
jld2_kw = Dict{Symbol, Any}())

Construct a `JLD2OutputWriter` that writes `label, func` pairs in `outputs` (which can be a `Dict` or `NamedTuple`)
to a JLD2 file, where `label` is a symbol that labels the output and `func` is a function of the form `func(model)`
that returns the data to be saved.

Keyword arguments
=================
- `prefix`: Descriptive filename prefixed to all output files.

- `dir`: Directory to save output to.
Default: "." (current working directory).

- `iteration_interval`: Save output every `n` model iterations.

- `time_interval`: Save output every `t` units of model clock time.
- `dir`: Directory to save output to. Default: "." (current working directory).
- `prefix`: Descriptive filename prefixed to all output files. Default: "".
- `init`: A function of the form `init(file, model)` that runs when a JLD2 output file is initialized.
Default: `noinit(args...) = nothing`.
- `including`: List of model properties to save with every file.
Default: `[:grid, :coriolis, :buoyancy, :closure]`
- `part`: The starting part number used if `max_filesize` is finite. Default: 1.

- `time_averaging_window`: Specifies a time window over which each member of `output` is averaged before
being saved. For this each member of output is converted to
`Oceananigans.Diagnostics.WindowedTimeAverage`.
Default `nothing` indicates no averaging.

- `time_averaging_stride`: Specifies a iteration 'stride' between the calculation of each `output` during
time-averaging. Longer strides means that output is calculated less frequently,
and that the resulting time-average is less accurate.
Default: 1.

- `max_filesize`: The writer will stop writing to the output file once the file size exceeds `max_filesize`,
and write to a new one with a consistent naming scheme ending in `part1`, `part2`, etc. Defaults to `Inf`.
- `force`: Remove existing files if their filenames conflict. Default: `false`.
- `async`: Write output asynchronously. Default: `false`.
and write to a new one with a consistent naming scheme ending in `part1`, `part2`, etc.
Defaults to `Inf`.

- `force`: Remove existing files if their filenames conflict.
Default: `false`.

- `init`: A function of the form `init(file, model)` that runs when a JLD2 output file is initialized.
Default: `noinit(args...) = nothing`.

- `async`: Write output asynchronously.
Default: `false`.

- `verbose`: Log what the output writer is doing with statistics on compute/write times and file sizes.
Default: `false`.
Default: `false`.

- `including`: List of model properties to save with every file.
Default: `[:grid, :coriolis, :buoyancy, :closure]`

- `part`: The starting part number used if `max_filesize` is finite.
Default: 1.

- `jld2_kw`: Dict of kwargs to be passed to `jldopen` when data is written.
"""
function JLD2OutputWriter(model, outputs; iteration_interval=nothing, time_interval=nothing,
dir=".", prefix="", init=noinit,
including=[:grid, :coriolis, :buoyancy, :closure],
part=1, max_filesize=Inf, force=false, async=false, verbose=false,
jld2_kw=Dict{Symbol, Any}())
function JLD2OutputWriter(model, outputs; prefix,
dir = ".",
iteration_interval = nothing,
time_interval = nothing,
time_averaging_window = nothing,
time_averaging_stride = 1,
max_filesize = Inf,
force = false,
init = noinit,
async = false,
verbose = false,
including = [:grid, :coriolis, :buoyancy, :closure],
part = 1,
jld2_kw = Dict{Symbol, Any}())

validate_intervals(iteration_interval, time_interval)

# Convert each output to WindowedTimeAverage if time_averaging_window is specified
if !isnothing(time_averaging_window)

!isnothing(iteration_interval) && error("Cannot specify iteration_interval with time_averaging_window.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might throw an ArgumentError as it's more specific and allows us to @test_throws the constructor.


output_names = Tuple(keys(outputs))

averaged_output = Tuple(WindowedTimeAverage(outputs[name]; time_interval = time_interval,
time_window = time_averaging_window,
stride = time_averaging_stride)
for name in output_names)

outputs = NamedTuple{output_names}(averaged_output)
end

mkpath(dir)
filepath = joinpath(dir, prefix * ".jld2")
force && isfile(filepath) && rm(filepath, force=true)
Expand Down Expand Up @@ -161,6 +224,8 @@ end
FieldOutput(field) = FieldOutput(Array, field) # default
(fo::FieldOutput)(model) = fo.return_type(fo.field.data.parent)

get_kernel(kernel::FieldOutput) = parent(kernel.field)

"""
FieldOutputs(fields)

Expand Down
25 changes: 21 additions & 4 deletions test/test_output_writers.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Statistics
using NCDatasets
using Oceananigans.BoundaryConditions: PBC, FBC, ZFBC
using Oceananigans.Diagnostics

function run_thermal_bubble_netcdf_tests(arch)
Nx, Ny, Nz = 16, 16, 16
Expand Down Expand Up @@ -518,6 +519,20 @@ function dependencies_added_correctly!(model, windowed_time_average, output_writ
return windowed_time_average ∈ values(simulation.diagnostics)
end

function jld2_time_averaging_window(model)

model.clock.iteration = 0
model.clock.time = 0.0

output = FieldOutputs(model.velocities)

jld2_output_writer = JLD2OutputWriter(model, output, time_interval=4.0, time_averaging_window=2.0,
dir=".", prefix="test", force=true)

outputs_are_time_averaged = Tuple(typeof(out) <: WindowedTimeAverage for out in jld2_output_writer.outputs)

return all(outputs_are_time_averaged)
end

@testset "Output writers" begin
@info "Testing output writers..."
Expand All @@ -543,24 +558,26 @@ end
@hascuda run_cross_architecture_checkpointer_tests(GPU(), CPU())
end

@testset "Output writer 'diagnostic dependencies' [$(typeof(arch))]" begin
@info " Testing output writer diagnostic-dependencies [$(typeof(arch))]..."
@testset "Output writer averaging and 'diagnostic dependencies' [$(typeof(arch))]" begin
@info " Testing output writer time-averaging and diagnostic-dependencies [$(typeof(arch))]..."

grid = RegularCartesianGrid(size=(16, 16, 16), extent=(1, 1, 1))
model = IncompressibleModel(architecture=arch, grid=grid)

@test jld2_time_averaging_window(model)

windowed_time_average = WindowedTimeAverage(model.velocities.u, time_window=2.0, time_interval=4.0)

output = Dict("time_average" => windowed_time_average)
attributes = Dict("time_average" => Dict("longname" => "A time average", "units" => "arbitrary"))
dimensions = Dict("time_average" => ("xF", "yC", "zC"))

# JLD2 test
# JLD2 dependencies test
jld2_output_writer = JLD2OutputWriter(model, output, time_interval=4.0, dir=".", prefix="test", force=true)

@test dependencies_added_correctly!(model, windowed_time_average, jld2_output_writer)

# NetCDF test
# NetCDF dependency test
netcdf_output_writer =
NetCDFOutputWriter(model, output, time_interval=4.0, filename="test.nc", with_halos=true,
output_attributes=attributes, dimensions=dimensions)
Expand Down