Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dot-broadcasting for short-circuiting ops .&& and .|| #39594

Merged
merged 7 commits into from
Apr 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ New language features
* `(; a, b) = x` can now be used to destructure properties `a` and `b` of `x`. This syntax is equivalent to `a = getproperty(x, :a)`
and similarly for `b`. ([#39285])
* Implicit multiplication by juxtaposition is now allowed for radical symbols (e.g., `x√y` and `x∛y`). ([#40173])
* The short-circuiting operators `&&` and `||` can now be dotted to participate in broadcast fusion
as `.&&` and `.||`. ([#39594])

Language changes
----------------
Expand Down
29 changes: 19 additions & 10 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using .Base.Cartesian
using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin, @pure,
_msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache, unalias
import .Base: copy, copyto!, axes
export broadcast, broadcast!, BroadcastStyle, broadcast_axes, broadcastable, dotview, @__dot__, broadcast_preserving_zero_d, BroadcastFunction
export broadcast, broadcast!, BroadcastStyle, broadcast_axes, broadcastable, dotview, @__dot__, broadcast_preserving_zero_d, BroadcastFunction, andand, oror

## Computing the result's axes: deprecated name
const broadcast_axes = axes
Expand Down Expand Up @@ -179,6 +179,21 @@ function Broadcasted{Style}(f::F, args::Args, axes=nothing) where {Style, F, Arg
Broadcasted{Style, typeof(axes), Core.Typeof(f), Args}(f, args, axes)
end

struct AndAnd end
andand = AndAnd()
broadcasted(::AndAnd, a, b) = broadcasted((a, b) -> a && b, a, b)
function broadcasted(::AndAnd, a, bc::Broadcasted)
bcf = flatten(bc)
mbauman marked this conversation as resolved.
Show resolved Hide resolved
broadcasted((a, args...) -> a && bcf.f(args...), a, bcf.args...)
end
struct OrOr end
const oror = OrOr()
broadcasted(::OrOr, a, b) = broadcasted((a, b) -> a || b, a, b)
function broadcasted(::OrOr, a, bc::Broadcasted)
bcf = flatten(bc)
broadcasted((a, args...) -> a || bcf.f(args...), a, bcf.args...)
end

Base.convert(::Type{Broadcasted{NewStyle}}, bc::Broadcasted{Style,Axes,F,Args}) where {NewStyle,Style,Axes,F,Args} =
Broadcasted{NewStyle,Axes,F,Args}(bc.f, bc.args, bc.axes)

Expand Down Expand Up @@ -1257,15 +1272,9 @@ function __dot__(x::Expr)
tmp = x.head === :(<:) ? :.<: : :.>:
Expr(:call, tmp, dotargs...)
else
if x.head === :&& || x.head === :||
error("""
Using `&&` and `||` is disallowed in `@.` expressions.
Use `&` or `|` for elementwise logical operations.
""")
end
head = string(x.head)
if last(head) == '=' && first(head) != '.'
Expr(Symbol('.',head), dotargs...)
head = String(x.head)::String
if last(head) == '=' && first(head) != '.' || head == "&&" || head == "||"
Expr(Symbol('.', head), dotargs...)
else
Expr(x.head, dotargs...)
end
Expand Down
8 changes: 4 additions & 4 deletions src/julia-parser.scm
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
(define prec-pair (add-dots '(=>)))
(define prec-conditional '(?))
(define prec-arrow (add-dots '(← → ↔ ↚ ↛ ↞ ↠ ↢ ↣ ↦ ↤ ↮ ⇎ ⇍ ⇏ ⇐ ⇒ ⇔ ⇴ ⇶ ⇷ ⇸ ⇹ ⇺ ⇻ ⇼ ⇽ ⇾ ⇿ ⟵ ⟶ ⟷ ⟹ ⟺ ⟻ ⟼ ⟽ ⟾ ⟿ ⤀ ⤁ ⤂ ⤃ ⤄ ⤅ ⤆ ⤇ ⤌ ⤍ ⤎ ⤏ ⤐ ⤑ ⤔ ⤕ ⤖ ⤗ ⤘ ⤝ ⤞ ⤟ ⤠ ⥄ ⥅ ⥆ ⥇ ⥈ ⥊ ⥋ ⥎ ⥐ ⥒ ⥓ ⥖ ⥗ ⥚ ⥛ ⥞ ⥟ ⥢ ⥤ ⥦ ⥧ ⥨ ⥩ ⥪ ⥫ ⥬ ⥭ ⥰ ⧴ ⬱ ⬰ ⬲ ⬳ ⬴ ⬵ ⬶ ⬷ ⬸ ⬹ ⬺ ⬻ ⬼ ⬽ ⬾ ⬿ ⭀ ⭁ ⭂ ⭃ ⭄ ⭇ ⭈ ⭉ ⭊ ⭋ ⭌ ← → ⇜ ⇝ ↜ ↝ ↩ ↪ ↫ ↬ ↼ ↽ ⇀ ⇁ ⇄ ⇆ ⇇ ⇉ ⇋ ⇌ ⇚ ⇛ ⇠ ⇢ ↷ ↶ ↺ ↻ --> <-- <-->)))
(define prec-lazy-or '(|\|\||))
(define prec-lazy-and '(&&))
(define prec-lazy-or (add-dots '(|\|\||)))
(define prec-lazy-and (add-dots '(&&)))
(define prec-comparison
(append! '(in isa)
(add-dots '(> < >= ≥ <= ≤ == === ≡ != ≠ !== ≢ ∈ ∉ ∋ ∌ ⊆ ⊈ ⊂ ⊄ ⊊ ∝ ∊ ∍ ∥ ∦ ∷ ∺ ∻ ∽ ∾ ≁ ≃ ≂ ≄ ≅ ≆ ≇ ≈ ≉ ≊ ≋ ≌ ≍ ≎ ≐ ≑ ≒ ≓ ≖ ≗ ≘ ≙ ≚ ≛ ≜ ≝ ≞ ≟ ≣ ≦ ≧ ≨ ≩ ≪ ≫ ≬ ≭ ≮ ≯ ≰ ≱ ≲ ≳ ≴ ≵ ≶ ≷ ≸ ≹ ≺ ≻ ≼ ≽ ≾ ≿ ⊀ ⊁ ⊃ ⊅ ⊇ ⊉ ⊋ ⊏ ⊐ ⊑ ⊒ ⊜ ⊩ ⊬ ⊮ ⊰ ⊱ ⊲ ⊳ ⊴ ⊵ ⊶ ⊷ ⋍ ⋐ ⋑ ⋕ ⋖ ⋗ ⋘ ⋙ ⋚ ⋛ ⋜ ⋝ ⋞ ⋟ ⋠ ⋡ ⋢ ⋣ ⋤ ⋥ ⋦ ⋧ ⋨ ⋩ ⋪ ⋫ ⋬ ⋭ ⋲ ⋳ ⋴ ⋵ ⋶ ⋷ ⋸ ⋹ ⋺ ⋻ ⋼ ⋽ ⋾ ⋿ ⟈ ⟉ ⟒ ⦷ ⧀ ⧁ ⧡ ⧣ ⧤ ⧥ ⩦ ⩧ ⩪ ⩫ ⩬ ⩭ ⩮ ⩯ ⩰ ⩱ ⩲ ⩳ ⩵ ⩶ ⩷ ⩸ ⩹ ⩺ ⩻ ⩼ ⩽ ⩾ ⩿ ⪀ ⪁ ⪂ ⪃ ⪄ ⪅ ⪆ ⪇ ⪈ ⪉ ⪊ ⪋ ⪌ ⪍ ⪎ ⪏ ⪐ ⪑ ⪒ ⪓ ⪔ ⪕ ⪖ ⪗ ⪘ ⪙ ⪚ ⪛ ⪜ ⪝ ⪞ ⪟ ⪠ ⪡ ⪢ ⪣ ⪤ ⪥ ⪦ ⪧ ⪨ ⪩ ⪪ ⪫ ⪬ ⪭ ⪮ ⪯ ⪰ ⪱ ⪲ ⪳ ⪴ ⪵ ⪶ ⪷ ⪸ ⪹ ⪺ ⪻ ⪼ ⪽ ⪾ ⪿ ⫀ ⫁ ⫂ ⫃ ⫄ ⫅ ⫆ ⫇ ⫈ ⫉ ⫊ ⫋ ⫌ ⫍ ⫎ ⫏ ⫐ ⫑ ⫒ ⫓ ⫔ ⫕ ⫖ ⫗ ⫘ ⫙ ⫷ ⫸ ⫹ ⫺ ⊢ ⊣ ⟂ <: >:))))
Expand Down Expand Up @@ -111,8 +111,8 @@

; operators that are special forms, not function names
(define syntactic-operators
(append! (add-dots '(= += -= *= /= //= |\\=| ^= ÷= %= <<= >>= >>>= |\|=| &= ⊻=))
'(:= $= && |\|\|| |.| ... ->)))
(append! (add-dots '(&& |\|\|| = += -= *= /= //= |\\=| ^= ÷= %= <<= >>= >>>= |\|=| &= ⊻=))
'(:= $= |.| ... ->)))
(define syntactic-unary-operators '($ & |::|))

(define syntactic-op? (Set syntactic-operators))
Expand Down
9 changes: 9 additions & 0 deletions src/julia-syntax.scm
Original file line number Diff line number Diff line change
Expand Up @@ -1804,6 +1804,10 @@
e))))
((and (pair? e) (eq? (car e) 'comparison))
(dot-to-fuse (expand-compare-chain (cdr e)) top))
((and (pair? e) (eq? (car e) '.&&))
(make-fuse '(top andand) (cdr e)))
((and (pair? e) (eq? (car e) '|.\|\||))
(make-fuse '(top oror) (cdr e)))
(else e)))
(let ((e (dot-to-fuse rhs #t)) ; an expression '(fuse func args) if expr is a dot call
(lhs-view (ref-to-view lhs))) ; x[...] expressions on lhs turn in to view(x, ...) to update x in-place
Expand Down Expand Up @@ -2125,6 +2129,11 @@
;; e = (|.| f x)
(expand-fuse-broadcast '() e)))

'.&&
(lambda (e) (expand-fuse-broadcast '() e))
'|.\|\||
(lambda (e) (expand-fuse-broadcast '() e))

'.=
(lambda (e)
(expand-fuse-broadcast (cadr e) (caddr e)))
Expand Down
29 changes: 29 additions & 0 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,35 @@ p0 = copy(p)
@test repr(.!) == "Base.Broadcast.BroadcastFunction(!)"
@test eval(:(.+)) == Base.BroadcastFunction(+)

@testset "Issue #5187: Broadcasting of short-circuiting ops" begin
ex = Meta.parse("A .< 1 .|| A .> 2")
@test ex == :((A .< 1) .|| (A .> 2))
@test ex.head == :.||
ex = Meta.parse("A .< 1 .&& A .> 2")
@test ex == :((A .< 1) .&& (A .> 2))
@test ex.head == :.&&

A = -1:4
@test (A .< 1 .|| A .> 2) == [true, true, false, false, true, true]
@test (A .>= 1 .&& A .<= 2) == [false, false, true, true, false, false]

mutable struct F5187; x; end
(f::F5187)(x) = (f.x += x)
@test (iseven.(1:4) .&& (F5187(0)).(ones(4))) == [false, 1, false, 2]
@test (iseven.(1:4) .|| (F5187(0)).(ones(4))) == [1, true, 2, true]
r = 1:4; o = ones(4); f = F5187(0);
@test (@. iseven(r) && f(o)) == [false, 1, false, 2]
@test (@. iseven(r) || f(o)) == [3, true, 4, true]

@test (iseven.(1:8) .&& iseven.((F5187(0)).(ones(8))) .&& (F5187(0)).(ones(8))) == [false,false,false,1,false,false,false,2]
@test (iseven.(1:8) .|| iseven.((F5187(0)).(ones(8))) .|| (F5187(0)).(ones(8))) == [1,true,true,true,2,true,true,true]
r = 1:8; o = ones(8); f1 = F5187(0); f2 = F5187(0)
@test (@. iseven(r) && iseven(f1(o)) && f2(o)) == [false,false,false,1,false,false,false,2]
@test (@. iseven(r) || iseven(f1(o)) || f2(o)) == [3,true,true,true,4,true,true,true]
@test (iseven.(1:8) .&& iseven.((F5187(0)).(ones(8))) .&& (F5187(0)).(ones(8))) == [false,false,false,1,false,false,false,2]
@test (iseven.(1:8) .|| iseven.((F5187(0)).(ones(8))) .|| (F5187(0)).(ones(8))) == [1,true,true,true,2,true,true,true]
end

@testset "Issue #28382: inferrability of broadcast with Union eltype" begin
@test isequal([1, 2] .+ [3.0, missing], [4.0, missing])
@test Core.Compiler.return_type(broadcast, Tuple{typeof(+), Vector{Int},
Expand Down