Skip to content
This repository has been archived by the owner on May 27, 2021. It is now read-only.

Commit

Permalink
Use the new non-recursive codegen from JuliaLang/julia#25984.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed May 4, 2018
1 parent c4676d1 commit 8b7e9d4
Showing 1 changed file with 52 additions and 89 deletions.
141 changes: 52 additions & 89 deletions src/jit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,6 @@ export cufunction
# main code generation functions
#

function module_setup(mod::LLVM.Module)
# NOTE: NVPTX::TargetMachine's data layout doesn't match the NVPTX user guide,
# so we specify it ourselves
if Int === Int64
triple!(mod, "nvptx64-nvidia-cuda")
datalayout!(mod, "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64")
else
triple!(mod, "nvptx-nvidia-cuda")
datalayout!(mod, "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64")
end

# add debug info metadata
push!(metadata(mod), "llvm.module.flags",
MDNode([ConstantInt(Int32(1)), # llvm::Module::Error
MDString("Debug Info Version"),
ConstantInt(DEBUG_METADATA_VERSION())]))
end

# make function names safe for PTX
safe_fn(fn::String) = replace(fn, r"[^aA-zZ0-9_]"=>"_")
safe_fn(f::Core.Function) = safe_fn(String(typeof(f).name.mt.name))
Expand All @@ -48,7 +30,6 @@ end

function irgen(@nospecialize(f), @nospecialize(tt))
# get the method instance
isa(f, Core.Builtin) && throw(ArgumentError("argument is not a generic function"))
world = typemax(UInt)
meth = which(f, tt)
sig_tt = Tuple{typeof(f), tt.parameters...}
Expand All @@ -59,99 +40,81 @@ function irgen(@nospecialize(f), @nospecialize(tt))
(Any, Any, Any, UInt), meth, ti, env, world)

# set-up the compiler interface
function hook_module_setup(ref::Ptr{Cvoid})
ref = convert(LLVM.API.LLVMModuleRef, ref)
module_setup(LLVM.Module(ref))
end
function hook_raise_exception(insblock::Ptr{Cvoid}, ex::Ptr{Cvoid})
insblock = convert(LLVM.API.LLVMValueRef, insblock)
ex = convert(LLVM.API.LLVMValueRef, ex)
raise_exception(BasicBlock(insblock), Value(ex))
end
dependencies = Vector{LLVM.Module}()
function hook_module_activation(ref::Ptr{Cvoid})
ref = convert(LLVM.API.LLVMModuleRef, ref)
push!(dependencies, LLVM.Module(ref))
end
params = Base.CodegenParams(track_allocations=false,
code_coverage=false,
static_alloc=false,
prefer_specsig=true,
module_setup=hook_module_setup,
module_activation=hook_module_activation,
raise_exception=hook_raise_exception)

