Skip to content

Commit

Permalink
Uniform scaling cat with zero dimensions (#29457)
Browse files Browse the repository at this point in the history
* Uniform scaling cat with zero dimensions

* Missing type assertion in test

(cherry picked from commit 4093dbf)
  • Loading branch information
mfalt authored and KristofferC committed Oct 29, 2018
1 parent f7eeed4 commit 6da9a6c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 12 deletions.
25 changes: 13 additions & 12 deletions stdlib/LinearAlgebra/src/uniformscaling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -240,18 +240,18 @@ promote_to_array_type(A::Tuple{Vararg{Union{AbstractVecOrMat,UniformScaling}}})
for (f,dim,name) in ((:hcat,1,"rows"), (:vcat,2,"cols"))
@eval begin
function $f(A::Union{AbstractVecOrMat,UniformScaling}...)
n = 0
n = -1
for a in A
if !isa(a, UniformScaling)
@assert !has_offset_axes(a)
na = size(a,$dim)
n > 0 && n != na &&
n >= 0 && n != na &&
throw(DimensionMismatch(string("number of ", $name,
" of each array must match (got ", n, " and ", na, ")")))
n = na
end
end
n == 0 && throw(ArgumentError($("$f of only UniformScaling objects cannot determine the matrix size")))
n == -1 && throw(ArgumentError($("$f of only UniformScaling objects cannot determine the matrix size")))
return $f(promote_to_arrays(fill(n,length(A)),1, promote_to_array_type(A), A...)...)
end
end
Expand All @@ -262,20 +262,20 @@ function hvcat(rows::Tuple{Vararg{Int}}, A::Union{AbstractVecOrMat,UniformScalin
@assert !has_offset_axes(A...)
nr = length(rows)
sum(rows) == length(A) || throw(ArgumentError("mismatch between row sizes and number of arguments"))
n = zeros(Int, length(A))
n = fill(-1, length(A))
needcols = false # whether we also need to infer some sizes from the column count
j = 0
for i = 1:nr # infer UniformScaling sizes from row counts, if possible:
ni = 0 # number of rows in this block-row
ni = -1 # number of rows in this block-row, -1 indicates unknown
for k = 1:rows[i]
if !isa(A[j+k], UniformScaling)
na = size(A[j+k], 1)
ni > 0 && ni != na &&
ni >= 0 && ni != na &&
throw(DimensionMismatch("mismatch in number of rows"))
ni = na
end
end
if ni > 0
if ni >= 0
for k = 1:rows[i]
n[j+k] = ni
end
Expand All @@ -285,21 +285,22 @@ function hvcat(rows::Tuple{Vararg{Int}}, A::Union{AbstractVecOrMat,UniformScalin
j += rows[i]
end
if needcols # some sizes still unknown, try to infer from column count
nc = j = 0
nc = -1
j = 0
for i = 1:nr
nci = 0
rows[i] > 0 && n[j+1] == 0 && continue # column count unknown in this row
rows[i] > 0 && n[j+1] == -1 && (j += rows[i]; continue)
for k = 1:rows[i]
nci += isa(A[j+k], UniformScaling) ? n[j+k] : size(A[j+k], 2)
end
nc > 0 && nc != nci && throw(DimensionMismatch("mismatch in number of columns"))
nc >= 0 && nc != nci && throw(DimensionMismatch("mismatch in number of columns"))
nc = nci
j += rows[i]
end
nc == 0 && throw(ArgumentError("sizes of UniformScalings could not be inferred"))
nc == -1 && throw(ArgumentError("sizes of UniformScalings could not be inferred"))
j = 0
for i = 1:nr
if rows[i] > 0 && n[j+1] == 0 # this row consists entirely of UniformScalings
if rows[i] > 0 && n[j+1] == -1 # this row consists entirely of UniformScalings
nci = nc ÷ rows[i]
nci * rows[i] != nc && throw(DimensionMismatch("indivisible UniformScaling sizes"))
for k = 1:rows[i]
Expand Down
14 changes: 14 additions & 0 deletions stdlib/LinearAlgebra/test/uniformscaling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,26 @@ end
for T in (Matrix, SparseMatrixCSC)
A = T(rand(3,4))
B = T(rand(3,3))
C = T(rand(0,3))
D = T(rand(2,0))
@test (hcat(A, 2I))::T == hcat(A, Matrix(2I, 3, 3))
@test (vcat(A, 2I))::T == vcat(A, Matrix(2I, 4, 4))
@test (hcat(C, 2I))::T == C
@test (vcat(D, 2I))::T == D
@test (hcat(I, 3I, A, 2I))::T == hcat(Matrix(I, 3, 3), Matrix(3I, 3, 3), A, Matrix(2I, 3, 3))
@test (vcat(I, 3I, A, 2I))::T == vcat(Matrix(I, 4, 4), Matrix(3I, 4, 4), A, Matrix(2I, 4, 4))
@test (hvcat((2,1,2), B, 2I, I, 3I, 4I))::T ==
hvcat((2,1,2), B, Matrix(2I, 3, 3), Matrix(I, 6, 6), Matrix(3I, 3, 3), Matrix(4I, 3, 3))
@test hvcat((3,1), C, C, I, 3I)::T == hvcat((2,1), C, C, Matrix(3I, 6,6))
@test hvcat((2,2,2), I, 2I, 3I, 4I, C, C)::T ==
hvcat((2,2,2), Matrix(I, 3, 3), Matrix(2I, 3,3 ), Matrix(3I, 3,3), Matrix(4I, 3,3), C, C)
@test hvcat((2,2,4), C, C, I, 2I, 3I, 4I, 5I, D)::T ==
hvcat((2,2,4), C, C, Matrix(I, 3, 3), Matrix(2I,3,3),
Matrix(3I, 2, 2), Matrix(4I, 2, 2), Matrix(5I,2,2), D)
@test (hvcat((2,3,2), B, 2I, C, C, I, 3I, 4I))::T ==
hvcat((2,2,2), B, Matrix(2I, 3, 3), C, C, Matrix(3I, 3, 3), Matrix(4I, 3, 3))
@test hvcat((3,2,1), C, C, I, B ,3I, 2I)::T ==
hvcat((2,2,1), C, C, B, Matrix(3I,3,3), Matrix(2I,6,6))
end
end

Expand Down

0 comments on commit 6da9a6c

Please sign in to comment.