Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow standalone dotted operators #35706

Closed
wants to merge 15 commits into from
8 changes: 7 additions & 1 deletion base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using .Base.Cartesian
using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin,
_msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache, unalias
import .Base: copy, copyto!, axes
export broadcast, broadcast!, BroadcastStyle, broadcast_axes, broadcastable, dotview, @__dot__, broadcast_preserving_zero_d
export broadcast, broadcast!, BroadcastStyle, broadcast_axes, broadcastable, dotview, @__dot__, broadcast_preserving_zero_d, BroadcastOp

## Computing the result's axes: deprecated name
const broadcast_axes = axes
Expand Down Expand Up @@ -1261,4 +1261,10 @@ end
end
@inline broadcasted(::S, f, args...) where S<:BroadcastStyle = Broadcasted{S}(f, args)

struct BroadcastOp{F} <: Function
simeonschaub marked this conversation as resolved.
Show resolved Hide resolved
f::F
end

@inline (op::BroadcastOp)(x...) = op.f.(x...)
stevengj marked this conversation as resolved.
Show resolved Hide resolved

end # module
10 changes: 6 additions & 4 deletions src/julia-syntax.scm
Original file line number Diff line number Diff line change
Expand Up @@ -1884,13 +1884,15 @@
`(= ,lhs ,rhs)))

(define (expand-forms e)
(if (or (atom? e) (memq (car e) '(quote inert top core globalref outerref module toplevel ssavalue null true false meta using import export thismodule toplevel-only)))
e
(let ((ex (get expand-table (car e) #f)))
(cond
((atom? e) (if (dotop-named? e) `(call (top BroadcastOp) ,(undotop e)) e))
((memq (car e) '(quote inert top core globalref outerref module toplevel ssavalue null true false meta using import export thismodule toplevel-only))
e)
(else (let ((ex (get expand-table (car e) #f)))
(if ex
(ex e)
(cons (car e)
(map expand-forms (cdr e)))))))
(map expand-forms (cdr e))))))))

;; table mapping expression head to a function expanding that form
(define expand-table
Expand Down
4 changes: 4 additions & 0 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -923,3 +923,7 @@ k4 = similar(u)
f(a,b,c,d,e) = @. a = a + 1*(b+c+d+e)
@allocated f(u,k1,k2,k3,k4)
@test (@allocated f(u,k1,k2,k3,k4)) == 0

@test identity(.+) == Broadcast.BroadcastOp(+)
@test identity.(.*) == Broadcast.BroadcastOp(*)
@test map(.+, [[1,2], [3,4]], [5, 6]) == [[6,7], [9,10]]