From 4246ddb42b18ef43958774b13621f6160f7992a8 Mon Sep 17 00:00:00 2001 From: Kristoffer Date: Mon, 20 Nov 2023 11:26:08 +0100 Subject: [PATCH] run `register_llvm_rules` in a module where optimizations are disabled --- src/compiler.jl | 2 +- src/rules/llvmrules.jl | 210 +++++++++++++++++++++-------------------- 2 files changed, 111 insertions(+), 101 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index dd4b156629..0ab52a726d 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -2450,7 +2450,7 @@ function __init__() API.EnzymeSetUndefinedValueForType(@cfunction( julia_undef_value_for_type, LLVM.API.LLVMValueRef, (LLVM.API.LLVMTypeRef,UInt8))) register_alloc_rules() - register_llvm_rules() + @time LLVMRules.register_llvm_rules() end # Define EnzymeTarget diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index 1ae10a3e57..cfddeacb09 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -1018,6 +1018,7 @@ end function register_handler!(variants, augfwd_handler, rev_handler, fwd_handler=nothing) for variant in variants + @show augfwd_handler, rev_handler, fwd_handler if augfwd_handler !== nothing && rev_handler !== nothing API.EnzymeRegisterCallHandler(variant, augfwd_handler, rev_handler) end @@ -1027,6 +1028,13 @@ function register_handler!(variants, augfwd_handler, rev_handler, fwd_handler=no end end +module LLVMRules + +@eval Base.Experimental.@compiler_options compile=min optimize=0 infer=false + +import ..LLVM, ..API, ..register_handler!, ..GradientUtils +import ..Compiler + macro augfunc(f) :(@cfunction((B, OrigCI, gutils, normalR, shadowR, tapeR) -> begin UInt8($f(LLVM.IRBuilder(B), LLVM.CallInst(OrigCI), GradientUtils(gutils), normalR, shadowR, tapeR)::Bool) @@ -1051,200 +1059,202 @@ end @inline function register_llvm_rules() register_handler!( ("julia.call",), - @augfunc(jlcall_augfwd), - @revfunc(jlcall_rev), - @fwdfunc(jlcall_fwd), + @augfunc(Compiler.jlcall_augfwd), + @revfunc(Compiler.jlcall_rev), + @fwdfunc(Compiler.jlcall_fwd), ) register_handler!( ("julia.call2",), - @augfunc(jlcall2_augfwd), - @revfunc(jlcall2_rev), - @fwdfunc(jlcall2_fwd), + @augfunc(Compiler.jlcall2_augfwd), + @revfunc(Compiler.jlcall2_rev), + @fwdfunc(Compiler.jlcall2_fwd), ) register_handler!( ("jl_apply_generic", "ijl_apply_generic"), - @augfunc(generic_augfwd), - @revfunc(generic_rev), - @fwdfunc(generic_fwd), + @augfunc(Compiler.generic_augfwd), + @revfunc(Compiler.generic_rev), + @fwdfunc(Compiler.generic_fwd), ) register_handler!( ("jl_invoke", "ijl_invoke", "jl_f_invoke"), - @augfunc(invoke_augfwd), - @revfunc(invoke_rev), - @fwdfunc(invoke_fwd), + @augfunc(Compiler.invoke_augfwd), + @revfunc(Compiler.invoke_rev), + @fwdfunc(Compiler.invoke_fwd), ) register_handler!( ("jl_f__apply_latest", "jl_f__call_latest"), - @augfunc(apply_latest_augfwd), - @revfunc(apply_latest_rev), - @fwdfunc(apply_latest_fwd), + @augfunc(Compiler.apply_latest_augfwd), + @revfunc(Compiler.apply_latest_rev), + @fwdfunc(Compiler.apply_latest_fwd), ) register_handler!( ("jl_threadsfor",), - @augfunc(threadsfor_augfwd), - @revfunc(threadsfor_rev), - @fwdfunc(threadsfor_fwd), + @augfunc(Compiler.threadsfor_augfwd), + @revfunc(Compiler.threadsfor_rev), + @fwdfunc(Compiler.threadsfor_fwd), ) register_handler!( ("jl_pmap",), - @augfunc(pmap_augfwd), - @revfunc(pmap_rev), - @fwdfunc(pmap_fwd), + @augfunc(Compiler.pmap_augfwd), + @revfunc(Compiler.pmap_rev), + @fwdfunc(Compiler.pmap_fwd), ) register_handler!( ("jl_new_task", "ijl_new_task"), - @augfunc(newtask_augfwd), - @revfunc(newtask_rev), - @fwdfunc(newtask_fwd), + @augfunc(Compiler.newtask_augfwd), + @revfunc(Compiler.newtask_rev), + @fwdfunc(Compiler.newtask_fwd), ) register_handler!( ("jl_set_task_threadpoolid", "ijl_set_task_threadpoolid"), - @augfunc(set_task_tid_augfwd), - @revfunc(set_task_tid_rev), - @fwdfunc(set_task_tid_fwd), + @augfunc(Compiler.set_task_tid_augfwd), + @revfunc(Compiler.set_task_tid_rev), + @fwdfunc(Compiler.set_task_tid_fwd), ) register_handler!( ("jl_enq_work",), - @augfunc(enq_work_augfwd), - @revfunc(enq_work_rev), - @fwdfunc(enq_work_fwd) + @augfunc(Compiler.enq_work_augfwd), + @revfunc(Compiler.enq_work_rev), + @fwdfunc(Compiler.enq_work_fwd) ) register_handler!( ("enzyme_custom",), - @augfunc(enzyme_custom_augfwd), - @revfunc(enzyme_custom_rev), - @fwdfunc(enzyme_custom_fwd) + @augfunc(Compiler.enzyme_custom_augfwd), + @revfunc(Compiler.enzyme_custom_rev), + @fwdfunc(Compiler.enzyme_custom_fwd) ) register_handler!( ("jl_wait",), - @augfunc(wait_augfwd), - @revfunc(wait_rev), - @fwdfunc(wait_fwd), + @augfunc(Compiler.wait_augfwd), + @revfunc(Compiler.wait_rev), + @fwdfunc(Compiler.wait_fwd), ) register_handler!( ("jl_","jl_breakpoint"), - @augfunc(noop_augfwd), - @revfunc(duplicate_rev), - @fwdfunc(noop_fwd), + @augfunc(Compiler.noop_augfwd), + @revfunc(Compiler.duplicate_rev), + @fwdfunc(Compiler.noop_fwd), ) register_handler!( ("jl_array_copy","ijl_array_copy"), - @augfunc(arraycopy_augfwd), - @revfunc(arraycopy_rev), - @fwdfunc(arraycopy_fwd), + @augfunc(Compiler.arraycopy_augfwd), + @revfunc(Compiler.arraycopy_rev), + @fwdfunc(Compiler.arraycopy_fwd), ) register_handler!( ("jl_reshape_array","ijl_reshape_array"), - @augfunc(arrayreshape_augfwd), - @revfunc(arrayreshape_rev), - @fwdfunc(arrayreshape_fwd), + @augfunc(Compiler.arrayreshape_augfwd), + @revfunc(Compiler.arrayreshape_rev), + @fwdfunc(Compiler.arrayreshape_fwd), ) register_handler!( ("jl_f_setfield","ijl_f_setfield"), - @augfunc(setfield_augfwd), - @revfunc(setfield_rev), - @fwdfunc(setfield_fwd), + @augfunc(Compiler.setfield_augfwd), + @revfunc(Compiler.setfield_rev), + @fwdfunc(Compiler.setfield_fwd), ) register_handler!( ("jl_box_float32","ijl_box_float32", "jl_box_float64", "ijl_box_float64"), - @augfunc(boxfloat_augfwd), - @revfunc(boxfloat_rev), - @fwdfunc(boxfloat_fwd), + @augfunc(Compiler.boxfloat_augfwd), + @revfunc(Compiler.boxfloat_rev), + @fwdfunc(Compiler.boxfloat_fwd), ) register_handler!( ("jl_f_tuple","ijl_f_tuple"), - @augfunc(f_tuple_augfwd), - @revfunc(f_tuple_rev), - @fwdfunc(f_tuple_fwd), + @augfunc(Compiler.f_tuple_augfwd), + @revfunc(Compiler.f_tuple_rev), + @fwdfunc(Compiler.f_tuple_fwd), ) register_handler!( ("jl_eqtable_get","ijl_eqtable_get"), - @augfunc(eqtableget_augfwd), - @revfunc(eqtableget_rev), - @fwdfunc(eqtableget_fwd), + @augfunc(Compiler.eqtableget_augfwd), + @revfunc(Compiler.eqtableget_rev), + @fwdfunc(Compiler.eqtableget_fwd), ) register_handler!( ("jl_eqtable_put","ijl_eqtable_put"), - @augfunc(eqtableput_augfwd), - @revfunc(eqtableput_rev), - @fwdfunc(eqtableput_fwd), + @augfunc(Compiler.eqtableput_augfwd), + @revfunc(Compiler.eqtableput_rev), + @fwdfunc(Compiler.eqtableput_fwd), ) register_handler!( ("jl_idtable_rehash","ijl_idtable_rehash"), - @augfunc(idtablerehash_augfwd), - @revfunc(idtablerehash_rev), - @fwdfunc(idtablerehash_fwd), + @augfunc(Compiler.idtablerehash_augfwd), + @revfunc(Compiler.idtablerehash_rev), + @fwdfunc(Compiler.idtablerehash_fwd), ) register_handler!( ("jl_f__apply_iterate","ijl_f__apply_iterate"), - @augfunc(apply_iterate_augfwd), - @revfunc(apply_iterate_rev), - @fwdfunc(apply_iterate_fwd), + @augfunc(Compiler.apply_iterate_augfwd), + @revfunc(Compiler.apply_iterate_rev), + @fwdfunc(Compiler.apply_iterate_fwd), ) register_handler!( ("jl_f__svec_ref","ijl_f__svec_ref"), - @augfunc(f_svec_ref_augfwd), - @revfunc(f_svec_ref_rev), - @fwdfunc(f_svec_ref_fwd), + @augfunc(Compiler.f_svec_ref_augfwd), + @revfunc(Compiler.f_svec_ref_rev), + @fwdfunc(Compiler.f_svec_ref_fwd), ) register_handler!( ("jl_new_structv","ijl_new_structv"), - @augfunc(new_structv_augfwd), - @revfunc(new_structv_rev), - @fwdfunc(new_structv_fwd), + @augfunc(Compiler.new_structv_augfwd), + @revfunc(Compiler.new_structv_rev), + @fwdfunc(Compiler.new_structv_fwd), ) register_handler!( ("jl_get_binding_or_error", "ijl_get_binding_or_error"), - @augfunc(get_binding_or_error_augfwd), - @revfunc(get_binding_or_error_rev), - @fwdfunc(get_binding_or_error_fwd), + @augfunc(Compiler.get_binding_or_error_augfwd), + @revfunc(Compiler.get_binding_or_error_rev), + @fwdfunc(Compiler.get_binding_or_error_fwd), ) register_handler!( ("jl_gc_add_finalizer_th","ijl_gc_add_finalizer_th", "jl_gc_add_ptr_finalizer","ijl_gc_add_ptr_finalizer"), - @augfunc(finalizer_augfwd), - @revfunc(finalizer_rev), - @fwdfunc(finalizer_fwd), + @augfunc(Compiler.finalizer_augfwd), + @revfunc(Compiler.finalizer_rev), + @fwdfunc(Compiler.finalizer_fwd), ) register_handler!( ("jl_array_grow_end","ijl_array_grow_end"), - @augfunc(jl_array_grow_end_augfwd), - @revfunc(jl_array_grow_end_rev), - @fwdfunc(jl_array_grow_end_fwd), + @augfunc(Compiler.jl_array_grow_end_augfwd), + @revfunc(Compiler.jl_array_grow_end_rev), + @fwdfunc(Compiler.jl_array_grow_end_fwd), ) register_handler!( ("jl_array_del_end","ijl_array_del_end"), - @augfunc(jl_array_del_end_augfwd), - @revfunc(jl_array_del_end_rev), - @fwdfunc(jl_array_del_end_fwd), + @augfunc(Compiler.jl_array_del_end_augfwd), + @revfunc(Compiler.jl_array_del_end_rev), + @fwdfunc(Compiler.jl_array_del_end_fwd), ) register_handler!( ("jl_f_getfield","ijl_f_getfield"), - @augfunc(jl_getfield_augfwd), - @revfunc(jl_getfield_rev), - @fwdfunc(jl_getfield_fwd), + @augfunc(Compiler.jl_getfield_augfwd), + @revfunc(Compiler.jl_getfield_rev), + @fwdfunc(Compiler.jl_getfield_fwd), ) register_handler!( ("ijl_get_nth_field_checked","jl_get_nth_field_checked"), - @augfunc(jl_nthfield_augfwd), - @revfunc(jl_nthfield_rev), - @fwdfunc(jl_nthfield_fwd), + @augfunc(Compiler.jl_nthfield_augfwd), + @revfunc(Compiler.jl_nthfield_rev), + @fwdfunc(Compiler.jl_nthfield_fwd), ) register_handler!( ("jl_array_sizehint","ijl_array_sizehint"), - @augfunc(jl_array_sizehint_augfwd), - @revfunc(jl_array_sizehint_rev), - @fwdfunc(jl_array_sizehint_fwd), + @augfunc(Compiler.jl_array_sizehint_augfwd), + @revfunc(Compiler.jl_array_sizehint_rev), + @fwdfunc(Compiler.jl_array_sizehint_fwd), ) register_handler!( ("jl_array_ptr_copy","ijl_array_ptr_copy"), - @augfunc(jl_array_ptr_copy_augfwd), - @revfunc(jl_array_ptr_copy_rev), - @fwdfunc(jl_array_ptr_copy_fwd), + @augfunc(Compiler.jl_array_ptr_copy_augfwd), + @revfunc(Compiler.jl_array_ptr_copy_rev), + @fwdfunc(Compiler.jl_array_ptr_copy_fwd), ) register_handler!( (), - @augfunc(jl_unhandled_augfwd), - @revfunc(jl_unhandled_rev), - @fwdfunc(jl_unhandled_fwd), + @augfunc(Compiler.jl_unhandled_augfwd), + @revfunc(Compiler.jl_unhandled_rev), + @fwdfunc(Compiler.jl_unhandled_fwd), ) -end \ No newline at end of file +end + +end