From b1bd4a60b6d5e466e7c42e458988ed54f6585189 Mon Sep 17 00:00:00 2001 From: K Pamnany Date: Fri, 23 Jul 2021 16:50:02 -0400 Subject: [PATCH 1/4] Perf: Specialize `fn` in `rmap` and `rmaptype` Also avoid use of `map` on `NamedTuple`s as it doesn't specialize `f`. --- src/RecursiveApply/RecursiveApply.jl | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) mode change 100644 => 100755 src/RecursiveApply/RecursiveApply.jl diff --git a/src/RecursiveApply/RecursiveApply.jl b/src/RecursiveApply/RecursiveApply.jl old mode 100644 new mode 100755 index e4127d0a7b..756b31e3a8 --- a/src/RecursiveApply/RecursiveApply.jl +++ b/src/RecursiveApply/RecursiveApply.jl @@ -14,26 +14,29 @@ export ⊞, ⊠, ⊟ Recursively apply `fn` to each element of `X` """ -rmap(fn, X) = fn(X) -rmap(fn, X, Y) = fn(X, Y) -rmap(fn, X::Tuple) = map(x -> rmap(fn, x), X) -rmap(fn, X::Tuple, Y::Tuple) = map((x, y) -> rmap(fn, x, y), X, Y) -rmap(fn, X::NamedTuple) = map(x -> rmap(fn, x), X) -rmap(fn, X::NamedTuple{names}, Y::NamedTuple{names}) where {names} = - map((x, y) -> rmap(fn, x, y), X, Y) +rmap(fn::F, X) where {F} = fn(X) +rmap(fn::F, X, Y) where {F} = fn(X, Y) +rmap(fn::F, X::Tuple) where {F} = map(x -> rmap(fn, x), X) +rmap(fn::F, X::Tuple, Y::Tuple) where {F} = map((x, y) -> rmap(fn, x, y), X, Y) +rmap(fn::F, X::NamedTuple{names}) where {F, names} = + NamedTuple{names}(rmap(fn, Tuple(X))) +rmap(fn::F, X::NamedTuple{names}, Y::NamedTuple{names}) where {F, names} = + NamedTuple{names}(rmap(fn, Tuple(X), Tuple(Y))) """ rmaptype(fn, T) The return type of `rmap(fn, X::T)`. """ -rmaptype(fn, ::Type{T}) where {T} = fn(T) -rmaptype(fn, ::Type{T}) where {T <: Tuple} = +rmaptype(fn::F, ::Type{T}) where {F, T} = fn(T) +rmaptype(fn::F, ::Type{T}) where {F, T <: Tuple} = Tuple{map(fn, tuple(T.parameters...))...} -rmaptype(fn, ::Type{T}) where {T <: NamedTuple{names, tup}} where {names, tup} = +rmaptype( + fn::F, + ::Type{T}, +) where {F, T <: NamedTuple{names, tup}} where {names, tup} = NamedTuple{names, rmaptype(fn, tup)} - """ rmul(w, X) w ⊠ X @@ -66,7 +69,6 @@ const ⊟ = rsub rdiv(X, w::Number) = rmap(x -> x / w, X) - """ rmuladd(w, X, Y) @@ -76,7 +78,6 @@ rmuladd(w::Number, X, Y) = rmap((x, y) -> muladd(w, x, y), X, Y) rmuladd(X, w::Number, Y) = rmap((x, y) -> muladd(x, w, y), X, Y) rmuladd(w::Number, x::Number, y::Number) = muladd(w, x, y) - """ rmatmul1(W, S, i, j) From cf42eb6805d2a93a4c156f7b500e7ae4afe5043a Mon Sep 17 00:00:00 2001 From: Simon Byrne Date: Mon, 2 Aug 2021 13:32:37 -0700 Subject: [PATCH 2/4] add JETTest, test on 1.7 --- .github/workflows/Linux-UnitTests.yml | 8 +++++--- test/Project.toml | 1 + test/recursive.jl | 16 ++++++++++++++++ test/runtests.jl | 1 + 4 files changed, 23 insertions(+), 3 deletions(-) create mode 100644 test/recursive.jl diff --git a/.github/workflows/Linux-UnitTests.yml b/.github/workflows/Linux-UnitTests.yml index 51410bad49..93459d972d 100644 --- a/.github/workflows/Linux-UnitTests.yml +++ b/.github/workflows/Linux-UnitTests.yml @@ -14,7 +14,9 @@ jobs: timeout-minutes: 60 strategy: fail-fast: true - + matrix: + julia-version: ['1.6', '1.7-nightly'] + env: CLIMATEMACHINE_SETTINGS_FIX_RNG_SEED: "true" @@ -30,14 +32,14 @@ jobs: - name: Set up Julia uses: julia-actions/setup-julia@latest with: - version: 1.6 + version: ${{ matrix.julia-version }} - name: Cache artifacts uses: actions/cache@v1 env: cache-name: cache-artifacts with: - path: ~/.julia/artifacts + path: ~/.julia/artifacts key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} restore-keys: | ${{ runner.os }}-test-${{ env.cache-name }}- diff --git a/test/Project.toml b/test/Project.toml index 823b915be7..ebecc918e9 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -4,6 +4,7 @@ ClimaCore = "d414da3d-4745-48bb-8d80-42e94e092884" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" GaussQuadrature = "d54b0c1a-921d-58e0-8e36-89d8069c0969" IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" +JETTest = "a79fb612-4a80-4749-a9bd-c2faab13da61" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" diff --git a/test/recursive.jl b/test/recursive.jl new file mode 100644 index 0000000000..8ab0c34b23 --- /dev/null +++ b/test/recursive.jl @@ -0,0 +1,16 @@ +using Test, JETTest + +using ClimaCore.RecursiveApply + +for x in [ + 1.0, + 1.0f0, + (1.0, 2.0), + (1.0f0, 2.0f0), + (a = 1.0, b = (x1 = 2.0, x2 = 3.0)), + (a = 1.0f0, b = (x1 = 2.0f0, x2 = 3.0f0)), +] + @test_nodispatch 2 ⊠ x + @test_nodispatch x ⊞ x + @test_nodispatch RecursiveApply.rdiv(x, 3) +end diff --git a/test/runtests.jl b/test/runtests.jl index b78c9677d4..289276b636 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ using Base: operator_associativity +include("recursive.jl") include("data1d.jl") include("data.jl") include("grid.jl") From 6abd5c86fd6d40c1beb39d40980ad2234a901486 Mon Sep 17 00:00:00 2001 From: Simon Byrne Date: Mon, 2 Aug 2021 15:15:18 -0700 Subject: [PATCH 3/4] fix LinearIndices --- src/Geometry/vectors.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Geometry/vectors.jl b/src/Geometry/vectors.jl index 8abe4482e3..39c61aac38 100644 --- a/src/Geometry/vectors.jl +++ b/src/Geometry/vectors.jl @@ -30,7 +30,7 @@ function Base.Broadcast.broadcast_shape( end Base.LinearIndices(axs::NTuple{N, AbstractAxis}) where {N} = - LinearIndices(map(ax -> ax.range, axs)) + LinearIndices(map(ax -> first(ax.range):last(ax.range), axs)) struct CovariantAxis{R} <: AbstractAxis range::R From 8d95413ce8c14c6e563e7ca3df3dac11b887d038 Mon Sep 17 00:00:00 2001 From: Simon Byrne Date: Tue, 3 Aug 2021 11:50:36 -0700 Subject: [PATCH 4/4] rework norms --- src/Fields/broadcast.jl | 22 +++++++++++++++++++--- src/Geometry/vectors.jl | 26 ++++++++++++++++++-------- 2 files changed, 37 insertions(+), 11 deletions(-) diff --git a/src/Fields/broadcast.jl b/src/Fields/broadcast.jl index 2537b40247..855bf84025 100644 --- a/src/Fields/broadcast.jl +++ b/src/Fields/broadcast.jl @@ -139,9 +139,25 @@ function Base.Broadcast.broadcasted( # wrap in a Field so that the axes line up correctly (it just get's unwraped so effectively a no-op) Base.Broadcast.broadcasted( fs, - LinearAlgebra.norm, + Geometry._norm, arg, - Field(space.local_geometry, space), + hasproperty(space, :local_geometry) ? + Field(space.local_geometry, space) : Ref(nothing), + ) +end +function Base.Broadcast.broadcasted( + fs::AbstractFieldStyle, + ::typeof(LinearAlgebra.norm_sqr), + arg, +) + space = Fields.axes(arg) + # wrap in a Field so that the axes line up correctly (it just get's unwraped so effectively a no-op) + Base.Broadcast.broadcasted( + fs, + Geometry._norm_sqr, + arg, + hasproperty(space, :local_geometry) ? + Field(space.local_geometry, space) : Ref(nothing), ) end @@ -155,7 +171,7 @@ function Base.Broadcast.broadcasted( # wrap in a Field so that the axes line up correctly (it just get's unwraped so effectively a no-op) Base.Broadcast.broadcasted( fs, - LinearAlgebra.cross, + Geometry._cross, arg1, arg2, Field(space.local_geometry, space), diff --git a/src/Geometry/vectors.jl b/src/Geometry/vectors.jl index 39c61aac38..b3b71be6cc 100644 --- a/src/Geometry/vectors.jl +++ b/src/Geometry/vectors.jl @@ -248,19 +248,29 @@ curl_result_type(::Type{V}) where {V <: Cartesian12Vector{FT}} where {FT} = curl_result_type(::Type{V}) where {V <: Covariant3Vector{FT}} where {FT} = Contravariant12Vector{FT} -function norm²(uᵢ::Covariant12Vector, local_geometry::LocalGeometry) +_norm_sqr(x, local_geometry) = LinearAlgebra.norm_sqr(x) + +function _norm_sqr(u::Contravariant3Vector, local_geometry::LocalGeometry) + LinearAlgebra.norm_sqr(u.u³) +end +function _norm_sqr(uᵢ::Covariant12Vector, local_geometry::LocalGeometry) + u = Cartesian12Vector(uᵢ, local_geometry) + _norm_sqr(u, local_geometry) +end + +function _norm_sqr(uᵢ::Covariant12Vector, local_geometry::LocalGeometry) u = Cartesian12Vector(uᵢ, local_geometry) - norm²(u, local_geometry) + _norm_sqr(u, local_geometry) end -function norm²(u::Cartesian12Vector, local_geometry::LocalGeometry) +function _norm_sqr(u::Cartesian12Vector, local_geometry::LocalGeometry) abs2(u.u1) + abs2(u.u2) end -LinearAlgebra.norm(u::CustomAxisFieldVector, local_geometry::LocalGeometry) = - sqrt(norm²(u, local_geometry)) +_norm(u::CustomAxisFieldVector, local_geometry) = + sqrt(_norm_sqr(u, local_geometry)) -function LinearAlgebra.cross( +function _cross( uⁱ::Contravariant12Vector, v::Contravariant3Vector, local_geometry::LocalGeometry, @@ -268,13 +278,13 @@ function LinearAlgebra.cross( Covariant12Vector(uⁱ.u² * v.u³, -uⁱ.u¹ * v.u³) end -function LinearAlgebra.cross( +function _cross( u::Cartesian12Vector, v::Contravariant3Vector, local_geometry::LocalGeometry, ) uⁱ = Contravariant12Vector(u, local_geometry) - LinearAlgebra.cross(uⁱ, v, local_geometry) + _cross(uⁱ, v, local_geometry) end # tensors