From b252fe625c09a99b09ff2e4e8ec7612e5377064b Mon Sep 17 00:00:00 2001 From: Robin Brown Date: Fri, 23 Feb 2024 14:06:36 -0500 Subject: [PATCH] feat: Treat values as being made of fields (#35) --- crates/ast/src/types.rs | 26 +- crates/codegen/src/builders/component.rs | 237 +++++++++ crates/codegen/src/builders/mod.rs | 2 + .../{module_builder.rs => builders/module.rs} | 0 crates/codegen/src/code.rs | 380 ++++++++++++++ crates/codegen/src/component_builder.rs | 293 ----------- crates/codegen/src/expression.rs | 469 +++++++++++------- crates/codegen/src/function.rs | 144 ++++++ crates/codegen/src/lib.rs | 462 ++++++----------- crates/codegen/src/statement.rs | 218 ++++---- crates/codegen/src/types.rs | 365 ++++++++++++++ crates/lib/src/lib.rs | 5 +- crates/lib/tests/programs/claw.wit | 4 + crates/lib/tests/programs/strings.claw | 11 + crates/lib/tests/runtime.rs | 36 +- crates/parser/src/types.rs | 5 + crates/resolver/src/lib.rs | 6 +- src/bin.rs | 5 +- test.wat | 119 +++++ 19 files changed, 1867 insertions(+), 920 deletions(-) create mode 100644 crates/codegen/src/builders/component.rs create mode 100644 crates/codegen/src/builders/mod.rs rename crates/codegen/src/{module_builder.rs => builders/module.rs} (100%) create mode 100644 crates/codegen/src/code.rs delete mode 100644 crates/codegen/src/component_builder.rs create mode 100644 crates/codegen/src/function.rs create mode 100644 crates/codegen/src/types.rs create mode 100644 crates/lib/tests/programs/strings.claw create mode 100644 test.wat diff --git a/crates/ast/src/types.rs b/crates/ast/src/types.rs index bac06ab..4cd59e0 100644 --- a/crates/ast/src/types.rs +++ b/crates/ast/src/types.rs @@ -9,8 +9,6 @@ entity_impl!(TypeId, "type"); /// The type for all values #[derive(Debug, Hash, Clone)] pub enum ValType { - // TypeName(NameId), - // Result Type Result { ok: TypeId, err: TypeId }, @@ -21,21 +19,23 @@ pub enum ValType { #[derive(Debug, Hash, Clone, Copy, PartialEq, Eq)] pub enum PrimitiveType { - // Unsigned Integers - U64, - U32, - U16, + // The boolean type + Bool, + // 8-bit Integers U8, - // Signed Integers - S64, - S32, - S16, S8, + // 16-bit Integers + U16, + S16, + // 32-bit Integers + U32, + S32, + // 64-bit Integers + U64, + S64, // Floating Point Numbers - F64, F32, - // The boolean type - Bool, + F64, } #[derive(PartialEq, Eq, Debug, Clone, Copy)] diff --git a/crates/codegen/src/builders/component.rs b/crates/codegen/src/builders/component.rs new file mode 100644 index 0000000..6ad9bb6 --- /dev/null +++ b/crates/codegen/src/builders/component.rs @@ -0,0 +1,237 @@ +use wasm_encoder as enc; + +#[derive(Default)] +pub struct ComponentBuilder { + component: enc::Component, + + num_types: u32, + num_funcs: u32, + num_core_funcs: u32, + num_core_mems: u32, + num_modules: u32, + num_module_instances: u32, +} + +#[derive(Clone, Copy, Debug)] +pub struct ComponentModuleIndex(u32); + +#[derive(Clone, Copy, Debug)] +pub struct ComponentModuleInstanceIndex(u32); + +#[derive(Clone, Copy, Debug)] +pub struct ComponentTypeIndex(u32); + +#[derive(Clone, Copy, Debug)] +pub struct ComponentFunctionIndex(u32); + +#[derive(Clone, Copy, Debug)] +pub struct ComponentCoreFunctionIndex(u32); + +#[derive(Clone, Copy, Debug)] +pub struct ComponentCoreMemoryIndex(u32); + +pub enum InlineExportItem { + Func(ComponentCoreFunctionIndex), +} + +pub enum ModuleInstiateArgs { + Instance(ComponentModuleInstanceIndex), +} + +impl ComponentBuilder { + pub fn module(&mut self, module: enc::Module) -> ComponentModuleIndex { + self.component.section(&enc::ModuleSection(&module)); + self.next_mod_idx() + } + + pub fn module_bytes(&mut self, bytes: Vec) -> ComponentModuleIndex { + self.component.section(&enc::RawSection { + id: enc::ComponentSectionId::CoreModule.into(), + data: bytes.as_slice(), + }); + self.next_mod_idx() + } + + #[allow(dead_code)] + pub fn inline_export( + &mut self, + exports: Vec<(String, InlineExportItem)>, + ) -> ComponentModuleInstanceIndex { + let exports: Vec<(String, enc::ExportKind, u32)> = exports + .into_iter() + .map(|(name, arg)| match arg { + InlineExportItem::Func(func) => (name, enc::ExportKind::Func, func.0), + }) + .collect(); + let mut section = enc::InstanceSection::new(); + section.export_items(exports); + self.component.section(§ion); + self.next_mod_instance_idx() + } + + pub fn instantiate( + &mut self, + module: ComponentModuleIndex, + args: Vec<(String, ModuleInstiateArgs)>, + ) -> ComponentModuleInstanceIndex { + let args: Vec<_> = args + .into_iter() + .map(|(name, arg)| match arg { + ModuleInstiateArgs::Instance(instance) => { + (name, enc::ModuleArg::Instance(instance.0)) + } + }) + .collect(); + let mut section = enc::InstanceSection::new(); + section.instantiate(module.0, args); + self.component.section(§ion); + self.next_mod_instance_idx() + } + + pub fn func_type<'b, P>( + &mut self, + params: P, + result: Option, + ) -> ComponentTypeIndex + where + P: IntoIterator, + P::IntoIter: ExactSizeIterator, + { + let mut section = enc::ComponentTypeSection::new(); + let mut builder = section.function(); + builder.params(params); + match result { + Some(return_type) => { + builder.result(return_type); + } + None => { + builder.results([] as [(&str, enc::ComponentValType); 0]); + } + } + self.component.section(§ion); + self.next_type_idx() + } + + pub fn import_func( + &mut self, + name: &str, + fn_type: ComponentTypeIndex, + ) -> ComponentFunctionIndex { + let mut section = enc::ComponentImportSection::new(); + let ty = enc::ComponentTypeRef::Func(fn_type.0); + section.import(name, ty); + self.component.section(§ion); + self.next_func_idx() + } + + pub fn lower_func(&mut self, func: ComponentFunctionIndex) -> ComponentCoreFunctionIndex { + let mut section = enc::CanonicalFunctionSection::new(); + section.lower(func.0, []); + self.component.section(§ion); + self.next_core_func_idx() + } + + pub fn alias_memory( + &mut self, + instance: ComponentModuleInstanceIndex, + name: &str, + ) -> ComponentCoreMemoryIndex { + let mut section = enc::ComponentAliasSection::new(); + section.alias(enc::Alias::CoreInstanceExport { + instance: instance.0, + kind: enc::ExportKind::Memory, + name, + }); + self.component.section(§ion); + self.next_core_memory_idx() + } + + pub fn alias_func( + &mut self, + instance: ComponentModuleInstanceIndex, + name: &str, + ) -> ComponentCoreFunctionIndex { + let mut section = enc::ComponentAliasSection::new(); + section.alias(enc::Alias::CoreInstanceExport { + instance: instance.0, + kind: enc::ExportKind::Func, + name, + }); + self.component.section(§ion); + self.next_core_func_idx() + } + + pub fn lift_func( + &mut self, + func: ComponentCoreFunctionIndex, + fn_type: ComponentTypeIndex, + memory: ComponentCoreMemoryIndex, + realloc: ComponentCoreFunctionIndex, + ) -> ComponentFunctionIndex { + let mut section = enc::CanonicalFunctionSection::new(); + let canon_opts: [enc::CanonicalOption; 2] = [ + enc::CanonicalOption::Memory(memory.0), + enc::CanonicalOption::Realloc(realloc.0), + ]; + section.lift(func.0, fn_type.0, canon_opts); + self.component.section(§ion); + self.next_func_idx() + } + + pub fn export_func( + &mut self, + name: &str, + func: ComponentFunctionIndex, + fn_type: ComponentTypeIndex, + ) -> ComponentFunctionIndex { + let mut section = enc::ComponentExportSection::new(); + section.export( + name, + enc::ComponentExportKind::Func, + func.0, + Some(enc::ComponentTypeRef::Func(fn_type.0)), + ); + self.component.section(§ion); + self.next_func_idx() + } + + pub fn finalize(self) -> enc::Component { + self.component + } + + fn next_mod_idx(&mut self) -> ComponentModuleIndex { + let index = ComponentModuleIndex(self.num_modules); + self.num_modules += 1; + index + } + + fn next_mod_instance_idx(&mut self) -> ComponentModuleInstanceIndex { + let index = ComponentModuleInstanceIndex(self.num_module_instances); + self.num_module_instances += 1; + index + } + + fn next_type_idx(&mut self) -> ComponentTypeIndex { + let index = ComponentTypeIndex(self.num_types); + self.num_types += 1; + index + } + + fn next_func_idx(&mut self) -> ComponentFunctionIndex { + let index = ComponentFunctionIndex(self.num_funcs); + self.num_funcs += 1; + index + } + + fn next_core_func_idx(&mut self) -> ComponentCoreFunctionIndex { + let index = ComponentCoreFunctionIndex(self.num_core_funcs); + self.num_core_funcs += 1; + index + } + + fn next_core_memory_idx(&mut self) -> ComponentCoreMemoryIndex { + let index = ComponentCoreMemoryIndex(self.num_core_mems); + self.num_core_mems += 1; + index + } +} diff --git a/crates/codegen/src/builders/mod.rs b/crates/codegen/src/builders/mod.rs new file mode 100644 index 0000000..74a9898 --- /dev/null +++ b/crates/codegen/src/builders/mod.rs @@ -0,0 +1,2 @@ +pub mod component; +pub mod module; diff --git a/crates/codegen/src/module_builder.rs b/crates/codegen/src/builders/module.rs similarity index 100% rename from crates/codegen/src/module_builder.rs rename to crates/codegen/src/builders/module.rs diff --git a/crates/codegen/src/code.rs b/crates/codegen/src/code.rs new file mode 100644 index 0000000..058f2b7 --- /dev/null +++ b/crates/codegen/src/code.rs @@ -0,0 +1,380 @@ +use std::collections::HashMap; + +use ast::{ExpressionId, FunctionId, NameId, StatementId}; +use claw_ast as ast; + +use crate::{ + function::{FunctionGenerator, ParamInfo, ReturnInfo}, + types::{ptype_mem_arg, EncodeType, FieldInfo}, + ComponentGenerator, EncodeExpression, EncodeStatement, GenerationError, ModuleFunctionIndex, +}; +use claw_resolver::{FunctionResolver, ItemId, LocalId, ParamId, ResolvedComponent, ResolvedType}; +use cranelift_entity::EntityRef; +use wasm_encoder as enc; + +pub struct CodeGenerator<'gen> { + id: FunctionId, + comp: &'gen ResolvedComponent, + parent: &'gen mut ComponentGenerator, + + builder: enc::Function, + + params: Vec, + return_info: ReturnInfo, + return_index: Option, + index_for_local: HashMap, + index_for_expr: HashMap, + #[allow(dead_code)] + local_space: Vec, +} +pub struct CoreLocalId(u32); + +impl From for CoreLocalId { + fn from(value: u32) -> Self { + CoreLocalId(value) + } +} + +impl<'gen> CodeGenerator<'gen> { + pub fn new( + code_gen: &'gen mut ComponentGenerator, + comp: &'gen ResolvedComponent, + func_gen: FunctionGenerator, + realloc: ModuleFunctionIndex, + id: FunctionId, + ) -> Result { + let resolver = &comp.resolved_funcs[&id]; + + // Layout parameters + let FunctionGenerator { + params, + param_types, + return_type, + } = func_gen; + let mut local_space = param_types; + + let locals_start = local_space.len(); + + let return_index = if return_type.spill() { + let index = local_space.len(); + local_space.push(enc::ValType::I32); + Some(index as u32) + } else { + None + }; + + // Layout locals + let mut index_for_local = HashMap::new(); + let mut locals = Vec::with_capacity(resolver.locals.len()); + for (id, _local) in resolver.locals.iter() { + let rtype = resolver.get_resolved_local_type(id, &comp.component)?; + let local_id = CoreLocalId((local_space.len() + locals.len()) as u32); + index_for_local.insert(id, local_id); + rtype.append_flattened(&comp.component, &mut locals); + } + local_space.extend(locals); + + // Layout expressions + let resolver = &comp.resolved_funcs[&id]; + let mut index_for_expr = HashMap::new(); + let mut allocator = + ExpressionAllocator::new(comp, resolver, &mut local_space, &mut index_for_expr); + let function = &comp.component.functions[id]; + for statement in function.body.iter() { + let statement = comp.component.get_statement(*statement); + statement.alloc_expr_locals(&mut allocator)?; + } + + let locals = &local_space[locals_start..]; + let locals = locals.iter().map(|l| (1, *l)); + let mut builder = enc::Function::new(locals); + + if let Some(return_index) = return_index { + // old ptr, old size + builder.instruction(&enc::Instruction::I32Const(0)); + builder.instruction(&enc::Instruction::I32Const(0)); + + let result_type = comp.component.functions[id].return_type.unwrap(); + // align + let align = result_type.align(&comp.component); + let align = 2u32.pow(align); + builder.instruction(&enc::Instruction::I32Const(align as i32)); + // new size + let size = result_type.mem_size(&comp.component); + builder.instruction(&enc::Instruction::I32Const(size as i32)); + // call allocator + builder.instruction(&enc::Instruction::Call(realloc.into())); + // store address + builder.instruction(&enc::Instruction::LocalSet(return_index)); + } + + Ok(Self { + id, + parent: code_gen, + comp, + builder, + params, + return_info: return_type, + return_index, + index_for_local, + index_for_expr, + local_space, + }) + } + + pub fn encode_statement(&mut self, statement: StatementId) -> Result<(), GenerationError> { + let stmt = self.comp.component.get_statement(statement); + stmt.encode(self) + } + + pub fn encode_child(&mut self, expression: ExpressionId) -> Result<(), GenerationError> { + let expr = self.comp.component.expr().get_exp(expression); + expr.encode(expression, self) + } + + pub fn instruction(&mut self, instruction: &enc::Instruction) { + self.builder.instruction(instruction); + } + + pub fn get_resolved_type( + &self, + expression: ExpressionId, + ) -> Result { + let resolver = &self.comp.resolved_funcs[&self.id]; + let type_id = resolver.get_resolved_type(expression, &self.comp.component)?; + Ok(type_id) + } + + pub fn fields(&self, expression: ExpressionId) -> Result, GenerationError> { + let rtype = self.get_resolved_type(expression)?; + Ok(rtype.fields(&self.comp.component)) + } + + pub fn lookup_name(&self, ident: NameId) -> ItemId { + let resolver = &self.comp.resolved_funcs[&self.id]; + resolver.bindings[&ident] + } + + pub fn spill_return(&self) -> bool { + self.return_info.spill() + } + + pub fn encode_call(&mut self, item: ItemId) -> Result<(), GenerationError> { + let index = match item { + ItemId::Import(import) => self.parent.func_idx_for_import[&import], + ItemId::Function(function) => self.parent.func_idx_for_func[&function], + _ => panic!(""), + }; + self.instruction(&enc::Instruction::Call(index.into())); + Ok(()) + } + + pub fn read_param_field(&mut self, param: ParamId, field: &FieldInfo) { + let param_info = &self.params[param.index()]; + match param_info { + ParamInfo::Local(local_info) => { + let local_index = local_info.index_offset + field.index_offset; + self.local_get(local_index); + } + ParamInfo::Spilled(spilled_info) => { + let mem_index = spilled_info.mem_offset + field.mem_offset; + self.builder.instruction(&enc::Instruction::LocalGet(0)); + self.builder + .instruction(&enc::Instruction::I32Const(mem_index as i32)); + self.builder.instruction(&enc::Instruction::I32Add); + self.load_ptype(field.ptype); + } + } + } + + pub fn read_local_field(&mut self, local: LocalId, field: &FieldInfo) { + let local_index = &self.index_for_local[&local]; + let local_index = local_index.0 + field.index_offset; + self.local_get(local_index); + } + + pub fn write_local_field(&mut self, local: LocalId, field: &FieldInfo) { + let local_index = &self.index_for_local[&local]; + let local_index = local_index.0 + field.index_offset; + self.local_set(local_index); + } + + pub fn read_expr_field(&mut self, expression: ExpressionId, field: &FieldInfo) { + let local_index = &self.index_for_expr[&expression]; + let local_index = local_index.0 + field.index_offset; + self.local_get(local_index); + } + + pub fn write_expr_field(&mut self, expression: ExpressionId, field: &FieldInfo) { + let local_index = &self.index_for_expr[&expression]; + let local_index = local_index.0 + field.index_offset; + self.local_set(local_index); + } + + pub fn read_return_ptr(&mut self) -> Result<(), GenerationError> { + let return_ptr_index = self.return_index.unwrap(); + self.local_get(return_ptr_index); + Ok(()) + } + + /// The value's base memory offset MUST be on the stack before calling this + pub fn field_address(&mut self, field: &FieldInfo) { + self.instruction(&enc::Instruction::I32Const(field.mem_offset as i32)); + self.instruction(&enc::Instruction::I32Add); + } + + /// The value's base memory offset MUST be on the stack before calling this + pub fn read_mem_field(&mut self, field: &FieldInfo) { + self.field_address(field); + self.load_ptype(field.ptype); + } + + /// Fields absolute offset in memory MUST be on the stack underneath the value before calling this + pub fn write_mem(&mut self, field: &FieldInfo) { + self.store_ptype(field.ptype); + } + + pub fn encode_const_int(&mut self, int: u64, field: &FieldInfo) { + let instruction = match field.ptype { + ast::PrimitiveType::U8 + | ast::PrimitiveType::S8 + | ast::PrimitiveType::U16 + | ast::PrimitiveType::S16 + | ast::PrimitiveType::U32 + | ast::PrimitiveType::S32 => enc::Instruction::I32Const(int as i32), + + ast::PrimitiveType::U64 | ast::PrimitiveType::S64 => { + enc::Instruction::I64Const(int as i64) + } + + _ => panic!("Not an integer"), + }; + self.instruction(&instruction); + } + + pub fn encode_const_float(&mut self, float: f64, field: &FieldInfo) { + let instruction = match field.ptype { + ast::PrimitiveType::F32 => enc::Instruction::F32Const(float as f32), + ast::PrimitiveType::F64 => enc::Instruction::F64Const(float), + _ => panic!("Not a float!"), + }; + self.instruction(&instruction); + } + + fn local_get(&mut self, local_index: u32) { + self.builder + .instruction(&enc::Instruction::LocalGet(local_index)); + } + + fn local_set(&mut self, local_index: u32) { + self.builder + .instruction(&enc::Instruction::LocalSet(local_index)); + } + + fn load_ptype(&mut self, ptype: ast::PrimitiveType) { + let mem_arg = ptype_mem_arg(ptype); + let instruction = match ptype { + // Small types with sign-extending + ast::PrimitiveType::U8 => enc::Instruction::I32Load8U(mem_arg), + ast::PrimitiveType::S8 => enc::Instruction::I32Load8S(mem_arg), + ast::PrimitiveType::U16 => enc::Instruction::I32Load16U(mem_arg), + ast::PrimitiveType::S16 => enc::Instruction::I32Load16S(mem_arg), + // 32 and 64 bit values don't need sign-extending + ast::PrimitiveType::U32 | ast::PrimitiveType::S32 => enc::Instruction::I32Load(mem_arg), + ast::PrimitiveType::U64 | ast::PrimitiveType::S64 => enc::Instruction::I64Load(mem_arg), + // Floats + ast::PrimitiveType::F32 => enc::Instruction::F32Load(mem_arg), + ast::PrimitiveType::F64 => enc::Instruction::F64Load(mem_arg), + // Booleans are treated as 8-bit unsigned values + ast::PrimitiveType::Bool => enc::Instruction::I32Load8U(mem_arg), + }; + self.builder.instruction(&instruction); + } + + fn store_ptype(&mut self, ptype: ast::PrimitiveType) { + let mem_arg = ptype_mem_arg(ptype); + let instruction = match ptype { + // All types which fit in i32 + ast::PrimitiveType::Bool + | ast::PrimitiveType::U8 + | ast::PrimitiveType::S8 + | ast::PrimitiveType::U16 + | ast::PrimitiveType::S16 + | ast::PrimitiveType::U32 + | ast::PrimitiveType::S32 => enc::Instruction::I32Store(mem_arg), + // Types that use i64 + ast::PrimitiveType::U64 | ast::PrimitiveType::S64 => { + enc::Instruction::I64Store(mem_arg) + } + // Floats + ast::PrimitiveType::F32 => enc::Instruction::F32Store(mem_arg), + ast::PrimitiveType::F64 => enc::Instruction::F64Store(mem_arg), + }; + self.builder.instruction(&instruction); + } + + pub fn finalize(mut self) -> Result<(), GenerationError> { + let function = &self.comp.component.functions[self.id]; + for statement in function.body.iter() { + self.encode_statement(*statement)?; + } + self.builder.instruction(&enc::Instruction::End); + + let mod_func_idx = self.parent.func_idx_for_func[&self.id]; + self.parent.module.code(mod_func_idx, self.builder); + Ok(()) + } +} + +pub struct ExpressionAllocator<'a> { + // Context + comp: &'a ResolvedComponent, + resolver: &'a FunctionResolver, + // State + local_space: &'a mut Vec, + index_for_expr: &'a mut HashMap, +} + +impl<'a> ExpressionAllocator<'a> { + pub fn new( + comp: &'a ResolvedComponent, + resolver: &'a FunctionResolver, + local_space: &'a mut Vec, + index_for_expr: &'a mut HashMap, + ) -> Self { + Self { + comp, + resolver, + local_space, + index_for_expr, + } + } + + pub fn alloc(&mut self, expression: ExpressionId) -> Result<(), GenerationError> { + // Record index + let index = self.local_space.len() as u32; + let index = CoreLocalId(index); + self.index_for_expr.insert(expression, index); + // Allocate locals + let rtype = self + .resolver + .get_resolved_type(expression, &self.comp.component)?; + rtype.append_flattened(&self.comp.component, self.local_space); + Ok(()) + } + + pub fn alloc_extra(&mut self, valtype: enc::ValType) -> Result<(), GenerationError> { + self.local_space.push(valtype); + Ok(()) + } + + pub fn alloc_child(&mut self, expression: ExpressionId) -> Result<(), GenerationError> { + let expr = self.comp.component.expr().get_exp(expression); + expr.alloc_expr_locals(expression, self) + } + + pub fn alloc_statement(&mut self, statement: StatementId) -> Result<(), GenerationError> { + let statement = self.comp.component.get_statement(statement); + statement.alloc_expr_locals(self) + } +} diff --git a/crates/codegen/src/component_builder.rs b/crates/codegen/src/component_builder.rs deleted file mode 100644 index 89b8609..0000000 --- a/crates/codegen/src/component_builder.rs +++ /dev/null @@ -1,293 +0,0 @@ -use wasm_encoder as enc; - -#[derive(Default)] -pub struct ComponentBuilder { - alias: enc::ComponentAliasSection, - types: enc::ComponentTypeSection, - imports: enc::ComponentImportSection, - lower_funcs: enc::CanonicalFunctionSection, - lift_funcs: enc::CanonicalFunctionSection, - exports: enc::ComponentExportSection, - - modules: Vec, - module_instances: Vec, - - num_types: u32, - num_funcs: u32, - num_core_funcs: u32, - num_module_instances: u32, -} - -#[derive(Clone, Copy, Debug)] -pub struct ComponentModuleIndex(u32); -#[derive(Clone, Copy, Debug)] -pub struct ComponentModuleInstanceIndex(u32); -#[derive(Clone, Copy, Debug)] -pub struct ComponentTypeIndex(u32); -#[derive(Clone, Copy, Debug)] -pub struct ComponentFunctionIndex(u32); -#[derive(Clone, Copy, Debug)] -pub struct ComponentCoreFunctionIndex(u32); - -enum ModuleSlot { - Empty, - Filled(enc::Module), - Bytes(Vec), -} - -enum ModuleInstanceSlots { - EmptyInlineExport, - FilledInlineExport { - exports: Vec<(String, enc::ExportKind, u32)>, - }, - ModuleInstance { - module_index: ComponentModuleIndex, - args: Vec<(String, enc::ModuleArg)>, - }, -} - -pub enum InlineExportItem { - Func(ComponentCoreFunctionIndex), -} - -pub enum ModuleInstiateArgs { - Instance(ComponentModuleInstanceIndex), -} - -impl ComponentBuilder { - #[allow(dead_code)] - pub fn module(&mut self, module: enc::Module) -> ComponentModuleIndex { - let index = self.modules.len() as u32; - self.modules.push(ModuleSlot::Filled(module)); - ComponentModuleIndex(index) - } - - pub fn module_bytes(&mut self, bytes: Vec) -> ComponentModuleIndex { - let index = self.modules.len() as u32; - self.modules.push(ModuleSlot::Bytes(bytes)); - ComponentModuleIndex(index) - } - - pub fn reserve_module(&mut self) -> ComponentModuleIndex { - let index = self.modules.len() as u32; - self.modules.push(ModuleSlot::Empty); - ComponentModuleIndex(index) - } - - pub fn fill_module(&mut self, index: ComponentModuleIndex, module: enc::Module) { - let index = index.0 as usize; - assert!(matches!(self.modules[index], ModuleSlot::Empty)); - self.modules[index] = ModuleSlot::Filled(module); - } - - #[allow(dead_code)] - pub fn inline_export( - &mut self, - exports: Vec<(String, enc::ExportKind, u32)>, - ) -> ComponentModuleInstanceIndex { - self.module_instances - .push(ModuleInstanceSlots::FilledInlineExport { exports }); - self.next_mod_instance_idx() - } - - pub fn reserve_inline_export(&mut self) -> ComponentModuleInstanceIndex { - self.module_instances - .push(ModuleInstanceSlots::EmptyInlineExport); - self.next_mod_instance_idx() - } - - pub fn fill_inline_export_args( - &mut self, - instance: ComponentModuleInstanceIndex, - exports: Vec<(String, InlineExportItem)>, - ) { - let index = instance.0 as usize; - match &self.module_instances[index] { - ModuleInstanceSlots::EmptyInlineExport => { - let exports = exports - .into_iter() - .map(|(name, arg)| match arg { - InlineExportItem::Func(func) => (name, enc::ExportKind::Func, func.0), - }) - .collect(); - self.module_instances[index] = ModuleInstanceSlots::FilledInlineExport { exports }; - } - ModuleInstanceSlots::FilledInlineExport { .. } => { - panic!("Slot for instance {} already filled", index) - } - ModuleInstanceSlots::ModuleInstance { .. } => { - panic!("Slot for instance {} already filled", index) - } - } - } - - pub fn instantiate( - &mut self, - module: ComponentModuleIndex, - args: Vec<(String, ModuleInstiateArgs)>, - ) -> ComponentModuleInstanceIndex { - let args = args - .into_iter() - .map(|(name, arg)| match arg { - ModuleInstiateArgs::Instance(instance) => { - (name, enc::ModuleArg::Instance(instance.0)) - } - }) - .collect(); - self.module_instances - .push(ModuleInstanceSlots::ModuleInstance { - module_index: module, - args, - }); - self.next_mod_instance_idx() - } - - pub fn func_type<'b, P>( - &mut self, - params: P, - result: Option, - ) -> ComponentTypeIndex - where - P: IntoIterator, - P::IntoIter: ExactSizeIterator, - { - let mut builder = self.types.function(); - builder.params(params); - match result { - Some(return_type) => { - builder.result(return_type); - } - None => { - builder.results([] as [(&str, enc::ComponentValType); 0]); - } - } - self.next_type_idx() - } - - pub fn import_func( - &mut self, - name: &str, - fn_type: ComponentTypeIndex, - ) -> ComponentFunctionIndex { - let ty = enc::ComponentTypeRef::Func(fn_type.0); - self.imports.import(name, ty); - self.next_func_idx() - } - - pub fn lower_func(&mut self, func: ComponentFunctionIndex) -> ComponentCoreFunctionIndex { - self.lower_funcs.lower(func.0, []); - self.next_core_func_idx() - } - - pub fn alias_func( - &mut self, - instance: ComponentModuleInstanceIndex, - name: &str, - ) -> ComponentCoreFunctionIndex { - self.alias.alias(enc::Alias::CoreInstanceExport { - instance: instance.0, - kind: enc::ExportKind::Func, - name, - }); - self.next_core_func_idx() - } - - pub fn lift_func( - &mut self, - func: ComponentCoreFunctionIndex, - fn_type: ComponentTypeIndex, - ) -> ComponentFunctionIndex { - const NO_CANON_OPTS: [enc::CanonicalOption; 0] = []; - self.lift_funcs.lift(func.0, fn_type.0, NO_CANON_OPTS); - self.next_func_idx() - } - - pub fn export_func( - &mut self, - name: &str, - func: ComponentFunctionIndex, - fn_type: ComponentTypeIndex, - ) { - self.exports.export( - name, - enc::ComponentExportKind::Func, - func.0, - Some(enc::ComponentTypeRef::Func(fn_type.0)), - ); - } - - pub fn finalize(self) -> enc::Component { - let mut component = enc::Component::new(); - - // Component Types - component.section(&self.types); - // Component Imports - component.section(&self.imports); - // Lower Imported Functions - component.section(&self.lower_funcs); - - for module in self.modules { - match module { - ModuleSlot::Empty => panic!("Module slot not filled before component finalized"), - ModuleSlot::Filled(module) => { - component.section(&enc::ModuleSection(&module)); - } - ModuleSlot::Bytes(bytes) => { - component.section(&enc::RawSection { - id: enc::ComponentSectionId::CoreModule.into(), - data: bytes.as_slice(), - }); - } - }; - } - - let mut instantiations = enc::InstanceSection::new(); - for instance in self.module_instances { - match instance { - ModuleInstanceSlots::EmptyInlineExport => { - panic!("Inline export instantiation slot not filled before component finalized") - } - ModuleInstanceSlots::FilledInlineExport { exports } => { - instantiations.export_items(exports); - } - ModuleInstanceSlots::ModuleInstance { module_index, args } => { - instantiations.instantiate(module_index.0, args); - } - } - } - component.section(&instantiations); - - // Alias module exports - component.section(&self.alias); - // Lift component functions - component.section(&self.lift_funcs); - // Export component functions - component.section(&self.exports); - - component - } - - fn next_mod_instance_idx(&mut self) -> ComponentModuleInstanceIndex { - let index = ComponentModuleInstanceIndex(self.num_module_instances); - self.num_module_instances += 1; - index - } - - fn next_type_idx(&mut self) -> ComponentTypeIndex { - let index = ComponentTypeIndex(self.num_types); - self.num_types += 1; - index - } - - fn next_func_idx(&mut self) -> ComponentFunctionIndex { - let index = ComponentFunctionIndex(self.num_funcs); - self.num_funcs += 1; - index - } - - fn next_core_func_idx(&mut self) -> ComponentCoreFunctionIndex { - let index = ComponentCoreFunctionIndex(self.num_core_funcs); - self.num_core_funcs += 1; - index - } -} diff --git a/crates/codegen/src/expression.rs b/crates/codegen/src/expression.rs index 2c71a04..3879bab 100644 --- a/crates/codegen/src/expression.rs +++ b/crates/codegen/src/expression.rs @@ -1,45 +1,48 @@ -use ast::{ExpressionId, FunctionId, Signedness}; +use ast::{ExpressionId, Signedness}; use claw_ast as ast; -use claw_resolver::{ItemId, ResolvedComponent}; +use claw_resolver::ItemId; -use super::{CodeGenerator, GenerationError}; +use crate::code::{CodeGenerator, ExpressionAllocator}; +use crate::GenerationError; use cranelift_entity::EntityRef; use wasm_encoder as enc; use wasm_encoder::Instruction; -/// A simple helper that calls EncodeExpression::encode -pub fn encode_expression( - generator: &CodeGenerator, - component: &ResolvedComponent, - expression: ExpressionId, - func: FunctionId, - builder: &mut enc::Function, -) -> Result<(), GenerationError> { - let expr = component.component.expr().get_exp(expression); - expr.encode(generator, component, expression, func, builder)?; - Ok(()) -} - pub trait EncodeExpression { + fn alloc_expr_locals( + &self, + expression: ExpressionId, + allocator: &mut ExpressionAllocator, + ) -> Result<(), GenerationError>; + fn encode( &self, - generator: &CodeGenerator, - component: &ResolvedComponent, expression: ExpressionId, - func: FunctionId, - builder: &mut enc::Function, + code_gen: &mut CodeGenerator, ) -> Result<(), GenerationError>; } impl EncodeExpression for ast::Expression { + fn alloc_expr_locals( + &self, + expression: ExpressionId, + allocator: &mut ExpressionAllocator, + ) -> Result<(), GenerationError> { + let expr: &dyn EncodeExpression = match self { + ast::Expression::Identifier(expr) => expr, + ast::Expression::Literal(expr) => expr, + ast::Expression::Call(expr) => expr, + ast::Expression::Unary(expr) => expr, + ast::Expression::Binary(expr) => expr, + }; + expr.alloc_expr_locals(expression, allocator) + } + fn encode( &self, - generator: &CodeGenerator, - component: &ResolvedComponent, expression: ExpressionId, - func: FunctionId, - builder: &mut enc::Function, + code_gen: &mut CodeGenerator, ) -> Result<(), GenerationError> { let expr: &dyn EncodeExpression = match self { ast::Expression::Identifier(expr) => expr, @@ -48,104 +51,137 @@ impl EncodeExpression for ast::Expression { ast::Expression::Unary(expr) => expr, ast::Expression::Binary(expr) => expr, }; - expr.encode(generator, component, expression, func, builder)?; + expr.encode(expression, code_gen)?; Ok(()) } } impl EncodeExpression for ast::Identifier { + fn alloc_expr_locals( + &self, + expression: ExpressionId, + allocator: &mut ExpressionAllocator, + ) -> Result<(), GenerationError> { + allocator.alloc(expression) + } + fn encode( &self, - _generator: &CodeGenerator, - component: &ResolvedComponent, - _expression: ExpressionId, - func: FunctionId, - builder: &mut enc::Function, + expression: ExpressionId, + code_gen: &mut CodeGenerator, ) -> Result<(), GenerationError> { - let resolver = component.resolved_funcs.get(&func).unwrap(); - match resolver.bindings.get(&self.ident).unwrap() { - ItemId::Import(_) => unimplemented!(), + let fields = code_gen.fields(expression)?; + match code_gen.lookup_name(self.ident) { + ItemId::Import(_) => panic!("Cannot use import as value!!"), ItemId::Global(global) => { - builder.instruction(&Instruction::GlobalGet(global.index() as u32)); + for field in fields.iter() { + // TODO handle composite globals + code_gen.instruction(&Instruction::GlobalGet(global.index() as u32)); + code_gen.write_expr_field(expression, field); + } } ItemId::Param(param) => { - let local_index = param.index(); - builder.instruction(&Instruction::LocalGet(local_index as u32)); + for field in fields.iter() { + code_gen.read_param_field(param, field); + code_gen.write_expr_field(expression, field); + } } ItemId::Local(local) => { - let func = component.component.functions.get(func).unwrap(); - let local_index = local.index() + func.arguments.len(); - builder.instruction(&Instruction::LocalGet(local_index as u32)); + for field in fields.iter() { + code_gen.read_local_field(local, field); + code_gen.write_expr_field(expression, field); + } } - ItemId::Function(_) => unimplemented!(), + ItemId::Function(_) => panic!("Cannot use function as value!!"), } Ok(()) } } impl EncodeExpression for ast::Literal { - fn encode( + fn alloc_expr_locals( &self, - _generator: &CodeGenerator, - component: &ResolvedComponent, expression: ExpressionId, - func: FunctionId, - builder: &mut enc::Function, + allocator: &mut ExpressionAllocator, ) -> Result<(), GenerationError> { - let comp = &component.component; - let resolver = component.resolved_funcs.get(&func).unwrap(); - - let rtype = resolver.get_resolved_type(expression, comp)?; - let valtype = super::rtype_to_core_valtype(rtype, &component.component); + allocator.alloc(expression) + } - use ast::Literal; - let instruction = match (valtype, self) { - (enc::ValType::I32, Literal::Integer(value)) => Instruction::I32Const(*value as i32), - (enc::ValType::I64, Literal::Integer(value)) => Instruction::I64Const(*value as i64), - (enc::ValType::F32, Literal::Float(value)) => Instruction::F32Const(*value as f32), - (enc::ValType::F64, Literal::Float(value)) => Instruction::F64Const(*value), - _ => todo!(), - }; - builder.instruction(&instruction); + fn encode( + &self, + expression: ExpressionId, + code_gen: &mut CodeGenerator, + ) -> Result<(), GenerationError> { + let fields = code_gen.fields(expression)?; + dbg!(&fields); + for field in fields.iter() { + match self { + ast::Literal::Integer(int) => code_gen.encode_const_int(*int, field), + ast::Literal::Float(float) => code_gen.encode_const_float(*float, field), + } + code_gen.write_expr_field(expression, field); + } Ok(()) } } impl EncodeExpression for ast::Call { + fn alloc_expr_locals( + &self, + expression: ExpressionId, + allocator: &mut ExpressionAllocator, + ) -> Result<(), GenerationError> { + allocator.alloc(expression)?; + for arg in self.args.iter() { + allocator.alloc_child(*arg)?; + } + Ok(()) + } + fn encode( &self, - generator: &CodeGenerator, - component: &ResolvedComponent, - _expression: ExpressionId, - func: FunctionId, - builder: &mut enc::Function, + expression: ExpressionId, + code_gen: &mut CodeGenerator, ) -> Result<(), GenerationError> { for arg in self.args.iter() { - encode_expression(generator, component, *arg, func, builder)?; + code_gen.encode_child(*arg)?; + for field in code_gen.fields(*arg)?.iter() { + code_gen.read_expr_field(*arg, field); + } + } + let item = code_gen.lookup_name(self.ident); + code_gen.encode_call(item)?; + for field in code_gen.fields(expression)?.iter() { + code_gen.write_expr_field(expression, field); } - let resolver = component.resolved_funcs.get(&func).unwrap(); - let index = match resolver.bindings.get(&self.ident).unwrap() { - ItemId::Import(import) => *generator.func_idx_for_import.get(import).unwrap(), - ItemId::Function(function) => *generator.func_idx_for_func.get(function).unwrap(), - _ => panic!(""), - }; - builder.instruction(&Instruction::Call(index.into())); Ok(()) } } impl EncodeExpression for ast::UnaryExpression { + fn alloc_expr_locals( + &self, + expression: ExpressionId, + allocator: &mut ExpressionAllocator, + ) -> Result<(), GenerationError> { + allocator.alloc(expression)?; + allocator.alloc_child(self.inner) + } + fn encode( &self, - generator: &CodeGenerator, - component: &ResolvedComponent, - _expression: ExpressionId, - func: FunctionId, - builder: &mut enc::Function, + expression: ExpressionId, + code_gen: &mut CodeGenerator, ) -> Result<(), GenerationError> { - builder.instruction(&enc::Instruction::I32Const(0)); - encode_expression(generator, component, self.inner, func, builder)?; - builder.instruction(&enc::Instruction::I32Sub); + code_gen.instruction(&enc::Instruction::I32Const(0)); // TODO support 64 bit ints + code_gen.encode_child(self.inner)?; + for field in code_gen.fields(self.inner)?.iter() { + code_gen.read_expr_field(self.inner, field); + } + code_gen.instruction(&enc::Instruction::I32Sub); + for field in code_gen.fields(expression)?.iter() { + code_gen.write_expr_field(expression, field); + } Ok(()) } } @@ -154,125 +190,176 @@ const S: Signedness = Signedness::Signed; const U: Signedness = Signedness::Unsigned; impl EncodeExpression for ast::BinaryExpression { - fn encode( + fn alloc_expr_locals( &self, - generator: &CodeGenerator, - component: &ResolvedComponent, - _expression: ExpressionId, - func: FunctionId, - builder: &mut enc::Function, + expression: ExpressionId, + allocator: &mut ExpressionAllocator, ) -> Result<(), GenerationError> { - let comp = &component.component; - encode_expression(generator, component, self.left, func, builder)?; - encode_expression(generator, component, self.right, func, builder)?; + allocator.alloc(expression)?; + allocator.alloc_child(self.left)?; + allocator.alloc_child(self.right)?; + Ok(()) + } - let resolver = component.resolved_funcs.get(&func).unwrap(); - let rtype = resolver.get_resolved_type(self.left, comp)?; + fn encode( + &self, + expression: ExpressionId, + code_gen: &mut CodeGenerator, + ) -> Result<(), GenerationError> { + code_gen.encode_child(self.left)?; + code_gen.encode_child(self.right)?; - let ptype = crate::rtype_to_ptype(rtype, comp).unwrap(); + let left_fields = code_gen.fields(self.left)?; + for field in left_fields.iter() { + code_gen.read_expr_field(self.left, field); + } + let right_fields = code_gen.fields(self.right)?; + for field in right_fields.iter() { + code_gen.read_expr_field(self.right, field); + } - let core_valtype = crate::ptype_to_core_valtype(ptype); - let instruction = match (self.op, core_valtype, ptype.signedness()) { - // Multiply - (ast::BinaryOp::Multiply, enc::ValType::I32, _) => enc::Instruction::I32Mul, - (ast::BinaryOp::Multiply, enc::ValType::I64, _) => enc::Instruction::I64Mul, - (ast::BinaryOp::Multiply, enc::ValType::F32, _) => enc::Instruction::F32Mul, - (ast::BinaryOp::Multiply, enc::ValType::F64, _) => enc::Instruction::F64Mul, - // Divide - (ast::BinaryOp::Divide, enc::ValType::I32, S) => enc::Instruction::I32DivS, - (ast::BinaryOp::Divide, enc::ValType::I32, U) => enc::Instruction::I32DivU, - (ast::BinaryOp::Divide, enc::ValType::I64, S) => enc::Instruction::I64DivS, - (ast::BinaryOp::Divide, enc::ValType::I64, U) => enc::Instruction::I64DivU, - (ast::BinaryOp::Divide, enc::ValType::F32, _) => enc::Instruction::F32Div, - (ast::BinaryOp::Divide, enc::ValType::F64, _) => enc::Instruction::F64Div, - // Modulo - (ast::BinaryOp::Modulo, enc::ValType::I32, S) => enc::Instruction::I32RemS, - (ast::BinaryOp::Modulo, enc::ValType::I32, U) => enc::Instruction::I32RemU, - (ast::BinaryOp::Modulo, enc::ValType::I64, S) => enc::Instruction::I64RemS, - (ast::BinaryOp::Modulo, enc::ValType::I64, U) => enc::Instruction::I64RemU, - // Addition - (ast::BinaryOp::Add, enc::ValType::I32, _) => enc::Instruction::I32Add, - (ast::BinaryOp::Add, enc::ValType::I64, _) => enc::Instruction::I64Add, - (ast::BinaryOp::Add, enc::ValType::F32, _) => enc::Instruction::F32Add, - (ast::BinaryOp::Add, enc::ValType::F64, _) => enc::Instruction::F64Add, - // Subtraction - (ast::BinaryOp::Subtract, enc::ValType::I32, _) => enc::Instruction::I32Sub, - (ast::BinaryOp::Subtract, enc::ValType::I64, _) => enc::Instruction::I64Sub, - (ast::BinaryOp::Subtract, enc::ValType::F32, _) => enc::Instruction::F32Sub, - (ast::BinaryOp::Subtract, enc::ValType::F64, _) => enc::Instruction::F64Sub, - // Logical Bit Shifting - (ast::BinaryOp::BitShiftL, enc::ValType::I32, _) => enc::Instruction::I32Shl, - (ast::BinaryOp::BitShiftL, enc::ValType::I64, _) => enc::Instruction::I64Shl, - (ast::BinaryOp::BitShiftR, enc::ValType::I32, _) => enc::Instruction::I32ShrU, - (ast::BinaryOp::BitShiftR, enc::ValType::I64, _) => enc::Instruction::I64ShrU, - // Arithmetic Bit Shifting - (ast::BinaryOp::ArithShiftR, enc::ValType::I32, S) => enc::Instruction::I32ShrS, - (ast::BinaryOp::ArithShiftR, enc::ValType::I32, U) => enc::Instruction::I32ShrU, - (ast::BinaryOp::ArithShiftR, enc::ValType::I64, S) => enc::Instruction::I64ShrS, - (ast::BinaryOp::ArithShiftR, enc::ValType::I64, U) => enc::Instruction::I64ShrU, - // Less than - (ast::BinaryOp::LessThan, enc::ValType::I32, S) => enc::Instruction::I32LtS, - (ast::BinaryOp::LessThan, enc::ValType::I32, U) => enc::Instruction::I32LtU, - (ast::BinaryOp::LessThan, enc::ValType::I64, S) => enc::Instruction::I64LtS, - (ast::BinaryOp::LessThan, enc::ValType::I64, U) => enc::Instruction::I64LtU, - (ast::BinaryOp::LessThan, enc::ValType::F32, _) => enc::Instruction::F32Lt, - (ast::BinaryOp::LessThan, enc::ValType::F64, _) => enc::Instruction::F64Lt, - // Less than equal - (ast::BinaryOp::LessThanEqual, enc::ValType::I32, S) => enc::Instruction::I32LeS, - (ast::BinaryOp::LessThanEqual, enc::ValType::I32, U) => enc::Instruction::I32LeU, - (ast::BinaryOp::LessThanEqual, enc::ValType::I64, S) => enc::Instruction::I64LeS, - (ast::BinaryOp::LessThanEqual, enc::ValType::I64, U) => enc::Instruction::I64LeU, - (ast::BinaryOp::LessThanEqual, enc::ValType::F32, _) => enc::Instruction::F32Le, - (ast::BinaryOp::LessThanEqual, enc::ValType::F64, _) => enc::Instruction::F64Le, - // Greater than - (ast::BinaryOp::GreaterThan, enc::ValType::I32, S) => enc::Instruction::I32GtS, - (ast::BinaryOp::GreaterThan, enc::ValType::I32, U) => enc::Instruction::I32GtU, - (ast::BinaryOp::GreaterThan, enc::ValType::I64, S) => enc::Instruction::I64GtS, - (ast::BinaryOp::GreaterThan, enc::ValType::I64, U) => enc::Instruction::I64GtU, - (ast::BinaryOp::GreaterThan, enc::ValType::F32, _) => enc::Instruction::F32Gt, - (ast::BinaryOp::GreaterThan, enc::ValType::F64, _) => enc::Instruction::F64Gt, - // Greater than or equal - (ast::BinaryOp::GreaterThanEqual, enc::ValType::I32, S) => enc::Instruction::I32GeS, - (ast::BinaryOp::GreaterThanEqual, enc::ValType::I32, U) => enc::Instruction::I32GeU, - (ast::BinaryOp::GreaterThanEqual, enc::ValType::I64, S) => enc::Instruction::I64GeS, - (ast::BinaryOp::GreaterThanEqual, enc::ValType::I64, U) => enc::Instruction::I64GeU, - (ast::BinaryOp::GreaterThanEqual, enc::ValType::F32, _) => enc::Instruction::F32Ge, - (ast::BinaryOp::GreaterThanEqual, enc::ValType::F64, _) => enc::Instruction::F64Ge, - // Equal - (ast::BinaryOp::Equals, enc::ValType::I32, _) => enc::Instruction::I32Eq, - (ast::BinaryOp::Equals, enc::ValType::I64, _) => enc::Instruction::I64Eq, - (ast::BinaryOp::Equals, enc::ValType::F32, _) => enc::Instruction::F32Eq, - (ast::BinaryOp::Equals, enc::ValType::F64, _) => enc::Instruction::F64Eq, - // Not equal - (ast::BinaryOp::NotEquals, enc::ValType::I32, _) => enc::Instruction::I32Eq, - (ast::BinaryOp::NotEquals, enc::ValType::I64, _) => enc::Instruction::I64Eq, - (ast::BinaryOp::NotEquals, enc::ValType::F32, _) => enc::Instruction::F32Eq, - (ast::BinaryOp::NotEquals, enc::ValType::F64, _) => enc::Instruction::F64Eq, - // Bitwise and - (ast::BinaryOp::BitAnd, enc::ValType::I32, _) => enc::Instruction::I32And, - (ast::BinaryOp::BitAnd, enc::ValType::I64, _) => enc::Instruction::I64And, - // Bitwise xor - (ast::BinaryOp::BitXor, enc::ValType::I32, _) => enc::Instruction::I32Xor, - (ast::BinaryOp::BitXor, enc::ValType::I64, _) => enc::Instruction::I64Xor, - // Bitwise or - (ast::BinaryOp::BitOr, enc::ValType::I32, _) => enc::Instruction::I32Or, - (ast::BinaryOp::BitOr, enc::ValType::I64, _) => enc::Instruction::I64Or, - // Logical and/or - (ast::BinaryOp::LogicalAnd, enc::ValType::I32, _) => enc::Instruction::I32And, - (ast::BinaryOp::LogicalOr, enc::ValType::I32, _) => enc::Instruction::I32Or, - // Fallback - (operator, valtype, _) => panic!( - "Cannot apply binary operator {:?} to type {:?}", - operator, valtype - ), - }; - builder.instruction(&instruction); + if left_fields.len() == 1 { + let field = &left_fields[0]; + let ptype = field.ptype; + let valtype = primitive_to_valtype(ptype); + let signedness = ptype.signedness(); + let mask = ptype.core_type_mask(); + encode_binary_arithmetic(self.op, valtype, signedness, mask, code_gen); + } else { + todo!() + } - if let Some(mask) = ptype.core_type_mask() { - builder.instruction(&enc::Instruction::I32Const(mask)); - builder.instruction(&enc::Instruction::I32And); + let fields = code_gen.fields(expression)?; + for field in fields.iter() { + code_gen.write_expr_field(expression, field); } Ok(()) } } + +fn primitive_to_valtype(ptype: ast::PrimitiveType) -> enc::ValType { + match ptype { + ast::PrimitiveType::Bool + | ast::PrimitiveType::U8 + | ast::PrimitiveType::S8 + | ast::PrimitiveType::U16 + | ast::PrimitiveType::S16 + | ast::PrimitiveType::U32 + | ast::PrimitiveType::S32 => enc::ValType::I32, + + ast::PrimitiveType::U64 | ast::PrimitiveType::S64 => enc::ValType::I64, + + ast::PrimitiveType::F32 => enc::ValType::F32, + ast::PrimitiveType::F64 => enc::ValType::F64, + } +} + +fn encode_binary_arithmetic( + op: ast::BinaryOp, + valtype: enc::ValType, + signedness: ast::Signedness, + mask: Option, + code_gen: &mut CodeGenerator, +) { + let instruction = match (op, valtype, signedness) { + // Multiply + (ast::BinaryOp::Multiply, enc::ValType::I32, _) => enc::Instruction::I32Mul, + (ast::BinaryOp::Multiply, enc::ValType::I64, _) => enc::Instruction::I64Mul, + (ast::BinaryOp::Multiply, enc::ValType::F32, _) => enc::Instruction::F32Mul, + (ast::BinaryOp::Multiply, enc::ValType::F64, _) => enc::Instruction::F64Mul, + // Divide + (ast::BinaryOp::Divide, enc::ValType::I32, S) => enc::Instruction::I32DivS, + (ast::BinaryOp::Divide, enc::ValType::I32, U) => enc::Instruction::I32DivU, + (ast::BinaryOp::Divide, enc::ValType::I64, S) => enc::Instruction::I64DivS, + (ast::BinaryOp::Divide, enc::ValType::I64, U) => enc::Instruction::I64DivU, + (ast::BinaryOp::Divide, enc::ValType::F32, _) => enc::Instruction::F32Div, + (ast::BinaryOp::Divide, enc::ValType::F64, _) => enc::Instruction::F64Div, + // Modulo + (ast::BinaryOp::Modulo, enc::ValType::I32, S) => enc::Instruction::I32RemS, + (ast::BinaryOp::Modulo, enc::ValType::I32, U) => enc::Instruction::I32RemU, + (ast::BinaryOp::Modulo, enc::ValType::I64, S) => enc::Instruction::I64RemS, + (ast::BinaryOp::Modulo, enc::ValType::I64, U) => enc::Instruction::I64RemU, + // Addition + (ast::BinaryOp::Add, enc::ValType::I32, _) => enc::Instruction::I32Add, + (ast::BinaryOp::Add, enc::ValType::I64, _) => enc::Instruction::I64Add, + (ast::BinaryOp::Add, enc::ValType::F32, _) => enc::Instruction::F32Add, + (ast::BinaryOp::Add, enc::ValType::F64, _) => enc::Instruction::F64Add, + // Subtraction + (ast::BinaryOp::Subtract, enc::ValType::I32, _) => enc::Instruction::I32Sub, + (ast::BinaryOp::Subtract, enc::ValType::I64, _) => enc::Instruction::I64Sub, + (ast::BinaryOp::Subtract, enc::ValType::F32, _) => enc::Instruction::F32Sub, + (ast::BinaryOp::Subtract, enc::ValType::F64, _) => enc::Instruction::F64Sub, + // Logical Bit Shifting + (ast::BinaryOp::BitShiftL, enc::ValType::I32, _) => enc::Instruction::I32Shl, + (ast::BinaryOp::BitShiftL, enc::ValType::I64, _) => enc::Instruction::I64Shl, + (ast::BinaryOp::BitShiftR, enc::ValType::I32, _) => enc::Instruction::I32ShrU, + (ast::BinaryOp::BitShiftR, enc::ValType::I64, _) => enc::Instruction::I64ShrU, + // Arithmetic Bit Shifting + (ast::BinaryOp::ArithShiftR, enc::ValType::I32, S) => enc::Instruction::I32ShrS, + (ast::BinaryOp::ArithShiftR, enc::ValType::I32, U) => enc::Instruction::I32ShrU, + (ast::BinaryOp::ArithShiftR, enc::ValType::I64, S) => enc::Instruction::I64ShrS, + (ast::BinaryOp::ArithShiftR, enc::ValType::I64, U) => enc::Instruction::I64ShrU, + // Less than + (ast::BinaryOp::LessThan, enc::ValType::I32, S) => enc::Instruction::I32LtS, + (ast::BinaryOp::LessThan, enc::ValType::I32, U) => enc::Instruction::I32LtU, + (ast::BinaryOp::LessThan, enc::ValType::I64, S) => enc::Instruction::I64LtS, + (ast::BinaryOp::LessThan, enc::ValType::I64, U) => enc::Instruction::I64LtU, + (ast::BinaryOp::LessThan, enc::ValType::F32, _) => enc::Instruction::F32Lt, + (ast::BinaryOp::LessThan, enc::ValType::F64, _) => enc::Instruction::F64Lt, + // Less than equal + (ast::BinaryOp::LessThanEqual, enc::ValType::I32, S) => enc::Instruction::I32LeS, + (ast::BinaryOp::LessThanEqual, enc::ValType::I32, U) => enc::Instruction::I32LeU, + (ast::BinaryOp::LessThanEqual, enc::ValType::I64, S) => enc::Instruction::I64LeS, + (ast::BinaryOp::LessThanEqual, enc::ValType::I64, U) => enc::Instruction::I64LeU, + (ast::BinaryOp::LessThanEqual, enc::ValType::F32, _) => enc::Instruction::F32Le, + (ast::BinaryOp::LessThanEqual, enc::ValType::F64, _) => enc::Instruction::F64Le, + // Greater than + (ast::BinaryOp::GreaterThan, enc::ValType::I32, S) => enc::Instruction::I32GtS, + (ast::BinaryOp::GreaterThan, enc::ValType::I32, U) => enc::Instruction::I32GtU, + (ast::BinaryOp::GreaterThan, enc::ValType::I64, S) => enc::Instruction::I64GtS, + (ast::BinaryOp::GreaterThan, enc::ValType::I64, U) => enc::Instruction::I64GtU, + (ast::BinaryOp::GreaterThan, enc::ValType::F32, _) => enc::Instruction::F32Gt, + (ast::BinaryOp::GreaterThan, enc::ValType::F64, _) => enc::Instruction::F64Gt, + // Greater than or equal + (ast::BinaryOp::GreaterThanEqual, enc::ValType::I32, S) => enc::Instruction::I32GeS, + (ast::BinaryOp::GreaterThanEqual, enc::ValType::I32, U) => enc::Instruction::I32GeU, + (ast::BinaryOp::GreaterThanEqual, enc::ValType::I64, S) => enc::Instruction::I64GeS, + (ast::BinaryOp::GreaterThanEqual, enc::ValType::I64, U) => enc::Instruction::I64GeU, + (ast::BinaryOp::GreaterThanEqual, enc::ValType::F32, _) => enc::Instruction::F32Ge, + (ast::BinaryOp::GreaterThanEqual, enc::ValType::F64, _) => enc::Instruction::F64Ge, + // Equal + (ast::BinaryOp::Equals, enc::ValType::I32, _) => enc::Instruction::I32Eq, + (ast::BinaryOp::Equals, enc::ValType::I64, _) => enc::Instruction::I64Eq, + (ast::BinaryOp::Equals, enc::ValType::F32, _) => enc::Instruction::F32Eq, + (ast::BinaryOp::Equals, enc::ValType::F64, _) => enc::Instruction::F64Eq, + // Not equal + (ast::BinaryOp::NotEquals, enc::ValType::I32, _) => enc::Instruction::I32Eq, + (ast::BinaryOp::NotEquals, enc::ValType::I64, _) => enc::Instruction::I64Eq, + (ast::BinaryOp::NotEquals, enc::ValType::F32, _) => enc::Instruction::F32Eq, + (ast::BinaryOp::NotEquals, enc::ValType::F64, _) => enc::Instruction::F64Eq, + // Bitwise and + (ast::BinaryOp::BitAnd, enc::ValType::I32, _) => enc::Instruction::I32And, + (ast::BinaryOp::BitAnd, enc::ValType::I64, _) => enc::Instruction::I64And, + // Bitwise xor + (ast::BinaryOp::BitXor, enc::ValType::I32, _) => enc::Instruction::I32Xor, + (ast::BinaryOp::BitXor, enc::ValType::I64, _) => enc::Instruction::I64Xor, + // Bitwise or + (ast::BinaryOp::BitOr, enc::ValType::I32, _) => enc::Instruction::I32Or, + (ast::BinaryOp::BitOr, enc::ValType::I64, _) => enc::Instruction::I64Or, + // Logical and/or + (ast::BinaryOp::LogicalAnd, enc::ValType::I32, _) => enc::Instruction::I32And, + (ast::BinaryOp::LogicalOr, enc::ValType::I32, _) => enc::Instruction::I32Or, + // Fallback + (operator, valtype, _) => panic!( + "Cannot apply binary operator {:?} to type {:?}", + operator, valtype + ), + }; + code_gen.instruction(&instruction); + + if let Some(mask) = mask { + code_gen.instruction(&enc::Instruction::I32Const(mask)); + code_gen.instruction(&enc::Instruction::I32And); + } +} diff --git a/crates/codegen/src/function.rs b/crates/codegen/src/function.rs new file mode 100644 index 0000000..beb8fc0 --- /dev/null +++ b/crates/codegen/src/function.rs @@ -0,0 +1,144 @@ +use ast::{FnTypeInfo, TypeId}; +use claw_ast as ast; + +use wasm_encoder as enc; + +use crate::{ + types::{align_to, EncodeType}, + ModuleBuilder, ModuleTypeIndex, +}; + +const MAX_FLAT_PARAMS: u8 = 16; +const MAX_FLAT_RESULTS: u8 = 1; + +pub struct FunctionGenerator { + pub params: Vec, + pub param_types: Vec, + pub return_type: ReturnInfo, +} + +impl FunctionGenerator { + pub fn new(fn_type: &FnType, comp: &ast::Component) -> Self { + // Layout parameters + let ParamConfig { + params, + param_types, + } = prepare_params(fn_type, comp); + + // Layout return types + let return_type = fn_type.get_return_type(); + let return_type = prepare_return_type(return_type, comp); + + Self { + params, + param_types, + return_type, + } + } + + pub fn encode_func_type(&self, module: &mut ModuleBuilder) -> ModuleTypeIndex { + let params = self.param_types.iter().copied(); + match self.return_type { + ReturnInfo::Flat(return_type) => module.func_type(params, [return_type]), + ReturnInfo::Spilled => module.func_type(params, [enc::ValType::I32]), + ReturnInfo::None => module.func_type(params, []), + } + } +} + +pub struct ParamConfig { + pub params: Vec, + pub param_types: Vec, +} + +pub enum ParamInfo { + Local(LocalParamInfo), + Spilled(SpilledParamInfo), +} + +pub struct LocalParamInfo { + pub index_offset: u32, +} + +pub struct SpilledParamInfo { + pub mem_offset: u32, +} + +fn prepare_params(fn_type: &FnType, comp: &ast::Component) -> ParamConfig { + // Flatten parameters + let mut flat_params = Vec::new(); + for (_name, type_id) in fn_type.get_args() { + type_id.append_flattened(comp, &mut flat_params); + } + // Either generate as locals or spill to memory based on flattened size + if flat_params.len() <= MAX_FLAT_PARAMS as usize { + let params = param_info_local(fn_type, comp); + let param_types = flat_params; + ParamConfig { + params, + param_types, + } + } else { + let params = param_info_spilled(fn_type, comp); + let param_types = vec![enc::ValType::I32]; + ParamConfig { + params, + param_types, + } + } +} + +fn param_info_local(fn_type: &FnType, comp: &ast::Component) -> Vec { + let mut params = Vec::new(); + let mut index_offset = 0; + for (_name, type_id) in fn_type.get_args() { + params.push(ParamInfo::Local(LocalParamInfo { index_offset })); + index_offset += type_id.flat_size(comp); + } + params +} + +fn param_info_spilled( + fn_type: &FnType, + comp: &ast::Component, +) -> Vec { + let mut params = Vec::new(); + let mut mem_offset = 0; + for (_name, type_id) in fn_type.get_args() { + let align = type_id.align(comp); + let size = type_id.mem_size(comp); + mem_offset = align_to(mem_offset, align); + params.push(ParamInfo::Spilled(SpilledParamInfo { mem_offset })); + mem_offset += size; + } + params +} + +pub enum ReturnInfo { + Flat(enc::ValType), + Spilled, + None, +} + +impl ReturnInfo { + pub fn spill(&self) -> bool { + match self { + ReturnInfo::Flat(_) | ReturnInfo::None => false, + ReturnInfo::Spilled => true, + } + } +} + +fn prepare_return_type(return_type: Option, comp: &ast::Component) -> ReturnInfo { + if let Some(return_type) = return_type { + if return_type.flat_size(comp) > MAX_FLAT_RESULTS as u32 { + ReturnInfo::Spilled + } else { + let return_types = return_type.flatten(comp); + assert_eq!(return_types.len(), 1); + ReturnInfo::Flat(return_types[0]) + } + } else { + ReturnInfo::None + } +} diff --git a/crates/codegen/src/lib.rs b/crates/codegen/src/lib.rs index 48f0ba0..15c5c2d 100644 --- a/crates/codegen/src/lib.rs +++ b/crates/codegen/src/lib.rs @@ -1,38 +1,130 @@ #![allow(clippy::single_match)] -mod component_builder; +mod builders; +mod code; mod expression; -mod module_builder; +mod function; mod statement; +mod types; use std::collections::HashMap; -use component_builder::*; +use builders::component::*; +use builders::module::*; +use code::CodeGenerator; pub use expression::*; -use module_builder::*; +use function::FunctionGenerator; pub use statement::*; -use ast::{FunctionId, ImportId, PrimitiveType, TypeId}; +use ast::{FunctionId, ImportId, PrimitiveType}; use claw_ast as ast; -use claw_resolver::{FunctionResolver, ResolvedComponent, ResolvedType, ResolverError}; +use claw_resolver::{ResolvedComponent, ResolverError}; use miette::Diagnostic; use thiserror::Error; +use types::EncodeType; use wasm_encoder as enc; -use wasm_encoder::Instruction; -pub struct CodeGenerator { - module: ModuleBuilder, - component: ComponentBuilder, +pub fn generate(resolved_comp: &ResolvedComponent) -> Result, GenerationError> { + let component = generate_component(resolved_comp)?; + Ok(component.finalize().finish()) +} + +fn generate_component( + resolved_comp: &ResolvedComponent, +) -> Result { + let comp = &resolved_comp.component; + let mut component = ComponentBuilder::default(); + + let alloc_module = component.module_bytes(gen_allocator()); + let code_module = component.module(generate_module(resolved_comp)?); + + let mut inline_export_args = Vec::new(); + for import in comp.imports.values() { + let import_name = comp.get_name(import.ident); + + match &import.external_type { + ast::ExternalType::Function(fn_type) => { + // Encode Component Type and Import + let type_idx = encode_comp_func_type(fn_type, comp, &mut component); + let func_idx = component.import_func(import_name, type_idx); + + // Lower the Import + let core_func_idx = component.lower_func(func_idx); + + inline_export_args.push(( + import_name.to_owned(), + InlineExportItem::Func(core_func_idx), + )); + } + } + } + + let imports_instance = component.inline_export(inline_export_args); + + let alloc_instance = component.instantiate(alloc_module, vec![]); + + let args = vec![ + ( + "claw".to_string(), + ModuleInstiateArgs::Instance(imports_instance), + ), + ( + "alloc".to_string(), + ModuleInstiateArgs::Instance(alloc_instance), + ), + ]; + let code_instance = component.instantiate(code_module, args); + + let memory = component.alias_memory(alloc_instance, "memory"); + let realloc = component.alias_func(alloc_instance, "realloc"); + + for function in resolved_comp.component.functions.values() { + if function.exported { + let name = comp.get_name(function.ident); + // Alias module instance export into component + let core_func_idx = component.alias_func(code_instance, name); + // Encode component func type + let type_idx = encode_comp_func_type(function, comp, &mut component); + // Lift aliased function to component function + let func_idx = component.lift_func(core_func_idx, type_idx, memory, realloc); + // Export component function + component.export_func(name, func_idx, type_idx); + } + } + + Ok(component) +} + +fn encode_comp_func_type( + fn_type: &dyn ast::FnTypeInfo, + comp: &ast::Component, + builder: &mut ComponentBuilder, +) -> ComponentTypeIndex { + let params = fn_type.get_args().iter().map(|(name, type_id)| { + let name = comp.get_name(*name); + let valtype = comp.get_type(*type_id); + (name, valtype.to_comp_valtype(comp)) + }); + + let result = fn_type.get_return_type().map(|return_type| { + let valtype = comp.get_type(return_type); + valtype.to_comp_valtype(comp) + }); + builder.func_type(params, result) +} + +fn generate_module(resolved_comp: &ResolvedComponent) -> Result { + let generator = ComponentGenerator::default(); + generator.generate(resolved_comp) +} - imports_instance: ComponentModuleInstanceIndex, - code_module: ComponentModuleIndex, - code_instance: ComponentModuleInstanceIndex, +#[derive(Default)] +pub struct ComponentGenerator { + module: ModuleBuilder, func_idx_for_import: HashMap, func_idx_for_func: HashMap, - - inline_export_args: Vec<(String, InlineExportItem)>, } #[derive(Error, Debug, Diagnostic)] @@ -42,115 +134,68 @@ pub enum GenerationError { Resolver(#[from] ResolverError), } -impl Default for CodeGenerator { - fn default() -> Self { - Self::new() - } -} - -impl CodeGenerator { - pub fn new() -> Self { - let mut component = ComponentBuilder::default(); - - let alloc_module = component.module_bytes(gen_allocator()); - let code_module = component.reserve_module(); - - let imports_instance = component.reserve_inline_export(); - let alloc_instance = component.instantiate(alloc_module, vec![]); - - let args = vec![ - ( - "claw".to_string(), - ModuleInstiateArgs::Instance(imports_instance), - ), - ( - "alloc".to_string(), - ModuleInstiateArgs::Instance(alloc_instance), - ), - ]; - let code_instance = component.instantiate(code_module, args); - - Self { - module: ModuleBuilder::default(), - component, - imports_instance, - code_module, - code_instance, - func_idx_for_import: Default::default(), - func_idx_for_func: Default::default(), - inline_export_args: Default::default(), - } - } - +impl ComponentGenerator { pub fn generate( mut self, resolved_comp: &ResolvedComponent, - ) -> Result, GenerationError> { - self.encode_globals(resolved_comp); - - self.encode_import_allocator(); + ) -> Result { + let comp = &resolved_comp.component; + // There is only ever one memory, memory zero + let (_memory, realloc) = self.encode_import_allocator(); - for (id, import) in resolved_comp.component.imports.iter() { - self.encode_import(id, import, resolved_comp); + for (id, import) in comp.imports.iter() { + self.encode_import(id, import, comp); } + self.encode_globals(resolved_comp)?; + + let mut functions = Vec::new(); for (id, function) in resolved_comp.component.functions.iter() { - self.encode_func(id, function, resolved_comp)?; + let func_gen = FunctionGenerator::new(function, comp); + functions.push((id, func_gen)); + self.encode_func(id, function, comp)?; } - for (id, function) in resolved_comp.component.functions.iter() { - self.encode_code(id, function, resolved_comp)?; + for (id, function) in functions { + let func_gen = CodeGenerator::new(&mut self, resolved_comp, function, realloc, id)?; + func_gen.finalize()?; } - Ok(self.emit_bytes()) + Ok(self.module.finalize()) } - fn encode_import_allocator(&mut self) { - let _memory = self.module.import_memory("alloc", "memory"); + fn encode_import_allocator(&mut self) -> (ModuleMemoryIndex, ModuleFunctionIndex) { + let memory: ModuleMemoryIndex = self.module.import_memory("alloc", "memory"); let realloc_type = self .module .func_type(vec![enc::ValType::I32; 4], vec![enc::ValType::I32; 1]); - self.module.import_func("alloc", "realloc", realloc_type); + let realloc = self.module.import_func("alloc", "realloc", realloc_type); + (memory, realloc) } - fn encode_import( - &mut self, - id: ImportId, - import: &ast::Import, - resolved_comp: &ResolvedComponent, - ) { - let import_name = resolved_comp.component.get_name(import.ident); + fn encode_import(&mut self, id: ImportId, import: &ast::Import, comp: &ast::Component) { + let import_name = comp.get_name(import.ident); - let comp = &resolved_comp.component; + let comp = ∁ match &import.external_type { ast::ExternalType::Function(fn_type) => { - // Encode Module Type and Import - let mod_type_idx = self.encode_mod_func_type(fn_type, comp); - let mod_func_idx = self.module.import_func("claw", import_name, mod_type_idx); - - self.func_idx_for_import.insert(id, mod_func_idx); - - // Encode Component Type and Import - let comp_type_idx = self.encode_comp_func_type(fn_type, comp); - let comp_func_idx = self.component.import_func(import_name, comp_type_idx); - - // Lower the Import - let comp_core_func_idx = self.component.lower_func(comp_func_idx); - - self.inline_export_args.push(( - import_name.to_owned(), - InlineExportItem::Func(comp_core_func_idx), - )); + let func_gen = FunctionGenerator::new(fn_type, comp); + let type_idx = func_gen.encode_func_type(&mut self.module); + let func_idx = self.module.import_func("claw", import_name, type_idx); + self.func_idx_for_import.insert(id, func_idx); } } } - fn encode_globals(&mut self, component: &ResolvedComponent) { - for (id, global) in component.component.globals.iter() { - let valtype = type_id_to_core_valtype(global.type_id, &component.component); + fn encode_globals(&mut self, resolved_comp: &ResolvedComponent) -> Result<(), GenerationError> { + let comp = &resolved_comp.component; + for (id, global) in comp.globals.iter() { + let valtypes = global.type_id.flatten(comp); + assert_eq!(valtypes.len(), 1, "Cannot use non-primitive globals"); + let valtype = valtypes[0]; - let init_expr = if let Some(init_value) = component.global_vals.get(&id) { - let valtype = component.component.get_type(global.type_id); + let init_expr = if let Some(init_value) = resolved_comp.global_vals.get(&id) { + let valtype = comp.get_type(global.type_id); match valtype { ast::ValType::Result { .. } => todo!(), ast::ValType::String => todo!(), @@ -162,131 +207,30 @@ impl CodeGenerator { self.module.global(global.mutable, valtype, &init_expr); } + Ok(()) } fn encode_func( &mut self, id: FunctionId, function: &ast::Function, - context: &ResolvedComponent, + comp: &ast::Component, ) -> Result<(), GenerationError> { - let comp = &context.component; + let func_gen = FunctionGenerator::new(function, comp); + let type_idx = func_gen.encode_func_type(&mut self.module); + let func_idx = self.module.function(type_idx); - let mod_type_idx = self.encode_mod_func_type(function, comp); - let mod_func_idx = self.module.function(mod_type_idx); - - self.func_idx_for_func.insert(id, mod_func_idx); + self.func_idx_for_func.insert(id, func_idx); if function.exported { - self.encode_func_export(mod_func_idx, function, context); - } - - Ok(()) - } - - fn encode_code( - &mut self, - id: FunctionId, - function: &ast::Function, - context: &ResolvedComponent, - ) -> Result<(), GenerationError> { - let resolver = context.resolved_funcs.get(&id).unwrap(); - let locals = encode_locals(resolver, context)?; - let mut builder = enc::Function::new(locals); - - for statement in function.body.iter() { - encode_statement(self, context, *statement, id, &mut builder)?; + let ident = function.ident; + let name = comp.get_name(ident); + // Export function from module + self.module.export_func(name, func_idx); } - builder.instruction(&Instruction::End); - let mod_func_idx = *self.func_idx_for_func.get(&id).unwrap(); - self.module.code(mod_func_idx, builder); Ok(()) } - - fn encode_func_export( - &mut self, - mod_func_idx: ModuleFunctionIndex, - function: &ast::Function, - context: &ResolvedComponent, - ) { - let comp = &context.component; - let ident = function.ident; - let name = context.component.get_name(ident); - - // Export function from module - self.module.export_func(name, mod_func_idx); - // Alias module instance export into component - let comp_core_func_idx = self.component.alias_func(self.code_instance, name); - // Encode component func type - let comp_type_idx = self.encode_comp_func_type(function, comp); - // Lift aliased function to component function - let comp_func_idx = self.component.lift_func(comp_core_func_idx, comp_type_idx); - // Export component function - self.component - .export_func(name, comp_func_idx, comp_type_idx); - } - - fn emit_bytes(mut self) -> Vec { - // Fill in imports instance - self.component - .fill_inline_export_args(self.imports_instance, self.inline_export_args); - - // Fill in code module & instance - let module = self.module.finalize(); - self.component.fill_module(self.code_module, module); - - self.component.finalize().finish() - } - - fn encode_mod_func_type( - &mut self, - fn_type: &dyn ast::FnTypeInfo, - comp: &ast::Component, - ) -> ModuleTypeIndex { - let params = fn_type - .get_args() - .iter() - .map(|(_name, type_id)| type_id_to_core_valtype(*type_id, comp)); - - match fn_type.get_return_type() { - Some(return_type) => { - let result_type = type_id_to_core_valtype(return_type, comp); - self.module.func_type(params, [result_type]) - } - None => self.module.func_type(params, []), - } - } - - fn encode_comp_func_type( - &mut self, - fn_type: &dyn ast::FnTypeInfo, - comp: &ast::Component, - ) -> ComponentTypeIndex { - let params = fn_type.get_args().iter().map(|(name, type_id)| { - let name = comp.get_name(*name); - let valtype = comp.get_type(*type_id); - (name, valtype_to_comp_valtype(valtype)) - }); - - let result = fn_type.get_return_type().map(|return_type| { - let valtype = comp.get_type(return_type); - valtype_to_comp_valtype(valtype) - }); - self.component.func_type(params, result) - } -} - -fn encode_locals( - resolver: &FunctionResolver, - resolved_comp: &ResolvedComponent, -) -> Result, GenerationError> { - let mut locals = Vec::with_capacity(resolver.locals.len()); - for (id, _local) in resolver.locals.iter() { - let rtype = resolver.get_resolved_local_type(id, &resolved_comp.component)?; - locals.push((1, rtype_to_core_valtype(rtype, &resolved_comp.component))); - } - Ok(locals) } // Literal @@ -306,106 +250,6 @@ fn literal_to_const_expr(literal: &ast::Literal, ptype: ast::PrimitiveType) -> e } } -// ResolvedType - -fn rtype_to_core_valtype(rtype: ResolvedType, component: &ast::Component) -> enc::ValType { - match rtype { - ResolvedType::Unit => panic!("Not able to encode as valtype"), - ResolvedType::Primitive(ptype) => ptype_to_valtype(ptype), - ResolvedType::ValType(type_id) => type_id_to_core_valtype(type_id, component), - } -} - -pub fn rtype_to_ptype(rtype: ResolvedType, component: &ast::Component) -> Option { - match rtype { - ResolvedType::Unit => None, - ResolvedType::Primitive(ptype) => Some(ptype), - ResolvedType::ValType(type_id) => match component.get_type(type_id) { - ast::ValType::Result { .. } => None, - ast::ValType::String => None, - ast::ValType::Primitive(ptype) => Some(*ptype), - }, - } -} - -// TypeId - -fn type_id_to_core_valtype(type_id: TypeId, component: &ast::Component) -> enc::ValType { - let valtype = component.get_type(type_id); - valtype_to_core_valtype(valtype) -} - -// ast::ValType - -fn valtype_to_core_valtype(valtype: &ast::ValType) -> enc::ValType { - match valtype { - ast::ValType::Primitive(ptype) => ptype_to_valtype(*ptype), - _ => panic!("Cannot encode non-primitive as a valtype"), - } -} - -fn valtype_to_comp_valtype(valtype: &ast::ValType) -> enc::ComponentValType { - match valtype { - ast::ValType::Result { .. } => todo!(), - ast::ValType::String => todo!(), - ast::ValType::Primitive(ptype) => ptype_to_comp_valtype(*ptype), - } -} - -// PrimitiveType - -fn ptype_to_valtype(ptype: PrimitiveType) -> enc::ValType { - use ast::PrimitiveType as PType; - match ptype { - PType::U32 | PType::S32 | PType::U16 | PType::S16 | PType::U8 | PType::S8 | PType::Bool => { - enc::ValType::I32 - } - - PType::U64 | PType::S64 => enc::ValType::I64, - - PType::F32 => enc::ValType::F32, - PType::F64 => enc::ValType::F64, - } -} - -fn ptype_to_ptype_valtype(ptype: PrimitiveType) -> enc::PrimitiveValType { - use ast::PrimitiveType as PType; - match ptype { - PType::U64 => enc::PrimitiveValType::U64, - PType::U32 => enc::PrimitiveValType::U32, - PType::U16 => enc::PrimitiveValType::U16, - PType::U8 => enc::PrimitiveValType::U8, - PType::S64 => enc::PrimitiveValType::S64, - PType::S32 => enc::PrimitiveValType::S32, - PType::S16 => enc::PrimitiveValType::S16, - PType::S8 => enc::PrimitiveValType::S8, - PType::F32 => enc::PrimitiveValType::Float32, - PType::F64 => enc::PrimitiveValType::Float64, - PType::Bool => enc::PrimitiveValType::Bool, - } -} - -fn ptype_to_comp_valtype(ptype: PrimitiveType) -> enc::ComponentValType { - enc::ComponentValType::Primitive(ptype_to_ptype_valtype(ptype)) -} - -fn ptype_to_core_valtype(ptype: PrimitiveType) -> enc::ValType { - use ast::PrimitiveType as PType; - - match ptype { - PType::Bool => enc::ValType::I32, - - PType::U64 | PType::S64 => enc::ValType::I64, - - PType::U32 | PType::U16 | PType::U8 | PType::S32 | PType::S16 | PType::S8 => { - enc::ValType::I32 - } - - PType::F32 => enc::ValType::F32, - PType::F64 => enc::ValType::F64, - } -} - // ValType pub fn gen_allocator() -> Vec { diff --git a/crates/codegen/src/statement.rs b/crates/codegen/src/statement.rs index 28df9a7..02504a4 100644 --- a/crates/codegen/src/statement.rs +++ b/crates/codegen/src/statement.rs @@ -1,161 +1,181 @@ -use super::{encode_expression, CodeGenerator, GenerationError}; -use ast::{ExpressionId, FunctionId, NameId, StatementId}; +use crate::code::{CodeGenerator, ExpressionAllocator}; + +use super::GenerationError; +use ast::{ExpressionId, NameId, Statement}; use claw_ast as ast; -use claw_resolver::{ItemId, ResolvedComponent}; +use claw_resolver::ItemId; use cranelift_entity::EntityRef; use wasm_encoder as enc; use wasm_encoder::Instruction; -pub fn encode_statement( - generator: &CodeGenerator, - component: &ResolvedComponent, - statement: StatementId, - func: FunctionId, - builder: &mut enc::Function, -) -> Result<(), GenerationError> { - let s: &dyn EncodeStatement = match &component.component.get_statement(statement) { - ast::Statement::Let(statement) => statement, - ast::Statement::Assign(statement) => statement, - ast::Statement::Call(statement) => statement, - ast::Statement::If(statement) => statement, - ast::Statement::Return(statement) => statement, - }; - s.encode_statement(generator, component, func, builder)?; - Ok(()) +pub trait EncodeStatement { + fn alloc_expr_locals(&self, allocator: &mut ExpressionAllocator) + -> Result<(), GenerationError>; + + fn encode(&self, code_gen: &mut CodeGenerator) -> Result<(), GenerationError>; } -pub trait EncodeStatement { - fn encode_statement( +impl EncodeStatement for Statement { + fn alloc_expr_locals( &self, - generator: &CodeGenerator, - component: &ResolvedComponent, - func: FunctionId, - builder: &mut enc::Function, - ) -> Result<(), GenerationError>; + allocator: &mut ExpressionAllocator, + ) -> Result<(), GenerationError> { + let statement: &dyn EncodeStatement = match self { + Statement::Let(statement) => statement, + Statement::Assign(statement) => statement, + Statement::Call(statement) => statement, + Statement::If(statement) => statement, + Statement::Return(statement) => statement, + }; + statement.alloc_expr_locals(allocator) + } + + fn encode(&self, code_gen: &mut CodeGenerator) -> Result<(), GenerationError> { + let statement: &dyn EncodeStatement = match self { + Statement::Let(statement) => statement, + Statement::Assign(statement) => statement, + Statement::Call(statement) => statement, + Statement::If(statement) => statement, + Statement::Return(statement) => statement, + }; + statement.encode(code_gen) + } } impl EncodeStatement for ast::Let { - fn encode_statement( + fn alloc_expr_locals( &self, - generator: &CodeGenerator, - component: &ResolvedComponent, - func: FunctionId, - builder: &mut enc::Function, + allocator: &mut ExpressionAllocator, ) -> Result<(), GenerationError> { - encode_assignment( - generator, - component, - func, - self.ident, - self.expression, - builder, - ) + allocator.alloc_child(self.expression) + } + + fn encode(&self, code_gen: &mut CodeGenerator) -> Result<(), GenerationError> { + encode_assignment(self.ident, self.expression, code_gen) } } impl EncodeStatement for ast::Assign { - fn encode_statement( + fn alloc_expr_locals( &self, - generator: &CodeGenerator, - component: &ResolvedComponent, - func: FunctionId, - builder: &mut enc::Function, + allocator: &mut ExpressionAllocator, ) -> Result<(), GenerationError> { - encode_assignment( - generator, - component, - func, - self.ident, - self.expression, - builder, - ) + allocator.alloc_child(self.expression) + } + + fn encode(&self, code_gen: &mut CodeGenerator) -> Result<(), GenerationError> { + encode_assignment(self.ident, self.expression, code_gen) } } impl EncodeStatement for ast::Call { - fn encode_statement( + fn alloc_expr_locals( &self, - generator: &CodeGenerator, - component: &ResolvedComponent, - func: FunctionId, - builder: &mut enc::Function, + allocator: &mut ExpressionAllocator, ) -> Result<(), GenerationError> { - let resolver = component.resolved_funcs.get(&func).unwrap(); + for arg in self.args.iter() { + allocator.alloc_child(*arg)?; + } + Ok(()) + } + fn encode(&self, code_gen: &mut CodeGenerator) -> Result<(), GenerationError> { for arg in self.args.iter() { - encode_expression(generator, component, *arg, func, builder)?; + code_gen.encode_child(*arg)?; } - let index = match resolver.bindings.get(&self.ident).unwrap() { - ItemId::Import(import) => *generator.func_idx_for_import.get(import).unwrap(), - ItemId::Function(function) => *generator.func_idx_for_func.get(function).unwrap(), - _ => panic!(""), - }; - builder.instruction(&Instruction::Call(index.into())); + let item = code_gen.lookup_name(self.ident); + code_gen.encode_call(item)?; Ok(()) } } impl EncodeStatement for ast::If { - fn encode_statement( + fn alloc_expr_locals( &self, - generator: &CodeGenerator, - component: &ResolvedComponent, - func: FunctionId, - builder: &mut enc::Function, + allocator: &mut ExpressionAllocator, ) -> Result<(), GenerationError> { - encode_expression(generator, component, self.condition, func, builder)?; - builder.instruction(&Instruction::If(enc::BlockType::Empty)); + allocator.alloc_child(self.condition)?; + for statement in self.block.iter() { + allocator.alloc_statement(*statement)?; + } + Ok(()) + } + + fn encode(&self, code_gen: &mut CodeGenerator) -> Result<(), GenerationError> { + code_gen.encode_child(self.condition)?; + let fields = code_gen.fields(self.condition)?; + assert_eq!(fields.len(), 1); + code_gen.read_expr_field(self.condition, &fields[0]); + code_gen.instruction(&Instruction::If(enc::BlockType::Empty)); for statement in self.block.iter() { - encode_statement(generator, component, *statement, func, builder)?; + code_gen.encode_statement(*statement)?; } - builder.instruction(&Instruction::End); + code_gen.instruction(&Instruction::End); Ok(()) } } impl EncodeStatement for ast::Return { - fn encode_statement( + fn alloc_expr_locals( &self, - generator: &CodeGenerator, - component: &ResolvedComponent, - func: FunctionId, - builder: &mut enc::Function, + allocator: &mut ExpressionAllocator, ) -> Result<(), GenerationError> { if let Some(expression) = self.expression { - encode_expression(generator, component, expression, func, builder)?; + allocator.alloc_child(expression)?; } - builder.instruction(&Instruction::Return); + Ok(()) + } + + fn encode(&self, code_gen: &mut CodeGenerator) -> Result<(), GenerationError> { + if let Some(expression) = self.expression { + code_gen.encode_child(expression)?; + + let fields = code_gen.fields(expression)?; + for field in fields.iter() { + if code_gen.spill_return() { + code_gen.read_return_ptr()?; + code_gen.field_address(field); + code_gen.read_expr_field(expression, field); + code_gen.write_mem(field); + } else { + code_gen.read_expr_field(expression, field); + } + } + + if code_gen.spill_return() { + code_gen.read_return_ptr()?; + } + } + code_gen.instruction(&Instruction::Return); Ok(()) } } fn encode_assignment( - generator: &CodeGenerator, - component: &ResolvedComponent, - func: FunctionId, ident: NameId, expression: ExpressionId, - builder: &mut enc::Function, + code_gen: &mut CodeGenerator, ) -> Result<(), GenerationError> { - let resolver = component.resolved_funcs.get(&func).unwrap(); - encode_expression(generator, component, expression, func, builder)?; - match resolver.bindings.get(&ident).unwrap() { - ItemId::Import(_) => unimplemented!(), + code_gen.encode_child(expression)?; + let fields = code_gen.fields(expression)?; + match code_gen.lookup_name(ident) { + ItemId::Import(_) => panic!("Assigning to imports isn't allowed!!"), ItemId::Global(global) => { - builder.instruction(&Instruction::GlobalSet(global.index() as u32)); - } - ItemId::Param(param) => { - let local_index = param.index() as u32; - builder.instruction(&Instruction::LocalSet(local_index)); + // TODO handle composite globals + for field in fields { + code_gen.read_expr_field(expression, &field); + code_gen.instruction(&Instruction::GlobalSet(global.index() as u32)); + } } + ItemId::Param(_) => panic!("Assigning to parameters isn't allowed!!"), ItemId::Local(local) => { - let func = component.component.functions.get(func).unwrap(); - let local_index = local.index() + func.arguments.len(); - let local_index = local_index as u32; - builder.instruction(&Instruction::LocalSet(local_index)); + for field in fields { + code_gen.read_expr_field(expression, &field); + code_gen.write_local_field(local, &field); + } } - ItemId::Function(_) => unimplemented!(), + ItemId::Function(_) => panic!("Assigning to functions isn't allowed!!"), } Ok(()) } diff --git a/crates/codegen/src/types.rs b/crates/codegen/src/types.rs new file mode 100644 index 0000000..a1e1ca5 --- /dev/null +++ b/crates/codegen/src/types.rs @@ -0,0 +1,365 @@ +use ast::TypeId; +use claw_ast as ast; + +use claw_resolver::ResolvedType; +use wasm_encoder as enc; + +const PTYPE_FLAT_SIZE: u32 = 1; + +const STRING_FLAT_SIZE: u32 = 2; +const STRING_COMP_VALTYPE: enc::ComponentValType = + enc::ComponentValType::Primitive(enc::PrimitiveValType::String); +const STRING_ALIGNMENT: u32 = 2; +const STRING_MEM_SIZE: u32 = 8; + +pub trait EncodeType { + fn flat_size(&self, comp: &ast::Component) -> u32; + + fn append_flattened(&self, comp: &ast::Component, out: &mut Vec); + + fn flatten(&self, comp: &ast::Component) -> Vec { + let mut out = Vec::new(); + self.append_flattened(comp, &mut out); + out + } + + fn append_fields(&self, comp: &ast::Component, out: &mut Vec); + + fn fields(&self, comp: &ast::Component) -> Vec { + let mut out = Vec::new(); + self.append_fields(comp, &mut out); + out + } + + fn to_comp_valtype(&self, comp: &ast::Component) -> enc::ComponentValType; + + fn mem_arg(&self, comp: &ast::Component) -> enc::MemArg { + enc::MemArg { + align: self.align(comp), + offset: 0, + memory_index: 0, + } + } + + fn align(&self, comp: &ast::Component) -> u32; + + fn mem_size(&self, comp: &ast::Component) -> u32; +} + +impl EncodeType for ResolvedType { + fn flat_size(&self, comp: &ast::Component) -> u32 { + match *self { + ResolvedType::Primitive(ptype) => ptype.flat_size(comp), + ResolvedType::ValType(type_id) => type_id.flat_size(comp), + } + } + + fn append_flattened(&self, comp: &ast::Component, out: &mut Vec) { + match *self { + ResolvedType::Primitive(ptype) => ptype.append_flattened(comp, out), + ResolvedType::ValType(type_id) => type_id.append_flattened(comp, out), + } + } + + fn append_fields(&self, comp: &ast::Component, out: &mut Vec) { + match *self { + ResolvedType::Primitive(ptype) => ptype.append_fields(comp, out), + ResolvedType::ValType(type_id) => type_id.append_fields(comp, out), + } + } + + fn to_comp_valtype(&self, comp: &ast::Component) -> enc::ComponentValType { + match *self { + ResolvedType::Primitive(ptype) => ptype.to_comp_valtype(comp), + ResolvedType::ValType(type_id) => type_id.to_comp_valtype(comp), + } + } + + fn align(&self, comp: &ast::Component) -> u32 { + match *self { + ResolvedType::Primitive(ptype) => ptype.align(comp), + ResolvedType::ValType(type_id) => type_id.align(comp), + } + } + + fn mem_size(&self, comp: &ast::Component) -> u32 { + match *self { + ResolvedType::Primitive(ptype) => ptype.mem_size(comp), + ResolvedType::ValType(type_id) => type_id.mem_size(comp), + } + } +} + +impl EncodeType for TypeId { + fn flat_size(&self, comp: &ast::Component) -> u32 { + let valtype = comp.get_type(*self); + valtype.flat_size(comp) + } + + fn append_flattened(&self, comp: &ast::Component, out: &mut Vec) { + let valtype = comp.get_type(*self); + valtype.append_flattened(comp, out); + } + + fn append_fields(&self, comp: &ast::Component, out: &mut Vec) { + let valtype = comp.get_type(*self); + valtype.append_fields(comp, out); + } + + fn to_comp_valtype(&self, comp: &ast::Component) -> enc::ComponentValType { + let valtype = comp.get_type(*self); + valtype.to_comp_valtype(comp) + } + + fn align(&self, comp: &ast::Component) -> u32 { + let valtype = comp.get_type(*self); + valtype.align(comp) + } + + fn mem_size(&self, comp: &ast::Component) -> u32 { + let valtype = comp.get_type(*self); + valtype.mem_size(comp) + } +} + +impl EncodeType for ast::ValType { + fn flat_size(&self, comp: &ast::Component) -> u32 { + match *self { + ast::ValType::Result { .. } => todo!(), + ast::ValType::String => STRING_FLAT_SIZE, + ast::ValType::Primitive(ptype) => ptype.flat_size(comp), + } + } + + fn append_flattened(&self, comp: &ast::Component, out: &mut Vec) { + match *self { + ast::ValType::Result { .. } => todo!(), + ast::ValType::String => { + out.push(enc::ValType::I32); + out.push(enc::ValType::I32); + } + ast::ValType::Primitive(ptype) => ptype.append_flattened(comp, out), + } + } + + fn append_fields(&self, comp: &ast::Component, out: &mut Vec) { + match *self { + ast::ValType::Result { .. } => todo!(), + ast::ValType::String => { + out.push(STRING_OFFSET_FIELD); + out.push(STRING_LENGTH_FIELD); + } + ast::ValType::Primitive(ptype) => ptype.append_fields(comp, out), + } + } + + fn to_comp_valtype(&self, comp: &ast::Component) -> enc::ComponentValType { + match *self { + ast::ValType::Result { .. } => todo!(), + ast::ValType::String => STRING_COMP_VALTYPE, + ast::ValType::Primitive(ptype) => ptype.to_comp_valtype(comp), + } + } + + fn align(&self, comp: &ast::Component) -> u32 { + match *self { + ast::ValType::Result { .. } => todo!(), + ast::ValType::String => STRING_ALIGNMENT, + ast::ValType::Primitive(ptype) => ptype.align(comp), + } + } + + fn mem_size(&self, comp: &ast::Component) -> u32 { + match *self { + ast::ValType::Result { .. } => todo!(), + ast::ValType::String => STRING_MEM_SIZE, + ast::ValType::Primitive(ptype) => ptype.mem_size(comp), + } + } +} + +impl EncodeType for ast::PrimitiveType { + fn flat_size(&self, _comp: &ast::Component) -> u32 { + PTYPE_FLAT_SIZE + } + + fn append_flattened(&self, _comp: &ast::Component, out: &mut Vec) { + let valtype = match *self { + ast::PrimitiveType::Bool + | ast::PrimitiveType::U8 + | ast::PrimitiveType::S8 + | ast::PrimitiveType::U16 + | ast::PrimitiveType::S16 + | ast::PrimitiveType::U32 + | ast::PrimitiveType::S32 => enc::ValType::I32, + ast::PrimitiveType::U64 | ast::PrimitiveType::S64 => enc::ValType::I64, + ast::PrimitiveType::F32 => enc::ValType::F32, + ast::PrimitiveType::F64 => enc::ValType::F64, + }; + out.push(valtype); + } + + fn append_fields(&self, _comp: &ast::Component, out: &mut Vec) { + let field = match self { + ast::PrimitiveType::Bool => BOOL_FIELD, + ast::PrimitiveType::U8 => U8_FIELD, + ast::PrimitiveType::S8 => S8_FIELD, + ast::PrimitiveType::U16 => U16_FIELD, + ast::PrimitiveType::S16 => S16_FIELD, + ast::PrimitiveType::U32 => U32_FIELD, + ast::PrimitiveType::S32 => S32_FIELD, + ast::PrimitiveType::U64 => U64_FIELD, + ast::PrimitiveType::S64 => S64_FIELD, + ast::PrimitiveType::F32 => F32_FIELD, + ast::PrimitiveType::F64 => F64_FIELD, + }; + out.push(field); + } + + fn to_comp_valtype(&self, _comp: &ast::Component) -> enc::ComponentValType { + enc::ComponentValType::Primitive(ptype_to_pvaltype(*self)) + } + + fn align(&self, _comp: &ast::Component) -> u32 { + ptype_align(*self) + } + + fn mem_size(&self, _comp: &ast::Component) -> u32 { + ptype_mem_size(*self) + } +} + +pub fn ptype_mem_arg(ptype: ast::PrimitiveType) -> enc::MemArg { + enc::MemArg { + offset: 0, + align: ptype_align(ptype), + memory_index: 0, + } +} + +fn ptype_align(ptype: ast::PrimitiveType) -> u32 { + match ptype { + ast::PrimitiveType::Bool | ast::PrimitiveType::U8 | ast::PrimitiveType::S8 => 0, + ast::PrimitiveType::U16 | ast::PrimitiveType::S16 => 1, + ast::PrimitiveType::U32 | ast::PrimitiveType::S32 | ast::PrimitiveType::F32 => 2, + ast::PrimitiveType::U64 | ast::PrimitiveType::S64 | ast::PrimitiveType::F64 => 3, + } +} + +fn ptype_mem_size(ptype: ast::PrimitiveType) -> u32 { + match ptype { + ast::PrimitiveType::Bool | ast::PrimitiveType::U8 | ast::PrimitiveType::S8 => 1, + ast::PrimitiveType::U16 | ast::PrimitiveType::S16 => 2, + ast::PrimitiveType::U32 | ast::PrimitiveType::S32 | ast::PrimitiveType::F32 => 4, + ast::PrimitiveType::U64 | ast::PrimitiveType::S64 | ast::PrimitiveType::F64 => 8, + } +} + +pub fn ptype_to_pvaltype(ptype: ast::PrimitiveType) -> enc::PrimitiveValType { + use ast::PrimitiveType as PType; + match ptype { + PType::U64 => enc::PrimitiveValType::U64, + PType::U32 => enc::PrimitiveValType::U32, + PType::U16 => enc::PrimitiveValType::U16, + PType::U8 => enc::PrimitiveValType::U8, + PType::S64 => enc::PrimitiveValType::S64, + PType::S32 => enc::PrimitiveValType::S32, + PType::S16 => enc::PrimitiveValType::S16, + PType::S8 => enc::PrimitiveValType::S8, + PType::F32 => enc::PrimitiveValType::Float32, + PType::F64 => enc::PrimitiveValType::Float64, + PType::Bool => enc::PrimitiveValType::Bool, + } +} + +pub fn align_to(offset: u32, alignment: u32) -> u32 { + offset.div_ceil(alignment) * alignment +} + +/// Info about a field required for reading/writing it +#[derive(Debug)] +pub struct FieldInfo { + pub ptype: ast::PrimitiveType, + pub index_offset: u32, + pub mem_offset: u32, +} + +// Statically known field info + +pub const BOOL_FIELD: FieldInfo = FieldInfo { + ptype: ast::PrimitiveType::Bool, + index_offset: 0, + mem_offset: 0, +}; + +pub const U8_FIELD: FieldInfo = FieldInfo { + ptype: ast::PrimitiveType::U8, + index_offset: 0, + mem_offset: 0, +}; + +pub const S8_FIELD: FieldInfo = FieldInfo { + ptype: ast::PrimitiveType::S8, + index_offset: 0, + mem_offset: 0, +}; + +pub const U16_FIELD: FieldInfo = FieldInfo { + ptype: ast::PrimitiveType::U16, + index_offset: 0, + mem_offset: 0, +}; + +pub const S16_FIELD: FieldInfo = FieldInfo { + ptype: ast::PrimitiveType::S16, + index_offset: 0, + mem_offset: 0, +}; + +pub const U32_FIELD: FieldInfo = FieldInfo { + ptype: ast::PrimitiveType::U32, + index_offset: 0, + mem_offset: 0, +}; + +pub const S32_FIELD: FieldInfo = FieldInfo { + ptype: ast::PrimitiveType::S32, + index_offset: 0, + mem_offset: 0, +}; + +pub const U64_FIELD: FieldInfo = FieldInfo { + ptype: ast::PrimitiveType::U64, + index_offset: 0, + mem_offset: 0, +}; + +pub const S64_FIELD: FieldInfo = FieldInfo { + ptype: ast::PrimitiveType::S64, + index_offset: 0, + mem_offset: 0, +}; + +pub const F32_FIELD: FieldInfo = FieldInfo { + ptype: ast::PrimitiveType::F32, + index_offset: 0, + mem_offset: 0, +}; + +pub const F64_FIELD: FieldInfo = FieldInfo { + ptype: ast::PrimitiveType::F64, + index_offset: 0, + mem_offset: 0, +}; + +pub const STRING_OFFSET_FIELD: FieldInfo = FieldInfo { + ptype: ast::PrimitiveType::U32, + index_offset: 0, + mem_offset: 0, +}; + +pub const STRING_LENGTH_FIELD: FieldInfo = FieldInfo { + ptype: ast::PrimitiveType::U32, + index_offset: 1, + mem_offset: 4, +}; diff --git a/crates/lib/src/lib.rs b/crates/lib/src/lib.rs index 57c67ba..eee1301 100644 --- a/crates/lib/src/lib.rs +++ b/crates/lib/src/lib.rs @@ -1,4 +1,4 @@ -use claw_codegen::CodeGenerator; +use claw_codegen::generate; use claw_common::{make_source, OkPretty}; use claw_parser::{parse, tokenize}; use claw_resolver::resolve; @@ -12,6 +12,5 @@ pub fn compile(source_name: String, source_code: &str) -> Option> { let resolved = resolve(src, ast).ok_pretty()?; - let gen = CodeGenerator::default(); - gen.generate(&resolved).ok_pretty() + generate(&resolved).ok_pretty() } diff --git a/crates/lib/tests/programs/claw.wit b/crates/lib/tests/programs/claw.wit index 7cd3c99..b084f83 100644 --- a/crates/lib/tests/programs/claw.wit +++ b/crates/lib/tests/programs/claw.wit @@ -38,6 +38,10 @@ world quadratic { export quad-f64-let: func(a: float64, b: float64, c: float64, x: float64) -> float64; } +world strings { + export identity: func(s: string) -> string; +} + world unary { export set: func(v: s32) -> s32; export get-inverse: func() -> s32; diff --git a/crates/lib/tests/programs/strings.claw b/crates/lib/tests/programs/strings.claw new file mode 100644 index 0000000..d58fcb2 --- /dev/null +++ b/crates/lib/tests/programs/strings.claw @@ -0,0 +1,11 @@ +export func identity(s: string) -> string { + return s; +} + +// export func hello-world() -> string { +// return "Hello, world!" +// } + +// export func concat(left: string, right: string) -> string { +// return left + right; +// } \ No newline at end of file diff --git a/crates/lib/tests/runtime.rs b/crates/lib/tests/runtime.rs index 3f8fab0..347e184 100644 --- a/crates/lib/tests/runtime.rs +++ b/crates/lib/tests/runtime.rs @@ -110,11 +110,13 @@ fn test_factorial() { Factorial::instantiate(&mut runtime.store, &runtime.component, &runtime.linker).unwrap(); for (i, val) in [1, 1, 2, 6, 24, 120].iter().enumerate() { + let fact = factorial + .call_factorial(&mut runtime.store, i as u64) + .unwrap(); assert_eq!( - factorial - .call_factorial(&mut runtime.store, i as u64) - .unwrap(), - *val + fact, *val, + "factorial({}) was {} instead of {}", + i, fact, *val ); } } @@ -233,6 +235,32 @@ fn test_quadratic() { } } +#[test] +fn test_strings() { + bindgen!("strings" in "tests/programs"); + + let mut runtime = Runtime::new("strings"); + + let (strings, _) = + Strings::instantiate(&mut runtime.store, &runtime.component, &runtime.linker).unwrap(); + + let long_string = "Z".repeat(1000); + let cases = [ + "", + "asdf", + "673hlksdfkjh5r;4hj6s", + "a", + long_string.as_str(), + ]; + + for case in cases { + assert_eq!( + case, + strings.call_identity(&mut runtime.store, case).unwrap() + ); + } +} + #[test] fn test_unary() { bindgen!("unary" in "tests/programs"); diff --git a/crates/parser/src/types.rs b/crates/parser/src/types.rs index b663f92..8971932 100644 --- a/crates/parser/src/types.rs +++ b/crates/parser/src/types.rs @@ -8,16 +8,21 @@ pub fn parse_valtype(input: &mut ParseInput, comp: &mut Component) -> Result ValType::Primitive(PrimitiveType::Bool), + // Unsigned Integers Token::U8 => ValType::Primitive(PrimitiveType::U8), Token::U16 => ValType::Primitive(PrimitiveType::U16), Token::U32 => ValType::Primitive(PrimitiveType::U32), Token::U64 => ValType::Primitive(PrimitiveType::U64), + // Signed Integers Token::S8 => ValType::Primitive(PrimitiveType::S8), Token::S16 => ValType::Primitive(PrimitiveType::S16), Token::S32 => ValType::Primitive(PrimitiveType::S32), Token::S64 => ValType::Primitive(PrimitiveType::S64), + // Floats Token::F32 => ValType::Primitive(PrimitiveType::F32), Token::F64 => ValType::Primitive(PrimitiveType::F64), + // String + Token::String => ValType::String, _ => return Err(input.unexpected_token("Not a legal type")), }; let name_id = comp.new_type(valtype, span); diff --git a/crates/resolver/src/lib.rs b/crates/resolver/src/lib.rs index 0459762..1d502f8 100644 --- a/crates/resolver/src/lib.rs +++ b/crates/resolver/src/lib.rs @@ -185,7 +185,7 @@ pub fn resolve(src: Source, component: ast::Component) -> Result, + pub params: PrimaryMap, // Name Resolution /// Entries for each unique local @@ -489,7 +489,6 @@ impl FunctionResolver { #[derive(Clone, Copy, Debug)] pub enum ResolvedType { - Unit, Primitive(ast::PrimitiveType), ValType(TypeId), } @@ -497,7 +496,6 @@ pub enum ResolvedType { impl std::fmt::Display for ResolvedType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - ResolvedType::Unit => f.write_str("Unit"), ResolvedType::Primitive(p) => (p as &dyn std::fmt::Debug).fmt(f), ResolvedType::ValType(v) => (v as &dyn std::fmt::Debug).fmt(f), } @@ -521,7 +519,6 @@ impl ResolvedType { impl<'ctx> ResolvedTypeContext<'ctx> { pub fn type_eq(&self, other: &ResolvedType) -> bool { match (self.rtype, *other) { - (ResolvedType::Unit, ResolvedType::Unit) => true, (ResolvedType::Primitive(left), ResolvedType::Primitive(right)) => left == right, (ResolvedType::ValType(left), ResolvedType::ValType(right)) => { let l_valtype = self.context.get_type(left); @@ -536,7 +533,6 @@ impl<'ctx> ResolvedTypeContext<'ctx> { _ => false, } } - _ => false, } } } diff --git a/src/bin.rs b/src/bin.rs index a8dd4b8..9c5c2a9 100644 --- a/src/bin.rs +++ b/src/bin.rs @@ -2,7 +2,7 @@ use std::{fs, path::PathBuf, sync::Arc}; use clap::Parser; -use claw_codegen::CodeGenerator; +use claw_codegen::generate; use claw_common::OkPretty; use claw_parser::{parse, tokenize}; use claw_resolver::resolve; @@ -39,8 +39,7 @@ impl Compile { let resolved = resolve(src, ast).ok_pretty()?; - let generator = CodeGenerator::default(); - let wasm = generator.generate(&resolved).ok_pretty()?; + let wasm = generate(&resolved).ok_pretty()?; match fs::write(&self.output, wasm) { Ok(_) => println!("Done"), diff --git a/test.wat b/test.wat new file mode 100644 index 0000000..87647d5 --- /dev/null +++ b/test.wat @@ -0,0 +1,119 @@ +(component + (core module (;0;) + (type (;0;) (func (param i32 i32 i32 i32) (result i32))) + (type (;1;) (func)) + (func $realloc (;0;) (type 0) (param $old_ptr i32) (param $old_size i32) (param $align i32) (param $new_size i32) (result i32) + (local $ret i32) + local.get $old_ptr + if ;; label = @1 + local.get $old_size + local.get $new_size + i32.gt_u + if ;; label = @2 + local.get $old_ptr + return + end + end + global.get $last + local.get $align + i32.const -1 + i32.add + i32.add + local.get $align + i32.const -1 + i32.add + i32.const -1 + i32.xor + i32.and + global.set $last + global.get $last + local.set $ret + global.get $last + local.get $new_size + i32.add + global.set $last + loop $loop ;; label = @1 + memory.size + i32.const 65536 + i32.mul + global.get $last + i32.lt_u + if ;; label = @2 + i32.const 1 + memory.grow + i32.const -1 + i32.eq + if ;; label = @3 + unreachable + end + br 1 (;@1;) + end + end + local.get $ret + i32.const 222 + local.get $new_size + memory.fill + local.get $old_ptr + if ;; label = @1 + local.get $ret + local.get $old_ptr + local.get $old_size + memory.copy + end + local.get $ret + ) + (func $clear (;1;) (type 1) + i32.const 8 + global.set $last + ) + (memory $memory (;0;) 1) + (global $last (;0;) (mut i32) i32.const 8) + (export "memory" (memory $memory)) + (export "realloc" (func $realloc)) + (export "clear" (func $clear)) + ) + (core module (;1;) + (type (;0;) (func (param i32 i32 i32 i32) (result i32))) + (type (;1;) (func (param i32 i32) (result i32))) + (import "alloc" "memory" (memory (;0;) 1)) + (import "alloc" "realloc" (func (;0;) (type 0))) + (func (;1;) (type 1) (param i32 i32) (result i32) + (local i32 i32 i32) + i32.const 0 + i32.const 0 + i32.const 4 + i32.const 8 + call 0 + local.set 2 + local.get 0 + local.set 3 + local.get 1 + local.set 4 + local.get 2 + i32.const 0 + i32.add + local.get 3 + i32.store align=16 + local.get 2 + i32.const 4 + i32.add + local.get 4 + i32.store align=16 + return + ) + (export "identity" (func 1)) + ) + (core instance (;0;)) + (core instance (;1;) (instantiate 0)) + (core instance (;2;) (instantiate 1 + (with "claw" (instance 0)) + (with "alloc" (instance 1)) + ) + ) + (alias core export 1 "memory" (core memory (;0;))) + (alias core export 1 "realloc" (core func (;0;))) + (alias core export 2 "identity" (core func (;1;))) + (type (;0;) (func (param "s" string) (result string))) + (func (;0;) (type 0) (canon lift (core func 1) (memory 0) (realloc 0))) + (export (;1;) "identity" (func 0) (func (type 0))) +) \ No newline at end of file