Skip to content

Commit

Permalink
Type stabilization of ProductArray (#247)
Browse files Browse the repository at this point in the history
* type stabilization of ProductArray

* formatting

* new test for ProducteArray

* version bumped
  • Loading branch information
mateuszbaran authored Nov 11, 2020
1 parent 71cc388 commit 29b1b01
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Manifolds"
uuid = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
authors = ["Seth Axen <[email protected]>", "Mateusz Baran <[email protected]>", "Ronny Bergmann <[email protected]>", "Antoine Levitt <[email protected]>"]
version = "0.4.4"
version = "0.4.5"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand Down
3 changes: 1 addition & 2 deletions src/manifolds/VectorBundle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -918,8 +918,7 @@ function `f` for representing an operation with result in the vector space `fibe
for manifold `M` on given arguments (passed at a tuple).
"""
function allocate_result_type(::VectorBundleFibers, f, args::NTuple{N,Any}) where {N}
T = typeof(reduce(+, one(number_eltype(eti)) for eti in args))
return T
return typeof(mapreduce(eti -> one(number_eltype(eti)), +, args))
end

Base.size(x::FVector) = size(x.data)
Expand Down
3 changes: 2 additions & 1 deletion src/product_representations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,8 @@ function Base.:*(
return ProductArray(ShapeSpec, a * v.data, v.reshapers)
end

number_eltype(::Type{ProductArray{TM,TData,TV}}) where {TM,TData,TV} = eltype(TData)
number_eltype(a::ProductArray) = number_eltype(a.data)
number_eltype(::Type{TPA}) where {TM,TData,TPA<:ProductArray{TM,TData}} = eltype(TData)

function _show_component(io::IO, v; pre = "", head = "")
sx = sprint(show, "text/plain", v, context = io, sizehint = 0)
Expand Down
16 changes: 16 additions & 0 deletions test/product_manifold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@ include("utils.jl")

struct NotImplementedReshaper <: Manifolds.AbstractReshaper end

function parray(M, x)
return Manifolds.ProductArray(
Manifolds.ShapeSpecification(Manifolds.StaticReshaper(), M.manifolds...),
x,
)
end

@testset "Product manifold" begin
@test_throws MethodError ProductManifold()
M1 = Sphere(2)
Expand Down Expand Up @@ -422,6 +429,15 @@ struct NotImplementedReshaper <: Manifolds.AbstractReshaper end
@test submanifold_component(pts[1], Val(1)) === pts[1].parts[1]
@test submanifold_components(Mse, pts[1]) === pts[1].parts
@test submanifold_components(pts[1]) === pts[1].parts

p_inf = parray(Mse, randn(5))
@test (@inferred ManifoldsBase.allocate_result_type(
Mse,
Manifolds.log,
(p_inf, p_inf),
)) === Float64
@test (@inferred number_eltype(typeof(p_inf))) === Float64
@test pts_prod[1] .+ fill(1.0, 5) == pts_prod[1] .+ 1.0
end

@testset "ProductRepr" begin
Expand Down

2 comments on commit 29b1b01

@mateuszbaran
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/24488

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.5 -m "<description of version>" 29b1b01b54b2a6fca5541b5cc09caba8d0fdc20a
git push origin v0.4.5

Please sign in to comment.