# get the code
mod = let
ref = ccall(:jl_get_llvmf_defn, LLVM.API.LLVMValueRef,
(Any, UInt, Bool, Bool, Base.CodegenParams),
linfo, world, #=wrapper=#false, #=optimize=#false, params)
if ref == C_NULL
error("could not compile the specified method")
end

llvmf = LLVM.Function(ref)
LLVM.parent(llvmf)
end

# the main module should contain a single jfptr_ function definition,
# e.g. jlcall_kernel_vadd_62977
definitions = filter(f->!isdeclaration(f), functions(mod))
wrapper = let
fs = if VERSION >= v"0.7.0-DEV.4747"
collect(filter(f->startswith(LLVM.name(f), "jfptr_"), definitions))
else
collect(filter(f->startswith(LLVM.name(f), "jlcall_"), definitions))
end
@assert length(fs) == 1
fs[1]
# generate IR
native_code = ccall(:jl_create_native, Ptr{Cvoid},
(Vector{Core.MethodInstance}, Base.CodegenParams),
[linfo], params)
@assert native_code != C_NULL
llvm_mod_ref = ccall(:jl_get_llvm_module, LLVM.API.LLVMModuleRef,
(Ptr{Cvoid},), native_code)
@assert llvm_mod_ref != C_NULL
llvm_mod = LLVM.Module(llvm_mod_ref)

# get the top-level function index
api = Ref{UInt8}(typemax(UInt8))
llvm_func_idx = Ref{UInt32}()
llvm_specfunc_idx = Ref{UInt32}()
ccall(:jl_get_function_id, Nothing,
(Ptr{Cvoid}, Ptr{Core.MethodInstance}, Ptr{UInt8}, Ptr{UInt32}, Ptr{UInt32}),
native_code, Ref(linfo), api, llvm_func_idx, llvm_specfunc_idx)
@assert api[] != typemax(api[])

# get the top-level function
llvm_func_ref = ccall(:jl_get_llvm_function, LLVM.API.LLVMValueRef,
(Ptr{Cvoid}, UInt32), native_code, llvm_func_idx[]-1)
@assert llvm_func_ref != C_NULL
llvm_func = LLVM.Function(llvm_func_ref)
llvm_specfunc_ref = ccall(:jl_get_llvm_function, LLVM.API.LLVMValueRef,
(Ptr{Cvoid}, UInt32), native_code, llvm_specfunc_idx[]-1)
@assert llvm_specfunc_ref != C_NULL
llvm_specfunc = LLVM.Function(llvm_specfunc_ref)

# configure the module
# NOTE: NVPTX::TargetMachine's data layout doesn't match the NVPTX user guide,
# so we specify it ourselves
if Int === Int64
triple!(llvm_mod, "nvptx64-nvidia-cuda")
datalayout!(llvm_mod, "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64")
else
triple!(llvm_mod, "nvptx-nvidia-cuda")
datalayout!(llvm_mod, "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64")
end

# the jlcall wrapper function should point us to the actual entry-point,
# e.g. julia_kernel_vadd_62984
entry_tag = let
m = if VERSION >= v"0.7.0-DEV.4747"
match(r"jfptr_(.+)_\d+", LLVM.name(wrapper))
else
match(r"jlcall_(.+)_\d+", LLVM.name(wrapper))
end
@assert m != nothing
m.captures[1]
end
unsafe_delete!(mod, wrapper)
entry = let
re = Regex("julia_$(entry_tag)_\\d+")
llvmcall_re = Regex("julia_$(entry_tag)_\\d+u\\d+")
fs = collect(filter(f->occursin(re, LLVM.name(f)) &&
!occursin(llvmcall_re, LLVM.name(f)), definitions))
if length(fs) != 1
error("Could not find single entry-point for $entry_tag (available functions: ",
join(map(f->LLVM.name(f), definitions), ", "), ")")
# clean up incompatibilities
for llvm_func in functions(llvm_mod)
# remove non-specsig functions
fn = LLVM.name(llvm_func)
if startswith(fn, "jlcall_")
unsafe_delete!(llvm_mod, llvm_func)
continue
end
fs[1]
end

# link in dependent modules
link!.(Ref(mod), dependencies)

# clean up incompatibilities
for llvmf in functions(mod)
# only occurs in debug builds
delete!(function_attributes(llvmf), EnumAttribute("sspreq", 0, jlctx[]))
delete!(function_attributes(llvm_func), EnumAttribute("sspreq", 0, jlctx[]))

# make function names safe for ptxas
# (LLVM ought to do this, see eg. D17738 and D19126), but fails
# TODO: fix all globals?
llvmfn = LLVM.name(llvmf)
if !isdeclaration(llvmf)
llvmfn′ = safe_fn(llvmf)
if llvmfn != llvmfn′
LLVM.name!(llvmf, llvmfn′)
if !isdeclaration(llvm_func)
fn = safe_fn(llvm_func)
if fn != fn
LLVM.name!(llvm_func, fn)
end
end
end

return mod, entry
return llvm_mod, llvm_specfunc
end

# promote a function to a kernel
Expand All @@ -163,7 +126,7 @@ function promote_kernel!(mod::LLVM.Module, entry_f::LLVM.Function, @nospecialize
kernel = wrap_entry!(mod, entry_f, tt);


# property annotations TODO: belongs in irgen? doesn't maxntidx doesn't appear in ptx code?
# property annotations

annotations = LLVM.Value[kernel]

Expand Down

0 comments on commit 8b7e9d4

Please sign in to comment.