Skip to content

Commit

Permalink
RFC: Introduce TypedCallable
Browse files Browse the repository at this point in the history
TypedCallable provides a wrapper for callable objects, with the following benefits:
    1. Enforced type-stability (for concrete AT/RT types)
    2. Fast calling convention (frequently < 10 ns / call)
    3. Normal Julia dispatch semantics (sees new Methods, etc.) + invoke_latest
    4. Pre-compilation support (including `--trim` compatibility)

It can be used like this:
```julia
callbacks = @TypedCallable{(::Int,::Int)->Bool}[]

register_callback!(callbacks, f::F) where {F<:Function} =
    push!(callbacks, @TypedCallable f(::Int,::Int)::Bool)

register_callback!(callbacks, (x,y)->(x == y))
register_callback!(callbacks, (x,y)->(x != y))

@Btime callbacks[rand(1:2)](1,1)
```

This is very similar to the existing `FunctionWrappers.jl`, but there
are a few key differences:
  - Better type support: TypedCallable supports the full range of Julia
    types (incl. Varargs), and it has access to all of Julia's "internal"
    calling conventions so calls are fast (and allocation-free) for a
    wider range of input types
  - Improved dispatch handling: The `@cfunction` functionality used by
    FunctionWrappers has several dispatch bugs, which cause wrappers to
    occasionally not see new Methods. These bugs are fixed (or soon to
    be fixed) for TypedCallable.
  - Pre-compilation support including for `juliac` / `--trim` (#55047)

Many of the improvements here are actually thanks to the `OpaqueClosure`
introduced by @Keno - This type just builds on top of OpaqueClosure to
provide an interface with Julia's usual dispatch semantics.

Co-authored-by: Gabriel Baraldi <[email protected]>
  • Loading branch information
topolarity and gbaraldi committed Jul 13, 2024
1 parent 2b140ba commit 8b40992
Showing 1 changed file with 194 additions and 0 deletions.
194 changes: 194 additions & 0 deletions base/opaque_closure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,197 @@ function generate_opaque_closure(@nospecialize(sig), @nospecialize(rt_lb), @nosp
return ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any, Cint, Cint),
sig, rt_lb, rt_ub, mod, src, lineno, file, nargs, isva, env, do_compile, isinferred)
end

struct Slot{T} end
struct Splat{T}
value::T
end

