Skip to content

Commit

Permalink
broadcasting for UMVTVector (#339)
Browse files Browse the repository at this point in the history
* broadcasting for UMVTVector

* a test and a comment
  • Loading branch information
mateuszbaran authored Mar 22, 2021
1 parent 2cfc9ec commit 3c83090
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Manifolds"
uuid = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
authors = ["Seth Axen <[email protected]>", "Mateusz Baran <[email protected]>", "Ronny Bergmann <[email protected]>", "Antoine Levitt <[email protected]>"]
version = "0.4.19"
version = "0.4.20"

[deps]
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
Expand Down
57 changes: 57 additions & 0 deletions src/manifolds/FixedRankMatrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,63 @@ Base.:-(v::UMVTVector) = UMVTVector(-v.U, -v.M, -v.Vt)
Base.:+(v::UMVTVector) = UMVTVector(v.U, v.M, v.Vt)
Base.:(==)(v::UMVTVector, w::UMVTVector) = (v.U == w.U) && (v.M == w.M) && (v.Vt == w.Vt)

Base.copy(v::UMVTVector) = UMVTVector(copy(v.U), copy(v.M), copy(v.Vt))

# Tuple-like broadcasting of UMVTVector

function Broadcast.BroadcastStyle(::Type{<:UMVTVector})
return Broadcast.Style{UMVTVector}()
end
function Broadcast.BroadcastStyle(
::Broadcast.AbstractArrayStyle{0},
b::Broadcast.Style{UMVTVector},
)
return b
end

Broadcast.instantiate(bc::Broadcast.Broadcasted{Broadcast.Style{UMVTVector},Nothing}) = bc
function Broadcast.instantiate(bc::Broadcast.Broadcasted{Broadcast.Style{UMVTVector}})
Broadcast.check_broadcast_axes(bc.axes, bc.args...)
return bc
end

Broadcast.broadcastable(v::UMVTVector) = v

@inline function Base.copy(bc::Broadcast.Broadcasted{Broadcast.Style{UMVTVector}})
return UMVTVector(
@inbounds(Broadcast._broadcast_getindex(bc, Val(:U))),
@inbounds(Broadcast._broadcast_getindex(bc, Val(:M))),
@inbounds(Broadcast._broadcast_getindex(bc, Val(:Vt))),
)
end

Base.@propagate_inbounds function Broadcast._broadcast_getindex(
v::UMVTVector,
::Val{I},
) where {I}
return getfield(v, I)
end

Base.axes(::UMVTVector) = ()

@inline function Base.copyto!(
dest::UMVTVector,
bc::Broadcast.Broadcasted{Broadcast.Style{UMVTVector}},
)
# Performance optimization: broadcast!(identity, dest, A) is equivalent to copyto!(dest, A) if indices match
if bc.f === identity && bc.args isa Tuple{UMVTVector} # only a single input argument to broadcast!
A = bc.args[1]
return copyto!(dest, A)
end
bc′ = Broadcast.preprocess(dest, bc)
copyto!(dest.U, Broadcast._broadcast_getindex(bc′, Val(:U)))
copyto!(dest.M, Broadcast._broadcast_getindex(bc′, Val(:M)))
copyto!(dest.Vt, Broadcast._broadcast_getindex(bc′, Val(:Vt)))
return dest
end

####

@doc raw"""
check_manifold_point(M::FixedRankMatrices{m,n,k}, p; kwargs...)
Expand Down
17 changes: 17 additions & 0 deletions test/fixed_rank.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,23 @@ include("utils.jl")
@test oneP == one(x)
oneV = UMVTVector(one(zeros(3, 3)), one(zeros(2, 2)), one(zeros(2, 2)), 2)
@test oneV == one(v)

# broadcasting
@test axes(w) === ()
wc = copy(w)
# test that the copy is equal to the original, but represented by
# a new array
@test wc.U !== w.U
@test wc.U == w.U
wb = w .+ v .* 2
@test wb isa UMVTVector
@test wb == w + v * 2

wb .= 2 .* w .+ v
@test wb == 2 * w + v

wb .= w
@test wb == w
end
test_manifold(
M,
Expand Down

0 comments on commit 3c83090

Please sign in to comment.