Skip to content

Commit

Permalink
simplify ffi wrappers (#169)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZuseZ4 authored Aug 16, 2024
1 parent 1a36626 commit 123b0e8
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 39 deletions.
14 changes: 0 additions & 14 deletions compiler/rustc_codegen_llvm/src/back/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1019,16 +1019,6 @@ pub(crate) unsafe fn enzyme_ad(
// A really simple check
assert!(src_num_args <= target_num_args);

// create enzyme typetrees
let llvm_data_layout = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) };
let llvm_data_layout =
std::str::from_utf8(unsafe { CStr::from_ptr(llvm_data_layout) }.to_bytes())
.expect("got a non-UTF8 data-layout from LLVM");

let input_tts =
item.inputs.into_iter().map(|x| to_enzyme_typetree(x, llvm_data_layout, llcx)).collect();
let output_tt = to_enzyme_typetree(item.output, llvm_data_layout, llcx);

let type_analysis: EnzymeTypeAnalysisRef =
unsafe {CreateTypeAnalysis(logic_ref, std::ptr::null_mut(), std::ptr::null_mut(), 0)};

Expand Down Expand Up @@ -1066,8 +1056,6 @@ pub(crate) unsafe fn enzyme_ad(
src_fnc,
args_activity,
ret_activity,
input_tts,
output_tt,
void_ret,
),
DiffMode::Reverse => enzyme_rust_reverse_diff(
Expand All @@ -1076,8 +1064,6 @@ pub(crate) unsafe fn enzyme_ad(
src_fnc,
args_activity,
ret_activity,
input_tts,
output_tt,
),
_ => unreachable!(),
};
Expand Down
4 changes: 0 additions & 4 deletions compiler/rustc_codegen_llvm/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,6 @@ pub fn add_tt2<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context, fn_def:

#[allow(unused)]
pub fn add_opt_dbg_helper<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context, val: &'ll Value, attrs: AutoDiffAttrs, i: usize) {
//pub mode: DiffMode,
//pub ret_activity: DiffActivity,
//pub input_activity: Vec<DiffActivity>,
let inputs = attrs.input_activity;
let outputs = attrs.ret_activity;
let ad_name = match attrs.mode {
Expand Down Expand Up @@ -136,7 +133,6 @@ pub fn add_opt_dbg_helper<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Contex
let num_args = llvm::LLVMCountParams(wrapper_fn);
let mut args = Vec::with_capacity(num_args as usize + 1);
args.push(val);
// metadata !"enzyme_const"
let enzyme_const = llvm::LLVMMDStringInContext(llcx, "enzyme_const".as_ptr() as *const c_char, 12);
let enzyme_out = llvm::LLVMMDStringInContext(llcx, "enzyme_out".as_ptr() as *const c_char, 10);
let enzyme_dup = llvm::LLVMMDStringInContext(llcx, "enzyme_dup".as_ptr() as *const c_char, 10);
Expand Down
23 changes: 2 additions & 21 deletions compiler/rustc_codegen_llvm/src/llvm/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -845,15 +845,12 @@ pub enum LLVMVerifierFailureAction {
LLVMReturnStatusAction,
}

#[allow(dead_code)]
pub(crate) unsafe fn enzyme_rust_forward_diff(
logic_ref: EnzymeLogicRef,
type_analysis: EnzymeTypeAnalysisRef,
fnc: &Value,
input_diffactivity: Vec<DiffActivity>,
ret_diffactivity: DiffActivity,
_input_tts: Vec<TypeTree>,
_output_tt: TypeTree,
void_ret: bool,
) -> (&Value, Vec<usize>) {
let ret_activity = cdiffe_from(ret_diffactivity);
Expand Down Expand Up @@ -882,9 +879,6 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
};
trace!("ret_primary_ret: {}", &ret_primary_ret);

//let mut args_tree = input_tts.iter().map(|x| x.inner).collect::<Vec<_>>();
//let mut args_tree = vec![TypeTree::new().inner; typetree.input_tt.len()];

// We don't support volatile / extern / (global?) values.
// Just because I didn't had time to test them, and it seems less urgent.
let args_uncacheable = vec![0; input_activity.len()];
Expand All @@ -900,9 +894,6 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
let tree_tmp = TypeTree::new();
let mut args_tree = vec![tree_tmp.inner; input_activity.len()];

//let mut args_tree = vec![std::ptr::null_mut(); input_activity.len()];
//let ret_tt = std::ptr::null_mut();
//let mut args_tree = vec![TypeTree::new().inner; input_tts.len()];
let ret_tt = TypeTree::new();
let dummy_type = CFnTypeInfo {
Arguments: args_tree.as_mut_ptr(),
Expand Down Expand Up @@ -944,8 +935,6 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
fnc: &Value,
rust_input_activity: Vec<DiffActivity>,
ret_activity: DiffActivity,
input_tts: Vec<TypeTree>,
_output_tt: TypeTree,
) -> (&Value, Vec<usize>) {
let (primary_ret, ret_activity) = match ret_activity {
DiffActivity::Const => (true, CDIFFE_TYPE::DFT_CONSTANT),
Expand All @@ -971,8 +960,6 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
input_activity.push(cdiffe_from(x));
}

//let args_tree = input_tts.iter().map(|x| x.inner).collect::<Vec<_>>();

// We don't support volatile / extern / (global?) values.
// Just because I didn't had time to test them, and it seems less urgent.
let args_uncacheable = vec![0; input_activity.len()];
Expand All @@ -982,14 +969,11 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
assert!(num_fnc_args == input_activity.len() as u32);
let kv_tmp = IntList { data: std::ptr::null_mut(), size: 0 };

let mut known_values = vec![kv_tmp; input_tts.len()];
let mut known_values = vec![kv_tmp; input_activity.len()];

let tree_tmp = TypeTree::new();
let mut args_tree = vec![tree_tmp.inner; input_tts.len()];
//let mut args_tree = vec![TypeTree::new().inner; input_tts.len()];
let mut args_tree = vec![tree_tmp.inner; input_activity.len()];
let ret_tt = TypeTree::new();
//let mut args_tree = vec![std::ptr::null_mut(); input_tts.len()];
//let ret_tt = std::ptr::null_mut();
let dummy_type = CFnTypeInfo {
Arguments: args_tree.as_mut_ptr(),
Return: ret_tt.inner,
Expand Down Expand Up @@ -1029,9 +1013,6 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
}

extern "C" {
// TODO: can I just ignore the non void return
// EraseFromParent doesn't exist :(
//pub fn LLVMEraseFromParent(BB: &BasicBlock) -> &Value;
// Enzyme
pub fn LLVMRustAddFncParamAttr<'a>(
F: &'a Value,
Expand Down

0 comments on commit 123b0e8

Please sign in to comment.