Skip to content

Commit

Permalink
[BlockSparseArrays] Direct sum/cat (#1579)
Browse files Browse the repository at this point in the history
* [BlockSparseArrays] Direct sum/`cat`

* [NDTensors] Bump to v0.3.64
  • Loading branch information
mtfishman authored Nov 14, 2024
1 parent cf050da commit 57994ff
Show file tree
Hide file tree
Showing 10 changed files with 184 additions and 1 deletion.
2 changes: 1 addition & 1 deletion NDTensors/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NDTensors"
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
authors = ["Matthew Fishman <[email protected]>"]
version = "0.3.63"
version = "0.3.64"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
2 changes: 2 additions & 0 deletions NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ include("blocksparsearrayinterface/broadcast.jl")
include("blocksparsearrayinterface/map.jl")
include("blocksparsearrayinterface/arraylayouts.jl")
include("blocksparsearrayinterface/views.jl")
include("blocksparsearrayinterface/cat.jl")
include("abstractblocksparsearray/abstractblocksparsearray.jl")
include("abstractblocksparsearray/wrappedabstractblocksparsearray.jl")
include("abstractblocksparsearray/abstractblocksparsematrix.jl")
Expand All @@ -17,6 +18,7 @@ include("abstractblocksparsearray/sparsearrayinterface.jl")
include("abstractblocksparsearray/broadcast.jl")
include("abstractblocksparsearray/map.jl")
include("abstractblocksparsearray/linearalgebra.jl")
include("abstractblocksparsearray/cat.jl")
include("blocksparsearray/defaults.jl")
include("blocksparsearray/blocksparsearray.jl")
include("BlockArraysSparseArrayInterfaceExt/BlockArraysSparseArrayInterfaceExt.jl")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# TODO: Change to `AnyAbstractBlockSparseArray`.
function Base.cat(as::BlockSparseArrayLike...; dims)
# TODO: Use `sparse_cat` instead, currently
# that erroneously allocates too many blocks that are
# zero and shouldn't be stored.
return blocksparse_cat(as...; dims)
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using BlockArrays: AbstractBlockedUnitRange, blockedrange, blocklengths
using NDTensors.SparseArrayInterface: SparseArrayInterface, allocate_cat_output, sparse_cat!

# TODO: Maybe move to `SparseArrayInterfaceBlockArraysExt`.
# TODO: Handle dual graded unit ranges, for example in a new `SparseArrayInterfaceGradedAxesExt`.
function SparseArrayInterface.axis_cat(
a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange
)
return blockedrange(vcat(blocklengths(a1), blocklengths(a2)))
end

# that erroneously allocates too many blocks that are
# zero and shouldn't be stored.
function blocksparse_cat!(a_dest::AbstractArray, as::AbstractArray...; dims)
sparse_cat!(blocks(a_dest), blocks.(as)...; dims)
return a_dest
end

# TODO: Delete this in favor of `sparse_cat`, currently
# that erroneously allocates too many blocks that are
# zero and shouldn't be stored.
function blocksparse_cat(as::AbstractArray...; dims)
a_dest = allocate_cat_output(as...; dims)
blocksparse_cat!(a_dest, as...; dims)
return a_dest
end
27 changes: 27 additions & 0 deletions NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,33 @@ using .NDTensorsTestUtils: devices_list, is_supported_eltype
@test a1' * a2 Array(a1)' * Array(a2)
@test dot(a1, a2) a1' * a2
end
@testset "cat" begin
a1 = dev(BlockSparseArray{elt}([2, 3], [2, 3]))
a1[Block(2, 1)] = dev(randn(elt, size(@view(a1[Block(2, 1)]))))
a2 = dev(BlockSparseArray{elt}([2, 3], [2, 3]))
a2[Block(1, 2)] = dev(randn(elt, size(@view(a2[Block(1, 2)]))))

a_dest = cat(a1, a2; dims=1)
@test block_nstored(a_dest) == 2
@test blocklengths.(axes(a_dest)) == ([2, 3, 2, 3], [2, 3])
@test issetequal(block_stored_indices(a_dest), [Block(2, 1), Block(3, 2)])
@test a_dest[Block(2, 1)] == a1[Block(2, 1)]
@test a_dest[Block(3, 2)] == a2[Block(1, 2)]

a_dest = cat(a1, a2; dims=2)
@test block_nstored(a_dest) == 2
@test blocklengths.(axes(a_dest)) == ([2, 3], [2, 3, 2, 3])
@test issetequal(block_stored_indices(a_dest), [Block(2, 1), Block(1, 4)])
@test a_dest[Block(2, 1)] == a1[Block(2, 1)]
@test a_dest[Block(1, 4)] == a2[Block(1, 2)]

a_dest = cat(a1, a2; dims=(1, 2))
@test block_nstored(a_dest) == 2
@test blocklengths.(axes(a_dest)) == ([2, 3, 2, 3], [2, 3, 2, 3])
@test issetequal(block_stored_indices(a_dest), [Block(2, 1), Block(3, 4)])
@test a_dest[Block(2, 1)] == a1[Block(2, 1)]
@test a_dest[Block(3, 4)] == a2[Block(1, 2)]
end
@testset "TensorAlgebra" begin
a1 = dev(BlockSparseArray{elt}([2, 3], [2, 3]))
a1[Block(1, 1)] = dev(randn(elt, size(@view(a1[Block(1, 1)]))))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ include("sparsearrayinterface/broadcast.jl")
include("sparsearrayinterface/conversion.jl")
include("sparsearrayinterface/wrappers.jl")
include("sparsearrayinterface/zero.jl")
include("sparsearrayinterface/cat.jl")
include("sparsearrayinterface/SparseArrayInterfaceLinearAlgebraExt.jl")
include("abstractsparsearray/abstractsparsearray.jl")
include("abstractsparsearray/abstractsparsematrix.jl")
Expand All @@ -24,6 +25,7 @@ include("abstractsparsearray/broadcast.jl")
include("abstractsparsearray/map.jl")
include("abstractsparsearray/baseinterface.jl")
include("abstractsparsearray/convert.jl")
include("abstractsparsearray/cat.jl")
include("abstractsparsearray/SparseArrayInterfaceSparseArraysExt.jl")
include("abstractsparsearray/SparseArrayInterfaceLinearAlgebraExt.jl")
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# TODO: Change to `AnyAbstractSparseArray`.
function Base.cat(as::SparseArrayLike...; dims)
return sparse_cat(as...; dims)
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
unval(x) = x
unval(::Val{x}) where {x} = x

# TODO: Assert that `a1` and `a2` start at one.
axis_cat(a1::AbstractUnitRange, a2::AbstractUnitRange) = Base.OneTo(length(a1) + length(a2))
function axis_cat(
a1::AbstractUnitRange, a2::AbstractUnitRange, a_rest::AbstractUnitRange...
)
return axis_cat(axis_cat(a1, a2), a_rest...)
end
function cat_axes(as::AbstractArray...; dims)
return ntuple(length(first(axes.(as)))) do dim
return if dim in unval(dims)
axis_cat(map(axes -> axes[dim], axes.(as))...)
else
axes(first(as))[dim]
end
end
end

function allocate_cat_output(as::AbstractArray...; dims)
eltype_dest = promote_type(eltype.(as)...)
axes_dest = cat_axes(as...; dims)
# TODO: Promote the block types of the inputs rather than using
# just the first input.
# TODO: Make this customizable with `cat_similar`.
# TODO: Base the zero element constructor on those of the inputs,
# for example block sparse arrays.
return similar(first(as), eltype_dest, axes_dest...)
end

# https://github.com/JuliaLang/julia/blob/v1.11.1/base/abstractarray.jl#L1748-L1857
# https://docs.julialang.org/en/v1/base/arrays/#Concatenation-and-permutation
# This is very similar to the `Base.cat` implementation but handles zero values better.
function cat_offset!(
a_dest::AbstractArray, offsets, a1::AbstractArray, a_rest::AbstractArray...; dims
)
inds = ntuple(ndims(a_dest)) do dim
dim in unval(dims) ? offsets[dim] .+ axes(a1, dim) : axes(a_dest, dim)
end
a_dest[inds...] = a1
new_offsets = ntuple(ndims(a_dest)) do dim
dim in unval(dims) ? offsets[dim] + size(a1, dim) : offsets[dim]
end
cat_offset!(a_dest, new_offsets, a_rest...; dims)
return a_dest
end
function cat_offset!(a_dest::AbstractArray, offsets; dims)
return a_dest
end

# TODO: Define a generic `cat!` function.
function sparse_cat!(a_dest::AbstractArray, as::AbstractArray...; dims)
offsets = ntuple(zero, ndims(a_dest))
# TODO: Fill `a_dest` with zeros if needed.
cat_offset!(a_dest, offsets, as...; dims)
return a_dest
end

function sparse_cat(as::AbstractArray...; dims)
a_dest = allocate_cat_output(as...; dims)
sparse_cat!(a_dest, as...; dims)
return a_dest
end
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,31 @@ function sparse_setindex!(a::AbstractArray, value, I::Vararg{Int})
return a
end

# Fix ambiguity error
function sparse_setindex!(a::AbstractArray, value)
sparse_setindex!(a, value, CartesianIndex())
return a
end

# Linear indexing
function sparse_setindex!(a::AbstractArray, value, I::CartesianIndex{1})
sparse_setindex!(a, value, CartesianIndices(a)[I])
return a
end

# Slicing
# TODO: Make this handle more general slicing operations,
# base it off of `ArrayLayouts.sub_materialize`.
function sparse_setindex!(a::AbstractArray, value, I::AbstractUnitRange...)
inds = CartesianIndices(I)
for i in stored_indices(value)
if i in CartesianIndices(inds)
a[inds[i]] = value[i]
end
end
return a
end

# Handle trailing indices
function sparse_setindex!(a::AbstractArray, value, I::CartesianIndex)
t = Tuple(I)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,38 @@ using Test: @test, @testset
@test a_dest isa SparseArray{elt}
@test SparseArrayInterface.nstored(a_dest) == 2

# cat
a1 = SparseArray{elt}(2, 3)
a1[1, 2] = 12
a1[2, 1] = 21
a2 = SparseArray{elt}(2, 3)
a2[1, 1] = 11
a2[2, 2] = 22

a_dest = cat(a1, a2; dims=1)
@test size(a_dest) == (4, 3)
@test SparseArrayInterface.nstored(a_dest) == 4
@test a_dest[1, 2] == a1[1, 2]
@test a_dest[2, 1] == a1[2, 1]
@test a_dest[3, 1] == a2[1, 1]
@test a_dest[4, 2] == a2[2, 2]

a_dest = cat(a1, a2; dims=2)
@test size(a_dest) == (2, 6)
@test SparseArrayInterface.nstored(a_dest) == 4
@test a_dest[1, 2] == a1[1, 2]
@test a_dest[2, 1] == a1[2, 1]
@test a_dest[1, 4] == a2[1, 1]
@test a_dest[2, 5] == a2[2, 2]

a_dest = cat(a1, a2; dims=(1, 2))
@test size(a_dest) == (4, 6)
@test SparseArrayInterface.nstored(a_dest) == 4
@test a_dest[1, 2] == a1[1, 2]
@test a_dest[2, 1] == a1[2, 1]
@test a_dest[3, 4] == a2[1, 1]
@test a_dest[4, 5] == a2[2, 2]

## # Sparse matrix of matrix multiplication
## TODO: Make this work, seems to require
## a custom zero constructor.
Expand Down

2 comments on commit 57994ff

@mtfishman
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register subdir=NDTensors

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/119367

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a NDTensors-v0.3.64 -m "<description of version>" 57994ff8ea3a869ec0a457fe766032faee7941b4
git push origin NDTensors-v0.3.64

Please sign in to comment.