diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 263d334..25ea24d 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -17,11 +17,8 @@ jobs: - name: Checkout uses: actions/checkout@v4 - - name: Cargo binstall - run: curl -L --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/cargo-bins/cargo-binstall/main/install-from-binstall-release.sh | bash - - - name: Instal cargo-workspaces - run: cargo binstall cargo-workspaces --locked -y + - name: Install cargo-workspaces + run: cargo install cargo-workspaces - name: Run tests run: cargo test --all-features --workspace @@ -31,7 +28,7 @@ jobs: - name: Unused dependencies run: | - cargo binstall cargo-machete --locked -y + cargo install cargo-machete --locked cargo machete - name: Fmt diff --git a/crates/rue-compiler/src/compiler.rs b/crates/rue-compiler/src/compiler.rs index 756c003..0c153dc 100644 --- a/crates/rue-compiler/src/compiler.rs +++ b/crates/rue-compiler/src/compiler.rs @@ -12,6 +12,7 @@ use crate::{ database::{Database, HirId, ScopeId, SymbolId}, hir::{Hir, Op}, scope::Scope, + symbol::{Function, Symbol}, value::{GuardPath, Value}, ErrorKind, }; @@ -46,8 +47,8 @@ pub struct Compiler<'a> { // The type definition stack is used for calculating types referenced in types. type_definition_stack: Vec, - // The type guard stack is used for overriding types in certain contexts. - type_guard_stack: Vec>, + // Overridden symbol types due to type guards. + type_overrides: Vec>, // The generic type stack is used for overriding generic types that are being checked against. generic_type_stack: Vec>, @@ -74,7 +75,7 @@ impl<'a> Compiler<'a> { scope_stack: vec![builtins.scope_id], symbol_stack: Vec::new(), type_definition_stack: Vec::new(), - type_guard_stack: Vec::new(), + type_overrides: Vec::new(), generic_type_stack: Vec::new(), allow_generic_inference_stack: vec![false], is_callee: false, @@ -169,13 +170,50 @@ impl<'a> Compiler<'a> { Value::new(self.builtins.unknown, self.ty.std().unknown) } - fn symbol_type(&self, guard_path: &GuardPath) -> Option { - for guards in self.type_guard_stack.iter().rev() { - if let Some(guard) = guards.get(guard_path) { - return Some(*guard); + fn build_overrides(&mut self, guards: HashMap) -> HashMap { + type GuardItem = (Vec, TypeId); + + let mut symbol_guards: HashMap> = HashMap::new(); + + for (guard_path, type_id) in guards { + symbol_guards + .entry(guard_path.symbol_id) + .or_default() + .push((guard_path.items, type_id)); + } + + let mut overrides = HashMap::new(); + + for (symbol_id, mut items) in symbol_guards { + // Order by length. + items.sort_by_key(|(items, _)| items.len()); + + let mut type_id = self.symbol_type(symbol_id); + + for (path_items, new_type_id) in items { + type_id = self.ty.replace(type_id, new_type_id, &path_items); + } + + overrides.insert(symbol_id, type_id); + } + + overrides + } + + fn symbol_type(&self, symbol_id: SymbolId) -> TypeId { + for guards in self.type_overrides.iter().rev() { + if let Some(type_id) = guards.get(&symbol_id) { + return *type_id; } } - None + + match self.db.symbol(symbol_id) { + Symbol::Unknown | Symbol::Module(..) => unreachable!(), + Symbol::Function(Function { type_id, .. }) + | Symbol::InlineFunction(Function { type_id, .. }) + | Symbol::Parameter(type_id) => *type_id, + Symbol::Let(value) | Symbol::Const(value) | Symbol::InlineConst(value) => value.type_id, + } } fn scope(&self) -> &Scope { diff --git a/crates/rue-compiler/src/compiler/block.rs b/crates/rue-compiler/src/compiler/block.rs index 984f1ea..35a1cfb 100644 --- a/crates/rue-compiler/src/compiler/block.rs +++ b/crates/rue-compiler/src/compiler/block.rs @@ -50,7 +50,8 @@ impl Compiler<'_> { // Push the type guards onto the stack. // This will be popped in reverse order later after all statements have been lowered. - self.type_guard_stack.push(else_guards); + let overrides = self.build_overrides(else_guards); + self.type_overrides.push(overrides); statements.push(Statement::If(condition_hir, then_hir)); } @@ -103,8 +104,8 @@ impl Compiler<'_> { // If the condition is false, we raise an error. // So we can assume that the condition is true from this point on. // This will be popped in reverse order later after all statements have been lowered. - - self.type_guard_stack.push(condition.then_guards()); + let overrides = self.build_overrides(condition.then_guards()); + self.type_overrides.push(overrides); let not_condition = self.db.alloc_hir(Hir::Op(Op::Not, condition.hir_id)); let raise = self.db.alloc_hir(Hir::Raise(None)); @@ -126,7 +127,8 @@ impl Compiler<'_> { assume_stmt.syntax().text_range(), ); - self.type_guard_stack.push(expr.then_guards()); + let overrides = self.build_overrides(expr.then_guards()); + self.type_overrides.push(overrides); statements.push(Statement::Assume); } } @@ -158,7 +160,7 @@ impl Compiler<'_> { body = value; } Statement::If(condition, then_block) => { - self.type_guard_stack.pop().unwrap(); + self.type_overrides.pop().unwrap(); body = Value::new( self.db @@ -167,7 +169,7 @@ impl Compiler<'_> { ); } Statement::Assume => { - self.type_guard_stack.pop().unwrap(); + self.type_overrides.pop().unwrap(); } } } diff --git a/crates/rue-compiler/src/compiler/expr/binary_expr.rs b/crates/rue-compiler/src/compiler/expr/binary_expr.rs index 1272f0e..1f31aab 100644 --- a/crates/rue-compiler/src/compiler/expr/binary_expr.rs +++ b/crates/rue-compiler/src/compiler/expr/binary_expr.rs @@ -172,7 +172,7 @@ impl Compiler<'_> { let else_type = self.ty.difference(rhs.type_id, self.ty.std().nil); value .guards - .insert(guard_path, Guard::new(then_type, else_type)); + .insert(guard_path, Guard::new(Some(then_type), Some(else_type))); } } @@ -182,7 +182,7 @@ impl Compiler<'_> { let else_type = self.ty.difference(lhs.type_id, self.ty.std().nil); value .guards - .insert(guard_path, Guard::new(then_type, else_type)); + .insert(guard_path, Guard::new(Some(then_type), Some(else_type))); } } @@ -250,13 +250,14 @@ impl Compiler<'_> { } fn op_and(&mut self, lhs: Value, rhs: Option<&Expr>, text_range: TextRange) -> Value { - self.type_guard_stack.push(lhs.then_guards()); + let overrides = self.build_overrides(lhs.then_guards()); + self.type_overrides.push(overrides); let rhs = rhs .map(|rhs| self.compile_expr(rhs, Some(self.ty.std().bool))) .unwrap_or_else(|| self.unknown()); - self.type_guard_stack.pop().unwrap(); + self.type_overrides.pop().unwrap(); self.type_check(lhs.type_id, self.ty.std().bool, text_range); self.type_check(rhs.type_id, self.ty.std().bool, text_range); @@ -267,19 +268,28 @@ impl Compiler<'_> { rhs.hir_id, self.ty.std().bool, ); - value.guards.extend(lhs.guards); - value.guards.extend(rhs.guards); + value.guards.extend( + lhs.guards + .into_iter() + .map(|(path, guard)| (path, Guard::new(guard.then_type, None))), + ); + value.guards.extend( + rhs.guards + .into_iter() + .map(|(path, guard)| (path, Guard::new(guard.then_type, None))), + ); value } fn op_or(&mut self, lhs: &Value, rhs: Option<&Expr>, text_range: TextRange) -> Value { - self.type_guard_stack.push(lhs.then_guards()); + let overrides = self.build_overrides(lhs.else_guards()); + self.type_overrides.push(overrides); let rhs = rhs .map(|rhs| self.compile_expr(rhs, Some(self.ty.std().bool))) .unwrap_or_else(|| self.unknown()); - self.type_guard_stack.pop().unwrap(); + self.type_overrides.pop().unwrap(); self.type_check(lhs.type_id, self.ty.std().bool, text_range); self.type_check(rhs.type_id, self.ty.std().bool, text_range); diff --git a/crates/rue-compiler/src/compiler/expr/field_access_expr.rs b/crates/rue-compiler/src/compiler/expr/field_access_expr.rs index dffcc67..a61700b 100644 --- a/crates/rue-compiler/src/compiler/expr/field_access_expr.rs +++ b/crates/rue-compiler/src/compiler/expr/field_access_expr.rs @@ -22,8 +22,8 @@ impl Compiler<'_> { return self.unknown(); }; - let mut new_value = match self.ty.get(old_value.type_id).clone() { - Type::Unknown => return self.unknown(), + match self.ty.get(old_value.type_id).clone() { + Type::Unknown => self.unknown(), Type::Struct(ty) => { let Some(value) = self.compile_struct_field_access(old_value, &ty, &name) else { return self.unknown(); @@ -55,17 +55,9 @@ impl Compiler<'_> { ), name.text_range(), ); - return self.unknown(); - } - }; - - if let Some(guard_path) = new_value.guard_path.as_ref() { - if let Some(type_override) = self.symbol_type(guard_path) { - new_value.type_id = type_override; + self.unknown() } } - - new_value } fn compile_pair_field_access( @@ -113,7 +105,7 @@ impl Compiler<'_> { ) -> Option { let fields = deconstruct_items(self.ty, ty.type_id, ty.field_names.len(), ty.nil_terminated) - .expect("invalid struct type"); + .unwrap(); let Some(index) = ty.field_names.get_index_of(name.text()) else { self.db @@ -157,7 +149,7 @@ impl Compiler<'_> { .as_ref() .map(|field_names| { deconstruct_items(self.ty, type_id, field_names.len(), ty.nil_terminated) - .expect("invalid struct type") + .unwrap() }) .unwrap_or_default() } else { diff --git a/crates/rue-compiler/src/compiler/expr/guard_expr.rs b/crates/rue-compiler/src/compiler/expr/guard_expr.rs index 8f2d51e..8b973ff 100644 --- a/crates/rue-compiler/src/compiler/expr/guard_expr.rs +++ b/crates/rue-compiler/src/compiler/expr/guard_expr.rs @@ -62,7 +62,9 @@ impl Compiler<'_> { if let Some(guard_path) = expr.guard_path { let difference = self.ty.difference(expr.type_id, rhs); - value.guards.insert(guard_path, Guard::new(rhs, difference)); + value + .guards + .insert(guard_path, Guard::new(Some(rhs), Some(difference))); } value diff --git a/crates/rue-compiler/src/compiler/expr/if_expr.rs b/crates/rue-compiler/src/compiler/expr/if_expr.rs index 9c226b4..62b39f1 100644 --- a/crates/rue-compiler/src/compiler/expr/if_expr.rs +++ b/crates/rue-compiler/src/compiler/expr/if_expr.rs @@ -10,7 +10,8 @@ impl Compiler<'_> { .map(|condition| self.compile_expr(&condition, Some(self.ty.std().bool))); if let Some(condition) = condition.as_ref() { - self.type_guard_stack.push(condition.then_guards()); + let overrides = self.build_overrides(condition.then_guards()); + self.type_overrides.push(overrides); } let then_block = if_expr @@ -18,11 +19,12 @@ impl Compiler<'_> { .map(|then_block| self.compile_block_expr(&then_block, expected_type)); if condition.is_some() { - self.type_guard_stack.pop().unwrap(); + self.type_overrides.pop().unwrap(); } if let Some(condition) = condition.as_ref() { - self.type_guard_stack.push(condition.else_guards()); + let overrides = self.build_overrides(condition.else_guards()); + self.type_overrides.push(overrides); } let expected_type = @@ -33,7 +35,7 @@ impl Compiler<'_> { .map(|else_block| self.compile_block_expr(&else_block, expected_type)); if condition.is_some() { - self.type_guard_stack.pop().unwrap(); + self.type_overrides.pop().unwrap(); } if let Some(condition_type) = condition.as_ref().map(|condition| condition.type_id) { diff --git a/crates/rue-compiler/src/compiler/expr/initializer_expr.rs b/crates/rue-compiler/src/compiler/expr/initializer_expr.rs index b9ddd97..d8940f1 100644 --- a/crates/rue-compiler/src/compiler/expr/initializer_expr.rs +++ b/crates/rue-compiler/src/compiler/expr/initializer_expr.rs @@ -13,7 +13,7 @@ impl Compiler<'_> { .path() .map(|path| self.compile_path_type(&path.items(), path.syntax().text_range())); - match ty.map(|ty| self.ty.get(ty)).cloned() { + match ty.map(|ty| self.ty.get_unaliased(ty)).cloned() { Some(Type::Struct(struct_type)) => { let fields = deconstruct_items( self.ty, @@ -85,7 +85,7 @@ impl Compiler<'_> { self.unknown() } } - Some(_) => { + Some(..) => { self.db.error( ErrorKind::UninitializableType(self.type_name(ty.unwrap())), initializer.path().unwrap().syntax().text_range(), diff --git a/crates/rue-compiler/src/compiler/expr/path_expr.rs b/crates/rue-compiler/src/compiler/expr/path_expr.rs index f857f4d..fe4538c 100644 --- a/crates/rue-compiler/src/compiler/expr/path_expr.rs +++ b/crates/rue-compiler/src/compiler/expr/path_expr.rs @@ -8,7 +8,7 @@ use crate::{ Compiler, }, hir::Hir, - symbol::{Function, Symbol}, + symbol::Symbol, value::{GuardPath, Value}, ErrorKind, }; @@ -75,18 +75,16 @@ impl Compiler<'_> { return self.unknown(); } - let type_override = self.symbol_type(&GuardPath::new(symbol_id)); + let type_id = self.symbol_type(symbol_id); let reference = self.db.alloc_hir(Hir::Reference(symbol_id, text_range)); let mut value = match self.db.symbol(symbol_id).clone() { Symbol::Unknown | Symbol::Module(..) => unreachable!(), - Symbol::Function(Function { type_id, .. }) - | Symbol::InlineFunction(Function { type_id, .. }) - | Symbol::Parameter(type_id) => Value::new(reference, type_override.unwrap_or(type_id)), + Symbol::Function(..) | Symbol::InlineFunction(..) | Symbol::Parameter(..) => { + Value::new(reference, type_id) + } Symbol::Let(mut value) | Symbol::Const(mut value) | Symbol::InlineConst(mut value) => { - if let Some(type_id) = type_override { - value.type_id = type_id; - } + value.type_id = type_id; value.hir_id = reference; value } diff --git a/crates/rue-compiler/src/compiler/stmt/if_stmt.rs b/crates/rue-compiler/src/compiler/stmt/if_stmt.rs index 9927551..727580b 100644 --- a/crates/rue-compiler/src/compiler/stmt/if_stmt.rs +++ b/crates/rue-compiler/src/compiler/stmt/if_stmt.rs @@ -35,7 +35,8 @@ impl Compiler<'_> { let scope_id = self.db.alloc_scope(Scope::default()); // We can apply any type guards from the condition. - self.type_guard_stack.push(condition.then_guards()); + let overrides = self.build_overrides(condition.then_guards()); + self.type_overrides.push(overrides); // Compile the then block. self.scope_stack.push(scope_id); @@ -43,7 +44,7 @@ impl Compiler<'_> { self.scope_stack.pop().unwrap(); // Pop the type guards, since we've left the scope. - self.type_guard_stack.pop().unwrap(); + self.type_overrides.pop().unwrap(); // If there's an implicit return, we want to raise an error. // This could technically work but makes the intent of the code unclear. diff --git a/crates/rue-compiler/src/value.rs b/crates/rue-compiler/src/value.rs index 89b61ac..5344b1f 100644 --- a/crates/rue-compiler/src/value.rs +++ b/crates/rue-compiler/src/value.rs @@ -29,26 +29,26 @@ impl Value { pub fn then_guards(&self) -> HashMap { self.guards .iter() - .map(|(guard_path, guard)| (guard_path.clone(), guard.then_type)) + .filter_map(|(guard_path, guard)| Some((guard_path.clone(), guard.then_type?))) .collect() } pub fn else_guards(&self) -> HashMap { self.guards .iter() - .map(|(guard_path, guard)| (guard_path.clone(), guard.else_type)) + .filter_map(|(guard_path, guard)| Some((guard_path.clone(), guard.else_type?))) .collect() } } #[derive(Debug, Clone, Copy)] pub struct Guard { - pub then_type: TypeId, - pub else_type: TypeId, + pub then_type: Option, + pub else_type: Option, } impl Guard { - pub fn new(then_type: TypeId, else_type: TypeId) -> Self { + pub fn new(then_type: Option, else_type: Option) -> Self { Self { then_type, else_type, diff --git a/crates/rue-typing/src/replace_type.rs b/crates/rue-typing/src/replace_type.rs index 54565d4..5179f8c 100644 --- a/crates/rue-typing/src/replace_type.rs +++ b/crates/rue-typing/src/replace_type.rs @@ -11,10 +11,19 @@ pub(crate) fn replace_type( } match types.get(type_id) { - Type::Pair(first, rest) => match path[0] { - TypePath::First => replace_type(types, *first, replace_type_id, &path[1..]), - TypePath::Rest => replace_type(types, *rest, replace_type_id, &path[1..]), - }, + Type::Pair(first, rest) => { + let (first, rest) = (*first, *rest); + match path[0] { + TypePath::First => { + let first = replace_type(types, first, replace_type_id, &path[1..]); + types.alloc(Type::Pair(first, rest)) + } + TypePath::Rest => { + let rest = replace_type(types, rest, replace_type_id, &path[1..]); + types.alloc(Type::Pair(first, rest)) + } + } + } Type::Alias(alias) => { let alias = alias.clone(); let new_type_id = replace_type(types, alias.type_id, replace_type_id, path); diff --git a/crates/rue-typing/src/stringify_type.rs b/crates/rue-typing/src/stringify_type.rs index 4cf18a7..2119040 100644 --- a/crates/rue-typing/src/stringify_type.rs +++ b/crates/rue-typing/src/stringify_type.rs @@ -62,8 +62,21 @@ pub(crate) fn stringify_type( name + &generics } Type::Alias(alias) => stringify_type(types, alias.type_id, names, visited), - Type::Struct(Struct { type_id, .. }) | Type::Variant(Variant { type_id, .. }) => { - stringify_type(types, *type_id, names, visited) + Type::Struct(Struct { + type_id, + original_type_id, + .. + }) + | Type::Variant(Variant { + type_id, + original_type_id, + .. + }) => { + if type_id == original_type_id { + stringify_type(types, *type_id, names, visited) + } else { + stringify_type(types, *original_type_id, names, visited) + } } Type::Enum(Enum { type_id, .. }) => stringify_type(types, *type_id, names, visited), Type::Callable(Callable { diff --git a/crates/rue-typing/src/type_system.rs b/crates/rue-typing/src/type_system.rs index 90c9e18..f6d02bb 100644 --- a/crates/rue-typing/src/type_system.rs +++ b/crates/rue-typing/src/type_system.rs @@ -109,6 +109,13 @@ impl TypeSystem { } } + pub fn get_unaliased(&self, type_id: TypeId) -> &Type { + match self.get(type_id) { + Type::Alias(ty) => self.get_unaliased(ty.type_id), + ty => ty, + } + } + pub fn get(&self, type_id: TypeId) -> &Type { match &self.arena[type_id] { Type::Ref(type_id) => self.get(*type_id),