Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preallocate GPU interpolant #75

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 69 additions & 10 deletions ext/InterpolationsRegridderExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ struct InterpolationsRegridder{
SPACE <: ClimaCore.Spaces.AbstractSpace,
FIELD <: ClimaCore.Fields.Field,
BC,
GITP,
} <: Regridders.AbstractRegridder

"""ClimaCore.Space where the output Field will be defined"""
Expand All @@ -22,6 +23,14 @@ struct InterpolationsRegridder{

"""Tuple of extrapolation conditions as accepted by Interpolations.jl"""
extrapolation_bc::BC

# This is needed because Adapt moves from CPU to GPU and allocates new memory.
"""Dictionary of preallocated areas of memory where to store the GPU interpolant (if
needed). Every time new data/dimensions are used in regrid, a new entry in the
dictionary is created. The keys of the dictionary a tuple of tuple
`(size(dimensions), size(data))`, with `dimensions` and `data` defined in `regrid`.
"""
_gpuitps::GITP
end

# Note, we swap Lat and Long! This is because according to the CF conventions longitude
Expand Down Expand Up @@ -58,6 +67,8 @@ function Regridders.InterpolationsRegridder(
)
coordinates = ClimaCore.Fields.coordinate_field(target_space)

num_dimensions = length(propertynames(coordinates))

if isnothing(extrapolation_bc)
extrapolation_bc = ()
if eltype(coordinates) <: ClimaCore.Geometry.LatLongPoint
Expand All @@ -69,9 +80,42 @@ function Regridders.InterpolationsRegridder(
end
end

return InterpolationsRegridder(target_space, coordinates, extrapolation_bc)
num_dimensions == length(extrapolation_bc) || error(
"Number of boundary conditions does not match the number of dimensions",
)

# Let's figure out the type of _gpuitps by creating a simple spline
FT = ClimaCore.Spaces.undertype(target_space)
dimensions = ntuple(_ -> [zero(FT), one(FT)], num_dimensions)
data = zeros(FT, ntuple(_ -> 2, num_dimensions))
itp = _create_linear_spline(FT, data, dimensions, extrapolation_bc)
fake_gpuitp = Adapt.adapt(ClimaComms.array_type(target_space), itp)
gpuitps = Dict((size.(dimensions), size(data)) => fake_gpuitp)

return InterpolationsRegridder(
target_space,
coordinates,
extrapolation_bc,
gpuitps,
)
end

"""
_create_linear_spline(regridder::InterpolationsRegridder, data, dimensions)

Create a linear spline for the given data on the given dimension (on the CPU).
"""
function _create_linear_spline(FT, data, dimensions, extrapolation_bc)
dimensions_FT = map(d -> FT.(d), dimensions)

# Make a linear spline
return Intp.extrapolate(
Intp.interpolate(dimensions_FT, FT.(data), Intp.Gridded(Intp.Linear())),
extrapolation_bc,
)
end


"""
regrid(regridder::InterpolationsRegridder, data, dimensions)::Field

Expand All @@ -81,16 +125,31 @@ This function is allocating.
"""
function Regridders.regrid(regridder::InterpolationsRegridder, data, dimensions)
FT = ClimaCore.Spaces.undertype(regridder.target_space)
dimensions_FT = map(d -> FT.(d), dimensions)

# Make a linear spline
itp = Intp.extrapolate(
Intp.interpolate(dimensions_FT, FT.(data), Intp.Gridded(Intp.Linear())),
regridder.extrapolation_bc,
)
itp =
_create_linear_spline(FT, data, dimensions, regridder.extrapolation_bc)

key = (size.(dimensions), size(data))

if haskey(regridder._gpuitps, key)
for (k, k_new) in zip(
regridder._gpuitps[key].itp.knots,
Adapt.adapt(
ClimaComms.array_type(regridder.target_space),
itp.itp.knots,
),
)
k .= k_new
end
regridder._gpuitps[key].itp.coefs .= Adapt.adapt(
ClimaComms.array_type(regridder.target_space),
itp.itp.coefs,
)
else
regridder._gpuitps[key] =
Adapt.adapt(ClimaComms.array_type(regridder.target_space), itp)
end

# Move it to GPU (if needed)
gpuitp = Adapt.adapt(ClimaComms.array_type(regridder.target_space), itp)
gpuitp = regridder._gpuitps[key]

return map(regridder.coordinates) do coord
gpuitp(totuple(coord)...)
Expand Down
10 changes: 9 additions & 1 deletion test/TestTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,15 @@ function make_spherical_space(FT; context = ClimaComms.context())
boundary_names = (:bottom, :top),
)
vertmesh = ClimaCore.Meshes.IntervalMesh(vertdomain, nelems = zelem)
vert_center_space = ClimaCore.Spaces.CenterFiniteDifferenceSpace(vertmesh)
if pkgversion(ClimaCore) >= v"0.14.10"
vert_center_space = ClimaCore.Spaces.CenterFiniteDifferenceSpace(
ClimaComms.device(context),
vertmesh,
)
else
vert_center_space =
ClimaCore.Spaces.CenterFiniteDifferenceSpace(vertmesh)
end

horzdomain = ClimaCore.Domains.SphereDomain(radius)
horzmesh = ClimaCore.Meshes.EquiangularCubedSphere(horzdomain, helem)
Expand Down
7 changes: 2 additions & 5 deletions test/data_handling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,10 @@ ClimaComms.init(context)
target_space;
regridder_type = :InterpolationsRegridder,
file_reader_kwargs = (; preprocess_func = (data) -> 0.0 * data),
regridder_kwargs = (;
extrapolation_bc = (Intp.Flat(), Intp.Flat(), Intp.Flat())
),
regridder_kwargs = (; extrapolation_bc = (Intp.Flat(), Intp.Flat())),
)

@test data_handler.regridder.extrapolation_bc ==
(Intp.Flat(), Intp.Flat(), Intp.Flat())
@test data_handler.regridder.extrapolation_bc == (Intp.Flat(), Intp.Flat())
field = DataHandling.regridded_snapshot(data_handler)
@test extrema(field) == (0.0, 0.0)
end
Expand Down
11 changes: 11 additions & 0 deletions test/regridders.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,17 @@ end
extrapolation_bc,
)

# Test num_dimensions != length(extrapolation_bc)
@test_throws ErrorException Regridders.InterpolationsRegridder(
hv_center_space;
extrapolation_bc = (
Interpolations.Periodic(),
Interpolations.Flat(),
Interpolations.Flat(),
Interpolations.Flat(),
),
)

regridded_lat = Regridders.regrid(reg_hv, data_lat3D, dimensions3D)
regridded_lon = Regridders.regrid(reg_hv, data_lon3D, dimensions3D)
regridded_z = Regridders.regrid(reg_hv, data_z3D, dimensions3D)
Expand Down
Loading