From b050fd3753a1bd7f53ab8f7de1da0a81982fe010 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 11 Aug 2024 14:58:04 -0400 Subject: [PATCH] boxfloat fixup --- src/rules/llvmrules.jl | 61 +++++++++++++++++++++++++++--------------- 1 file changed, 39 insertions(+), 22 deletions(-) diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index 664d643af0..7b27cf298e 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -589,7 +589,12 @@ end @register_fwd function boxfloat_fwd(B, orig, gutils, normalR, shadowR) origops = collect(operands(orig)) width = get_width(gutils) - if is_constant_value(gutils, orig) + + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + + if is_constant_value(gutils, orig) || needsShadowP[] == 0 return true end @@ -616,7 +621,12 @@ end @register_aug function boxfloat_augfwd(B, orig, gutils, normalR, shadowR, tapeR) origops = collect(operands(orig)) width = get_width(gutils) - if is_constant_value(gutils, orig) + + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + + if is_constant_value(gutils, orig) || needsShadowP[] == 0 return true end @@ -642,30 +652,37 @@ end end @register_rev function boxfloat_rev(B, orig, gutils, tape) + + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + + if is_constant_value(gutils, orig) || needsShadowP[] == 0 + return nothing + end + origops = collect(operands(orig)) width = get_width(gutils) - if !is_constant_value(gutils, orig) - ip = lookup_value(gutils, invert_pointer(gutils, orig, B), B) - flt = value_type(origops[1]) - if width == 1 - ipc = bitcast!(B, ip, LLVM.PointerType(flt, addrspace(value_type(orig)))) + ip = lookup_value(gutils, invert_pointer(gutils, orig, B), B) + flt = value_type(origops[1]) + if width == 1 + ipc = bitcast!(B, ip, LLVM.PointerType(flt, addrspace(value_type(orig)))) + ld = load!(B, flt, ipc) + store!(B, ConstantFP(flt, 0.0), ipc) + if !is_constant_value(gutils, origops[1]) + API.EnzymeGradientUtilsAddToDiffe(gutils, origops[1], ld, B, flt) + end + else + shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, flt))) + for idx in 1:width + ipc = extract_value!(B, ip, idx-1) + ipc = bitcast!(B, ipc, LLVM.PointerType(flt, addrspace(value_type(orig)))) ld = load!(B, flt, ipc) store!(B, ConstantFP(flt, 0.0), ipc) - if !is_constant_value(gutils, origops[1]) - API.EnzymeGradientUtilsAddToDiffe(gutils, origops[1], ld, B, flt) - end - else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, flt))) - for idx in 1:width - ipc = extract_value!(B, ip, idx-1) - ipc = bitcast!(B, ipc, LLVM.PointerType(flt, addrspace(value_type(orig)))) - ld = load!(B, flt, ipc) - store!(B, ConstantFP(flt, 0.0), ipc) - shadowres = insert_value!(B, shadowres, ld, idx-1) - end - if !is_constant_value(gutils, origops[1]) - API.EnzymeGradientUtilsAddToDiffe(gutils, origops[1], shadowret, B, flt) - end + shadowres = insert_value!(B, shadowres, ld, idx-1) + end + if !is_constant_value(gutils, origops[1]) + API.EnzymeGradientUtilsAddToDiffe(gutils, origops[1], shadowret, B, flt) end end return nothing