Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow standalone dotted operators #35706

Closed
wants to merge 15 commits into from
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ New language features
`Base.Experimental.@optlevel n`. For code that is not performance-critical, setting
this to 0 or 1 can provide significant latency improvements ([#34896]).

* Standalone "dotted" operators now get lowered to `Base.BroadcastOp(op)` and can be passed to
higher-order functions, i.e. `.op` is functionally equivalent to `(x...) -> op.(x...)`.
([#34156], [#35706])
simeonschaub marked this conversation as resolved.
Show resolved Hide resolved

Language changes
----------------

Expand Down
31 changes: 30 additions & 1 deletion base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using .Base.Cartesian
using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin,
_msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache, unalias
import .Base: copy, copyto!, axes
export broadcast, broadcast!, BroadcastStyle, broadcast_axes, broadcastable, dotview, @__dot__, broadcast_preserving_zero_d
export broadcast, broadcast!, BroadcastStyle, broadcast_axes, broadcastable, dotview, @__dot__, broadcast_preserving_zero_d, BroadcastOp

## Computing the result's axes: deprecated name
const broadcast_axes = axes
Expand Down Expand Up @@ -1261,4 +1261,33 @@ end
end
@inline broadcasted(::S, f, args...) where S<:BroadcastStyle = Broadcasted{S}(f, args)

"""
BroadcastOp{F} <: Function
simeonschaub marked this conversation as resolved.
Show resolved Hide resolved

Represents the "dotted" version of an operator, which broadcasts the operator over its
arguments, so `BroadcastOp(op)` is functionally equivalent to `(x...) -> op.(x...)`.
simeonschaub marked this conversation as resolved.
Show resolved Hide resolved

Can be created by just passing an operator preceded by a dot to a higher-order function.

# Examples
```jldoctest
julia> a = [reshape(i:i+3, 2, 2) for i in [1, 5]];
simeonschaub marked this conversation as resolved.
Show resolved Hide resolved

julia> b = [reshape(i:i+3, 2, 2) for i in [9, 13]];

julia> map(.*, a, b)
2-element Array{Array{Int64,2},1}:
[9 33; 20 48]
[65 105; 84 128]

julia> Base.BroadcastOp(+)(a, b) == a .+ b
true
```
"""
struct BroadcastOp{F} <: Function
simeonschaub marked this conversation as resolved.
Show resolved Hide resolved
f::F
end

@inline (op::BroadcastOp)(x...) = op.f.(x...)
stevengj marked this conversation as resolved.
Show resolved Hide resolved

end # module
14 changes: 10 additions & 4 deletions src/julia-syntax.scm
Original file line number Diff line number Diff line change
Expand Up @@ -1884,13 +1884,19 @@
`(= ,lhs ,rhs)))

(define (expand-forms e)
(if (or (atom? e) (memq (car e) '(quote inert top core globalref outerref module toplevel ssavalue null true false meta using import export thismodule toplevel-only)))
e
(let ((ex (get expand-table (car e) #f)))
(cond
;; if atom is a dotted operator .op, lower to BroadcastOp(op)
((atom? e)
(if (dotop-named? e)
`(call (top BroadcastOp) ,(undotop e))
e))
((memq (car e) '(quote inert top core globalref outerref module toplevel ssavalue null true false meta using import export thismodule toplevel-only))
e)
(else (let ((ex (get expand-table (car e) #f)))
(if ex
(ex e)
(cons (car e)
(map expand-forms (cdr e)))))))
(map expand-forms (cdr e))))))))

;; table mapping expression head to a function expanding that form
(define expand-table
Expand Down
4 changes: 4 additions & 0 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -923,3 +923,7 @@ k4 = similar(u)
f(a,b,c,d,e) = @. a = a + 1*(b+c+d+e)
@allocated f(u,k1,k2,k3,k4)
@test (@allocated f(u,k1,k2,k3,k4)) == 0

@test identity(.+) == Broadcast.BroadcastOp(+)
@test identity.(.*) == Broadcast.BroadcastOp(*)
@test map(.+, [[1,2], [3,4]], [5, 6]) == [[6,7], [9,10]]