Skip to content

Commit

Permalink
Defines many identity's to avoid recursion when compiling AbstractOpe…
Browse files Browse the repository at this point in the history
…rations (#1595)

* Uses a closure to define identity interpolation

* This is crazy, but heres my number...

* Random identities

* Maybe we need different numbers for different functions

* Precompile those identities

* Better comments

* There was no number

* Trim trailing numbers from all the identical identity functions

* Actually trim trailing numbers from identity operators without bugs

* Lets try 10 identities rather than 30

* Further reduce number of identities and uncomment tests

* Moar identities

* Its important that my number is not 0

* Back to 6 identities

* Some hacks to try to get AveragedField to compile in kernels

* Still have to skip some computations with AveragedField

* Update test_implicit_free_surface_solver.jl
  • Loading branch information
glwagner authored Apr 17, 2021
1 parent 98cd4f7 commit e293068
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 36 deletions.
15 changes: 12 additions & 3 deletions src/AbstractOperations/show_abstract_operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@ for op_string in ("UnaryOperation", "BinaryOperation", "MultiaryOperation", "Der
end
end

function show_interp(op)
op_str = string(op)
if op_str[1:8] == "identity"
return "identity"
else
return op_str
end
end

short_show(operation::AbstractOperation) = string(operation_name(operation), " at ", show_location(operation))

Base.show(io::IO, operation::AbstractOperation) =
Expand All @@ -31,15 +40,15 @@ get_tree_padding(depth, nesting) = "    "^(depth-nesting) * "│   "^nesting
function tree_show(unary::UnaryOperation{X, Y, Z}, depth, nesting) where {X, Y, Z}
padding = get_tree_padding(depth, nesting)

return string(unary.op, " at ", show_location(X, Y, Z), " via ", unary.▶, '\n',
return string(unary.op, " at ", show_location(X, Y, Z), " via ", show_interp(unary.▶), '\n',
padding, "└── ", tree_show(unary.arg, depth+1, nesting))
end

"Return a string representaion of a `BinaryOperation` leaf within a tree visualization of an `AbstractOperation`."
function tree_show(binary::BinaryOperation{X, Y, Z}, depth, nesting) where {X, Y, Z}
padding = get_tree_padding(depth, nesting)

return string(binary.op, " at ", show_location(X, Y, Z), " via ", binary.▶op, '\n',
return string(binary.op, " at ", show_location(X, Y, Z), " via ", show_interp(binary.▶op), '\n',
padding, "├── ", tree_show(binary.a, depth+1, nesting+1), '\n',
padding, "└── ", tree_show(binary.b, depth+1, nesting))
end
Expand All @@ -59,6 +68,6 @@ end
function tree_show(deriv::Derivative{X, Y, Z}, depth, nesting) where {X, Y, Z}
padding = get_tree_padding(depth, nesting)

return string(deriv.∂, " at ", show_location(X, Y, Z), " via ", deriv.▶, '\n',
return string(deriv.∂, " at ", show_location(X, Y, Z), " via ", show_interp(deriv.▶), '\n',
padding, "└── ", tree_show(deriv.arg, depth+1, nesting))
end
30 changes: 16 additions & 14 deletions src/Fields/reduced_field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,28 @@ abstract type AbstractReducedField{X, Y, Z, A, G, N} <: AbstractField{X, Y, Z, A

const ARF = AbstractReducedField

@inline Base.getindex( r::ARF{Nothing, Y, Z}, i, j, k) where {Y, Z} = @inbounds r.data[1, j, k]
@inline Base.setindex!(r::ARF{Nothing, Y, Z}, d, i, j, k) where {Y, Z} = @inbounds r.data[1, j, k] = d
const Loc = Union{Face, Center}

@inline Base.getindex( r::ARF{X, Nothing, Z}, i, j, k) where {X, Z} = @inbounds r.data[i, 1, k]
@inline Base.setindex!(r::ARF{X, Nothing, Z}, d, i, j, k) where {X, Z} = @inbounds r.data[i, 1, k] = d
@propagate_inbounds Base.getindex( r::ARF{Nothing, <:Loc, <:Loc}, i, j, k) = @inbounds r.data[1, j, k]
@propagate_inbounds Base.setindex!(r::ARF{Nothing, <:Loc, <:Loc}, d, i, j, k) = @inbounds r.data[1, j, k] = d

@inline Base.getindex( r::ARF{X, Y, Nothing}, i, j, k) where {X, Y} = @inbounds r.data[i, j, 1]
@inline Base.setindex!(r::ARF{X, Y, Nothing}, d, i, j, k) where {X, Y} = @inbounds r.data[i, j, 1] = d
@propagate_inbounds Base.getindex( r::ARF{<:Loc, Nothing, <:Loc}, i, j, k) = @inbounds r.data[i, 1, k]
@propagate_inbounds Base.setindex!(r::ARF{<:Loc, Nothing, <:Loc}, d, i, j, k) = @inbounds r.data[i, 1, k] = d

@inline Base.getindex( r::ARF{X, Nothing, Nothing}, i, j, k) where X = @inbounds r.data[i, 1, 1]
@inline Base.setindex!(r::ARF{X, Nothing, Nothing}, d, i, j, k) where X = @inbounds r.data[i, 1, 1] = d
@propagate_inbounds Base.getindex( r::ARF{<:Loc, <:Loc, Nothing}, i, j, k) = @inbounds r.data[i, j, 1]
@propagate_inbounds Base.setindex!(r::ARF{<:Loc, <:Loc, Nothing}, d, i, j, k) = @inbounds r.data[i, j, 1] = d

@inline Base.getindex( r::ARF{Nothing, Y, Nothing}, i, j, k) where Y = @inbounds r.data[1, j, 1]
@inline Base.setindex!(r::ARF{Nothing, Y, Nothing}, d, i, j, k) where Y = @inbounds r.data[1, j, 1] = d
@propagate_inbounds Base.getindex( r::ARF{<:Loc, Nothing, Nothing}, i, j, k) = @inbounds r.data[i, 1, 1]
@propagate_inbounds Base.setindex!(r::ARF{<:Loc, Nothing, Nothing}, d, i, j, k) = @inbounds r.data[i, 1, 1] = d

@inline Base.getindex( r::ARF{Nothing, Nothing, Z}, i, j, k) where Z = @inbounds r.data[1, 1, k]
@inline Base.setindex!(r::ARF{Nothing, Nothing, Z}, d, i, j, k) where Z = @inbounds r.data[1, 1, k] = d
@propagate_inbounds Base.getindex( r::ARF{Nothing, <:Loc, Nothing}, i, j, k) = @inbounds r.data[1, j, 1]
@propagate_inbounds Base.setindex!(r::ARF{Nothing, <:Loc, Nothing}, d, i, j, k) = @inbounds r.data[1, j, 1] = d

@inline Base.getindex( r::ARF{Nothing, Nothing, Nothing}, i, j, k) = @inbounds r.data[1, 1, 1]
@inline Base.setindex!(r::ARF{Nothing, Nothing, Nothing}, d, i, j, k) = @inbounds r.data[1, 1, 1] = d
@propagate_inbounds Base.getindex( r::ARF{Nothing, Nothing, <:Loc}, i, j, k) = @inbounds r.data[1, 1, k]
@propagate_inbounds Base.setindex!(r::ARF{Nothing, Nothing, <:Loc}, d, i, j, k) = @inbounds r.data[1, 1, k] = d

@propagate_inbounds Base.getindex( r::ARF{Nothing, Nothing, Nothing}, i, j, k) = @inbounds r.data[1, 1, 1]
@propagate_inbounds Base.setindex!(r::ARF{Nothing, Nothing, Nothing}, d, i, j, k) = @inbounds r.data[1, 1, 1] = d

fill_halo_regions!(field::AbstractReducedField, arch, args...) =
fill_halo_regions!(field.data, field.boundary_conditions, arch, field.grid, args...; reduced_dimensions=field.dims)
Expand Down
39 changes: 28 additions & 11 deletions src/Operators/interpolation_utils.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,7 @@
using Random
using Oceananigans.Utils: instantiate
using Oceananigans.Grids: Face, Center

import Base: identity

@inline identity(i, j, k, grid, c) = @inbounds c[i, j, k]
@inline identity(i, j, k, grid, a::Number) = a

"""Evaluate the function `F` with signature `F(i, j, k, grid, args...)` at index `i, j, k` without
interpolation."""
@inline identity(i, j, k, grid, F::TF, args...) where TF<:Function = F(i, j, k, grid, args...)

# Utilities for inferring the interpolation function needed to
# interpolate a field from one location to the next.
interpolation_code(from, to) = interpolation_code(to)
Expand All @@ -34,6 +26,23 @@ for ξ in ("x", "y", "z")
end
end

# It's not Oceananigans for nothing
const number_of_identities = 6 # hopefully enough for Oceananigans (most need just one)

for i = 1:number_of_identities
identity = Symbol(:identity, i)

@eval begin
@inline $identity(i, j, k, grid, c) = @inbounds c[i, j, k]
@inline $identity(i, j, k, grid, a::Number) = a
@inline $identity(i, j, k, grid, F::TF, args...) where TF<:Function = F(i, j, k, grid, args...)
end
end

torus(x, lower, upper) = lower + rem(x - lower, upper - lower, RoundDown)
identify_an_identity(number) = Symbol(:identity, torus(number, 1, number_of_identities))
identity_counter = 0

"""
interpolation_operator(from, to)
Expand All @@ -44,8 +53,12 @@ function interpolation_operator(from, to)
from, to = instantiate.(from), instantiate.(to)
x, y, z = (interpolation_code(X, Y) for (X, Y) in zip(from, to))

# This is crazy, but here's my number...
global identity_counter += 1
identity = identify_an_identity(identity_counter)

if all=== :ᵃ for ξ in (x, y, z))
return identity
return @eval $identity
else
return eval(Symbol(:ℑ, ℑxsym(x), ℑysym(y), ℑzsym(z), x, y, z))
end
Expand All @@ -57,7 +70,11 @@ end
Return the `identity` interpolator function. This is needed to obtain the interpolation
operator for fields that have no intrinsic location, like numbers or functions.
"""
interpolation_operator(::Nothing, to) = identity
function interpolation_operator(::Nothing, to)
global identity_counter += 1
identity = identify_an_identity(identity_counter)
return @eval $identity
end

assumed_field_location(name) = name === :u ? (Face, Center, Center) :
name === :v ? (Center, Face, Center) :
Expand Down
18 changes: 11 additions & 7 deletions test/test_abstract_operations_computed_field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -693,25 +693,29 @@ end
@testset "Computations with AveragedFields [$FT, $(typeof(arch))]" begin
@info " Testing computations with AveragedField [$FT, $(typeof(arch))]..."

@test computations_with_averaged_field_derivative(model)

# These don't work on the GPU right now
if arch isa CPU
@test computations_with_averaged_fields(model)
@test computations_with_averaged_field_derivative(model)
else
@test_skip computations_with_averaged_fields(model)
@test_skip computations_with_averaged_field_derivative(model)
end
end

@testset "Computations with ComputedFields [$FT, $(typeof(arch))]" begin
@info " Testing computations with ComputedField [$FT, $(typeof(arch))]..."

# These don't work on the GPU right now
if arch isa CPU
@test computations_with_computed_fields(model)
else
@test_skip computations_with_computed_fields(model)
# Basic compilation test...
u, v, w = model.velocities
@test try
compute!(ComputedField(u + v - w))
true
catch
false
end

@test computations_with_computed_fields(model)
end

@testset "Conditional computation of ComputedField and BuoyancyField [$FT, $(typeof(arch))]" begin
Expand Down
2 changes: 1 addition & 1 deletion test/test_implicit_free_surface_solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ function run_implicit_free_surface_solver_tests(arch, grid)

# Compare
extrema_tolerance = 1e-10
std_tolerance = 1e-11
std_tolerance = 1e-10

CUDA.@allowscalar begin
@test abs(minimum(left_hand_side[1:Nx, 1:Ny, 1] .- right_hand_side[1:Nx, 1:Ny, 1])) < extrema_tolerance
Expand Down

0 comments on commit e293068

Please sign in to comment.