-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
better internal representation for points etc. on product manifolds
- Loading branch information
1 parent
d8f1766
commit 4698967
Showing
11 changed files
with
244 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,4 @@ | ||
docs/build | ||
Manifest.toml | ||
benchmark/tune.json | ||
benchmark/results.md |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
using Manifolds | ||
using BenchmarkTools | ||
using StaticArrays | ||
using LinearAlgebra | ||
using Random | ||
|
||
# Define a parent BenchmarkGroup to contain our suite | ||
SUITE = BenchmarkGroup() | ||
|
||
Random.seed!(12334) | ||
|
||
function add_manifold(M::Manifold, pts, name; | ||
test_tangent_vector_broadcasting = true, | ||
retraction_methods = [], | ||
inverse_retraction_methods = []) | ||
|
||
SUITE["manifolds"][name] = BenchmarkGroup() | ||
tv = log(M, pts[1], pts[2]) | ||
tv1 = log(M, pts[1], pts[2]) | ||
tv2 = log(M, pts[1], pts[3]) | ||
p = similar(pts[1]) | ||
SUITE["manifolds"][name]["similar"] = @benchmarkable similar($(pts[1])) | ||
SUITE["manifolds"][name]["log"] = @benchmarkable log($M, $(pts[1]), $(pts[2])) | ||
SUITE["manifolds"][name]["log!"] = @benchmarkable log!($M, $tv, $(pts[1]), $(pts[2])) | ||
SUITE["manifolds"][name]["exp"] = @benchmarkable exp($M, $(pts[1]), $tv1) | ||
SUITE["manifolds"][name]["exp"] = @benchmarkable exp!($M, $p, $(pts[1]), $tv1) | ||
SUITE["manifolds"][name]["norm"] = @benchmarkable norm($M, $(pts[1]), $tv1) | ||
SUITE["manifolds"][name]["inner"] = @benchmarkable inner($M, $(pts[1]), $tv1, $tv2) | ||
SUITE["manifolds"][name]["distance"] = @benchmarkable distance($M, $(pts[1]), $(pts[2])) | ||
SUITE["manifolds"][name]["isapprox (pt)"] = @benchmarkable isapprox($M, $(pts[1]), $(pts[2])) | ||
SUITE["manifolds"][name]["isapprox (tv)"] = @benchmarkable isapprox($M, $(pts[1]), $tv1, $tv2) | ||
SUITE["manifolds"][name]["2 * tv1 + 3 * tv2"] = @benchmarkable 2 * $tv1 + 3 * $tv2 | ||
if test_tangent_vector_broadcasting | ||
SUITE["manifolds"][name]["tv = 2 .* tv1 .+ 3 .* tv2"] = @benchmarkable $tv = 2 .* $tv1 .+ 3 .* $tv2 | ||
SUITE["manifolds"][name]["tv .= 2 .* tv1 .+ 3 .* tv2"] = @benchmarkable $tv .= 2 .* $tv1 .+ 3 .* $tv2 | ||
end | ||
end | ||
|
||
# General manifold benchmarks | ||
function add_manifold_benchmarks() | ||
|
||
SUITE["manifolds"] = BenchmarkGroup() | ||
|
||
s2 = Manifolds.Sphere(2) | ||
array_s2 = ArrayManifold(s2) | ||
r2 = Manifolds.Euclidean(2) | ||
|
||
add_manifold(s2, | ||
[Size(2)([1.0, 1.0]), | ||
Size(2)([-2.0, 3.0]), | ||
Size(2)([3.0, -2.0])], | ||
"Euclidean{2} -- SizedArray") | ||
|
||
add_manifold(r2, | ||
[MVector{2,Float64}([1.0, 1.0]), | ||
MVector{2,Float64}([-2.0, 3.0]), | ||
MVector{2,Float64}([3.0, -2.0])], | ||
"Euclidean{2} -- MVector") | ||
|
||
add_manifold(s2, | ||
[Size(3)([1.0, 0.0, 0.0]), | ||
Size(3)([0.0, 1.0, 0.0]), | ||
Size(3)([0.0, 0.0, 1.0])], | ||
"Sphere{2} -- SizedArray") | ||
|
||
add_manifold(array_s2, | ||
[Size(3)([1.0, 0.0, 0.0]), | ||
Size(3)([0.0, 1.0, 0.0]), | ||
Size(3)([0.0, 0.0, 1.0])], | ||
"ArrayManifold{Sphere{2}} -- SizedArray"; | ||
test_tangent_vector_broadcasting = false) | ||
|
||
so2 = Manifolds.Rotations(2) | ||
angles = (0.0, π/2, 2π/3) | ||
add_manifold(so2, | ||
[Size(2, 2)([cos(ϕ) sin(ϕ); -sin(ϕ) cos(ϕ)]) for ϕ in angles], | ||
"Rotations(2) -- SizedArray") | ||
|
||
m_prod = Manifolds.ProductManifold(s2, r2) | ||
|
||
pts_prd_base = [[1.0, 0.0, 0.0, 0.0, 0.0], | ||
[0.0, 1.0, 0.0, 1.0, 0.0], | ||
[0.0, 0.0, 1.0, 0.0, 0.1]] | ||
pts_prod = map(p -> Manifolds.ProductView(m_prod, p), pts_prd_base) | ||
|
||
add_manifold(m_prod, pts_prod, "ProductManifold") | ||
end | ||
|
||
add_manifold_benchmarks() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
using PkgBenchmark | ||
|
||
#--track-allocation=all | ||
config = BenchmarkConfig(id = nothing, | ||
juliacmd = `julia -O3`, | ||
env = Dict("JULIA_NUM_THREADS" => 4)) | ||
|
||
results = benchmarkpkg("Manifolds", config) | ||
export_markdown("benchmark/results.md", results) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,11 @@ import Base: isapprox, | |
angle, | ||
eltype, | ||
similar, | ||
getindex, | ||
setindex!, | ||
size, | ||
convert, | ||
dataids, | ||
+, | ||
-, | ||
* | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,12 +26,81 @@ function ProductManifold(manifolds...) | |
k += len | ||
end | ||
TRanges = tuple(ranges...) | ||
TSizes = tuple(sizes...) | ||
TSizes = Tuple{map(s -> Tuple{s...}, sizes)...} | ||
return ProductManifold{typeof(manifolds), TRanges, TSizes}(manifolds) | ||
end | ||
|
||
struct ProductView{TM<:ProductManifold,T,N,TData<:AbstractArray{T,N},TV<:Tuple} <: AbstractArray{T,N} | ||
M::TM | ||
data::TData | ||
views::TV | ||
end | ||
|
||
# The two-argument version of this constructor is substantially faster than | ||
# the generic one. | ||
function ProductView(M::ProductManifold{TM, TRanges, Tuple{Size1, Size2}}, data::TData) where {TM, TRanges, Size1, Size2, T, N, TData<:AbstractArray{T,N}} | ||
#println("PV2") | ||
views = (SizedAbstractArray{Size1}(view(data, TRanges[1])), | ||
SizedAbstractArray{Size2}(view(data, TRanges[2]))) | ||
return ProductView{typeof(M), T, N, TData, typeof(views)}(M, data, views) | ||
end | ||
|
||
function ProductView(M::ProductManifold{TM, TRanges, TSizes}, data::TData) where {TM, TRanges, TSizes, TData<:AbstractArray} | ||
#println("PVn") | ||
views = map(ziptuples(TRanges, TSizes)) do t | ||
SizedAbstractArray{t[2]}(view(data, t[1])) | ||
end | ||
return ProductView{ProductManifold{TM, TRanges, TSizes}, TData, typeof(views)}(M, data, views) | ||
end | ||
|
||
Base.BroadcastStyle(::Type{<:ProductView}) = Broadcast.ArrayStyle{ProductView}() | ||
|
||
function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{ProductView}}, ::Type{ElType}) where ElType | ||
# Scan the inputs for the ProductView: | ||
A = find_pv(bc) | ||
return ProductView(A.M, similar(A.data, ElType)) | ||
end | ||
|
||
Base.dataids(x::ProductView) = Base.dataids(x.data) | ||
|
||
""" | ||
find_pv(x...) | ||
`A = find_pv(x...)` returns the first `ProductView` among the arguments. | ||
""" | ||
find_pv(bc::Base.Broadcast.Broadcasted) = find_pv(bc.args) | ||
find_pv(args::Tuple) = find_pv(find_pv(args[1]), Base.tail(args)) | ||
find_pv(x) = x | ||
find_pv(a::ProductView, rest) = a | ||
find_pv(::Any, rest) = find_pv(rest) | ||
|
||
size(x::ProductView) = size(x.data) | ||
Base.@propagate_inbounds getindex(x::ProductView, i) = getindex(x.data, i) | ||
Base.@propagate_inbounds setindex!(x::ProductView, val, i) = setindex!(x.data, val, i) | ||
|
||
(+)(v1::ProductView, v2::ProductView) = ProductView(v1.M, v1.data + v2.data) | ||
(-)(v1::ProductView, v2::ProductView) = ProductView(v1.M, v1.data - v2.data) | ||
(-)(v::ProductView) = ProductView(v.M, -v.data) | ||
(*)(a::Number, v::ProductView) = ProductView(v.M, a*v.data) | ||
|
||
eltype(::Type{ProductView{TM, TData, TV}}) where {TM, TData, TV} = eltype(TData) | ||
similar(x::ProductView) = ProductView(x.M, similar(x.data)) | ||
similar(x::ProductView, ::Type{T}) where T = ProductView(x.M, similar(x.data, T)) | ||
|
||
function isapprox(M::ProductManifold, x, y; kwargs...) | ||
return mapreduce(&, ziptuples(M.manifolds, x.views, y.views)) do t | ||
return isapprox(t[1], t[2], t[3]; kwargs...) | ||
end | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong. |
||
end | ||
|
||
function isapprox(M::ProductManifold, x, v, w; kwargs...) | ||
return mapreduce(&, ziptuples(M.manifolds, x.views, v.views, w.views)) do t | ||
return isapprox(t[1], t[2], t[3], t[4]; kwargs...) | ||
end | ||
end | ||
|
||
function representation_size(M::ProductManifold, ::Type{T}) where {T} | ||
return (mapreduce(m -> representation_size(m, T), +, M.manifolds),) | ||
return (mapreduce(m -> sum(representation_size(m, T)), +, M.manifolds),) | ||
end | ||
|
||
manifold_dimension(M::ProductManifold) = sum(map(m -> manifold_dimension(m), M.manifolds)) | ||
|
@@ -48,47 +117,22 @@ function inverse_local_metric(M::MetricManifold{<:ProductManifold,ProductMetric} | |
error("TODO") | ||
end | ||
|
||
function uview_element(x::AbstractArray, range, shape::Size{S}) where S | ||
return SizedAbstractArray{Tuple{S...}}(uview(x, range)) | ||
end | ||
|
||
function suview_element(x::AbstractArray, range, shape::Size) | ||
return reshape(uview(x, range), shape) | ||
end | ||
|
||
function det_local_metric(M::MetricManifold{ProductManifold{<:Manifold,TRanges,TSizes},ProductMetric}, x) where {TRanges, TSizes} | ||
dets = map(ziptuples(M.manifolds, TRanges, TSizes)) do d | ||
return det_local_metric(d[1], view_element(x, d[2], d[3])) | ||
end | ||
function det_local_metric(M::MetricManifold{ProductManifold, ProductMetric}, x::ProductView) | ||
dets = map(det_local_metric, M.manifolds, x.views) | ||
return prod(dets) | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
mateuszbaran
Author
Member
|
||
end | ||
|
||
function inner(M::ProductManifold{TM, TRanges, TSizes}, x, v, w) where {TM, TRanges, TSizes} | ||
subproducts = map(ziptuples(M.manifolds, TRanges, TSizes)) do t | ||
inner(t[1], | ||
suview_element(x, t[2], t[3]), | ||
suview_element(v, t[2], t[3]), | ||
suview_element(w, t[2], t[3])) | ||
end | ||
function inner(M::ProductManifold, x::ProductView, v::ProductView, w::ProductView) | ||
subproducts = map(inner, M.manifolds, x.views, v.views, w.views) | ||
return sum(subproducts) | ||
end | ||
|
||
function exp!(M::ProductManifold{TM, TRanges, TSizes}, y, x, v) where {TM, TRanges, TSizes} | ||
map(ziptuples(M.manifolds, TRanges, TSizes)) do t | ||
exp!(t[1], | ||
uview_element(y, t[2], t[3]), | ||
suview_element(x, t[2], t[3]), | ||
suview_element(v, t[2], t[3])) | ||
end | ||
function exp!(M::ProductManifold, y::ProductView, x::ProductView, v::ProductView) | ||
map(exp!, M.manifolds, y.views, x.views, v.views) | ||
return y | ||
end | ||
|
||
function log!(M::ProductManifold{TM, TRanges, TSizes}, v, x, y) where {TM, TRanges, TSizes} | ||
map(ziptuples(M.manifolds, TRanges, TSizes)) do t | ||
log!(t[1], | ||
uview_element(v, t[2], t[3]), | ||
suview_element(x, t[2], t[3]), | ||
suview_element(y, t[2], t[3])) | ||
end | ||
function log!(M::ProductManifold, v::ProductView, x::ProductView, y::ProductView) | ||
map(log!, M.manifolds, v.views, x.views, y.views) | ||
return v | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
include("utils.jl") | ||
|
||
@testset "Euclidean" begin | ||
M = Manifolds.Euclidean(3) | ||
types = [Vector{Float64}, | ||
SizedVector{3, Float64}, | ||
MVector{3, Float64}, | ||
Vector{Float32}, | ||
SizedVector{3, Float32}, | ||
MVector{3, Float32}, | ||
Vector{Double64}, | ||
MVector{3, Double64}, | ||
SizedVector{3, Double64}] | ||
for T in types | ||
@testset "Type $T" begin | ||
pts = [convert(T, [1.0, 0.0, 0.0]), | ||
convert(T, [0.0, 1.0, 0.0]), | ||
convert(T, [0.0, 0.0, 1.0])] | ||
test_manifold(M, | ||
pts, | ||
test_reverse_diff = isa(T, Vector)) | ||
end | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
You could save some effort here by skipping after the first
false
.and similarly below.