Skip to content

Commit

Permalink
Merge #80
Browse files Browse the repository at this point in the history
80: Perf: Specialize `fn` in `rmap` and `rmaptype` r=simonbyrne a=kpamnany

Also avoid use of `map` on `NamedTuple`s as it doesn't specialize `f`.

Co-authored-by: K Pamnany <[email protected]>
Co-authored-by: Simon Byrne <[email protected]>
  • Loading branch information
3 people authored Aug 3, 2021
2 parents 1ce475c + 8d95413 commit 0f6e922
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 28 deletions.
8 changes: 5 additions & 3 deletions .github/workflows/Linux-UnitTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ jobs:
timeout-minutes: 60
strategy:
fail-fast: true

matrix:
julia-version: ['1.6', '1.7-nightly']

env:
CLIMATEMACHINE_SETTINGS_FIX_RNG_SEED: "true"

Expand All @@ -30,14 +32,14 @@ jobs:
- name: Set up Julia
uses: julia-actions/setup-julia@latest
with:
version: 1.6
version: ${{ matrix.julia-version }}

- name: Cache artifacts
uses: actions/cache@v1
env:
cache-name: cache-artifacts
with:
path: ~/.julia/artifacts
path: ~/.julia/artifacts
key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }}
restore-keys: |
${{ runner.os }}-test-${{ env.cache-name }}-
Expand Down
22 changes: 19 additions & 3 deletions src/Fields/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,25 @@ function Base.Broadcast.broadcasted(
# wrap in a Field so that the axes line up correctly (it just get's unwraped so effectively a no-op)
Base.Broadcast.broadcasted(
fs,
LinearAlgebra.norm,
Geometry._norm,
arg,
Field(space.local_geometry, space),
hasproperty(space, :local_geometry) ?
Field(space.local_geometry, space) : Ref(nothing),
)
end
function Base.Broadcast.broadcasted(
fs::AbstractFieldStyle,
::typeof(LinearAlgebra.norm_sqr),
arg,
)
space = Fields.axes(arg)
# wrap in a Field so that the axes line up correctly (it just get's unwraped so effectively a no-op)
Base.Broadcast.broadcasted(
fs,
Geometry._norm_sqr,
arg,
hasproperty(space, :local_geometry) ?
Field(space.local_geometry, space) : Ref(nothing),
)
end

Expand All @@ -155,7 +171,7 @@ function Base.Broadcast.broadcasted(
# wrap in a Field so that the axes line up correctly (it just get's unwraped so effectively a no-op)
Base.Broadcast.broadcasted(
fs,
LinearAlgebra.cross,
Geometry._cross,
arg1,
arg2,
Field(space.local_geometry, space),
Expand Down
28 changes: 19 additions & 9 deletions src/Geometry/vectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ function Base.Broadcast.broadcast_shape(
end

Base.LinearIndices(axs::NTuple{N, AbstractAxis}) where {N} =
LinearIndices(map(ax -> ax.range, axs))
LinearIndices(map(ax -> first(ax.range):last(ax.range), axs))

struct CovariantAxis{R} <: AbstractAxis
range::R
Expand Down Expand Up @@ -248,33 +248,43 @@ curl_result_type(::Type{V}) where {V <: Cartesian12Vector{FT}} where {FT} =
curl_result_type(::Type{V}) where {V <: Covariant3Vector{FT}} where {FT} =
Contravariant12Vector{FT}

function norm²(uᵢ::Covariant12Vector, local_geometry::LocalGeometry)
_norm_sqr(x, local_geometry) = LinearAlgebra.norm_sqr(x)

function _norm_sqr(u::Contravariant3Vector, local_geometry::LocalGeometry)
LinearAlgebra.norm_sqr(u.u³)
end
function _norm_sqr(uᵢ::Covariant12Vector, local_geometry::LocalGeometry)
u = Cartesian12Vector(uᵢ, local_geometry)
_norm_sqr(u, local_geometry)
end

function _norm_sqr(uᵢ::Covariant12Vector, local_geometry::LocalGeometry)
u = Cartesian12Vector(uᵢ, local_geometry)
norm²(u, local_geometry)
_norm_sqr(u, local_geometry)
end

function norm²(u::Cartesian12Vector, local_geometry::LocalGeometry)
function _norm_sqr(u::Cartesian12Vector, local_geometry::LocalGeometry)
abs2(u.u1) + abs2(u.u2)
end

LinearAlgebra.norm(u::CustomAxisFieldVector, local_geometry::LocalGeometry) =
sqrt(norm²(u, local_geometry))
_norm(u::CustomAxisFieldVector, local_geometry) =
sqrt(_norm_sqr(u, local_geometry))

function LinearAlgebra.cross(
function _cross(
uⁱ::Contravariant12Vector,
v::Contravariant3Vector,
local_geometry::LocalGeometry,
)
Covariant12Vector(uⁱ.* v.u³, -uⁱ.* v.u³)
end

function LinearAlgebra.cross(
function _cross(
u::Cartesian12Vector,
v::Contravariant3Vector,
local_geometry::LocalGeometry,
)
uⁱ = Contravariant12Vector(u, local_geometry)
LinearAlgebra.cross(uⁱ, v, local_geometry)
_cross(uⁱ, v, local_geometry)
end

# tensors
Expand Down
27 changes: 14 additions & 13 deletions src/RecursiveApply/RecursiveApply.jl
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,29 @@ export ⊞, ⊠, ⊟
Recursively apply `fn` to each element of `X`
"""
rmap(fn, X) = fn(X)
rmap(fn, X, Y) = fn(X, Y)
rmap(fn, X::Tuple) = map(x -> rmap(fn, x), X)
rmap(fn, X::Tuple, Y::Tuple) = map((x, y) -> rmap(fn, x, y), X, Y)
rmap(fn, X::NamedTuple) = map(x -> rmap(fn, x), X)
rmap(fn, X::NamedTuple{names}, Y::NamedTuple{names}) where {names} =
map((x, y) -> rmap(fn, x, y), X, Y)
rmap(fn::F, X) where {F} = fn(X)
rmap(fn::F, X, Y) where {F} = fn(X, Y)
rmap(fn::F, X::Tuple) where {F} = map(x -> rmap(fn, x), X)
rmap(fn::F, X::Tuple, Y::Tuple) where {F} = map((x, y) -> rmap(fn, x, y), X, Y)
rmap(fn::F, X::NamedTuple{names}) where {F, names} =
NamedTuple{names}(rmap(fn, Tuple(X)))
rmap(fn::F, X::NamedTuple{names}, Y::NamedTuple{names}) where {F, names} =
NamedTuple{names}(rmap(fn, Tuple(X), Tuple(Y)))

"""
rmaptype(fn, T)
The return type of `rmap(fn, X::T)`.
"""
rmaptype(fn, ::Type{T}) where {T} = fn(T)
rmaptype(fn, ::Type{T}) where {T <: Tuple} =
rmaptype(fn::F, ::Type{T}) where {F, T} = fn(T)
rmaptype(fn::F, ::Type{T}) where {F, T <: Tuple} =
Tuple{map(fn, tuple(T.parameters...))...}
rmaptype(fn, ::Type{T}) where {T <: NamedTuple{names, tup}} where {names, tup} =
rmaptype(
fn::F,
::Type{T},
) where {F, T <: NamedTuple{names, tup}} where {names, tup} =
NamedTuple{names, rmaptype(fn, tup)}


"""
rmul(w, X)
w ⊠ X
Expand Down Expand Up @@ -66,7 +69,6 @@ const ⊟ = rsub

rdiv(X, w::Number) = rmap(x -> x / w, X)


"""
rmuladd(w, X, Y)
Expand All @@ -76,7 +78,6 @@ rmuladd(w::Number, X, Y) = rmap((x, y) -> muladd(w, x, y), X, Y)
rmuladd(X, w::Number, Y) = rmap((x, y) -> muladd(x, w, y), X, Y)
rmuladd(w::Number, x::Number, y::Number) = muladd(w, x, y)


"""
rmatmul1(W, S, i, j)
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ ClimaCore = "d414da3d-4745-48bb-8d80-42e94e092884"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
GaussQuadrature = "d54b0c1a-921d-58e0-8e36-89d8069c0969"
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
JETTest = "a79fb612-4a80-4749-a9bd-c2faab13da61"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Expand Down
16 changes: 16 additions & 0 deletions test/recursive.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
using Test, JETTest

using ClimaCore.RecursiveApply

for x in [
1.0,
1.0f0,
(1.0, 2.0),
(1.0f0, 2.0f0),
(a = 1.0, b = (x1 = 2.0, x2 = 3.0)),
(a = 1.0f0, b = (x1 = 2.0f0, x2 = 3.0f0)),
]
@test_nodispatch 2 x
@test_nodispatch x x
@test_nodispatch RecursiveApply.rdiv(x, 3)
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Base: operator_associativity

include("recursive.jl")
include("data1d.jl")
include("data.jl")
include("grid.jl")
Expand Down

0 comments on commit 0f6e922

Please sign in to comment.