Skip to content

Commit

Permalink
simplify foreigncall handling (#554)
Browse files Browse the repository at this point in the history
  • Loading branch information
KristofferC authored Dec 8, 2022
1 parent 39d04c7 commit be125b1
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 107 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ CodeTracking = "0.5.9, 1"
julia = "1.6"

[extras]
CassetteOverlay = "d78b62d4-37fa-4a6f-acd8-2f19986eb9ee"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
Expand All @@ -29,4 +30,4 @@ Tensors = "48a634ad-e948-5137-8d70-aa71f2a747f4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["DataFrames", "Dates", "DeepDiffs", "Distributed", "FunctionWrappers", "HTTP", "LinearAlgebra", "Logging", "Mmap", "PyCall", "SHA", "SparseArrays", "Tensors", "Test"]
test = ["CassetteOverlay", "DataFrames", "Dates", "DeepDiffs", "Distributed", "FunctionWrappers", "HTTP", "LinearAlgebra", "Logging", "Mmap", "PyCall", "SHA", "SparseArrays", "Tensors", "Test"]
165 changes: 59 additions & 106 deletions src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ end

function lookup_getproperties(a::Expr)
if a.head === :call && length(a.args) == 3 &&
a.args[1] isa QuoteNode && a.args[1].value === Base.getproperty &&
a.args[1] isa QuoteNode && a.args[1].value === Base.getproperty &&
a.args[2] isa QuoteNode && a.args[2].value isa Module &&
a.args[3] isa QuoteNode && a.args[3].value isa Symbol
return lookup_global_ref(Core.GlobalRef(a.args[2].value, a.args[3].value))
Expand Down Expand Up @@ -179,7 +179,6 @@ function optimize!(code::CodeInfo, scope)
# Replace :llvmcall and :foreigncall with compiled variants. See
# https://github.com/JuliaDebug/JuliaInterpreter.jl/issues/13#issuecomment-464880123
foreigncalls_idx = Int[]
delete_idxs = Int[]
for (idx, stmt) in enumerate(code.code)
# Foregincalls can be rhs of assignments
if isexpr(stmt, :(=))
Expand All @@ -190,36 +189,18 @@ function optimize!(code::CodeInfo, scope)
# Check for :llvmcall
arg1 = stmt.args[1]
if (arg1 === :llvmcall || lookup_stmt(code.code, arg1) === Base.llvmcall) && isempty(sparams) && scope isa Method
nargs = length(stmt.args)-4
# Call via `invokelatest` to avoid compiling it until we need it
delete_idx = Base.invokelatest(build_compiled_call!, stmt, Base.llvmcall, code, idx, nargs, sparams, evalmod)
delete_idx === nothing && error("llvmcall must be compiled, but exited early from build_compiled_call!")
Base.invokelatest(build_compiled_llvmcall!, stmt, code, idx, evalmod)
push!(foreigncalls_idx, idx)
append!(delete_idxs, delete_idx)
end
elseif stmt.head === :foreigncall && scope isa Method
nargs = length(stmt.args[3]::SimpleVector)
# Call via `invokelatest` to avoid compiling it until we need it
delete_idx = Base.invokelatest(build_compiled_call!, stmt, :ccall, code, idx, nargs, sparams, evalmod)
if delete_idx !== nothing
push!(foreigncalls_idx, idx)
append!(delete_idxs, delete_idx)
end
Base.invokelatest(build_compiled_foreigncall!, stmt, code, sparams, evalmod)
push!(foreigncalls_idx, idx)
end
end
end

if !isempty(delete_idxs)
ssalookup = compute_ssa_mapping_delete_statements!(code, delete_idxs)
let lkup = ssalookup
foreigncalls_idx = map(x -> lkup[x], foreigncalls_idx)
end
deleteat!(codelocs(code), delete_idxs)
deleteat!(code.code, delete_idxs)
code.ssavaluetypes = length(code.code)
renumber_ssa!(code.code, ssalookup)
end

## Un-nest :call expressions (so that there will be only one :call per line)
# This will allow us to re-use args-buffers rather than having to allocate new ones each time.
old_code, old_codelocs = code.code, codelocs(code)
Expand Down Expand Up @@ -273,66 +254,52 @@ function parametric_type_to_expr(@nospecialize(t::Type))
return t
end

# Handle :llvmcall & :foreigncall (issue #28)
function build_compiled_call!(stmt::Expr, fcall, code, idx, nargs::Int, sparams::Vector{Symbol}, evalmod)
TVal = evalmod == Core.Compiler ? Core.Compiler.Val : Val
delete_idx = Int[]
if fcall === :ccall
cfunc, RetType, ArgType = lookup_stmt(code.code, stmt.args[1]), stmt.args[2], stmt.args[3]::SimpleVector
# delete cconvert and unsafe_convert calls and forward the original values, since
# the same conversions will be applied within the generated compiled variant of this :foreigncall anyway
args = []
for (atype, arg) in zip(ArgType, stmt.args[6:6+nargs-1])
if atype === Any
push!(args, arg)
else
if arg isa SSAValue
unsafe_convert_expr = code.code[arg.id]::Expr
push!(delete_idx, arg.id) # delete the unsafe_convert
cconvert_val = unsafe_convert_expr.args[3]
if isa(cconvert_val, SSAValue)
push!(delete_idx, cconvert_val.id) # delete the cconvert
newarg = (code.code[cconvert_val.id]::Expr).args[3]
push!(args, newarg)
else
@assert isa(cconvert_val, SlotNumber)
push!(args, cconvert_val)
end
elseif arg isa SlotNumber
idx = findfirst(code.code) do expr
Meta.isexpr(expr, :(=)) || return false
lhs = expr.args[1]
return lhs isa SlotNumber && lhs.id === arg.id
end::Int
unsafe_convert_expr = code.code[idx]::Expr
push!(delete_idx, idx) # delete the unsafe_convert
push!(args, unsafe_convert_expr.args[2])
else
error("unexpected foreigncall argument type encountered: $(typeof(arg))")
end
end
end
else
# Run a mini-interpreter to extract the types
framecode = FrameCode(CompiledCalls, code; optimize=false)
frame = Frame(framecode, prepare_framedata(framecode, []))
idxstart = idx
for i = 2:4
idxstart = smallest_ref(code.code, stmt.args[i], idxstart)
end
frame.pc = idxstart
if idxstart < idx
while true
pc = step_expr!(Compiled(), frame)
pc === idx && break
pc === nothing && error("this should never happen")
end
function build_compiled_llvmcall!(stmt::Expr, code, idx, evalmod)
# Run a mini-interpreter to extract the types
framecode = FrameCode(CompiledCalls, code; optimize=false)
frame = Frame(framecode, prepare_framedata(framecode, []))
idxstart = idx
for i = 2:4
idxstart = smallest_ref(code.code, stmt.args[i], idxstart)
end
frame.pc = idxstart
if idxstart < idx
while true
pc = step_expr!(Compiled(), frame)
pc === idx && break
pc === nothing && error("this should never happen")
end
cfunc, RetType, ArgType = @lookup(frame, stmt.args[2]), @lookup(frame, stmt.args[3]), @lookup(frame, stmt.args[4])::DataType
args = stmt.args[5:end]
end
llvmir, RetType, ArgType = @lookup(frame, stmt.args[2]), @lookup(frame, stmt.args[3]), @lookup(frame, stmt.args[4])::DataType
args = stmt.args[5:end]
argnames = Any[Symbol(:arg, i) for i = 1:length(args)]
cc_key = (llvmir, RetType, ArgType, evalmod) # compiled call key
f = get(compiled_calls, cc_key, nothing)
if f === nothing
methname = gensym("compiled_llvmcall")
def = :(
function $methname($(argnames...))
return $(Base.llvmcall)($llvmir, $RetType, $ArgType, $(argnames...))
end)
f = Core.eval(evalmod, def)
compiled_calls[cc_key] = f
end

stmt.args[1] = QuoteNode(f)
stmt.head = :call
deleteat!(stmt.args, 2:length(stmt.args))
append!(stmt.args, args)
end


# Handle :llvmcall & :foreigncall (issue #28)
function build_compiled_foreigncall!(stmt::Expr, code, sparams::Vector{Symbol}, evalmod)
TVal = evalmod == Core.Compiler ? Core.Compiler.Val : Val
cfunc, RetType, ArgType = lookup_stmt(code.code, stmt.args[1]), stmt.args[2], stmt.args[3]::SimpleVector

dynamic_ccall = false
if isa(cfunc, Expr) # specification by tuple, e.g., (:clock, "libc")
oldcfunc = nothing
if isa(cfunc, Expr) # specification by tuple, e.g., (:clock, "libc")
cfunc = something(static_eval(cfunc), cfunc)
end
if isa(cfunc, Symbol)
Expand All @@ -348,14 +315,12 @@ function build_compiled_call!(stmt::Expr, fcall, code, idx, nargs::Int, sparams:
@assert length(RetType) == 1
RetType = RetType[1]
end
args = stmt.args[6:end]
# When the ccall is dynamic we pass the pointer as an argument so can reuse the function
cc_key = (dynamic_ccall ? :ptr : cfunc, RetType, ArgType, evalmod, length(sparams)) # compiled call key
cc_key = ((dynamic_ccall ? :ptr : cfunc), RetType, ArgType, evalmod, length(sparams)) # compiled call key
f = get(compiled_calls, cc_key, nothing)
argnames = Any[Symbol(:arg, i) for i = 1:nargs]
if f === nothing
if fcall === :ccall
ArgType = Expr(:tuple, Any[parametric_type_to_expr(t) for t in ArgType::SimpleVector]...)
end
ArgType = Expr(:tuple, Any[parametric_type_to_expr(t) for t in ArgType::SimpleVector]...)
RetType = parametric_type_to_expr(RetType)
# #285: test whether we can evaluate an type constraints on parametric expressions
# this essentially comes down to having the names be available in CompiledCalls,
Expand All @@ -366,31 +331,19 @@ function build_compiled_call!(stmt::Expr, fcall, code, idx, nargs::Int, sparams:
catch
return nothing
end
argnames = Any[Symbol(:arg, i) for i = 1:length(args)]
wrapargs = copy(argnames)
if dynamic_ccall
pushfirst!(wrapargs, cfunc)
end
for sparam in sparams
push!(wrapargs, :(::$TVal{$sparam}))
end
methname = gensym("compiledcall")
calling_convention = stmt.args[5]
if calling_convention === :(:llvmcall)
def = :(
function $methname($(wrapargs...)) where {$(sparams...)}
return $fcall($cfunc, llvmcall, $RetType, $ArgType, $(argnames...))
end)
elseif calling_convention === :(:stdcall)
def = :(
function $methname($(wrapargs...)) where {$(sparams...)}
return $fcall($cfunc, stdcall, $RetType, $ArgType, $(argnames...))
end)
else
def = :(
function $methname($(wrapargs...)) where {$(sparams...)}
return $fcall($cfunc, $RetType, $ArgType, $(argnames...))
end)
if dynamic_ccall
pushfirst!(wrapargs, cfunc)
end
methname = gensym("compiled_ccall")
def = :(
function $methname($(wrapargs...)) where {$(sparams...)}
return $(Expr(:foreigncall, cfunc, RetType, stmt.args[3:5]..., argnames...))
end)
f = Core.eval(evalmod, def)
compiled_calls[cc_key] = f
end
Expand All @@ -404,7 +357,7 @@ function build_compiled_call!(stmt::Expr, fcall, code, idx, nargs::Int, sparams:
for i in 1:length(sparams)
push!(stmt.args, :($TVal($(Expr(:static_parameter, i)))))
end
return delete_idx
return nothing
end

function replace_coretypes!(src; rev::Bool=false)
Expand Down
17 changes: 17 additions & 0 deletions test/interpret.jl
Original file line number Diff line number Diff line change
Expand Up @@ -937,3 +937,20 @@ end
@static if isdefined(Base.Experimental, Symbol("@opaque"))
@test @interpret (Base.Experimental.@opaque x->3*x)(4) == 12
end

# CassetteOverlay, issue #552
@static if VERSION >= v"1.8"
using CassetteOverlay
end

@static if VERSION >= v"1.8"
function foo()
x = IdDict()
x[:foo] = 1
end
@MethodTable SinTable;
@testset "CassetteOverlay" begin
pass = @overlaypass SinTable;
@test (@interpret pass(foo)) == 1
end
end

0 comments on commit be125b1

Please sign in to comment.