Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use on_architecture more liberally with boundary conditions #3893

Merged
merged 11 commits into from
Nov 8, 2024
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Oceananigans"
uuid = "9e8cae18-63c1-5223-a75c-80ca9d6e9a09"
authors = ["Climate Modeling Alliance and contributors"]
version = "0.93.2"
version = "0.93.3"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down Expand Up @@ -50,7 +50,7 @@ CubedSphere = "0.2, 0.3"
Dates = "1.9"
Distances = "0.10"
DocStringExtensions = "0.8, 0.9"
Enzyme = "0.13.3"
Enzyme = "0.13.13"
FFTW = "1"
Glob = "1.3"
IncompleteLU = "0.2"
Expand Down
21 changes: 15 additions & 6 deletions src/BoundaryConditions/field_boundary_conditions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ default_auxiliary_bc(::LeftConnected, ::Face) = nothing
#####

mutable struct FieldBoundaryConditions{W, E, S, N, B, T, I}
west :: W
east :: E
south :: S
north :: N
bottom :: B
top :: T
west :: W
east :: E
south :: S
north :: N
bottom :: B
top :: T
immersed :: I
end

Expand All @@ -65,6 +65,15 @@ FieldBoundaryConditions(indices::Tuple, ::Nothing) = nothing
window_boundary_conditions(::Colon, left, right) = left, right
window_boundary_conditions(::UnitRange, left, right) = nothing, nothing

on_architecture(arch, fbcs::FieldBoundaryConditions) =
FieldBoundaryConditions(on_architecture(arch, fbcs.west),
on_architecture(arch, fbcs.east),
on_architecture(arch, fbcs.south),
on_architecture(arch, fbcs.north),
on_architecture(arch, fbcs.bottom),
on_architecture(arch, fbcs.top),
on_architecture(arch, fbcs.immersed))

"""
FieldBoundaryConditions(; kwargs...)

Expand Down
28 changes: 15 additions & 13 deletions src/OutputReaders/field_time_series.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,16 +214,16 @@ Base.length(backend::PartlyInMemory) = backend.length
#####

mutable struct FieldTimeSeries{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N, KW} <: AbstractField{LX, LY, LZ, G, ET, 4}
data :: D
grid :: G
backend :: K
data :: D
grid :: G
backend :: K
boundary_conditions :: B
indices :: I
times :: χ
path :: P
name :: N
time_indexing :: TI
reader_kw :: KW
indices :: I
times :: χ
path :: P
name :: N
time_indexing :: TI
reader_kw :: KW

function FieldTimeSeries{LX, LY, LZ}(data::D,
grid::G,
Expand Down Expand Up @@ -460,10 +460,6 @@ function FieldTimeSeries(path::String, name::String;
isnothing(times) && (times = [file["timeseries/t/$i"] for i in iterations])
isnothing(location) && (Location = file["timeseries/$name/serialized/location"])

if boundary_conditions isa UnspecifiedBoundaryConditions
boundary_conditions = file["timeseries/$name/serialized/boundary_conditions"]
end

indices = try
file["timeseries/$name/serialized/indices"]
catch
Expand All @@ -480,6 +476,12 @@ function FieldTimeSeries(path::String, name::String;
end
end

if boundary_conditions isa UnspecifiedBoundaryConditions
boundary_conditions = file["timeseries/$name/serialized/boundary_conditions"]
boundary_conditions = on_architecture(architecture, boundary_conditions)
end


# This should be removed eventually... (4/5/2022)
grid = try
on_architecture(architecture, grid)
Expand Down
8 changes: 4 additions & 4 deletions src/OutputWriters/output_writer_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ function saveproperty!(file, address, bcs::FieldBoundaryConditions)
if bc.condition isa Function || bc.condition isa ContinuousBoundaryFunction
file[address * "/$boundary/condition"] = missing
else
file[address * "/$boundary/condition"] = bc.condition
file[address * "/$boundary/condition"] = on_architecture(CPU(), bc.condition)
end
end
end
Expand Down Expand Up @@ -130,13 +130,13 @@ function serializeproperty!(file, address, grid::DistributedGrid)
file[address] = on_architecture(cpu_arch, grid)
end

function serializeproperty!(file, address, p::FieldBoundaryConditions)
function serializeproperty!(file, address, fbcs::FieldBoundaryConditions)
# TODO: it'd be better to "filter" `FieldBoundaryCondition` and then serialize
# rather than punting with `missing` instead.
if has_reference(Function, p)
if has_reference(Function, fbcs)
file[address] = missing
else
file[address] = p
file[address] = on_architecture(CPU(), fbcs)
end
end

Expand Down
34 changes: 34 additions & 0 deletions test/test_output_readers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,40 @@ end
@test v1[2] isa Field
end
end

if arch isa GPU
@testset "FieldTimeSeries with CuArray boundary conditions [$(typeof(arch))]" begin
@info " Testing FieldTimeSeries with CuArray boundary conditions..."

x = y = z = (0, 1)
grid = RectilinearGrid(GPU(); size=(1, 1, 1), x, y, z)

τx = CuArray(zeros(size(grid)...))
τy = Field{Center, Face, Nothing}(grid)
u_bcs = FieldBoundaryConditions(top = FluxBoundaryCondition(τx))
v_bcs = FieldBoundaryConditions(top = FluxBoundaryCondition(τy))
model = NonhydrostaticModel(; grid, boundary_conditions = (; u=u_bcs, v=v_bcs))
simulation = Simulation(model; Δt=1, stop_iteration=1)

simulation.output_writers[:jld2] = JLD2OutputWriter(model, model.velocities,
filename = "test_cuarray_bc.jld2",
schedule=IterationInterval(1),
overwrite_existing = true)

run!(simulation)

ut = FieldTimeSeries("test_cuarray_bc.jld2", "u")
vt = FieldTimeSeries("test_cuarray_bc.jld2", "v")
@test ut.boundary_conditions.top.classification isa Flux
@test ut.boundary_conditions.top.condition isa Array

τy_ow = vt.boundary_conditions.top.condition
@test τy_ow isa Field{Center, Face, Nothing}
@test architecture(τy_ow) isa CPU
@test parent(τy_ow) isa Array
rm("test_cuarray_bc.jld2")
end
end
end

for arch in archs
Expand Down