-
Notifications
You must be signed in to change notification settings - Fork 195
/
binary_operations.jl
188 lines (146 loc) · 6.36 KB
/
binary_operations.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
const binary_operators = Set()
"""
BinaryOperation{X, Y, Z, O, A, B, IA, IB, IΩ, G} <: AbstractOperation{X, Y, Z, G}
An abstract representation of a binary operation on `AbstractField`s.
"""
struct BinaryOperation{X, Y, Z, O, A, B, IA, IB, IΩ, G} <: AbstractOperation{X, Y, Z, G}
op :: O
a :: A
b :: B
▶a :: IA
▶b :: IB
▶op :: IΩ
grid :: G
"""
BinaryOperation{X, Y, Z}(op, a, b, ▶a, ▶b, ▶op, grid)
Returns an abstract representation of the binary operation `op(▶a(a), ▶b(b))`,
followed by interpolation by `▶op` to `(X, Y, Z)`, where `▶a` and `▶b` interpolate
`a` and `b` to a common location.
"""
function BinaryOperation{X, Y, Z}(op, a, b, ▶a, ▶b, ▶op, grid) where {X, Y, Z}
any((X, Y, Z) .=== Nothing) && throw(ArgumentError("Nothing locations are invalid! " *
"Cannot construct BinaryOperation at ($X, $Y, $Z)."))
return new{X, Y, Z, typeof(op), typeof(a), typeof(b), typeof(▶a), typeof(▶b),
typeof(▶op), typeof(grid)}(op, a, b, ▶a, ▶b, ▶op, grid)
end
end
@inline Base.getindex(β::BinaryOperation, i, j, k) = β.▶op(i, j, k, β.grid, β.op, β.▶a, β.▶b, β.a, β.b)
#####
##### BinaryOperation construction
#####
"""Create a binary operation for `op` acting on `a` and `b` with locations `La` and `Lb`.
The operator acts at `Lab` and the result is interpolated to `Lc`."""
function _binary_operation(Lc, op, a, b, La, Lb, Lab, grid)
▶a = interpolation_operator(La, Lab)
▶b = interpolation_operator(Lb, Lab)
▶op = interpolation_operator(Lab, Lc)
return BinaryOperation{Lc[1], Lc[2], Lc[3]}(op, a, b, ▶a, ▶b, ▶op, grid)
end
"""Return an expression that defines an abstract `BinaryOperator` named `op` for `AbstractField`."""
function define_binary_operator(op)
return quote
import Oceananigans.Grids: AbstractGrid
import Oceananigans.Fields: AbstractField
local location = Oceananigans.Fields.location
local FunctionField = Oceananigans.Fields.FunctionField
local AF = AbstractField
@inline $op(i, j, k, grid::AbstractGrid, ▶a, ▶b, a, b) =
@inbounds $op(▶a(i, j, k, grid, a), ▶b(i, j, k, grid, b))
"""
$($op)(Lc, Lab, a, b)
Returns an abstract representation of the operator `$($op)` acting on `a` and `b` at
location `Lab`, and subsequently interpolated to location `Lc`.
"""
function $op(Lc::Tuple, Lop::Tuple, a, b)
La = location(a)
Lb = location(b)
grid = Oceananigans.AbstractOperations.validate_grid(a, b)
return Oceananigans.AbstractOperations._binary_operation(Lc, $op, a, b, La, Lb, Lop, grid)
end
$op(Lc::Tuple, a, b) = $op(Lc, Lc, a, b)
$op(Lc::Tuple, a::Number, b) = $op(Lc, location(b), a, b)
$op(Lc::Tuple, a, b::Number) = $op(Lc, location(a), a, b)
$op(Lc::Tuple, a::AF{X, Y, Z}, b::AF{X, Y, Z}) where {X, Y, Z} = $op(Lc, location(a), a, b)
# Sugar for mixing in functions of (x, y, z)
$op(Lc::Tuple, a::Function, b::AbstractField) = $op(Lc, FunctionField(Lc, a, b.grid), b)
$op(Lc::Tuple, a::AbstractField, b::Function) = $op(Lc, a, FunctionField(Lc, b, a.grid))
# Sugary versions with default locations
$op(a::AF, b::AF) = $op(location(a), a, b)
$op(a::AF, b::Number) = $op(location(a), a, b)
$op(a::Number, b::AF) = $op(location(b), a, b)
$op(a::AF, b::Function) = $op(location(a), a, FunctionField(location(a), b, a.grid))
$op(a::Function, b::AF) = $op(location(b), FunctionField(location(b), a, b.grid), b)
end
end
"""
@binary op1 op2 op3...
Turn each binary function in the list `(op1, op2, op3...)`
into a binary operator on `Oceananigans.Fields` for use in `AbstractOperations`.
Note: a binary function is a function with two arguments: for example, `+(x, y)` is a binary function.
Also note: a binary function in `Base` must be imported to be extended: use `import Base: op; @binary op`.
Example
=======
```jldoctest
julia> using Oceananigans, Oceananigans.AbstractOperations, Oceananigans.Grids
julia> plus_or_times(x, y) = x < 0 ? x + y : x * y
plus_or_times (generic function with 1 method)
julia> @binary plus_or_times
6-element Array{Any,1}:
:+
:-
:/
:^
:*
:plus_or_times
julia> c, d = (Field(Cell, Cell, Cell, CPU(), RegularCartesianGrid(size=(1, 1, 1), extent=(1, 1, 1))) for i = 1:2);
julia> plus_or_times(c, d)
BinaryOperation at (Cell, Cell, Cell)
├── grid: RegularCartesianGrid{Float64, Periodic, Periodic, Bounded}(Nx=1, Ny=1, Nz=1)
│ └── domain: x ∈ [0.0, 1.0], y ∈ [0.0, 1.0], z ∈ [-1.0, 0.0]
└── tree:
plus_or_times at (Cell, Cell, Cell) via identity
├── Field located at (Cell, Cell, Cell)
└── Field located at (Cell, Cell, Cell)
"""
macro binary(ops...)
expr = Expr(:block)
for op in ops
defexpr = define_binary_operator(op)
push!(expr.args, :($(esc(defexpr))))
add_to_operator_lists = quote
push!(Oceananigans.AbstractOperations.operators, Symbol($op))
push!(Oceananigans.AbstractOperations.binary_operators, Symbol($op))
end
push!(expr.args, :($(esc(add_to_operator_lists))))
end
return expr
end
#####
##### Architecture inference for BinaryOperation
#####
architecture(β::BinaryOperation) = architecture(β.a, β.b)
function architecture(a, b)
arch_a = architecture(a)
arch_b = architecture(b)
arch_a === arch_b && return arch_a
isnothing(arch_a) && return arch_b
isnothing(arch_b) && return arch_a
throw(ArgumentError("Operation involves fields on architectures $arch_a and $arch_b"))
return nothing
end
#####
##### Nested computations
#####
function compute!(β::BinaryOperation)
compute!(β.a)
compute!(β.b)
return nothing
end
#####
##### GPU capabilities
#####
"Adapt `BinaryOperation` to work on the GPU via CUDAnative and CUDAdrv."
Adapt.adapt_structure(to, binary::BinaryOperation{X, Y, Z}) where {X, Y, Z} =
BinaryOperation{X, Y, Z}(Adapt.adapt(to, binary.op), Adapt.adapt(to, binary.a), Adapt.adapt(to, binary.b),
Adapt.adapt(to, binary.▶a), Adapt.adapt(to, binary.▶b), Adapt.adapt(to, binary.▶op),
binary.grid)