Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Even faster storage handling #212

Merged
merged 41 commits into from
Feb 21, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
b1c53a1
Even faster storage handling
mateuszbaran Jan 31, 2023
cf192bd
minor improvement
mateuszbaran Jan 31, 2023
2a319d8
cleanup and asserts
mateuszbaran Jan 31, 2023
dba15cb
fixing things and addressing code review
mateuszbaran Feb 1, 2023
284782a
some general improvements and docs
mateuszbaran Feb 1, 2023
e9b3d17
a bit more docs and type bounds
mateuszbaran Feb 1, 2023
85dc5f8
unify constructors for CG coefficients
mateuszbaran Feb 1, 2023
5f4059a
Update one docstring.
kellertuer Feb 3, 2023
49f0031
simplify initialization marking
mateuszbaran Feb 5, 2023
5983e5a
unified storage accessors with fallbacks
mateuszbaran Feb 6, 2023
d95a1c9
new StoreStateAction constructor
mateuszbaran Feb 8, 2023
2a6de87
storage tests
mateuszbaran Feb 9, 2023
a13f920
even more tests
mateuszbaran Feb 9, 2023
18136a0
adapt to scaled gradients and some minor changes
mateuszbaran Feb 15, 2023
5d5bc81
Update DebugChange to a unified field name scheme and tries to use th…
kellertuer Feb 16, 2023
71ce596
addressing review
mateuszbaran Feb 16, 2023
e7270fb
Set p_init and X_init as keyword arguments.
kellertuer Feb 16, 2023
c744c5d
Merge branch 'master' into mbaran/even-faster-storage
kellertuer Feb 18, 2023
471382a
new StoreStateAction constructor
mateuszbaran Feb 18, 2023
96f9199
remove the old complicated constructor.
kellertuer Feb 19, 2023
1b4c108
Adapt RecordChange.
kellertuer Feb 19, 2023
9b93cc5
StoreStateAction accepts vectors of symbols in kwargs
mateuszbaran Feb 19, 2023
2acf3f0
adapt the remaining ones.
kellertuer Feb 19, 2023
e0e42f8
Merge branch 'mbaran/even-faster-storage' of github.com:JuliaManifold…
kellertuer Feb 19, 2023
0d00e83
Update Documentation of StoreStateAction.
kellertuer Feb 19, 2023
400bc3a
includes deprecated tests.
kellertuer Feb 19, 2023
ee6dce9
fix storage for circle
mateuszbaran Feb 19, 2023
f2b8a91
improve test coverage, fixes a bug.
kellertuer Feb 19, 2023
d884a63
Merge branch 'mbaran/even-faster-storage' of github.com:JuliaManifold…
kellertuer Feb 19, 2023
013e3c6
stabilize a test.
kellertuer Feb 19, 2023
136b0a0
Adds a changelog.
kellertuer Feb 19, 2023
2fb8a32
faster storage update; test one edge case
mateuszbaran Feb 19, 2023
9341816
Update Changelog.
kellertuer Feb 19, 2023
703acf3
Merge branch 'mbaran/even-faster-storage' of github.com:JuliaManifold…
kellertuer Feb 19, 2023
568c51f
small cleanup
mateuszbaran Feb 19, 2023
ede0d6c
fix nonmutating storage
mateuszbaran Feb 20, 2023
6062387
use more storage keys
mateuszbaran Feb 20, 2023
5f3e98d
add a few stati I missed in the last PR (Sorrey).
kellertuer Feb 20, 2023
eba6978
a few small quasi Newton optimizations
mateuszbaran Feb 21, 2023
a0c5bbb
revert p_old storage in quasi Newton
mateuszbaran Feb 21, 2023
f4172a0
Finalize 0.4.8
kellertuer Feb 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
337 changes: 192 additions & 145 deletions src/plans/conjugate_gradient_plan.jl

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions src/plans/debug.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ debug for the amount of change of the iterate (stored in `get_iterate(o)` of the
during the last iteration. See [`DebugEntryChange`](@ref) for the general case

# Keyword Parameters
* `storage` – (`StoreStateAction( (:Iterate,) )`) – (eventually shared) the storage of the previous action
* `storage` – (`StoreStateAction( [:Gradient] )`) – (eventually shared) the storage of the previous action
* `prefix` – (`"Last Change:"`) prefix of the debug output (ignored if you set `format`)
* `io` – (`stdout`) default steream to print the debug to.
* `format` - ( `"$prefix %f"`) format to print the output using an sprintf format.
Expand All @@ -154,7 +154,7 @@ mutable struct DebugChange{TInvRetr<:AbstractInverseRetractionMethod} <: DebugAc
storage::StoreStateAction
invretr::TInvRetr
function DebugChange(;
storage::StoreStateAction=StoreStateAction((:Iterate,)),
storage::StoreStateAction=StoreStateAction([:Iterate]),
io::IO=stdout,
prefix::String="Last Change: ",
format::String="$(prefix)%f",
Expand Down Expand Up @@ -183,7 +183,7 @@ debug for the amount of change of the gradient (stored in `get_gradient(o)` of t
during the last iteration. See [`DebugEntryChange`](@ref) for the general case

# Keyword Parameters
* `storage` – (`StoreStateAction( (:Gradient,) )`) – (eventually shared) the storage of the previous action
* `storage` – (`StoreStateAction( [:Gradient] )`) – (eventually shared) the storage of the previous action
* `prefix` – (`"Last Change:"`) prefix of the debug output (ignored if you set `format`)
* `io` – (`stdout`) default steream to print the debug to.
* `format` - ( `"$prefix %f"`) format to print the output using an sprintf format.
Expand All @@ -193,7 +193,7 @@ mutable struct DebugGradientChange <: DebugAction
format::String
storage::StoreStateAction
function DebugGradientChange(;
storage::StoreStateAction=StoreStateAction((:Gradient,)),
storage::StoreStateAction=StoreStateAction([:Iterate, :Gradient]),
io::IO=stdout,
prefix::String="Last Change: ",
format::String="$(prefix)%f",
Expand Down Expand Up @@ -379,11 +379,11 @@ mutable struct DebugEntryChange <: DebugAction
function DebugEntryChange(
f::Symbol,
d;
storage::StoreStateAction=StoreStateAction((f,)),
storage::StoreStateAction=StoreStateAction([f]),
prefix::String="Change of $f:",
format::String="$prefix%s",
io::IO=stdout,
initial_value::T where {T}=NaN,
initial_value::Any=NaN,
)
if !isa(initial_value, Number) || !isnan(initial_value) #set initial value
update_storage!(storage, Dict(f => initial_value))
Expand Down
8 changes: 4 additions & 4 deletions src/plans/record.jl
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ mutable struct RecordChange{TInvRetr<:AbstractInverseRetractionMethod} <: Record
storage::StoreStateAction
invretr::TInvRetr
function RecordChange(
a::StoreStateAction=StoreStateAction((:Iterate,));
a::StoreStateAction=StoreStateAction([:Iterate]);
manifold::AbstractManifold=DefaultManifold(1),
invretr::AbstractInverseRetractionMethod=default_inverse_retraction_method(
manifold
Expand All @@ -339,7 +339,7 @@ mutable struct RecordChange{TInvRetr<:AbstractInverseRetractionMethod} <: Record
end
function RecordChange(
p,
a::StoreStateAction=StoreStateAction((:Iterate,));
a::StoreStateAction=StoreStateAction([:Iterate]);
manifold::AbstractManifold=DefaultManifold(1),
invretr::AbstractInverseRetractionMethod=default_inverse_retraction_method(
manifold, typeof(p)
Expand Down Expand Up @@ -403,11 +403,11 @@ mutable struct RecordEntryChange <: RecordAction
field::Symbol
distance::Any
storage::StoreStateAction
function RecordEntryChange(f::Symbol, d, a::StoreStateAction=StoreStateAction((f,)))
function RecordEntryChange(f::Symbol, d, a::StoreStateAction=StoreStateAction([f]))
return new(Float64[], f, d, a)
end
function RecordEntryChange(
v::T where {T}, f::Symbol, d, a::StoreStateAction=StoreStateAction((f,))
v::T where {T}, f::Symbol, d, a::StoreStateAction=StoreStateAction([f])
)
update_storage!(a, Dict(f => v))
return new(Float64[], f, d, a)
Expand Down
194 changes: 164 additions & 30 deletions src/plans/solver_state.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,29 @@ iteration, i.e. acts on `(p,o,i)`, where `p` is a [`AbstractManoptProblem`](@ref

# Fields
* `values` – a dictionary to store interims values based on certain `Symbols`
* `keys` – an `NTuple` of `Symbols` to refer to fields of `AbstractManoptSolverState`
* `keys` – a `Vector` of `Symbols` to refer to fields of `AbstractManoptSolverState`
* `point_values` – a `NamedTuple` of mutable values of points on a manifold to be stored in
`StoreStateAction`. Manifold is later determined by `AbstractManoptProblem` passed
to `update_storage!`.
* `point_init` – a `NamedTuple` of boolean values indicating whether a point in
`point_values` with matching key has been already initialized to a value. When it is
false, it corresponds to a general value not being stored for the key present in the
vector `keys`.
* `tangent_values` – a `NamedTuple` of mutable values of tangent vectors on a manifold to be
stored in `StoreStateAction`. Manifold is later determined by `AbstractManoptProblem`
passed to `update_storage!`. It is not specified at which point the vectors are tangent
but for storage it should not matter.
* `vector_init` – a `NamedTuple` of boolean values indicating whether a tangent vector in
`tangent_values` with matching key has been already initialized to a value. When it is
false, it corresponds to a general value not being stored for the key present in the
vector `keys`.
* `once` – whether to update the internal values only once per iteration
* `lastStored` – last iterate, where this `AbstractStateAction` was called (to determine `once`
* `lastStored` – last iterate, where this `AbstractStateAction` was called (to determine `once`)

To handle the general storage, use `get_storage` and `has_storage`. For the point storage
use `get_point_storage` and `has_point_storage`. For tangent vector storage use
`get_tangent_storage` and `has_tangent_storage`. Point and tangent storage have been
optimized to be more efficient

# Constructiors

Expand All @@ -205,65 +225,128 @@ iteration, i.e. acts on `(p,o,i)`, where `p` is a [`AbstractManoptProblem`](@ref
Initialize the Functor to an (empty) set of keys, where `once` determines
whether more that one update per iteration are effective

AbstractStateAction(keys, once=true])
function StoreStateAction(
general_keys::Vector{Symbol}=Symbol[],
point_values::NamedTuple=NamedTuple(),
tangent_values::NamedTuple=NamedTuple(),
once::Bool=true,
))
kellertuer marked this conversation as resolved.
Show resolved Hide resolved

Initialize the Functor to a set of keys, where the dictionary is initialized to
be empty. Further, `once` determines whether more that one update per iteration
are effective, otherwise only the first update is stored, all others are ignored.
Make a copy of points and tangent vectors passed to `point_values` and `tangent_values`
for later storage respective fields.
"""
mutable struct StoreStateAction <: AbstractStateAction
values::Dict{Symbol,<:Any}
keys::NTuple{N,Symbol} where {N}
mutable struct StoreStateAction{
TPS<:NamedTuple,TXS<:NamedTuple,TPI<:NamedTuple,TTI<:NamedTuple
} <: AbstractStateAction
values::Dict{Symbol,Any}
keys::Vector{Symbol} # for values
kellertuer marked this conversation as resolved.
Show resolved Hide resolved
point_values::TPS
tangent_values::TXS
point_init::TPI
kellertuer marked this conversation as resolved.
Show resolved Hide resolved
tangent_init::TTI
once::Bool
last_stored::Int
function StoreStateAction(
keys::NTuple{N,Symbol} where {N}=NTuple{0,Symbol}(), once=true
general_keys::Vector{Symbol}=Symbol[],
point_values::NamedTuple=NamedTuple(),
tangent_values::NamedTuple=NamedTuple(),
once::Bool=true,
)
return new(Dict{Symbol,Any}(), keys, once, -1)
point_init = NamedTuple{keys(point_values)}(map(u -> false, keys(point_values)))
tangent_init = NamedTuple{keys(tangent_values)}(
map(u -> false, keys(tangent_values))
)
point_values_copy = NamedTuple{keys(point_values)}(
map(u -> copy(point_values[u]), keys(point_values))
)
tangent_values_copy = NamedTuple{keys(tangent_values)}(
map(u -> copy(tangent_values[u]), keys(tangent_values))
)
return new{
typeof(point_values),
typeof(tangent_values),
typeof(point_init),
typeof(tangent_init),
}(
Dict{Symbol,Any}(),
general_keys,
point_values_copy,
tangent_values_copy,
point_init,
tangent_init,
once,
-1,
)
end
end
function (a::StoreStateAction)(
amp::AbstractManoptProblem, s::AbstractManoptSolverState, i::Int
)
#update values (maybe only once)
if !a.once || a.last_stored != i
for key in a.keys
if key === :Iterate
M = get_manifold(amp)
a.values[key] = copy(M, get_iterate(s))
elseif key === :Gradient
M = get_manifold(amp)
p = get_iterate(s)
a.values[key] = copy(M, p, get_gradient(s))
elseif hasproperty(s, key)
a.values[key] = deepcopy(getproperty(s, key))
end
end
update_storage!(a, amp, s)
end
return a.last_stored = i
end

"""
get_storage(a,key)
get_storage(a::AbstractStateAction, key::Symbol)

return the internal value of the [`AbstractStateAction`](@ref) `a` at the
Return the internal value of the [`AbstractStateAction`](@ref) `a` at the
`Symbol` `key`.
"""
get_storage(a::AbstractStateAction, key) = a.values[key]
get_storage(a::AbstractStateAction, key::Symbol) = a.values[key]

"""
get_point_storage(a::AbstractStateAction, key::Symbol)

Return the internal value of the [`AbstractStateAction`](@ref) `a` at the
`Symbol` `key` that represents a point.
"""
get_point_storage(a::AbstractStateAction, key::Symbol) = a.point_values[key]

"""
get_storage(a,key)
get_tangent_storage(a::AbstractStateAction, key::Symbol)

return whether the [`AbstractStateAction`](@ref) `a` has a value stored at the
Return the internal value of the [`AbstractStateAction`](@ref) `a` at the
`Symbol` `key` that represents a tangent vector.
"""
get_tangent_storage(a::AbstractStateAction, key::Symbol) = a.tangent_values[key]

"""
has_storage(a::AbstractStateAction, key::Symbol)

Return whether the [`AbstractStateAction`](@ref) `a` has a value stored at the
`Symbol` `key`.
"""
has_storage(a::AbstractStateAction, key) = haskey(a.values, key)
has_storage(a::AbstractStateAction, key::Symbol) = haskey(a.values, key)

"""
update_storage!(a, s)
has_point_storage(a::AbstractStateAction, key::Symbol)

update the [`AbstractStateAction`](@ref) `a` internal values to the ones given on
Return whether the [`AbstractStateAction`](@ref) `a` has a point value stored at the
`Symbol` `key`.
"""
has_point_storage(a::AbstractStateAction, key::Symbol) = a.point_init[key]

"""
has_tangent_storage(a::AbstractStateAction, key::Symbol)

Return whether the [`AbstractStateAction`](@ref) `a` has a point value stored at the
`Symbol` `key`.
"""
has_tangent_storage(a::AbstractStateAction, key::Symbol) = a.tangent_init[key]

"""
update_storage!(a::AbstractStateAction, s::AbstractManoptSolverState)

Update the [`AbstractStateAction`](@ref) `a` internal values to the ones given on
the [`AbstractManoptSolverState`](@ref) `s`.

Warning: it does not update point and tangent vector storage.
"""
function update_storage!(a::AbstractStateAction, s::AbstractManoptSolverState)
for key in a.keys
Expand All @@ -278,14 +361,65 @@ function update_storage!(a::AbstractStateAction, s::AbstractManoptSolverState)
return a.keys
end

function _storage_key_true(nt::NamedTuple)
kellertuer marked this conversation as resolved.
Show resolved Hide resolved
return map(key -> NamedTuple{(key,),Tuple{Bool}}(true), keys(nt))
kellertuer marked this conversation as resolved.
Show resolved Hide resolved
end

"""
update_storage!(a::AbstractStateAction, amp::AbstractManoptProblem, s::AbstractManoptSolverState)

Update the [`AbstractStateAction`](@ref) `a` internal values to the ones given on
the [`AbstractManoptSolverState`](@ref) `s`.
Optimized using the information from `amp`
"""
function update_storage!(
a::AbstractStateAction, amp::AbstractManoptProblem, s::AbstractManoptSolverState
)
for key in a.keys
if key === :Iterate
a.values[key] = deepcopy(get_iterate(s))
elseif key === :Gradient
a.values[key] = deepcopy(get_gradient(s))
else
a.values[key] = deepcopy(getproperty(s, key))
end
end

M = get_manifold(amp)

pt_kts = _storage_key_true(a.point_values)
map(keys(a.point_values), pt_kts) do key, kt
if key === :Iterate
copyto!(M, a.point_values[key], get_iterate(s))
else
copyto!(
M, a.point_values[key], getproperty(s, key)::typeof(a.point_values[key])
)
end
a.point_init = merge(a.point_init, kt)
end
tv_kts = _storage_key_true(a.tangent_values)
map(keys(a.tangent_values), tv_kts) do key, kt
if key === :Gradient
copyto!(M, a.tangent_values[key], get_gradient(s))
else
copyto!(
M, a.tangent_values[key], getproperty(s, key)::typeof(a.tangent_values[key])
)
end
a.tangent_init = merge(a.tangent_init, kt)
end
return a.keys
end

"""
update_storage!(a, d)
update_storage!(a::AbstractStateAction, d::Dict{Symbol,<:Any})

Update the [`AbstractStateAction`](@ref) `a` internal values to the ones given in
the dictionary `d`. The values are merged, where the values from `d` are preferred.
"""
function update_storage!(a::AbstractStateAction, d::Dict{Symbol,<:Any})
merge!(a.values, d)
# update keys
return a.keys = Tuple(keys(a.values))
return a.keys = collect(keys(a.values))
end
Loading