From 38b38cfdd5b1a5bdf1a613a6180d2aaadf06a632 Mon Sep 17 00:00:00 2001 From: Matt Bauman Date: Wed, 3 Jul 2019 14:30:35 -0500 Subject: [PATCH] Address part of #28866, prevent array operations from dropping 0d containers (#32122) This is a simple workaround for the handful of elementwise operations that are defined on arrays _without_ the need for explicit broadcast but use broadcasting (with an extra shape check) in their implementation. These were the only affected cases I could find. --- NEWS.md | 4 ++++ base/arraymath.jl | 10 +++++----- base/broadcast.jl | 18 +++++++++++++++++- stdlib/LinearAlgebra/src/adjtrans.jl | 5 ++++- test/arrayops.jl | 17 +++++++++++++++++ 5 files changed, 47 insertions(+), 7 deletions(-) diff --git a/NEWS.md b/NEWS.md index 0310608c30ffd..81fbdcca94bfd 100644 --- a/NEWS.md +++ b/NEWS.md @@ -33,6 +33,10 @@ Standard library changes * `Cmd` interpolation (`` `$(x::Cmd) a b c` `` where) now propagates `x`'s process flags (environment, flags, working directory, etc) if `x` is the first interpolant and errors otherwise ([#24353]). +* Zero-dimensional arrays are now consistently preserved in the return values of mathematical + functions that operate on the array(s) as a whole (and are not explicitly broadcasted across their elements). + Previously, the functions `+`, `-`, `*`, `/`, `conj`, `real` and `imag` returned the unwrapped element + when operating over zero-dimensional arrays ([#32122]). * `IPAddr` subtypes now behave like scalars when used in broadcasting ([#32133]). * `clamp` can now handle missing values ([#31066]). diff --git a/base/arraymath.jl b/base/arraymath.jl index 5a8e1287f3232..7603f4d54ee65 100644 --- a/base/arraymath.jl +++ b/base/arraymath.jl @@ -27,7 +27,7 @@ julia> A conj!(A::AbstractArray{<:Number}) = (@inbounds broadcast!(conj, A, A); A) for f in (:-, :conj, :real, :imag) - @eval ($f)(A::AbstractArray) = broadcast($f, A) + @eval ($f)(A::AbstractArray) = broadcast_preserving_zero_d($f, A) end @@ -36,7 +36,7 @@ end for f in (:+, :-) @eval function ($f)(A::AbstractArray, B::AbstractArray) promote_shape(A, B) # check size compatibility - broadcast($f, A, B) + broadcast_preserving_zero_d($f, A, B) end end @@ -44,15 +44,15 @@ function +(A::Array, Bs::Array...) for B in Bs promote_shape(A, B) # check size compatibility end - broadcast(+, A, Bs...) + broadcast_preserving_zero_d(+, A, Bs...) end for f in (:/, :\, :*) if f != :/ - @eval ($f)(A::Number, B::AbstractArray) = broadcast($f, A, B) + @eval ($f)(A::Number, B::AbstractArray) = broadcast_preserving_zero_d($f, A, B) end if f != :\ - @eval ($f)(A::AbstractArray, B::Number) = broadcast($f, A, B) + @eval ($f)(A::AbstractArray, B::Number) = broadcast_preserving_zero_d($f, A, B) end end diff --git a/base/broadcast.jl b/base/broadcast.jl index 4d5ca6b92fe19..206e2e2da7042 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -11,7 +11,7 @@ using .Base.Cartesian using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin, _msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache, unalias import .Base: copy, copyto!, axes -export broadcast, broadcast!, BroadcastStyle, broadcast_axes, broadcastable, dotview, @__dot__ +export broadcast, broadcast!, BroadcastStyle, broadcast_axes, broadcastable, dotview, @__dot__, broadcast_preserving_zero_d ## Computing the result's axes: deprecated name const broadcast_axes = axes @@ -790,6 +790,22 @@ julia> A """ broadcast!(f::Tf, dest, As::Vararg{Any,N}) where {Tf,N} = (materialize!(dest, broadcasted(f, As...)); dest) +""" + broadcast_preserving_zero_d(f, As...) + +Like [`broadcast`](@ref), except in the case of a 0-dimensional result where it returns a 0-dimensional container + +Broadcast automatically unwraps zero-dimensional results to be just the element itself, +but in some cases it is necessary to always return a container — even in the 0-dimensional case. +""" +function broadcast_preserving_zero_d(f, As...) + bc = broadcasted(f, As...) + r = materialize(bc) + return length(axes(bc)) == 0 ? fill!(similar(bc, typeof(r)), r) : r +end +broadcast_preserving_zero_d(f) = fill(f()) +broadcast_preserving_zero_d(f, as::Number...) = fill(f(as...)) + """ Broadcast.materialize(bc) diff --git a/stdlib/LinearAlgebra/src/adjtrans.jl b/stdlib/LinearAlgebra/src/adjtrans.jl index 84fa4a30140e4..c980e500adde0 100644 --- a/stdlib/LinearAlgebra/src/adjtrans.jl +++ b/stdlib/LinearAlgebra/src/adjtrans.jl @@ -229,7 +229,10 @@ quasiparentt(x) = parent(x); quasiparentt(x::Number) = x # to handle numbers in quasiparenta(x) = parent(x); quasiparenta(x::Number) = conj(x) # to handle numbers in the defs below broadcast(f, avs::Union{Number,AdjointAbsVec}...) = adjoint(broadcast((xs...) -> adjoint(f(adjoint.(xs)...)), quasiparenta.(avs)...)) broadcast(f, tvs::Union{Number,TransposeAbsVec}...) = transpose(broadcast((xs...) -> transpose(f(transpose.(xs)...)), quasiparentt.(tvs)...)) -# TODO unify and allow mixed combinations +# Hack to preserve behavior after #32122; this needs to be done with a broadcast style instead to support dotted fusion +Broadcast.broadcast_preserving_zero_d(f, avs::Union{Number,AdjointAbsVec}...) = adjoint(broadcast((xs...) -> adjoint(f(adjoint.(xs)...)), quasiparenta.(avs)...)) +Broadcast.broadcast_preserving_zero_d(f, tvs::Union{Number,TransposeAbsVec}...) = transpose(broadcast((xs...) -> transpose(f(transpose.(xs)...)), quasiparentt.(tvs)...)) +# TODO unify and allow mixed combinations with a broadcast style ### linear algebra diff --git a/test/arrayops.jl b/test/arrayops.jl index c6abfa98d09e1..7a2fa864f543c 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -2628,6 +2628,23 @@ Base.view(::T25958, args...) = args @test t[end,end,end] == @view(t[end,end,end]) == @views t[end,end,end] end +@testset "0-dimensional container operations" begin + for op in (-, conj, real, imag) + @test op(fill(2)) == fill(op(2)) + @test op(fill(1+2im)) == fill(op(1+2im)) + end + for op in (+, -) + @test op(fill(1), fill(2)) == fill(op(1, 2)) + @test op(fill(1), fill(2)) isa AbstractArray{Int, 0} + end + @test fill(1) + fill(2) + fill(3) == fill(1+2+3) + @test fill(1) / 2 == fill(1/2) + @test 2 \ fill(1) == fill(1/2) + @test 2*fill(1) == fill(2) + @test fill(1)*2 == fill(2) +end + + # Fix oneunit bug for unitful arrays @test oneunit([Second(1) Second(2); Second(3) Second(4)]) == [Second(1) Second(0); Second(0) Second(1)]