Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add minimalvec #22

Merged
merged 4 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/Invalidations.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
steps:
- uses: julia-actions/setup-julia@v2
with:
version: '1'
version: 'lts'
- uses: actions/checkout@v4
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-invalidations@v1
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "VectorInterface"
uuid = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
authors = ["Jutho Haegeman <[email protected]> and contributors"]
version = "0.4.9"
version = "0.5"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -10,7 +10,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Aqua = "0.6, 0.7, 0.8"
LinearAlgebra = "1"
Test = "1"
TestExtras = "0.2"
TestExtras = "0.2,0.3"
julia = "1"

[extras]
Expand Down
3 changes: 3 additions & 0 deletions src/VectorInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,7 @@ include("namedtuple.jl")
# General fallback implementation: comes with warning and some overhead
include("fallbacks.jl")

# Minimal vector type for testing
include("minimalvec.jl")

end
10 changes: 5 additions & 5 deletions src/fallbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ function scalartype(T::Type)
elT === T && throw(ArgumentError(_error_message(scalartype, T)))
return scalartype(elT)
end
# should this try to use `eltype` instead? e.g. scalartype(T) = scalartype(eltype(T))

# zerovector & zerovector!!
#---------------------------
Expand Down Expand Up @@ -102,8 +101,8 @@ end
function scale!(y, x, α::Number)
T = Tuple{typeof(y),typeof(x),typeof(α)}
@warn _warn_message(scale!, T) maxlog = 1
if applicable(LinearAlgebra.mul!, y, x, α)
return LinearAlgebra.mul!(y, x, α)
if applicable(LinearAlgebra.mul!, y, x, α, true, false)
return LinearAlgebra.mul!(y, x, α, true, false)
else
throw(ArgumentError(_error_message(scale!, T)))
end
Expand All @@ -112,8 +111,9 @@ end
function scale!!(y, x, α::Number)
T = Tuple{typeof(y),typeof(x),typeof(α)}
@warn _warn_message(scale!!, T) maxlog = 1
if applicable(LinearAlgebra.mul!, y, x, α) && promote_scale(y, x, α) <: scalartype(y)
return LinearAlgebra.mul!(y, x, α)
if applicable(LinearAlgebra.mul!, y, x, α, true, false) &&
promote_scale(y, x, α) <: scalartype(y)
return LinearAlgebra.mul!(y, x, α, true, false)
else
α_Ty = α * one(scalartype(y))
if applicable(*, x, α_Ty)
Expand Down
110 changes: 110 additions & 0 deletions src/minimalvec.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""
MinimalVec{M,V<:AbstractVector}
MinimalVec{M}(vec::V) where {M,V<:AbstractVector}

Wraps a vector of type `V<:AbstractVector` in such a way that the wrapper only supports the
minimal interface put forward by VectorInterface.jl. The type parameter `M` can take the
values `true` or `false` and determines whether the vector behaves as a mutable vector that
supports in-place operations (`M == true`) or whether it behaves as an immutable or
static vector (`M == false`).

This wrapper can be used to test whether an algorithm is implemented using only the minimal
interface of VectorInterface.jl, without relying on other methods that would for example
be available for `AbstractVector` or `AbstractArray`.

To unwrap the contents of a `v::MinimalVec` instance, the field access `v.vec` can be used.

See also [`MinimalMVec`](@ref) and [`MinimalSVec`](@ref) for convenience constructors.
"""
struct MinimalVec{M,V<:AbstractVector}
vec::V
function MinimalVec{M,V}(vec::V) where {M,V}
M isa Bool || throw(ArgumentError("first type parameter must be `true` or `false`"))
return new{M,V}(vec)
end
MinimalVec{M}(vec::V) where {M,V} = MinimalVec{M,V}(vec)
end
"""
const MinimalMVec = MinimalVec{true}
MinimalMVec(v::AbstractVector)

Type alias for `MinimalVec{true}`, representing a vector wrapper that implements the minimal
interface of VectorInterface.jl, including in-place operations (!-methods).

