Skip to content

Commit

Permalink
improve cat design / performance (#49322)
Browse files Browse the repository at this point in the history
This used to make a lot of references to design issues with the
SparseArrays package (#2326 /
#20815), which result in a
non-sensical dispatch arrangement, and contribute to a slow loading
experience do to the nonsense Unions that must be checked by subtyping.
  • Loading branch information
vtjnash authored Jul 13, 2023
1 parent dcca46b commit 5a922fa
Show file tree
Hide file tree
Showing 10 changed files with 63 additions and 95 deletions.
50 changes: 24 additions & 26 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1651,7 +1651,7 @@ function _typed_hcat(::Type{T}, A::AbstractVecOrTuple{AbstractVecOrMat}) where T
for j = 1:nargs
Aj = A[j]
if size(Aj, 1) != nrows
throw(ArgumentError("number of rows of each array must match (got $(map(x->size(x,1), A)))"))
throw(DimensionMismatch("number of rows of each array must match (got $(map(x->size(x,1), A)))"))
end
dense &= isa(Aj,Array)
nd = ndims(Aj)
Expand Down Expand Up @@ -1686,7 +1686,7 @@ function _typed_vcat(::Type{T}, A::AbstractVecOrTuple{AbstractVecOrMat}) where T
ncols = size(A[1], 2)
for j = 2:nargs
if size(A[j], 2) != ncols
throw(ArgumentError("number of columns of each array must match (got $(map(x->size(x,2), A)))"))
throw(DimensionMismatch("number of columns of each array must match (got $(map(x->size(x,2), A)))"))
end
end
B = similar(A[1], T, nrows, ncols)
Expand Down Expand Up @@ -1984,16 +1984,14 @@ julia> cat(1, [2], [3;;]; dims=Val(2))

# The specializations for 1 and 2 inputs are important
# especially when running with --inline=no, see #11158
# The specializations for Union{AbstractVecOrMat,Number} are necessary
# to have more specialized methods here than in LinearAlgebra/uniformscaling.jl
vcat(A::AbstractArray) = cat(A; dims=Val(1))
vcat(A::AbstractArray, B::AbstractArray) = cat(A, B; dims=Val(1))
vcat(A::AbstractArray...) = cat(A...; dims=Val(1))
vcat(A::Union{AbstractVecOrMat,Number}...) = cat(A...; dims=Val(1))
vcat(A::Union{AbstractArray,Number}...) = cat(A...; dims=Val(1))
hcat(A::AbstractArray) = cat(A; dims=Val(2))
hcat(A::AbstractArray, B::AbstractArray) = cat(A, B; dims=Val(2))
hcat(A::AbstractArray...) = cat(A...; dims=Val(2))
hcat(A::Union{AbstractVecOrMat,Number}...) = cat(A...; dims=Val(2))
hcat(A::Union{AbstractArray,Number}...) = cat(A...; dims=Val(2))

typed_vcat(T::Type, A::AbstractArray) = _cat_t(Val(1), T, A)
typed_vcat(T::Type, A::AbstractArray, B::AbstractArray) = _cat_t(Val(1), T, A, B)
Expand Down Expand Up @@ -2055,8 +2053,8 @@ julia> hvcat((2,2,2), a,b,c,d,e,f) == hvcat(2, a,b,c,d,e,f)
true
```
"""
hvcat(rows::Tuple{Vararg{Int}}, xs::AbstractVecOrMat...) = typed_hvcat(promote_eltype(xs...), rows, xs...)
hvcat(rows::Tuple{Vararg{Int}}, xs::AbstractVecOrMat{T}...) where {T} = typed_hvcat(T, rows, xs...)
hvcat(rows::Tuple{Vararg{Int}}, xs::AbstractArray...) = typed_hvcat(promote_eltype(xs...), rows, xs...)
hvcat(rows::Tuple{Vararg{Int}}, xs::AbstractArray{T}...) where {T} = typed_hvcat(T, rows, xs...)

function typed_hvcat(::Type{T}, rows::Tuple{Vararg{Int}}, as::AbstractVecOrMat...) where T
nbr = length(rows) # number of block rows
Expand Down Expand Up @@ -2084,16 +2082,16 @@ function typed_hvcat(::Type{T}, rows::Tuple{Vararg{Int}}, as::AbstractVecOrMat..
Aj = as[a+j-1]
szj = size(Aj,2)
if size(Aj,1) != szi
throw(ArgumentError("mismatched height in block row $(i) (expected $szi, got $(size(Aj,1)))"))
throw(DimensionMismatch("mismatched height in block row $(i) (expected $szi, got $(size(Aj,1)))"))
end
if c-1+szj > nc
throw(ArgumentError("block row $(i) has mismatched number of columns (expected $nc, got $(c-1+szj))"))
throw(DimensionMismatch("block row $(i) has mismatched number of columns (expected $nc, got $(c-1+szj))"))
end
out[r:r-1+szi, c:c-1+szj] = Aj
c += szj
end
if c != nc+1
throw(ArgumentError("block row $(i) has mismatched number of columns (expected $nc, got $(c-1))"))
throw(DimensionMismatch("block row $(i) has mismatched number of columns (expected $nc, got $(c-1))"))
end
r += szi
a += rows[i]
Expand All @@ -2115,7 +2113,7 @@ function hvcat(rows::Tuple{Vararg{Int}}, xs::T...) where T<:Number
k = 1
@inbounds for i=1:nr
if nc != rows[i]
throw(ArgumentError("row $(i) has mismatched number of columns (expected $nc, got $(rows[i]))"))
throw(DimensionMismatch("row $(i) has mismatched number of columns (expected $nc, got $(rows[i]))"))
end
for j=1:nc
a[i,j] = xs[k]
Expand Down Expand Up @@ -2144,14 +2142,14 @@ end
hvcat(rows::Tuple{Vararg{Int}}, xs::Number...) = typed_hvcat(promote_typeof(xs...), rows, xs...)
hvcat(rows::Tuple{Vararg{Int}}, xs...) = typed_hvcat(promote_eltypeof(xs...), rows, xs...)
# the following method is needed to provide a more specific one compared to LinearAlgebra/uniformscaling.jl
hvcat(rows::Tuple{Vararg{Int}}, xs::Union{AbstractVecOrMat,Number}...) = typed_hvcat(promote_eltypeof(xs...), rows, xs...)
hvcat(rows::Tuple{Vararg{Int}}, xs::Union{AbstractArray,Number}...) = typed_hvcat(promote_eltypeof(xs...), rows, xs...)

function typed_hvcat(::Type{T}, rows::Tuple{Vararg{Int}}, xs::Number...) where T
nr = length(rows)
nc = rows[1]
for i = 2:nr
if nc != rows[i]
throw(ArgumentError("row $(i) has mismatched number of columns (expected $nc, got $(rows[i]))"))
throw(DimensionMismatch("row $(i) has mismatched number of columns (expected $nc, got $(rows[i]))"))
end
end
hvcat_fill!(Matrix{T}(undef, nr, nc), xs)
Expand Down Expand Up @@ -2319,7 +2317,7 @@ function _typed_hvncat(::Type{T}, ::Val{N}, as::AbstractArray...) where {T, N}
Ndim += cat_size(as[i], N)
nd = max(nd, cat_ndims(as[i]))
for d 1:N - 1
cat_size(as[1], d) == cat_size(as[i], d) || throw(ArgumentError("mismatched size along axis $d in element $i"))
cat_size(as[1], d) == cat_size(as[i], d) || throw(DimensionMismatch("mismatched size along axis $d in element $i"))
end
end

Expand All @@ -2346,7 +2344,7 @@ function _typed_hvncat(::Type{T}, ::Val{N}, as...) where {T, N}
nd = max(nd, cat_ndims(as[i]))
for d 1:N-1
cat_size(as[i], d) == 1 ||
throw(ArgumentError("all dimensions of element $i other than $N must be of length 1"))
throw(DimensionMismatch("all dimensions of element $i other than $N must be of length 1"))
end
end

Expand Down Expand Up @@ -2463,7 +2461,7 @@ function _typed_hvncat_dims(::Type{T}, dims::NTuple{N, Int}, row_first::Bool, as
for dd 1:N
dd == d && continue
if cat_size(as[startelementi], dd) != cat_size(as[i], dd)
throw(ArgumentError("incompatible shape in element $i"))
throw(DimensionMismatch("incompatible shape in element $i"))
end
end
end
Expand Down Expand Up @@ -2500,18 +2498,18 @@ function _typed_hvncat_dims(::Type{T}, dims::NTuple{N, Int}, row_first::Bool, as
elseif currentdims[d] < outdims[d] # dimension in progress
break
else # exceeded dimension
throw(ArgumentError("argument $i has too many elements along axis $d"))
throw(DimensionMismatch("argument $i has too many elements along axis $d"))
end
end
end
elseif currentdims[d1] > outdims[d1] # exceeded dimension
throw(ArgumentError("argument $i has too many elements along axis $d1"))
throw(DimensionMismatch("argument $i has too many elements along axis $d1"))
end
end

outlen = prod(outdims)
elementcount == outlen ||
throw(ArgumentError("mismatched number of elements; expected $(outlen), got $(elementcount)"))
throw(DimensionMismatch("mismatched number of elements; expected $(outlen), got $(elementcount)"))

# copy into final array
A = cat_similar(as[1], T, outdims)
Expand Down Expand Up @@ -2572,8 +2570,8 @@ function _typed_hvncat_shape(::Type{T}, shape::NTuple{N, Tuple}, row_first, as::
if d == 1 || i == 1 || wasstartblock
currentdims[d] += dsize
elseif dsize != cat_size(as[i - 1], ad)
throw(ArgumentError("argument $i has a mismatched number of elements along axis $ad; \
expected $(cat_size(as[i - 1], ad)), got $dsize"))
throw(DimensionMismatch("argument $i has a mismatched number of elements along axis $ad; \
expected $(cat_size(as[i - 1], ad)), got $dsize"))
end

wasstartblock = blockcounts[d] == 1 # remember for next dimension
Expand All @@ -2583,15 +2581,15 @@ function _typed_hvncat_shape(::Type{T}, shape::NTuple{N, Tuple}, row_first, as::
if outdims[d] == -1
outdims[d] = currentdims[d]
elseif outdims[d] != currentdims[d]
throw(ArgumentError("argument $i has a mismatched number of elements along axis $ad; \
expected $(abs(outdims[d] - (currentdims[d] - dsize))), got $dsize"))
throw(DimensionMismatch("argument $i has a mismatched number of elements along axis $ad; \
expected $(abs(outdims[d] - (currentdims[d] - dsize))), got $dsize"))
end
currentdims[d] = 0
blockcounts[d] = 0
shapepos[d] += 1
d > 1 && (blockcounts[d - 1] == 0 ||
throw(ArgumentError("shape in level $d is inconsistent; level counts must nest \
evenly into each other")))
throw(DimensionMismatch("shape in level $d is inconsistent; level counts must nest \
evenly into each other")))
end
end
end
Expand Down
12 changes: 0 additions & 12 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2041,18 +2041,6 @@ function vcat(arrays::Vector{T}...) where T
end
vcat(A::Vector...) = cat(A...; dims=Val(1)) # more special than SparseArrays's vcat

# disambiguation with LinAlg/special.jl
# Union{Number,Vector,Matrix} is for LinearAlgebra._DenseConcatGroup
# VecOrMat{T} is for LinearAlgebra._TypedDenseConcatGroup
hcat(A::Union{Number,Vector,Matrix}...) = cat(A...; dims=Val(2))
hcat(A::VecOrMat{T}...) where {T} = typed_hcat(T, A...)
vcat(A::Union{Number,Vector,Matrix}...) = cat(A...; dims=Val(1))
vcat(A::VecOrMat{T}...) where {T} = typed_vcat(T, A...)
hvcat(rows::Tuple{Vararg{Int}}, xs::Union{Number,Vector,Matrix}...) =
typed_hvcat(promote_eltypeof(xs...), rows, xs...)
hvcat(rows::Tuple{Vararg{Int}}, xs::VecOrMat{T}...) where {T} =
typed_hvcat(T, rows, xs...)

_cat(n::Integer, x::Integer...) = reshape([x...], (ntuple(Returns(1), n-1)..., length(x)))

## find ##
Expand Down

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
e59c1c57b97e17a73eba758d65022bd7
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ad88ebe77aaf1580e6d7ee7649ac5b812a23b9d9bf947f26babe9dd79902f6da11aa69bf63f22f67f6eae92a4c6e665cc3b950bb7c648c623e9cb4b9cb4daac4
26 changes: 5 additions & 21 deletions stdlib/LinearAlgebra/src/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -330,27 +330,11 @@ end
==(A::Bidiagonal, B::SymTridiagonal) = iszero(_evview(B)) && iszero(A.ev) && A.dv == B.dv
==(B::SymTridiagonal, A::Bidiagonal) = A == B

# concatenation
const _SpecialArrays = Union{Diagonal, Bidiagonal, Tridiagonal, SymTridiagonal}
const _Symmetric_DenseArrays{T,A<:Matrix} = Symmetric{T,A}
const _Hermitian_DenseArrays{T,A<:Matrix} = Hermitian{T,A}
const _Triangular_DenseArrays{T,A<:Matrix} = UpperOrLowerTriangular{T,A}
const _Annotated_DenseArrays = Union{_SpecialArrays, _Triangular_DenseArrays, _Symmetric_DenseArrays, _Hermitian_DenseArrays}
const _Annotated_Typed_DenseArrays{T} = Union{_Triangular_DenseArrays{T}, _Symmetric_DenseArrays{T}, _Hermitian_DenseArrays{T}}
const _DenseConcatGroup = Union{Number, Vector, Adjoint{<:Any,<:Vector}, Transpose{<:Any,<:Vector}, Matrix, _Annotated_DenseArrays}
const _TypedDenseConcatGroup{T} = Union{Vector{T}, Adjoint{T,Vector{T}}, Transpose{T,Vector{T}}, Matrix{T}, _Annotated_Typed_DenseArrays{T}}

promote_to_array_type(::Tuple{Vararg{Union{_DenseConcatGroup,UniformScaling}}}) = Matrix

Base._cat(dims, xs::_DenseConcatGroup...) = Base._cat_t(dims, promote_eltype(xs...), xs...)
vcat(A::_DenseConcatGroup...) = Base.typed_vcat(promote_eltype(A...), A...)
hcat(A::_DenseConcatGroup...) = Base.typed_hcat(promote_eltype(A...), A...)
hvcat(rows::Tuple{Vararg{Int}}, xs::_DenseConcatGroup...) = Base.typed_hvcat(promote_eltype(xs...), rows, xs...)
# For performance, specially handle the case where the matrices/vectors have homogeneous eltype
Base._cat(dims, xs::_TypedDenseConcatGroup{T}...) where {T} = Base._cat_t(dims, T, xs...)
vcat(A::_TypedDenseConcatGroup{T}...) where {T} = Base.typed_vcat(T, A...)
hcat(A::_TypedDenseConcatGroup{T}...) where {T} = Base.typed_hcat(T, A...)
hvcat(rows::Tuple{Vararg{Int}}, xs::_TypedDenseConcatGroup{T}...) where {T} = Base.typed_hvcat(T, rows, xs...)
# TODO: remove these deprecations (used by SparseArrays in the past)
const _DenseConcatGroup = Union{}
const _SpecialArrays = Union{}

promote_to_array_type(::Tuple) = Matrix

# factorizations
function cholesky(S::RealHermSymComplexHerm{<:Real,<:SymTridiagonal}, ::NoPivot = NoPivot(); check::Bool = true)
Expand Down
14 changes: 6 additions & 8 deletions stdlib/LinearAlgebra/src/uniformscaling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ end
# so that we can re-use this code for sparse-matrix hcat etcetera.
promote_to_arrays_(n::Int, ::Type, a::Number) = a
promote_to_arrays_(n::Int, ::Type{Matrix}, J::UniformScaling{T}) where {T} = Matrix(J, n, n)
promote_to_arrays_(n::Int, ::Type, A::AbstractVecOrMat) = A
promote_to_arrays_(n::Int, ::Type, A::AbstractArray) = A
promote_to_arrays(n,k, ::Type) = ()
promote_to_arrays(n,k, ::Type{T}, A) where {T} = (promote_to_arrays_(n[k], T, A),)
promote_to_arrays(n,k, ::Type{T}, A, B) where {T} =
Expand All @@ -417,17 +417,16 @@ promote_to_arrays(n,k, ::Type{T}, A, B, C) where {T} =
(promote_to_arrays_(n[k], T, A), promote_to_arrays_(n[k+1], T, B), promote_to_arrays_(n[k+2], T, C))
promote_to_arrays(n,k, ::Type{T}, A, B, Cs...) where {T} =
(promote_to_arrays_(n[k], T, A), promote_to_arrays_(n[k+1], T, B), promote_to_arrays(n,k+2, T, Cs...)...)
promote_to_array_type(A::Tuple{Vararg{Union{AbstractVecOrMat,UniformScaling,Number}}}) = Matrix

_us2number(A) = A
_us2number(J::UniformScaling) = J.λ

for (f, _f, dim, name) in ((:hcat, :_hcat, 1, "rows"), (:vcat, :_vcat, 2, "cols"))
@eval begin
@inline $f(A::Union{AbstractVecOrMat,UniformScaling}...) = $_f(A...)
@inline $f(A::Union{AbstractArray,UniformScaling}...) = $_f(A...)
# if there's a Number present, J::UniformScaling must be 1x1-dimensional
@inline $f(A::Union{AbstractVecOrMat,UniformScaling,Number}...) = $f(map(_us2number, A)...)
function $_f(A::Union{AbstractVecOrMat,UniformScaling,Number}...; array_type = promote_to_array_type(A))
@inline $f(A::Union{AbstractArray,UniformScaling,Number}...) = $f(map(_us2number, A)...)
function $_f(A::Union{AbstractArray,UniformScaling,Number}...; array_type = promote_to_array_type(A))
n = -1
for a in A
if !isa(a, UniformScaling)
Expand All @@ -445,9 +444,8 @@ for (f, _f, dim, name) in ((:hcat, :_hcat, 1, "rows"), (:vcat, :_vcat, 2, "cols"
end
end

hvcat(rows::Tuple{Vararg{Int}}, A::Union{AbstractVecOrMat,UniformScaling}...) = _hvcat(rows, A...)
hvcat(rows::Tuple{Vararg{Int}}, A::Union{AbstractVecOrMat,UniformScaling,Number}...) = _hvcat(rows, A...)
function _hvcat(rows::Tuple{Vararg{Int}}, A::Union{AbstractVecOrMat,UniformScaling,Number}...; array_type = promote_to_array_type(A))
hvcat(rows::Tuple{Vararg{Int}}, A::Union{AbstractArray,UniformScaling,Number}...) = _hvcat(rows, A...)
function _hvcat(rows::Tuple{Vararg{Int}}, A::Union{AbstractArray,UniformScaling,Number}...; array_type = promote_to_array_type(A))
require_one_based_indexing(A...)
nr = length(rows)
sum(rows) == length(A) || throw(ArgumentError("mismatch between row sizes and number of arguments"))
Expand Down
2 changes: 1 addition & 1 deletion stdlib/SparseArrays.version
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
SPARSEARRAYS_BRANCH = main
SPARSEARRAYS_SHA1 = 2c7f4d6d839e9a97027454a037bfa004c1eb34b0
SPARSEARRAYS_SHA1 = b4b0e721ada6e8cf5f6391aff4db307be69b0401
SPARSEARRAYS_GIT_URL := https://github.com/JuliaSparse/SparseArrays.jl.git
SPARSEARRAYS_TAR_URL = https://api.github.com/repos/JuliaSparse/SparseArrays.jl/tarball/$1
Loading

0 comments on commit 5a922fa

Please sign in to comment.