Skip to content
This repository has been archived by the owner on Mar 1, 2023. It is now read-only.

Commit

Permalink
Merge #1931
Browse files Browse the repository at this point in the history
1931: Fix NamedTuple keys, add error handling r=charleskawczynski a=charleskawczynski

### Description

After #1929, there is some mismatch with the keys and values returned from `flattened_named_tuple` and thus `single_stack_diagnostics` (e.g., IIRC, mixing length). cc @yairchn @ilopezgp. This PR should fix this mismatch. Specifically, adding

```julia
    length(keys_) == length(vals) || error("key-value mismatch")
```

to `flattened_named_tuple` errors on master due to improper handling of `SHermitianCompact` and `Diagonal` arrays. This required some changes / additions to `flattened_nt_vals`.



Co-authored-by: Charles Kawczynski <[email protected]>
  • Loading branch information
bors[bot] and charleskawczynski authored Jan 25, 2021
2 parents 0313049 + 6daf821 commit fe19efd
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 21 deletions.
8 changes: 8 additions & 0 deletions src/Utilities/SingleStackUtils/single_stack_diagnostics.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
using ..Orientations
import ..VariableTemplates: flattened_named_tuple
using ..VariableTemplates

# Sometimes `NodalStack` returns local states
# that is `nothing`. Here, we return `nothing`
# to preserve the keys (e.g., `hyperdiff`)
# when misssing.
flattened_named_tuple(v::Nothing, ft::FlattenType = FlattenArr()) = nothing