See also [`MinimalVec`](@ref) and [`MinimalSVec`](@ref).
"""
const MinimalMVec{V} = MinimalVec{true,V}
"""
const MinimalSVec = MinimalVec{false}
MinimalSVec(v::AbstractVector)

Type alias for `MinimalVec{false}`, representing a vector wrapper that implements the
minimal interface of VectorInterface.jl, excluding in-place operations (!-methods).

See also [`MinimalVec`](@ref) and [`MinimalMVec`](@ref).
"""
const MinimalSVec{V} = MinimalVec{false,V}

MinimalMVec(v::AbstractVector) = MinimalVec{true}(v)
MinimalSVec(v::AbstractVector) = MinimalVec{false}(v)

_ismutable(::Type{MinimalVec{M,V}}) where {V,M} = M
_ismutable(v::MinimalVec) = _ismutable(typeof(v))

scalartype(::Type{<:MinimalVec{M,V}}) where {M,V} = scalartype(V)

function zerovector(v::MinimalVec, S::Type{<:Number})
return MinimalVec{_ismutable(v)}(zerovector(v.vec, S))
end
function zerovector!(v::MinimalMVec{V}) where {V}
zerovector!(v.vec)
return v
end
zerovector!!(v::MinimalVec) = _ismutable(v) ? zerovector!(v) : zerovector(v)

function scale(v::MinimalVec, α::Number)
return MinimalVec{_ismutable(v)}(scale(v.vec, α))
end
function scale!(v::MinimalMVec{V}, α::Number) where {V}
scale!(v.vec, α)
return v
end
function scale!!(v::MinimalVec, α::Number)
if _ismutable(v)
w = scale!!(v.vec, α)
return w === v.vec ? v : MinimalMVec(w)
else
return scale(v, α)
end
end
function scale!(w::MinimalMVec{V}, v::MinimalMVec{W}, α::Number) where {V,W}
scale!(w.vec, v.vec, α)
return w
end
function scale!!(w::MinimalVec, v::MinimalVec, α::Number)
if _ismutable(w)
wvec = scale!!(w.vec, v.vec, α)
return wvec === w.vec ? w : MinimalMVec(wvec)
else
return scale(v, α * one(scalartype(w)))
end
end

function add(y::MinimalVec, x::MinimalVec, α::Number, β::Number)
return MinimalVec{_ismutable(y)}(add(y.vec, x.vec, α, β))
end
function add!(y::MinimalMVec{W}, x::MinimalMVec{V}, α::Number, β::Number) where {W,V}
add!(y.vec, x.vec, α, β)
return y
end
function add!!(y::MinimalVec, x::MinimalVec, α::Number, β::Number)
if _ismutable(y)
yvec = add!!(y.vec, x.vec, α, β)
return yvec === y.vec ? y : MinimalMVec(yvec)
else
return add(y, x, α, β)
end
end

inner(x::MinimalVec, y::MinimalVec) = inner(x.vec, y.vec)
LinearAlgebra.norm(x::MinimalVec) = LinearAlgebra.norm(x.vec)
160 changes: 160 additions & 0 deletions test/minimalmvec.jl
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle, this is quite an extensive testset for something that is simply wrapping Vector. We could simply have:

@testset "scale" begin
    a = rand(2)
    alfa = rand()
    @test scale(a, alfa) == scale(MinimalMVec(a), alfa).vec 
    ...
end

Along with some simple checks of type stability. I am also fine with just leaving this as is, but less code is easier to maintain 😉

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is true, it's copy pasted code, with small adjustments for the specific case.

It did help me with finding some problem in our fallbacks though. Base has a mul!(::Any, :Any, :Any) method, so testing applicable(mul!, x, y, alpha) was kind of pointless. That Tuple{Any,Any,Any} method assumes this is a matrix multiplication call, and just adds 2 more arguments true and false playing the role of alpha and beta. But that is of course pointless if you want to do what we call scale! or add!, and alpha is already the third argument. I guess this just illustrates what I dislike about the current mul! method and why I started VectorInterface.jl in the first place.

It actually also helped me track down another error in our MinimalVec implementation, where our interface promises that scale!!(y, x, alpha) will always work. So even if y is mutable, but happens to have a scalar type that is not compatible with the scalar type of x * alpha, the method should not complain and return a new vector (which actually has scalar type that is the promotion of that of y, x and alpha)`. This was broken and resulting in an error in the original implementation.

So I would favor keeping it for now.

Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
module Minimal
using VectorInterface
using VectorInterface: MinimalVec, MinimalMVec
using Test
using TestExtras

deepcollect(x::MinimalVec) = x.vec
deepcollect(x::Number) = x

x = MinimalMVec(randn(3))
y = MinimalMVec(randn(3))

@testset "scalartype" begin
s = @constinferred scalartype(x)
@test s == Float64
end

@testset "zerovector" begin
z = @constinferred zerovector(x)
@test all(iszero, deepcollect(z))
@test all(deepcollect(z) .=== zero(scalartype(x)))
z1 = @constinferred zerovector!!(deepcopy(x))
@test all(deepcollect(z1) .=== zero(scalartype(x)))
z2 = @constinferred zerovector!(deepcopy(x))
@test all(deepcollect(z2) .=== zero(scalartype(x)))

z3 = @constinferred zerovector(x, ComplexF64)
@test all(deepcollect(z3) .=== zero(ComplexF64))
z4 = @constinferred zerovector!!(deepcopy(x), ComplexF64)
@test all(deepcollect(z4) .=== zero(ComplexF64))
@test_throws MethodError zerovector!(deepcopy(x), ComplexF64)
end

@testset "scale" begin
α = randn()
z = @constinferred scale(x, α)
@test all(deepcollect(z) .== α .* deepcollect(x))

z2 = @constinferred scale!!(deepcopy(x), α)
@test deepcollect(z2) ≈ (α .* deepcollect(x))
xcopy = deepcopy(x)
z2 = @constinferred scale!!(deepcopy(y), xcopy, α)
@test deepcollect(z2) ≈ (α .* deepcollect(x))
@test all(deepcollect(xcopy) .== deepcollect(x))

z3 = @constinferred scale!(deepcopy(x), α)
@test deepcollect(z3) ≈ (α .* deepcollect(x))
xcopy = deepcopy(x)
z3 = @constinferred scale!(zerovector(x), xcopy, α)
@test deepcollect(z3) ≈ (α .* deepcollect(x))
@test all(deepcollect(xcopy) .== deepcollect(x))

α = randn(ComplexF64)
z4 = @constinferred scale(x, α)
@test deepcollect(z4) ≈ (α .* deepcollect(x))
xcopy = deepcopy(x)
z5 = @constinferred scale!!(xcopy, α)
@test deepcollect(z5) ≈ (α .* deepcollect(x))
@test all(deepcollect(xcopy) .== deepcollect(x))
@test_throws InexactError scale!(xcopy, α)

α = randn(ComplexF64)
xcopy = deepcopy(x)
z6 = @constinferred scale!!(zerovector(x), xcopy, α)
@test deepcollect(z6) ≈ (α .* deepcollect(x))
@test all(deepcollect(xcopy) .== deepcollect(x))
@test_throws InexactError scale!(zerovector(x), xcopy, α)

xz = @constinferred zerovector(x, ComplexF64)
z6 = @constinferred scale!!(xz, xcopy, α)
@test deepcollect(z6) ≈ (α .* deepcollect(x))
@test all(deepcollect(xcopy) .== deepcollect(x))

z7 = @constinferred scale!(xz, xcopy, α)
@test deepcollect(z7) ≈ (α .* deepcollect(x))
@test all(deepcollect(xcopy) .== deepcollect(x))

ycomplex = zerovector(y, ComplexF64)
α = randn(Float64)
xcopy = deepcopy(x)
z8 = @constinferred scale!!(ycomplex, xcopy, α)
@test z8 === ycomplex
@test all(deepcollect(z8) .== α .* deepcollect(xcopy))
end

@testset "add" begin
α, β = randn(2)
z = @constinferred add(y, x)
@test all(deepcollect(z) .== deepcollect(x) .+ deepcollect(y))
z = @constinferred add(y, x, α)
# for some reason, on some Julia versions on some platforms, but only in test mode
# there is a small floating point discrepancy, which makes the following test fail:
# @test all(deepcollect(z) .== muladd.(deepcollect(x), α, deepcollect(y)))
@test deepcollect(z) ≈ muladd.(deepcollect(x), α, deepcollect(y))
z = @constinferred add(y, x, α, β)
# for some reason, on some Julia versions on some platforms, but only in test mode
# there is a small floating point discrepancy, which makes the following test fail:
# @test all(deepcollect(z) .== muladd.(deepcollect(x), α, deepcollect(y) .* β))
@test deepcollect(z) ≈ muladd.(deepcollect(x), α, deepcollect(y) .* β)

α, β = randn(2)
xcopy = deepcopy(x)
z2 = @constinferred add!!(deepcopy(y), xcopy)
@test deepcollect(z2) ≈ (deepcollect(x) .+ deepcollect(y))
@test all(deepcollect(xcopy) .== deepcollect(x))
z2 = @constinferred add!!(deepcopy(y), xcopy, α)
@test deepcollect(z2) ≈ (muladd.(deepcollect(x), α, deepcollect(y)))
@test all(deepcollect(xcopy) .== deepcollect(x))
z2 = @constinferred add!!(deepcopy(y), xcopy, α, β)
@test deepcollect(z2) ≈ (muladd.(deepcollect(x), α, deepcollect(y) .* β))
@test all(deepcollect(xcopy) .== deepcollect(x))

α, β = randn(2)
z3 = @constinferred add!(deepcopy(y), xcopy)
@test deepcollect(z3) ≈ (deepcollect(y) .+ deepcollect(x))
@test all(deepcollect(xcopy) .== deepcollect(x))
z3 = @constinferred add!(deepcopy(y), xcopy, α)
@test all(deepcollect(xcopy) .== deepcollect(x))
@test deepcollect(z3) ≈ (muladd.(deepcollect(x), α, deepcollect(y)))
z3 = @constinferred add!(deepcopy(y), xcopy, α, β)
@test deepcollect(z3) ≈ (muladd.(deepcollect(x), α, deepcollect(y) .* β))
@test all(deepcollect(xcopy) .== deepcollect(x))

α, β = randn(ComplexF64, 2)
z4 = @constinferred add(y, x, α)
@test deepcollect(z4) ≈ (muladd.(deepcollect(x), α, deepcollect(y)))
z4 = @constinferred add(y, x, α, β)
@test deepcollect(z4) ≈ (muladd.(deepcollect(x), α, deepcollect(y) .* β))

α, β = randn(ComplexF64, 2)
z5 = @constinferred add!!(deepcopy(y), xcopy, α)
@test deepcollect(z5) ≈ (muladd.(deepcollect(x), α, deepcollect(y)))
@test all(deepcollect(xcopy) .== deepcollect(x))
z5 = @constinferred add!!(deepcopy(y), xcopy, α, β)
@test deepcollect(z5) ≈ (muladd.(deepcollect(x), α, deepcollect(y) .* β))
@test all(deepcollect(xcopy) .== deepcollect(x))

α, β = randn(ComplexF64, 2)
z5 = @constinferred add!!(deepcopy(y), xcopy, α)
@test deepcollect(z5) ≈ (muladd.(deepcollect(x), α, deepcollect(y)))
@test all(deepcollect(xcopy) .== deepcollect(x))
z5 = @constinferred add!!(deepcopy(y), xcopy, α, β)
@test deepcollect(z5) ≈ (muladd.(deepcollect(x), α, deepcollect(y) .* β))
@test all(deepcollect(xcopy) .== deepcollect(x))

α, β = randn(ComplexF64, 2)
@test_throws InexactError add!(deepcopy(y), xcopy, α)
@test_throws InexactError add!(deepcopy(y), xcopy, α, β)
end

@testset "inner" begin
s = @constinferred inner(x, y)
@test s ≈ inner(deepcollect(x), deepcollect(y))

α, β = randn(ComplexF64, 2)
s2 = @constinferred inner(scale(x, α), scale(y, β))
@test s2 ≈ inner(α * deepcollect(x), β * deepcollect(y))
end

end
Loading
Loading