diff --git a/Project.toml b/Project.toml index d4b48c4ee1..ad15395453 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Manifolds" uuid = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e" authors = ["Seth Axen ", "Mateusz Baran ", "Ronny Bergmann ", "Antoine Levitt "] -version = "0.4.4" +version = "0.4.5" [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" diff --git a/src/manifolds/VectorBundle.jl b/src/manifolds/VectorBundle.jl index f991f09c49..2d0cf3564b 100644 --- a/src/manifolds/VectorBundle.jl +++ b/src/manifolds/VectorBundle.jl @@ -918,8 +918,7 @@ function `f` for representing an operation with result in the vector space `fibe for manifold `M` on given arguments (passed at a tuple). """ function allocate_result_type(::VectorBundleFibers, f, args::NTuple{N,Any}) where {N} - T = typeof(reduce(+, one(number_eltype(eti)) for eti in args)) - return T + return typeof(mapreduce(eti -> one(number_eltype(eti)), +, args)) end Base.size(x::FVector) = size(x.data) diff --git a/src/product_representations.jl b/src/product_representations.jl index 80864b0290..2bef7da577 100644 --- a/src/product_representations.jl +++ b/src/product_representations.jl @@ -269,7 +269,8 @@ function Base.:*( return ProductArray(ShapeSpec, a * v.data, v.reshapers) end -number_eltype(::Type{ProductArray{TM,TData,TV}}) where {TM,TData,TV} = eltype(TData) +number_eltype(a::ProductArray) = number_eltype(a.data) +number_eltype(::Type{TPA}) where {TM,TData,TPA<:ProductArray{TM,TData}} = eltype(TData) function _show_component(io::IO, v; pre = "", head = "") sx = sprint(show, "text/plain", v, context = io, sizehint = 0) diff --git a/test/product_manifold.jl b/test/product_manifold.jl index ed2270d5f8..5a6c72afc5 100644 --- a/test/product_manifold.jl +++ b/test/product_manifold.jl @@ -2,6 +2,13 @@ include("utils.jl") struct NotImplementedReshaper <: Manifolds.AbstractReshaper end +function parray(M, x) + return Manifolds.ProductArray( + Manifolds.ShapeSpecification(Manifolds.StaticReshaper(), M.manifolds...), + x, + ) +end + @testset "Product manifold" begin @test_throws MethodError ProductManifold() M1 = Sphere(2) @@ -422,6 +429,15 @@ struct NotImplementedReshaper <: Manifolds.AbstractReshaper end @test submanifold_component(pts[1], Val(1)) === pts[1].parts[1] @test submanifold_components(Mse, pts[1]) === pts[1].parts @test submanifold_components(pts[1]) === pts[1].parts + + p_inf = parray(Mse, randn(5)) + @test (@inferred ManifoldsBase.allocate_result_type( + Mse, + Manifolds.log, + (p_inf, p_inf), + )) === Float64 + @test (@inferred number_eltype(typeof(p_inf))) === Float64 + @test pts_prod[1] .+ fill(1.0, 5) == pts_prod[1] .+ 1.0 end @testset "ProductRepr" begin