From 5dd01ddfe73dd03515b52e6499b27c1b61c71364 Mon Sep 17 00:00:00 2001 From: "Gregory L. Wagner" Date: Fri, 8 Nov 2024 08:55:46 -0700 Subject: [PATCH] Use on_architecture more liberally with boundary conditions (#3893) * Use on_architecture more liberally with boundary conditions * Add test for FieldTimeSeries when CuArray bcs * Too many ends * Fix location of y-momentum flux * Rm data after test * Bump version * Bump Enzyme version * Downgrade Enzyme --------- Co-authored-by: Navid C. Constantinou --- Project.toml | 4 +-- .../field_boundary_conditions.jl | 21 ++++++++---- src/OutputReaders/field_time_series.jl | 28 ++++++++------- src/OutputWriters/output_writer_utils.jl | 8 ++--- test/test_output_readers.jl | 34 +++++++++++++++++++ 5 files changed, 70 insertions(+), 25 deletions(-) diff --git a/Project.toml b/Project.toml index 4d80cee5b7..30ba117d7a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Oceananigans" uuid = "9e8cae18-63c1-5223-a75c-80ca9d6e9a09" authors = ["Climate Modeling Alliance and contributors"] -version = "0.93.2" +version = "0.93.3" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -50,7 +50,7 @@ CubedSphere = "0.2, 0.3" Dates = "1.9" Distances = "0.10" DocStringExtensions = "0.8, 0.9" -Enzyme = "0.13.3" +Enzyme = "0.13.13" FFTW = "1" Glob = "1.3" IncompleteLU = "0.2" diff --git a/src/BoundaryConditions/field_boundary_conditions.jl b/src/BoundaryConditions/field_boundary_conditions.jl index a42beeee2c..a5a9862b1d 100644 --- a/src/BoundaryConditions/field_boundary_conditions.jl +++ b/src/BoundaryConditions/field_boundary_conditions.jl @@ -39,12 +39,12 @@ default_auxiliary_bc(::LeftConnected, ::Face) = nothing ##### mutable struct FieldBoundaryConditions{W, E, S, N, B, T, I} - west :: W - east :: E - south :: S - north :: N - bottom :: B - top :: T + west :: W + east :: E + south :: S + north :: N + bottom :: B + top :: T immersed :: I end @@ -65,6 +65,15 @@ FieldBoundaryConditions(indices::Tuple, ::Nothing) = nothing window_boundary_conditions(::Colon, left, right) = left, right window_boundary_conditions(::UnitRange, left, right) = nothing, nothing +on_architecture(arch, fbcs::FieldBoundaryConditions) = + FieldBoundaryConditions(on_architecture(arch, fbcs.west), + on_architecture(arch, fbcs.east), + on_architecture(arch, fbcs.south), + on_architecture(arch, fbcs.north), + on_architecture(arch, fbcs.bottom), + on_architecture(arch, fbcs.top), + on_architecture(arch, fbcs.immersed)) + """ FieldBoundaryConditions(; kwargs...) diff --git a/src/OutputReaders/field_time_series.jl b/src/OutputReaders/field_time_series.jl index a07377e22c..0600bec040 100644 --- a/src/OutputReaders/field_time_series.jl +++ b/src/OutputReaders/field_time_series.jl @@ -214,16 +214,16 @@ Base.length(backend::PartlyInMemory) = backend.length ##### mutable struct FieldTimeSeries{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N, KW} <: AbstractField{LX, LY, LZ, G, ET, 4} - data :: D - grid :: G - backend :: K + data :: D + grid :: G + backend :: K boundary_conditions :: B - indices :: I - times :: χ - path :: P - name :: N - time_indexing :: TI - reader_kw :: KW + indices :: I + times :: χ + path :: P + name :: N + time_indexing :: TI + reader_kw :: KW function FieldTimeSeries{LX, LY, LZ}(data::D, grid::G, @@ -460,10 +460,6 @@ function FieldTimeSeries(path::String, name::String; isnothing(times) && (times = [file["timeseries/t/$i"] for i in iterations]) isnothing(location) && (Location = file["timeseries/$name/serialized/location"]) - if boundary_conditions isa UnspecifiedBoundaryConditions - boundary_conditions = file["timeseries/$name/serialized/boundary_conditions"] - end - indices = try file["timeseries/$name/serialized/indices"] catch @@ -480,6 +476,12 @@ function FieldTimeSeries(path::String, name::String; end end + if boundary_conditions isa UnspecifiedBoundaryConditions + boundary_conditions = file["timeseries/$name/serialized/boundary_conditions"] + boundary_conditions = on_architecture(architecture, boundary_conditions) + end + + # This should be removed eventually... (4/5/2022) grid = try on_architecture(architecture, grid) diff --git a/src/OutputWriters/output_writer_utils.jl b/src/OutputWriters/output_writer_utils.jl index af6c6f6834..7b79755b3b 100644 --- a/src/OutputWriters/output_writer_utils.jl +++ b/src/OutputWriters/output_writer_utils.jl @@ -98,7 +98,7 @@ function saveproperty!(file, address, bcs::FieldBoundaryConditions) if bc.condition isa Function || bc.condition isa ContinuousBoundaryFunction file[address * "/$boundary/condition"] = missing else - file[address * "/$boundary/condition"] = bc.condition + file[address * "/$boundary/condition"] = on_architecture(CPU(), bc.condition) end end end @@ -130,13 +130,13 @@ function serializeproperty!(file, address, grid::DistributedGrid) file[address] = on_architecture(cpu_arch, grid) end -function serializeproperty!(file, address, p::FieldBoundaryConditions) +function serializeproperty!(file, address, fbcs::FieldBoundaryConditions) # TODO: it'd be better to "filter" `FieldBoundaryCondition` and then serialize # rather than punting with `missing` instead. - if has_reference(Function, p) + if has_reference(Function, fbcs) file[address] = missing else - file[address] = p + file[address] = on_architecture(CPU(), fbcs) end end diff --git a/test/test_output_readers.jl b/test/test_output_readers.jl index 7d4657cd81..805fc1c8c0 100644 --- a/test/test_output_readers.jl +++ b/test/test_output_readers.jl @@ -221,6 +221,40 @@ end @test v1[2] isa Field end end + + if arch isa GPU + @testset "FieldTimeSeries with CuArray boundary conditions [$(typeof(arch))]" begin + @info " Testing FieldTimeSeries with CuArray boundary conditions..." + + x = y = z = (0, 1) + grid = RectilinearGrid(GPU(); size=(1, 1, 1), x, y, z) + + τx = CuArray(zeros(size(grid)...)) + τy = Field{Center, Face, Nothing}(grid) + u_bcs = FieldBoundaryConditions(top = FluxBoundaryCondition(τx)) + v_bcs = FieldBoundaryConditions(top = FluxBoundaryCondition(τy)) + model = NonhydrostaticModel(; grid, boundary_conditions = (; u=u_bcs, v=v_bcs)) + simulation = Simulation(model; Δt=1, stop_iteration=1) + + simulation.output_writers[:jld2] = JLD2OutputWriter(model, model.velocities, + filename = "test_cuarray_bc.jld2", + schedule=IterationInterval(1), + overwrite_existing = true) + + run!(simulation) + + ut = FieldTimeSeries("test_cuarray_bc.jld2", "u") + vt = FieldTimeSeries("test_cuarray_bc.jld2", "v") + @test ut.boundary_conditions.top.classification isa Flux + @test ut.boundary_conditions.top.condition isa Array + + τy_ow = vt.boundary_conditions.top.condition + @test τy_ow isa Field{Center, Face, Nothing} + @test architecture(τy_ow) isa CPU + @test parent(τy_ow) isa Array + rm("test_cuarray_bc.jld2") + end + end end for arch in archs