diff --git a/src/jit.jl b/src/jit.jl index 9c1b0aca..ac42a0f9 100644 --- a/src/jit.jl +++ b/src/jit.jl @@ -7,24 +7,6 @@ export cufunction # main code generation functions # -function module_setup(mod::LLVM.Module) - # NOTE: NVPTX::TargetMachine's data layout doesn't match the NVPTX user guide, - # so we specify it ourselves - if Int === Int64 - triple!(mod, "nvptx64-nvidia-cuda") - datalayout!(mod, "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64") - else - triple!(mod, "nvptx-nvidia-cuda") - datalayout!(mod, "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64") - end - - # add debug info metadata - push!(metadata(mod), "llvm.module.flags", - MDNode([ConstantInt(Int32(1)), # llvm::Module::Error - MDString("Debug Info Version"), - ConstantInt(DEBUG_METADATA_VERSION())])) -end - # make function names safe for PTX safe_fn(fn::String) = replace(fn, r"[^aA-zZ0-9_]"=>"_") safe_fn(f::Core.Function) = safe_fn(String(typeof(f).name.mt.name)) @@ -48,7 +30,6 @@ end function irgen(@nospecialize(f), @nospecialize(tt)) # get the method instance - isa(f, Core.Builtin) && throw(ArgumentError("argument is not a generic function")) world = typemax(UInt) meth = which(f, tt) sig_tt = Tuple{typeof(f), tt.parameters...} @@ -59,96 +40,81 @@ function irgen(@nospecialize(f), @nospecialize(tt)) (Any, Any, Any, UInt), meth, ti, env, world) # set-up the compiler interface - function hook_module_setup(ref::Ptr{Cvoid}) - ref = convert(LLVM.API.LLVMModuleRef, ref) - module_setup(LLVM.Module(ref)) - end function hook_raise_exception(insblock::Ptr{Cvoid}, ex::Ptr{Cvoid}) insblock = convert(LLVM.API.LLVMValueRef, insblock) ex = convert(LLVM.API.LLVMValueRef, ex) raise_exception(BasicBlock(insblock), Value(ex)) end - dependencies = Vector{LLVM.Module}() - function hook_module_activation(ref::Ptr{Cvoid}) - ref = convert(LLVM.API.LLVMModuleRef, ref) - push!(dependencies, LLVM.Module(ref)) - end params = Base.CodegenParams(track_allocations=false, code_coverage=false, static_alloc=false, prefer_specsig=true, - module_setup=hook_module_setup, - module_activation=hook_module_activation, raise_exception=hook_raise_exception) - # get the code - mod = let - ref = ccall(:jl_get_llvmf_defn, LLVM.API.LLVMValueRef, - (Any, UInt, Bool, Bool, Base.CodegenParams), - linfo, world, #=wrapper=#false, #=optimize=#false, params) - if ref == C_NULL - error("could not compile the specified method") - end - - llvmf = LLVM.Function(ref) - LLVM.parent(llvmf) - end - - # the main module should contain a single jlcall_ function definition, - # e.g. jlcall_kernel_vadd_62977 - definitions = filter(f->!isdeclaration(f), functions(mod)) - wrapper = let - fs = collect(filter(f->startswith(LLVM.name(f), "jlcall_"), definitions)) - @assert length(fs) == 1 - fs[1] - end - - # the jlcall wrapper function should point us to the actual entry-point, - # e.g. julia_kernel_vadd_62984 - entry_tag = let - m = match(r"jlcall_(.+)_\d+", LLVM.name(wrapper)) - @assert m != nothing - m.captures[1] - end - unsafe_delete!(mod, wrapper) - entry = let - re = Regex("julia_$(entry_tag)_\\d+") - llvmcall_re = Regex("julia_$(entry_tag)_\\d+u\\d+") - fs = collect(filter(f->contains(LLVM.name(f), re) && - !contains(LLVM.name(f), llvmcall_re), definitions)) - if length(fs) != 1 - error("Could not find single entry-point for $entry_tag (available functions: ", - join(map(f->LLVM.name(f), definitions), ", "), ")") - end - fs[1] + # generate IR + native_code = ccall(:jl_create_native, Ptr{Cvoid}, + (Vector{Core.MethodInstance}, Base.CodegenParams), + [linfo], params) + @assert native_code != C_NULL + llvm_mod_ref = ccall(:jl_get_llvm_module, LLVM.API.LLVMModuleRef, + (Ptr{Cvoid},), native_code) + @assert llvm_mod_ref != C_NULL + llvm_mod = LLVM.Module(llvm_mod_ref) + + # get the top-level function index + api = Ref{UInt8}(typemax(UInt8)) + llvm_func_idx = Ref{UInt32}() + llvm_specfunc_idx = Ref{UInt32}() + ccall(:jl_get_function_id, Nothing, + (Ptr{Cvoid}, Ptr{Core.MethodInstance}, Ptr{UInt8}, Ptr{UInt32}, Ptr{UInt32}), + native_code, Ref(linfo), api, llvm_func_idx, llvm_specfunc_idx) + @assert api[] != typemax(api[]) + + # get the top-level function + llvm_func_ref = ccall(:jl_get_llvm_function, LLVM.API.LLVMValueRef, + (Ptr{Cvoid}, UInt32), native_code, llvm_func_idx[]-1) + @assert llvm_func_ref != C_NULL + llvm_func = LLVM.Function(llvm_func_ref) + llvm_specfunc_ref = ccall(:jl_get_llvm_function, LLVM.API.LLVMValueRef, + (Ptr{Cvoid}, UInt32), native_code, llvm_specfunc_idx[]-1) + @assert llvm_specfunc_ref != C_NULL + llvm_specfunc = LLVM.Function(llvm_specfunc_ref) + + # configure the module + # NOTE: NVPTX::TargetMachine's data layout doesn't match the NVPTX user guide, + # so we specify it ourselves + if Int === Int64 + triple!(llvm_mod, "nvptx64-nvidia-cuda") + datalayout!(llvm_mod, "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64") + else + triple!(llvm_mod, "nvptx-nvidia-cuda") + datalayout!(llvm_mod, "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64") end - # link in dependent modules - for dep in dependencies - if VERSION < v"0.7.0-DEV.2513" - module_setup(dep) + # clean up incompatibilities + for llvm_func in functions(llvm_mod) + # remove non-specsig functions + fn = LLVM.name(llvm_func) + if startswith(fn, "jlcall_") + unsafe_delete!(llvm_mod, llvm_func) + continue end - link!(mod, dep) - end - # clean up incompatibilities - for llvmf in functions(mod) # only occurs in debug builds - delete!(function_attributes(llvmf), EnumAttribute("sspreq", 0, jlctx[])) + delete!(function_attributes(llvm_func), EnumAttribute("sspreq", 0, jlctx[])) # make function names safe for ptxas # (LLVM ought to do this, see eg. D17738 and D19126), but fails # TODO: fix all globals? - llvmfn = LLVM.name(llvmf) - if !isdeclaration(llvmf) - llvmfn′ = safe_fn(llvmf) - if llvmfn != llvmfn′ - LLVM.name!(llvmf, llvmfn′) + if !isdeclaration(llvm_func) + fn = safe_fn(llvm_func) + if fn != fn + LLVM.name!(llvm_func, fn) end end end - return mod, entry + return llvm_mod, llvm_specfunc end # promote a function to a kernel @@ -160,7 +126,7 @@ function promote_kernel!(mod::LLVM.Module, entry_f::LLVM.Function, @nospecialize kernel = wrap_entry!(mod, entry_f, tt); - # property annotations TODO: belongs in irgen? doesn't maxntidx doesn't appear in ptx code? + # property annotations annotations = LLVM.Value[kernel] @@ -480,6 +446,7 @@ function compile_function(@nospecialize(func), @nospecialize(tt), cap::VersionNu for e in errors warn("Encountered incompatible LLVM IR for $sig at capability $cap: ", e) end + println(mod) error("LLVM IR generated for $sig at capability $cap is not compatible") end