-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add minimalvec #22
Merged
Merged
add minimalvec #22
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
name = "VectorInterface" | ||
uuid = "409d34a3-91d5-4945-b6ec-7529ddf182d8" | ||
authors = ["Jutho Haegeman <[email protected]> and contributors"] | ||
version = "0.4.9" | ||
version = "0.5" | ||
|
||
[deps] | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
|
@@ -10,7 +10,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | |
Aqua = "0.6, 0.7, 0.8" | ||
LinearAlgebra = "1" | ||
Test = "1" | ||
TestExtras = "0.2" | ||
TestExtras = "0.2,0.3" | ||
julia = "1" | ||
|
||
[extras] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
""" | ||
MinimalVec{M,V<:AbstractVector} | ||
MinimalVec{M}(vec::V) where {M,V<:AbstractVector} | ||
|
||
Wraps a vector of type `V<:AbstractVector` in such a way that the wrapper only supports the | ||
minimal interface put forward by VectorInterface.jl. The type parameter `M` can take the | ||
values `true` or `false` and determines whether the vector behaves as a mutable vector that | ||
supports in-place operations (`M == true`) or whether it behaves as an immutable or | ||
static vector (`M == false`). | ||
|
||
This wrapper can be used to test whether an algorithm is implemented using only the minimal | ||
interface of VectorInterface.jl, without relying on other methods that would for example | ||
be available for `AbstractVector` or `AbstractArray`. | ||
|
||
To unwrap the contents of a `v::MinimalVec` instance, the field access `v.vec` can be used. | ||
|
||
See also [`MinimalMVec`](@ref) and [`MinimalSVec`](@ref) for convenience constructors. | ||
""" | ||
struct MinimalVec{M,V<:AbstractVector} | ||
vec::V | ||
function MinimalVec{M,V}(vec::V) where {M,V} | ||
M isa Bool || throw(ArgumentError("first type parameter must be `true` or `false`")) | ||
return new{M,V}(vec) | ||
end | ||
MinimalVec{M}(vec::V) where {M,V} = MinimalVec{M,V}(vec) | ||
end | ||
""" | ||
const MinimalMVec = MinimalVec{true} | ||
MinimalMVec(v::AbstractVector) | ||
|
||
Type alias for `MinimalVec{true}`, representing a vector wrapper that implements the minimal | ||
interface of VectorInterface.jl, including in-place operations (!-methods). | ||
|
||
See also [`MinimalVec`](@ref) and [`MinimalSVec`](@ref). | ||
""" | ||
const MinimalMVec{V} = MinimalVec{true,V} | ||
""" | ||
const MinimalSVec = MinimalVec{false} | ||
MinimalSVec(v::AbstractVector) | ||
|
||
Type alias for `MinimalVec{false}`, representing a vector wrapper that implements the | ||
minimal interface of VectorInterface.jl, excluding in-place operations (!-methods). | ||
|
||
See also [`MinimalVec`](@ref) and [`MinimalMVec`](@ref). | ||
""" | ||
const MinimalSVec{V} = MinimalVec{false,V} | ||
|
||
MinimalMVec(v::AbstractVector) = MinimalVec{true}(v) | ||
MinimalSVec(v::AbstractVector) = MinimalVec{false}(v) | ||
|
||
_ismutable(::Type{MinimalVec{M,V}}) where {V,M} = M | ||
_ismutable(v::MinimalVec) = _ismutable(typeof(v)) | ||
|
||
scalartype(::Type{<:MinimalVec{M,V}}) where {M,V} = scalartype(V) | ||
|
||
function zerovector(v::MinimalVec, S::Type{<:Number}) | ||
return MinimalVec{_ismutable(v)}(zerovector(v.vec, S)) | ||
end | ||
function zerovector!(v::MinimalMVec{V}) where {V} | ||
zerovector!(v.vec) | ||
return v | ||
end | ||
zerovector!!(v::MinimalVec) = _ismutable(v) ? zerovector!(v) : zerovector(v) | ||
|
||
function scale(v::MinimalVec, α::Number) | ||
return MinimalVec{_ismutable(v)}(scale(v.vec, α)) | ||
end | ||
function scale!(v::MinimalMVec{V}, α::Number) where {V} | ||
scale!(v.vec, α) | ||
return v | ||
end | ||
function scale!!(v::MinimalVec, α::Number) | ||
if _ismutable(v) | ||
w = scale!!(v.vec, α) | ||
return w === v.vec ? v : MinimalMVec(w) | ||
else | ||
return scale(v, α) | ||
end | ||
end | ||
function scale!(w::MinimalMVec{V}, v::MinimalMVec{W}, α::Number) where {V,W} | ||
scale!(w.vec, v.vec, α) | ||
return w | ||
end | ||
function scale!!(w::MinimalVec, v::MinimalVec, α::Number) | ||
if _ismutable(w) | ||
wvec = scale!!(w.vec, v.vec, α) | ||
return wvec === w.vec ? w : MinimalMVec(wvec) | ||
else | ||
return scale(v, α * one(scalartype(w))) | ||
end | ||
end | ||
|
||
function add(y::MinimalVec, x::MinimalVec, α::Number, β::Number) | ||
return MinimalVec{_ismutable(y)}(add(y.vec, x.vec, α, β)) | ||
end | ||
function add!(y::MinimalMVec{W}, x::MinimalMVec{V}, α::Number, β::Number) where {W,V} | ||
add!(y.vec, x.vec, α, β) | ||
return y | ||
end | ||
function add!!(y::MinimalVec, x::MinimalVec, α::Number, β::Number) | ||
if _ismutable(y) | ||
yvec = add!!(y.vec, x.vec, α, β) | ||
return yvec === y.vec ? y : MinimalMVec(yvec) | ||
else | ||
return add(y, x, α, β) | ||
end | ||
end | ||
|
||
inner(x::MinimalVec, y::MinimalVec) = inner(x.vec, y.vec) | ||
LinearAlgebra.norm(x::MinimalVec) = LinearAlgebra.norm(x.vec) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
module Minimal | ||
using VectorInterface | ||
using VectorInterface: MinimalVec, MinimalMVec | ||
using Test | ||
using TestExtras | ||
|
||
deepcollect(x::MinimalVec) = x.vec | ||
deepcollect(x::Number) = x | ||
|
||
x = MinimalMVec(randn(3)) | ||
y = MinimalMVec(randn(3)) | ||
|
||
@testset "scalartype" begin | ||
s = @constinferred scalartype(x) | ||
@test s == Float64 | ||
end | ||
|
||
@testset "zerovector" begin | ||
z = @constinferred zerovector(x) | ||
@test all(iszero, deepcollect(z)) | ||
@test all(deepcollect(z) .=== zero(scalartype(x))) | ||
z1 = @constinferred zerovector!!(deepcopy(x)) | ||
@test all(deepcollect(z1) .=== zero(scalartype(x))) | ||
z2 = @constinferred zerovector!(deepcopy(x)) | ||
@test all(deepcollect(z2) .=== zero(scalartype(x))) | ||
|
||
z3 = @constinferred zerovector(x, ComplexF64) | ||
@test all(deepcollect(z3) .=== zero(ComplexF64)) | ||
z4 = @constinferred zerovector!!(deepcopy(x), ComplexF64) | ||
@test all(deepcollect(z4) .=== zero(ComplexF64)) | ||
@test_throws MethodError zerovector!(deepcopy(x), ComplexF64) | ||
end | ||
|
||
@testset "scale" begin | ||
α = randn() | ||
z = @constinferred scale(x, α) | ||
@test all(deepcollect(z) .== α .* deepcollect(x)) | ||
|
||
z2 = @constinferred scale!!(deepcopy(x), α) | ||
@test deepcollect(z2) ≈ (α .* deepcollect(x)) | ||
xcopy = deepcopy(x) | ||
z2 = @constinferred scale!!(deepcopy(y), xcopy, α) | ||
@test deepcollect(z2) ≈ (α .* deepcollect(x)) | ||
@test all(deepcollect(xcopy) .== deepcollect(x)) | ||
|
||
z3 = @constinferred scale!(deepcopy(x), α) | ||
@test deepcollect(z3) ≈ (α .* deepcollect(x)) | ||
xcopy = deepcopy(x) | ||
z3 = @constinferred scale!(zerovector(x), xcopy, α) | ||
@test deepcollect(z3) ≈ (α .* deepcollect(x)) | ||
@test all(deepcollect(xcopy) .== deepcollect(x)) | ||
|
||
α = randn(ComplexF64) | ||
z4 = @constinferred scale(x, α) | ||
@test deepcollect(z4) ≈ (α .* deepcollect(x)) | ||
xcopy = deepcopy(x) | ||
z5 = @constinferred scale!!(xcopy, α) | ||
@test deepcollect(z5) ≈ (α .* deepcollect(x)) | ||
@test all(deepcollect(xcopy) .== deepcollect(x)) | ||
@test_throws InexactError scale!(xcopy, α) | ||
|
||
α = randn(ComplexF64) | ||
xcopy = deepcopy(x) | ||
z6 = @constinferred scale!!(zerovector(x), xcopy, α) | ||
@test deepcollect(z6) ≈ (α .* deepcollect(x)) | ||
@test all(deepcollect(xcopy) .== deepcollect(x)) | ||
@test_throws InexactError scale!(zerovector(x), xcopy, α) | ||
|
||
xz = @constinferred zerovector(x, ComplexF64) | ||
z6 = @constinferred scale!!(xz, xcopy, α) | ||
@test deepcollect(z6) ≈ (α .* deepcollect(x)) | ||
@test all(deepcollect(xcopy) .== deepcollect(x)) | ||
|
||
z7 = @constinferred scale!(xz, xcopy, α) | ||
@test deepcollect(z7) ≈ (α .* deepcollect(x)) | ||
@test all(deepcollect(xcopy) .== deepcollect(x)) | ||
|
||
ycomplex = zerovector(y, ComplexF64) | ||
α = randn(Float64) | ||
xcopy = deepcopy(x) | ||
z8 = @constinferred scale!!(ycomplex, xcopy, α) | ||
@test z8 === ycomplex | ||
@test all(deepcollect(z8) .== α .* deepcollect(xcopy)) | ||
end | ||
|
||
@testset "add" begin | ||
α, β = randn(2) | ||
z = @constinferred add(y, x) | ||
@test all(deepcollect(z) .== deepcollect(x) .+ deepcollect(y)) | ||
z = @constinferred add(y, x, α) | ||
# for some reason, on some Julia versions on some platforms, but only in test mode | ||
# there is a small floating point discrepancy, which makes the following test fail: | ||
# @test all(deepcollect(z) .== muladd.(deepcollect(x), α, deepcollect(y))) | ||
@test deepcollect(z) ≈ muladd.(deepcollect(x), α, deepcollect(y)) | ||
z = @constinferred add(y, x, α, β) | ||
# for some reason, on some Julia versions on some platforms, but only in test mode | ||
# there is a small floating point discrepancy, which makes the following test fail: | ||
# @test all(deepcollect(z) .== muladd.(deepcollect(x), α, deepcollect(y) .* β)) | ||
@test deepcollect(z) ≈ muladd.(deepcollect(x), α, deepcollect(y) .* β) | ||
|
||
α, β = randn(2) | ||
xcopy = deepcopy(x) | ||
z2 = @constinferred add!!(deepcopy(y), xcopy) | ||
@test deepcollect(z2) ≈ (deepcollect(x) .+ deepcollect(y)) | ||
@test all(deepcollect(xcopy) .== deepcollect(x)) | ||
z2 = @constinferred add!!(deepcopy(y), xcopy, α) | ||
@test deepcollect(z2) ≈ (muladd.(deepcollect(x), α, deepcollect(y))) | ||
@test all(deepcollect(xcopy) .== deepcollect(x)) | ||
z2 = @constinferred add!!(deepcopy(y), xcopy, α, β) | ||
@test deepcollect(z2) ≈ (muladd.(deepcollect(x), α, deepcollect(y) .* β)) | ||
@test all(deepcollect(xcopy) .== deepcollect(x)) | ||
|
||
α, β = randn(2) | ||
z3 = @constinferred add!(deepcopy(y), xcopy) | ||
@test deepcollect(z3) ≈ (deepcollect(y) .+ deepcollect(x)) | ||
@test all(deepcollect(xcopy) .== deepcollect(x)) | ||
z3 = @constinferred add!(deepcopy(y), xcopy, α) | ||
@test all(deepcollect(xcopy) .== deepcollect(x)) | ||
@test deepcollect(z3) ≈ (muladd.(deepcollect(x), α, deepcollect(y))) | ||
z3 = @constinferred add!(deepcopy(y), xcopy, α, β) | ||
@test deepcollect(z3) ≈ (muladd.(deepcollect(x), α, deepcollect(y) .* β)) | ||
@test all(deepcollect(xcopy) .== deepcollect(x)) | ||
|
||
α, β = randn(ComplexF64, 2) | ||
z4 = @constinferred add(y, x, α) | ||
@test deepcollect(z4) ≈ (muladd.(deepcollect(x), α, deepcollect(y))) | ||
z4 = @constinferred add(y, x, α, β) | ||
@test deepcollect(z4) ≈ (muladd.(deepcollect(x), α, deepcollect(y) .* β)) | ||
|
||
α, β = randn(ComplexF64, 2) | ||
z5 = @constinferred add!!(deepcopy(y), xcopy, α) | ||
@test deepcollect(z5) ≈ (muladd.(deepcollect(x), α, deepcollect(y))) | ||
@test all(deepcollect(xcopy) .== deepcollect(x)) | ||
z5 = @constinferred add!!(deepcopy(y), xcopy, α, β) | ||
@test deepcollect(z5) ≈ (muladd.(deepcollect(x), α, deepcollect(y) .* β)) | ||
@test all(deepcollect(xcopy) .== deepcollect(x)) | ||
|
||
α, β = randn(ComplexF64, 2) | ||
z5 = @constinferred add!!(deepcopy(y), xcopy, α) | ||
@test deepcollect(z5) ≈ (muladd.(deepcollect(x), α, deepcollect(y))) | ||
@test all(deepcollect(xcopy) .== deepcollect(x)) | ||
z5 = @constinferred add!!(deepcopy(y), xcopy, α, β) | ||
@test deepcollect(z5) ≈ (muladd.(deepcollect(x), α, deepcollect(y) .* β)) | ||
@test all(deepcollect(xcopy) .== deepcollect(x)) | ||
|
||
α, β = randn(ComplexF64, 2) | ||
@test_throws InexactError add!(deepcopy(y), xcopy, α) | ||
@test_throws InexactError add!(deepcopy(y), xcopy, α, β) | ||
end | ||
|
||
@testset "inner" begin | ||
s = @constinferred inner(x, y) | ||
@test s ≈ inner(deepcollect(x), deepcollect(y)) | ||
|
||
α, β = randn(ComplexF64, 2) | ||
s2 = @constinferred inner(scale(x, α), scale(y, β)) | ||
@test s2 ≈ inner(α * deepcollect(x), β * deepcollect(y)) | ||
end | ||
|
||
end |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In principle, this is quite an extensive testset for something that is simply wrapping
Vector
. We could simply have:Along with some simple checks of type stability. I am also fine with just leaving this as is, but less code is easier to maintain 😉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is true, it's copy pasted code, with small adjustments for the specific case.
It did help me with finding some problem in our fallbacks though. Base has a
mul!(::Any, :Any, :Any)
method, so testingapplicable(mul!, x, y, alpha)
was kind of pointless. ThatTuple{Any,Any,Any}
method assumes this is a matrix multiplication call, and just adds 2 more argumentstrue
andfalse
playing the role of alpha and beta. But that is of course pointless if you want to do what we callscale!
oradd!
, andalpha
is already the third argument. I guess this just illustrates what I dislike about the currentmul!
method and why I started VectorInterface.jl in the first place.It actually also helped me track down another error in our
MinimalVec
implementation, where our interface promises thatscale!!(y, x, alpha)
will always work. So even ify
is mutable, but happens to have a scalar type that is not compatible with the scalar type of x * alpha, the method should not complain and return a new vector (which actually has scalar type that is the promotion of that of y, x and alpha)`. This was broken and resulting in an error in the original implementation.So I would favor keeping it for now.