"""
single_stack_diagnostics(
Expand Down
17 changes: 11 additions & 6 deletions src/Utilities/VariableTemplates/VariableTemplates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module VariableTemplates
export varsize, Vars, Grad, @vars, varsindex, varsindices

using StaticArrays
using LinearAlgebra

"""
varsindex(S, p::Symbol, [sp::Symbol...])
Expand Down Expand Up @@ -429,10 +430,12 @@ function getpropertyorindex end
# Redirect to Base getproperty/getindex:
Base.@propagate_inbounds getpropertyorindex(t::Tuple, ::Val{i}) where {i} =
Base.getindex(t, i)
Base.@propagate_inbounds getpropertyorindex(a::SArray, ::Val{i}) where {i} =
Base.getindex(a, i)
Base.@propagate_inbounds getpropertyorindex(nt::AbstractVars, s::Symbol) =
Base.getproperty(nt, s)
Base.@propagate_inbounds getpropertyorindex(
a::AbstractArray,
::Val{i},
) where {i} = Base.getindex(a, i)
Base.@propagate_inbounds getpropertyorindex(v::AbstractVars, s::Symbol) =
Base.getproperty(v, s)
Base.@propagate_inbounds getpropertyorindex(
v::AbstractVars,
::Val{i},
Expand All @@ -443,8 +446,10 @@ Base.@propagate_inbounds getpropertyorindex(
v::AbstractVars,
t::Tuple{A},
) where {A} = getpropertyorindex(v, t[1])
Base.@propagate_inbounds getpropertyorindex(a::SArray, t::Tuple{A}) where {A} =
getpropertyorindex(a, t[1])
Base.@propagate_inbounds getpropertyorindex(
a::AbstractArray,
t::Tuple{A},
) where {A} = getpropertyorindex(a, t[1])

# Peel first element from tuple and recurse:
Base.@propagate_inbounds getpropertyorindex(v::AbstractVars, t::Tuple) =
Expand Down
74 changes: 59 additions & 15 deletions src/Utilities/VariableTemplates/flattened_tup_chain.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using LinearAlgebra

export flattened_tup_chain, flattened_named_tuple
export FlattenArr, RetainArr
export FlattenType, FlattenArr, RetainArr

abstract type FlattenType end

Expand All @@ -19,16 +21,22 @@ and `flattened_named_tuple`.
"""
struct RetainArr <: FlattenType end

# The Vars instance has many empty entries.
# Keeping all of the keys results in many
# duplicated values. So, it's best we
# "prune" the tree by removing the keys:
flattened_tup_chain(
::Type{NamedTuple{(), Tuple{}}},
::FlattenType = FlattenArr();
prefix = (Symbol(),),
) = ()

flattened_tup_chain(
::Type{T},
::FlattenType;
prefix = (Symbol(),),
) where {T <: Real} = (prefix,)

flattened_tup_chain(
::Type{T},
::RetainArr;
Expand All @@ -42,9 +50,27 @@ flattened_tup_chain(

flattened_tup_chain(
::Type{T},
::FlattenType;
::RetainArr;
prefix = (Symbol(),),
) where {T <: SHermitianCompact} = (prefix,)
flattened_tup_chain(
::Type{T},
::FlattenType;
prefix = (Symbol(),),
) where {T <: SHermitianCompact} =
ntuple(i -> (prefix..., i), length(StaticArrays.lowertriangletype(T)))

flattened_tup_chain(
::Type{T},
::RetainArr;
prefix = (Symbol(),),
) where {N, TA, T <: Diagonal{N, TA}} = (prefix,)
flattened_tup_chain(
::Type{T},
::FlattenArr;
prefix = (Symbol(),),
) where {N, TA, T <: Diagonal{N, TA}} = ntuple(i -> (prefix..., i), length(TA))

flattened_tup_chain(::Type{T}, ::FlattenType; prefix = (Symbol(),)) where {T} =
(prefix,)

Expand Down Expand Up @@ -79,9 +105,10 @@ flattened_tup_chain(
) where {S} = flattened_tup_chain(S, ft)

"""
flattened_named_tuple(v::AbstractVars, ::FlattenType)
flattened_named_tuple
A flattened NamedTuple, given a `Vars` instance.
A flattened NamedTuple, given a
`Vars` or nested `NamedTuple` instance.
# Example:
Expand All @@ -106,23 +133,40 @@ function flattened_named_tuple(v::AbstractVars, ft::FlattenType = FlattenArr())
ftc = flattened_tup_chain(v, ft)
keys_ = Symbol.(join.(ftc, :_))
vals = map(x -> getproperty(v, wrap_val.(x)), ftc)
length(keys_) == length(vals) || error("key-value mismatch")
return (; zip(keys_, vals)...)
end
flattened_named_tuple(v::Nothing, ft::FlattenType = FlattenArr()) = NamedTuple()

function flattened_named_tuple(nt::NamedTuple, ft::FlattenType = FlattenArr())
ftc = flattened_tup_chain(typeof(nt), ft)
keys_ = Symbol.(join.(ftc, :_))
vals = flattened_nt_vals(nt)
vals = flattened_nt_vals(ft, nt)
length(keys_) == length(vals) || error("key-value mismatch")
return (; zip(keys_, vals)...)
end

flattened_nt_vals(a::NamedTuple) = flattened_nt_vals(Tuple(a))
flattened_nt_vals(a::NamedTuple{(), Tuple{}}) = (nothing,)
flattened_nt_vals(a) = (a,)
flattened_nt_vals(a::NamedTuple, b...) =
tuple(flattened_nt_vals(a)..., flattened_nt_vals(b...)...)
flattened_nt_vals(a::NamedTuple{(), Tuple{}}, b...) =
tuple(nothing, flattened_nt_vals(b...)...)
flattened_nt_vals(a, b...) = tuple(values(a), flattened_nt_vals(b...)...)
flattened_nt_vals(x::Tuple) = flattened_nt_vals(x...)

flattened_nt_vals(::FlattenArr, a::AbstractArray) = tuple(a...)
flattened_nt_vals(::RetainArr, a::AbstractArray) = tuple(a)

flattened_nt_vals(::FlattenArr, a::Diagonal) = tuple(a.diag...)
flattened_nt_vals(::RetainArr, a::Diagonal) = tuple(a.diag)

flattened_nt_vals(::FlattenArr, a::SHermitianCompact) =
tuple(a.lowertriangle...)
flattened_nt_vals(::RetainArr, a::SHermitianCompact) = tuple(a.lowertriangle)

# when we splat an empty tuple `b` into `flattened_nt_vals(ft, b...)`
flattened_nt_vals(::FlattenType) = ()

# for structs
flattened_nt_vals(::FlattenType, a) = (a,)

# Divide and concur:
flattened_nt_vals(ft::FlattenType, a, b...) =
tuple(flattened_nt_vals(ft, a)..., flattened_nt_vals(ft, b...)...)

flattened_nt_vals(ft::FlattenType, a::Tuple) = flattened_nt_vals(ft, a...)

flattened_nt_vals(ft::FlattenType, a::NamedTuple) =
flattened_nt_vals(ft, Tuple(a))
69 changes: 69 additions & 0 deletions test/Utilities/VariableTemplates/test_complex_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ using Test
using StaticArrays
using ClimateMachine.VariableTemplates
using ClimateMachine.VariableTemplates: wrap_val
import ClimateMachine.VariableTemplates
VT = VariableTemplates

@testset "Test complex models" begin
include("complex_models.jl")
Expand Down Expand Up @@ -91,6 +93,12 @@ using ClimateMachine.VariableTemplates: wrap_val
end
@test fn[j] === "scalar_model.x"

# flattened_tup_chain - empty/generic cases
struct Foo end
@test flattened_tup_chain(NamedTuple{(), Tuple{}}) == ()
@test flattened_tup_chain(Foo, RetainArr()) == ((Symbol(),),)
@test flattened_tup_chain(Foo, FlattenArr()) == ((Symbol(),),)

# flattened_tup_chain - Retain arrays

ftc = flattened_tup_chain(st, RetainArr())
Expand Down Expand Up @@ -230,4 +238,65 @@ using ClimateMachine.VariableTemplates: wrap_val
@test fnt.vector_model_x_3 == 23.0
@test fnt.scalar_model_x == 24.0f0

struct Foo end
nt = (;
nest = (;
v = SVector(1, 2, 3),
nt = (;
shc = SHermitianCompact{3, FT, 6}(collect(1:6)),
f = FT(1.0),
),
d = SDiagonal(collect(1:3)...),
tt = (Foo(), Foo()),
t = Foo(),
),
)
# Test flattened_nt_vals:

@test VT.flattened_nt_vals(RetainArr(), NamedTuple()) == ()
@test VT.flattened_nt_vals(FlattenArr(), NamedTuple()) == ()
@test VT.flattened_nt_vals(RetainArr(), Tuple(NamedTuple())) == ()
@test VT.flattened_nt_vals(FlattenArr(), Tuple(NamedTuple())) == ()

ft = FlattenArr()
@test VT.flattened_nt_vals(ft, nt.nest.nt.f) == (1.0f0,)
@test VT.flattened_nt_vals(ft, nt.nest.nt) ==
(1.0f0, 2.0f0, 3.0f0, 4.0f0, 5.0f0, 6.0f0, 1.0f0)
@test VT.flattened_nt_vals(ft, nt.nest.d) == (1, 2, 3)
@test VT.flattened_nt_vals(ft, nt.nest.t) == (Foo(),)
@test VT.flattened_nt_vals(ft, nt.nest.tt) == (Foo(), Foo())

ft = RetainArr()
@test VT.flattened_nt_vals(ft, nt.nest.nt.f) == (1.0f0,)
@test VT.flattened_nt_vals(ft, nt.nest.nt)[1] ==
nt.nest.nt.shc.lowertriangle
@test VT.flattened_nt_vals(ft, nt.nest.nt)[2] == 1.0f0
@test VT.flattened_nt_vals(ft, nt.nest.d) == (nt.nest.d.diag,)
@test VT.flattened_nt_vals(ft, nt.nest.t) == (Foo(),)
@test VT.flattened_nt_vals(ft, nt.nest.tt) == (Foo(), Foo())

# Test flattened_named_tuple for NamedTuples
fnt = flattened_named_tuple(nt, FlattenArr())
@test fnt.nest_v_1 == 1
@test fnt.nest_v_2 == 2
@test fnt.nest_v_3 == 3
@test fnt.nest_nt_shc_1 == 1.0
@test fnt.nest_nt_shc_2 == 2.0
@test fnt.nest_nt_shc_3 == 3.0
@test fnt.nest_nt_shc_4 == 4.0
@test fnt.nest_nt_shc_5 == 5.0
@test fnt.nest_nt_shc_6 == 6.0
@test fnt.nest_nt_f == 1.0
@test fnt.nest_tt_1 == Foo()
@test fnt.nest_tt_2 == Foo()
@test fnt.nest_t == Foo()

fnt = flattened_named_tuple(nt, RetainArr())
@test fnt.nest_v == SVector(1, 2, 3)
@test fnt.nest_nt_shc == nt.nest.nt.shc.lowertriangle
@test fnt.nest_nt_f == 1.0
@test fnt.nest_tt_1 == Foo()
@test fnt.nest_tt_2 == Foo()
@test fnt.nest_t == Foo()

end

0 comments on commit fe19efd

Please sign in to comment.