diff --git a/runtime/src/debug_helper.rs b/runtime/src/debug_helper.rs index f05e61a..cc63039 100644 --- a/runtime/src/debug_helper.rs +++ b/runtime/src/debug_helper.rs @@ -6,9 +6,12 @@ pub unsafe fn trace_continuation(continuation: *mut Continuation) { while !continuation.is_null() { println!(" \x1b[32m[{:p}]\x1b[0m", continuation); println!(" func: {:p}", (*continuation).func); - println!(" arg_stack_frame_height: {}", (*continuation).arg_stack_frame_height); + println!( + " arg_stack_frame_height: {}", + (*continuation).arg_stack_frame_height + ); println!(" next: {:p}", (*continuation).next); println!(" state: {}", (*continuation).state); continuation = (*continuation).next; } -} \ No newline at end of file +} diff --git a/runtime/src/lib.rs b/runtime/src/lib.rs index 9d849c6..79508c0 100644 --- a/runtime/src/lib.rs +++ b/runtime/src/lib.rs @@ -1,4 +1,4 @@ -pub mod runtime_utils; +pub mod debug_helper; pub mod runtime; +pub mod runtime_utils; pub mod types; -pub mod debug_helper; \ No newline at end of file diff --git a/runtime/src/runtime_utils.rs b/runtime/src/runtime_utils.rs index 9713b1c..503d939 100644 --- a/runtime/src/runtime_utils.rs +++ b/runtime/src/runtime_utils.rs @@ -1,12 +1,15 @@ +use crate::runtime::HandlerEntry::SimpleOperationMarker; +use crate::runtime::{ + CapturedContinuation, ContImplPtr, Continuation, Eff, Generic, Handler, HandlerEntry, + HandlerType, HandlerTypeOrdinal, RawFuncPtr, ThunkPtr, Uniform, +}; +use crate::types::{UPtr, UniformPtr, UniformType}; use enum_ordinalize::Ordinalize; use std::arch::global_asm; use std::cell::RefCell; use std::iter::Peekable; -use std::ops::{DerefMut}; +use std::ops::DerefMut; use std::slice::IterMut; -use crate::runtime::{Continuation, HandlerEntry, Handler, Uniform, ThunkPtr, Eff, Generic, CapturedContinuation, RawFuncPtr, ContImplPtr, HandlerTypeOrdinal, HandlerType}; -use crate::runtime::HandlerEntry::SimpleOperationMarker; -use crate::types::{UPtr, UniformPtr, UniformType}; // TODO: use custom allocator that allocates through Boehm GC for vecs thread_local!( @@ -14,7 +17,8 @@ thread_local!( ); #[cfg(target_arch = "aarch64")] -global_asm!(r#" +global_asm!( + r#" .global _runtime_mark_handler .global _long_jump @@ -68,7 +72,8 @@ global_asm!(r#" mov x3, x1 mov x2, x0 br x16 -"#); +"# +); extern "C" { /// Store FP, SP, and return address to the handler entry, then invoke the input function. @@ -76,9 +81,9 @@ extern "C" { /// in Cranelifts' tail call convention. /// On ARM64, the first argument starts at x2. pub fn runtime_mark_handler( - input_base_address: *mut Uniform, // x2 - next_continuation: *mut Continuation, // x3 - input_func_ptr: RawFuncPtr, // x4 + input_base_address: *mut Uniform, // x2 + next_continuation: *mut Continuation, // x3 + input_func_ptr: RawFuncPtr, // x4 handler: *const Handler<*mut Uniform>, // x5 ) -> *const Uniform; @@ -86,14 +91,13 @@ extern "C" { /// This function follows the normal call convention. But it ends up calling the next continuation, which is in /// Cranelift's tail call convention. So on ARM64, it needs to shift the arguments. fn long_jump( - next_base_address: *mut Uniform, // x0 + next_base_address: *mut Uniform, // x0 next_continuation: *const Continuation, // x1 - result_ptr: *const Uniform, // x2 - handler: *const Handler<*mut Uniform>, // x3 + result_ptr: *const Uniform, // x2 + handler: *const Handler<*mut Uniform>, // x3 ) -> !; } - static EMPTY_STRUCT: usize = 0; const TRANSFORM_LOADER_NUM_ARGS: usize = 1; @@ -122,11 +126,13 @@ pub unsafe extern "C" fn runtime_word_box() -> *mut usize { ptr } - /// Takes a pointer to a function or thunk, push any arguments to the tip of the stack, and return /// a pointer to the underlying raw function. #[no_mangle] -pub unsafe extern "C" fn runtime_force_thunk(thunk: ThunkPtr, tip_address_ptr: *mut *mut usize) -> RawFuncPtr { +pub unsafe extern "C" fn runtime_force_thunk( + thunk: ThunkPtr, + tip_address_ptr: *mut *mut usize, +) -> RawFuncPtr { let thunk_ptr = thunk.to_normal_ptr(); match UniformType::from_bits(thunk as usize) { UniformType::PPtr => thunk_ptr, @@ -142,7 +148,7 @@ pub unsafe extern "C" fn runtime_force_thunk(thunk: ThunkPtr, tip_address_ptr: * tip_address_ptr.write(tip_address); runtime_force_thunk(next_thunk, tip_address_ptr) } - _ => unreachable!("bad thunk pointer") + _ => unreachable!("bad thunk pointer"), } } @@ -162,11 +168,14 @@ pub unsafe extern "C" fn runtime_alloc_stack() -> *mut usize { } /// Returns the result of the operation in uniform representation -pub unsafe fn debug_helper(base: *const usize, last_result_ptr: *const usize, thunk: *const usize) -> usize { +pub unsafe fn debug_helper( + base: *const usize, + last_result_ptr: *const usize, + thunk: *const usize, +) -> usize { return 1 + 1; } - /// Returns the following results on the argument stack. /// - ptr + 0: the function pointer to the matched handler implementationon /// - ptr + 8: the base address used to find the arguments when invoking the handler implementation @@ -218,10 +227,11 @@ unsafe fn prepare_complex_operation( handler_index: usize, handler_impl: ThunkPtr, ) -> (*const usize, *mut usize, *mut Continuation) { - let handler_entry_fragment = HANDLERS.with(|handler| handler.borrow_mut().split_off(handler_index)); + let handler_entry_fragment = + HANDLERS.with(|handler| handler.borrow_mut().split_off(handler_index)); let matching_handler = match handler_entry_fragment.first().unwrap() { HandlerEntry::Handler(handler) => handler, - _ => panic!("Expect a handler entry") + _ => panic!("Expect a handler entry"), }; // Update the tip continuation so that its height no longer includes the arguments passed to @@ -242,27 +252,37 @@ unsafe fn prepare_complex_operation( (*next_continuation).arg_stack_frame_height += handler_num_args + 2 - TRANSFORM_LOADER_NUM_ARGS; // Copy the stack fragment. - let stack_fragment_end = matching_handler.transform_loader_base_address.add(TRANSFORM_LOADER_NUM_ARGS); + let stack_fragment_end = matching_handler + .transform_loader_base_address + .add(TRANSFORM_LOADER_NUM_ARGS); let stack_fragment_start = handler_call_base_address.add(handler_num_args); let stack_fragment_length = stack_fragment_end.offset_from(stack_fragment_start); assert!(stack_fragment_length >= 0); - let stack_fragment: Vec = std::slice::from_raw_parts(stack_fragment_start, stack_fragment_length as usize).to_vec(); + let stack_fragment: Vec = + std::slice::from_raw_parts(stack_fragment_start, stack_fragment_length as usize).to_vec(); // Copy the handler fragment. let matching_base_address = matching_handler.transform_loader_base_address; - let handler_fragment: Vec> = handler_entry_fragment.into_iter().map(|handler_entry| { - match handler_entry { + let handler_fragment: Vec> = handler_entry_fragment + .into_iter() + .map(|handler_entry| match handler_entry { HandlerEntry::Handler(handler) => { let transform_base_address = handler.transform_loader_base_address; let mut handler: Handler = std::mem::transmute(handler); - handler.transform_loader_base_address = matching_base_address.offset_from(transform_base_address) as usize; + handler.transform_loader_base_address = + matching_base_address.offset_from(transform_base_address) as usize; handler } - _ => panic!("Expect a handler entry") - } - }).collect(); + _ => panic!("Expect a handler entry"), + }) + .collect(); - let captured_continuation_thunk = create_captured_continuation(captured_continuation_thunk_impl, tip_continuation, stack_fragment, handler_fragment); + let captured_continuation_thunk = create_captured_continuation( + captured_continuation_thunk_impl, + tip_continuation, + stack_fragment, + handler_fragment, + ); let mut new_tip_address = stack_fragment_end; @@ -285,7 +305,8 @@ unsafe fn prepare_complex_operation( let tip_address_before_forcing_handler_impl = new_tip_address; let handler_function_ptr = runtime_force_thunk(handler_impl, &mut new_tip_address); // unpacking captured arguments of the handler would affect frame height of the next continuation as well. - (*next_continuation).arg_stack_frame_height += tip_address_before_forcing_handler_impl.offset_from(new_tip_address) as usize; + (*next_continuation).arg_stack_frame_height += + tip_address_before_forcing_handler_impl.offset_from(new_tip_address) as usize; (handler_function_ptr, new_tip_address, next_continuation) } @@ -295,7 +316,8 @@ unsafe fn create_captured_continuation( stack_fragment: Vec, handler_fragment: Vec>, ) -> *mut usize { - let captured_continuation = runtime_alloc((std::mem::size_of::()) / 8) as *mut CapturedContinuation; + let captured_continuation = runtime_alloc((std::mem::size_of::()) / 8) + as *mut CapturedContinuation; *captured_continuation = CapturedContinuation { tip_continuation, @@ -306,11 +328,18 @@ unsafe fn create_captured_continuation( create_captured_continuation_thunk(captured_continuation_thunk_impl, captured_continuation) } -unsafe fn create_captured_continuation_thunk(captured_continuation_thunk_impl: RawFuncPtr, captured_continuation: *mut CapturedContinuation) -> *mut usize { +unsafe fn create_captured_continuation_thunk( + captured_continuation_thunk_impl: RawFuncPtr, + captured_continuation: *mut CapturedContinuation, +) -> *mut usize { let captured_continuation_thunk = runtime_alloc(3); - captured_continuation_thunk.write(UniformType::to_uniform_pptr(captured_continuation_thunk_impl)); + captured_continuation_thunk.write(UniformType::to_uniform_pptr( + captured_continuation_thunk_impl, + )); captured_continuation_thunk.add(1).write(1); - captured_continuation_thunk.add(2).write(UniformType::to_uniform_sptr(captured_continuation)); + captured_continuation_thunk + .add(2) + .write(UniformType::to_uniform_sptr(captured_continuation)); captured_continuation_thunk } @@ -342,13 +371,20 @@ unsafe fn prepare_simple_operation( handler_type: HandlerType, ) -> (*const usize, *mut usize, *mut Continuation) { assert!(handler_type != HandlerType::Complex); - let matching_handler = HANDLERS.with(|handler| match handler.borrow_mut().get_mut(handler_index).unwrap() { - HandlerEntry::Handler(handler) => handler as *mut Handler<*mut Uniform>, - _ => panic!("Expect a handler entry") + let matching_handler = + HANDLERS.with( + |handler| match handler.borrow_mut().get_mut(handler_index).unwrap() { + HandlerEntry::Handler(handler) => handler as *mut Handler<*mut Uniform>, + _ => panic!("Expect a handler entry"), + }, + ); + + HANDLERS.with(|handler| { + handler + .borrow_mut() + .push(SimpleOperationMarker { handler_index }) }); - HANDLERS.with(|handler| handler.borrow_mut().push(SimpleOperationMarker { handler_index })); - let mut tip_address = handler_call_base_address; tip_address = tip_address.sub(1); tip_address.write((*matching_handler).parameter); @@ -370,7 +406,11 @@ unsafe fn prepare_simple_operation( // We can still update the frame height of the next continuation here, but that would be redundant and hence left // out. - (simple_handler_runner_impl_ptr, tip_address, next_continuation) + ( + simple_handler_runner_impl_ptr, + tip_address, + next_continuation, + ) } /// Special function that may do long jump instead of normal return if the result is exceptional. If @@ -391,7 +431,7 @@ pub unsafe extern "C" fn runtime_process_simple_handler_result( handler.parameter = simple_result.handler_parameter; handler as *mut Handler<*mut Uniform> } - _ => panic!("Expect a handler entry") + _ => panic!("Expect a handler entry"), }; // pop the simple handler entry marker let last_entry = handlers.pop().unwrap(); @@ -406,7 +446,7 @@ pub unsafe extern "C" fn runtime_process_simple_handler_result( let ptr = simple_result.result_value; (ptr.value, ptr.tag >> 1) } - _ => unreachable!("bad simple operation type") + _ => unreachable!("bad simple operation type"), }; match tag { @@ -434,10 +474,15 @@ pub unsafe extern "C" fn runtime_process_simple_handler_result( // Here we use the matching handler because the jump needs to restore execution to the state at the matching // handler. This mismatch between fp, sp, and lr with the continuation and argument stack is fine because // the continuation is a disposer loader continuation, which does not care about fp, sp, or lr. - long_jump(next_base_address, next_continuation, result_ptr, matching_handler); - } + long_jump( + next_base_address, + next_continuation, + result_ptr, + matching_handler, + ); + }, 1 => value, - _ => unreachable!("bad simple result tag") + _ => unreachable!("bad simple result tag"), } } @@ -447,7 +492,7 @@ pub unsafe extern "C" fn runtime_pop_handler() -> Uniform { let handler = handlers.borrow_mut().pop().unwrap(); match handler { HandlerEntry::Handler(handler) => handler.parameter, - _ => panic!("Expect a handler entry") + _ => panic!("Expect a handler entry"), } }) } @@ -489,7 +534,9 @@ unsafe fn convert_handler_transformers_to_disposers( }); let mut last_handler = std::ptr::null(); for entry in handlers[matching_handler_index..].iter_mut() { - let HandlerEntry::Handler(handler) = entry else { unreachable!() }; + let HandlerEntry::Handler(handler) = entry else { + unreachable!() + }; let parameter_disposer = handler.parameter_disposer as Uniform; (*handler.transform_loader_continuation).next = next_continuation; (*handler.transform_loader_continuation).func = runtime_disposer_loader_cps_impl; @@ -499,15 +546,21 @@ unsafe fn convert_handler_transformers_to_disposers( // supposed to be consumed. And for debug build, it's set to a very large value so // that it can allow arbitrary increment and decrement without overflow or // underflow. - (*next_continuation).arg_stack_frame_height = next_base_address.offset_from(handler.transform_loader_base_address) as usize; + (*next_continuation).arg_stack_frame_height = + next_base_address.offset_from(handler.transform_loader_base_address) as usize; } next_continuation = handler.transform_loader_continuation; next_base_address = handler.transform_loader_base_address; - handler.transform_loader_base_address.write(parameter_disposer); + handler + .transform_loader_base_address + .write(parameter_disposer); last_handler = handler; } - ((*last_handler).transform_loader_continuation, (*last_handler).transform_loader_base_address) + ( + (*last_handler).transform_loader_continuation, + (*last_handler).transform_loader_base_address, + ) }) } @@ -528,12 +581,16 @@ fn find_matching_handler(eff: Eff, may_be_complex: usize) -> (usize, ThunkPtr, H } for (e, handler_impl, simple_operation_type) in &handler.simple_handler { if unsafe { compare_uniform(eff, *e) } { - return (handler_index, *handler_impl, match simple_operation_type { - 0 => HandlerType::Exceptional, - 1 => HandlerType::Linear, - 2 => HandlerType::Affine, - _ => unreachable!("bad simple operation type") - }); + return ( + handler_index, + *handler_impl, + match simple_operation_type { + 0 => HandlerType::Exceptional, + 1 => HandlerType::Linear, + 2 => HandlerType::Affine, + _ => unreachable!("bad simple operation type"), + }, + ); } } i = handler_index; @@ -547,7 +604,6 @@ fn find_matching_handler(eff: Eff, may_be_complex: usize) -> (usize, ThunkPtr, H }) } - #[no_mangle] pub unsafe extern "C" fn runtime_register_handler( tip_address_ptr: *mut *mut usize, @@ -600,20 +656,26 @@ pub unsafe extern "C" fn runtime_register_handler( general_callee_saved_registers: [0; 10], vector_callee_saved_registers: [0; 8], })); - match handlers.borrow().last().unwrap() - { + match handlers.borrow().last().unwrap() { HandlerEntry::Handler(handler) => handler as *const Handler<*mut Uniform>, - _ => panic!("Expect a handler entry") + _ => panic!("Expect a handler entry"), } }) } #[no_mangle] -pub extern "C" fn runtime_add_handler(handler: &mut Handler<*mut Uniform>, eff: Eff, handler_impl: ThunkPtr, handler_type: HandlerTypeOrdinal) { +pub extern "C" fn runtime_add_handler( + handler: &mut Handler<*mut Uniform>, + eff: Eff, + handler_impl: ThunkPtr, + handler_type: HandlerTypeOrdinal, +) { if handler_type == HandlerType::Complex.ordinal() as usize { handler.complex_handler.push((eff, handler_impl)) } else { - handler.simple_handler.push((eff, handler_impl, handler_type)); + handler + .simple_handler + .push((eff, handler_impl, handler_type)); } } @@ -654,7 +716,6 @@ pub unsafe extern "C" fn runtime_prepare_resume_continuation( tip_address } - /// Returns a pointer pointing to the following: /// - ptr + 0: the function pointer to the resumed continuation /// - ptr + 8: the base address for the resumed continuation to find its arguments @@ -743,7 +804,8 @@ unsafe fn unpack_captured_continuation( let mut handlers = handlers.borrow_mut(); for handler in captured_continuation.handler_fragment.into_iter() { handlers.push(HandlerEntry::Handler(Handler { - transform_loader_base_address: transform_loader_base_address.sub(handler.transform_loader_base_address), + transform_loader_base_address: transform_loader_base_address + .sub(handler.transform_loader_base_address), transform_loader_continuation: handler.transform_loader_continuation, parameter: handler.parameter, parameter_disposer: handler.parameter_disposer, @@ -772,7 +834,10 @@ pub unsafe extern "C" fn runtime_replicate_continuation( parameter: Uniform, frame_pointer: *const u8, stack_pointer: *const u8, - runtime_invoke_cps_function_with_trivial_continuation: extern fn(RawFuncPtr, *mut Uniform) -> *mut Uniform, + runtime_invoke_cps_function_with_trivial_continuation: extern "C" fn( + RawFuncPtr, + *mut Uniform, + ) -> *mut Uniform, captured_continuation_thunk_impl: RawFuncPtr, ) -> *const Uniform { let matching_handler_index = HANDLERS.with(|handlers| handlers.borrow().len()); @@ -795,7 +860,7 @@ pub unsafe extern "C" fn runtime_replicate_continuation( while handlers.len() > matching_handler_index { let handler = match handlers.pop().unwrap() { HandlerEntry::Handler(handler) => handler, - _ => panic!("Expect a handler entry") + _ => panic!("Expect a handler entry"), }; let parameter_replicator = handler.parameter_replicator; if parameter_replicator as Uniform == 0b11 { @@ -807,13 +872,19 @@ pub unsafe extern "C" fn runtime_replicate_continuation( let mut tip_address = base_address.sub(1); tip_address.write(handler.parameter); let replicator_func_ptr = runtime_force_thunk(parameter_replicator, &mut tip_address); - let parameter_pair = runtime_invoke_cps_function_with_trivial_continuation(replicator_func_ptr, tip_address); + let parameter_pair = runtime_invoke_cps_function_with_trivial_continuation( + replicator_func_ptr, + tip_address, + ); parameters.push((parameter_pair.read(), parameter_pair.add(1).read())); } }); let mut cloned_handler_fragment = Vec::new(); - for ((parameter1, parameter2), handler) in parameters.into_iter().zip(captured_continuation.handler_fragment.iter_mut()) { + for ((parameter1, parameter2), handler) in parameters + .into_iter() + .zip(captured_continuation.handler_fragment.iter_mut()) + { handler.parameter = parameter1; let cloned_handler = Handler { transform_loader_continuation: handler.transform_loader_continuation, @@ -833,12 +904,14 @@ pub unsafe extern "C" fn runtime_replicate_continuation( } let mut cloned_continuation = std::ptr::null_mut(); - clone_continuation(captured_continuation.tip_continuation, &mut cloned_continuation, cloned_handler_fragment.iter_mut().peekable()); - - let captured_continuation_thunk1 = create_captured_continuation_thunk( - captured_continuation_thunk_impl, - captured_continuation, + clone_continuation( + captured_continuation.tip_continuation, + &mut cloned_continuation, + cloned_handler_fragment.iter_mut().peekable(), ); + + let captured_continuation_thunk1 = + create_captured_continuation_thunk(captured_continuation_thunk_impl, captured_continuation); let captured_continuation_thunk2 = create_captured_continuation( captured_continuation_thunk_impl, cloned_continuation, @@ -849,10 +922,15 @@ pub unsafe extern "C" fn runtime_replicate_continuation( // Write the last result let last_result_address = base_address.sub(1); let pair_of_captured_continuation_thuns = runtime_alloc(2); - pair_of_captured_continuation_thuns.write(UniformType::to_uniform_sptr(captured_continuation_thunk1)); - pair_of_captured_continuation_thuns.add(1).write(UniformType::to_uniform_sptr(captured_continuation_thunk2)); + pair_of_captured_continuation_thuns + .write(UniformType::to_uniform_sptr(captured_continuation_thunk1)); + pair_of_captured_continuation_thuns + .add(1) + .write(UniformType::to_uniform_sptr(captured_continuation_thunk2)); - last_result_address.write(UniformType::to_uniform_sptr(pair_of_captured_continuation_thuns)); + last_result_address.write(UniformType::to_uniform_sptr( + pair_of_captured_continuation_thuns, + )); // Write the return values of this helper function. let result_ptr = last_result_address.sub(3); @@ -865,7 +943,11 @@ pub unsafe extern "C" fn runtime_replicate_continuation( result_ptr } -unsafe fn clone_continuation(continuation: *mut Continuation, cloned_continuation_ptr: &mut *mut Continuation, mut iter_mut: Peekable>>) { +unsafe fn clone_continuation( + continuation: *mut Continuation, + cloned_continuation_ptr: &mut *mut Continuation, + mut iter_mut: Peekable>>, +) { let continuation_addr = continuation as *const usize; let length = (continuation_addr).sub(1).read(); let cloned_continuation_addr = runtime_alloc(length); @@ -884,7 +966,11 @@ unsafe fn clone_continuation(continuation: *mut Continuation, cloned_continuatio if next_continuation.is_null() { return; } - clone_continuation(next_continuation, &mut (*cloned_continuation).next, iter_mut) + clone_continuation( + next_continuation, + &mut (*cloned_continuation).next, + iter_mut, + ) } unsafe fn compare_uniform(a: Uniform, b: Uniform) -> bool { diff --git a/runtime/src/types.rs b/runtime/src/types.rs index 431b551..debd557 100644 --- a/runtime/src/types.rs +++ b/runtime/src/types.rs @@ -81,7 +81,7 @@ impl UniformPtr<*const usize> for usize { #[derive(Debug, Clone, Copy, PartialEq)] #[repr(C, align(8))] -pub struct UPtr (*const u8, PhantomData); +pub struct UPtr(*const u8, PhantomData); impl<'a, T> UPtr<&'a T> { pub fn as_uniform(&self) -> usize { diff --git a/src/ast/free_var.rs b/src/ast/free_var.rs index 2c1df70..313cc02 100644 --- a/src/ast/free_var.rs +++ b/src/ast/free_var.rs @@ -1,6 +1,6 @@ -use std::collections::{HashMap, HashSet}; use crate::ast::term::{CTerm, VTerm}; use crate::ast::visitor::Visitor; +use std::collections::{HashMap, HashSet}; pub trait HasFreeVar { fn free_vars(&mut self) -> HashSet; @@ -8,7 +8,10 @@ pub trait HasFreeVar { impl HasFreeVar for CTerm { fn free_vars(&mut self) -> HashSet { - let mut visitor = FreeVarVisitor { free_vars: HashSet::new(), binding_count: HashMap::new() }; + let mut visitor = FreeVarVisitor { + free_vars: HashSet::new(), + binding_count: HashMap::new(), + }; visitor.visit_c_term(self, ()); visitor.free_vars } @@ -16,7 +19,10 @@ impl HasFreeVar for CTerm { impl HasFreeVar for VTerm { fn free_vars(&mut self) -> HashSet { - let mut visitor = FreeVarVisitor { free_vars: HashSet::new(), binding_count: HashMap::new() }; + let mut visitor = FreeVarVisitor { + free_vars: HashSet::new(), + binding_count: HashMap::new(), + }; visitor.visit_v_term(self, ()); visitor.free_vars } @@ -31,11 +37,17 @@ impl Visitor for FreeVarVisitor { type Ctx = (); fn add_binding(&mut self, name: usize, _: ()) { - self.binding_count.insert(name, self.binding_count.get(&name).cloned().unwrap_or(0) + 1); + self.binding_count.insert( + name, + self.binding_count.get(&name).cloned().unwrap_or(0) + 1, + ); } fn remove_binding(&mut self, name: usize, _: ()) { - self.binding_count.insert(name, self.binding_count.get(&name).cloned().unwrap_or(0) - 1); + self.binding_count.insert( + name, + self.binding_count.get(&name).cloned().unwrap_or(0) - 1, + ); } fn visit_var(&mut self, _v_term: &VTerm, _: ()) { diff --git a/src/ast/primitive_functions.rs b/src/ast/primitive_functions.rs index 9a441b4..ed80997 100644 --- a/src/ast/primitive_functions.rs +++ b/src/ast/primitive_functions.rs @@ -1,7 +1,7 @@ -use cranelift::prelude::{FunctionBuilder, InstBuilder, Value}; -use phf::phf_map; use crate::ast::term::SpecializedType; use crate::ast::term::VType; +use cranelift::prelude::{FunctionBuilder, InstBuilder, Value}; +use phf::phf_map; use SpecializedType::*; use VType::*; @@ -137,4 +137,3 @@ pub static PRIMITIVE_FUNCTIONS: phf::Map<&'static str, &'static PrimitiveFunctio } }, }; - diff --git a/src/ast/signature.rs b/src/ast/signature.rs index c36ffbb..899cb6e 100644 --- a/src/ast/signature.rs +++ b/src/ast/signature.rs @@ -1,11 +1,11 @@ -use std::cmp::{min, Ordering}; -use std::collections::{HashMap}; -use itertools::Itertools; use crate::ast::free_var::HasFreeVar; -use crate::ast::primitive_functions::{PRIMITIVE_FUNCTIONS, PrimitiveFunction}; +use crate::ast::primitive_functions::{PrimitiveFunction, PRIMITIVE_FUNCTIONS}; use crate::ast::term::{CTerm, CType, Effect, VTerm, VType}; use crate::ast::transformer::Transformer; use crate::ast::visitor::Visitor; +use itertools::Itertools; +use std::cmp::{min, Ordering}; +use std::collections::HashMap; #[derive(Debug, Clone)] pub struct FunctionDefinition { @@ -103,107 +103,186 @@ impl Signature { fn reduce_redundancy(&mut self) { let mut normalizer = RedundancyRemover {}; - self.defs.iter_mut().for_each(|(_, FunctionDefinition { body, .. })| { - normalizer.transform_c_term(body); - }); + self.defs + .iter_mut() + .for_each(|(_, FunctionDefinition { body, .. })| { + normalizer.transform_c_term(body); + }); } - fn specialize_calls(&mut self) { let mut new_defs: Vec<(String, FunctionDefinition)> = Vec::new(); - let specializable_functions: HashMap<_, _> = self.defs.iter() - .filter_map(|(name, FunctionDefinition { args, c_type, need_simple: may_be_simple, .. })| { - if let CType::SpecializedF(_) = c_type && *may_be_simple { - Some((name.clone(), args.len())) - } else { - None - } - }).collect(); + let specializable_functions: HashMap<_, _> = self + .defs + .iter() + .filter_map( + |( + name, + FunctionDefinition { + args, + c_type, + need_simple: may_be_simple, + .. + }, + )| { + if let CType::SpecializedF(_) = c_type + && *may_be_simple + { + Some((name.clone(), args.len())) + } else { + None + } + }, + ) + .collect(); - self.defs.iter_mut().for_each(|(name, FunctionDefinition { body, .. })| { - let mut specializer = CallSpecializer { def_name: name, new_defs: &mut new_defs, primitive_wrapper_counter: 0, specializable_functions: &specializable_functions }; - specializer.transform_c_term(body); - }); + self.defs + .iter_mut() + .for_each(|(name, FunctionDefinition { body, .. })| { + let mut specializer = CallSpecializer { + def_name: name, + new_defs: &mut new_defs, + primitive_wrapper_counter: 0, + specializable_functions: &specializable_functions, + }; + specializer.transform_c_term(body); + }); self.insert_new_defs(new_defs); } /// Assume all local variables are distinct. Also this transformation preserves this property. fn reduce_immediate_redexes(&mut self) { let mut reducer = RedexReducer {}; - self.defs.iter_mut().for_each(|(_, FunctionDefinition { body, .. })| { - reducer.transform_c_term(body); - }); + self.defs + .iter_mut() + .for_each(|(_, FunctionDefinition { body, .. })| { + reducer.transform_c_term(body); + }); } /// Assume all local variables are distinct. Also this transformation preserves this property. fn lift_lambdas(&mut self) { let mut new_defs: Vec<(String, FunctionDefinition)> = Vec::new(); - self.defs.iter_mut().for_each(|(name, FunctionDefinition { args, body, var_bound, .. })| { - let local_var_types = &mut vec![VType::Uniform; *var_bound]; - for (i, ty) in args { - local_var_types[*i] = *ty; - } - let lifter = LambdaLifter { def_name: name, counter: 0, new_defs: &mut new_defs, local_var_types }; - let mut thunk_lifter = lifter; - thunk_lifter.transform_c_term(body); - }); + self.defs.iter_mut().for_each( + |( + name, + FunctionDefinition { + args, + body, + var_bound, + .. + }, + )| { + let local_var_types = &mut vec![VType::Uniform; *var_bound]; + for (i, ty) in args { + local_var_types[*i] = *ty; + } + let lifter = LambdaLifter { + def_name: name, + counter: 0, + new_defs: &mut new_defs, + local_var_types, + }; + let mut thunk_lifter = lifter; + thunk_lifter.transform_c_term(body); + }, + ); self.insert_new_defs(new_defs); } fn insert_new_defs(&mut self, new_defs: Vec<(String, FunctionDefinition)>) { - for (name, FunctionDefinition { - mut args, - mut body, - c_type, - mut var_bound, - need_simple: may_be_simple, - need_cps: may_be_complex, - need_specialized: may_be_specialized, - }) in new_defs.into_iter() { - Self::rename_local_vars_in_def(&mut args, &mut body, &mut var_bound); - self.insert(name, FunctionDefinition { - args, - body, + for ( + name, + FunctionDefinition { + mut args, + mut body, c_type, - var_bound, + mut var_bound, need_simple: may_be_simple, need_cps: may_be_complex, need_specialized: may_be_specialized, - }) + }, + ) in new_defs.into_iter() + { + Self::rename_local_vars_in_def(&mut args, &mut body, &mut var_bound); + self.insert( + name, + FunctionDefinition { + args, + body, + c_type, + var_bound, + need_simple: may_be_simple, + need_cps: may_be_complex, + need_specialized: may_be_specialized, + }, + ) } } fn rename_local_vars(&mut self) { - self.defs.iter_mut().for_each(|(_, FunctionDefinition { args, body, var_bound: max_arg_size, .. })| { - Self::rename_local_vars_in_def(args, body, max_arg_size); - }); + self.defs.iter_mut().for_each( + |( + _, + FunctionDefinition { + args, + body, + var_bound: max_arg_size, + .. + }, + )| { + Self::rename_local_vars_in_def(args, body, max_arg_size); + }, + ); } fn remove_duplicate_defs(&mut self) { let mut def_replacement: HashMap = HashMap::new(); let mut def_content_map = HashMap::new(); - for (name, FunctionDefinition { args, body, c_type, .. }) in self.defs.iter().sorted_by_key(|(name, _)| *name) { - let chosen_name = def_content_map.entry((args, body, c_type)).or_insert_with(|| name); + for ( + name, + FunctionDefinition { + args, body, c_type, .. + }, + ) in self.defs.iter().sorted_by_key(|(name, _)| *name) + { + let chosen_name = def_content_map + .entry((args, body, c_type)) + .or_insert_with(|| name); if *chosen_name != name { def_replacement.insert(name.clone(), chosen_name.clone()); } } for (name, replacement) in def_replacement.iter() { let function_definition = self.defs.remove(name).unwrap(); - let FunctionDefinition { need_simple, need_cps, need_specialized, .. } = function_definition; + let FunctionDefinition { + need_simple, + need_cps, + need_specialized, + .. + } = function_definition; let replacement_function_definition = self.defs.get_mut(replacement).unwrap(); replacement_function_definition.need_simple |= need_simple; replacement_function_definition.need_cps |= need_cps; replacement_function_definition.need_specialized |= need_specialized; } let mut def_replacer = DefReplacer { def_replacement }; - self.defs.iter_mut().for_each(|(_, FunctionDefinition { body, .. })| { - def_replacer.transform_c_term(body); - }); + self.defs + .iter_mut() + .for_each(|(_, FunctionDefinition { body, .. })| { + def_replacer.transform_c_term(body); + }); } - fn rename_local_vars_in_def(args: &mut [(usize, VType)], body: &mut CTerm, var_bound: &mut usize) { - let mut renamer = DistinctVarRenamer { bindings: HashMap::new(), counter: 0 }; + fn rename_local_vars_in_def( + args: &mut [(usize, VType)], + body: &mut CTerm, + var_bound: &mut usize, + ) { + let mut renamer = DistinctVarRenamer { + bindings: HashMap::new(), + counter: 0, + }; for (i, _) in args.iter_mut() { *i = renamer.add_binding(*i); } @@ -218,20 +297,29 @@ impl Visitor for Signature { type Ctx = Effect; fn visit_thunk(&mut self, v_term: &VTerm, context_effect: Effect) { - let VTerm::Thunk { box t, .. } = v_term else { unreachable!() }; + let VTerm::Thunk { box t, .. } = v_term else { + unreachable!() + }; let empty_vec = vec![]; let (name, args) = match t { - CTerm::Redex { function: box CTerm::Def { name, .. }, args, } => (name, args), + CTerm::Redex { + function: box CTerm::Def { name, .. }, + args, + } => (name, args), CTerm::Def { name, .. } => (name, &empty_vec), _ => panic!("all thunks should have been lifted at this point"), }; - args.iter().for_each(|arg| self.visit_v_term(arg, context_effect)); + args.iter() + .for_each(|arg| self.visit_v_term(arg, context_effect)); self.enable(name, FunctionEnablement::Cps); } fn visit_redex(&mut self, c_term: &CTerm, context_effect: Effect) { - let CTerm::Redex { box function, args } = c_term else { unreachable!() }; - args.iter().for_each(|arg| self.visit_v_term(arg, context_effect)); + let CTerm::Redex { box function, args } = c_term else { + unreachable!() + }; + args.iter() + .for_each(|arg| self.visit_v_term(arg, context_effect)); let CTerm::Def { name, effect } = function else { self.visit_c_term(function, context_effect); return; @@ -249,9 +337,10 @@ impl Visitor for Signature { } } - fn visit_def(&mut self, c_term: &CTerm, context_effect: Effect) { - let CTerm::Def { name, effect } = c_term else { unreachable!() }; + let CTerm::Def { name, effect } = c_term else { + unreachable!() + }; let effect = effect.intersect(context_effect); if effect == Effect::Complex { self.enable(name, FunctionEnablement::Cps); @@ -308,23 +397,31 @@ impl Transformer for RedundancyRemover { match c_term { CTerm::Redex { function, args } => { if args.is_empty() { - let mut placeholder = CTerm::Return { value: VTerm::Int { value: 1 } }; + let mut placeholder = CTerm::Return { + value: VTerm::Int { value: 1 }, + }; std::mem::swap(&mut placeholder, function); *c_term = placeholder; } else { let is_nested_redex = matches!(function.as_ref(), CTerm::Redex { .. }); if is_nested_redex { - let mut placeholder = CTerm::Return { value: VTerm::Int { value: 1 } }; + let mut placeholder = CTerm::Return { + value: VTerm::Int { value: 1 }, + }; std::mem::swap(&mut placeholder, c_term); match placeholder { - CTerm::Redex { function, args } => { - match *function { - CTerm::Redex { function: sub_function, args: sub_args } => { - *c_term = CTerm::Redex { function: sub_function, args: sub_args.into_iter().chain(args).collect() }; - } - _ => unreachable!(), + CTerm::Redex { function, args } => match *function { + CTerm::Redex { + function: sub_function, + args: sub_args, + } => { + *c_term = CTerm::Redex { + function: sub_function, + args: sub_args.into_iter().chain(args).collect(), + }; } - } + _ => unreachable!(), + }, _ => unreachable!(), } } @@ -336,16 +433,32 @@ impl Transformer for RedundancyRemover { fn transform_lambda(&mut self, c_term: &mut CTerm) { self.transform_lambda_default(c_term); - let CTerm::Lambda { args, box body, effect } = c_term else { unreachable!() }; - if let CTerm::Lambda { args: sub_args, body: box sub_body, effect: sub_effect } = body { + let CTerm::Lambda { + args, + box body, + effect, + } = c_term + else { + unreachable!() + }; + if let CTerm::Lambda { + args: sub_args, + body: box sub_body, + effect: sub_effect, + } = body + { *effect = effect.union(*sub_effect); args.extend(sub_args.iter().copied()); - let mut placeholder = CTerm::Return { value: VTerm::Int { value: 0 } }; + let mut placeholder = CTerm::Return { + value: VTerm::Int { value: 0 }, + }; std::mem::swap(&mut placeholder, sub_body); *body = placeholder } if args.is_empty() { - let mut placeholder = CTerm::Return { value: VTerm::Int { value: 0 } }; + let mut placeholder = CTerm::Return { + value: VTerm::Int { value: 0 }, + }; std::mem::swap(&mut placeholder, body); *c_term = placeholder; } @@ -362,34 +475,60 @@ struct CallSpecializer<'a> { impl<'a> Transformer for CallSpecializer<'a> { fn transform_redex(&mut self, c_term: &mut CTerm) { self.transform_redex_default(c_term); - let CTerm::Redex { box function, args } = c_term else { unreachable!() }; - let CTerm::Def { name, effect } = function else { return; }; - if let Some((name, PrimitiveFunction { arg_types, return_type, .. })) = PRIMITIVE_FUNCTIONS.get_entry(name) { + let CTerm::Redex { box function, args } = c_term else { + unreachable!() + }; + let CTerm::Def { name, effect } = function else { + return; + }; + if let Some(( + name, + PrimitiveFunction { + arg_types, + return_type, + .. + }, + )) = PRIMITIVE_FUNCTIONS.get_entry(name) + { match arg_types.len().cmp(&args.len()) { Ordering::Greater => { - let primitive_wrapper_name = format!("{}$__primitive_wrapper_{}", self.def_name, self.primitive_wrapper_counter); + let primitive_wrapper_name = format!( + "{}$__primitive_wrapper_{}", + self.def_name, self.primitive_wrapper_counter + ); // Primitive calls cannot be effectful. assert_eq!(*effect, Effect::Simple); - *function = CTerm::Def { name: primitive_wrapper_name.clone(), effect: Effect::Simple }; - self.new_defs.push((primitive_wrapper_name, FunctionDefinition { - args: arg_types.iter().enumerate().map(|(i, t)| (i, *t)).collect(), - body: CTerm::PrimitiveCall { - name, - args: (0..arg_types.len()).map(|index| VTerm::Var { index }).collect(), + *function = CTerm::Def { + name: primitive_wrapper_name.clone(), + effect: Effect::Simple, + }; + self.new_defs.push(( + primitive_wrapper_name, + FunctionDefinition { + args: arg_types.iter().enumerate().map(|(i, t)| (i, *t)).collect(), + body: CTerm::PrimitiveCall { + name, + args: (0..arg_types.len()) + .map(|index| VTerm::Var { index }) + .collect(), + }, + c_type: CType::SpecializedF(*return_type), + var_bound: arg_types.len(), + need_simple: false, + need_cps: false, + need_specialized: false, }, - c_type: CType::SpecializedF(*return_type), - var_bound: arg_types.len(), - need_simple: false, - need_cps: false, - need_specialized: false, - })) + )) } Ordering::Equal => { - let CTerm::Redex { args, .. } = std::mem::replace( - c_term, - CTerm::PrimitiveCall { name, args: vec![] }, - ) else { unreachable!() }; - let CTerm::PrimitiveCall { args: new_args, .. } = c_term else { unreachable!() }; + let CTerm::Redex { args, .. } = + std::mem::replace(c_term, CTerm::PrimitiveCall { name, args: vec![] }) + else { + unreachable!() + }; + let CTerm::PrimitiveCall { args: new_args, .. } = c_term else { + unreachable!() + }; *new_args = args; } Ordering::Less => { @@ -412,18 +551,30 @@ impl<'a> LambdaLifter<'a> { let thunk_def_name = format!("{}$__lambda_{}", self.def_name, self.counter); self.counter += 1; - let redex = - CTerm::Redex { - function: Box::new(CTerm::Def { name: thunk_def_name.clone(), effect }), - args: free_vars.iter().map(|i| VTerm::Var { index: *i }).collect(), - }; + let redex = CTerm::Redex { + function: Box::new(CTerm::Def { + name: thunk_def_name.clone(), + effect, + }), + args: free_vars.iter().map(|i| VTerm::Var { index: *i }).collect(), + }; (thunk_def_name, redex) } - fn create_new_def(&mut self, name: String, free_vars: Vec, args: &[(usize, VType)], body: CTerm) { + fn create_new_def( + &mut self, + name: String, + free_vars: Vec, + args: &[(usize, VType)], + body: CTerm, + ) { let var_bound = *free_vars.iter().max().unwrap_or(&0); let function_definition = FunctionDefinition { - args: free_vars.into_iter().map(|v| (v, self.local_var_types[v])).chain(args.iter().copied()).collect(), + args: free_vars + .into_iter() + .map(|v| (v, self.local_var_types[v])) + .chain(args.iter().copied()) + .collect(), body, c_type: CType::Default, var_bound, @@ -437,10 +588,21 @@ impl<'a> LambdaLifter<'a> { impl<'a> Transformer for LambdaLifter<'a> { fn transform_thunk(&mut self, v_term: &mut VTerm) { - let VTerm::Thunk { t: box c_term, effect } = v_term else { unreachable!() }; + let VTerm::Thunk { + t: box c_term, + effect, + } = v_term + else { + unreachable!() + }; self.transform_c_term(c_term); - if let CTerm::Redex { function: box CTerm::Def { .. }, .. } | CTerm::Def { .. } = c_term { + if let CTerm::Redex { + function: box CTerm::Def { .. }, + .. + } + | CTerm::Def { .. } = c_term + { // There is no need to lift the thunk if it's already a simple function call. return; } @@ -458,13 +620,17 @@ impl<'a> Transformer for LambdaLifter<'a> { let mut free_vars: Vec<_> = c_term.free_vars().iter().copied().collect(); free_vars.sort(); - let CTerm::Lambda { args, effect, .. } = c_term else { unreachable!() }; + let CTerm::Lambda { args, effect, .. } = c_term else { + unreachable!() + }; let (thunk_def_name, mut redex) = self.create_new_redex(&free_vars, *effect); let args = args.clone(); std::mem::swap(c_term, &mut redex); - let CTerm::Lambda { box body, .. } = redex else { unreachable!() }; + let CTerm::Lambda { box body, .. } = redex else { + unreachable!() + }; self.create_new_def(thunk_def_name, free_vars, &args, body); } @@ -475,12 +641,26 @@ struct RedexReducer {} impl Transformer for RedexReducer { fn transform_redex(&mut self, c_term: &mut CTerm) { self.transform_redex_default(c_term); - let CTerm::Redex { box function, args } = c_term else { unreachable!() }; - let CTerm::Lambda { args: lambda_args, body: box lambda_body, .. } = function else { return; }; + let CTerm::Redex { box function, args } = c_term else { + unreachable!() + }; + let CTerm::Lambda { + args: lambda_args, + body: box lambda_body, + .. + } = function + else { + return; + }; let num_args = min(args.len(), lambda_args.len()); let matching_args = args.drain(..num_args).collect::>(); - let matching_lambda_args = lambda_args.drain(..num_args).map(|(index, _)| index).collect::>(); - let mut substitutor = Substitutor { bindings: HashMap::from_iter(matching_lambda_args.into_iter().zip(matching_args)) }; + let matching_lambda_args = lambda_args + .drain(..num_args) + .map(|(index, _)| index) + .collect::>(); + let mut substitutor = Substitutor { + bindings: HashMap::from_iter(matching_lambda_args.into_iter().zip(matching_args)), + }; substitutor.transform_c_term(lambda_body); RedundancyRemover {}.transform_c_term(c_term); // Call the reducer again to reduce the new redex. This is terminating because any loops @@ -490,10 +670,16 @@ impl Transformer for RedexReducer { fn transform_force(&mut self, c_term: &mut CTerm) { self.transform_force_default(c_term); - let CTerm::Force { thunk, .. } = c_term else { unreachable!() }; - let VTerm::Thunk { box t, .. } = thunk else { return; }; + let CTerm::Force { thunk, .. } = c_term else { + unreachable!() + }; + let VTerm::Thunk { box t, .. } = thunk else { + return; + }; // The content of placeholder does not matter here - let mut placeholder = CTerm::Return { value: VTerm::Var { index: 0 } }; + let mut placeholder = CTerm::Return { + value: VTerm::Var { index: 0 }, + }; std::mem::swap(t, &mut placeholder); *c_term = placeholder; // Call the reducer again to reduce the new redex. This is terminating because any loops @@ -508,7 +694,9 @@ struct Substitutor { impl Transformer for Substitutor { fn transform_var(&mut self, v_term: &mut VTerm) { - let VTerm::Var { index } = v_term else { unreachable!() }; + let VTerm::Var { index } = v_term else { + unreachable!() + }; if let Some(replacement) = self.bindings.get(index) { *v_term = replacement.clone(); } @@ -521,9 +709,11 @@ struct DefReplacer { impl Transformer for DefReplacer { fn transform_def(&mut self, c_term: &mut CTerm) { - let CTerm::Def { name, .. } = c_term else { unreachable!() }; + let CTerm::Def { name, .. } = c_term else { + unreachable!() + }; if let Some(replacement) = self.def_replacement.get(name.as_str()) { *name = replacement.clone() } } -} \ No newline at end of file +} diff --git a/src/ast/term.rs b/src/ast/term.rs index 1015390..793c41d 100644 --- a/src/ast/term.rs +++ b/src/ast/term.rs @@ -82,12 +82,23 @@ impl Effect { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum VTerm { - Var { index: usize }, - Thunk { t: Box, effect: Effect }, + Var { + index: usize, + }, + Thunk { + t: Box, + effect: Effect, + }, /// 51 bit integer represented as a machine word with highest bits sign-extended - Int { value: i64 }, - Str { value: String }, - Struct { values: Vec }, + Int { + value: i64, + }, + Str { + value: String, + }, + Struct { + values: Vec, + }, // TODO: the following types are not yet implemented in the AST yet // PrimitiveArray { values: Vec }, // F64 { value: f64 }, @@ -98,30 +109,68 @@ pub enum VTerm { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum CTerm { - Redex { function: Box, args: Vec }, - Return { value: VTerm }, - Force { thunk: VTerm, effect: Effect }, - Let { t: Box, bound_index: usize, body: Box }, + Redex { + function: Box, + args: Vec, + }, + Return { + value: VTerm, + }, + Force { + thunk: VTerm, + effect: Effect, + }, + Let { + t: Box, + bound_index: usize, + body: Box, + }, /// Note: the flag means whether the function has complex effects that need to be handled by /// some handlers. System effects like IO do not count because they appear to be a simple call. /// This value should be conservatively set to true if side effects are unknown. For example, /// the containing redex may not have enough arguments to determine the effects of this /// computation. - Def { name: String, effect: Effect }, + Def { + name: String, + effect: Effect, + }, /// Note on result type, different branches can have different computation types. For example, /// for record instance or function with large elimination, one branch may return a value while /// another consumes more arguments (aka tail call). In this case the result type is /// [CType::Default]. - CaseInt { t: VTerm, result_type: CType, branches: Vec<(i64, CTerm)>, default_branch: Option> }, - Lambda { args: Vec<(usize, VType)>, body: Box, effect: Effect }, - MemGet { base: VTerm, offset: VTerm }, - MemSet { base: VTerm, offset: VTerm, value: VTerm }, + CaseInt { + t: VTerm, + result_type: CType, + branches: Vec<(i64, CTerm)>, + default_branch: Option>, + }, + Lambda { + args: Vec<(usize, VType)>, + body: Box, + effect: Effect, + }, + MemGet { + base: VTerm, + offset: VTerm, + }, + MemSet { + base: VTerm, + offset: VTerm, + value: VTerm, + }, // TODO: implement the following for setting and getting primitive values // PMemGet { base: VTerm, offset: VTerm, p_type: PType }, // PMemSet { base: VTerm, offset: VTerm, value: VTerm, p_type: PType }, - PrimitiveCall { name: &'static str, args: Vec }, + PrimitiveCall { + name: &'static str, + args: Vec, + }, /// Note: effect cannot be pure for operation calls. - OperationCall { eff: VTerm, args: Vec, effect: Effect }, + OperationCall { + eff: VTerm, + args: Vec, + effect: Effect, + }, Handler { parameter: VTerm, /// thunk of lambda: parameter -> 0 @@ -140,4 +189,4 @@ pub enum CTerm { /// the transform continuation between this input and the parent term of this handler. input: VTerm, }, -} \ No newline at end of file +} diff --git a/src/ast/transformer.rs b/src/ast/transformer.rs index 8c2abc4..79c4c4e 100644 --- a/src/ast/transformer.rs +++ b/src/ast/transformer.rs @@ -1,7 +1,9 @@ use crate::ast::term::{CTerm, VTerm}; pub trait Transformer { - fn add_binding(&mut self, name: usize) -> usize { name } + fn add_binding(&mut self, name: usize) -> usize { + name + } fn remove_binding(&mut self, _name: usize) {} @@ -17,13 +19,17 @@ pub trait Transformer { fn transform_var(&mut self, _v_term: &mut VTerm) {} fn transform_thunk(&mut self, v_term: &mut VTerm) { - let VTerm::Thunk { t, .. } = v_term else { unreachable!() }; + let VTerm::Thunk { t, .. } = v_term else { + unreachable!() + }; self.transform_c_term(t); } fn transform_int(&mut self, _v_term: &mut VTerm) {} fn transform_str(&mut self, _v_term: &mut VTerm) {} fn transform_tuple(&mut self, v_term: &mut VTerm) { - let VTerm::Struct { values } = v_term else { unreachable!() }; + let VTerm::Struct { values } = v_term else { + unreachable!() + }; for v in values { self.transform_v_term(v); } @@ -51,13 +57,17 @@ pub trait Transformer { } fn transform_redex_default(&mut self, c_term: &mut CTerm) { - let CTerm::Redex { function, args } = c_term else { unreachable!() }; + let CTerm::Redex { function, args } = c_term else { + unreachable!() + }; self.transform_c_term(function); args.iter_mut().for_each(|arg| self.transform_v_term(arg)); } fn transform_return(&mut self, c_term: &mut CTerm) { - let CTerm::Return { value } = c_term else { unreachable!() }; + let CTerm::Return { value } = c_term else { + unreachable!() + }; self.transform_v_term(value); } @@ -66,12 +76,22 @@ pub trait Transformer { } fn transform_force_default(&mut self, c_term: &mut CTerm) { - let CTerm::Force { thunk, .. } = c_term else { unreachable!() }; + let CTerm::Force { thunk, .. } = c_term else { + unreachable!() + }; self.transform_v_term(thunk); } fn transform_let(&mut self, c_term: &mut CTerm) { - let CTerm::Let { t, body, bound_index: bound_name, .. } = c_term else { unreachable!() }; + let CTerm::Let { + t, + body, + bound_index: bound_name, + .. + } = c_term + else { + unreachable!() + }; self.transform_c_term(t); let old_name = *bound_name; *bound_name = self.add_binding(*bound_name); @@ -82,7 +102,15 @@ pub trait Transformer { fn transform_def(&mut self, _c_term: &mut CTerm) {} fn transform_case_int(&mut self, c_term: &mut CTerm) { - let CTerm::CaseInt { t, branches, default_branch, .. } = c_term else { unreachable!() }; + let CTerm::CaseInt { + t, + branches, + default_branch, + .. + } = c_term + else { + unreachable!() + }; self.transform_v_term(t); for (_, branch) in branches.iter_mut() { self.transform_c_term(branch); @@ -97,33 +125,51 @@ pub trait Transformer { } fn transform_lambda_default(&mut self, c_term: &mut CTerm) { - let CTerm::Lambda { args, body, .. } = c_term else { unreachable!() }; + let CTerm::Lambda { args, body, .. } = c_term else { + unreachable!() + }; let old_args = args.clone(); - args.iter_mut().for_each(|(arg, _)| *arg = self.add_binding(*arg)); + args.iter_mut() + .for_each(|(arg, _)| *arg = self.add_binding(*arg)); self.transform_c_term(body); - old_args.iter().for_each(|(arg, _)| self.remove_binding(*arg)); + old_args + .iter() + .for_each(|(arg, _)| self.remove_binding(*arg)); } fn transform_mem_get(&mut self, c_term: &mut CTerm) { - let CTerm::MemGet { base, offset } = c_term else { unreachable!() }; + let CTerm::MemGet { base, offset } = c_term else { + unreachable!() + }; self.transform_v_term(base); self.transform_v_term(offset); } fn transform_mem_set(&mut self, c_term: &mut CTerm) { - let CTerm::MemSet { base, offset, value } = c_term else { unreachable!() }; + let CTerm::MemSet { + base, + offset, + value, + } = c_term + else { + unreachable!() + }; self.transform_v_term(base); self.transform_v_term(offset); self.transform_v_term(value); } fn transform_primitive_call(&mut self, c_term: &mut CTerm) { - let CTerm::PrimitiveCall { args, .. } = c_term else { unreachable!() }; + let CTerm::PrimitiveCall { args, .. } = c_term else { + unreachable!() + }; args.iter_mut().for_each(|arg| self.transform_v_term(arg)); } fn transform_operation_call(&mut self, c_term: &mut CTerm) { - let CTerm::OperationCall { eff, args, .. } = c_term else { unreachable!() }; + let CTerm::OperationCall { eff, args, .. } = c_term else { + unreachable!() + }; self.transform_v_term(eff); args.iter_mut().for_each(|arg| self.transform_v_term(arg)); } @@ -141,11 +187,18 @@ pub trait Transformer { parameter_replicator, transform, handlers, - input - } = c_term else { unreachable!() }; + input, + } = c_term + else { + unreachable!() + }; self.transform_v_term(parameter); - parameter_disposer.iter_mut().for_each(|disposer| self.transform_v_term(disposer)); - parameter_replicator.iter_mut().for_each(|replicator| self.transform_v_term(replicator)); + parameter_disposer + .iter_mut() + .for_each(|disposer| self.transform_v_term(disposer)); + parameter_replicator + .iter_mut() + .for_each(|replicator| self.transform_v_term(replicator)); self.transform_v_term(transform); for (eff, handler, ..) in handlers.iter_mut() { self.transform_v_term(eff); @@ -153,4 +206,4 @@ pub trait Transformer { } self.transform_v_term(input); } -} \ No newline at end of file +} diff --git a/src/ast/visitor.rs b/src/ast/visitor.rs index 96e90d0..365ada1 100644 --- a/src/ast/visitor.rs +++ b/src/ast/visitor.rs @@ -19,13 +19,17 @@ pub trait Visitor { fn visit_var(&mut self, _v_term: &VTerm, _ctx: Self::Ctx) {} fn visit_thunk(&mut self, v_term: &VTerm, ctx: Self::Ctx) { - let VTerm::Thunk { t, .. } = v_term else { unreachable!() }; + let VTerm::Thunk { t, .. } = v_term else { + unreachable!() + }; self.visit_c_term(t, ctx); } fn visit_int(&mut self, _v_term: &VTerm, _ctx: Self::Ctx) {} fn visit_str(&mut self, _v_term: &VTerm, _ctx: Self::Ctx) {} fn visit_tuple(&mut self, v_term: &VTerm, ctx: Self::Ctx) { - let VTerm::Struct { values } = v_term else { unreachable!() }; + let VTerm::Struct { values } = v_term else { + unreachable!() + }; for v in values { self.visit_v_term(v, ctx); } @@ -49,23 +53,36 @@ pub trait Visitor { } fn visit_redex(&mut self, c_term: &CTerm, ctx: Self::Ctx) { - let CTerm::Redex { function, args } = c_term else { unreachable!() }; + let CTerm::Redex { function, args } = c_term else { + unreachable!() + }; self.visit_c_term(function, ctx); args.iter().for_each(|arg| self.visit_v_term(arg, ctx)); } fn visit_return(&mut self, c_term: &CTerm, ctx: Self::Ctx) { - let CTerm::Return { value } = c_term else { unreachable!() }; + let CTerm::Return { value } = c_term else { + unreachable!() + }; self.visit_v_term(value, ctx); } fn visit_force(&mut self, c_term: &CTerm, ctx: Self::Ctx) { - let CTerm::Force { thunk, .. } = c_term else { unreachable!() }; + let CTerm::Force { thunk, .. } = c_term else { + unreachable!() + }; self.visit_v_term(thunk, ctx); } fn visit_let(&mut self, c_term: &CTerm, ctx: Self::Ctx) { - let CTerm::Let { t, body, bound_index: bound_name } = c_term else { unreachable!() }; + let CTerm::Let { + t, + body, + bound_index: bound_name, + } = c_term + else { + unreachable!() + }; self.visit_c_term(t, ctx); self.add_binding(*bound_name, ctx); self.visit_c_term(body, ctx); @@ -75,7 +92,15 @@ pub trait Visitor { fn visit_def(&mut self, _c_term: &CTerm, _ctx: Self::Ctx) {} fn visit_case_int(&mut self, c_term: &CTerm, ctx: Self::Ctx) { - let CTerm::CaseInt { t, branches, default_branch, .. } = c_term else { unreachable!() }; + let CTerm::CaseInt { + t, + branches, + default_branch, + .. + } = c_term + else { + unreachable!() + }; self.visit_v_term(t, ctx); for (_, branch) in branches.iter() { self.visit_c_term(branch, ctx); @@ -86,32 +111,48 @@ pub trait Visitor { } fn visit_lambda(&mut self, c_term: &CTerm, ctx: Self::Ctx) { - let CTerm::Lambda { args, body, .. } = c_term else { unreachable!() }; + let CTerm::Lambda { args, body, .. } = c_term else { + unreachable!() + }; args.iter().for_each(|(arg, _)| self.add_binding(*arg, ctx)); self.visit_c_term(body, ctx); - args.iter().for_each(|(arg, _)| self.remove_binding(*arg, ctx)); + args.iter() + .for_each(|(arg, _)| self.remove_binding(*arg, ctx)); } fn visit_mem_get(&mut self, c_term: &CTerm, ctx: Self::Ctx) { - let CTerm::MemGet { base, offset } = c_term else { unreachable!() }; + let CTerm::MemGet { base, offset } = c_term else { + unreachable!() + }; self.visit_v_term(base, ctx); self.visit_v_term(offset, ctx); } fn visit_mem_set(&mut self, c_term: &CTerm, ctx: Self::Ctx) { - let CTerm::MemSet { base, offset, value } = c_term else { unreachable!() }; + let CTerm::MemSet { + base, + offset, + value, + } = c_term + else { + unreachable!() + }; self.visit_v_term(base, ctx); self.visit_v_term(offset, ctx); self.visit_v_term(value, ctx); } fn visit_primitive_call(&mut self, c_term: &CTerm, ctx: Self::Ctx) { - let CTerm::PrimitiveCall { args, .. } = c_term else { unreachable!() }; + let CTerm::PrimitiveCall { args, .. } = c_term else { + unreachable!() + }; args.iter().for_each(|arg| self.visit_v_term(arg, ctx)); } fn visit_operation_call(&mut self, c_term: &CTerm, ctx: Self::Ctx) { - let CTerm::OperationCall { eff, args, .. } = c_term else { unreachable!() }; + let CTerm::OperationCall { eff, args, .. } = c_term else { + unreachable!() + }; self.visit_v_term(eff, ctx); args.iter().for_each(|arg| self.visit_v_term(arg, ctx)); } @@ -123,11 +164,18 @@ pub trait Visitor { parameter_replicator, transform, handlers, - input - } = c_term else { unreachable!() }; + input, + } = c_term + else { + unreachable!() + }; self.visit_v_term(parameter, ctx); - parameter_disposer.iter().for_each(|disposer| self.visit_v_term(disposer, ctx)); - parameter_replicator.iter().for_each(|replicator| self.visit_v_term(replicator, ctx)); + parameter_disposer + .iter() + .for_each(|disposer| self.visit_v_term(disposer, ctx)); + parameter_replicator + .iter() + .for_each(|replicator| self.visit_v_term(replicator, ctx)); self.visit_v_term(transform, ctx); for (eff, handler, ..) in handlers.iter() { self.visit_v_term(eff, ctx); @@ -135,4 +183,4 @@ pub trait Visitor { } self.visit_v_term(input, ctx); } -} \ No newline at end of file +} diff --git a/src/backend/common.rs b/src/backend/common.rs index 700456e..b374a6e 100644 --- a/src/backend/common.rs +++ b/src/backend/common.rs @@ -1,16 +1,16 @@ +use crate::ast::term::{PType, SpecializedType, VType}; +use archon_vm_runtime::runtime_utils::*; use cranelift::codegen::ir::{Endianness, Inst}; use cranelift::codegen::isa::CallConv; use cranelift::frontend::Switch; -use archon_vm_runtime::runtime_utils::*; -use cranelift::prelude::*; use cranelift::prelude::types::{F32, F64, I32, I64}; -use cranelift_jit::{JITBuilder}; +use cranelift::prelude::*; +use cranelift_jit::JITBuilder; use cranelift_module::{DataDescription, DataId, FuncId, Linkage, Module}; -use crate::ast::term::{VType, SpecializedType, PType}; +use enum_map::Enum; use strum_macros::EnumIter; -use enum_map::{Enum}; -use VType::{Specialized, Uniform}; use SpecializedType::{Integer, PrimitivePtr, StructPtr}; +use VType::{Specialized, Uniform}; /// None means the function call is a tail call or returned so no value is returned. pub type TypedValue = (Value, VType); @@ -33,8 +33,8 @@ impl HasType for VType { PType::I32 => I32, PType::F64 => F64, PType::F32 => F32, - } - } + }, + }, } } } @@ -123,16 +123,24 @@ impl BuiltinFunction { BuiltinFunction::PrepareResumeContinuation => "__runtime_prepare_resume_continuation", BuiltinFunction::PrepareDisposeContinuation => "__runtime_prepare_dispose_continuation", BuiltinFunction::ReplicateContinuation => "__runtime_replicate_continuation", - BuiltinFunction::ProcessSimpleHandlerResult => "__runtime_process_simple_handler_result", + BuiltinFunction::ProcessSimpleHandlerResult => { + "__runtime_process_simple_handler_result" + } BuiltinFunction::MarkHandler => "__runtime_mark_handler", BuiltinFunction::TrivialContinuationImpl => "__runtime_trivial_continuation_impl", - BuiltinFunction::CapturedContinuationRecordImpl => "__runtime_captured_continuation_record_impl", + BuiltinFunction::CapturedContinuationRecordImpl => { + "__runtime_captured_continuation_record_impl" + } BuiltinFunction::SimpleHandlerRunnerImpl => "__runtime_simple_handler_runner_impl", BuiltinFunction::TransformLoaderCpsImpl => "__runtime_transform_loader_cps_impl", BuiltinFunction::DisposerLoaderCpsImpl => "__runtime_disposer_loader_cps_impl", - BuiltinFunction::InvokeCpsFunctionWithTrivialContinuation => "__runtime_invoke_cps_function", - BuiltinFunction::SimpleExceptionContinuationImpl => "__runtime_simple_exception_continuation_impl", + BuiltinFunction::InvokeCpsFunctionWithTrivialContinuation => { + "__runtime_invoke_cps_function" + } + BuiltinFunction::SimpleExceptionContinuationImpl => { + "__runtime_simple_exception_continuation_impl" + } } } @@ -146,10 +154,16 @@ impl BuiltinFunction { BuiltinFunction::PopHandler => runtime_pop_handler as *const u8, BuiltinFunction::RegisterHandler => runtime_register_handler as *const u8, BuiltinFunction::AddHandler => runtime_add_handler as *const u8, - BuiltinFunction::PrepareResumeContinuation => runtime_prepare_resume_continuation as *const u8, - BuiltinFunction::PrepareDisposeContinuation => runtime_prepare_dispose_continuation as *const u8, + BuiltinFunction::PrepareResumeContinuation => { + runtime_prepare_resume_continuation as *const u8 + } + BuiltinFunction::PrepareDisposeContinuation => { + runtime_prepare_dispose_continuation as *const u8 + } BuiltinFunction::ReplicateContinuation => runtime_replicate_continuation as *const u8, - BuiltinFunction::ProcessSimpleHandlerResult => runtime_process_simple_handler_result as *const u8, + BuiltinFunction::ProcessSimpleHandlerResult => { + runtime_process_simple_handler_result as *const u8 + } BuiltinFunction::MarkHandler => runtime_mark_handler as *const u8, BuiltinFunction::TrivialContinuationImpl => return, @@ -166,19 +180,33 @@ impl BuiltinFunction { pub fn declare(&self, m: &mut M) -> (FuncId, Signature, Linkage) { let mut sig = m.make_signature(); - let mut declare_func_with_call_conv = |m: &mut M, arg_count: usize, return_count: usize, linkage: Linkage, call_conv: CallConv| { - for _ in 0..arg_count { - sig.params.push(AbiParam::new(I64)); - } - for _ in 0..return_count { - sig.returns.push(AbiParam::new(I64)); - } - sig.call_conv = call_conv; - (m.declare_function(self.func_name(), linkage, &sig).unwrap(), linkage) - }; + let mut declare_func_with_call_conv = + |m: &mut M, + arg_count: usize, + return_count: usize, + linkage: Linkage, + call_conv: CallConv| { + for _ in 0..arg_count { + sig.params.push(AbiParam::new(I64)); + } + for _ in 0..return_count { + sig.returns.push(AbiParam::new(I64)); + } + sig.call_conv = call_conv; + ( + m.declare_function(self.func_name(), linkage, &sig).unwrap(), + linkage, + ) + }; let mut declare_func = |arg_count: usize, return_count: usize, linkage: Linkage| { - declare_func_with_call_conv(m, arg_count, return_count, linkage, m.isa().default_call_conv()) + declare_func_with_call_conv( + m, + arg_count, + return_count, + linkage, + m.isa().default_call_conv(), + ) }; let (func_id, linkage) = match self { @@ -194,15 +222,31 @@ impl BuiltinFunction { BuiltinFunction::PrepareDisposeContinuation => declare_func(7, 1, Linkage::Import), BuiltinFunction::ReplicateContinuation => declare_func(8, 1, Linkage::Import), BuiltinFunction::ProcessSimpleHandlerResult => declare_func(5, 1, Linkage::Import), - BuiltinFunction::MarkHandler => declare_func_with_call_conv(m, 4, 1, Linkage::Import, CallConv::Tail), - - BuiltinFunction::TrivialContinuationImpl => declare_func_with_call_conv(m, 3, 1, Linkage::Local, CallConv::Tail), - BuiltinFunction::CapturedContinuationRecordImpl => declare_func_with_call_conv(m, 2, 1, Linkage::Local, CallConv::Tail), - BuiltinFunction::SimpleHandlerRunnerImpl => declare_func_with_call_conv(m, 2, 1, Linkage::Local, CallConv::Tail), - BuiltinFunction::TransformLoaderCpsImpl => declare_func_with_call_conv(m, 3, 1, Linkage::Local, CallConv::Tail), - BuiltinFunction::DisposerLoaderCpsImpl => declare_func_with_call_conv(m, 3, 1, Linkage::Local, CallConv::Tail), - BuiltinFunction::InvokeCpsFunctionWithTrivialContinuation => declare_func(2, 1, Linkage::Local), - BuiltinFunction::SimpleExceptionContinuationImpl => declare_func_with_call_conv(m, 2, 1, Linkage::Local, CallConv::Tail), + BuiltinFunction::MarkHandler => { + declare_func_with_call_conv(m, 4, 1, Linkage::Import, CallConv::Tail) + } + + BuiltinFunction::TrivialContinuationImpl => { + declare_func_with_call_conv(m, 3, 1, Linkage::Local, CallConv::Tail) + } + BuiltinFunction::CapturedContinuationRecordImpl => { + declare_func_with_call_conv(m, 2, 1, Linkage::Local, CallConv::Tail) + } + BuiltinFunction::SimpleHandlerRunnerImpl => { + declare_func_with_call_conv(m, 2, 1, Linkage::Local, CallConv::Tail) + } + BuiltinFunction::TransformLoaderCpsImpl => { + declare_func_with_call_conv(m, 3, 1, Linkage::Local, CallConv::Tail) + } + BuiltinFunction::DisposerLoaderCpsImpl => { + declare_func_with_call_conv(m, 3, 1, Linkage::Local, CallConv::Tail) + } + BuiltinFunction::InvokeCpsFunctionWithTrivialContinuation => { + declare_func(2, 1, Linkage::Local) + } + BuiltinFunction::SimpleExceptionContinuationImpl => { + declare_func_with_call_conv(m, 2, 1, Linkage::Local, CallConv::Tail) + } }; (func_id, sig, linkage) } @@ -217,21 +261,36 @@ impl BuiltinFunction { let mut builder_ctx = FunctionBuilderContext::new(); let mut builder = FunctionBuilder::new(&mut ctx.func, &mut builder_ctx); match self { - BuiltinFunction::TrivialContinuationImpl => Self::trivial_continuation_impl(m, &mut builder), - BuiltinFunction::CapturedContinuationRecordImpl => Self::captured_continuation_record_impl(m, &mut builder), - BuiltinFunction::SimpleHandlerRunnerImpl => Self::simple_handler_runner_impl(m, &mut builder), - BuiltinFunction::TransformLoaderCpsImpl => Self::transform_loader_cps_impl(m, &mut builder), - BuiltinFunction::DisposerLoaderCpsImpl => Self::disposer_loader_cps_impl(m, &mut builder), - BuiltinFunction::InvokeCpsFunctionWithTrivialContinuation => Self::invoke_cps_function_with_trivial_continuation(m, &mut builder), - BuiltinFunction::SimpleExceptionContinuationImpl => Self::simple_exception_continuation_impl(m, &mut builder), - _ => { unreachable!() } + BuiltinFunction::TrivialContinuationImpl => { + Self::trivial_continuation_impl(m, &mut builder) + } + BuiltinFunction::CapturedContinuationRecordImpl => { + Self::captured_continuation_record_impl(m, &mut builder) + } + BuiltinFunction::SimpleHandlerRunnerImpl => { + Self::simple_handler_runner_impl(m, &mut builder) + } + BuiltinFunction::TransformLoaderCpsImpl => { + Self::transform_loader_cps_impl(m, &mut builder) + } + BuiltinFunction::DisposerLoaderCpsImpl => { + Self::disposer_loader_cps_impl(m, &mut builder) + } + BuiltinFunction::InvokeCpsFunctionWithTrivialContinuation => { + Self::invoke_cps_function_with_trivial_continuation(m, &mut builder) + } + BuiltinFunction::SimpleExceptionContinuationImpl => { + Self::simple_exception_continuation_impl(m, &mut builder) + } + _ => { + unreachable!() + } } builder.finalize(); m.define_function(func_id, &mut ctx).unwrap(); func_id } - fn trivial_continuation_impl(m: &mut M, builder: &mut FunctionBuilder) { let entry_block = builder.create_block(); builder.append_block_params_for_function_params(entry_block); @@ -244,12 +303,15 @@ impl BuiltinFunction { // trivial continuation should be placing the result. let last_result_ptr = builder.block_params(entry_block)[2]; Self::call_built_in( - m, builder, BuiltinFunction::DebugHelper, + m, + builder, + BuiltinFunction::DebugHelper, &[ builder.block_params(entry_block)[0], builder.block_params(entry_block)[1], - builder.block_params(entry_block)[2] - ]); + builder.block_params(entry_block)[2], + ], + ); builder.ins().return_(&[last_result_ptr]); } @@ -320,32 +382,66 @@ impl BuiltinFunction { m, builder, BuiltinFunction::PrepareResumeContinuation, - &[base_address, next_continuation, captured_continuation, handler_parameter, result, frame_pointer, stack_pointer], + &[ + base_address, + next_continuation, + captured_continuation, + handler_parameter, + result, + frame_pointer, + stack_pointer, + ], ); let prepare_result_ptr = builder.inst_results(inst)[0]; builder.ins().jump(final_block, &[prepare_result_ptr]); // dispose builder.switch_to_block(dispose_block); - let disposer_loader_cps_impl = Self::get_built_in_func_ptr(m, builder, BuiltinFunction::DisposerLoaderCpsImpl); + let disposer_loader_cps_impl = + Self::get_built_in_func_ptr(m, builder, BuiltinFunction::DisposerLoaderCpsImpl); let inst = Self::call_built_in( m, builder, BuiltinFunction::PrepareDisposeContinuation, - &[base_address, next_continuation, captured_continuation, handler_parameter, frame_pointer, stack_pointer, disposer_loader_cps_impl], + &[ + base_address, + next_continuation, + captured_continuation, + handler_parameter, + frame_pointer, + stack_pointer, + disposer_loader_cps_impl, + ], ); let prepare_result_ptr = builder.inst_results(inst)[0]; builder.ins().jump(final_block, &[prepare_result_ptr]); // replicate builder.switch_to_block(replicate_block); - let invoke_cps_function_with_trivial_continuation = Self::get_built_in_func_ptr(m, builder, BuiltinFunction::InvokeCpsFunctionWithTrivialContinuation); - let captured_continuation_record_impl = Self::get_built_in_func_ptr(m, builder, BuiltinFunction::CapturedContinuationRecordImpl); + let invoke_cps_function_with_trivial_continuation = Self::get_built_in_func_ptr( + m, + builder, + BuiltinFunction::InvokeCpsFunctionWithTrivialContinuation, + ); + let captured_continuation_record_impl = Self::get_built_in_func_ptr( + m, + builder, + BuiltinFunction::CapturedContinuationRecordImpl, + ); let inst = Self::call_built_in( m, builder, BuiltinFunction::ReplicateContinuation, - &[base_address, next_continuation, captured_continuation, handler_parameter, frame_pointer, stack_pointer, invoke_cps_function_with_trivial_continuation, captured_continuation_record_impl], + &[ + base_address, + next_continuation, + captured_continuation, + handler_parameter, + frame_pointer, + stack_pointer, + invoke_cps_function_with_trivial_continuation, + captured_continuation_record_impl, + ], ); let prepare_result_ptr = builder.inst_results(inst)[0]; builder.ins().jump(final_block, &[prepare_result_ptr]); @@ -358,11 +454,23 @@ impl BuiltinFunction { builder.seal_block(final_block); builder.switch_to_block(final_block); let prepare_result_ptr = builder.block_params(final_block)[0]; - let next_continuation = builder.ins().load(I64, MemFlags::new(), prepare_result_ptr, 0); - let next_base_address = builder.ins().load(I64, MemFlags::new(), prepare_result_ptr, 8); - let last_result_ptr = builder.ins().load(I64, MemFlags::new(), prepare_result_ptr, 16); - let continuation_impl = builder.ins().load(I64, MemFlags::new(), next_continuation, 0); - builder.ins().return_call_indirect(cps_impl_sig_ref, continuation_impl, &[next_base_address, next_continuation, last_result_ptr]); + let next_continuation = builder + .ins() + .load(I64, MemFlags::new(), prepare_result_ptr, 0); + let next_base_address = builder + .ins() + .load(I64, MemFlags::new(), prepare_result_ptr, 8); + let last_result_ptr = builder + .ins() + .load(I64, MemFlags::new(), prepare_result_ptr, 16); + let continuation_impl = builder + .ins() + .load(I64, MemFlags::new(), next_continuation, 0); + builder.ins().return_call_indirect( + cps_impl_sig_ref, + continuation_impl, + &[next_base_address, next_continuation, last_result_ptr], + ); } fn simple_handler_runner_impl(m: &mut M, builder: &mut FunctionBuilder) { @@ -378,33 +486,62 @@ impl BuiltinFunction { let handler_index = builder.ins().load(I64, MemFlags::new(), base_address, 8); let simple_handler_type = builder.ins().load(I64, MemFlags::new(), base_address, 16); - let trivial_continuation = Self::get_built_in_data(m, builder, BuiltinData::TrivialContinuation); + let trivial_continuation = + Self::get_built_in_data(m, builder, BuiltinData::TrivialContinuation); let new_base_address = builder.ins().iadd_imm(base_address, 24); let sig = create_cps_signature(m); let sig_ref = builder.import_signature(sig); - let inst = builder.ins().call_indirect(sig_ref, handler_function_ptr, &[new_base_address, trivial_continuation]); + let inst = builder.ins().call_indirect( + sig_ref, + handler_function_ptr, + &[new_base_address, trivial_continuation], + ); let result_ptr = builder.inst_results(inst)[0]; let simple_handler_result = builder.ins().load(I64, MemFlags::new(), result_ptr, 0); let simple_handler_result_ptr = builder.ins().band_imm(simple_handler_result, !0b11); - let simple_exception_continuation_impl = Self::get_built_in_func_ptr(m, builder, BuiltinFunction::SimpleExceptionContinuationImpl); - let disposer_loader_cps_impl = Self::get_built_in_func_ptr(m, builder, BuiltinFunction::DisposerLoaderCpsImpl); - let inst = Self::call_built_in(m, builder, BuiltinFunction::ProcessSimpleHandlerResult, &[handler_index, simple_handler_type, simple_handler_result_ptr, simple_exception_continuation_impl, disposer_loader_cps_impl]); + let simple_exception_continuation_impl = Self::get_built_in_func_ptr( + m, + builder, + BuiltinFunction::SimpleExceptionContinuationImpl, + ); + let disposer_loader_cps_impl = + Self::get_built_in_func_ptr(m, builder, BuiltinFunction::DisposerLoaderCpsImpl); + let inst = Self::call_built_in( + m, + builder, + BuiltinFunction::ProcessSimpleHandlerResult, + &[ + handler_index, + simple_handler_type, + simple_handler_result_ptr, + simple_exception_continuation_impl, + disposer_loader_cps_impl, + ], + ); let result = builder.inst_results(inst)[0]; let new_base_address = builder.ins().iadd_imm(result_ptr, 8); - let next_continuation_impl_ptr = builder.ins().load(I64, MemFlags::new(), next_continuation, 0); + let next_continuation_impl_ptr = + builder + .ins() + .load(I64, MemFlags::new(), next_continuation, 0); builder.ins().store(MemFlags::new(), result, result_ptr, 0); let cps_impl_sig = create_cps_impl_signature(m); let cps_impl_sig_ref = builder.import_signature(cps_impl_sig); - builder.ins().return_call_indirect(cps_impl_sig_ref, next_continuation_impl_ptr, &[new_base_address, next_continuation, result_ptr]); + builder.ins().return_call_indirect( + cps_impl_sig_ref, + next_continuation_impl_ptr, + &[new_base_address, next_continuation, result_ptr], + ); } fn transform_loader_cps_impl(m: &mut M, builder: &mut FunctionBuilder) { let entry_block = builder.create_block(); - let tip_address_slot = builder.create_sized_stack_slot(StackSlotData::new(StackSlotKind::ExplicitSlot, 8, 0)); + let tip_address_slot = + builder.create_sized_stack_slot(StackSlotData::new(StackSlotKind::ExplicitSlot, 8, 0)); builder.append_block_params_for_function_params(entry_block); builder.switch_to_block(entry_block); builder.seal_block(entry_block); @@ -423,8 +560,12 @@ impl BuiltinFunction { // write handler parameter and result on stack. The first parameter is written closer to the // base address, so it's lower on the stack. - builder.ins().store(MemFlags::new(), handler_parameter, base_address, -8); - builder.ins().store(MemFlags::new(), last_result, base_address, 0); + builder + .ins() + .store(MemFlags::new(), handler_parameter, base_address, -8); + builder + .ins() + .store(MemFlags::new(), last_result, base_address, 0); // set the tip address to base address + 8 so that the only argument of transform loader // is taken for the final tail call to the transform function. let tip_address = builder.ins().iadd_imm(base_address, -8); @@ -432,29 +573,55 @@ impl BuiltinFunction { // push all the thunk arguments to the stack builder.ins().stack_store(tip_address, tip_address_slot, 0); let tip_address_ptr = builder.ins().stack_addr(I64, tip_address_slot, 0); - let inst = Self::call_built_in(m, builder, BuiltinFunction::ForceThunk, &[transform_thunk, tip_address_ptr]); + let inst = Self::call_built_in( + m, + builder, + BuiltinFunction::ForceThunk, + &[transform_thunk, tip_address_ptr], + ); let transform_ptr = builder.inst_results(inst)[0]; let tip_address = builder.ins().stack_load(I64, tip_address_slot, 0); - let next_continuation = builder.ins().load(I64, MemFlags::new(), current_continuation, 16); + let next_continuation = builder + .ins() + .load(I64, MemFlags::new(), current_continuation, 16); // update frame height of the next continuation to account for the transform arguments // pushed to the stack - let next_continuation_frame_height = builder.ins().load(I64, MemFlags::new(), next_continuation, 8); - let next_continuation_frame_height_delta_bytes = builder.ins().isub(base_address, tip_address); - let next_continuation_frame_height_delta = builder.ins().ushr_imm(next_continuation_frame_height_delta_bytes, 3); - let next_continuation_frame_height = builder.ins().iadd(next_continuation_frame_height, next_continuation_frame_height_delta); - builder.ins().store(MemFlags::new(), next_continuation_frame_height, next_continuation, 8); + let next_continuation_frame_height = + builder + .ins() + .load(I64, MemFlags::new(), next_continuation, 8); + let next_continuation_frame_height_delta_bytes = + builder.ins().isub(base_address, tip_address); + let next_continuation_frame_height_delta = builder + .ins() + .ushr_imm(next_continuation_frame_height_delta_bytes, 3); + let next_continuation_frame_height = builder.ins().iadd( + next_continuation_frame_height, + next_continuation_frame_height_delta, + ); + builder.ins().store( + MemFlags::new(), + next_continuation_frame_height, + next_continuation, + 8, + ); // call the next continuation let sig = create_cps_signature(m); let sig_ref = builder.import_signature(sig); - builder.ins().return_call_indirect(sig_ref, transform_ptr, &[tip_address, next_continuation]); + builder.ins().return_call_indirect( + sig_ref, + transform_ptr, + &[tip_address, next_continuation], + ); } fn disposer_loader_cps_impl(m: &mut M, builder: &mut FunctionBuilder) { let entry_block = builder.create_block(); - let tip_address_slot = builder.create_sized_stack_slot(StackSlotData::new(StackSlotKind::ExplicitSlot, 8, 0)); + let tip_address_slot = + builder.create_sized_stack_slot(StackSlotData::new(StackSlotKind::ExplicitSlot, 8, 0)); builder.append_block_params_for_function_params(entry_block); builder.switch_to_block(entry_block); builder.seal_block(entry_block); @@ -473,7 +640,13 @@ impl BuiltinFunction { let next_continuation_block = builder.create_block(); let disposer_thunk_ptr = builder.ins().band_imm(disposer_thunk, !0b11); - builder.ins().brif(disposer_thunk_ptr, call_disposer_block, &[], next_continuation_block, &[]); + builder.ins().brif( + disposer_thunk_ptr, + call_disposer_block, + &[], + next_continuation_block, + &[], + ); builder.seal_block(call_disposer_block); builder.seal_block(next_continuation_block); @@ -484,52 +657,97 @@ impl BuiltinFunction { builder.switch_to_block(call_disposer_block); // replace the disposer thunk with the parameter. This works since the loader takes exactly one argument, // the disposer thunk. And the disposer thunk also takes exactly one parameter, the handler parameter. - builder.ins().store(MemFlags::new(), handler_parameter, base_address, 0); + builder + .ins() + .store(MemFlags::new(), handler_parameter, base_address, 0); // push all the thunk arguments to the stack builder.ins().stack_store(base_address, tip_address_slot, 0); let tip_address_ptr = builder.ins().stack_addr(I64, tip_address_slot, 0); - let inst = Self::call_built_in(m, builder, BuiltinFunction::ForceThunk, &[disposer_thunk, tip_address_ptr]); + let inst = Self::call_built_in( + m, + builder, + BuiltinFunction::ForceThunk, + &[disposer_thunk, tip_address_ptr], + ); let disposer_ptr = builder.inst_results(inst)[0]; let tip_address = builder.ins().stack_load(I64, tip_address_slot, 0); - let next_continuation = builder.ins().load(I64, MemFlags::new(), current_continuation, 16); + let next_continuation = builder + .ins() + .load(I64, MemFlags::new(), current_continuation, 16); // update frame height of the next continuation to account for the disposer arguments // pushed to the stack - let next_continuation_frame_height = builder.ins().load(I64, MemFlags::new(), next_continuation, 8); - let next_continuation_frame_height_delta_bytes = builder.ins().isub(base_address, tip_address); - let next_continuation_frame_height_delta = builder.ins().ushr_imm(next_continuation_frame_height_delta_bytes, 3); - let next_continuation_frame_height = builder.ins().iadd(next_continuation_frame_height, next_continuation_frame_height_delta); - builder.ins().store(MemFlags::new(), next_continuation_frame_height, next_continuation, 8); + let next_continuation_frame_height = + builder + .ins() + .load(I64, MemFlags::new(), next_continuation, 8); + let next_continuation_frame_height_delta_bytes = + builder.ins().isub(base_address, tip_address); + let next_continuation_frame_height_delta = builder + .ins() + .ushr_imm(next_continuation_frame_height_delta_bytes, 3); + let next_continuation_frame_height = builder.ins().iadd( + next_continuation_frame_height, + next_continuation_frame_height_delta, + ); + builder.ins().store( + MemFlags::new(), + next_continuation_frame_height, + next_continuation, + 8, + ); // call the next continuation let sig = create_cps_signature(m); let sig_ref = builder.import_signature(sig); - builder.ins().return_call_indirect(sig_ref, disposer_ptr, &[tip_address, next_continuation]); - + builder.ins().return_call_indirect( + sig_ref, + disposer_ptr, + &[tip_address, next_continuation], + ); // +-------------------------+ // | next continuation block | // +-------------------------+ // Call next continuation directly if the disposer is null. builder.switch_to_block(next_continuation_block); - let next_continuation = builder.ins().load(I64, MemFlags::new(), current_continuation, 16); - let next_continuation_height = builder.ins().load(I64, MemFlags::new(), next_continuation, 8); + let next_continuation = builder + .ins() + .load(I64, MemFlags::new(), current_continuation, 16); + let next_continuation_height = + builder + .ins() + .load(I64, MemFlags::new(), next_continuation, 8); let next_continuation_height_bytes = builder.ins().ishl_imm(next_continuation_height, 3); - let next_base_address = builder.ins().iadd(base_address, next_continuation_height_bytes); - let next_continuation_impl_ptr = builder.ins().load(I64, MemFlags::new(), next_continuation, 0); + let next_base_address = builder + .ins() + .iadd(base_address, next_continuation_height_bytes); + let next_continuation_impl_ptr = + builder + .ins() + .load(I64, MemFlags::new(), next_continuation, 0); let empty_struct = Self::get_built_in_data(m, builder, BuiltinData::EmptyStruct); // The loader has exactly one argument, hence, the result should be written at the single argument // location. We can't derive the next base address by base_address + 1, though. Because the // caller could have set up other arguments on the stack. - builder.ins().store(MemFlags::new(), empty_struct, base_address, 0); + builder + .ins() + .store(MemFlags::new(), empty_struct, base_address, 0); let sig = create_cps_impl_signature(m); let sig_ref = builder.import_signature(sig); - builder.ins().return_call_indirect(sig_ref, next_continuation_impl_ptr, &[next_base_address, next_continuation, base_address]); + builder.ins().return_call_indirect( + sig_ref, + next_continuation_impl_ptr, + &[next_base_address, next_continuation, base_address], + ); } - fn invoke_cps_function_with_trivial_continuation(m: &mut M, builder: &mut FunctionBuilder) { + fn invoke_cps_function_with_trivial_continuation( + m: &mut M, + builder: &mut FunctionBuilder, + ) { let entry_block = builder.create_block(); builder.append_block_params_for_function_params(entry_block); builder.switch_to_block(entry_block); @@ -538,11 +756,15 @@ impl BuiltinFunction { let func_ptr = builder.block_params(entry_block)[0]; let base_address = builder.block_params(entry_block)[1]; - let trivial_continuation = Self::get_built_in_data(m, builder, BuiltinData::TrivialContinuation); + let trivial_continuation = + Self::get_built_in_data(m, builder, BuiltinData::TrivialContinuation); let sig = create_cps_signature(m); let sig_ref = builder.import_signature(sig); - let inst = builder.ins().call_indirect(sig_ref, func_ptr, &[base_address, trivial_continuation]); + let inst = + builder + .ins() + .call_indirect(sig_ref, func_ptr, &[base_address, trivial_continuation]); let result_ptr = builder.inst_results(inst)[0]; builder.ins().return_(&[result_ptr]); } @@ -557,18 +779,34 @@ impl BuiltinFunction { let current_continuation = builder.block_params(entry_block)[1]; // The exceptional value is stored inside the state field of the continuation. - let exception_value = builder.ins().load(I64, MemFlags::new(), current_continuation, 24); + let exception_value = builder + .ins() + .load(I64, MemFlags::new(), current_continuation, 24); let result_ptr = builder.ins().iadd_imm(base_address, -8); - builder.ins().store(MemFlags::new(), exception_value, result_ptr, 0); - - let next_continuation = builder.ins().load(I64, MemFlags::new(), current_continuation, 16); - let next_func = builder.ins().load(I64, MemFlags::new(), next_continuation, 0); + builder + .ins() + .store(MemFlags::new(), exception_value, result_ptr, 0); + + let next_continuation = builder + .ins() + .load(I64, MemFlags::new(), current_continuation, 16); + let next_func = builder + .ins() + .load(I64, MemFlags::new(), next_continuation, 0); let cps_impl_sig = create_cps_impl_signature(m); let cps_impl_sig_ref = builder.import_signature(cps_impl_sig); - builder.ins().return_call_indirect(cps_impl_sig_ref, next_func, &[base_address, next_continuation, result_ptr]); + builder.ins().return_call_indirect( + cps_impl_sig_ref, + next_func, + &[base_address, next_continuation, result_ptr], + ); } - fn get_built_in_data(m: &mut M, builder: &mut FunctionBuilder, data: BuiltinData) -> Value { + fn get_built_in_data( + m: &mut M, + builder: &mut FunctionBuilder, + data: BuiltinData, + ) -> Value { let data_id = data.declare(m); let data_ref = m.declare_data_in_func(data_id, builder.func); let result = if data.is_tls() { @@ -579,14 +817,23 @@ impl BuiltinFunction { builder.ins().iadd_imm(result, data.offset()) } - fn call_built_in(m: &mut M, builder: &mut FunctionBuilder, func: BuiltinFunction, args: &[Value]) -> Inst { + fn call_built_in( + m: &mut M, + builder: &mut FunctionBuilder, + func: BuiltinFunction, + args: &[Value], + ) -> Inst { let func_id = func.declare(m).0; let func_ref = m.declare_func_in_func(func_id, builder.func); let inst = builder.ins().call(func_ref, args); inst } - fn get_built_in_func_ptr(m: &mut M, builder: &mut FunctionBuilder, func: BuiltinFunction) -> Value { + fn get_built_in_func_ptr( + m: &mut M, + builder: &mut FunctionBuilder, + func: BuiltinFunction, + ) -> Value { let func_id = func.declare(m).0; let func_ref = m.declare_func_in_func(func_id, builder.func); builder.ins().func_addr(I64, func_ref) @@ -625,10 +872,18 @@ impl FunctionFlavor { pub fn create_cps_impl_signature(module: &M) -> Signature { let mut uniform_cps_impl_func_signature = module.make_signature(); - uniform_cps_impl_func_signature.params.push(AbiParam::new(I64)); // base address - uniform_cps_impl_func_signature.params.push(AbiParam::new(I64)); // the current continuation object - uniform_cps_impl_func_signature.params.push(AbiParam::new(I64)); // the last result - uniform_cps_impl_func_signature.returns.push(AbiParam::new(I64)); + uniform_cps_impl_func_signature + .params + .push(AbiParam::new(I64)); // base address + uniform_cps_impl_func_signature + .params + .push(AbiParam::new(I64)); // the current continuation object + uniform_cps_impl_func_signature + .params + .push(AbiParam::new(I64)); // the last result + uniform_cps_impl_func_signature + .returns + .push(AbiParam::new(I64)); uniform_cps_impl_func_signature.call_conv = CallConv::Tail; uniform_cps_impl_func_signature } @@ -652,12 +907,18 @@ impl BuiltinData { fn name(&self) -> &'static str { match self { BuiltinData::TrivialContinuation => "__runtime_trivial_continuation", - BuiltinData::EmptyStruct => "__runtime_empty_struct" + BuiltinData::EmptyStruct => "__runtime_empty_struct", } } pub fn declare(&self, m: &mut M) -> DataId { - m.declare_data(self.name(), Linkage::Local, self.is_writable(), self.is_tls()).unwrap() + m.declare_data( + self.name(), + Linkage::Local, + self.is_writable(), + self.is_tls(), + ) + .unwrap() } pub fn is_tls(&self) -> bool { @@ -679,7 +940,7 @@ impl BuiltinData { match self { // Offset by a machine word since a trivial continuation is an object, whose first word is the object length. BuiltinData::TrivialContinuation => 8, - BuiltinData::EmptyStruct => 8 + BuiltinData::EmptyStruct => 8, } } @@ -688,8 +949,10 @@ impl BuiltinData { let mut data_description = DataDescription::new(); match self { BuiltinData::TrivialContinuation => { - let (trivial_continuation_impl_func_id, ..) = BuiltinFunction::TrivialContinuationImpl.declare(m); - let trivial_continuation_impl_func_ref = m.declare_func_in_data(trivial_continuation_impl_func_id, &mut data_description); + let (trivial_continuation_impl_func_id, ..) = + BuiltinFunction::TrivialContinuationImpl.declare(m); + let trivial_continuation_impl_func_ref = m + .declare_func_in_data(trivial_continuation_impl_func_id, &mut data_description); data_description.set_align(8); // A trivial continuation takes 1 word for object header and 4 words for object body. @@ -721,10 +984,13 @@ impl BuiltinData { } fn to_bytes(m: &mut M, data: Vec) -> Vec { - let data_bytes = data.into_iter().flat_map(|x| match m.isa().endianness() { - Endianness::Little => x.to_le_bytes(), - Endianness::Big => x.to_be_bytes(), - }).collect::>(); + let data_bytes = data + .into_iter() + .flat_map(|x| match m.isa().endianness() { + Endianness::Little => x.to_le_bytes(), + Endianness::Big => x.to_be_bytes(), + }) + .collect::>(); data_bytes } } diff --git a/src/backend/compiler.rs b/src/backend/compiler.rs index 5ff0b27..8cd5136 100644 --- a/src/backend/compiler.rs +++ b/src/backend/compiler.rs @@ -1,18 +1,21 @@ -use std::arch::global_asm; -use std::collections::HashMap; +use crate::ast::signature::FunctionDefinition; +use crate::ast::term::CType; +use crate::backend::common::{ + create_cps_impl_signature, create_cps_signature, BuiltinData, BuiltinFunction, FunctionFlavor, + HasType, +}; +use crate::backend::cps_function_translator::CpsFunctionTranslator; +use crate::backend::simple_function_translator::SimpleFunctionTranslator; use cranelift::codegen::isa::CallConv; +use cranelift::prelude::types::I64; use cranelift::prelude::*; -use cranelift::prelude::types::{I64}; use cranelift_jit::{JITBuilder, JITModule}; use cranelift_module::{DataId, FuncId, Linkage, Module}; use cranelift_object::ObjectModule; -use crate::ast::signature::FunctionDefinition; -use crate::ast::term::{CType}; +use enum_map::EnumMap; +use std::arch::global_asm; +use std::collections::HashMap; use strum::IntoEnumIterator; -use enum_map::{EnumMap}; -use crate::backend::common::{BuiltinData, BuiltinFunction, create_cps_impl_signature, create_cps_signature, FunctionFlavor, HasType}; -use crate::backend::cps_function_translator::{CpsFunctionTranslator}; -use crate::backend::simple_function_translator::SimpleFunctionTranslator; /// The basic JIT class. pub struct Compiler { @@ -63,7 +66,8 @@ impl Default for Compiler { } #[cfg(target_arch = "aarch64")] -global_asm!(r#" +global_asm!( + r#" .global _invoke_compiled_function _invoke_compiled_function: @@ -87,7 +91,8 @@ global_asm!(r#" ldp x19, x20, [sp], #16 ret -"#); +"# +); extern "C" { /// compiled function are often invoked with cranelift's tail call convention , which doesn't @@ -120,7 +125,8 @@ impl Compiler { impl Compiler { fn new(mut module: M) -> Self { - let builtin_functions = EnumMap::from_fn(|e: BuiltinFunction| e.declare_or_define(&mut module)); + let builtin_functions = + EnumMap::from_fn(|e: BuiltinFunction| e.declare_or_define(&mut module)); let builtin_data = EnumMap::from_fn(|e: BuiltinData| e.define(&mut module)); let mut uniform_func_signature = module.make_signature(); @@ -146,34 +152,60 @@ impl Compiler { } } - pub fn compile(&mut self, defs: &[(String, FunctionDefinition)], clir: &mut Option<&mut Vec<(String, String)>>) { + pub fn compile( + &mut self, + defs: &[(String, FunctionDefinition)], + clir: &mut Option<&mut Vec<(String, String)>>, + ) { let mut specialized_function_signatures = HashMap::new(); let mut local_function_arg_types = HashMap::new(); for (name, function_definition) in defs.iter() { local_function_arg_types.insert( name.clone(), - (function_definition.args.iter().map(|(_, v_type)| *v_type).collect::>(), function_definition.c_type), + ( + function_definition + .args + .iter() + .map(|(_, v_type)| *v_type) + .collect::>(), + function_definition.c_type, + ), ); if function_definition.need_cps { let cps_name = FunctionFlavor::Cps.decorate_name(name); - let function = self.module.declare_function(&cps_name, Linkage::Local, &self.uniform_cps_func_signature).unwrap(); + let function = self + .module + .declare_function(&cps_name, Linkage::Local, &self.uniform_cps_func_signature) + .unwrap(); self.local_functions.insert(cps_name, function); let cps_impl_name = FunctionFlavor::CpsImpl.decorate_name(name); - let function = self.module.declare_function(&cps_impl_name, Linkage::Local, &self.uniform_cps_impl_func_signature).unwrap(); + let function = self + .module + .declare_function( + &cps_impl_name, + Linkage::Local, + &self.uniform_cps_impl_func_signature, + ) + .unwrap(); self.local_functions.insert(cps_impl_name, function); } if function_definition.need_simple { // Simple let simple_name = FunctionFlavor::Simple.decorate_name(name); - let function = self.module.declare_function(&simple_name, Linkage::Local, &self.uniform_func_signature).unwrap(); + let function = self + .module + .declare_function(&simple_name, Linkage::Local, &self.uniform_func_signature) + .unwrap(); self.local_functions.insert(simple_name, function); } if function_definition.need_specialized { - let CType::SpecializedF(v_type) = function_definition.c_type else { unreachable!() }; + let CType::SpecializedF(v_type) = function_definition.c_type else { + unreachable!() + }; let mut sig = self.module.make_signature(); sig.call_conv = CallConv::Tail; // The first argument is the base address of the parameter stack, which is useful @@ -184,7 +216,12 @@ impl Compiler { } sig.returns.push(AbiParam::new(v_type.get_type())); let specialized_name = FunctionFlavor::Specialized.decorate_name(name); - self.local_functions.insert(specialized_name.clone(), self.module.declare_function(&specialized_name, Linkage::Local, &sig).unwrap()); + self.local_functions.insert( + specialized_name.clone(), + self.module + .declare_function(&specialized_name, Linkage::Local, &sig) + .unwrap(), + ); specialized_function_signatures.insert(name, sig); } } @@ -192,17 +229,37 @@ impl Compiler { for (name, function_definition) in defs.iter() { if function_definition.need_cps { // CPS - CpsFunctionTranslator::compile_cps_function(name, self, function_definition, &local_function_arg_types, clir, true); + CpsFunctionTranslator::compile_cps_function( + name, + self, + function_definition, + &local_function_arg_types, + clir, + true, + ); } if function_definition.need_simple { // simple - SimpleFunctionTranslator::compile_simple_function(name, self, function_definition, &local_function_arg_types, clir); + SimpleFunctionTranslator::compile_simple_function( + name, + self, + function_definition, + &local_function_arg_types, + clir, + ); } // specialized if function_definition.need_specialized { let sig = specialized_function_signatures.get(name).unwrap(); - SimpleFunctionTranslator::compile_specialized_function(name, self, sig.clone(), function_definition, &local_function_arg_types, clir); + SimpleFunctionTranslator::compile_specialized_function( + name, + self, + sig.clone(), + function_definition, + &local_function_arg_types, + clir, + ); } } @@ -212,12 +269,21 @@ impl Compiler { /// Creates a main wrapper function (named `__main__`) that calls the `__runtime_alloc_stack__`, /// which sets up the parameter stack and invokes the user-defined `main` function. fn generate_main_wrapper(&mut self, clir: &mut Option<&mut Vec<(String, String)>>) { - let main_wrapper_id = self.module.declare_function(MAIN_WRAPPER_NAME, Linkage::Local, &self.uniform_func_signature).unwrap(); - self.local_functions.insert(MAIN_WRAPPER_NAME.to_string(), main_wrapper_id); + let main_wrapper_id = self + .module + .declare_function( + MAIN_WRAPPER_NAME, + Linkage::Local, + &self.uniform_func_signature, + ) + .unwrap(); + self.local_functions + .insert(MAIN_WRAPPER_NAME.to_string(), main_wrapper_id); self.ctx.clear(); self.ctx.func.signature.returns.push(AbiParam::new(I64)); - let mut function_builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context); + let mut function_builder = + FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context); let entry_block = function_builder.create_block(); function_builder.append_block_params_for_function_params(entry_block); @@ -225,12 +291,19 @@ impl Compiler { function_builder.seal_block(entry_block); let alloc_stack_id = self.builtin_functions[BuiltinFunction::AllocStack]; - let alloc_stack_func_ref = self.module.declare_func_in_func(alloc_stack_id, function_builder.func); + let alloc_stack_func_ref = self + .module + .declare_func_in_func(alloc_stack_id, function_builder.func); let inst = function_builder.ins().call(alloc_stack_func_ref, &[]); let stack_base = function_builder.inst_results(inst)[0]; - let main_id = self.local_functions.get(&FunctionFlavor::Specialized.decorate_name("main")).unwrap(); - let main_func_ref = self.module.declare_func_in_func(*main_id, function_builder.func); + let main_id = self + .local_functions + .get(&FunctionFlavor::Specialized.decorate_name("main")) + .unwrap(); + let main_func_ref = self + .module + .declare_func_in_func(*main_id, function_builder.func); let inst = function_builder.ins().call(main_func_ref, &[stack_base]); let return_value = function_builder.inst_results(inst)[0]; function_builder.ins().return_(&[return_value]); @@ -241,7 +314,12 @@ impl Compiler { self.module.clear_context(&mut self.ctx); } - pub fn define_function(&mut self, name: &str, func_id: FuncId, clir: &mut Option<&mut Vec<(String, String)>>) { + pub fn define_function( + &mut self, + name: &str, + func_id: FuncId, + clir: &mut Option<&mut Vec<(String, String)>>, + ) { if let Some(clir) = clir { clir.push((name.to_owned(), format!("{}", self.ctx.func.display()))); } diff --git a/src/backend/cps_function_translator.rs b/src/backend/cps_function_translator.rs index beae3d6..a866f91 100644 --- a/src/backend/cps_function_translator.rs +++ b/src/backend/cps_function_translator.rs @@ -1,17 +1,17 @@ -use std::collections::{HashMap, HashSet}; -use std::ops::{Deref, DerefMut}; -use cranelift::prelude::{Block, InstBuilder, MemFlags, TrapCode, Value}; -use cranelift::prelude::types::I64; -use cranelift_module::{FuncId, Linkage, Module}; -use cranelift::frontend::Switch; use crate::ast::signature::FunctionDefinition; -use crate::ast::term::{CTerm, CType, Effect, VType}; use crate::ast::term::SpecializedType::Integer; use crate::ast::term::VType::Uniform; +use crate::ast::term::{CTerm, CType, Effect, VType}; use crate::backend::common::{BuiltinFunction, FunctionFlavor, HasType, TypedReturnValue}; use crate::backend::compiler::Compiler; use crate::backend::function_analyzer::FunctionAnalyzer; use crate::backend::simple_function_translator::SimpleFunctionTranslator; +use cranelift::frontend::Switch; +use cranelift::prelude::types::I64; +use cranelift::prelude::{Block, InstBuilder, MemFlags, TrapCode, Value}; +use cranelift_module::{FuncId, Linkage, Module}; +use std::collections::{HashMap, HashSet}; +use std::ops::{Deref, DerefMut}; pub struct CpsFunctionTranslator {} @@ -25,7 +25,14 @@ impl CpsFunctionTranslator { may_be_complex: bool, ) { if !may_be_complex { - SimpleCpsFunctionTranslator::compile_cps_function(name, compiler, function_definition, local_function_arg_types, clir, false); + SimpleCpsFunctionTranslator::compile_cps_function( + name, + compiler, + function_definition, + local_function_arg_types, + clir, + false, + ); return; } let mut function_analyzer = FunctionAnalyzer::new(); @@ -33,10 +40,32 @@ impl CpsFunctionTranslator { let num_blocks = function_analyzer.count; let case_blocks = function_analyzer.case_blocks; if function_analyzer.has_non_tail_complex_effects { - let cps_impl_func_id = ComplexCpsFunctionTranslator::compile_cps_impl_function(name, compiler, function_definition, local_function_arg_types, num_blocks, case_blocks, clir); - ComplexCpsFunctionTranslator::compile_cps_function(name, compiler, function_definition, local_function_arg_types, cps_impl_func_id, clir); + let cps_impl_func_id = ComplexCpsFunctionTranslator::compile_cps_impl_function( + name, + compiler, + function_definition, + local_function_arg_types, + num_blocks, + case_blocks, + clir, + ); + ComplexCpsFunctionTranslator::compile_cps_function( + name, + compiler, + function_definition, + local_function_arg_types, + cps_impl_func_id, + clir, + ); } else { - SimpleCpsFunctionTranslator::compile_cps_function(name, compiler, function_definition, local_function_arg_types, clir, true); + SimpleCpsFunctionTranslator::compile_cps_function( + name, + compiler, + function_definition, + local_function_arg_types, + clir, + true, + ); } } } @@ -80,7 +109,12 @@ impl<'a, M: Module> SimpleCpsFunctionTranslator<'a, M> { clir: &mut Option<&mut Vec<(String, String)>>, may_be_complex: bool, ) { - let mut translator = SimpleCpsFunctionTranslator::new(compiler, function_definition, local_function_arg_types, may_be_complex); + let mut translator = SimpleCpsFunctionTranslator::new( + compiler, + function_definition, + local_function_arg_types, + may_be_complex, + ); let typed_return_value = translator.translate_c_term_cps(&function_definition.body, true); match typed_return_value { None => {} @@ -90,14 +124,31 @@ impl<'a, M: Module> SimpleCpsFunctionTranslator<'a, M> { function_definition.args.len(), ); let continuation = translator.next_continuation; - invoke_next_continuation_in_the_end(&mut translator, return_value_address, continuation); + invoke_next_continuation_in_the_end( + &mut translator, + return_value_address, + continuation, + ); } } translator.function_translator.function_builder.finalize(); let cps_name = FunctionFlavor::Cps.decorate_name(name); - let func_id = compiler.module.declare_function(&cps_name, Linkage::Local, &compiler.uniform_cps_func_signature).unwrap(); - SimpleFunctionTranslator::define_function(&mut compiler.module, &mut compiler.ctx, &cps_name, func_id, clir); + let func_id = compiler + .module + .declare_function( + &cps_name, + Linkage::Local, + &compiler.uniform_cps_func_signature, + ) + .unwrap(); + SimpleFunctionTranslator::define_function( + &mut compiler.module, + &mut compiler.ctx, + &cps_name, + func_id, + clir, + ); } fn new( @@ -124,7 +175,12 @@ impl<'a, M: Module> SimpleCpsFunctionTranslator<'a, M> { // stack grows from higher address to lower address, so parameter list grows in the // reverse order and hence the offset is the index of the parameter in the parameter // list. - let value = translator.function_builder.ins().load(I64, MemFlags::new(), translator.base_address, (i * 8) as i32); + let value = translator.function_builder.ins().load( + I64, + MemFlags::new(), + translator.base_address, + (i * 8) as i32, + ); Some((value, Uniform)) }, ); @@ -136,7 +192,6 @@ impl<'a, M: Module> SimpleCpsFunctionTranslator<'a, M> { } } - fn translate_c_term_cps(&mut self, c_term: &CTerm, is_tail: bool) -> TypedReturnValue { match (c_term, is_tail) { (CTerm::Force { thunk, .. }, true) => { @@ -147,9 +202,11 @@ impl<'a, M: Module> SimpleCpsFunctionTranslator<'a, M> { let sig_ref = self.function_builder.import_signature(signature); compute_cps_tail_call_base_address(self, next_continuation); let tip_address = self.tip_address; - self.function_builder.ins().return_call_indirect(sig_ref, func_pointer, &[ - tip_address, next_continuation, - ]); + self.function_builder.ins().return_call_indirect( + sig_ref, + func_pointer, + &[tip_address, next_continuation], + ); None } (CTerm::Def { name, effect }, true) => { @@ -158,7 +215,9 @@ impl<'a, M: Module> SimpleCpsFunctionTranslator<'a, M> { let next_continuation = self.next_continuation; compute_cps_tail_call_base_address(self, next_continuation); let tip_address = self.tip_address; - self.function_builder.ins().return_call(func_ref, &[tip_address, next_continuation]); + self.function_builder + .ins() + .return_call(func_ref, &[tip_address, next_continuation]); } else { let func_ref = self.get_local_function(name, FunctionFlavor::Simple); let next_continuation = self.next_continuation; @@ -166,16 +225,18 @@ impl<'a, M: Module> SimpleCpsFunctionTranslator<'a, M> { let tip_address = self.tip_address; let inst = self.function_builder.ins().call(func_ref, &[tip_address]); let return_ptr = self.function_builder.inst_results(inst)[0]; - invoke_next_continuation_in_the_end(&mut self.function_translator, return_ptr, next_continuation); + invoke_next_continuation_in_the_end( + &mut self.function_translator, + return_ptr, + next_continuation, + ); } None } (CTerm::CaseInt { .. }, _) => { let s = self as *mut SimpleCpsFunctionTranslator; - self.translate_case_int(c_term, is_tail, |c_term, is_tail| { - unsafe { - (*s).translate_c_term_cps(c_term, is_tail) - } + self.translate_case_int(c_term, is_tail, |c_term, is_tail| unsafe { + (*s).translate_c_term_cps(c_term, is_tail) }) } (CTerm::OperationCall { eff, args, .. }, true) => { @@ -194,13 +255,18 @@ impl<'a, M: Module> SimpleCpsFunctionTranslator<'a, M> { } else { Effect::Simple }; - self.translate_redex(c_term, is_tail, context_effect, |c_term, is_tail| { - unsafe { - (*s).translate_c_term_cps(c_term, is_tail) - } + self.translate_redex(c_term, is_tail, context_effect, |c_term, is_tail| unsafe { + (*s).translate_c_term_cps(c_term, is_tail) }) } - (CTerm::Let { box t, bound_index, box body }, _) => { + ( + CTerm::Let { + box t, + bound_index, + box body, + }, + _, + ) => { let t_value = self.translate_c_term_cps(t, false); self.local_vars[*bound_index] = t_value; self.translate_c_term_cps(body, is_tail) @@ -211,13 +277,19 @@ impl<'a, M: Module> SimpleCpsFunctionTranslator<'a, M> { self.translate_handler(is_tail, c_term, next_continuation) } // These terms cannot return a computation so we just extract the return value with the simple translator. - (CTerm::Return { .. } | CTerm::Lambda { .. } | CTerm::MemGet { .. } | CTerm::MemSet { .. } | - CTerm::PrimitiveCall { .. } | CTerm::OperationCall { effect: Effect::Simple, .. }, _) => { - self.translate_c_term(c_term, false) - } - (_, false) => { - self.translate_c_term(c_term, false) - } + ( + CTerm::Return { .. } + | CTerm::Lambda { .. } + | CTerm::MemGet { .. } + | CTerm::MemSet { .. } + | CTerm::PrimitiveCall { .. } + | CTerm::OperationCall { + effect: Effect::Simple, + .. + }, + _, + ) => self.translate_c_term(c_term, false), + (_, false) => self.translate_c_term(c_term, false), } } } @@ -292,31 +364,62 @@ impl<'a, M: Module> ComplexCpsFunctionTranslator<'a, M> { function_builder.block_params(entry_block)[0] }, // We don't do anything here and each block will load the parameters on demand. - |_, _, _, _| None); + |_, _, _, _| None, + ); let num_local_vars = function_definition.var_bound - function_definition.args.len(); let continuation_size = 32 + num_local_vars * 8; - let continuation_size_value = translator.function_builder.ins().iconst(I64, continuation_size as i64); + let continuation_size_value = translator + .function_builder + .ins() + .iconst(I64, continuation_size as i64); let inst = translator.call_builtin_func(BuiltinFunction::Alloc, &[continuation_size_value]); let continuation = translator.function_builder.inst_results(inst)[0]; // initialize continuation object - let cps_impl_func_ref = translator.module.declare_func_in_func(cps_impl_func_id, translator.function_builder.func); - let cps_impl_func_addr = translator.function_builder.ins().func_addr(I64, cps_impl_func_ref); + let cps_impl_func_ref = translator + .module + .declare_func_in_func(cps_impl_func_id, translator.function_builder.func); + let cps_impl_func_addr = translator + .function_builder + .ins() + .func_addr(I64, cps_impl_func_ref); // set up continuation impl function pointer - translator.function_builder.ins().store(MemFlags::new(), cps_impl_func_addr, continuation, 0); + translator.function_builder.ins().store( + MemFlags::new(), + cps_impl_func_addr, + continuation, + 0, + ); // set up next continuation - translator.function_builder.ins().store(MemFlags::new(), next_continuation, continuation, 16); + translator.function_builder.ins().store( + MemFlags::new(), + next_continuation, + continuation, + 16, + ); // state defaults to 0 so there is nothing to do for it. // frame height defaults to 0 so there is nothing to do for it. // Initially sets the last result pointer to the base address -8 so that the tip address is // updated correctly when the continuation implementation function is called. - let last_result_ptr = translator.function_builder.ins().iadd_imm(translator.base_address, -8); - translator.function_builder.ins().return_call(cps_impl_func_ref, &[translator.base_address, continuation, last_result_ptr]); + let last_result_ptr = translator + .function_builder + .ins() + .iadd_imm(translator.base_address, -8); + translator.function_builder.ins().return_call( + cps_impl_func_ref, + &[translator.base_address, continuation, last_result_ptr], + ); translator.function_builder.finalize(); let cps_name = FunctionFlavor::Cps.decorate_name(name); let func_id = compiler.local_functions.get(&cps_name).unwrap(); - SimpleFunctionTranslator::define_function(&mut compiler.module, &mut compiler.ctx, &cps_name, *func_id, clir); + SimpleFunctionTranslator::define_function( + &mut compiler.module, + &mut compiler.ctx, + &cps_name, + *func_id, + clir, + ); } fn compile_cps_impl_function( @@ -328,7 +431,10 @@ impl<'a, M: Module> ComplexCpsFunctionTranslator<'a, M> { case_blocks: HashMap, usize, usize)>, clir: &mut Option<&mut Vec<(String, String)>>, ) -> FuncId { - assert!(num_blocks > 1, "if there is only a single block, one should not create a cps_impl function at all!"); + assert!( + num_blocks > 1, + "if there is only a single block, one should not create a cps_impl function at all!" + ); let mut translator = ComplexCpsFunctionTranslator::new( compiler, function_definition, @@ -336,7 +442,8 @@ impl<'a, M: Module> ComplexCpsFunctionTranslator<'a, M> { num_blocks, case_blocks, ); - let typed_return_value = translator.translate_c_term_cps_impl(&function_definition.body, true); + let typed_return_value = + translator.translate_c_term_cps_impl(&function_definition.body, true); match typed_return_value { None => { // Nothing to do since tail call is already a terminating instruction @@ -349,16 +456,37 @@ impl<'a, M: Module> ComplexCpsFunctionTranslator<'a, M> { ); let continuation = translator.continuation; - let next_continuation = translator.function_builder.ins().load(I64, MemFlags::new(), continuation, 16); - - invoke_next_continuation_in_the_end(&mut translator, return_address, next_continuation); + let next_continuation = + translator + .function_builder + .ins() + .load(I64, MemFlags::new(), continuation, 16); + + invoke_next_continuation_in_the_end( + &mut translator, + return_address, + next_continuation, + ); } } translator.function_translator.function_builder.finalize(); let cps_impl_name = FunctionFlavor::CpsImpl.decorate_name(name); - let func_id = compiler.module.declare_function(&cps_impl_name, Linkage::Local, &compiler.uniform_cps_impl_func_signature).unwrap(); - SimpleFunctionTranslator::define_function(&mut compiler.module, &mut compiler.ctx, &cps_impl_name, func_id, clir); + let func_id = compiler + .module + .declare_function( + &cps_impl_name, + Linkage::Local, + &compiler.uniform_cps_impl_func_signature, + ) + .unwrap(); + SimpleFunctionTranslator::define_function( + &mut compiler.module, + &mut compiler.ctx, + &cps_impl_name, + func_id, + clir, + ); func_id } @@ -384,16 +512,32 @@ impl<'a, M: Module> ComplexCpsFunctionTranslator<'a, M> { function_builder.block_params(entry_block)[0] }, // We don't do anything here and each block will load the parameters on demand. - |_, _, _, _| None); - let state = function_translator.function_builder.ins().load(I64, MemFlags::new(), continuation, 24); + |_, _, _, _| None, + ); + let state = + function_translator + .function_builder + .ins() + .load(I64, MemFlags::new(), continuation, 24); // local vars are stored in the continuation object starting at the fifth word - function_translator.local_var_ptr = function_translator.function_builder.ins().iadd_imm(continuation, 32); + function_translator.local_var_ptr = function_translator + .function_builder + .ins() + .iadd_imm(continuation, 32); // set the tip address according to the last result pointer. - let start_tip_address = function_translator.function_builder.ins().iadd_imm(last_result_ptr, 8); + let start_tip_address = function_translator + .function_builder + .ins() + .iadd_imm(last_result_ptr, 8); function_translator.tip_address = start_tip_address; - let last_result = function_translator.function_builder.ins().load(I64, MemFlags::new(), last_result_ptr, 0); + let last_result = function_translator.function_builder.ins().load( + I64, + MemFlags::new(), + last_result_ptr, + 0, + ); let argument_count = function_definition.args.len(); @@ -403,11 +547,13 @@ impl<'a, M: Module> ComplexCpsFunctionTranslator<'a, M> { // skip case blocks because these blocks won't suspend. let mut skipped_block_ids = HashSet::new(); - case_blocks.values().for_each(|(branch_block_ids, default_block_id, joining_block_id)| { - skipped_block_ids.extend(branch_block_ids); - skipped_block_ids.insert(*default_block_id); - skipped_block_ids.insert(*joining_block_id); - }); + case_blocks + .values() + .for_each(|(branch_block_ids, default_block_id, joining_block_id)| { + skipped_block_ids.extend(branch_block_ids); + skipped_block_ids.insert(*default_block_id); + skipped_block_ids.insert(*joining_block_id); + }); for i in 0..num_blocks { let block = function_translator.function_builder.create_block(); blocks.push(block); @@ -418,9 +564,15 @@ impl<'a, M: Module> ComplexCpsFunctionTranslator<'a, M> { let first_block = blocks[0]; // The state number cannot be outside of the range of the switch table so the default block // is unreachable. Hence we just arbitrarily set it to the first block. - switch.emit(&mut function_translator.function_builder, state, first_block); + switch.emit( + &mut function_translator.function_builder, + state, + first_block, + ); function_translator.function_builder.seal_block(first_block); - function_translator.function_builder.switch_to_block(first_block); + function_translator + .function_builder + .switch_to_block(first_block); Self { function_translator, @@ -439,32 +591,40 @@ impl<'a, M: Module> ComplexCpsFunctionTranslator<'a, M> { match (c_term, is_tail) { (CTerm::Redex { .. }, _) => { let s = self as *mut ComplexCpsFunctionTranslator; - self.translate_redex(c_term, is_tail, Effect::Complex, |c_term, is_tail| { - unsafe { - (*s).translate_c_term_cps_impl(c_term, is_tail) - } + self.translate_redex(c_term, is_tail, Effect::Complex, |c_term, is_tail| unsafe { + (*s).translate_c_term_cps_impl(c_term, is_tail) }) } - (CTerm::Force { thunk, effect: Effect::Complex }, _) => { + ( + CTerm::Force { + thunk, + effect: Effect::Complex, + }, + _, + ) => { let continuation = self.continuation; let func_pointer = self.process_thunk(thunk); let signature = self.uniform_cps_func_signature.clone(); let sig_ref = self.function_builder.import_signature(signature); if is_tail { - let next_continuation = self.adjust_next_continuation_frame_height(continuation); + let next_continuation = + self.adjust_next_continuation_frame_height(continuation); let tip_address = self.tip_address; - self.function_builder.ins().return_call_indirect(sig_ref, func_pointer, &[ - tip_address, next_continuation, - ]); + self.function_builder.ins().return_call_indirect( + sig_ref, + func_pointer, + &[tip_address, next_continuation], + ); None } else { self.pack_up_continuation(); let tip_address = self.tip_address; - self.function_builder.ins().return_call_indirect(sig_ref, func_pointer, &[ - tip_address, - continuation, - ]); + self.function_builder.ins().return_call_indirect( + sig_ref, + func_pointer, + &[tip_address, continuation], + ); self.advance_for_complex_effect() } } @@ -473,18 +633,34 @@ impl<'a, M: Module> ComplexCpsFunctionTranslator<'a, M> { let next_continuation = self.adjust_next_continuation_frame_height(continuation); self.invoke_thunk(is_tail, thunk, next_continuation) } - (CTerm::Let { box t, bound_index, box body }, _) => { + ( + CTerm::Let { + box t, + bound_index, + box body, + }, + _, + ) => { let t_value = self.translate_c_term_cps_impl(t, false); self.local_vars[*bound_index] = t_value; self.touched_vars_in_current_session.insert(*bound_index); self.translate_c_term_cps_impl(body, is_tail) } - (CTerm::Def { name, effect: Effect::Complex }, _) => { + ( + CTerm::Def { + name, + effect: Effect::Complex, + }, + _, + ) => { let func_ref = self.get_local_function(name, FunctionFlavor::Cps); if is_tail { - let next_continuation = self.adjust_next_continuation_frame_height(self.continuation); + let next_continuation = + self.adjust_next_continuation_frame_height(self.continuation); let tip_address = self.tip_address; - self.function_builder.ins().return_call(func_ref, &[tip_address, next_continuation]); + self.function_builder + .ins() + .return_call(func_ref, &[tip_address, next_continuation]); None } else { self.pack_up_continuation(); @@ -500,16 +676,32 @@ impl<'a, M: Module> ComplexCpsFunctionTranslator<'a, M> { let return_ptr = self.function_builder.inst_results(inst)[0]; let continuation = self.continuation; let next_continuation = self.adjust_next_continuation_frame_height(continuation); - invoke_next_continuation_in_the_end(&mut self.function_translator, return_ptr, next_continuation); + invoke_next_continuation_in_the_end( + &mut self.function_translator, + return_ptr, + next_continuation, + ); None } - (CTerm::CaseInt { t, result_type, branches, default_branch }, _) => { - let (branch_block_ids, default_block_id, joining_block_id) = &self.case_blocks[&self.current_block_id].clone(); + ( + CTerm::CaseInt { + t, + result_type, + branches, + default_branch, + }, + _, + ) => { + let (branch_block_ids, default_block_id, joining_block_id) = + &self.case_blocks[&self.current_block_id].clone(); let t_value = self.translate_v_term(t); let t_value = self.convert_to_special(t_value, Integer); // Create table jump - let branch_body_and_blocks: Vec<_> = branches.iter().zip(branch_block_ids.iter().map(|id| (*id, self.blocks[*id]))).collect(); + let branch_body_and_blocks: Vec<_> = branches + .iter() + .zip(branch_block_ids.iter().map(|id| (*id, self.blocks[*id]))) + .collect(); let mut switch = Switch::new(); for ((value, _), (_, branch_block)) in branch_body_and_blocks.iter() { switch.set_entry(*value as u128, *branch_block); @@ -524,7 +716,8 @@ impl<'a, M: Module> ComplexCpsFunctionTranslator<'a, M> { CType::SpecializedF(vty) => vty, }; // return value - self.function_builder.append_block_param(joining_block, result_v_type.get_type()); + self.function_builder + .append_block_param(joining_block, result_v_type.get_type()); // tip address self.function_builder.append_block_param(joining_block, I64); @@ -534,24 +727,46 @@ impl<'a, M: Module> ComplexCpsFunctionTranslator<'a, M> { self.advance_for_case(); assert_eq!(self.current_block_id, block_id); self.tip_address = start_tip_address; - self.create_branch_block(branch_block, is_tail, joining_block, result_v_type, Some(c_term)); + self.create_branch_block( + branch_block, + is_tail, + joining_block, + result_v_type, + Some(c_term), + ); } self.advance_for_case(); assert_eq!(self.current_block_id, *default_block_id); self.tip_address = start_tip_address; - self.create_branch_block(default_block, is_tail, joining_block, result_v_type, match default_branch { - None => None, - Some(box branch) => Some(branch), - }); + self.create_branch_block( + default_block, + is_tail, + joining_block, + result_v_type, + match default_branch { + None => None, + Some(box branch) => Some(branch), + }, + ); // Switch to joining block for future code generation self.advance_for_case(); assert_eq!(self.current_block_id, *joining_block_id); self.tip_address = self.function_builder.block_params(joining_block)[1]; - Some((self.function_builder.block_params(joining_block)[0], *result_v_type)) + Some(( + self.function_builder.block_params(joining_block)[0], + *result_v_type, + )) } - (CTerm::OperationCall { eff, args, effect: Effect::Complex }, _) => { + ( + CTerm::OperationCall { + eff, + args, + effect: Effect::Complex, + }, + _, + ) => { let eff_value = self.translate_v_term(eff); let eff_value = self.convert_to_uniform(eff_value); self.push_arg_v_terms(args); @@ -579,27 +794,45 @@ impl<'a, M: Module> ComplexCpsFunctionTranslator<'a, M> { self.translate_handler(is_tail, c_term, continuation) } // These terms cannot return a computation so we just extract the return value with the simple translator. - (CTerm::Return { .. } | CTerm::Lambda { .. } | CTerm::MemGet { .. } | CTerm::MemSet { .. } | - CTerm::PrimitiveCall { .. } | CTerm::OperationCall { effect: Effect::Simple, .. }, _) => { - self.translate_c_term(c_term, false) - } - (_, false) => { - self.translate_c_term(c_term, false) - } + ( + CTerm::Return { .. } + | CTerm::Lambda { .. } + | CTerm::MemGet { .. } + | CTerm::MemSet { .. } + | CTerm::PrimitiveCall { .. } + | CTerm::OperationCall { + effect: Effect::Simple, + .. + }, + _, + ) => self.translate_c_term(c_term, false), + (_, false) => self.translate_c_term(c_term, false), } } fn adjust_next_continuation_frame_height(&mut self, continuation: Value) -> Value { - let next_continuation = self.function_builder.ins().load(I64, MemFlags::new(), continuation, 16); + let next_continuation = + self.function_builder + .ins() + .load(I64, MemFlags::new(), continuation, 16); compute_cps_tail_call_base_address(self, next_continuation); next_continuation } - fn create_branch_block(&mut self, branch_block: Block, is_tail: bool, joining_block: Block, result_v_type: &VType, branch: Option<&CTerm>) { + fn create_branch_block( + &mut self, + branch_block: Block, + is_tail: bool, + joining_block: Block, + result_v_type: &VType, + branch: Option<&CTerm>, + ) { self.function_builder.switch_to_block(branch_block); let typed_return_value = match branch { None => { - self.function_builder.ins().trap(TrapCode::UnreachableCodeReached); + self.function_builder + .ins() + .trap(TrapCode::UnreachableCodeReached); None } Some(branch) => self.translate_c_term_cps_impl(branch, is_tail), @@ -611,7 +844,9 @@ impl<'a, M: Module> ComplexCpsFunctionTranslator<'a, M> { Some(..) => { let value = self.adapt_type(typed_return_value, result_v_type); let tip_address = self.tip_address; - self.function_builder.ins().jump(joining_block, &[value, tip_address]); + self.function_builder + .ins() + .jump(joining_block, &[value, tip_address]); } } } @@ -622,24 +857,44 @@ impl<'a, M: Module> ComplexCpsFunctionTranslator<'a, M> { // Store current height let tip_address = self.tip_address; let base_address = self.base_address; - let arg_stack_frame_height_bytes = self.function_builder.ins().isub(base_address, tip_address); - let arg_stack_frame_height = self.function_builder.ins().ushr_imm(arg_stack_frame_height_bytes, 3); - self.function_builder.ins().store(MemFlags::new(), arg_stack_frame_height, continuation, 8); + let arg_stack_frame_height_bytes = + self.function_builder.ins().isub(base_address, tip_address); + let arg_stack_frame_height = self + .function_builder + .ins() + .ushr_imm(arg_stack_frame_height_bytes, 3); + self.function_builder + .ins() + .store(MemFlags::new(), arg_stack_frame_height, continuation, 8); // Store next state let current_block_id = self.current_block_id; - let next_block_id = self.function_builder.ins().iconst(I64, current_block_id as i64 + 1); - self.function_builder.ins().store(MemFlags::new(), next_block_id, continuation, 24); + let next_block_id = self + .function_builder + .ins() + .iconst(I64, current_block_id as i64 + 1); + self.function_builder + .ins() + .store(MemFlags::new(), next_block_id, continuation, 24); // Store local vars let local_var_ptr = self.local_var_ptr; - let mut touched_vars: Vec<_> = self.touched_vars_in_current_session.iter().copied().collect(); + let mut touched_vars: Vec<_> = self + .touched_vars_in_current_session + .iter() + .copied() + .collect(); touched_vars.sort(); for index in touched_vars { let local_var = self.local_vars[index]; let value = self.convert_to_uniform(local_var); let num_args = self.num_args; - self.function_builder.ins().store(MemFlags::new(), value, local_var_ptr, ((index - num_args) * 8) as i32); + self.function_builder.ins().store( + MemFlags::new(), + value, + local_var_ptr, + ((index - num_args) * 8) as i32, + ); } } @@ -665,7 +920,10 @@ impl<'a, M: Module> ComplexCpsFunctionTranslator<'a, M> { } } -fn compute_cps_tail_call_base_address(translator: &mut SimpleFunctionTranslator, next_continuation: Value) { +fn compute_cps_tail_call_base_address( + translator: &mut SimpleFunctionTranslator, + next_continuation: Value, +) { let base_address = translator.base_address; // accommodate the height of the next continuation is updated here because tail // call causes the next continuation to be directly passed to the callee, which, @@ -673,23 +931,48 @@ fn compute_cps_tail_call_base_address(translator: &mut SimpleFunction // address from this new height. The height can be different because the // callee args are effectively altered by the current function. let new_base_address = translator.copy_tail_call_args_and_get_new_base(); - let offset_in_bytes = translator.function_builder.ins().isub(base_address, new_base_address); + let offset_in_bytes = translator + .function_builder + .ins() + .isub(base_address, new_base_address); translator.adjust_continuation_height(next_continuation, offset_in_bytes); translator.tip_address = new_base_address; } -fn invoke_next_continuation_in_the_end(translator: &mut SimpleFunctionTranslator, return_address: Value, next_continuation: Value) { +fn invoke_next_continuation_in_the_end( + translator: &mut SimpleFunctionTranslator, + return_address: Value, + next_continuation: Value, +) { // compute next base address - let next_continuation_height = translator.function_builder.ins().load(I64, MemFlags::new(), next_continuation, 8); - let next_continuation_height_bytes = translator.function_builder.ins().ishl_imm(next_continuation_height, 3); + let next_continuation_height = + translator + .function_builder + .ins() + .load(I64, MemFlags::new(), next_continuation, 8); + let next_continuation_height_bytes = translator + .function_builder + .ins() + .ishl_imm(next_continuation_height, 3); let base_address = translator.base_address; - let next_base_address = translator.function_builder.ins().iadd(base_address, next_continuation_height_bytes); + let next_base_address = translator + .function_builder + .ins() + .iadd(base_address, next_continuation_height_bytes); // get next continuation impl function - let next_continuation_impl = translator.function_builder.ins().load(I64, MemFlags::new(), next_continuation, 0); + let next_continuation_impl = + translator + .function_builder + .ins() + .load(I64, MemFlags::new(), next_continuation, 0); // call the next continuation let signature = translator.uniform_cps_impl_func_signature.clone(); let sig_ref = translator.function_builder.import_signature(signature); - translator.function_builder.ins().return_call_indirect(sig_ref, next_continuation_impl, &[next_base_address, next_continuation, return_address]); + translator.function_builder.ins().return_call_indirect( + sig_ref, + next_continuation_impl, + &[next_base_address, next_continuation, return_address], + ); } diff --git a/src/backend/function_analyzer.rs b/src/backend/function_analyzer.rs index 47e1e49..d1d4fdc 100644 --- a/src/backend/function_analyzer.rs +++ b/src/backend/function_analyzer.rs @@ -1,5 +1,5 @@ -use std::collections::HashMap; use crate::ast::term::{CTerm, Effect}; +use std::collections::HashMap; #[derive(Debug)] pub struct FunctionAnalyzer { @@ -21,7 +21,10 @@ impl FunctionAnalyzer { pub(crate) fn analyze(&mut self, c_term: &CTerm, is_tail: bool) { match c_term { CTerm::Redex { function, .. } => self.analyze(function, is_tail), - CTerm::Force { effect: Effect::Complex, .. } if !is_tail => { + CTerm::Force { + effect: Effect::Complex, + .. + } if !is_tail => { self.count += 1; self.has_non_tail_complex_effects = true; } @@ -29,11 +32,18 @@ impl FunctionAnalyzer { self.analyze(t, false); self.analyze(body, is_tail); } - CTerm::Def { effect: Effect::Complex, .. } if !is_tail => { + CTerm::Def { + effect: Effect::Complex, + .. + } if !is_tail => { self.count += 1; self.has_non_tail_complex_effects = true; } - CTerm::CaseInt { branches, default_branch, .. } => { + CTerm::CaseInt { + branches, + default_branch, + .. + } => { let current_block_id = self.count - 1; let mut branch_block_ids = Vec::new(); for (_, branch) in branches { @@ -49,9 +59,15 @@ impl FunctionAnalyzer { } let joining_block_id = self.count; self.count += 1; - self.case_blocks.insert(current_block_id, (branch_block_ids, default_block_id, joining_block_id)); + self.case_blocks.insert( + current_block_id, + (branch_block_ids, default_block_id, joining_block_id), + ); } - CTerm::OperationCall { effect: Effect::Complex, .. } if !is_tail => { + CTerm::OperationCall { + effect: Effect::Complex, + .. + } if !is_tail => { self.count += 1; self.has_non_tail_complex_effects = true; } diff --git a/src/backend/mod.rs b/src/backend/mod.rs index 2768113..e0b672f 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -1,5 +1,5 @@ mod common; pub(crate) mod compiler; -mod simple_function_translator; mod cps_function_translator; mod function_analyzer; +mod simple_function_translator; diff --git a/src/backend/simple_function_translator.rs b/src/backend/simple_function_translator.rs index d3101dd..4ebba88 100644 --- a/src/backend/simple_function_translator.rs +++ b/src/backend/simple_function_translator.rs @@ -1,20 +1,22 @@ -use std::collections::HashMap; -use std::iter; +use crate::ast::primitive_functions::PRIMITIVE_FUNCTIONS; +use crate::ast::signature::FunctionDefinition; +use crate::ast::term::{CTerm, CType, Effect, PType, SpecializedType, VTerm, VType}; +use crate::backend::common::{ + BuiltinData, BuiltinFunction, FunctionFlavor, HasType, TypedReturnValue, TypedValue, +}; +use crate::backend::compiler::Compiler; +use archon_vm_runtime::runtime::HandlerType; use cranelift::codegen::ir::{Endianness, FuncRef, Inst, StackSlot}; use cranelift::frontend::Switch; -use cranelift::prelude::*; use cranelift::prelude::types::{F32, I32, I64}; +use cranelift::prelude::*; use cranelift_module::{DataDescription, DataId, FuncId, Linkage, Module}; -use crate::ast::term::{CTerm, VTerm, VType, SpecializedType, PType, CType, Effect}; -use enum_map::{EnumMap}; +use enum_map::EnumMap; use enum_ordinalize::Ordinalize; -use archon_vm_runtime::runtime::HandlerType; -use VType::{Specialized, Uniform}; +use std::collections::HashMap; +use std::iter; use SpecializedType::{Integer, PrimitivePtr, StructPtr}; -use crate::backend::common::{BuiltinData, BuiltinFunction, FunctionFlavor, HasType, TypedReturnValue, TypedValue}; -use crate::ast::primitive_functions::PRIMITIVE_FUNCTIONS; -use crate::ast::signature::FunctionDefinition; -use crate::backend::compiler::Compiler; +use VType::{Specialized, Uniform}; pub struct SimpleFunctionTranslator<'a, M: Module> { pub module: &'a mut M, @@ -94,7 +96,12 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { // stack grows from higher address to lower address, so parameter list grows in the // reverse order and hence the offset is the index of the parameter in the parameter // list. - let value = translator.function_builder.ins().load(I64, MemFlags::new(), translator.base_address, (i * 8) as i32); + let value = translator.function_builder.ins().load( + I64, + MemFlags::new(), + translator.base_address, + (i * 8) as i32, + ); Some((value, Uniform)) }, ); @@ -107,7 +114,10 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { return_value_or_param, function_definition.args.len(), ); - translator.function_builder.ins().return_(&[return_value_address]); + translator + .function_builder + .ins() + .return_(&[return_value_address]); } None => { // Nothing to do since tail call is already a terminating instruction. @@ -118,14 +128,29 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { let simple_name = FunctionFlavor::Simple.decorate_name(name); let func_id = compiler.local_functions.get(&simple_name).unwrap(); - SimpleFunctionTranslator::define_function(&mut compiler.module, &mut compiler.ctx, &simple_name, *func_id, clir); + SimpleFunctionTranslator::define_function( + &mut compiler.module, + &mut compiler.ctx, + &simple_name, + *func_id, + clir, + ); } - pub fn store_return_value_on_argument_stack(&mut self, return_value_or_param: TypedReturnValue, num_args: usize) -> Value { + pub fn store_return_value_on_argument_stack( + &mut self, + return_value_or_param: TypedReturnValue, + num_args: usize, + ) -> Value { let value = self.convert_to_uniform(return_value_or_param); let return_address_offset = (num_args as i64 - 1) * 8; - let return_address = self.function_builder.ins().iadd_imm(self.base_address, return_address_offset); - self.function_builder.ins().store(MemFlags::new(), value, return_address, 0); + let return_address = self + .function_builder + .ins() + .iadd_imm(self.base_address, return_address_offset); + self.function_builder + .ins() + .store(MemFlags::new(), value, return_address, 0); return_address } pub fn compile_specialized_function( @@ -160,7 +185,9 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { let return_value_or_param = translator.translate_c_term(&function_definition.body, true); match return_value_or_param { Some(_) => { - let CType::SpecializedF(v_type) = function_definition.c_type else { unreachable!() }; + let CType::SpecializedF(v_type) = function_definition.c_type else { + unreachable!() + }; let value = translator.adapt_type(return_value_or_param, &v_type); translator.function_builder.ins().return_(&[value]); } @@ -173,7 +200,13 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { let specialized_name = FunctionFlavor::Specialized.decorate_name(name); let func_id = compiler.local_functions.get(&specialized_name).unwrap(); - SimpleFunctionTranslator::define_function(&mut compiler.module, &mut compiler.ctx, &specialized_name, *func_id, clir); + SimpleFunctionTranslator::define_function( + &mut compiler.module, + &mut compiler.ctx, + &specialized_name, + *func_id, + clir, + ); } pub fn new( compiler: &'a mut Compiler, @@ -182,17 +215,30 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { local_function_arg_types: &'a HashMap, CType)>, is_specialized: bool, base_address_getter: F, - parameter_initializer: fn(&mut SimpleFunctionTranslator, Block, usize, &VType) -> TypedReturnValue, - ) -> SimpleFunctionTranslator<'a, M> where F: FnOnce(&mut FunctionBuilder, Block) -> Value { + parameter_initializer: fn( + &mut SimpleFunctionTranslator, + Block, + usize, + &VType, + ) -> TypedReturnValue, + ) -> SimpleFunctionTranslator<'a, M> + where + F: FnOnce(&mut FunctionBuilder, Block) -> Value, + { compiler.ctx.clear(); compiler.ctx.func.signature = sig; - let mut function_builder = FunctionBuilder::new(&mut compiler.ctx.func, &mut compiler.builder_context); + let mut function_builder = + FunctionBuilder::new(&mut compiler.ctx.func, &mut compiler.builder_context); let entry_block = function_builder.create_block(); // Allocate slot for storing the tip address so that a pointer to the tip address can be // passed to built-in force call helper function in order to have the tip address updated. - let tip_address_slot = function_builder.create_sized_stack_slot(StackSlotData::new(StackSlotKind::ExplicitSlot, 8, 0)); + let tip_address_slot = function_builder.create_sized_stack_slot(StackSlotData::new( + StackSlotKind::ExplicitSlot, + 8, + 0, + )); function_builder.append_block_params_for_function_params(entry_block); function_builder.switch_to_block(entry_block); function_builder.seal_block(entry_block); @@ -222,7 +268,8 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { // Here we transform the function body to non-specialized version, hence the argument types // are ignored. for (i, (v, v_type)) in function_definition.args.iter().enumerate() { - translator.local_vars[*v] = parameter_initializer(&mut translator, entry_block, i, v_type); + translator.local_vars[*v] = + parameter_initializer(&mut translator, entry_block, i, v_type); } translator } @@ -231,23 +278,20 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { match c_term { CTerm::Redex { .. } => { let s = self as *mut SimpleFunctionTranslator; - self.translate_redex( - c_term, - is_tail, - Effect::Simple, - |c_term, is_tail| { - unsafe { - (*s).translate_c_term(c_term, is_tail) - } - }, - ) + self.translate_redex(c_term, is_tail, Effect::Simple, |c_term, is_tail| unsafe { + (*s).translate_c_term(c_term, is_tail) + }) } CTerm::Return { value } => self.translate_v_term(value), CTerm::Force { thunk, .. } => { let trivial_continuation = self.get_builtin_data(BuiltinData::TrivialContinuation); self.invoke_thunk(is_tail, thunk, trivial_continuation) } - CTerm::Let { box t, bound_index, box body } => { + CTerm::Let { + box t, + bound_index, + box body, + } => { let t_value = self.translate_c_term(t, false); self.local_vars[*bound_index] = t_value; self.translate_c_term(body, is_tail) @@ -256,19 +300,22 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { let func_ref = self.get_local_function(name, FunctionFlavor::Simple); if is_tail && !self.is_specialized { let base_address = self.copy_tail_call_args_and_get_new_base(); - self.function_builder.ins().return_call(func_ref, &[base_address]); + self.function_builder + .ins() + .return_call(func_ref, &[base_address]); None } else { - let inst = self.function_builder.ins().call(func_ref, &[self.tip_address]); + let inst = self + .function_builder + .ins() + .call(func_ref, &[self.tip_address]); self.extract_return_value(inst) } } CTerm::CaseInt { .. } => { let s = self as *mut SimpleFunctionTranslator; - self.translate_case_int(c_term, is_tail, |c_term, is_tail| { - unsafe { - (*s).translate_c_term(c_term, is_tail) - } + self.translate_case_int(c_term, is_tail, |c_term, is_tail| unsafe { + (*s).translate_c_term(c_term, is_tail) }) } CTerm::Lambda { .. } => unreachable!("lambda should have been lifted away"), @@ -279,10 +326,17 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { let offset_value = self.convert_to_special(offset_value, Integer); let offset_value = self.function_builder.ins().ishl_imm(offset_value, 3); let load_address = self.function_builder.ins().iadd(base_value, offset_value); - let value = self.function_builder.ins().load(I64, MemFlags::new(), load_address, 0); + let value = self + .function_builder + .ins() + .load(I64, MemFlags::new(), load_address, 0); Some((value, Uniform)) } - CTerm::MemSet { base, offset, value } => { + CTerm::MemSet { + base, + offset, + value, + } => { let base_value = self.translate_v_term(base); let offset_value = self.translate_v_term(offset); let value_value = self.translate_v_term(value); @@ -291,15 +345,26 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { let offset_value = self.function_builder.ins().ishl_imm(offset_value, 3); let value_value = self.convert_to_uniform(value_value); let store_address = self.function_builder.ins().iadd(base_value, offset_value); - self.function_builder.ins().store(MemFlags::new(), value_value, store_address, 0); + self.function_builder + .ins() + .store(MemFlags::new(), value_value, store_address, 0); // Return the base address so that the caller can continue to use it. Some((base_value, Specialized(StructPtr))) } CTerm::PrimitiveCall { name, args } => { - let args = args.iter().map(|arg| { self.translate_v_term(arg) }).collect::>(); + let args = args + .iter() + .map(|arg| self.translate_v_term(arg)) + .collect::>(); let primitive_function = *PRIMITIVE_FUNCTIONS.get(name).unwrap(); - let arg_values: Vec = primitive_function.arg_types.iter().zip(args).map(|(ty, arg)| self.adapt_type(arg, ty)).collect(); - let return_value = (primitive_function.code_gen)(&mut self.function_builder, &arg_values); + let arg_values: Vec = primitive_function + .arg_types + .iter() + .zip(args) + .map(|(ty, arg)| self.adapt_type(arg, ty)) + .collect(); + let return_value = + (primitive_function.code_gen)(&mut self.function_builder, &arg_values); Some((return_value, primitive_function.return_type)) } CTerm::OperationCall { eff, args, .. } => { @@ -309,7 +374,13 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { self.push_arg_v_terms(args); let use_tail_call = is_tail && !self.is_specialized; self.update_tip_address(use_tail_call); - let inst = self.handle_operation_call(eff_value, trivial_continuation, args.len(), false, use_tail_call); + let inst = self.handle_operation_call( + eff_value, + trivial_continuation, + args.len(), + false, + use_tail_call, + ); if use_tail_call { None } else { @@ -324,21 +395,32 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { } } - pub fn translate_redex(&mut self, c_term: &CTerm, is_tail: bool, context_effect: Effect, mut translate_c_term: F) -> TypedReturnValue where F: FnMut(&CTerm, bool) -> TypedReturnValue { - let CTerm::Redex { box function, args } = c_term else { unreachable!() }; + pub fn translate_redex( + &mut self, + c_term: &CTerm, + is_tail: bool, + context_effect: Effect, + mut translate_c_term: F, + ) -> TypedReturnValue + where + F: FnMut(&CTerm, bool) -> TypedReturnValue, + { + let CTerm::Redex { box function, args } = c_term else { + unreachable!() + }; if let CTerm::Def { name, effect } = function { // Handle specialized function call let (arg_types, return_type) = self.local_function_arg_types.get(name).unwrap(); - if let CType::SpecializedF(return_type) = return_type && arg_types.len() == args.len() && (effect.intersect(context_effect) != Effect::Complex) { + if let CType::SpecializedF(return_type) = return_type + && arg_types.len() == args.len() + && (effect.intersect(context_effect) != Effect::Complex) + { let tip_address = self.tip_address; let all_args = iter::once(tip_address) - .chain(args.iter() - .zip(arg_types) - .map(|(arg, v_type)| { - let arg = self.translate_v_term(arg); - self.adapt_type(arg, v_type) - } - )) + .chain(args.iter().zip(arg_types).map(|(arg, v_type)| { + let arg = self.translate_v_term(arg); + self.adapt_type(arg, v_type) + })) .collect::>(); let func_ref = self.get_local_function(name, FunctionFlavor::Specialized); if is_tail && self.is_specialized { @@ -354,8 +436,24 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { translate_c_term(function, is_tail) } - pub fn translate_case_int(&mut self, c_term: &CTerm, is_tail: bool, mut translate_c_term: F) -> TypedReturnValue where F: FnMut(&CTerm, bool) -> TypedReturnValue { - let CTerm::CaseInt { t, branches, default_branch, result_type } = c_term else { unreachable!() }; + pub fn translate_case_int( + &mut self, + c_term: &CTerm, + is_tail: bool, + mut translate_c_term: F, + ) -> TypedReturnValue + where + F: FnMut(&CTerm, bool) -> TypedReturnValue, + { + let CTerm::CaseInt { + t, + branches, + default_branch, + result_type, + } = c_term + else { + unreachable!() + }; let branch_map: HashMap<_, _> = branches.iter().map(|(i, v)| (i, v)).collect(); let t_value = self.translate_v_term(t); let t_value = self.convert_to_special(t_value, Integer); @@ -368,7 +466,8 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { }; let result_value_type = result_v_type.get_type(); // return value - self.function_builder.append_block_param(joining_block, result_value_type); + self.function_builder + .append_block_param(joining_block, result_value_type); // tip address self.function_builder.append_block_param(joining_block, I64); @@ -395,20 +494,37 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { for (value, branch_block) in branch_blocks.into_iter() { self.tip_address = start_tip_address; let branch = branch_map.get(&value).unwrap(); - self.create_branch_block(branch_block, is_tail, joining_block, result_v_type, Some(branch), &mut translate_c_term); + self.create_branch_block( + branch_block, + is_tail, + joining_block, + result_v_type, + Some(branch), + &mut translate_c_term, + ); } self.tip_address = start_tip_address; - self.create_branch_block(default_block, is_tail, joining_block, result_v_type, match default_branch { - None => None, - Some(box branch) => Some(branch), - }, translate_c_term); + self.create_branch_block( + default_block, + is_tail, + joining_block, + result_v_type, + match default_branch { + None => None, + Some(box branch) => Some(branch), + }, + translate_c_term, + ); // Switch to joining block for future code generation self.function_builder.seal_all_blocks(); self.function_builder.switch_to_block(joining_block); self.tip_address = self.function_builder.block_params(joining_block)[1]; - Some((self.function_builder.block_params(joining_block)[0], *result_v_type)) + Some(( + self.function_builder.block_params(joining_block)[0], + *result_v_type, + )) } fn update_tip_address(&mut self, use_tail_call: bool) { @@ -417,15 +533,23 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { } } - pub fn translate_handler(&mut self, is_tail: bool, handler_c_term: &CTerm, next_continuation: Value) -> TypedReturnValue { + pub fn translate_handler( + &mut self, + is_tail: bool, + handler_c_term: &CTerm, + next_continuation: Value, + ) -> TypedReturnValue { let CTerm::Handler { parameter, parameter_disposer, parameter_replicator, transform, handlers, - input - } = handler_c_term else { unreachable!() }; + input, + } = handler_c_term + else { + unreachable!() + }; let parameter_typed_value = self.translate_v_term(parameter); let parameter_value = self.convert_to_uniform(parameter_typed_value); let parameter_disposer_typed_value = match parameter_disposer { @@ -433,7 +557,7 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { None => Some((self.function_builder.ins().iconst(I64, 0b11), Uniform)), }; let parameter_disposer_value = self.convert_to_uniform(parameter_disposer_typed_value); - let parameter_replicator_typed_value = match parameter_replicator { + let parameter_replicator_typed_value = match parameter_replicator { Some(parameter_replicator) => self.translate_v_term(parameter_replicator), None => Some((self.function_builder.ins().iconst(I64, 0b11), Uniform)), }; @@ -444,17 +568,23 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { self.builtin_functions[BuiltinFunction::TransformLoaderCpsImpl], self.function_builder.func, ); - let transform_loader_cps_impl_func_ptr = self.function_builder.ins().func_addr(I64, transform_loader_cps_impl_func_ref); + let transform_loader_cps_impl_func_ptr = self + .function_builder + .ins() + .func_addr(I64, transform_loader_cps_impl_func_ref); let tip_address_ptr = self.store_tip_address_to_stack(); - let inst = self.call_builtin_func(BuiltinFunction::RegisterHandler, &[ - tip_address_ptr, - next_continuation, - parameter_value, - parameter_disposer_value, - parameter_replicator_value, - transform_value, - transform_loader_cps_impl_func_ptr, - ]); + let inst = self.call_builtin_func( + BuiltinFunction::RegisterHandler, + &[ + tip_address_ptr, + next_continuation, + parameter_value, + parameter_disposer_value, + parameter_replicator_value, + transform_value, + transform_loader_cps_impl_func_ptr, + ], + ); let handler = self.function_builder.inst_results(inst)[0]; self.load_tip_address_from_stack(); @@ -462,12 +592,21 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { self.add_handlers(handler, handlers); // The transform loader continuation is the first value inside the handler struct. - let transform_loader_continuation = self.function_builder.ins().load(I64, MemFlags::new(), handler, 0); - let mark_handler_ref = self.module.declare_func_in_func(self.builtin_functions[BuiltinFunction::MarkHandler], self.function_builder.func); + let transform_loader_continuation = + self.function_builder + .ins() + .load(I64, MemFlags::new(), handler, 0); + let mark_handler_ref = self.module.declare_func_in_func( + self.builtin_functions[BuiltinFunction::MarkHandler], + self.function_builder.func, + ); let old_tip_address = self.tip_address; let input_thunk_func_ptr = self.process_thunk(input); // stack grows downwards so if we expect positive offset, we subtract old tip address by new tip address. - let offset_in_bytes = self.function_builder.ins().isub(old_tip_address, self.tip_address); + let offset_in_bytes = self + .function_builder + .ins() + .isub(old_tip_address, self.tip_address); self.adjust_continuation_height(transform_loader_continuation, offset_in_bytes); let args = &[ self.tip_address, @@ -476,7 +615,9 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { handler, ]; if is_tail && !self.is_specialized { - self.function_builder.ins().return_call(mark_handler_ref, args); + self.function_builder + .ins() + .return_call(mark_handler_ref, args); None } else { let inst = self.function_builder.ins().call(mark_handler_ref, args); @@ -485,11 +626,26 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { } pub fn adjust_continuation_height(&mut self, continuation: Value, offset_in_bytes: Value) { - let continuation_height = self.function_builder.ins().load(I64, MemFlags::new(), continuation, 8); - let continuation_height_bytes = self.function_builder.ins().ishl_imm(continuation_height, 3); - let new_continuation_height_bytes = self.function_builder.ins().iadd(continuation_height_bytes, offset_in_bytes); - let new_continuation_height = self.function_builder.ins().ushr_imm(new_continuation_height_bytes, 3); - self.function_builder.ins().store(MemFlags::new(), new_continuation_height, continuation, 8); + let continuation_height = + self.function_builder + .ins() + .load(I64, MemFlags::new(), continuation, 8); + let continuation_height_bytes = + self.function_builder.ins().ishl_imm(continuation_height, 3); + let new_continuation_height_bytes = self + .function_builder + .ins() + .iadd(continuation_height_bytes, offset_in_bytes); + let new_continuation_height = self + .function_builder + .ins() + .ushr_imm(new_continuation_height_bytes, 3); + self.function_builder.ins().store( + MemFlags::new(), + new_continuation_height, + continuation, + 8, + ); } fn add_handlers(&mut self, handler: Value, handler_impls: &Vec<(VTerm, VTerm, HandlerType)>) { @@ -498,27 +654,47 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { let eff_value = self.convert_to_uniform(eff_value); let handler_impl_value = self.translate_v_term(handler_impl); let handler_impl_value = self.convert_to_uniform(handler_impl_value); - let simple_operation_type_value = self.function_builder.ins().iconst(I64, handler_type.ordinal() as i64); - self.call_builtin_func(BuiltinFunction::AddHandler, &[handler, eff_value, handler_impl_value, simple_operation_type_value]); + let simple_operation_type_value = self + .function_builder + .ins() + .iconst(I64, handler_type.ordinal() as i64); + self.call_builtin_func( + BuiltinFunction::AddHandler, + &[ + handler, + eff_value, + handler_impl_value, + simple_operation_type_value, + ], + ); } } - pub(crate) fn invoke_thunk(&mut self, is_tail: bool, thunk: &VTerm, next_continuation: Value) -> Option { + pub(crate) fn invoke_thunk( + &mut self, + is_tail: bool, + thunk: &VTerm, + next_continuation: Value, + ) -> Option { let func_pointer = self.process_thunk(thunk); - let sig_ref = self.function_builder.import_signature(self.uniform_cps_func_signature.clone()); + let sig_ref = self + .function_builder + .import_signature(self.uniform_cps_func_signature.clone()); if is_tail && !self.is_specialized { let base_address = self.copy_tail_call_args_and_get_new_base(); - self.function_builder.ins().return_call_indirect(sig_ref, func_pointer, &[ - base_address, - next_continuation, - ]); + self.function_builder.ins().return_call_indirect( + sig_ref, + func_pointer, + &[base_address, next_continuation], + ); None } else { - let inst = self.function_builder.ins().call_indirect(sig_ref, func_pointer, &[ - self.tip_address, - next_continuation, - ]); + let inst = self.function_builder.ins().call_indirect( + sig_ref, + func_pointer, + &[self.tip_address, next_continuation], + ); self.extract_return_value(inst) } } @@ -531,37 +707,61 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { let thunk_value = self.convert_to_uniform(thunk_value); let tip_address_ptr = self.store_tip_address_to_stack(); - let inst = self.call_builtin_func(BuiltinFunction::ForceThunk, &[thunk_value, tip_address_ptr]); + let inst = + self.call_builtin_func(BuiltinFunction::ForceThunk, &[thunk_value, tip_address_ptr]); let func_pointer = self.function_builder.inst_results(inst)[0]; self.load_tip_address_from_stack(); func_pointer } fn load_tip_address_from_stack(&mut self) { - self.tip_address = self.function_builder.ins().stack_load(I64, self.tip_address_slot, 0); + self.tip_address = self + .function_builder + .ins() + .stack_load(I64, self.tip_address_slot, 0); } fn store_tip_address_to_stack(&mut self) -> Value { - self.function_builder.ins().stack_store(self.tip_address, self.tip_address_slot, 0); - let tip_address_ptr = self.function_builder.ins().stack_addr(I64, self.tip_address_slot, 0); + self.function_builder + .ins() + .stack_store(self.tip_address, self.tip_address_slot, 0); + let tip_address_ptr = self + .function_builder + .ins() + .stack_addr(I64, self.tip_address_slot, 0); tip_address_ptr } pub fn push_arg_v_terms(&mut self, args: &[VTerm]) { // TODO: take an argument determining if it's tall call. If so, just copy existing arguments // to the correct location and push new args rest after that. This way we can copy less. - let arg_values = args.iter().map(|arg| { - let v = self.translate_v_term(arg); - self.convert_to_uniform(v) - }).collect::>(); + let arg_values = args + .iter() + .map(|arg| { + let v = self.translate_v_term(arg); + self.convert_to_uniform(v) + }) + .collect::>(); self.push_args(arg_values); } - fn create_branch_block(&mut self, branch_block: Block, is_tail: bool, joining_block: Block, result_v_type: &VType, branch: Option<&CTerm>, mut translate_c_term: F) where F: FnMut(&CTerm, bool) -> TypedReturnValue { + fn create_branch_block( + &mut self, + branch_block: Block, + is_tail: bool, + joining_block: Block, + result_v_type: &VType, + branch: Option<&CTerm>, + mut translate_c_term: F, + ) where + F: FnMut(&CTerm, bool) -> TypedReturnValue, + { self.function_builder.switch_to_block(branch_block); let typed_return_value = match branch { None => { - self.function_builder.ins().trap(TrapCode::UnreachableCodeReached); + self.function_builder + .ins() + .trap(TrapCode::UnreachableCodeReached); None } Some(branch) => translate_c_term(branch, is_tail), @@ -572,14 +772,19 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { } Some(..) => { let value = self.adapt_type(typed_return_value, result_v_type); - self.function_builder.ins().jump(joining_block, &[value, self.tip_address]); + self.function_builder + .ins() + .jump(joining_block, &[value, self.tip_address]); } } } pub fn extract_return_value(&mut self, inst: Inst) -> TypedReturnValue { let return_address = self.function_builder.inst_results(inst)[0]; - let return_value = self.function_builder.ins().load(I64, MemFlags::new(), return_address, 0); + let return_value = + self.function_builder + .ins() + .load(I64, MemFlags::new(), return_address, 0); self.tip_address = self.function_builder.ins().iadd_imm(return_address, 8); Some((return_value, Uniform)) } @@ -588,9 +793,20 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { if self.num_args == 0 { return self.tip_address; } - let new_base_value = self.function_builder.ins().iadd_imm(self.tip_address, (self.num_args * 8) as i64); - let num_bytes_to_copy = self.function_builder.ins().isub(self.base_address, self.tip_address); - self.function_builder.call_memmove(self.module.target_config(), new_base_value, self.tip_address, num_bytes_to_copy); + let new_base_value = self + .function_builder + .ins() + .iadd_imm(self.tip_address, (self.num_args * 8) as i64); + let num_bytes_to_copy = self + .function_builder + .ins() + .isub(self.base_address, self.tip_address); + self.function_builder.call_memmove( + self.module.target_config(), + new_base_value, + self.tip_address, + num_bytes_to_copy, + ); new_base_value } @@ -600,14 +816,24 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { None => { if *index < self.num_args { let base_address = self.base_address; - let value = self.function_builder.ins().load(I64, MemFlags::new(), base_address, (8 * index) as i32); + let value = self.function_builder.ins().load( + I64, + MemFlags::new(), + base_address, + (8 * index) as i32, + ); let typed_return_value = Some((value, Uniform)); self.local_vars[*index] = typed_return_value; typed_return_value } else { let local_var_index = *index - self.num_args; let local_var_ptr = self.local_var_ptr; - let value = self.function_builder.ins().load(I64, MemFlags::new(), local_var_ptr, (8 * local_var_index) as i32); + let value = self.function_builder.ins().load( + I64, + MemFlags::new(), + local_var_ptr, + (8 * local_var_index) as i32, + ); let typed_return_value = Some((value, Uniform)); self.local_vars[*index] = typed_return_value; typed_return_value @@ -618,9 +844,12 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { VTerm::Thunk { box t, .. } => { let empty_args = &vec![]; let (name, args) = match t { - CTerm::Redex { function: box CTerm::Def { name, .. }, args } => (name, args), + CTerm::Redex { + function: box CTerm::Def { name, .. }, + args, + } => (name, args), CTerm::Def { name, .. } => (name, empty_args), - _ => unreachable!("thunk lifting should have guaranteed this") + _ => unreachable!("thunk lifting should have guaranteed this"), }; let func_ref = self.get_local_function(name, FunctionFlavor::Cps); let func_pointer = self.function_builder.ins().func_addr(I64, func_ref); @@ -638,7 +867,10 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { } Some((self.create_struct(thunk_components), Specialized(StructPtr))) } - VTerm::Int { value } => Some((self.function_builder.ins().iconst(I64, *value), Specialized(Integer))), + VTerm::Int { value } => Some(( + self.function_builder.ins().iconst(I64, *value), + Specialized(Integer), + )), VTerm::Str { value } => { // Insert into the global data section if not already there. let data_id = self.static_strings.entry(value.clone()).or_insert_with(|| { @@ -653,12 +885,19 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { self.data_description.define(bytes.into_boxed_slice()); // Align to 8 bytes because comparison logic compares aligned words. self.data_description.set_align(8); - let data_id = self.module.declare_data(value, Linkage::Local, false, false).unwrap(); - self.module.define_data(data_id, &self.data_description).unwrap(); + let data_id = self + .module + .declare_data(value, Linkage::Local, false, false) + .unwrap(); + self.module + .define_data(data_id, &self.data_description) + .unwrap(); self.data_description.clear(); data_id }); - let global_value = self.module.declare_data_in_func(*data_id, self.function_builder.func); + let global_value = self + .module + .declare_data_in_func(*data_id, self.function_builder.func); let raw_data_ptr = self.function_builder.ins().symbol_value(I64, global_value); // Offset the pointer by 8 bytes to skip the length field. let data_ptr = self.function_builder.ins().iadd_imm(raw_data_ptr, 8); @@ -666,10 +905,13 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { } VTerm::Struct { values } => { // TODO: use a common empty struct if values is empty - let translated = values.iter().map(|v| { - let v = self.translate_v_term(v); - self.convert_to_uniform(v) - }).collect::>(); + let translated = values + .iter() + .map(|v| { + let v = self.translate_v_term(v); + self.convert_to_uniform(v) + }) + .collect::>(); Some((self.create_struct(translated), Specialized(StructPtr))) } } @@ -678,44 +920,66 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { fn push_args(&mut self, args: Vec) { for arg in args.into_iter().rev() { self.tip_address = self.function_builder.ins().iadd_imm(self.tip_address, -8); - self.function_builder.ins().store(MemFlags::new().with_aligned(), arg, self.tip_address, 0); + self.function_builder.ins().store( + MemFlags::new().with_aligned(), + arg, + self.tip_address, + 0, + ); } } pub fn get_local_function(&mut self, name: &str, flavor: FunctionFlavor) -> FuncRef { let desired_func_name = flavor.decorate_name(name); - let func_id = self.local_functions.get(&desired_func_name).unwrap_or_else(|| panic!("Cannot get function '{}'", &desired_func_name)); - self.module.declare_func_in_func(*func_id, self.function_builder.func) + let func_id = self + .local_functions + .get(&desired_func_name) + .unwrap_or_else(|| panic!("Cannot get function '{}'", &desired_func_name)); + self.module + .declare_func_in_func(*func_id, self.function_builder.func) } fn create_struct(&mut self, values: Vec) -> Value { let struct_size = values.len(); let struct_size_value = self.function_builder.ins().iconst(I64, struct_size as i64); - let runtime_alloc_call = self.call_builtin_func(BuiltinFunction::Alloc, &[struct_size_value]); + let runtime_alloc_call = + self.call_builtin_func(BuiltinFunction::Alloc, &[struct_size_value]); let struct_address = self.function_builder.inst_results(runtime_alloc_call)[0]; for (offset, value) in values.into_iter().enumerate() { self.function_builder.ins().store( MemFlags::new().with_aligned(), value, struct_address, - (offset * 8) as i32); + (offset * 8) as i32, + ); } struct_address } - pub(crate) fn call_builtin_func(&mut self, builtin_function: BuiltinFunction, args: &[Value]) -> Inst { - let func_ref = self.module.declare_func_in_func(self.builtin_functions[builtin_function], self.function_builder.func); + pub(crate) fn call_builtin_func( + &mut self, + builtin_function: BuiltinFunction, + args: &[Value], + ) -> Inst { + let func_ref = self.module.declare_func_in_func( + self.builtin_functions[builtin_function], + self.function_builder.func, + ); self.function_builder.ins().call(func_ref, args) } pub(crate) fn get_builtin_data(&mut self, builtin_data: BuiltinData) -> Value { - let data_ref = self.module.declare_data_in_func(self.builtin_data[builtin_data], self.function_builder.func); + let data_ref = self + .module + .declare_data_in_func(self.builtin_data[builtin_data], self.function_builder.func); let result = if builtin_data.is_tls() { return self.function_builder.ins().tls_value(I64, data_ref); } else { self.function_builder.ins().global_value(I64, data_ref) }; - self.function_builder.ins().iadd_imm(result, builtin_data.offset()) + self.function_builder + .ins() + .iadd_imm(result, builtin_data.offset()) } pub(crate) fn convert_to_uniform(&mut self, value_and_type: TypedReturnValue) -> Value { @@ -731,7 +995,9 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { let alloc_size = self.function_builder.ins().iconst(I64, 8); let inst = self.call_builtin_func(BuiltinFunction::Alloc, &[alloc_size]); let ptr = self.function_builder.inst_results(inst)[0]; - self.function_builder.ins().store(MemFlags::new(), value, ptr, 0); + self.function_builder + .ins() + .store(MemFlags::new(), value, ptr, 0); // Add 0b11 to the end to signify this is a primitive pointer self.function_builder.ins().iadd_imm(ptr, 0b11) } @@ -740,16 +1006,23 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { self.function_builder.ins().ishl_imm(extended, 32) } PType::F32 => { - let casted = self.function_builder.ins().bitcast(I32, MemFlags::new(), value); + let casted = + self.function_builder + .ins() + .bitcast(I32, MemFlags::new(), value); let extended = self.function_builder.ins().sextend(I64, casted); self.function_builder.ins().ishl_imm(extended, 32) } - } - } + }, + }, } } - pub(crate) fn convert_to_special(&mut self, value_and_type: TypedReturnValue, specialized_type: SpecializedType) -> Value { + pub(crate) fn convert_to_special( + &mut self, + value_and_type: TypedReturnValue, + specialized_type: SpecializedType, + ) -> Value { let (value, value_type) = Self::extract_value_and_type(value_and_type); match value_type { Uniform => match specialized_type { @@ -759,7 +1032,9 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { SpecializedType::Primitive(p) => match p { PType::I64 | PType::F64 => { let ptr = self.function_builder.ins().iadd_imm(value, -0b11); - self.function_builder.ins().load(p.get_type(), MemFlags::new(), ptr, 0) + self.function_builder + .ins() + .load(p.get_type(), MemFlags::new(), ptr, 0) } PType::I32 => { let shifted = self.function_builder.ins().sshr_imm(value, 32); @@ -768,19 +1043,27 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { PType::F32 => { let shifted = self.function_builder.ins().sshr_imm(value, 32); let truncated = self.function_builder.ins().ireduce(I32, shifted); - self.function_builder.ins().bitcast(F32, MemFlags::new(), truncated) + self.function_builder + .ins() + .bitcast(F32, MemFlags::new(), truncated) } + }, + }, + Specialized(s) => { + if s == specialized_type { + value + } else { + unreachable!("type conversion between two specialized types is not supported and this must be a type error in the input program") } } - Specialized(s) => if s == specialized_type { - value - } else { - unreachable!("type conversion between two specialized types is not supported and this must be a type error in the input program") - } } } - pub(crate) fn adapt_type(&mut self, value_and_type: TypedReturnValue, target_type: &VType) -> Value { + pub(crate) fn adapt_type( + &mut self, + value_and_type: TypedReturnValue, + target_type: &VType, + ) -> Value { let (value, value_type) = Self::extract_value_and_type(value_and_type); if value_type == *target_type { return value; @@ -797,38 +1080,96 @@ impl<'a, M: Module> SimpleFunctionTranslator<'a, M> { value_and_type.expect("non-local return value cannot be converted and this must be a bug in the compilation logic or input is not well-typed") } - pub fn define_function(module: &mut M, ctx: &mut codegen::Context, name: &str, func_id: FuncId, clir: &mut Option<&mut Vec<(String, String)>>) { + pub fn define_function( + module: &mut M, + ctx: &mut codegen::Context, + name: &str, + func_id: FuncId, + clir: &mut Option<&mut Vec<(String, String)>>, + ) { if let Some(clir) = clir { clir.push((name.to_owned(), format!("{}", ctx.func.display()))); } module.define_function(func_id, ctx).unwrap_or_else(|e| { - panic!("failed to define function {}: {:#?}\nDetails: {}", name, e, ctx.func.display()); + panic!( + "failed to define function {}: {:#?}\nDetails: {}", + name, + e, + ctx.func.display() + ); }); } - pub fn handle_operation_call(&mut self, eff_value: Value, continuation: Value, num_args: usize, may_be_complex: bool, use_return_call: bool) -> Inst { + pub fn handle_operation_call( + &mut self, + eff_value: Value, + continuation: Value, + num_args: usize, + may_be_complex: bool, + use_return_call: bool, + ) -> Inst { let num_args_value = self.function_builder.ins().iconst(I64, num_args as i64); - let captured_continuation_record_impl_ref = self.module.declare_func_in_func(self.builtin_functions[BuiltinFunction::CapturedContinuationRecordImpl], self.function_builder.func); - let captured_continuation_record_impl_ptr = self.function_builder.ins().func_addr(I64, captured_continuation_record_impl_ref); - let simple_handler_runner_impl_ref = self.module.declare_func_in_func(self.builtin_functions[BuiltinFunction::SimpleHandlerRunnerImpl], self.function_builder.func); - let simple_handler_runner_impl_ptr = self.function_builder.ins().func_addr(I64, simple_handler_runner_impl_ref); - let may_be_complex_value = self.function_builder.ins().iconst(I64, may_be_complex as i64); + let captured_continuation_record_impl_ref = self.module.declare_func_in_func( + self.builtin_functions[BuiltinFunction::CapturedContinuationRecordImpl], + self.function_builder.func, + ); + let captured_continuation_record_impl_ptr = self + .function_builder + .ins() + .func_addr(I64, captured_continuation_record_impl_ref); + let simple_handler_runner_impl_ref = self.module.declare_func_in_func( + self.builtin_functions[BuiltinFunction::SimpleHandlerRunnerImpl], + self.function_builder.func, + ); + let simple_handler_runner_impl_ptr = self + .function_builder + .ins() + .func_addr(I64, simple_handler_runner_impl_ref); + let may_be_complex_value = self + .function_builder + .ins() + .iconst(I64, may_be_complex as i64); let inst = self.call_builtin_func( BuiltinFunction::PrepareOperation, - &[eff_value, self.tip_address, continuation, num_args_value, captured_continuation_record_impl_ptr, simple_handler_runner_impl_ptr, may_be_complex_value], + &[ + eff_value, + self.tip_address, + continuation, + num_args_value, + captured_continuation_record_impl_ptr, + simple_handler_runner_impl_ptr, + may_be_complex_value, + ], ); let result_ptr = self.function_builder.inst_results(inst)[0]; - let handler_impl = self.function_builder.ins().load(I64, MemFlags::new(), result_ptr, 0); - let handler_base_address = self.function_builder.ins().load(I64, MemFlags::new(), result_ptr, 8); - let next_continuation = self.function_builder.ins().load(I64, MemFlags::new(), result_ptr, 16); + let handler_impl = self + .function_builder + .ins() + .load(I64, MemFlags::new(), result_ptr, 0); + let handler_base_address = + self.function_builder + .ins() + .load(I64, MemFlags::new(), result_ptr, 8); + let next_continuation = + self.function_builder + .ins() + .load(I64, MemFlags::new(), result_ptr, 16); let signature = self.uniform_cps_func_signature.clone(); let sig_ref = self.function_builder.import_signature(signature); if use_return_call { - self.function_builder.ins().return_call_indirect(sig_ref, handler_impl, &[handler_base_address, next_continuation]) + self.function_builder.ins().return_call_indirect( + sig_ref, + handler_impl, + &[handler_base_address, next_continuation], + ) } else { - self.function_builder.ins().call_indirect(sig_ref, handler_impl, &[handler_base_address, next_continuation]) + self.function_builder.ins().call_indirect( + sig_ref, + handler_impl, + &[handler_base_address, next_continuation], + ) } } } diff --git a/src/bin/asm_gen_playground.rs b/src/bin/asm_gen_playground.rs index c3c59b3..769347b 100644 --- a/src/bin/asm_gen_playground.rs +++ b/src/bin/asm_gen_playground.rs @@ -1,13 +1,13 @@ -use std::fs::File; -use std::io::{BufWriter}; use cranelift::codegen; use cranelift::codegen::isa::CallConv; use cranelift::codegen::settings; use cranelift::frontend::{FunctionBuilder, FunctionBuilderContext}; -use cranelift::prelude::{AbiParam, Block, Configurable, InstBuilder, MemFlags, Signature}; use cranelift::prelude::types::I64; +use cranelift::prelude::{AbiParam, Block, Configurable, InstBuilder, MemFlags, Signature}; use cranelift_module::{FuncId, Module}; use cranelift_object::{ObjectBuilder, ObjectModule, ObjectProduct}; +use std::fs::File; +use std::io::BufWriter; struct Compiler { builder_context: FunctionBuilderContext, @@ -35,7 +35,9 @@ impl Default for Compiler { .finish(settings::Flags::new(flag_builder)) .unwrap(); - let builder = ObjectBuilder::new(isa, "playground", cranelift_module::default_libcall_names()).unwrap(); + let builder = + ObjectBuilder::new(isa, "playground", cranelift_module::default_libcall_names()) + .unwrap(); let module = ObjectModule::new(builder); Self { builder_context: FunctionBuilderContext::new(), @@ -48,11 +50,19 @@ impl Default for Compiler { impl Compiler { fn declare_function(&mut self, name: &str, sig: &Signature) -> FuncId { - self.module.declare_function(name, cranelift_module::Linkage::Local, sig).unwrap() + self.module + .declare_function(name, cranelift_module::Linkage::Local, sig) + .unwrap() } - fn define_function(&mut self, name: &str, sig: &Signature, f: F) -> FuncId where F: FnOnce(&mut ObjectModule, &mut FunctionBuilder, Block) { - let func_id = self.module.declare_function(name, cranelift_module::Linkage::Local, sig).unwrap(); + fn define_function(&mut self, name: &str, sig: &Signature, f: F) -> FuncId + where + F: FnOnce(&mut ObjectModule, &mut FunctionBuilder, Block), + { + let func_id = self + .module + .declare_function(name, cranelift_module::Linkage::Local, sig) + .unwrap(); self.ctx.func.signature = sig.clone(); let mut builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context); let entry_block = builder.create_block(); @@ -97,5 +107,8 @@ fn main() { let object_product = compiler.finish(); let output_path = home::home_dir().unwrap().join("tmp/playground.o"); let temp_output_writer = BufWriter::new(File::create(output_path).unwrap()); - object_product.object.write_stream(temp_output_writer).unwrap(); -} \ No newline at end of file + object_product + .object + .write_stream(temp_output_writer) + .unwrap(); +} diff --git a/src/bin/asm_playground.rs b/src/bin/asm_playground.rs index 4cdb97f..64d18e7 100644 --- a/src/bin/asm_playground.rs +++ b/src/bin/asm_playground.rs @@ -1,12 +1,14 @@ use std::arch::global_asm; -global_asm!(r#" +global_asm!( + r#" .global _foo _foo: mov x16, sp str x16, [x1] ret -"#); +"# +); extern "C" { fn foo(a: i64, b: &mut i64) -> i64; @@ -18,4 +20,4 @@ fn main() { foo(1, &mut i); } println!("{}", i); -} \ No newline at end of file +} diff --git a/src/frontend/f_term.rs b/src/frontend/f_term.rs index 4680ed1..e66aeb1 100644 --- a/src/frontend/f_term.rs +++ b/src/frontend/f_term.rs @@ -1,22 +1,67 @@ -use archon_vm_runtime::runtime::HandlerType; use crate::ast::term::{CType, Effect, VType}; +use archon_vm_runtime::runtime::HandlerType; #[derive(Debug, Clone, PartialEq)] pub enum FTerm { - Identifier { name: String, effect: Effect }, - Int { value: i64 }, - Str { value: String }, - Struct { values: Vec }, - Lambda { arg_names: Vec<(String, VType)>, body: Box, effect: Effect }, - Redex { function: Box, args: Vec }, - Force { thunk: Box, effect: Effect }, - Thunk { computation: Box, effect: Effect }, - CaseInt { t: Box, result_type: CType, branches: Vec<(i64, FTerm)>, default_branch: Option> }, - MemGet { base: Box, offset: Box }, - MemSet { base: Box, offset: Box, value: Box }, - Let { name: String, t: Box, body: Box }, - Defs { defs: Vec<(String, Def)>, body: Option> }, - OperationCall { eff: Box, args: Vec, effect: Effect }, + Identifier { + name: String, + effect: Effect, + }, + Int { + value: i64, + }, + Str { + value: String, + }, + Struct { + values: Vec, + }, + Lambda { + arg_names: Vec<(String, VType)>, + body: Box, + effect: Effect, + }, + Redex { + function: Box, + args: Vec, + }, + Force { + thunk: Box, + effect: Effect, + }, + Thunk { + computation: Box, + effect: Effect, + }, + CaseInt { + t: Box, + result_type: CType, + branches: Vec<(i64, FTerm)>, + default_branch: Option>, + }, + MemGet { + base: Box, + offset: Box, + }, + MemSet { + base: Box, + offset: Box, + value: Box, + }, + Let { + name: String, + t: Box, + body: Box, + }, + Defs { + defs: Vec<(String, Def)>, + body: Option>, + }, + OperationCall { + eff: Box, + args: Vec, + effect: Effect, + }, Handler { parameter: Box, parameter_disposer: Option>, @@ -32,4 +77,4 @@ pub struct Def { pub args: Vec<(String, VType)>, pub body: Box, pub c_type: CType, -} \ No newline at end of file +} diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index 4c2d7ca..cc6481d 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -1,3 +1,3 @@ +mod f_term; mod parser; mod transpiler; -mod f_term; \ No newline at end of file diff --git a/src/frontend/parser.rs b/src/frontend/parser.rs index 59830c3..35cc612 100644 --- a/src/frontend/parser.rs +++ b/src/frontend/parser.rs @@ -1,18 +1,18 @@ +use crate::ast::term::{CType, Effect, PType, SpecializedType, VType}; +use crate::frontend::f_term::{Def, FTerm}; +use crate::frontend::parser::Fixity::*; +use archon_vm_runtime::runtime::HandlerType; use either::Either; use nom::branch::alt; -use nom::{InputLength, IResult}; -use nom::character::complete::{space0, char, satisfy as char_satisfy, alphanumeric1, one_of}; +use nom::bytes::complete::{escaped, take_while1}; +use nom::character::complete::{alphanumeric1, char, one_of, satisfy as char_satisfy, space0}; use nom::combinator::{cut, map, map_res, opt}; -use nom::bytes::complete::{take_while1, escaped}; use nom::error::{context, ErrorKind, ParseError}; -use nom::Parser; use nom::multi::{many0, many1}; use nom::sequence::{delimited, pair, preceded, terminated, tuple}; -use crate::frontend::f_term::{Def, FTerm}; -use nom_locate::{LocatedSpan}; -use archon_vm_runtime::runtime::HandlerType; -use crate::frontend::parser::Fixity::{*}; -use crate::ast::term::{CType, SpecializedType, PType, VType, Effect}; +use nom::Parser; +use nom::{IResult, InputLength}; +use nom_locate::LocatedSpan; type Span<'a> = LocatedSpan<&'a str>; @@ -29,9 +29,22 @@ type OperatorAndName = (&'static str, &'static str); static PRECEDENCE: &[(&[OperatorAndName], Fixity)] = &[ (&[("+", "_int_pos"), ("-", "_int_neg")], Prefix), - (&[("*", "_int_mul"), ("/", "_int_div"), ("%", "_int_mod")], Infixl), + ( + &[("*", "_int_mul"), ("/", "_int_div"), ("%", "_int_mod")], + Infixl, + ), (&[("+", "_int_add"), ("-", "_int_sub")], Infixl), - (&[(">", "_int_gt"), ("<", "_int_lt"), (">=", "_int_gte"), ("<=", "_int_lte"), ("==", "_int_eq"), ("!=", "_int_ne")], Infix), + ( + &[ + (">", "_int_gt"), + ("<", "_int_lt"), + (">=", "_int_gte"), + ("<=", "_int_lte"), + ("==", "_int_eq"), + ("!=", "_int_ne"), + ], + Infix, + ), (&[("~", "_bool_not")], Prefix), (&[("&&", "_bool_and")], Infixl), (&[("||", "_bool_or")], Infixl), @@ -39,7 +52,32 @@ static PRECEDENCE: &[(&[OperatorAndName], Fixity)] = &[ // keywords static KEYWORDS: &[&str] = &[ - "let", "def", "case", "force", "thunk", "handler", "=>", "=>!", "=", "(", ")", ",", "\\", "{", "}", "@", "_", ":", "->", "!", "#", "#!", "#^", "##", "disposer", "replicator" + "let", + "def", + "case", + "force", + "thunk", + "handler", + "=>", + "=>!", + "=", + "(", + ")", + ",", + "\\", + "{", + "}", + "@", + "_", + ":", + "->", + "!", + "#", + "#!", + "#^", + "##", + "disposer", + "replicator", ]; // tokenizer @@ -67,37 +105,70 @@ fn identifier_token(input: Span) -> IResult { context( "identifier_token", alt(( - // alpha-numeric identifier - char_satisfy(|c| c.is_alphabetic() || c == '_') - .and(opt(take_while1(|c: char| c.is_alphanumeric() || c == '_'))) - .map(|(head, tail)| { - let id_string = match tail { - Some(tail) => format!("{}{}", head, tail), - None => head.to_string(), - }; - Token::Normal(id_string, input.location_line() - 1, input.naive_get_utf8_column() - 1) - }), - // specially handle some punctuations that should never be combined with others - one_of("(),\\{}").map(|c| Token::Normal(c.to_string(), input.location_line() - 1, input.naive_get_utf8_column() - 1)), - // punctuation identifier - take_while1(|c: char| c.is_ascii_punctuation() && - c != '`' && c != '"' && c != '(' && c != ')' && c != ',' && c != '\\' && c != '{' && c != '}') - .map(|s: Span| Token::Normal(s.to_string(), input.location_line() - 1, input.naive_get_utf8_column() - 1)), - // backtick-quoted identifier - delimited( - char('`'), - take_while1(|c| c != '`'), char('`')).map(|s: Span| Token::Normal(s.to_string(), input.location_line() - 1, input.naive_get_utf8_column() - 1)), - ), - ))(input) + // alpha-numeric identifier + char_satisfy(|c| c.is_alphabetic() || c == '_') + .and(opt(take_while1(|c: char| c.is_alphanumeric() || c == '_'))) + .map(|(head, tail)| { + let id_string = match tail { + Some(tail) => format!("{}{}", head, tail), + None => head.to_string(), + }; + Token::Normal( + id_string, + input.location_line() - 1, + input.naive_get_utf8_column() - 1, + ) + }), + // specially handle some punctuations that should never be combined with others + one_of("(),\\{}").map(|c| { + Token::Normal( + c.to_string(), + input.location_line() - 1, + input.naive_get_utf8_column() - 1, + ) + }), + // punctuation identifier + take_while1(|c: char| { + c.is_ascii_punctuation() + && c != '`' + && c != '"' + && c != '(' + && c != ')' + && c != ',' + && c != '\\' + && c != '{' + && c != '}' + }) + .map(|s: Span| { + Token::Normal( + s.to_string(), + input.location_line() - 1, + input.naive_get_utf8_column() - 1, + ) + }), + // backtick-quoted identifier + delimited(char('`'), take_while1(|c| c != '`'), char('`')).map(|s: Span| { + Token::Normal( + s.to_string(), + input.location_line() - 1, + input.naive_get_utf8_column() - 1, + ) + }), + )), + )(input) } fn int_token(input: Span) -> IResult { context( "int_token", - map( - take_while1(|c: char| c.is_ascii_digit()), - |s: Span| Token::Int(s.to_string().parse::().unwrap(), input.location_line() - 1, s.naive_get_utf8_column() - 1), - ))(input) + map(take_while1(|c: char| c.is_ascii_digit()), |s: Span| { + Token::Int( + s.to_string().parse::().unwrap(), + input.location_line() - 1, + s.naive_get_utf8_column() - 1, + ) + }), + )(input) } fn str_token(input: Span) -> IResult { @@ -108,45 +179,51 @@ fn str_token(input: Span) -> IResult { "string", preceded( char('\"'), - cut( - terminated( - escaped( - alphanumeric1, - '\\', - one_of("\"\\")), - char('\"')))), + cut(terminated( + escaped(alphanumeric1, '\\', one_of("\"\\")), + char('\"'), + )), + ), ), - |s: Span| Token::Str(s.to_string(), input.location_line() - 1, s.naive_get_utf8_column() - 2), // offset quote - ))(input) + |s: Span| { + Token::Str( + s.to_string(), + input.location_line() - 1, + s.naive_get_utf8_column() - 2, + ) + }, // offset quote + ), + )(input) } fn indent_token(input: Span) -> IResult { context( "indent_token", - map_res( - many0(one_of(" \t\n\r")), - |whitespaces| { - let mut indent = 0; - let mut has_newline = false; - for c in whitespaces { - match c { - ' ' => indent += 1, - '\t' => indent += 4, - '\n' => { - indent = 0; - has_newline = true; - } - '\r' => indent = 0, - _ => panic!("unexpected whitespace character: {}", c), + map_res(many0(one_of(" \t\n\r")), |whitespaces| { + let mut indent = 0; + let mut has_newline = false; + for c in whitespaces { + match c { + ' ' => indent += 1, + '\t' => indent += 4, + '\n' => { + indent = 0; + has_newline = true; } + '\r' => indent = 0, + _ => panic!("unexpected whitespace character: {}", c), } - if has_newline { - Ok(Token::Indent(input.location_line(), indent)) - } else { - Err(nom::Err::Error(nom::error::Error { input, code: ErrorKind::Space })) - } - }, - ))(input) + } + if has_newline { + Ok(Token::Indent(input.location_line(), indent)) + } else { + Err(nom::Err::Error(nom::error::Error { + input, + code: ErrorKind::Space, + })) + } + }), + )(input) } fn tokens(input: Span) -> IResult> { @@ -154,12 +231,8 @@ fn tokens(input: Span) -> IResult> { many0(one_of(" \t\n\r")), separated_list0( space0, - alt(( - identifier_token, - int_token, - str_token, - indent_token, - ))), + alt((identifier_token, int_token, str_token, indent_token)), + ), )(input) } @@ -177,11 +250,15 @@ impl<'a> InputLength for Input<'a> { } } -fn boolean(false_parser: P1, true_parser: P2) -> impl FnMut(I) -> IResult where +fn boolean( + false_parser: P1, + true_parser: P2, +) -> impl FnMut(I) -> IResult +where I: Clone + InputLength, P1: Parser, P2: Parser, - E: ParseError + E: ParseError, { alt((false_parser.map(|_| false), true_parser.map(|_| true))) } @@ -197,29 +274,59 @@ fn boolean(false_parser: P1, true_parser: P2) -> impl FnMu // E: ParseError, // { -fn map_token(f: F) -> impl FnMut(Input) -> IResult where F: Fn(&Token) -> Option { +fn map_token(f: F) -> impl FnMut(Input) -> IResult +where + F: Fn(&Token) -> Option, +{ move |input: Input| { if input.tokens.is_empty() { - Err(nom::Err::Error(nom::error::Error { input, code: ErrorKind::Eof })) + Err(nom::Err::Error(nom::error::Error { + input, + code: ErrorKind::Eof, + })) } else { let token = input.tokens.first().unwrap(); match f(token) { - Some(r) => Ok((Input { tokens: &input.tokens[1..], current_indent: input.current_indent }, r)), - None => Err(nom::Err::Error(nom::error::Error { input, code: ErrorKind::MapRes })), + Some(r) => Ok(( + Input { + tokens: &input.tokens[1..], + current_indent: input.current_indent, + }, + r, + )), + None => Err(nom::Err::Error(nom::error::Error { + input, + code: ErrorKind::MapRes, + })), } } } } -fn satisfy(f: F) -> impl FnMut(Input) -> IResult where F: Fn(&Token) -> bool { +fn satisfy(f: F) -> impl FnMut(Input) -> IResult +where + F: Fn(&Token) -> bool, +{ move |input: Input| { if input.tokens.is_empty() { - Err(nom::Err::Error(nom::error::Error { input, code: ErrorKind::Eof })) + Err(nom::Err::Error(nom::error::Error { + input, + code: ErrorKind::Eof, + })) } else { let token = input.tokens.first().unwrap(); match f(token) { - true => Ok((Input { tokens: &input.tokens[1..], current_indent: input.current_indent }, ())), - false => Err(nom::Err::Error(nom::error::Error { input, code: ErrorKind::Satisfy })), + true => Ok(( + Input { + tokens: &input.tokens[1..], + current_indent: input.current_indent, + }, + (), + )), + false => Err(nom::Err::Error(nom::error::Error { + input, + code: ErrorKind::Satisfy, + })), } } } @@ -227,14 +334,29 @@ fn satisfy(f: F) -> impl FnMut(Input) -> IResult where F: Fn(&Toke /// Update the current indent level according to the next token, which is assumed to be an /// identifier. -fn scoped<'a, F, R>(mut f: F) -> impl FnMut(Input<'a>) -> IResult, R> where F: Parser, R, nom::error::Error>> { +fn scoped<'a, F, R>(mut f: F) -> impl FnMut(Input<'a>) -> IResult, R> +where + F: Parser, R, nom::error::Error>>, +{ move |input| { if input.tokens.is_empty() { - Err(nom::Err::Error(nom::error::Error { input, code: ErrorKind::Eof })) + Err(nom::Err::Error(nom::error::Error { + input, + code: ErrorKind::Eof, + })) } else { let token = input.tokens.first().unwrap(); - match f.parse(Input { tokens: input.tokens, current_indent: token.column() + 1 }) { - Ok((new_input, r)) => Ok((Input { tokens: new_input.tokens, current_indent: input.current_indent }, r)), + match f.parse(Input { + tokens: input.tokens, + current_indent: token.column() + 1, + }) { + Ok((new_input, r)) => Ok(( + Input { + tokens: new_input.tokens, + current_indent: input.current_indent, + }, + r, + )), Err(e) => Err(e), } } @@ -258,10 +380,9 @@ fn v_type(input: Input) -> IResult { } fn v_type_decl(input: Input) -> IResult { - map( - opt(preceded(token(":"), v_type)), - |v_type_option| v_type_option.unwrap_or(VType::Uniform), - )(input) + map(opt(preceded(token(":"), v_type)), |v_type_option| { + v_type_option.unwrap_or(VType::Uniform) + })(input) } fn specialized_c_type(input: Input) -> IResult { @@ -279,7 +400,8 @@ fn newline(input: Input) -> IResult { let current_indent = input.current_indent; // This local variable is needed to make rust's borrow checker happy. Otherwise it complains // `current_indent` does not live long enough. - let mut p = satisfy(|token| matches!(token,Token::Indent(_, column) if *column >= current_indent )); + let mut p = + satisfy(|token| matches!(token,Token::Indent(_, column) if *column >= current_indent )); p(input) } @@ -299,18 +421,21 @@ fn token(s: &'static str) -> impl FnMut(Input) -> IResult { fn id(input: Input) -> IResult { map_token(|token| match token { Token::Normal(name, _, _) - if KEYWORDS.iter().all(|k| k != name) && - PRECEDENCE.iter().all(|(names, ..)| names.iter().all(|(n, _)| n != name)) - => Some(name.clone()), + if KEYWORDS.iter().all(|k| k != name) + && PRECEDENCE + .iter() + .all(|(names, ..)| names.iter().all(|(n, _)| n != name)) => + { + Some(name.clone()) + } _ => None, })(input) } fn effect(input: Input) -> IResult { - map( - opt(token("!").map(|_| Effect::Complex)), - |effect_opt| effect_opt.unwrap_or(Effect::Simple), - )(input) + map(opt(token("!").map(|_| Effect::Complex)), |effect_opt| { + effect_opt.unwrap_or(Effect::Simple) + })(input) } fn op_effect(input: Input) -> IResult { @@ -337,7 +462,13 @@ fn lambda_effect(input: Input) -> IResult { } fn id_term(input: Input) -> IResult { - context("id_term", map(pair(id, effect), |(name, effect)| FTerm::Identifier { name, effect }))(input) + context( + "id_term", + map(pair(id, effect), |(name, effect)| FTerm::Identifier { + name, + effect, + }), + )(input) } fn int(input: Input) -> IResult { @@ -363,78 +494,117 @@ fn str_term(input: Input) -> IResult { } fn struct_(input: Input) -> IResult { - context("struct", map( - delimited( - token("{"), - cut(separated_list0( - delimited(newline_opt, token(","), newline_opt), - expr)), - token("}"), + context( + "struct", + map( + delimited( + token("{"), + cut(separated_list0( + delimited(newline_opt, token(","), newline_opt), + expr, + )), + token("}"), + ), + |values: Vec| FTerm::Struct { values }, ), - |values: Vec| FTerm::Struct { values }, - ))(input) + )(input) } fn atom(input: Input) -> IResult { - context("atom", alt(( - id_term, - int_term, - str_term, - struct_, - delimited( - pair(token("("), newline_opt), - cut(f_term), - pair(newline_opt, token(")"))), - )))(input) + context( + "atom", + alt(( + id_term, + int_term, + str_term, + struct_, + delimited( + pair(token("("), newline_opt), + cut(f_term), + pair(newline_opt, token(")")), + ), + )), + )(input) } fn force(input: Input) -> IResult { - context("force", map(preceded(token("force"), pair(effect, cut(atom))), |(effect, t)| FTerm::Force { thunk: Box::new(t), effect }))(input) + context( + "force", + map( + preceded(token("force"), pair(effect, cut(atom))), + |(effect, t)| FTerm::Force { + thunk: Box::new(t), + effect, + }, + ), + )(input) } fn atomic_call(input: Input) -> IResult { - context("atomic_call", - map( - pair( - atom, - alt(( - map(pair(op_effect, cut(struct_)), Either::Right), - map(many0(pair(preceded(token("@"), cut(atom)), opt(preceded(token("="), cut(f_term))))), Either::Left), - )), + context( + "atomic_call", + map( + pair( + atom, + alt(( + map(pair(op_effect, cut(struct_)), Either::Right), + map( + many0(pair( + preceded(token("@"), cut(atom)), + opt(preceded(token("="), cut(f_term))), + )), + Either::Left, + ), + )), + ), + |(t, either)| match either { + Either::Left(index_and_values) => index_and_values.into_iter().fold( + t, + |t, (index, assignment)| match assignment { + None => FTerm::MemGet { + base: Box::new(t), + offset: Box::new(index), + }, + Some(value) => FTerm::MemSet { + base: Box::new(t), + offset: Box::new(index), + value: Box::new(value), + }, + }, ), - |(t, either)| { - match either { - Either::Left(index_and_values) => index_and_values.into_iter().fold(t, |t, (index, assignment)| { - match assignment { - None => FTerm::MemGet { base: Box::new(t), offset: Box::new(index) }, - Some(value) => FTerm::MemSet { base: Box::new(t), offset: Box::new(index), value: Box::new(value) }, - } - }), - Either::Right((effect, FTerm::Struct { values })) => { - FTerm::OperationCall { eff: Box::new(t), args: values, effect } - } - _ => unreachable!() - } + Either::Right((effect, FTerm::Struct { values })) => FTerm::OperationCall { + eff: Box::new(t), + args: values, + effect, }, - ), + _ => unreachable!(), + }, + ), )(input) } fn scoped_app(input: Input) -> IResult { - context("scoped_app", scoped( - map( - pair(many1(alt((atomic_call, force))), many0(preceded(newline, f_term))), + context( + "scoped_app", + scoped(map( + pair( + many1(alt((atomic_call, force))), + many0(preceded(newline, f_term)), + ), |(f_and_args, more_args)| { if f_and_args.len() == 1 && more_args.is_empty() { f_and_args.into_iter().next().unwrap() } else { let f = f_and_args.first().unwrap().clone(); let all_args = f_and_args.into_iter().skip(1).chain(more_args).collect(); - FTerm::Redex { function: Box::new(f), args: all_args } + FTerm::Redex { + function: Box::new(f), + args: all_args, + } } }, - ) - ))(input) + )), + )(input) } #[allow(clippy::redundant_closure)] @@ -442,19 +612,18 @@ fn scoped_app_boxed() -> BoxedFTermParser { Box::new(move |input| scoped_app(input)) } -fn operator_id(operators: &'static [OperatorAndName]) -> impl FnMut(Input) -> IResult { +fn operator_id( + operators: &'static [OperatorAndName], +) -> impl FnMut(Input) -> IResult { map_token(|token| match token { - Token::Normal(name, _, _) => { - operators.iter() - .filter_map(|(op, fun_name)| - if op == name { - Some(fun_name) - } else { - None - }) - .next() - .map(|fun_name| FTerm::Identifier { name: fun_name.to_string(), effect: Effect::Simple }) - } + Token::Normal(name, _, _) => operators + .iter() + .filter_map(|(op, fun_name)| if op == name { Some(fun_name) } else { None }) + .next() + .map(|fun_name| FTerm::Identifier { + name: fun_name.to_string(), + effect: Effect::Simple, + }), _ => None, }) } @@ -465,11 +634,11 @@ pub fn separated_list0( mut sep: G, mut f: F, ) -> impl FnMut(I) -> IResult, E> - where - I: Clone + InputLength, - F: Parser, - G: Parser, - E: ParseError, +where + I: Clone + InputLength, + F: Parser, + G: Parser, + E: ParseError, { move |mut i: I| { let mut res = Vec::new(); @@ -487,27 +656,28 @@ pub fn separated_list0( match sep.parse(i.clone()) { Err(nom::Err::Error(_)) => return Ok((i, res)), Err(e) => return Err(e), - Ok((i1, _)) => { - match f.parse(i1.clone()) { - Err(nom::Err::Error(_)) => return Ok((i, res)), - Err(e) => return Err(e), - Ok((i2, o)) => { - res.push(o); - i = i2; - } + Ok((i1, _)) => match f.parse(i1.clone()) { + Err(nom::Err::Error(_)) => return Ok((i, res)), + Err(e) => return Err(e), + Ok((i2, o)) => { + res.push(o); + i = i2; } - } + }, } } } } -pub fn infixl(mut operator: F, mut operand: G) -> impl FnMut(I) -> IResult), E> - where - I: Clone + InputLength, - F: Parser, - G: Parser, - E: ParseError, +pub fn infixl( + mut operator: F, + mut operand: G, +) -> impl FnMut(I) -> IResult), E> +where + I: Clone + InputLength, + F: Parser, + G: Parser, + E: ParseError, { move |mut i: I| { let head = match operand.parse(i.clone()) { @@ -548,12 +718,15 @@ pub fn infixl(mut operator: F, mut operand: G) -> impl FnMut } } -pub fn infixr(mut operator: F, mut operand: G) -> impl FnMut(I) -> IResult, O2), E> - where - I: Clone + InputLength, - F: Parser, - G: Parser, - E: ParseError, +pub fn infixr( + mut operator: F, + mut operand: G, +) -> impl FnMut(I) -> IResult, O2), E> +where + I: Clone + InputLength, + F: Parser, + G: Parser, + E: ParseError, { move |mut i: I| { let mut res = Vec::with_capacity(4); @@ -594,12 +767,15 @@ pub fn infixr(mut operator: F, mut operand: G) -> impl FnMut } } -pub fn infix(mut operator: F, mut operand: G) -> impl FnMut(I) -> IResult), E> - where - I: Clone + InputLength, - F: Parser, - G: Parser, - E: ParseError, +pub fn infix( + mut operator: F, + mut operand: G, +) -> impl FnMut(I) -> IResult), E> +where + I: Clone + InputLength, + F: Parser, + G: Parser, + E: ParseError, { move |mut i: I| { let len = i.input_len(); @@ -643,54 +819,95 @@ pub fn infix(mut operator: F, mut operand: G) -> impl FnMut( } } - type BoxedFTermParser = Box IResult>; -fn operator_call(operators: &'static [OperatorAndName], fixity: Fixity, mut component: BoxedFTermParser) -> BoxedFTermParser { - Box::new( - move |input| - match fixity { - Infixl => context("infixl_operator", map( - infixl(delimited(newline_opt, operator_id(operators), newline_opt), |input| (*component)(input)), - |(head, rest)| { - rest.into_iter().fold(head, |acc, (op, arg)| FTerm::Redex { function: Box::new(op), args: vec![acc, arg] }) - }, - ))(input), - Infixr => context("infixr_operator", map( - infixr(delimited(newline_opt, operator_id(operators), newline_opt), |input| (*component)(input)), - |(init, last)| { - init.into_iter().rfold(last, |acc, (arg, op)| FTerm::Redex { function: Box::new(op), args: vec![acc, arg] }) - }, - ))(input), - Infix => context("infix_operator", map( - infix(delimited(newline_opt, operator_id(operators), newline_opt), |input| (*component)(input)), - |(first, middle_last)| match middle_last { - None => first, - Some((middle, last)) => FTerm::Redex { function: Box::new(middle), args: vec![first, last] } +fn operator_call( + operators: &'static [OperatorAndName], + fixity: Fixity, + mut component: BoxedFTermParser, +) -> BoxedFTermParser { + Box::new(move |input| match fixity { + Infixl => context( + "infixl_operator", + map( + infixl( + delimited(newline_opt, operator_id(operators), newline_opt), + |input| (*component)(input), + ), + |(head, rest)| { + rest.into_iter().fold(head, |acc, (op, arg)| FTerm::Redex { + function: Box::new(op), + args: vec![acc, arg], + }) + }, + ), + )(input), + Infixr => context( + "infixr_operator", + map( + infixr( + delimited(newline_opt, operator_id(operators), newline_opt), + |input| (*component)(input), + ), + |(init, last)| { + init.into_iter().rfold(last, |acc, (arg, op)| FTerm::Redex { + function: Box::new(op), + args: vec![acc, arg], + }) + }, + ), + )(input), + Infix => context( + "infix_operator", + map( + infix( + delimited(newline_opt, operator_id(operators), newline_opt), + |input| (*component)(input), + ), + |(first, middle_last)| match middle_last { + None => first, + Some((middle, last)) => FTerm::Redex { + function: Box::new(middle), + args: vec![first, last], }, - ))(input), - Prefix => context("prefix_operator", map( - pair(opt(operator_id(operators)), |input| (*component)(input)), - |(operator, operand)| match operator { - None => operand, - Some(operator) => FTerm::Redex { function: Box::new(operator), args: vec![operand] }, + }, + ), + )(input), + Prefix => context( + "prefix_operator", + map( + pair(opt(operator_id(operators)), |input| (*component)(input)), + |(operator, operand)| match operator { + None => operand, + Some(operator) => FTerm::Redex { + function: Box::new(operator), + args: vec![operand], }, - ))(input), - Postfix => context("postfix_operator", map( - pair(|input| (*component)(input), opt(operator_id(operators))), - |(operand, operator)| match operator { - None => operand, - Some(operator) => FTerm::Redex { function: Box::new(operator), args: vec![operand] }, + }, + ), + )(input), + Postfix => context( + "postfix_operator", + map( + pair(|input| (*component)(input), opt(operator_id(operators))), + |(operand, operator)| match operator { + None => operand, + Some(operator) => FTerm::Redex { + function: Box::new(operator), + args: vec![operand], }, - ))(input), - } - ) + }, + ), + )(input), + }) } fn expr_impl(input: Input) -> IResult { - let mut p = PRECEDENCE.iter().fold(scoped_app_boxed(), |f, (operators, fixity)| { - operator_call(operators, *fixity, f) - }); + let mut p = PRECEDENCE + .iter() + .fold(scoped_app_boxed(), |f, (operators, fixity)| { + operator_call(operators, *fixity, f) + }); (*p)(input) } @@ -699,139 +916,206 @@ fn expr(input: Input) -> IResult { } fn lambda(input: Input) -> IResult { - context("lambda", scoped( - map( + context( + "lambda", + scoped(map( tuple(( preceded( token("\\"), separated_list0(newline_opt, pair(id, v_type_decl)), ), lambda_effect, - preceded(opt(newline), cut(computation)))), - |(arg_names, effect, body)| - FTerm::Lambda { arg_names, body: Box::new(body), effect }, - )))(input) + preceded(opt(newline), cut(computation)), + )), + |(arg_names, effect, body)| FTerm::Lambda { + arg_names, + body: Box::new(body), + effect, + }, + )), + )(input) } fn case(input: Input) -> IResult { - context("case", scoped( - map( + context( + "case", + scoped(map( tuple(( preceded(token("case"), map(expr, Box::new)), c_type_decl, - many0(map(preceded(newline, scoped(tuple((int, preceded(token("=>"), f_term))))), |(i, branch)| (i, branch))), - opt(preceded(newline, scoped(preceded(pair(token("_"), token("=>")), map(f_term, Box::new))))), + many0(map( + preceded(newline, scoped(tuple((int, preceded(token("=>"), f_term))))), + |(i, branch)| (i, branch), + )), + opt(preceded( + newline, + scoped(preceded( + pair(token("_"), token("=>")), + map(f_term, Box::new), + )), + )), )), - |(t, result_type, branches, default_branch)| { - FTerm::CaseInt { t, result_type, branches, default_branch } - }) - ))(input) + |(t, result_type, branches, default_branch)| FTerm::CaseInt { + t, + result_type, + branches, + default_branch, + }, + )), + )(input) } fn let_term(input: Input) -> IResult { - context("let_term", map( - tuple(( - alt(( - scoped(pair(delimited(token("let"), id, token("=")), preceded(newline_opt, cut(f_term)))), - map(expr, |t| (String::from("_"), t)), + context( + "let_term", + map( + tuple(( + alt(( + scoped(pair( + delimited(token("let"), id, token("=")), + preceded(newline_opt, cut(f_term)), + )), + map(expr, |t| (String::from("_"), t)), + )), + preceded(newline, f_term), )), - preceded(newline, f_term), - )), - |((name, t), body)| { - FTerm::Let { name, t: Box::new(t), body: Box::new(body) } - }))(input) + |((name, t), body)| FTerm::Let { + name, + t: Box::new(t), + body: Box::new(body), + }, + ), + )(input) } fn defs_term(input: Input) -> IResult { - context("defs_term", map( - pair( - many1( - map( - scoped( - tuple(( - preceded(token("def"), cut(id)), - many0(pair(id, v_type_decl)), - c_type_decl, - preceded(preceded(token("=>"), newline_opt), map(cut(f_term), Box::new))))), + context( + "defs_term", + map( + pair( + many1(map( + scoped(tuple(( + preceded(token("def"), cut(id)), + many0(pair(id, v_type_decl)), + c_type_decl, + preceded( + preceded(token("=>"), newline_opt), + map(cut(f_term), Box::new), + ), + ))), |(name, args, c_type, body)| (name, Def { args, c_type, body }), - ) - ), opt(preceded(newline, map(f_term, Box::new)))), - |(defs, body)| FTerm::Defs { defs, body }, - ))(input) + )), + opt(preceded(newline, map(f_term, Box::new))), + ), + |(defs, body)| FTerm::Defs { defs, body }, + ), + )(input) } enum HandlerComponent { Disposer(FTerm), Replicator(FTerm), Transform(FTerm), - Handler { eff: FTerm, handler: FTerm, handler_type: HandlerType }, + Handler { + eff: FTerm, + handler: FTerm, + handler_type: HandlerType, + }, } fn handler_component(input: Input) -> IResult { alt(( - map(preceded(token("disposer"), cut(computation)), HandlerComponent::Disposer), - map(preceded(token("replicator"), cut(computation)), HandlerComponent::Replicator), - map(preceded(token("#"), preceded(opt(newline), cut(computation))), HandlerComponent::Transform), map( - tuple(( - atom, - op_declaration, - preceded(opt(newline), computation), - )), - |(eff, effect, handler, )| HandlerComponent::Handler { eff, handler, handler_type: effect }, + preceded(token("disposer"), cut(computation)), + HandlerComponent::Disposer, + ), + map( + preceded(token("replicator"), cut(computation)), + HandlerComponent::Replicator, + ), + map( + preceded(token("#"), preceded(opt(newline), cut(computation))), + HandlerComponent::Transform, + ), + map( + tuple((atom, op_declaration, preceded(opt(newline), computation))), + |(eff, effect, handler)| HandlerComponent::Handler { + eff, + handler, + handler_type: effect, + }, ), ))(input) } fn handler_term(input: Input) -> IResult { - context("handler_term", map( - pair( - scoped(tuple(( - preceded(token("handler"), pair(effect, cut(opt(atom)))), - many0(preceded(newline, cut(handler_component))), - ))), - preceded(newline, f_term), - ), - |(((effect, parameter, ), handler_components), input)| { - let mut handler = FTerm::Handler { - parameter: Box::new(parameter.unwrap_or(FTerm::Struct { values: vec![] })), - parameter_disposer: None, - parameter_replicator: None, - transform: Box::new(FTerm::Lambda { - arg_names: vec![("p".to_owned(), VType::Uniform), ("r".to_owned(), VType::Uniform)], - body: Box::new(FTerm::Identifier { name: "r".to_owned(), effect: Effect::Simple }), - effect: Effect::Simple, - }), - handlers: vec![], - input: Box::new(FTerm::Thunk { computation: Box::new(input), effect }), - }; - for handler_component in handler_components.into_iter() { - let FTerm::Handler { - parameter_disposer, - parameter_replicator, - box transform, - handlers, - .. - } = &mut handler else { unreachable!() }; - - match handler_component { - HandlerComponent::Disposer(disposer) => { - *parameter_disposer = Some(Box::new(disposer)); - } - HandlerComponent::Replicator(replicator) => { - *parameter_replicator = Some(Box::new(replicator)); - } - HandlerComponent::Transform(t) => { - *transform = t; - } - HandlerComponent::Handler { eff, handler, handler_type } => { - handlers.push((eff, handler, handler_type)); + context( + "handler_term", + map( + pair( + scoped(tuple(( + preceded(token("handler"), pair(effect, cut(opt(atom)))), + many0(preceded(newline, cut(handler_component))), + ))), + preceded(newline, f_term), + ), + |(((effect, parameter), handler_components), input)| { + let mut handler = FTerm::Handler { + parameter: Box::new(parameter.unwrap_or(FTerm::Struct { values: vec![] })), + parameter_disposer: None, + parameter_replicator: None, + transform: Box::new(FTerm::Lambda { + arg_names: vec![ + ("p".to_owned(), VType::Uniform), + ("r".to_owned(), VType::Uniform), + ], + body: Box::new(FTerm::Identifier { + name: "r".to_owned(), + effect: Effect::Simple, + }), + effect: Effect::Simple, + }), + handlers: vec![], + input: Box::new(FTerm::Thunk { + computation: Box::new(input), + effect, + }), + }; + for handler_component in handler_components.into_iter() { + let FTerm::Handler { + parameter_disposer, + parameter_replicator, + box transform, + handlers, + .. + } = &mut handler + else { + unreachable!() + }; + + match handler_component { + HandlerComponent::Disposer(disposer) => { + *parameter_disposer = Some(Box::new(disposer)); + } + HandlerComponent::Replicator(replicator) => { + *parameter_replicator = Some(Box::new(replicator)); + } + HandlerComponent::Transform(t) => { + *transform = t; + } + HandlerComponent::Handler { + eff, + handler, + handler_type, + } => { + handlers.push((eff, handler, handler_type)); + } } } - } - handler - }, - ))(input) + handler + }, + ), + )(input) } fn computation(input: Input) -> IResult { @@ -841,11 +1125,17 @@ fn computation(input: Input) -> IResult { fn thunk(input: Input) -> IResult { context( "thunk", - scoped( - map( - preceded(preceded(token("thunk"), newline_opt), pair(effect, cut(computation))), - |(effect, t)| FTerm::Thunk { computation: Box::new(t), effect }) - ))(input) + scoped(map( + preceded( + preceded(token("thunk"), newline_opt), + pair(effect, cut(computation)), + ), + |(effect, t)| FTerm::Thunk { + computation: Box::new(t), + effect, + }, + )), + )(input) } fn f_term(input: Input) -> IResult { @@ -856,36 +1146,55 @@ fn tokenize(input: &str) -> Result, String> { let input = Span::new(input); let (input, tokens) = tokens(input).map_err(|e| format!("lex error: {:?}", e))?; if !input.is_empty() { - return Err(format!("lex error: unexpected character at ({:?}:{:?}): {:?}", input.location_line(), input.naive_get_utf8_column(), input.lines().next().unwrap())); + return Err(format!( + "lex error: unexpected character at ({:?}:{:?}): {:?}", + input.location_line(), + input.naive_get_utf8_column(), + input.lines().next().unwrap() + )); } Ok(tokens) } pub fn parse_f_term(input: &str) -> Result { let tokens = tokenize(input)?; - let input = Input { tokens: &tokens, current_indent: 0 }; + let input = Input { + tokens: &tokens, + current_indent: 0, + }; let (input, term) = f_term(input).map_err(|e| format!("parse error: {:?}", e))?; - if input.tokens.iter().any(|token| !matches!(token, Token::Indent(_, _))) { - return Err(format!("parse error: unexpected token at {:?}", input.tokens.first())); + if input + .tokens + .iter() + .any(|token| !matches!(token, Token::Indent(_, _))) + { + return Err(format!( + "parse error: unexpected token at {:?}", + input.tokens.first() + )); } Ok(term) } #[cfg(test)] mod tests { - use std::fs; - use std::path::PathBuf; use crate::frontend::parser::{parse_f_term, tokenize}; use crate::test_utils::debug_print; + use std::fs; + use std::path::PathBuf; #[test] fn check_tokenize() -> Result<(), String> { - let result = tokenize(r#"a b "c" ++- (+) () ,.~ + let result = tokenize( + r#"a b "c" ++- (+) () ,.~ x y - z"#)?; - assert_eq!(format!("{:#?}", result), r#"[ + z"#, + )?; + assert_eq!( + format!("{:#?}", result), + r#"[ Normal( "a", 0, @@ -968,7 +1277,8 @@ mod tests { 4, 2, ), -]"#); +]"# + ); Ok(()) } @@ -979,18 +1289,43 @@ mod tests { let mut test_input_paths = fs::read_dir(resource_dir) .unwrap() .map(|r| r.unwrap().path()) - .filter(|p| p.file_name().unwrap().to_str().unwrap().ends_with("input.txt")) + .filter(|p| { + p.file_name() + .unwrap() + .to_str() + .unwrap() + .ends_with("input.txt") + }) .collect::>(); test_input_paths.sort(); - let all_results = test_input_paths.into_iter().map(|test_input_path| { - let test_output_path = test_input_path.with_extension("").with_extension("output.txt"); - let result = check(test_input_path.to_str().unwrap(), test_output_path.to_str().unwrap()); - (test_input_path, result) - }).filter(|(_, r)| r.is_err()).collect::>(); + let all_results = test_input_paths + .into_iter() + .map(|test_input_path| { + let test_output_path = test_input_path + .with_extension("") + .with_extension("output.txt"); + let result = check( + test_input_path.to_str().unwrap(), + test_output_path.to_str().unwrap(), + ); + (test_input_path, result) + }) + .filter(|(_, r)| r.is_err()) + .collect::>(); if all_results.is_empty() { Ok(()) } else { - Err(all_results.into_iter().map(|(test_input_path, r)| format!("[{}] {}", test_input_path.file_name().unwrap().to_str().unwrap(), r.unwrap_err())).collect::>().join("\n")) + Err(all_results + .into_iter() + .map(|(test_input_path, r)| { + format!( + "[{}] {}", + test_input_path.file_name().unwrap().to_str().unwrap(), + r.unwrap_err() + ) + }) + .collect::>() + .join("\n")) } } @@ -1007,4 +1342,4 @@ mod tests { Ok(()) } } -} \ No newline at end of file +} diff --git a/src/frontend/transpiler.rs b/src/frontend/transpiler.rs index 4c3990e..a0fb02b 100644 --- a/src/frontend/transpiler.rs +++ b/src/frontend/transpiler.rs @@ -1,9 +1,9 @@ -use std::collections::{HashMap, HashSet}; -use either::{Either, Left, Right}; +use crate::ast::primitive_functions::PRIMITIVE_FUNCTIONS; use crate::ast::signature::{FunctionDefinition, Signature}; use crate::ast::term::{CTerm, CType, Effect, SpecializedType, VTerm, VType}; use crate::frontend::f_term::{Def, FTerm}; -use crate::ast::primitive_functions::PRIMITIVE_FUNCTIONS; +use either::{Either, Left, Right}; +use std::collections::{HashMap, HashSet}; pub struct Transpiler { signature: Signature, @@ -22,254 +22,398 @@ impl Transpiler { self.signature } fn transpile(&mut self, f_term: FTerm) { - let main = self.transpile_impl(f_term, &Context { - enclosing_def_name: "", - def_map: &HashMap::new(), - var_map: &HashMap::new(), - }); - self.signature.insert("main".to_string(), FunctionDefinition { - args: vec![], - body: main, - c_type: CType::SpecializedF(VType::Specialized(SpecializedType::Integer)), - var_bound: 0, - need_simple: false, - need_cps: false, - need_specialized: false, - }); + let main = self.transpile_impl( + f_term, + &Context { + enclosing_def_name: "", + def_map: &HashMap::new(), + var_map: &HashMap::new(), + }, + ); + self.signature.insert( + "main".to_string(), + FunctionDefinition { + args: vec![], + body: main, + c_type: CType::SpecializedF(VType::Specialized(SpecializedType::Integer)), + var_bound: 0, + need_simple: false, + need_cps: false, + need_specialized: false, + }, + ); } fn transpile_impl(&mut self, f_term: FTerm, context: &Context) -> CTerm { match f_term { - FTerm::Identifier { name, effect } => + FTerm::Identifier { name, effect } => { match self.transpile_identifier(&name, context, effect) { Left(c) => c, Right(v) => CTerm::Return { value: v }, - }, - FTerm::Int { value } => CTerm::Return { value: VTerm::Int { value } }, - FTerm::Str { value } => CTerm::Return { value: VTerm::Str { value } }, + } + } + FTerm::Int { value } => CTerm::Return { + value: VTerm::Int { value }, + }, + FTerm::Str { value } => CTerm::Return { + value: VTerm::Str { value }, + }, FTerm::Struct { values } => { - let (transpiled_values, transpiled_computations) = self.transpile_values(values, context); - Self::squash_computations(CTerm::Return { value: VTerm::Struct { values: transpiled_values } }, transpiled_computations) + let (transpiled_values, transpiled_computations) = + self.transpile_values(values, context); + Self::squash_computations( + CTerm::Return { + value: VTerm::Struct { + values: transpiled_values, + }, + }, + transpiled_computations, + ) } - FTerm::Lambda { arg_names, body, effect } => { + FTerm::Lambda { + arg_names, + body, + effect, + } => { let mut var_map = context.var_map.clone(); - let args: Vec<_> = arg_names.iter().map(|(name, v_type)| { - let index = self.new_local_index(); - var_map.insert(name.clone(), index); - (index, *v_type) - }).collect(); - let transpiled_body = self.transpile_impl(*body, &Context { - enclosing_def_name: context.enclosing_def_name, - def_map: context.def_map, - var_map: &var_map, - }); - CTerm::Lambda { args, body: Box::new(transpiled_body), effect } + let args: Vec<_> = arg_names + .iter() + .map(|(name, v_type)| { + let index = self.new_local_index(); + var_map.insert(name.clone(), index); + (index, *v_type) + }) + .collect(); + let transpiled_body = self.transpile_impl( + *body, + &Context { + enclosing_def_name: context.enclosing_def_name, + def_map: context.def_map, + var_map: &var_map, + }, + ); + CTerm::Lambda { + args, + body: Box::new(transpiled_body), + effect, + } } FTerm::Redex { function, args } => { - let (transpiled_args, transpiled_computations) = self.transpile_values(args, context); - let body = CTerm::Redex { function: Box::new(self.transpile_impl(*function, context)), args: transpiled_args }; + let (transpiled_args, transpiled_computations) = + self.transpile_values(args, context); + let body = CTerm::Redex { + function: Box::new(self.transpile_impl(*function, context)), + args: transpiled_args, + }; Self::squash_computations(body, transpiled_computations) } - FTerm::Force { thunk, effect } => - self.transpile_value_and_map(*thunk, context, |(_, t)| CTerm::Force { thunk: t, effect }), - FTerm::Thunk { computation, effect } => - CTerm::Return { - value: VTerm::Thunk { - t: Box::new(self.transpile_impl(*computation, context)), - effect, - } + FTerm::Force { thunk, effect } => { + self.transpile_value_and_map(*thunk, context, |(_, t)| CTerm::Force { + thunk: t, + effect, + }) + } + FTerm::Thunk { + computation, + effect, + } => CTerm::Return { + value: VTerm::Thunk { + t: Box::new(self.transpile_impl(*computation, context)), + effect, }, - FTerm::CaseInt { t, result_type, branches, default_branch } => { + }, + FTerm::CaseInt { + t, + result_type, + branches, + default_branch, + } => { let mut transpiled_branches = Vec::new(); for (value, branch) in branches { transpiled_branches.push((value, self.transpile_impl(branch, context))); } - let transpiled_default_branch = default_branch.map(|b| Box::new(self.transpile_impl(*b, context))); - self.transpile_value_and_map(*t, context, |(_, t)| { - CTerm::CaseInt { t, result_type, branches: transpiled_branches, default_branch: transpiled_default_branch } + let transpiled_default_branch = + default_branch.map(|b| Box::new(self.transpile_impl(*b, context))); + self.transpile_value_and_map(*t, context, |(_, t)| CTerm::CaseInt { + t, + result_type, + branches: transpiled_branches, + default_branch: transpiled_default_branch, }) } - FTerm::MemGet { box base, box offset } => { - self.transpile_value_and_map(base, context, |(s, base)| { - s.transpile_value_and_map(offset, context, |(_, offset)| { - CTerm::MemGet { base, offset } - }) + FTerm::MemGet { + box base, + box offset, + } => self.transpile_value_and_map(base, context, |(s, base)| { + s.transpile_value_and_map(offset, context, |(_, offset)| CTerm::MemGet { + base, + offset, }) - } - FTerm::MemSet { box base, box offset, box value } => { - self.transpile_value_and_map(base, context, |(s, base)| { - s.transpile_value_and_map(offset, context, |(s, offset)| { - s.transpile_value_and_map(value, context, |(_, value)| { - CTerm::MemSet { base, offset, value } - }) + }), + FTerm::MemSet { + box base, + box offset, + box value, + } => self.transpile_value_and_map(base, context, |(s, base)| { + s.transpile_value_and_map(offset, context, |(s, offset)| { + s.transpile_value_and_map(value, context, |(_, value)| CTerm::MemSet { + base, + offset, + value, }) }) - } + }), FTerm::Let { name, t, body } => { let transpiled_t = self.transpile_impl(*t, context); let mut var_map = context.var_map.clone(); let bound_index = self.new_local_index(); var_map.insert(name.to_string(), bound_index); - let transpiled_body = self.transpile_impl(*body, &Context { - enclosing_def_name: context.enclosing_def_name, - def_map: context.def_map, - var_map: &var_map, - }); - CTerm::Let { t: Box::new(transpiled_t), bound_index, body: Box::new(transpiled_body) } + let transpiled_body = self.transpile_impl( + *body, + &Context { + enclosing_def_name: context.enclosing_def_name, + def_map: context.def_map, + var_map: &var_map, + }, + ); + CTerm::Let { + t: Box::new(transpiled_t), + bound_index, + body: Box::new(transpiled_body), + } } FTerm::Defs { defs, body } => { let def_names: Vec = defs.iter().map(|(s, _)| s.to_string()).collect(); let mut def_map = context.def_map.clone(); - let def_with_names: Vec<(Def, String, Vec)> = defs.into_iter().map(|(name, def)| { - let mut free_vars = HashSet::new(); - let identifier_names: HashSet<&str> = HashSet::new(); - Self::get_free_vars(&def.body, &identifier_names, &mut free_vars); - // remove all names bound inside the current def - def.args.iter().for_each(|(v, _v_type)| { - free_vars.remove(v.as_str()); - }); - // remove all names matching defs bound in current scope - def_names.iter().for_each(|v| { - free_vars.remove(v.as_str()); - }); - // remove all names matching defs bound in parent scopes - context.def_map.iter().for_each(|(name, _)| { free_vars.remove(name.as_str()); }); - let def_name = if context.enclosing_def_name.is_empty() { - name.clone() - } else { - format!("{}${}", context.enclosing_def_name, name) - }; - let mut free_var_vec: Vec = free_vars.into_iter().map(|s| s.to_owned()).collect(); - free_var_vec.sort(); - def_map.insert( - name, - ( - def_name.clone(), - free_var_vec - .iter() - .map(|name| VTerm::Var { index: *context.var_map.get(name).unwrap() }) - .collect())); - (def, def_name.clone(), free_var_vec) - }).collect(); - def_with_names.into_iter().for_each(|(def, name, free_vars)| { - let mut var_map = HashMap::new(); - let bound_indexes: Vec<_> = free_vars.iter().map(|v| (v.clone(), VType::Uniform)).chain(def.args.clone()).map(|(arg, arg_type)| { - let index = self.new_local_index(); - var_map.insert(arg.clone(), index); - (index, arg_type) - }).collect(); - let def_body = self.transpile_impl(*def.body, &Context { - enclosing_def_name: &name, - def_map: &def_map, - var_map: &var_map, + let def_with_names: Vec<(Def, String, Vec)> = defs + .into_iter() + .map(|(name, def)| { + let mut free_vars = HashSet::new(); + let identifier_names: HashSet<&str> = HashSet::new(); + Self::get_free_vars(&def.body, &identifier_names, &mut free_vars); + // remove all names bound inside the current def + def.args.iter().for_each(|(v, _v_type)| { + free_vars.remove(v.as_str()); + }); + // remove all names matching defs bound in current scope + def_names.iter().for_each(|v| { + free_vars.remove(v.as_str()); + }); + // remove all names matching defs bound in parent scopes + context.def_map.iter().for_each(|(name, _)| { + free_vars.remove(name.as_str()); + }); + let def_name = if context.enclosing_def_name.is_empty() { + name.clone() + } else { + format!("{}${}", context.enclosing_def_name, name) + }; + let mut free_var_vec: Vec = + free_vars.into_iter().map(|s| s.to_owned()).collect(); + free_var_vec.sort(); + def_map.insert( + name, + ( + def_name.clone(), + free_var_vec + .iter() + .map(|name| VTerm::Var { + index: *context.var_map.get(name).unwrap(), + }) + .collect(), + ), + ); + (def, def_name.clone(), free_var_vec) + }) + .collect(); + def_with_names + .into_iter() + .for_each(|(def, name, free_vars)| { + let mut var_map = HashMap::new(); + let bound_indexes: Vec<_> = free_vars + .iter() + .map(|v| (v.clone(), VType::Uniform)) + .chain(def.args.clone()) + .map(|(arg, arg_type)| { + let index = self.new_local_index(); + var_map.insert(arg.clone(), index); + (index, arg_type) + }) + .collect(); + let def_body = self.transpile_impl( + *def.body, + &Context { + enclosing_def_name: &name, + def_map: &def_map, + var_map: &var_map, + }, + ); + let mut free_var_strings: Vec = + free_vars.iter().map(|s| s.to_string()).collect(); + free_var_strings.extend(def.args.into_iter().map(|(v, _)| v)); + self.signature.defs.insert( + name.clone(), + FunctionDefinition { + args: bound_indexes, + body: def_body, + c_type: def.c_type, + var_bound: self.local_counter, + need_simple: false, + need_cps: false, + need_specialized: false, + }, + ); }); - let mut free_var_strings: Vec = free_vars.iter().map(|s| s.to_string()).collect(); - free_var_strings.extend(def.args.into_iter().map(|(v, _)| v)); - self.signature.defs.insert( - name.clone(), - FunctionDefinition { - args: bound_indexes, - body: def_body, - c_type: def.c_type, - var_bound: self.local_counter, - need_simple: false, - need_cps: false, - need_specialized: false, - }, - ); - }); match body { - None => CTerm::Return { value: VTerm::Struct { values: Vec::new() } }, - Some(body) => self.transpile_impl(*body, &Context { - enclosing_def_name: context.enclosing_def_name, - var_map: context.var_map, - def_map: &def_map, - }) + None => CTerm::Return { + value: VTerm::Struct { values: Vec::new() }, + }, + Some(body) => self.transpile_impl( + *body, + &Context { + enclosing_def_name: context.enclosing_def_name, + var_map: context.var_map, + def_map: &def_map, + }, + ), } } - FTerm::OperationCall { box eff, args, effect } => { - self.transpile_value_and_map(eff, context, |(s, eff)| { - let (transpiled_args, transpiled_computations) = s.transpile_values(args, context); - let operation_call = CTerm::OperationCall { eff, args: transpiled_args, effect }; - Self::squash_computations(operation_call, transpiled_computations) - }) - } + FTerm::OperationCall { + box eff, + args, + effect, + } => self.transpile_value_and_map(eff, context, |(s, eff)| { + let (transpiled_args, transpiled_computations) = s.transpile_values(args, context); + let operation_call = CTerm::OperationCall { + eff, + args: transpiled_args, + effect, + }; + Self::squash_computations(operation_call, transpiled_computations) + }), FTerm::Handler { box parameter, parameter_disposer, parameter_replicator, box transform, handlers, - box input - } => { - self.transpile_value_and_map(parameter, context, |(s, parameter)| { - s.transpile_option_value_and_map(parameter_disposer, context, |(s, parameter_disposer)| { - s.transpile_option_value_and_map(parameter_replicator, context, |(s, parameter_replicator)| { - s.transpile_value_and_map(transform, context, |(s, transform)| { - let (simple_effs, simple_handlers, handler_types): (Vec<_>, Vec<_>, Vec<_>) = itertools::multiunzip(handlers); - let (effs_v, simple_effs_c) = s.transpile_values(simple_effs, context); - let (handlers_v, simple_handlers_c) = s.transpile_values(simple_handlers, context); - s.transpile_value_and_map(input, context, |(_s, input)| { - let handler = CTerm::Handler { - parameter, - parameter_disposer, - parameter_replicator, - transform, - handlers: itertools::multizip((effs_v, handlers_v, handler_types)).collect(), - input, - }; - let handler = Self::squash_computations(handler, simple_effs_c); - Self::squash_computations(handler, simple_handlers_c) + box input, + } => self.transpile_value_and_map(parameter, context, |(s, parameter)| { + s.transpile_option_value_and_map( + parameter_disposer, + context, + |(s, parameter_disposer)| { + s.transpile_option_value_and_map( + parameter_replicator, + context, + |(s, parameter_replicator)| { + s.transpile_value_and_map(transform, context, |(s, transform)| { + let (simple_effs, simple_handlers, handler_types): ( + Vec<_>, + Vec<_>, + Vec<_>, + ) = itertools::multiunzip(handlers); + let (effs_v, simple_effs_c) = + s.transpile_values(simple_effs, context); + let (handlers_v, simple_handlers_c) = + s.transpile_values(simple_handlers, context); + s.transpile_value_and_map(input, context, |(_s, input)| { + let handler = CTerm::Handler { + parameter, + parameter_disposer, + parameter_replicator, + transform, + handlers: itertools::multizip(( + effs_v, + handlers_v, + handler_types, + )) + .collect(), + input, + }; + let handler = + Self::squash_computations(handler, simple_effs_c); + Self::squash_computations(handler, simple_handlers_c) + }) }) - }) - }) - }) - }) - } + }, + ) + }, + ) + }), } } - fn transpile_value(&mut self, f_term: FTerm, context: &Context) -> (VTerm, Option<(usize, CTerm)>) { + fn transpile_value( + &mut self, + f_term: FTerm, + context: &Context, + ) -> (VTerm, Option<(usize, CTerm)>) { match f_term { - FTerm::Identifier { name, effect } => match self.transpile_identifier(&name, context, effect) { - Left(c_term) => self.new_computation(c_term), - Right(v_term) => (v_term, None), + FTerm::Identifier { name, effect } => { + match self.transpile_identifier(&name, context, effect) { + Left(c_term) => self.new_computation(c_term), + Right(v_term) => (v_term, None), + } } FTerm::Int { value } => (VTerm::Int { value }, None), FTerm::Str { value } => (VTerm::Str { value }, None), - FTerm::Struct { .. } => { - match self.transpile_impl(f_term, context) { - CTerm::Return { value } => (value, None), - c_term => self.new_computation(c_term), - } - } + FTerm::Struct { .. } => match self.transpile_impl(f_term, context) { + CTerm::Return { value } => (value, None), + c_term => self.new_computation(c_term), + }, FTerm::Lambda { effect, .. } => { let c_term = self.transpile_impl(f_term, context); - (VTerm::Thunk { t: Box::new(c_term), effect }, None) - } - FTerm::Thunk { computation, effect } => { - (VTerm::Thunk { t: Box::new(self.transpile_impl(*computation, context)), effect }, None) + ( + VTerm::Thunk { + t: Box::new(c_term), + effect, + }, + None, + ) } - FTerm::CaseInt { .. } | - FTerm::MemGet { .. } | - FTerm::MemSet { .. } | - FTerm::Redex { .. } | - FTerm::Force { .. } | - FTerm::Let { .. } | - FTerm::OperationCall { .. } | - FTerm::Handler { .. } | - FTerm::Defs { .. } => { + FTerm::Thunk { + computation, + effect, + } => ( + VTerm::Thunk { + t: Box::new(self.transpile_impl(*computation, context)), + effect, + }, + None, + ), + FTerm::CaseInt { .. } + | FTerm::MemGet { .. } + | FTerm::MemSet { .. } + | FTerm::Redex { .. } + | FTerm::Force { .. } + | FTerm::Let { .. } + | FTerm::OperationCall { .. } + | FTerm::Handler { .. } + | FTerm::Defs { .. } => { let c_term = self.transpile_impl(f_term, context); self.new_computation(c_term) } } } - fn transpile_values(&mut self, f_terms: Vec, context: &Context) -> (Vec, Vec>) { - f_terms.into_iter().map(|v| self.transpile_value(v, context)).unzip() + fn transpile_values( + &mut self, + f_terms: Vec, + context: &Context, + ) -> (Vec, Vec>) { + f_terms + .into_iter() + .map(|v| self.transpile_value(v, context)) + .unzip() } - fn transpile_option_value_and_map(&mut self, f_term: Option>, context: &Context, f: F) -> CTerm + fn transpile_option_value_and_map( + &mut self, + f_term: Option>, + context: &Context, + f: F, + ) -> CTerm where F: FnOnce((&mut Self, Option)) -> CTerm, { @@ -277,7 +421,11 @@ impl Transpiler { Some(box f_term) => { let (v_term, computation) = self.transpile_value(f_term, context); if let Some((name, computation)) = computation { - CTerm::Let { t: Box::new(computation), bound_index: name, body: Box::new(f((self, Some(v_term)))) } + CTerm::Let { + t: Box::new(computation), + bound_index: name, + body: Box::new(f((self, Some(v_term)))), + } } else { f((self, Some(v_term))) } @@ -292,7 +440,11 @@ impl Transpiler { { let (v_term, computation) = self.transpile_value(f_term, context); if let Some((name, computation)) = computation { - CTerm::Let { t: Box::new(computation), bound_index: name, body: Box::new(f((self, v_term))) } + CTerm::Let { + t: Box::new(computation), + bound_index: name, + body: Box::new(f((self, v_term))), + } } else { f((self, v_term)) } @@ -301,7 +453,11 @@ impl Transpiler { fn squash_computations(body: CTerm, computations: Vec>) -> CTerm { computations.into_iter().rfold(body, |c, o| { if let Some((bound_index, t)) = o { - CTerm::Let { t: Box::new(t), bound_index, body: Box::new(c) } + CTerm::Let { + t: Box::new(t), + bound_index, + body: Box::new(c), + } } else { c } @@ -313,17 +469,28 @@ impl Transpiler { (VTerm::Var { index }, Some((index, c_term))) } - fn transpile_identifier(&self, name: &str, context: &Context, effect: Effect) -> Either { + fn transpile_identifier( + &self, + name: &str, + context: &Context, + effect: Effect, + ) -> Either { if let Some(index) = context.var_map.get(name) { Right(VTerm::Var { index: *index }) } else if let Some((name, args)) = context.def_map.get(name) { let term = CTerm::Redex { - function: Box::new(CTerm::Def { name: name.to_owned(), effect }), + function: Box::new(CTerm::Def { + name: name.to_owned(), + effect, + }), args: args.to_owned(), }; Left(term.clone()) } else if let Some(name) = PRIMITIVE_FUNCTIONS.get_key(name) { - Left(CTerm::Def { name: (*name).to_owned(), effect }) + Left(CTerm::Def { + name: (*name).to_owned(), + effect, + }) } else { // properly returning a Result is better but very annoying since that requires transposing out of various collection panic!("Unknown identifier: {}", name) @@ -342,28 +509,50 @@ impl Transpiler { new_local_index } - fn get_free_vars<'a>(f_term: &'a FTerm, bound_names: &HashSet<&'a str>, free_vars: &mut HashSet<&'a str>) { + fn get_free_vars<'a>( + f_term: &'a FTerm, + bound_names: &HashSet<&'a str>, + free_vars: &mut HashSet<&'a str>, + ) { match f_term { - FTerm::Identifier { name, .. } => if !bound_names.contains(name.as_str()) && !PRIMITIVE_FUNCTIONS.contains_key(name.as_str()) { - free_vars.insert(name.as_str()); + FTerm::Identifier { name, .. } => { + if !bound_names.contains(name.as_str()) + && !PRIMITIVE_FUNCTIONS.contains_key(name.as_str()) + { + free_vars.insert(name.as_str()); + } } FTerm::Int { .. } => {} FTerm::Str { .. } => {} - FTerm::Struct { values } => values.iter().for_each(|v| Self::get_free_vars(v, bound_names, free_vars)), - FTerm::Lambda { arg_names, body, .. } => { + FTerm::Struct { values } => values + .iter() + .for_each(|v| Self::get_free_vars(v, bound_names, free_vars)), + FTerm::Lambda { + arg_names, body, .. + } => { let mut new_bound_names = bound_names.clone(); new_bound_names.extend(arg_names.iter().map(|(s, _)| s.as_str())); Self::get_free_vars(body, &new_bound_names, free_vars); } FTerm::Redex { function, args } => { Self::get_free_vars(function, bound_names, free_vars); - args.iter().for_each(|v| Self::get_free_vars(v, bound_names, free_vars)); + args.iter() + .for_each(|v| Self::get_free_vars(v, bound_names, free_vars)); } FTerm::Force { thunk, .. } => Self::get_free_vars(thunk, bound_names, free_vars), - FTerm::Thunk { computation, .. } => Self::get_free_vars(computation, bound_names, free_vars), - FTerm::CaseInt { t, branches, default_branch, .. } => { + FTerm::Thunk { computation, .. } => { + Self::get_free_vars(computation, bound_names, free_vars) + } + FTerm::CaseInt { + t, + branches, + default_branch, + .. + } => { Self::get_free_vars(t, bound_names, free_vars); - branches.iter().for_each(|(_, v)| Self::get_free_vars(v, bound_names, free_vars)); + branches + .iter() + .for_each(|(_, v)| Self::get_free_vars(v, bound_names, free_vars)); if let Some(default_branch) = default_branch { Self::get_free_vars(default_branch, bound_names, free_vars); } @@ -372,7 +561,11 @@ impl Transpiler { Self::get_free_vars(base, bound_names, free_vars); Self::get_free_vars(offset, bound_names, free_vars); } - FTerm::MemSet { base, offset, value } => { + FTerm::MemSet { + base, + offset, + value, + } => { Self::get_free_vars(base, bound_names, free_vars); Self::get_free_vars(offset, bound_names, free_vars); Self::get_free_vars(value, bound_names, free_vars); @@ -397,7 +590,8 @@ impl Transpiler { } FTerm::OperationCall { box eff, args, .. } => { Self::get_free_vars(eff, bound_names, free_vars); - args.iter().for_each(|v| Self::get_free_vars(v, bound_names, free_vars)); + args.iter() + .for_each(|v| Self::get_free_vars(v, bound_names, free_vars)); } FTerm::Handler { box parameter, @@ -405,15 +599,19 @@ impl Transpiler { parameter_replicator, box transform, handlers: simple_handlers, - box input + box input, } => { Self::get_free_vars(parameter, bound_names, free_vars); match parameter_disposer { - Some(box parameter_disposer) => Self::get_free_vars(parameter_disposer, bound_names, free_vars), + Some(box parameter_disposer) => { + Self::get_free_vars(parameter_disposer, bound_names, free_vars) + } None => {} }; match parameter_replicator { - Some(box parameter_replicator) => Self::get_free_vars(parameter_replicator, bound_names, free_vars), + Some(box parameter_replicator) => { + Self::get_free_vars(parameter_replicator, bound_names, free_vars) + } None => {} }; Self::get_free_vars(transform, bound_names, free_vars); @@ -429,13 +627,13 @@ impl Transpiler { #[cfg(test)] mod tests { - use std::fs; - use std::path::PathBuf; - use cranelift_jit::JITModule; + use crate::ast::signature::{FunctionEnablement, Signature}; use crate::backend::compiler::Compiler; use crate::frontend::parser::parse_f_term; - use crate::ast::signature::{FunctionEnablement, Signature}; use crate::frontend::transpiler::Transpiler; + use cranelift_jit::JITModule; + use std::fs; + use std::path::PathBuf; fn check(test_input_path: &PathBuf) -> Result<(), String> { println!("checking {}", test_input_path.to_str().unwrap()); @@ -456,14 +654,16 @@ mod tests { let expected = fs::read_to_string(&test_ir_path).unwrap_or_else(|_| "".to_owned()); let partial_actual = format!( "FTerm\n========\n{:#?}\n\nDefs\n========\n{:#?}", - f_term, - defs); + f_term, defs + ); if expected != partial_actual { // Write partial actual to expected in case executing the compiled function crashes the // test. fs::write(&test_ir_path, &partial_actual).unwrap(); } - let test_clir_path = test_input_path.with_extension("").with_extension("clir.txt"); + let test_clir_path = test_input_path + .with_extension("") + .with_extension("clir.txt"); let expected = fs::read_to_string(&test_clir_path).unwrap_or_else(|_| "".to_owned()); let mut compiler: Compiler = Default::default(); let mut clir = vec![]; @@ -472,20 +672,29 @@ mod tests { let partial_actual = format!( "CLIR\n========\n{}", - clir.iter().map(|(name, clir)| format!("[{}]\n{}", name, clir)).collect::>().join("\n\n")); + clir.iter() + .map(|(name, clir)| format!("[{}]\n{}", name, clir)) + .collect::>() + .join("\n\n") + ); if expected != partial_actual { // Write partial actual to expected in case executing the compiled function crashes the // test. fs::write(test_clir_path, &partial_actual).unwrap(); } - let test_output_path = test_input_path.with_extension("").with_extension("output.txt"); + let test_output_path = test_input_path + .with_extension("") + .with_extension("output.txt"); let expected = fs::read_to_string(&test_output_path).unwrap_or_else(|_| "".to_owned()); let result = main_func(); let actual = format!("{}", result); if expected != actual { fs::write(test_output_path, actual).unwrap(); - Err(format!("Output mismatch for {}", test_input_path.to_str().unwrap())) + Err(format!( + "Output mismatch for {}", + test_input_path.to_str().unwrap() + )) } else { Ok(()) } @@ -498,17 +707,37 @@ mod tests { let mut test_input_paths = fs::read_dir(resource_dir) .unwrap() .map(|r| r.unwrap().path()) - .filter(|p| p.file_name().unwrap().to_str().unwrap().ends_with("input.txt")) + .filter(|p| { + p.file_name() + .unwrap() + .to_str() + .unwrap() + .ends_with("input.txt") + }) .collect::>(); test_input_paths.sort(); - let all_results = test_input_paths.into_iter().map(|test_input_path| { - let result = check(&test_input_path); - (test_input_path, result) - }).filter(|(_, r)| r.is_err()).collect::>(); + let all_results = test_input_paths + .into_iter() + .map(|test_input_path| { + let result = check(&test_input_path); + (test_input_path, result) + }) + .filter(|(_, r)| r.is_err()) + .collect::>(); if all_results.is_empty() { Ok(()) } else { - Err(all_results.into_iter().map(|(test_input_path, r)| format!("[{}] {}", test_input_path.file_name().unwrap().to_str().unwrap(), r.unwrap_err())).collect::>().join("\n")) + Err(all_results + .into_iter() + .map(|(test_input_path, r)| { + format!( + "[{}] {}", + test_input_path.file_name().unwrap().to_str().unwrap(), + r.unwrap_err() + ) + }) + .collect::>() + .join("\n")) } } -} \ No newline at end of file +} diff --git a/src/lib.rs b/src/lib.rs index 390eecc..8be56aa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,7 +5,7 @@ extern crate core; -mod frontend; mod ast; mod backend; -mod test_utils; \ No newline at end of file +mod frontend; +mod test_utils; diff --git a/src/test_utils.rs b/src/test_utils.rs index 5f58067..9a74851 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -2,6 +2,9 @@ use std::fmt::Debug; -pub fn debug_print(t: T) -> String where T: Debug { +pub fn debug_print(t: T) -> String +where + T: Debug, +{ format!("{:#?}", t) }