# Args is a Tuple{Vararg{Union{Slot{T},Some{T}}}} where Slot{T} represents
# an uncurried argument slot, and Some{T} represents an argument to curry.
@noinline @generated function Core.OpaqueClosure(Args::Tuple, ::Slot{RT}) where RT
AT = Any[]
call = Expr(:call)
extracted = Expr[]
closure_args = Expr(:tuple)
for (i, T) in enumerate(Args.parameters)
v = Symbol("arg", i)
is_splat = T <: Splat
if is_splat # TODO: check position
push!(call.args, :($v...))
T = T.parameters[1]
else
push!(call.args, v)
end
if T <: Some
push!(extracted, :($v = something(Args[$i])))
elseif T <: Slot
SlotT = T.parameters[1]
push!(AT, is_splat ? Vararg{SlotT} : SlotT)
push!(closure_args.args, call.args[end])
else @assert false end
end
AT = Tuple{AT...}
return Base.remove_linenums!(quote
$(extracted...)
$(Expr(:opaque_closure, AT, RT, RT, #= allow_partial =# false, :(($(closure_args))->@inline $(call))))
end)
end

"""
TypedCallable{AT,RT}
TypedCallable provides a wrapper for callable objects, with the following benefits:
1. Enforced type-stability (for concrete AT/RT types)
2. Fast calling convention (frequently < 10 ns / call)
3. Normal Julia dispatch semantics (sees new Methods, etc.) + invoke_latest
4. Full pre-compilation support (including `--trim` compatibility)
## Examples
```julia
const callbacks = @TypedCallable{(::Int,::Int)->Bool}[]
register_callback!(callbacks, f::F) where {F<:Function} =
push!(callbacks, @TypedCallable f(::Int,::Int)::Bool)
register_callback!(callbacks, (x,y)->(x == y))
register_callback!(callbacks, (x,y)->(x != y))
# Calling a random (or runtime-known) callback is fast!
@btime callbacks[rand(1:2)](1,1)
```
# Extended help
### As an invalidation barrier
TypedCallable can also be used as an "invalidation barrier", since the caller of a
TypedCallable is not affected by any invalidations of its callee(s). This doesn't
completely cure the original invalidation, but it stops it from propagating all the
way through your code.
This can be especially helpful, e.g., when calling back to user-provided functions
whose invalidations you may have no control over.
"""
mutable struct TypedCallable{AT,RT}
@atomic oc::Base.RefValue{Core.OpaqueClosure{AT,RT}}
const task::Union{Task,Nothing}
const build_oc::Function
end

function Base.show(io::IO, tc::Base.Experimental.TypedCallable)
A, R = typeof(tc).parameters
Base.print(io, "@TypedCallable{")
Base.show_tuple_as_call(io, Symbol(""), A; hasfirst=false)
Base.print(io, "->◌::", R, "}()")
end

function rebuild_in_world!(@nospecialize(self::TypedCallable), world::UInt)
oc = Base.invoke_in_world(world, self.build_oc)
@atomic :release self.oc = Base.Ref(oc)
return oc
end

@inline function (self::TypedCallable{AT,RT})(args...) where {AT,RT}
invoke_world = if self.task === nothing
Base.get_world_counter() # Base.unsafe_load(cglobal(:jl_world_counter, UInt), :acquire) ?
elseif self.task === Base.current_task()
Base.tls_world_age()
else
error("TypedCallable{...} was called from a different task than it was created in.")
end
oc = (@atomic :acquire self.oc)[]
if oc.world != invoke_world
oc = @noinline rebuild_in_world!(self, invoke_world)::Core.OpaqueClosure{AT,RT}
end
return oc(args...)
end

function _TypedCallable_type(ex)
type_err = "Invalid @TypedCallable expression: $(ex)\nExpected \"@TypedCallable{(::T,::U,...)->RT}\""

# Unwrap {...}
(length(ex.args) != 1) && error(type_err)
ex = ex.args[1]

# Unwrap (...)->RT
!(Base.isexpr(ex, :->) && length(ex.args) == 2) && error(type_err)
tuple_, rt = ex.args
if !(Base.isexpr(tuple_, :tuple) && all((x)->Base.isexpr(x, :(::)), tuple_.args))
# note: (arg::T, ...) is specifically allowed (the "arg" part is unused)
error(type_err)
end
!Base.isexpr(rt, :block) && error(type_err)

# Remove any LineNumberNodes inserted by lowering
filter!((x)->!isa(x,Core.LineNumberNode), rt.args)
(length(rt.args) != 1) && error(type_err)

# Build args
AT = Expr[esc(last(x.args)) for x in tuple_.args]
RT = rt.args[1]

# Unwrap ◌::T to T
if Base.isexpr(RT, :(::)) && length(RT.args) == 2 && RT.args[1] == :◌
RT = RT.args[2]
end

return :($TypedCallable{Tuple{$(AT...)}, $(esc(RT))})
end

function _TypedCallable_closure(ex)
if Base.isexpr(ex, :call)
error("""
Invalid @TypedCallable expression: $(ex)
An explicit return type assert is required (e.g. "@TypedCallable f(...)::RT")
""")
end

call_, RT = ex.args
if !Base.isexpr(call_, :call)
error("""Invalid @TypedCallable expression: $(ex)
The supported syntax is:
@TypedCallable{(::T,::U,...)->RT} (to construct the type)
@TypedCallable f(x,::T,...)::RT (to construct the TypedCallable)
""")
end
oc_args = map(call_.args) do arg
is_splat = Base.isexpr(arg, :(...))
arg = is_splat ? arg.args[1] : arg
transformed = if Base.isexpr(arg, :(::))
if length(arg.args) == 1 # it's a "slot"
slot_ty = esc(only(arg.args))
:(Slot{$slot_ty}())
elseif length(arg.args) == 2
(arg, ty) = arg.args
:(Some{$(esc(ty))}($(esc(arg))))
else @assert false end
else
:(Some($(esc(arg))))
end
return is_splat ? Expr(:call, Splat, transformed) : transformed
end
# TODO: kwargs support
RT = :(Slot{$(esc(RT))}())
invoke_latest = true # expose as flag?
task = invoke_latest ? nothing : :(Base.current_task())
return quote
build_oc = ()->Core.OpaqueClosure(($(oc_args...),), $(RT))
$(TypedCallable)(Ref(build_oc()), $task, build_oc)
end
end

macro TypedCallable(ex)
if Base.isexpr(ex, :braces)
return _TypedCallable_type(ex)
elseif Base.isexpr(ex, :call) || (Base.isexpr(ex, :(::)) && length(ex.args) == 2)
return _TypedCallable_closure(ex)
else
error("""Invalid @TypedCallable expression: $(ex)
The supported syntax is:
@TypedCallable{(::T,::U,...)->RT} (to construct the type)
@TypedCallable f(x,::T,...)::RT (to construct the TypedCallable)
""")
end
end

0 comments on commit 8b40992

Please sign in to comment.