Skip to content

Commit

Permalink
Normalize indices in promote_shape error messages (#41311)
Browse files Browse the repository at this point in the history
Seeing implementation details like `Base.OneTo` in error messages may
be confusing to some users (cf discussion in #39242,
[discourse](https://discourse.julialang.org/t/promote-shape-dimension-mismatch/57529/)).

This PR turns
```julia
julia> ones(2, 3) + ones(3, 2)
ERROR: DimensionMismatch("dimensions must match: a has dims (Base.OneTo(2), Base.OneTo(3)), b has dims (Base.OneTo(3), Base.OneTo(2)), mismatch at 1")
```
into
```julia
julia> ones(2, 3) + ones(3, 2)
ERROR: DimensionMismatch("dimensions must match: a has size (2, 3), b has size (3, 2), mismatch at 1")
```

Fixes #40118. 

(This is basically #40124, but redone because I made a mess rebasing).

---------

Co-authored-by: Jameson Nash <[email protected]>
  • Loading branch information
tpapp and vtjnash authored Feb 4, 2024
1 parent 831cc14 commit 47663bd
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 24 deletions.
57 changes: 36 additions & 21 deletions base/indices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,26 +106,49 @@ IndexStyle(::IndexStyle, ::IndexStyle) = IndexCartesian()

promote_shape(::Tuple{}, ::Tuple{}) = ()

function promote_shape(a::Tuple{Int,}, b::Tuple{Int,})
if a[1] != b[1]
throw(DimensionMismatch("dimensions must match: a has dims $a, b has dims $b"))
# Consistent error message for promote_shape mismatch, hiding type details like
# OneTo. When b ≡ nothing, it is omitted; i can be supplied for an index.
function throw_promote_shape_mismatch(a::Tuple, b::Union{Nothing,Tuple}, i = nothing)
if a isa Tuple{Vararg{Base.OneTo}} && (b === nothing || b isa Tuple{Vararg{Base.OneTo}})
a = map(lastindex, a)::Dims
b === nothing || (b = map(lastindex, b)::Dims)
end
_has_axes = !(a isa Dims && (b === nothing || b isa Dims))
if _has_axes
_normalize(d) = map(x -> firstindex(x):lastindex(x), d)
a = _normalize(a)
b === nothing || (b = _normalize(b))
_things = "axes "
else
_things = "size "
end
msg = IOBuffer()
print(msg, "a has ", _things)
print(msg, a)
if b nothing
print(msg, ", b has ", _things)
print(msg, b)
end
if i nothing
print(msg, ", mismatch at dim ", i)
end
throw(DimensionMismatch(String(take!(msg))))
end

function promote_shape(a::Tuple{Int,}, b::Tuple{Int,})
a[1] != b[1] && throw_promote_shape_mismatch(a, b)
return a
end

function promote_shape(a::Tuple{Int,Int}, b::Tuple{Int,})
if a[1] != b[1] || a[2] != 1
throw(DimensionMismatch("dimensions must match: a has dims $a, b has dims $b"))
end
(a[1] != b[1] || a[2] != 1) && throw_promote_shape_mismatch(a, b)
return a
end

promote_shape(a::Tuple{Int,}, b::Tuple{Int,Int}) = promote_shape(b, a)

function promote_shape(a::Tuple{Int, Int}, b::Tuple{Int, Int})
if a[1] != b[1] || a[2] != b[2]
throw(DimensionMismatch("dimensions must match: a has dims $a, b has dims $b"))
end
(a[1] != b[1] || a[2] != b[2]) && throw_promote_shape_mismatch(a, b)
return a
end

Expand Down Expand Up @@ -153,14 +176,10 @@ function promote_shape(a::Dims, b::Dims)
return promote_shape(b, a)
end
for i=1:length(b)
if a[i] != b[i]
throw(DimensionMismatch("dimensions must match: a has dims $a, b has dims $b, mismatch at $i"))
end
a[i] != b[i] && throw_promote_shape_mismatch(a, b, i)
end
for i=length(b)+1:length(a)
if a[i] != 1
throw(DimensionMismatch("dimensions must match: a has dims $a, must have singleton at dim $i"))
end
a[i] != 1 && throw_promote_shape_mismatch(a, nothing, i)
end
return a
end
Expand All @@ -174,14 +193,10 @@ function promote_shape(a::Indices, b::Indices)
return promote_shape(b, a)
end
for i=1:length(b)
if a[i] != b[i]
throw(DimensionMismatch("dimensions must match: a has dims $a, b has dims $b, mismatch at $i"))
end
a[i] != b[i] && throw_promote_shape_mismatch(a, b, i)
end
for i=length(b)+1:length(a)
if a[i] != 1:1
throw(DimensionMismatch("dimensions must match: a has dims $a, must have singleton at dim $i"))
end
a[i] != 1:1 && throw_promote_shape_mismatch(a, nothing, i)
end
return a
end
Expand Down
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ function Base.muladd(A::AbstractMatrix, y::AbstractVecOrMat, z::Union{Number, Ab
end
for d in ndims(Ay)+1:ndims(z)
# Similar error to what Ay + z would give, to match (Any,Any,Any) method:
size(z,d) > 1 && throw(DimensionMismatch(string("dimensions must match: z has dims ",
size(z,d) > 1 && throw(DimensionMismatch(string("z has dims ",
axes(z), ", must have singleton at dim ", d)))
end
Ay .+ z
Expand All @@ -197,7 +197,7 @@ function Base.muladd(u::AbstractVector, v::AdjOrTransAbsVec, z::Union{Number, Ab
end
for d in 3:ndims(z)
# Similar error to (u*v) + z:
size(z,d) > 1 && throw(DimensionMismatch(string("dimensions must match: z has dims ",
size(z,d) > 1 && throw(DimensionMismatch(string("z has dims ",
axes(z), ", must have singleton at dim ", d)))
end
(u .* v) .+ z
Expand Down
4 changes: 3 additions & 1 deletion stdlib/LinearAlgebra/src/structuredbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,8 @@ end
# We can also implement `map` and its promotion in terms of broadcast with a stricter dimension check
function map(f, A::StructuredMatrix, Bs::StructuredMatrix...)
sz = size(A)
all(map(B->size(B)==sz, Bs)) || throw(DimensionMismatch("dimensions must match"))
for B in Bs
size(B) == sz || Base.throw_promote_shape_mismatch(sz, size(B))
end
return f.(A, Bs...)
end
5 changes: 5 additions & 0 deletions stdlib/LinearAlgebra/test/structuredbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ end
@test map!(*, Z, X, Y) == broadcast(*, fX, fY)
end
end
# these would be valid for broadcast, but not for map
@test_throws DimensionMismatch map(+, D, Diagonal(rand(1)))
@test_throws DimensionMismatch map(+, D, Diagonal(rand(1)), D)
@test_throws DimensionMismatch map(+, D, D, Diagonal(rand(1)))
@test_throws DimensionMismatch map(+, Diagonal(rand(1)), D, D)
end

@testset "Issue #33397" begin
Expand Down

0 comments on commit 47663bd

Please sign in to comment.