diff --git a/stdlib/LinearAlgebra/src/uniformscaling.jl b/stdlib/LinearAlgebra/src/uniformscaling.jl index 57037473f4b29b..06c97ea1a34b4c 100644 --- a/stdlib/LinearAlgebra/src/uniformscaling.jl +++ b/stdlib/LinearAlgebra/src/uniformscaling.jl @@ -240,18 +240,18 @@ promote_to_array_type(A::Tuple{Vararg{Union{AbstractVecOrMat,UniformScaling}}}) for (f,dim,name) in ((:hcat,1,"rows"), (:vcat,2,"cols")) @eval begin function $f(A::Union{AbstractVecOrMat,UniformScaling}...) - n = 0 + n = -1 for a in A if !isa(a, UniformScaling) @assert !has_offset_axes(a) na = size(a,$dim) - n > 0 && n != na && + n >= 0 && n != na && throw(DimensionMismatch(string("number of ", $name, " of each array must match (got ", n, " and ", na, ")"))) n = na end end - n == 0 && throw(ArgumentError($("$f of only UniformScaling objects cannot determine the matrix size"))) + n == -1 && throw(ArgumentError($("$f of only UniformScaling objects cannot determine the matrix size"))) return $f(promote_to_arrays(fill(n,length(A)),1, promote_to_array_type(A), A...)...) end end @@ -262,20 +262,20 @@ function hvcat(rows::Tuple{Vararg{Int}}, A::Union{AbstractVecOrMat,UniformScalin @assert !has_offset_axes(A...) nr = length(rows) sum(rows) == length(A) || throw(ArgumentError("mismatch between row sizes and number of arguments")) - n = zeros(Int, length(A)) + n = fill(-1, length(A)) needcols = false # whether we also need to infer some sizes from the column count j = 0 for i = 1:nr # infer UniformScaling sizes from row counts, if possible: - ni = 0 # number of rows in this block-row + ni = -1 # number of rows in this block-row, -1 indicates unknown for k = 1:rows[i] if !isa(A[j+k], UniformScaling) na = size(A[j+k], 1) - ni > 0 && ni != na && + ni >= 0 && ni != na && throw(DimensionMismatch("mismatch in number of rows")) ni = na end end - if ni > 0 + if ni >= 0 for k = 1:rows[i] n[j+k] = ni end @@ -285,21 +285,22 @@ function hvcat(rows::Tuple{Vararg{Int}}, A::Union{AbstractVecOrMat,UniformScalin j += rows[i] end if needcols # some sizes still unknown, try to infer from column count - nc = j = 0 + nc = -1 + j = 0 for i = 1:nr nci = 0 - rows[i] > 0 && n[j+1] == 0 && continue # column count unknown in this row + rows[i] > 0 && n[j+1] == -1 && (j += rows[i]; continue) for k = 1:rows[i] nci += isa(A[j+k], UniformScaling) ? n[j+k] : size(A[j+k], 2) end - nc > 0 && nc != nci && throw(DimensionMismatch("mismatch in number of columns")) + nc >= 0 && nc != nci && throw(DimensionMismatch("mismatch in number of columns")) nc = nci j += rows[i] end - nc == 0 && throw(ArgumentError("sizes of UniformScalings could not be inferred")) + nc == -1 && throw(ArgumentError("sizes of UniformScalings could not be inferred")) j = 0 for i = 1:nr - if rows[i] > 0 && n[j+1] == 0 # this row consists entirely of UniformScalings + if rows[i] > 0 && n[j+1] == -1 # this row consists entirely of UniformScalings nci = nc รท rows[i] nci * rows[i] != nc && throw(DimensionMismatch("indivisible UniformScaling sizes")) for k = 1:rows[i] diff --git a/stdlib/LinearAlgebra/test/uniformscaling.jl b/stdlib/LinearAlgebra/test/uniformscaling.jl index 52271177a1f96f..0852f065ae1838 100644 --- a/stdlib/LinearAlgebra/test/uniformscaling.jl +++ b/stdlib/LinearAlgebra/test/uniformscaling.jl @@ -183,12 +183,26 @@ end for T in (Matrix, SparseMatrixCSC) A = T(rand(3,4)) B = T(rand(3,3)) + C = T(rand(0,3)) + D = T(rand(2,0)) @test (hcat(A, 2I))::T == hcat(A, Matrix(2I, 3, 3)) @test (vcat(A, 2I))::T == vcat(A, Matrix(2I, 4, 4)) + @test (hcat(C, 2I))::T == C + @test (vcat(D, 2I))::T == D @test (hcat(I, 3I, A, 2I))::T == hcat(Matrix(I, 3, 3), Matrix(3I, 3, 3), A, Matrix(2I, 3, 3)) @test (vcat(I, 3I, A, 2I))::T == vcat(Matrix(I, 4, 4), Matrix(3I, 4, 4), A, Matrix(2I, 4, 4)) @test (hvcat((2,1,2), B, 2I, I, 3I, 4I))::T == hvcat((2,1,2), B, Matrix(2I, 3, 3), Matrix(I, 6, 6), Matrix(3I, 3, 3), Matrix(4I, 3, 3)) + @test hvcat((3,1), C, C, I, 3I)::T == hvcat((2,1), C, C, Matrix(3I, 6,6)) + @test hvcat((2,2,2), I, 2I, 3I, 4I, C, C)::T == + hvcat((2,2,2), Matrix(I, 3, 3), Matrix(2I, 3,3 ), Matrix(3I, 3,3), Matrix(4I, 3,3), C, C) + @test hvcat((2,2,4), C, C, I, 2I, 3I, 4I, 5I, D)::T == + hvcat((2,2,4), C, C, Matrix(I, 3, 3), Matrix(2I,3,3), + Matrix(3I, 2, 2), Matrix(4I, 2, 2), Matrix(5I,2,2), D) + @test (hvcat((2,3,2), B, 2I, C, C, I, 3I, 4I))::T == + hvcat((2,2,2), B, Matrix(2I, 3, 3), C, C, Matrix(3I, 3, 3), Matrix(4I, 3, 3)) + @test hvcat((3,2,1), C, C, I, B ,3I, 2I)::T == + hvcat((2,2,1), C, C, B, Matrix(3I,3,3), Matrix(2I,6,6)) end end