Skip to content
This repository has been archived by the owner on May 27, 2021. It is now read-only.

Commit

Permalink
Adapt to nonrecursive codegen. [ci skip]
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Feb 27, 2020
1 parent cf7e5af commit 0ecfa88
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 166 deletions.
169 changes: 51 additions & 118 deletions src/compiler/irgen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,47 +58,10 @@ function compile_method_instance(job::CompilerJob, method_instance::Core.MethodI
end

# set-up the compiler interface
last_method_instance = nothing
call_stack = Vector{Core.MethodInstance}()
dependencies = MultiDict{Core.MethodInstance,LLVM.Function}()
function hook_module_setup(ref::Ptr{Cvoid})
ref = convert(LLVM.API.LLVMModuleRef, ref)
ir = LLVM.Module(ref)
module_setup(ir)
end
function hook_module_activation(ref::Ptr{Cvoid})
ref = convert(LLVM.API.LLVMModuleRef, ref)
ir = LLVM.Module(ref)
postprocess(ir)

# find the function that this module defines
llvmfs = filter(llvmf -> !isdeclaration(llvmf) &&
linkage(llvmf) == LLVM.API.LLVMExternalLinkage,
collect(functions(ir)))

llvmf = nothing
if length(llvmfs) == 1
llvmf = first(llvmfs)
elseif length(llvmfs) > 1
llvmfs = filter!(llvmf -> startswith(LLVM.name(llvmf), "julia_"), llvmfs)
if length(llvmfs) == 1
llvmf = first(llvmfs)
end
end

@compiler_assert llvmf !== nothing job

insert!(dependencies, last_method_instance, llvmf)
end
function hook_emit_function(method_instance, code, world)
call_stack = [method_instance]
function hook_emit_function(method_instance, code)
push!(call_stack, method_instance)

# check for recursion
if method_instance in call_stack[1:end-1]
throw(KernelError(job, "recursion is currently not supported";
bt=backtrace(job, call_stack)))
end

# check for Base functions that exist in CUDAnative too
# FIXME: this might be too coarse
method = method_instance.def
Expand All @@ -115,17 +78,14 @@ function compile_method_instance(job::CompilerJob, method_instance::Core.MethodI
end
end
end
function hook_emitted_function(method, code, world)
function hook_emitted_function(method, code)
@compiler_assert last(call_stack) == method job
last_method_instance = pop!(call_stack)
pop!(call_stack)
end
param_kwargs = [:cached => false,
:track_allocations => false,
param_kwargs = [:track_allocations => false,
:code_coverage => false,
:static_alloc => false,
:prefer_specsig => true,
:module_setup => hook_module_setup,
:module_activation => hook_module_activation,
:emit_function => hook_emit_function,
:emitted_function => hook_emitted_function]
if LLVM.version() >= v"8.0" && VERSION >= v"1.3.0-DEV.547"
Expand All @@ -150,63 +110,55 @@ function compile_method_instance(job::CompilerJob, method_instance::Core.MethodI
end
params = Base.CodegenParams(;param_kwargs...)

# get the code
ref = ccall(:jl_get_llvmf_defn, LLVM.API.LLVMValueRef,
(Any, UInt, Bool, Bool, Base.CodegenParams),
method_instance, world, #=wrapper=#false, #=optimize=#false, params)
if ref == C_NULL
throw(InternalCompilerError(job, "the Julia compiler could not generate LLVM IR"))
# generate IR
native_code = ccall(:jl_create_native, Ptr{Cvoid},
(Vector{Core.MethodInstance}, Base.CodegenParams),
[method_instance], params)
@assert native_code != C_NULL
llvm_mod_ref = ccall(:jl_get_llvm_module, LLVM.API.LLVMModuleRef,
(Ptr{Cvoid},), native_code)
@assert llvm_mod_ref != C_NULL
llvm_mod = LLVM.Module(llvm_mod_ref)

# get the top-level code
code = Core.Compiler.inf_for_methodinstance(method_instance, world, world)

# get the top-level function index
llvm_func_idx = Ref{Int32}(-1)
llvm_specfunc_idx = Ref{Int32}(-1)
ccall(:jl_breakpoint, Nothing, ())
ccall(:jl_get_function_id, Nothing,
(Ptr{Cvoid}, Any, Ptr{Int32}, Ptr{Int32}),
native_code, code, llvm_func_idx, llvm_specfunc_idx)
@assert llvm_func_idx[] != -1
@assert llvm_specfunc_idx[] != -1

# get the top-level function)
llvm_func_ref = ccall(:jl_get_llvm_function, LLVM.API.LLVMValueRef,
(Ptr{Cvoid}, UInt32), native_code, llvm_func_idx[]-1)
@assert llvm_func_ref != C_NULL
llvm_func = LLVM.Function(llvm_func_ref)
llvm_specfunc_ref = ccall(:jl_get_llvm_function, LLVM.API.LLVMValueRef,
(Ptr{Cvoid}, UInt32), native_code, llvm_specfunc_idx[]-1)
@assert llvm_specfunc_ref != C_NULL
llvm_specfunc = LLVM.Function(llvm_specfunc_ref)

# configure the module
# NOTE: NVPTX::TargetMachine's data layout doesn't match the NVPTX user guide,
# so we specify it ourselves
if Int === Int64
triple!(llvm_mod, "nvptx64-nvidia-cuda")
datalayout!(llvm_mod, "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64")
else
triple!(llvm_mod, "nvptx-nvidia-cuda")
datalayout!(llvm_mod, "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64")
end
llvmf = LLVM.Function(ref)
ir = LLVM.parent(llvmf)
postprocess(ir)

return llvmf, dependencies
return llvm_specfunc, llvm_mod
end

function irgen(job::CompilerJob, method_instance::Core.MethodInstance, world)
entry, dependencies = @timeit_debug to "emission" compile_method_instance(job, method_instance, world)
mod = LLVM.parent(entry)

# link in dependent modules
@timeit_debug to "linking" begin
# we disable Julia's compilation cache not to poison it with GPU-specific code.
# as a result, we might get multiple modules for a single method instance.
cache = Dict{String,String}()

for called_method_instance in keys(dependencies)
llvmfs = dependencies[called_method_instance]

# link the first module
llvmf = popfirst!(llvmfs)
llvmfn = LLVM.name(llvmf)
link!(mod, LLVM.parent(llvmf))

# process subsequent duplicate modules
for dup_llvmf in llvmfs
if Base.JLOptions().debug_level >= 2
# link them too, to ensure accurate backtrace reconstruction
link!(mod, LLVM.parent(dup_llvmf))
else
# don't link them, but note the called function name in a cache
dup_llvmfn = LLVM.name(dup_llvmf)
cache[dup_llvmfn] = llvmfn
end
end
end

# resolve function declarations with cached entries
for llvmf in filter(isdeclaration, collect(functions(mod)))
llvmfn = LLVM.name(llvmf)
if haskey(cache, llvmfn)
def_llvmfn = cache[llvmfn]
replace_uses!(llvmf, functions(mod)[def_llvmfn])

@compiler_assert isempty(uses(llvmf)) job
unsafe_delete!(LLVM.parent(llvmf), llvmf)
end
end
end
entry, mod = @timeit_debug to "emission" compile_method_instance(job, method_instance, world)

# clean up incompatibilities
@timeit_debug to "clean-up" for llvmf in functions(mod)
Expand All @@ -215,28 +167,9 @@ function irgen(job::CompilerJob, method_instance::Core.MethodInstance, world)
# only occurs in debug builds
delete!(function_attributes(llvmf), EnumAttribute("sspstrong", 0, JuliaContext()))

# rename functions
# rename functions to be safe for ptxas
# FIXME: Base already does this, but leaves in % which isn't safe
if !isdeclaration(llvmf)
# Julia disambiguates local functions by prefixing with `#\d#`.
# since we don't use a global function namespace, get rid of those tags.
if occursin(r"^julia_#\d+#", llvmfn)
llvmfn′ = replace(llvmfn, r"#\d+#"=>"")
if !haskey(functions(mod), llvmfn′)
LLVM.name!(llvmf, llvmfn′)
llvmfn = llvmfn′
end
end

# anonymous functions are just named `#\d`, make that somewhat more readable
m = match(r"_#(\d+)_", llvmfn)
if m !== nothing
llvmfn′ = replace(llvmfn, m.match=>"_anonymous$(m.captures[1])_")
LLVM.name!(llvmf, llvmfn′)
llvmfn = llvmfn′
end

# finally, make function names safe for ptxas
# (LLVM should to do this, but fails, see eg. D17738 and D19126)
llvmfn′ = safe_name(llvmfn)
if llvmfn != llvmfn′
LLVM.name!(llvmf, llvmfn′)
Expand Down
4 changes: 2 additions & 2 deletions test/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ function post17057_parent(arr::Ptr{Int64})
end

# bug: default module activation segfaulted on NULL child function if cached=false
params = Base.CodegenParams(cached=false)
params = Base.CodegenParams()
if VERSION >= v"1.1.0-DEV.762"
_dump_function(post17057_parent, Tuple{Ptr{Int64}},
#=native=#false, #=wrapper=#false, #=strip=#false,
Expand All @@ -31,4 +31,4 @@ end

############################################################################################

end
end
50 changes: 4 additions & 46 deletions test/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
ir = sprint(io->CUDAnative.code_llvm(io, valid_kernel, Tuple{}; optimize=false, dump_module=true))

# module should contain our function + a generic call wrapper
@test occursin(r"define void @.*julia_valid_kernel.*\(\)", ir)
@test occursin(r"define\ .* void\ @.*julia_valid_kernel.*\(\)"x, ir)
@test !occursin("define %jl_value_t* @jlcall_", ir)

# there should be no debug metadata
Expand Down Expand Up @@ -130,21 +130,6 @@ end
CUDAnative.code_llvm(devnull, D32593, Tuple{CuDeviceVector{D32593_struct,AS.Global}})
end

@testset "kernel names" begin
regular() = return
closure = ()->return

function test_name(f, name; kwargs...)
code = sprint(io->CUDAnative.code_llvm(io, f, Tuple{}; kwargs...))
@test occursin(name, code)
end

test_name(regular, "julia_regular")
test_name(regular, "ptxcall_regular"; kernel=true)
test_name(closure, "julia_anonymous")
test_name(closure, "ptxcall_anonymous"; kernel=true)
end

@testset "PTX TBAA" begin
load(ptr) = unsafe_load(ptr)
store(ptr) = unsafe_store!(ptr, 0)
Expand Down Expand Up @@ -251,7 +236,7 @@ end
end

asm = sprint(io->CUDAnative.code_ptx(io, parent, Tuple{Int64}))
@test occursin(r"call.uni\s+julia_child_"m, asm)
@test occursin(r"call.uni\s+julia_.*child_"m, asm)
end

@testset "kernel functions" begin
Expand Down Expand Up @@ -309,15 +294,15 @@ end
end

asm = sprint(io->CUDAnative.code_ptx(io, parent1, Tuple{Int}))
@test occursin(r".func julia_child_", asm)
@test occursin(r".func julia_.*child_", asm)

function parent2(i)
child(i+1)
return
end

asm = sprint(io->CUDAnative.code_ptx(io, parent2, Tuple{Int}))
@test occursin(r".func julia_child_", asm)
@test occursin(r".func julia_.*child_", asm)
end

@testset "child function reuse bis" begin
Expand Down Expand Up @@ -381,21 +366,6 @@ end
CUDAnative.code_ptx(devnull, kernel, Tuple{Float64})
end

@testset "kernel names" begin
regular() = nothing
closure = ()->nothing

function test_name(f, name; kwargs...)
code = sprint(io->CUDAnative.code_ptx(io, f, Tuple{}; kwargs...))
@test occursin(name, code)
end

test_name(regular, "julia_regular")
test_name(regular, "ptxcall_regular"; kernel=true)
test_name(closure, "julia_anonymous")
test_name(closure, "ptxcall_anonymous"; kernel=true)
end

@testset "exception arguments" begin
function kernel(a)
unsafe_store!(a, trunc(Int, unsafe_load(a)))
Expand Down Expand Up @@ -473,18 +443,6 @@ end

# some validation happens in the emit_function hook, which is called by code_llvm

@testset "recursion" begin
@eval recurse_outer(i) = i > 0 ? i : recurse_inner(i)
@eval @noinline recurse_inner(i) = i < 0 ? i : recurse_outer(i)

@test_throws_message(CUDAnative.KernelError, CUDAnative.code_llvm(devnull, recurse_outer, Tuple{Int})) do msg
occursin("recursion is currently not supported", msg) &&
occursin("[1] recurse_outer", msg) &&
occursin("[2] recurse_inner", msg) &&
occursin("[3] recurse_outer", msg)
end
end

@testset "base intrinsics" begin
foobar(i) = sin(i)

Expand Down

0 comments on commit 0ecfa88

Please sign in to comment.