From a5575a0fa58ac43eac2fae4325bcac5997d4bc1c Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Mon, 7 Feb 2022 12:38:28 +0000 Subject: [PATCH 1/6] define IteratorSize for array style broadcasted --- base/broadcast.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/base/broadcast.jl b/base/broadcast.jl index 7c32e6893268f..38fa337278edb 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -262,6 +262,7 @@ Base.@propagate_inbounds function Base.iterate(bc::Broadcasted, s) end Base.IteratorSize(::Type{<:Broadcasted{<:Any,<:NTuple{N,Base.OneTo}}}) where {N} = Base.HasShape{N}() +Base.IteratorSize(::Type{<:Broadcasted{<:AbstractArrayStyle{N}, Nothing}}) where {N} = Base.HasShape{N}() Base.IteratorEltype(::Type{<:Broadcasted}) = Base.EltypeUnknown() ## Instantiation fills in the "missing" fields in Broadcasted. From e0511da81852c6c84a12ce0a501e370cf8f687a5 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Mon, 7 Feb 2022 12:38:52 +0000 Subject: [PATCH 2/6] test collected broadcasted objects retain their shape --- test/broadcast.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/broadcast.jl b/test/broadcast.jl index 39af6e20b9a08..57f8bd0ec8f00 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -855,6 +855,12 @@ let @test ndims(copy(bc)) == ndims([v for v in bc]) == ndims(collect(bc)) == ndims(bc) end +# issue 43847: collect preserves shape of broadcasted +let + bc = Broadcast.broadcasted(*, [1 2; 3 4], 2) + @test size(collect(bc)) == size(bc) + end + # issue #31295 let a = rand(5), b = rand(5), c = copy(a) view(identity(a), 1:3) .+= view(b, 1:3) From f0049ba6c475d3bd5681edb29c0fd93ce19888c1 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Mon, 14 Feb 2022 11:04:47 +0000 Subject: [PATCH 3/6] IteratorSize for ArrayStyle broadcasts that don't propagate dims --- base/broadcast.jl | 1 + test/broadcast.jl | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/base/broadcast.jl b/base/broadcast.jl index 38fa337278edb..ea257d25f71e2 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -263,6 +263,7 @@ end Base.IteratorSize(::Type{<:Broadcasted{<:Any,<:NTuple{N,Base.OneTo}}}) where {N} = Base.HasShape{N}() Base.IteratorSize(::Type{<:Broadcasted{<:AbstractArrayStyle{N}, Nothing}}) where {N} = Base.HasShape{N}() +Base.IteratorSize(::Type{<:Broadcasted{<:ArrayStyle, Nothing, <:Any, <:Tuple{T, N}}}) where {T, N} = Base.HasShape{ndims(T)}() Base.IteratorEltype(::Type{<:Broadcasted}) = Base.EltypeUnknown() ## Instantiation fills in the "missing" fields in Broadcasted. diff --git a/test/broadcast.jl b/test/broadcast.jl index 57f8bd0ec8f00..b71370c206f4d 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -858,7 +858,11 @@ end # issue 43847: collect preserves shape of broadcasted let bc = Broadcast.broadcasted(*, [1 2; 3 4], 2) - @test size(collect(bc)) == size(bc) + @test collect(Iterators.product(bc, bc)) == collect(Iterators.product(copy(bc), copy(bc))) + + a1 = AD1(rand(2,3)) + bc1 = Broadcast.broadcasted(*, a1, 2) + @test collect(Iterators.product(bc1, bc1)) == collect(Iterators.product(copy(bc1), copy(bc1))) end # issue #31295 From 81efab9628429e4211e2d985a4b612a91dd40420 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Thu, 17 Feb 2022 12:33:12 +0000 Subject: [PATCH 4/6] Generalise IteratorSize definition for broadcasted --- base/broadcast.jl | 14 +++++++++++--- test/broadcast.jl | 18 ++++++++++++++++++ 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/base/broadcast.jl b/base/broadcast.jl index ea257d25f71e2..8ee4673b83854 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -261,9 +261,17 @@ Base.@propagate_inbounds function Base.iterate(bc::Broadcasted, s) return (bc[i], (s[1], newstate)) end -Base.IteratorSize(::Type{<:Broadcasted{<:Any,<:NTuple{N,Base.OneTo}}}) where {N} = Base.HasShape{N}() -Base.IteratorSize(::Type{<:Broadcasted{<:AbstractArrayStyle{N}, Nothing}}) where {N} = Base.HasShape{N}() -Base.IteratorSize(::Type{<:Broadcasted{<:ArrayStyle, Nothing, <:Any, <:Tuple{T, N}}}) where {T, N} = Base.HasShape{ndims(T)}() +Base.IteratorSize(::Type{T}) where {T<:Broadcasted} = Base.HasShape{ndims(T)}() +Base.ndims(BC::Type{<:Broadcasted{<:Any,Nothing}}) = _maxndims(fieldtype(BC, 2)) +function Base.ndims(BC::Type{<:Broadcasted{<:AbstractArrayStyle{N},Nothing}}) where {N} + N isa Integer && return N + _maxndims(fieldtype(BC, 2)) +end +_maxndims(T) = mapfoldl(_ndims, max, _fieldtypes(T)) +_fieldtypes(T) = ntuple(Base.Fix1(fieldtype,T), Val(fieldcount(T))) # Base.fieldtypes is not stable. +_ndims(x) = ndims(x) +_ndims(::Type{<:Tuple}) = 1 + Base.IteratorEltype(::Type{<:Broadcasted}) = Base.EltypeUnknown() ## Instantiation fills in the "missing" fields in Broadcasted. diff --git a/test/broadcast.jl b/test/broadcast.jl index b71370c206f4d..7b1d3538fbf90 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -863,6 +863,24 @@ let a1 = AD1(rand(2,3)) bc1 = Broadcast.broadcasted(*, a1, 2) @test collect(Iterators.product(bc1, bc1)) == collect(Iterators.product(copy(bc1), copy(bc1))) + + # using ndims of second arg + bc2 = Broadcast.broadcasted(*, 2, a1) + @test collect(Iterators.product(bc2, bc2)) == collect(Iterators.product(copy(bc2), copy(bc2))) + + # >2 args + bc3 = Broadcast.broadcasted(*, a1, 3, a1) + @test collect(Iterators.product(bc3, bc3)) == collect(Iterators.product(copy(bc3), copy(bc3))) + + # including a tuple and custom array type + bc4 = Broadcast.broadcasted(*, (1,2,3), AD1(rand(3))) + @test collect(Iterators.product(bc4, bc4)) == collect(Iterators.product(copy(bc4), copy(bc4))) + + # testing ArrayConflict + @test Broadcast.broadcasted(+, AD1(rand(3)), AD2(rand(3))) isa Broadcast.Broadcasted{Broadcast.ArrayConflict} + @test Broadcast.broadcasted(+, AD1(rand(3)), AD2(rand(3))) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{Any}} + + @test @inferred(Base.IteratorSize(Broadcast.broadcasted((1,2,3),a1,zeros(3,3,3)))) === Base.HasShape{3}() end # issue #31295 From 1b6ffdad7a0bde53dfd5e4d423945eb9fe4028bb Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Mon, 21 Feb 2022 10:16:12 +0000 Subject: [PATCH 5/6] support itertor size for nested broadcasts using @pure --- base/broadcast.jl | 5 ++--- test/broadcast.jl | 5 +++++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/base/broadcast.jl b/base/broadcast.jl index 8ee4673b83854..20873adbf1bd9 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -244,7 +244,7 @@ Base.IndexStyle(::Type{<:Broadcasted{<:Any}}) = IndexCartesian() Base.LinearIndices(bc::Broadcasted{<:Any,<:Tuple{Any}}) = LinearIndices(axes(bc))::LinearIndices{1} -Base.ndims(::Broadcasted{<:Any,<:NTuple{N,Any}}) where {N} = N +Base.ndims(bc::Broadcasted) = ndims(typeof(bc)) Base.ndims(::Type{<:Broadcasted{<:Any,<:NTuple{N,Any}}}) where {N} = N Base.size(bc::Broadcasted) = map(length, axes(bc)) @@ -267,8 +267,7 @@ function Base.ndims(BC::Type{<:Broadcasted{<:AbstractArrayStyle{N},Nothing}}) wh N isa Integer && return N _maxndims(fieldtype(BC, 2)) end -_maxndims(T) = mapfoldl(_ndims, max, _fieldtypes(T)) -_fieldtypes(T) = ntuple(Base.Fix1(fieldtype,T), Val(fieldcount(T))) # Base.fieldtypes is not stable. +Base.@pure _maxndims(T) = mapfoldl(_ndims, max, fieldtypes(T)) _ndims(x) = ndims(x) _ndims(::Type{<:Tuple}) = 1 diff --git a/test/broadcast.jl b/test/broadcast.jl index 7b1d3538fbf90..8e577fa69bbf1 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -881,6 +881,11 @@ let @test Broadcast.broadcasted(+, AD1(rand(3)), AD2(rand(3))) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{Any}} @test @inferred(Base.IteratorSize(Broadcast.broadcasted((1,2,3),a1,zeros(3,3,3)))) === Base.HasShape{3}() + + # inference on nested + bc = Base.broadcasted(+, AD1(randn(3)), AD1(randn(3))) + bc_nest = Base.broadcasted(+, bc , bc) + @test @inferred(Base.IteratorSize(bc_nest)) === Base.HasShape{1}() end # issue #31295 From 70fc3cdc11b086fc6c70595006d2a8398d5d7e6b Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Fri, 4 Mar 2022 14:28:52 +0000 Subject: [PATCH 6/6] define _maxndims methods for small tuples to help inference --- base/broadcast.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/base/broadcast.jl b/base/broadcast.jl index 20873adbf1bd9..1896e5edad105 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -263,11 +263,15 @@ end Base.IteratorSize(::Type{T}) where {T<:Broadcasted} = Base.HasShape{ndims(T)}() Base.ndims(BC::Type{<:Broadcasted{<:Any,Nothing}}) = _maxndims(fieldtype(BC, 2)) -function Base.ndims(BC::Type{<:Broadcasted{<:AbstractArrayStyle{N},Nothing}}) where {N} - N isa Integer && return N - _maxndims(fieldtype(BC, 2)) +Base.ndims(::Type{<:Broadcasted{<:AbstractArrayStyle{N},Nothing}}) where {N<:Integer} = N + +_maxndims(T::Type{<:Tuple}) = reduce(max, (ntuple(n -> _ndims(fieldtype(T, n)), Base._counttuple(T)))) +_maxndims(::Type{<:Tuple{T}}) where {T} = ndims(T) +_maxndims(::Type{<:Tuple{T}}) where {T<:Tuple} = _ndims(T) +function _maxndims(::Type{<:Tuple{T, S}}) where {T, S} + return T<:Tuple || S<:Tuple ? max(_ndims(T), _ndims(S)) : max(ndims(T), ndims(S)) end -Base.@pure _maxndims(T) = mapfoldl(_ndims, max, fieldtypes(T)) + _ndims(x) = ndims(x) _ndims(::Type{<:Tuple}) = 1