From 29b1b01b54b2a6fca5541b5cc09caba8d0fdc20a Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Wed, 11 Nov 2020 13:46:53 +0100 Subject: [PATCH 1/2] Type stabilization of ProductArray (#247) * type stabilization of ProductArray * formatting * new test for ProducteArray * version bumped --- Project.toml | 2 +- src/manifolds/VectorBundle.jl | 3 +-- src/product_representations.jl | 3 ++- test/product_manifold.jl | 16 ++++++++++++++++ 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index d4b48c4ee1..ad15395453 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Manifolds" uuid = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e" authors = ["Seth Axen ", "Mateusz Baran ", "Ronny Bergmann ", "Antoine Levitt "] -version = "0.4.4" +version = "0.4.5" [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" diff --git a/src/manifolds/VectorBundle.jl b/src/manifolds/VectorBundle.jl index f991f09c49..2d0cf3564b 100644 --- a/src/manifolds/VectorBundle.jl +++ b/src/manifolds/VectorBundle.jl @@ -918,8 +918,7 @@ function `f` for representing an operation with result in the vector space `fibe for manifold `M` on given arguments (passed at a tuple). """ function allocate_result_type(::VectorBundleFibers, f, args::NTuple{N,Any}) where {N} - T = typeof(reduce(+, one(number_eltype(eti)) for eti in args)) - return T + return typeof(mapreduce(eti -> one(number_eltype(eti)), +, args)) end Base.size(x::FVector) = size(x.data) diff --git a/src/product_representations.jl b/src/product_representations.jl index 80864b0290..2bef7da577 100644 --- a/src/product_representations.jl +++ b/src/product_representations.jl @@ -269,7 +269,8 @@ function Base.:*( return ProductArray(ShapeSpec, a * v.data, v.reshapers) end -number_eltype(::Type{ProductArray{TM,TData,TV}}) where {TM,TData,TV} = eltype(TData) +number_eltype(a::ProductArray) = number_eltype(a.data) +number_eltype(::Type{TPA}) where {TM,TData,TPA<:ProductArray{TM,TData}} = eltype(TData) function _show_component(io::IO, v; pre = "", head = "") sx = sprint(show, "text/plain", v, context = io, sizehint = 0) diff --git a/test/product_manifold.jl b/test/product_manifold.jl index ed2270d5f8..5a6c72afc5 100644 --- a/test/product_manifold.jl +++ b/test/product_manifold.jl @@ -2,6 +2,13 @@ include("utils.jl") struct NotImplementedReshaper <: Manifolds.AbstractReshaper end +function parray(M, x) + return Manifolds.ProductArray( + Manifolds.ShapeSpecification(Manifolds.StaticReshaper(), M.manifolds...), + x, + ) +end + @testset "Product manifold" begin @test_throws MethodError ProductManifold() M1 = Sphere(2) @@ -422,6 +429,15 @@ struct NotImplementedReshaper <: Manifolds.AbstractReshaper end @test submanifold_component(pts[1], Val(1)) === pts[1].parts[1] @test submanifold_components(Mse, pts[1]) === pts[1].parts @test submanifold_components(pts[1]) === pts[1].parts + + p_inf = parray(Mse, randn(5)) + @test (@inferred ManifoldsBase.allocate_result_type( + Mse, + Manifolds.log, + (p_inf, p_inf), + )) === Float64 + @test (@inferred number_eltype(typeof(p_inf))) === Float64 + @test pts_prod[1] .+ fill(1.0, 5) == pts_prod[1] .+ 1.0 end @testset "ProductRepr" begin From 3ffe741fdc9c435e4b67837dbbb67919752b4d43 Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Sat, 14 Nov 2020 15:38:28 +0100 Subject: [PATCH 2/2] Some work on making groups embedded (#251) * some work on making groups embedded * additionally to just TransparentIsometric the AbstractGroupManifold should also be transparent wrt vector/coordinates. * runs formatter. * fixed special orthogonal test * Update src/manifolds/Rotations.jl Co-authored-by: Seth Axen * version bumped Co-authored-by: Ronny Bergmann Co-authored-by: Seth Axen --- Project.toml | 2 +- src/groups/group.jl | 38 ++++++++++++++++++++++--------- src/manifolds/Rotations.jl | 5 +++- test/groups/special_orthogonal.jl | 2 +- 4 files changed, 33 insertions(+), 14 deletions(-) diff --git a/Project.toml b/Project.toml index ad15395453..a224b5425c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Manifolds" uuid = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e" authors = ["Seth Axen ", "Mateusz Baran ", "Ronny Bergmann ", "Antoine Levitt "] -version = "0.4.5" +version = "0.4.6" [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" diff --git a/src/groups/group.jl b/src/groups/group.jl index f3f2dd8ce4..82149e3c22 100644 --- a/src/groups/group.jl +++ b/src/groups/group.jl @@ -29,8 +29,8 @@ Abstract type for a Lie group, a group that is also a smooth manifold with an implement at least [`inv`](@ref), [`identity`](@ref), [`compose`](@ref), and [`translate_diff`](@ref). """ -abstract type AbstractGroupManifold{𝔽,O<:AbstractGroupOperation} <: - AbstractDecoratorManifold{𝔽} end +abstract type AbstractGroupManifold{𝔽,O<:AbstractGroupOperation,T<:AbstractEmbeddingType} <: + AbstractEmbeddedManifold{𝔽,T} end """ GroupManifold{𝔽,M<:Manifold{𝔽},O<:AbstractGroupOperation} <: AbstractGroupManifold{𝔽,O} @@ -45,7 +45,7 @@ Group manifolds by default forward metric-related operations to the wrapped mani GroupManifold(manifold, op) """ struct GroupManifold{𝔽,M<:Manifold{𝔽},O<:AbstractGroupOperation} <: - AbstractGroupManifold{𝔽,O} + AbstractGroupManifold{𝔽,O,TransparentIsometricEmbedding} manifold::M op::O end @@ -67,11 +67,13 @@ function base_group(M::Manifold) end base_group(G::AbstractGroupManifold) = G -decorator_group_dispatch(M::Manifold) = Val(false) +base_manifold(G::GroupManifold) = G.manifold + +decorator_group_dispatch(::Manifold) = Val(false) function decorator_group_dispatch(M::AbstractDecoratorManifold) return decorator_group_dispatch(decorated_manifold(M)) end -decorator_group_dispatch(M::AbstractGroupManifold) = Val(true) +decorator_group_dispatch(::AbstractGroupManifold) = Val(true) function is_group_decorator(M::Manifold) return _extract_val(decorator_group_dispatch(M)) @@ -85,6 +87,22 @@ if VERSION ≥ v"1.3" (::Type{T})(M::Manifold) where {T<:AbstractGroupOperation} = GroupManifold(M, T()) end +function decorator_transparent_dispatch( + ::typeof(get_coordinates!), + ::AbstractGroupManifold, + args..., +) + return Val(:transparent) +end +function decorator_transparent_dispatch( + ::typeof(get_vector!), + ::AbstractGroupManifold, + args..., +) + return Val(:transparent) +end + + ################### # Action directions ################### @@ -322,14 +340,12 @@ for MT in GROUP_MANIFOLD_BASIS_DISAMBIGUATION ) end -@decorator_transparent_fallback :transparent function check_manifold_point( - G::AbstractGroupManifold, - e::Identity; - kwargs..., -) +manifold_dimension(G::GroupManifold) = manifold_dimension(G.manifold) + +function check_manifold_point(G::AbstractGroupManifold, e::Identity; kwargs...) return DomainError(e, "The identity element $(e) does not belong to $(G).") end -@decorator_transparent_fallback :transparent function check_manifold_point( +function check_manifold_point( G::GT, e::Identity{GT}; kwargs..., diff --git a/src/manifolds/Rotations.jl b/src/manifolds/Rotations.jl index 1bede6186a..aa26b6be0e 100644 --- a/src/manifolds/Rotations.jl +++ b/src/manifolds/Rotations.jl @@ -79,7 +79,10 @@ function check_manifold_point(M::Rotations{N}, p; kwargs...) where {N} return DomainError(det(p), "The determinant of $p has to be +1 but it is $(det(p))") end if !isapprox(transpose(p) * p, one(p); kwargs...) - return DomainError(norm(p), "$p has to be orthogonal but it's not") + return DomainError( + norm(transpose(p) * p - one(p)), + "$p must be orthogonal but it's not at kwargs $kwargs", + ) end return nothing end diff --git a/test/groups/special_orthogonal.jl b/test/groups/special_orthogonal.jl index 8f8f98d82c..8e3f843d33 100644 --- a/test/groups/special_orthogonal.jl +++ b/test/groups/special_orthogonal.jl @@ -90,7 +90,7 @@ include("group_utils.jl") inverse_retraction_methods = inverse_retraction_methods, exp_log_atol_multiplier = 20, retraction_atol_multiplier = 12, - is_tangent_atol_multiplier = 1, + is_tangent_atol_multiplier = 1.2, ) @test injectivity_radius(G) == injectivity_radius(M)