diff --git a/Cargo.lock b/Cargo.lock index 4968a86..35bb919 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,19 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "ahash" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +dependencies = [ + "cfg-if", + "getrandom", + "once_cell", + "version_check", + "zerocopy", +] + [[package]] name = "aho-corasick" version = "1.1.3" @@ -26,6 +39,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "allocator-api2" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" + [[package]] name = "anstream" version = "0.6.13" @@ -76,9 +95,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.81" +version = "1.0.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0952808a6c2afd1aa8947271f3a60f1a6763c7b912d210184c5149b5cf147247" +checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" [[package]] name = "arbitrary" @@ -787,9 +806,13 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.14.3" +version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "ahash", + "allocator-api2", +] [[package]] name = "heck" @@ -1321,9 +1344,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.4" +version = "1.10.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" +checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f" dependencies = [ "aho-corasick", "memchr", @@ -1410,6 +1433,7 @@ dependencies = [ "rowan", "rue-clvm", "rue-parser", + "rue-typing", ] [[package]] @@ -1455,6 +1479,21 @@ dependencies = [ "walkdir", ] +[[package]] +name = "rue-typing" +version = "0.1.1" +dependencies = [ + "ahash", + "anyhow", + "clvmr 0.6.1", + "hashbrown", + "id-arena", + "indexmap", + "num-bigint", + "num-traits", + "thiserror", +] + [[package]] name = "rustc-demangle" version = "0.1.23" @@ -2326,6 +2365,26 @@ dependencies = [ "linked-hash-map", ] +[[package]] +name = "zerocopy" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.61", +] + [[package]] name = "zeroize" version = "1.7.0" diff --git a/Cargo.toml b/Cargo.toml index a17e72f..f455dee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,12 +8,12 @@ keywords = ["chia", "blockchain", "crypto"] categories = ["cryptography::cryptocurrencies", "development-tools"] [workspace.lints.rust] +rust_2018_idioms = { level = "deny", priority = -1 } +rust_2021_compatibility = { level = "deny", priority = -1 } +future_incompatible = { level = "deny", priority = -1 } +nonstandard_style = { level = "deny", priority = -1 } unsafe_code = "deny" -rust_2018_idioms = "deny" -rust_2021_compatibility = "deny" -future_incompatible = "deny" non_ascii_idents = "deny" -nonstandard_style = "deny" unused_extern_crates = "deny" trivial_casts = "deny" trivial_numeric_casts = "deny" @@ -26,7 +26,7 @@ missing_debug_implementations = "warn" missing_copy_implementations = "warn" [workspace.lints.clippy] -all = "deny" +all = { level = "deny", priority = -1 } pedantic = { level = "warn", priority = -1 } missing_errors_doc = "allow" missing_panics_doc = "allow" @@ -34,3 +34,33 @@ module_name_repetitions = "allow" multiple_crate_versions = "allow" must_use_candidate = "allow" too_many_lines = "allow" + +[workspace.dependencies] +rue-parser = { path = "./crates/rue-parser", version = "0.1.1" } +rue-compiler = { path = "./crates/rue-compiler", version = "0.1.1" } +rue-typing = { path = "./crates/rue-typing", version = "0.1.1" } +rue-clvm = { path = "./crates/rue-clvm", version = "0.1.1" } +rue-lexer = { path = "./crates/rue-lexer", version = "0.1.1" } +clvmr_old = { version = "0.3.2", package = "clvmr" } +clvmr = "0.6.1" +clap = "4.5.4" +hex = "0.4.3" +clvm_tools_rs = "0.1.41" +thiserror = "1.0.61" +num-bigint = "0.4.6" +num-traits = "0.2.19" +num-derive = "0.4.2" +id-arena = "2.2.1" +indexmap = "2.2.6" +rowan = "0.15.15" +log = "0.4.21" +indoc = "2.0.5" +tokio = "1.37.0" +tower-lsp = "0.20.0" +clvm-utils = "0.6.0" +toml = "0.8.12" +serde = "1.0.197" +walkdir = "2.5.0" +anyhow = "1.0.86" +hashbrown = "0.14.5" +ahash = "0.8.11" diff --git a/crates/rue-cli/Cargo.toml b/crates/rue-cli/Cargo.toml index 952c6ac..9bc0913 100644 --- a/crates/rue-cli/Cargo.toml +++ b/crates/rue-cli/Cargo.toml @@ -15,12 +15,12 @@ categories = { workspace = true } workspace = true [dependencies] -clap = { version = "4.5.4", features = ["derive"] } -rue-parser = { path = "../../crates/rue-parser", version = "0.1.0" } -rue-compiler = { path = "../../crates/rue-compiler", version = "0.1.0" } -rue-clvm = { path = "../../crates/rue-clvm", version = "0.1.0" } -clvmr = "0.6.1" -hex = "0.4.3" +clap = { workspace = true, features = ["derive"] } +rue-parser = { workspace = true } +rue-compiler = { workspace = true } +rue-clvm = { workspace = true } +clvmr = { workspace = true } +hex = { workspace = true } [[bin]] name = "rue" diff --git a/crates/rue-cli/src/main.rs b/crates/rue-cli/src/main.rs index a64d8cc..c32db5c 100644 --- a/crates/rue-cli/src/main.rs +++ b/crates/rue-cli/src/main.rs @@ -5,7 +5,7 @@ use std::fs; use clap::Parser; use clvmr::{serde::node_to_bytes, Allocator, NodePtr}; use rue_clvm::{parse_clvm, run_clvm, stringify_clvm}; -use rue_compiler::{compile, Diagnostic, DiagnosticKind}; +use rue_compiler::{compile_raw, Diagnostic, DiagnosticKind}; use rue_parser::{line_col, parse, LineCol}; /// CLI tools for working with the Rue compiler. @@ -20,23 +20,31 @@ enum Command { /// A list of parameters to run the compiled program with. #[clap(long, short = 'r')] run: Option>, + + /// Whether to exclude the standard library. + #[clap(long, short = 'n')] + no_std: bool, }, /// Check a Rue source file for errors. Check { /// The source file to check. file: String, + + /// Whether to exclude the standard library. + #[clap(long, short = 'n')] + no_std: bool, }, } fn main() { match Command::parse() { - Command::Build { file, run } => build(file, true, &run), - Command::Check { file } => build(file, false, &None), + Command::Build { file, run, no_std } => build(file, true, &run, no_std), + Command::Check { file, no_std } => build(file, false, &None, no_std), } } -fn build(file: String, should_compile: bool, run: &Option>) { +fn build(file: String, should_compile: bool, run: &Option>, no_std: bool) { let source = fs::read_to_string(file).expect("could not read source file"); let (ast, errors) = parse(&source); @@ -49,7 +57,12 @@ fn build(file: String, should_compile: bool, run: &Option>) { } let mut allocator = Allocator::new(); - let output = compile(&mut allocator, &ast, should_compile && errors.is_empty()); + let output = compile_raw( + &mut allocator, + &ast, + should_compile && errors.is_empty(), + !no_std, + ); if print_diagnostics(&source, &output.diagnostics) { return; diff --git a/crates/rue-clvm/Cargo.toml b/crates/rue-clvm/Cargo.toml index 5e2f733..4bb4bed 100644 --- a/crates/rue-clvm/Cargo.toml +++ b/crates/rue-clvm/Cargo.toml @@ -15,9 +15,9 @@ categories = { workspace = true } workspace = true [dependencies] -clvmr = "0.6.1" -clvmr_old = { version = "0.3.2", package = "clvmr" } -clvm_tools_rs = "0.1.41" -thiserror = "1.0.61" -num-bigint = "0.4.6" -num-traits = "0.2.19" +clvmr = { workspace = true } +clvmr_old = { workspace = true } +clvm_tools_rs = { workspace = true } +thiserror = { workspace = true } +num-bigint = { workspace = true } +num-traits = { workspace = true } diff --git a/crates/rue-compiler/Cargo.toml b/crates/rue-compiler/Cargo.toml index 40c2574..bbdb618 100644 --- a/crates/rue-compiler/Cargo.toml +++ b/crates/rue-compiler/Cargo.toml @@ -15,14 +15,15 @@ categories = { workspace = true } workspace = true [dependencies] -rue-parser = { path = "../rue-parser", version = "0.1.0" } -rue-clvm = { path = "../rue-clvm", version = "0.1.0" } -clvmr = "0.6.1" -id-arena = "2.2.1" -indexmap = "2.2.6" -rowan = "0.15.15" -num-traits = "0.2.18" -num-bigint = "0.4.4" -log = "0.4.21" -hex = "0.4.3" -indoc = "2.0.5" +rue-parser = { workspace = true } +rue-clvm = { workspace = true } +rue-typing = { workspace = true } +clvmr = { workspace = true } +id-arena = { workspace = true } +indexmap = { workspace = true } +rowan = { workspace = true } +num-traits = { workspace = true } +num-bigint = { workspace = true } +log = { workspace = true } +hex = { workspace = true } +indoc = { workspace = true } diff --git a/crates/rue-compiler/src/compiler.rs b/crates/rue-compiler/src/compiler.rs index ad7cc33..756c003 100644 --- a/crates/rue-compiler/src/compiler.rs +++ b/crates/rue-compiler/src/compiler.rs @@ -1,17 +1,18 @@ #![allow(clippy::map_unwrap_or)] -use std::collections::HashMap; +use rue_typing::{HashMap, TypePath}; pub(crate) use builtins::Builtins; -use indexmap::IndexSet; + use rowan::TextRange; +use rue_typing::{Comparison, TypeId, TypeSystem}; use symbol_table::SymbolTable; use crate::{ - database::{Database, HirId, ScopeId, SymbolId, TypeId}, + database::{Database, HirId, ScopeId, SymbolId}, hir::{Hir, Op}, scope::Scope, - value::{GuardPath, Mutation, PairType, Type, TypeOverride, Value}, + value::{GuardPath, Value}, ErrorKind, }; @@ -25,8 +26,6 @@ mod stmt; mod symbol_table; mod ty; -#[cfg(test)] -pub use builtins::*; pub use context::*; /// Responsible for lowering the AST into the HIR. @@ -35,6 +34,9 @@ pub struct Compiler<'a> { // The database is mutable because we need to allocate new symbols and types. db: &'a mut Database, + // The type system is responsible for type checking and type inference. + ty: &'a mut TypeSystem, + // The scope stack is used to keep track of the current scope. scope_stack: Vec, @@ -45,7 +47,7 @@ pub struct Compiler<'a> { type_definition_stack: Vec, // The type guard stack is used for overriding types in certain contexts. - type_guard_stack: Vec>, + type_guard_stack: Vec>, // The generic type stack is used for overriding generic types that are being checked against. generic_type_stack: Vec>, @@ -65,9 +67,10 @@ pub struct Compiler<'a> { } impl<'a> Compiler<'a> { - pub fn new(db: &'a mut Database, builtins: Builtins) -> Self { + pub fn new(db: &'a mut Database, ty: &'a mut TypeSystem, builtins: Builtins) -> Self { Self { db, + ty, scope_stack: vec![builtins.scope_id], symbol_stack: Vec::new(), type_definition_stack: Vec::new(), @@ -85,15 +88,14 @@ impl<'a> Compiler<'a> { self.sym } - fn compile_index(&mut self, value: HirId, index: usize, rest: bool) -> HirId { - let mut result = value; - for _ in 0..index { - result = self.db.alloc_hir(Hir::Op(Op::Rest, result)); - } - if !rest { - result = self.db.alloc_hir(Hir::Op(Op::First, result)); + fn hir_path(&mut self, mut value: HirId, path_items: &[TypePath]) -> HirId { + for path in path_items { + match path { + TypePath::First => value = self.db.alloc_hir(Hir::Op(Op::First, value)), + TypePath::Rest => value = self.db.alloc_hir(Hir::Op(Op::Rest, value)), + } } - result + value } fn type_reference(&mut self, referenced_type_id: TypeId) { @@ -117,123 +119,29 @@ impl<'a> Compiler<'a> { None } - fn type_name(&self, ty: TypeId) -> String { - self.type_name_visitor(ty, &mut IndexSet::new()) - } + fn type_name(&self, type_id: TypeId) -> String { + let mut names = HashMap::new(); - fn type_name_visitor(&self, ty: TypeId, stack: &mut IndexSet) -> String { - for &scope_id in self.scope_stack.iter().rev() { - if let Some(name) = self.db.scope(scope_id).type_name(ty) { - return name.to_string(); - } - } - - if stack.contains(&ty) { - return "".to_string(); - } - - stack.insert(ty); - - let name = match self.db.ty(ty) { - Type::Unknown => "{unknown}".to_string(), - Type::Generic => "{generic}".to_string(), - Type::Nil => "Nil".to_string(), - Type::Any => "Any".to_string(), - Type::Int => "Int".to_string(), - Type::Bool => "Bool".to_string(), - Type::Bytes => "Bytes".to_string(), - Type::Bytes32 => "Bytes32".to_string(), - Type::PublicKey => "PublicKey".to_string(), - Type::List(items) => { - let inner = self.type_name_visitor(*items, stack); - format!("{inner}[]") - } - Type::Pair(PairType { first, rest }) => { - let first = self.type_name_visitor(*first, stack); - let rest = self.type_name_visitor(*rest, stack); - format!("({first}, {rest})") - } - Type::Struct(struct_type) => { - if struct_type.original_type_id == ty { - let fields: Vec = struct_type - .fields - .iter() - .map(|(name, ty)| { - format!("{}: {}", name, self.type_name_visitor(*ty, stack)) - }) - .collect(); - - format!("{{{}}}", fields.join(", ")) - } else { - self.type_name_visitor(struct_type.original_type_id, stack) + for &scope_id in &self.scope_stack { + for type_id in self.db.scope(scope_id).local_types() { + if let Some(name) = self.db.scope(scope_id).type_name(type_id) { + names.insert(type_id, name.to_string()); } } - Type::Enum { .. } => "".to_string(), - Type::EnumVariant(enum_variant) => { - let enum_name = self.type_name_visitor(enum_variant.enum_type, stack); - - let fields: Option> = enum_variant.fields.as_ref().map(|fields| { - fields - .iter() - .map(|(name, ty)| { - format!("{}: {}", name, self.type_name_visitor(*ty, stack)) - }) - .collect() - }); - - let variant_name = match self.db.ty(enum_variant.enum_type) { - Type::Enum(enum_type) => enum_type - .variants - .iter() - .find(|item| *item.1 == enum_variant.original_type_id) - .expect("enum type is missing variant") - .0 - .clone(), - _ => unreachable!(), - }; - - if let Some(fields) = fields { - format!("{enum_name}::{variant_name} {{{}}}", fields.join(", ")) - } else { - format!("{enum_name}::{variant_name}") - } - } - Type::Function(function_type) => { - let params: Vec = function_type - .param_types - .iter() - .map(|&ty| self.type_name_visitor(ty, stack)) - .collect(); - - let ret = self.type_name_visitor(function_type.return_type, stack); - - format!("fun({}) -> {}", params.join(", "), ret) - } - Type::Alias(..) => unreachable!(), - Type::Nullable(ty) => { - let inner = self.type_name_visitor(*ty, stack); - format!("{inner}?") - } - Type::Optional(ty) => { - let inner = self.type_name_visitor(*ty, stack); - format!("optional {inner}") - } - }; - - stack.pop().unwrap(); + } - name + self.ty.stringify_named(type_id, names) } fn type_check(&mut self, from: TypeId, to: TypeId, range: TextRange) { let comparison = if self.allow_generic_inference_stack.last().copied().unwrap() { - self.db - .compare_type_with_generics(from, to, &mut self.generic_type_stack) + self.ty + .compare_with_generics(from, to, &mut self.generic_type_stack, true) } else { - self.db.compare_type(from, to) + self.ty.compare(from, to) }; - if !comparison.is_assignable() { + if comparison > Comparison::Assignable { self.db.error( ErrorKind::TypeMismatch(self.type_name(from), self.type_name(to)), range, @@ -243,13 +151,13 @@ impl<'a> Compiler<'a> { fn cast_check(&mut self, from: TypeId, to: TypeId, range: TextRange) { let comparison = if self.allow_generic_inference_stack.last().copied().unwrap() { - self.db - .compare_type_with_generics(from, to, &mut self.generic_type_stack) + self.ty + .compare_with_generics(from, to, &mut self.generic_type_stack, true) } else { - self.db.compare_type(from, to) + self.ty.compare(from, to) }; - if !comparison.is_castable() { + if comparison > Comparison::Castable { self.db.error( ErrorKind::CastMismatch(self.type_name(from), self.type_name(to)), range, @@ -258,10 +166,10 @@ impl<'a> Compiler<'a> { } fn unknown(&self) -> Value { - Value::new(self.builtins.unknown_hir, self.builtins.unknown) + Value::new(self.builtins.unknown, self.ty.std().unknown) } - fn symbol_type(&self, guard_path: &GuardPath) -> Option { + fn symbol_type(&self, guard_path: &GuardPath) -> Option { for guards in self.type_guard_stack.iter().rev() { if let Some(guard) = guards.get(guard_path) { return Some(*guard); @@ -270,13 +178,6 @@ impl<'a> Compiler<'a> { None } - fn apply_mutation(&mut self, hir_id: HirId, mutation: Mutation) -> HirId { - match mutation { - Mutation::None => hir_id, - Mutation::UnwrapOptional => self.db.alloc_hir(Hir::Op(Op::First, hir_id)), - } - } - fn scope(&self) -> &Scope { self.db .scope(self.scope_stack.last().copied().expect("no scope found")) diff --git a/crates/rue-compiler/src/compiler/block.rs b/crates/rue-compiler/src/compiler/block.rs index 83d103d..984f1ea 100644 --- a/crates/rue-compiler/src/compiler/block.rs +++ b/crates/rue-compiler/src/compiler/block.rs @@ -1,9 +1,10 @@ use rue_parser::{AstNode, Block, Stmt}; +use rue_typing::TypeId; use crate::{ hir::{Hir, Op}, value::Value, - ErrorKind, TypeId, + ErrorKind, }; use super::{stmt::Statement, Compiler}; @@ -62,7 +63,7 @@ impl Compiler<'_> { // Make sure that the return value matches the expected type. self.type_check( value.type_id, - expected_type.unwrap_or(self.builtins.unknown), + expected_type.unwrap_or(self.ty.std().unknown), return_stmt.syntax().text_range(), ); @@ -83,19 +84,19 @@ impl Compiler<'_> { terminator = BlockTerminator::Raise; is_terminated = true; - statements.push(Statement::Return(Value::new(hir_id, self.builtins.unknown))); + statements.push(Statement::Return(Value::new(hir_id, self.ty.std().never))); } Stmt::AssertStmt(assert_stmt) => { // Compile the condition expression. let condition = assert_stmt .expr() - .map(|condition| self.compile_expr(&condition, Some(self.builtins.bool))) + .map(|condition| self.compile_expr(&condition, Some(self.ty.std().bool))) .unwrap_or_else(|| self.unknown()); // Make sure that the condition is a boolean. self.type_check( condition.type_id, - self.builtins.bool, + self.ty.std().bool, assert_stmt.syntax().text_range(), ); @@ -115,13 +116,13 @@ impl Compiler<'_> { // Compile the expression. let expr = assume_stmt .expr() - .map(|expr| self.compile_expr(&expr, Some(self.builtins.bool))) + .map(|expr| self.compile_expr(&expr, Some(self.ty.std().bool))) .unwrap_or_else(|| self.unknown()); // Make sure that the condition is a boolean. self.type_check( expr.type_id, - self.builtins.bool, + self.ty.std().bool, assume_stmt.syntax().text_range(), ); diff --git a/crates/rue-compiler/src/compiler/builtins.rs b/crates/rue-compiler/src/compiler/builtins.rs index 46d908b..1045357 100644 --- a/crates/rue-compiler/src/compiler/builtins.rs +++ b/crates/rue-compiler/src/compiler/builtins.rs @@ -1,69 +1,50 @@ +use indexmap::indexset; use rowan::TextRange; +use rue_typing::{Callable, Type, TypeSystem}; use crate::{ hir::{BinOp, Hir, Op}, scope::Scope, symbol::{Function, Symbol}, - value::{FunctionType, PairType, Rest, Type}, - Database, HirId, ScopeId, SymbolId, TypeId, + Database, HirId, ScopeId, SymbolId, }; /// These are the built-in types and most commonly used HIR nodes. pub struct Builtins { pub scope_id: ScopeId, - pub any: TypeId, - pub int: TypeId, - pub bool: TypeId, - pub bytes: TypeId, - pub bytes32: TypeId, - pub public_key: TypeId, - pub nil: TypeId, - pub nil_hir: HirId, - pub unknown: TypeId, - pub unknown_hir: HirId, + pub nil: HirId, + pub unknown: HirId, } /// Defines intrinsics that cannot be implemented in Rue. -pub fn builtins(db: &mut Database) -> Builtins { +pub fn builtins(db: &mut Database, ty: &mut TypeSystem) -> Builtins { let mut scope = Scope::default(); - let int = db.alloc_type(Type::Int); - let bool = db.alloc_type(Type::Bool); - let bytes = db.alloc_type(Type::Bytes); - let bytes32 = db.alloc_type(Type::Bytes32); - let public_key = db.alloc_type(Type::PublicKey); - let any = db.alloc_type(Type::Any); - let nil = db.alloc_type(Type::Nil); - let nil_hir = db.alloc_hir(Hir::Atom(Vec::new())); - let unknown = db.alloc_type(Type::Unknown); - let unknown_hir = db.alloc_hir(Hir::Unknown); - - scope.define_type("Nil".to_string(), nil); - scope.define_type("Int".to_string(), int); - scope.define_type("Bool".to_string(), bool); - scope.define_type("Bytes".to_string(), bytes); - scope.define_type("Bytes32".to_string(), bytes32); - scope.define_type("PublicKey".to_string(), public_key); - scope.define_type("Any".to_string(), any); + let nil = db.alloc_hir(Hir::Atom(Vec::new())); + let unknown = db.alloc_hir(Hir::Unknown); + + scope.define_type("Int".to_string(), ty.std().int); + scope.define_type("Bool".to_string(), ty.std().bool); + scope.define_type("Bytes".to_string(), ty.std().bytes); + scope.define_type("Bytes32".to_string(), ty.std().bytes32); + scope.define_type("PublicKey".to_string(), ty.std().public_key); + scope.define_type("Any".to_string(), ty.std().any); + scope.define_type("List".to_string(), ty.std().unmapped_list); let builtins = Builtins { scope_id: db.alloc_scope(scope), - any, - int, - bool, - bytes, - bytes32, - public_key, nil, - nil_hir, unknown, - unknown_hir, }; - let sha256 = sha256(db, &builtins); - let pubkey_for_exp = pubkey_for_exp(db, &builtins); - let divmod = divmod(db, &builtins); - let substr = substr(db, &builtins); + let cast = cast(db, ty); + let sha256 = sha256(db, ty); + let pubkey_for_exp = pubkey_for_exp(db, ty); + let divmod = divmod(db, ty); + let substr = substr(db, ty); + + db.scope_mut(builtins.scope_id) + .define_symbol("cast".to_string(), cast); db.scope_mut(builtins.scope_id) .define_symbol("sha256".to_string(), sha256); @@ -80,51 +61,94 @@ pub fn builtins(db: &mut Database) -> Builtins { builtins } -fn sha256(db: &mut Database, builtins: &Builtins) -> SymbolId { +fn cast(db: &mut Database, ty: &mut TypeSystem) -> SymbolId { let mut scope = Scope::default(); - let param = db.alloc_symbol(Symbol::Parameter(builtins.bytes)); + let param = db.alloc_symbol(Symbol::Parameter(ty.std().any)); + scope.define_symbol("value".to_string(), param); + let param_ref = db.alloc_hir(Hir::Reference(param, TextRange::default())); + let scope_id = db.alloc_scope(scope); + + let generic = ty.alloc(Type::Generic); + let type_id = ty.alloc(Type::Unknown); + + *ty.get_mut(type_id) = Type::Callable(Callable { + original_type_id: type_id, + parameter_names: indexset!["value".to_string()], + parameters: ty.alloc(Type::Pair(ty.std().any, ty.std().nil)), + nil_terminated: true, + return_type: generic, + generic_types: vec![generic], + }); + + db.alloc_symbol(Symbol::InlineFunction(Function { + scope_id, + hir_id: param_ref, + type_id, + nil_terminated: true, + })) +} + +fn sha256(db: &mut Database, ty: &mut TypeSystem) -> SymbolId { + let mut scope = Scope::default(); + + let param = db.alloc_symbol(Symbol::Parameter(ty.std().bytes)); scope.define_symbol("bytes".to_string(), param); let param_ref = db.alloc_hir(Hir::Reference(param, TextRange::default())); let hir_id = db.alloc_hir(Hir::Op(Op::Sha256, param_ref)); let scope_id = db.alloc_scope(scope); + let type_id = ty.alloc(Type::Unknown); + + *ty.get_mut(type_id) = Type::Callable(Callable { + original_type_id: type_id, + parameter_names: indexset!["bytes".to_string()], + parameters: ty.alloc(Type::Pair(ty.std().bytes, ty.std().nil)), + nil_terminated: true, + return_type: ty.std().bytes32, + generic_types: Vec::new(), + }); + db.alloc_symbol(Symbol::InlineFunction(Function { scope_id, hir_id, - ty: FunctionType { - param_types: vec![builtins.bytes], - rest: Rest::Nil, - return_type: builtins.bytes32, - generic_types: Vec::new(), - }, + type_id, + nil_terminated: true, })) } -fn pubkey_for_exp(db: &mut Database, builtins: &Builtins) -> SymbolId { +fn pubkey_for_exp(db: &mut Database, ty: &mut TypeSystem) -> SymbolId { let mut scope = Scope::default(); - let param = db.alloc_symbol(Symbol::Parameter(builtins.bytes32)); + let param = db.alloc_symbol(Symbol::Parameter(ty.std().bytes32)); scope.define_symbol("exponent".to_string(), param); let param_ref = db.alloc_hir(Hir::Reference(param, TextRange::default())); let hir_id = db.alloc_hir(Hir::Op(Op::PubkeyForExp, param_ref)); let scope_id = db.alloc_scope(scope); + let type_id = ty.alloc(Type::Unknown); + + *ty.get_mut(type_id) = Type::Callable(Callable { + original_type_id: type_id, + parameter_names: indexset!["exponent".to_string()], + parameters: ty.alloc(Type::Pair(ty.std().bytes32, ty.std().nil)), + nil_terminated: true, + return_type: ty.std().public_key, + generic_types: Vec::new(), + }); + db.alloc_symbol(Symbol::InlineFunction(Function { scope_id, hir_id, - ty: FunctionType { - param_types: vec![builtins.bytes32], - rest: Rest::Nil, - return_type: builtins.public_key, - generic_types: Vec::new(), - }, + type_id, + nil_terminated: true, })) } -fn divmod(db: &mut Database, builtins: &Builtins) -> SymbolId { +fn divmod(db: &mut Database, ty: &mut TypeSystem) -> SymbolId { let mut scope = Scope::default(); - let lhs = db.alloc_symbol(Symbol::Parameter(builtins.int)); - let rhs = db.alloc_symbol(Symbol::Parameter(builtins.int)); + + let lhs = db.alloc_symbol(Symbol::Parameter(ty.std().int)); + let rhs = db.alloc_symbol(Symbol::Parameter(ty.std().int)); scope.define_symbol("lhs".to_string(), lhs); scope.define_symbol("rhs".to_string(), rhs); let lhs_ref = db.alloc_hir(Hir::Reference(lhs, TextRange::default())); @@ -132,28 +156,36 @@ fn divmod(db: &mut Database, builtins: &Builtins) -> SymbolId { let hir_id = db.alloc_hir(Hir::BinaryOp(BinOp::DivMod, lhs_ref, rhs_ref)); let scope_id = db.alloc_scope(scope); - let int_pair = db.alloc_type(Type::Pair(PairType { - first: builtins.int, - rest: builtins.int, - })); + let int_pair = ty.alloc(Type::Pair(ty.std().int, ty.std().int)); + + let type_id = ty.alloc(Type::Unknown); + + let parameters = ty.alloc(Type::Pair(ty.std().int, ty.std().nil)); + let parameters = ty.alloc(Type::Pair(ty.std().int, parameters)); + + *ty.get_mut(type_id) = Type::Callable(Callable { + original_type_id: type_id, + parameter_names: indexset!["lhs".to_string(), "rhs".to_string()], + parameters, + nil_terminated: true, + return_type: int_pair, + generic_types: Vec::new(), + }); db.alloc_symbol(Symbol::InlineFunction(Function { scope_id, hir_id, - ty: FunctionType { - param_types: vec![builtins.int, builtins.int], - rest: Rest::Nil, - return_type: int_pair, - generic_types: Vec::new(), - }, + type_id, + nil_terminated: true, })) } -fn substr(db: &mut Database, builtins: &Builtins) -> SymbolId { +fn substr(db: &mut Database, ty: &mut TypeSystem) -> SymbolId { let mut scope = Scope::default(); - let value = db.alloc_symbol(Symbol::Parameter(builtins.bytes)); - let start = db.alloc_symbol(Symbol::Parameter(builtins.int)); - let end = db.alloc_symbol(Symbol::Parameter(builtins.int)); + + let value = db.alloc_symbol(Symbol::Parameter(ty.std().bytes)); + let start = db.alloc_symbol(Symbol::Parameter(ty.std().int)); + let end = db.alloc_symbol(Symbol::Parameter(ty.std().int)); scope.define_symbol("value".to_string(), value); scope.define_symbol("start".to_string(), start); scope.define_symbol("end".to_string(), end); @@ -163,14 +195,25 @@ fn substr(db: &mut Database, builtins: &Builtins) -> SymbolId { let hir_id = db.alloc_hir(Hir::Substr(value_ref, start_ref, end_ref)); let scope_id = db.alloc_scope(scope); + let end = ty.alloc(Type::Pair(ty.std().int, ty.std().nil)); + let int = ty.alloc(Type::Pair(ty.std().int, end)); + let parameters = ty.alloc(Type::Pair(ty.std().bytes, int)); + + let type_id = ty.alloc(Type::Unknown); + + *ty.get_mut(type_id) = Type::Callable(Callable { + original_type_id: type_id, + parameter_names: indexset!["value".to_string(), "start".to_string(), "end".to_string()], + parameters, + nil_terminated: true, + return_type: ty.std().bytes, + generic_types: Vec::new(), + }); + db.alloc_symbol(Symbol::InlineFunction(Function { scope_id, hir_id, - ty: FunctionType { - param_types: vec![builtins.bytes, builtins.int, builtins.int], - rest: Rest::Nil, - return_type: builtins.bytes, - generic_types: Vec::new(), - }, + type_id, + nil_terminated: true, })) } diff --git a/crates/rue-compiler/src/compiler/context.rs b/crates/rue-compiler/src/compiler/context.rs index d81e92c..bb11625 100644 --- a/crates/rue-compiler/src/compiler/context.rs +++ b/crates/rue-compiler/src/compiler/context.rs @@ -1,8 +1,9 @@ -use std::collections::HashSet; +use rue_typing::HashSet; use clvmr::{Allocator, NodePtr}; use indexmap::IndexMap; use rue_parser::{parse, Root}; +use rue_typing::{Type, TypeSystem}; use crate::{ codegen::Codegen, @@ -11,7 +12,6 @@ use crate::{ optimizer::Optimizer, scope::Scope, symbol::{Module, Symbol}, - value::Type, Database, SymbolId, }; @@ -22,9 +22,9 @@ pub struct CompilerContext<'a> { roots: IndexMap, } -pub fn setup_compiler(db: &mut Database) -> CompilerContext<'_> { - let builtins = builtins(db); - let compiler = Compiler::new(db, builtins); +pub fn setup_compiler<'a>(db: &'a mut Database, ty: &'a mut TypeSystem) -> CompilerContext<'a> { + let builtins = builtins(db, ty); + let compiler = Compiler::new(db, ty, builtins); CompilerContext { compiler, roots: IndexMap::new(), @@ -88,6 +88,7 @@ pub fn compile_modules(mut ctx: CompilerContext<'_>) -> SymbolTable { pub fn build_graph( db: &mut Database, + ty: &TypeSystem, symbol_table: &SymbolTable, main_module_id: SymbolId, library_module_ids: &[SymbolId], @@ -104,7 +105,7 @@ pub fn build_graph( } for type_id in ignored_types.clone() { - let Type::Enum(enum_type) = db.ty(type_id) else { + let Type::Enum(enum_type) = ty.get(type_id) else { continue; }; ignored_types.extend(enum_type.variants.values()); @@ -114,7 +115,11 @@ pub fn build_graph( let Symbol::Function(function) = db.symbol_mut(symbol_id).clone() else { continue; }; - ignored_types.extend(function.ty.generic_types.iter().copied()); + ignored_types.extend( + ty.get_callable(function.type_id) + .map(|callable| callable.generic_types.clone()) + .unwrap_or_default(), + ); } let Symbol::Module(module) = db.symbol_mut(main_module_id).clone() else { @@ -122,7 +127,7 @@ pub fn build_graph( }; let graph = DependencyGraph::build(db, &module); - symbol_table.calculate_unused(db, &graph, &ignored_symbols, &ignored_types); + symbol_table.calculate_unused(db, ty, &graph, &ignored_symbols, &ignored_types); graph } diff --git a/crates/rue-compiler/src/compiler/expr.rs b/crates/rue-compiler/src/compiler/expr.rs index 7ab0bd7..cfa0ef0 100644 --- a/crates/rue-compiler/src/compiler/expr.rs +++ b/crates/rue-compiler/src/compiler/expr.rs @@ -1,19 +1,18 @@ use rue_parser::{AstNode, Expr}; +use rue_typing::TypeId; -use crate::{value::Value, TypeId}; +use crate::value::Value; use super::Compiler; mod binary_expr; mod block_expr; mod cast_expr; -mod exists_expr; mod field_access_expr; mod function_call_expr; mod group_expr; mod guard_expr; mod if_expr; -mod index_access_expr; mod initializer_expr; mod lambda_expr; mod list_expr; @@ -31,11 +30,11 @@ impl Compiler<'_> { let value = match expr { Expr::PathExpr(path) => { - self.compile_path_expr(&path.idents(), path.syntax().text_range()) + self.compile_path_expr(&path.items(), path.syntax().text_range()) } Expr::InitializerExpr(initializer) => self.compile_initializer_expr(initializer), Expr::LiteralExpr(literal) => self.compile_literal_expr(literal), - Expr::ListExpr(list) => self.compile_list_expr(list, expected_type), + Expr::ListExpr(list) => self.compile_list_expr(list), Expr::PairExpr(pair) => self.compile_pair_expr(pair, expected_type), Expr::Block(block) => self.compile_block_expr(block, expected_type), Expr::LambdaExpr(lambda) => self.compile_lambda_expr(lambda, expected_type), @@ -47,8 +46,6 @@ impl Compiler<'_> { Expr::IfExpr(if_expr) => self.compile_if_expr(if_expr, expected_type), Expr::FunctionCallExpr(call) => self.compile_function_call_expr(call), Expr::FieldAccessExpr(field_access) => self.compile_field_access_expr(field_access), - Expr::IndexAccessExpr(index_access) => self.compile_index_access_expr(index_access), - Expr::ExistsExpr(exists) => self.compile_exists_expr(exists), }; self.is_callee = false; diff --git a/crates/rue-compiler/src/compiler/expr/binary_expr.rs b/crates/rue-compiler/src/compiler/expr/binary_expr.rs index 4c81442..1272f0e 100644 --- a/crates/rue-compiler/src/compiler/expr/binary_expr.rs +++ b/crates/rue-compiler/src/compiler/expr/binary_expr.rs @@ -1,11 +1,12 @@ use rowan::TextRange; use rue_parser::{AstNode, BinaryExpr, BinaryOp, Expr}; +use rue_typing::{Comparison, Type, TypeId}; use crate::{ compiler::Compiler, hir::{BinOp, Hir, Op}, - value::{Guard, Type, TypeOverride, Value}, - ErrorKind, HirId, TypeId, + value::{Guard, Value}, + ErrorKind, HirId, }; impl Compiler<'_> { @@ -51,98 +52,90 @@ impl Compiler<'_> { } fn op_add(&mut self, lhs: &Value, rhs: Option<&Expr>, text_range: TextRange) -> Value { - if matches!(self.db.ty(lhs.type_id), Type::Unknown) { + if matches!(self.ty.get(lhs.type_id), Type::Unknown) { if let Some(rhs) = rhs { self.compile_expr(rhs, None); } return self.unknown(); } - if self - .db - .compare_type(lhs.type_id, self.builtins.public_key) - .is_equal() - { + if self.ty.compare(lhs.type_id, self.ty.std().public_key) == Comparison::Equal { return self.add_public_key(lhs.hir_id, rhs, text_range); } - if self - .db - .compare_type(lhs.type_id, self.builtins.bytes) - .is_equal() - { + if self.ty.compare(lhs.type_id, self.ty.std().bytes) <= Comparison::Assignable { return self.add_bytes(lhs.hir_id, rhs, text_range); } let rhs = rhs - .map(|rhs| self.compile_expr(rhs, Some(self.builtins.int))) + .map(|rhs| self.compile_expr(rhs, Some(self.ty.std().int))) .unwrap_or_else(|| self.unknown()); - self.type_check(lhs.type_id, self.builtins.int, text_range); - self.type_check(rhs.type_id, self.builtins.int, text_range); - self.binary_op(BinOp::Add, lhs.hir_id, rhs.hir_id, self.builtins.int) + self.type_check(lhs.type_id, self.ty.std().int, text_range); + self.type_check(rhs.type_id, self.ty.std().int, text_range); + self.binary_op(BinOp::Add, lhs.hir_id, rhs.hir_id, self.ty.std().int) } fn add_public_key(&mut self, lhs: HirId, rhs: Option<&Expr>, text_range: TextRange) -> Value { let rhs = rhs - .map(|rhs| self.compile_expr(rhs, Some(self.builtins.public_key))) + .map(|rhs| self.compile_expr(rhs, Some(self.ty.std().public_key))) .unwrap_or_else(|| self.unknown()); - self.type_check(rhs.type_id, self.builtins.public_key, text_range); - self.binary_op(BinOp::PointAdd, lhs, rhs.hir_id, self.builtins.public_key) + self.type_check(rhs.type_id, self.ty.std().public_key, text_range); + self.binary_op(BinOp::PointAdd, lhs, rhs.hir_id, self.ty.std().public_key) } fn add_bytes(&mut self, lhs: HirId, rhs: Option<&Expr>, text_range: TextRange) -> Value { let rhs = rhs - .map(|rhs| self.compile_expr(rhs, Some(self.builtins.bytes))) + .map(|rhs| self.compile_expr(rhs, Some(self.ty.std().bytes))) .unwrap_or_else(|| self.unknown()); - self.type_check(rhs.type_id, self.builtins.bytes, text_range); - self.binary_op(BinOp::Concat, lhs, rhs.hir_id, self.builtins.bytes) + self.type_check(rhs.type_id, self.ty.std().bytes, text_range); + self.binary_op(BinOp::Concat, lhs, rhs.hir_id, self.ty.std().bytes) } fn op_subtract(&mut self, lhs: &Value, rhs: Option<&Expr>, text_range: TextRange) -> Value { let rhs = rhs - .map(|rhs| self.compile_expr(rhs, Some(self.builtins.int))) + .map(|rhs| self.compile_expr(rhs, Some(self.ty.std().int))) .unwrap_or_else(|| self.unknown()); - self.type_check(lhs.type_id, self.builtins.int, text_range); - self.type_check(rhs.type_id, self.builtins.int, text_range); - self.binary_op(BinOp::Subtract, lhs.hir_id, rhs.hir_id, self.builtins.int) + self.type_check(lhs.type_id, self.ty.std().int, text_range); + self.type_check(rhs.type_id, self.ty.std().int, text_range); + self.binary_op(BinOp::Subtract, lhs.hir_id, rhs.hir_id, self.ty.std().int) } fn op_multiply(&mut self, lhs: &Value, rhs: Option<&Expr>, text_range: TextRange) -> Value { let rhs = rhs - .map(|rhs| self.compile_expr(rhs, Some(self.builtins.int))) + .map(|rhs| self.compile_expr(rhs, Some(self.ty.std().int))) .unwrap_or_else(|| self.unknown()); - self.type_check(lhs.type_id, self.builtins.int, text_range); - self.type_check(rhs.type_id, self.builtins.int, text_range); - self.binary_op(BinOp::Multiply, lhs.hir_id, rhs.hir_id, self.builtins.int) + self.type_check(lhs.type_id, self.ty.std().int, text_range); + self.type_check(rhs.type_id, self.ty.std().int, text_range); + self.binary_op(BinOp::Multiply, lhs.hir_id, rhs.hir_id, self.ty.std().int) } fn op_divide(&mut self, lhs: &Value, rhs: Option<&Expr>, text_range: TextRange) -> Value { let rhs = rhs - .map(|rhs| self.compile_expr(rhs, Some(self.builtins.int))) + .map(|rhs| self.compile_expr(rhs, Some(self.ty.std().int))) .unwrap_or_else(|| self.unknown()); - self.type_check(lhs.type_id, self.builtins.int, text_range); - self.type_check(rhs.type_id, self.builtins.int, text_range); - self.binary_op(BinOp::Divide, lhs.hir_id, rhs.hir_id, self.builtins.int) + self.type_check(lhs.type_id, self.ty.std().int, text_range); + self.type_check(rhs.type_id, self.ty.std().int, text_range); + self.binary_op(BinOp::Divide, lhs.hir_id, rhs.hir_id, self.ty.std().int) } fn op_remainder(&mut self, lhs: &Value, rhs: Option<&Expr>, text_range: TextRange) -> Value { let rhs = rhs - .map(|rhs| self.compile_expr(rhs, Some(self.builtins.int))) + .map(|rhs| self.compile_expr(rhs, Some(self.ty.std().int))) .unwrap_or_else(|| self.unknown()); - self.type_check(lhs.type_id, self.builtins.int, text_range); - self.type_check(rhs.type_id, self.builtins.int, text_range); - self.binary_op(BinOp::Remainder, lhs.hir_id, rhs.hir_id, self.builtins.int) + self.type_check(lhs.type_id, self.ty.std().int, text_range); + self.type_check(rhs.type_id, self.ty.std().int, text_range); + self.binary_op(BinOp::Remainder, lhs.hir_id, rhs.hir_id, self.ty.std().int) } fn op_equals(&mut self, lhs: &Value, rhs: Option<&Expr>, text_range: TextRange) -> Value { - if matches!(self.db.ty(lhs.type_id), Type::Unknown) { + if matches!(self.ty.get(lhs.type_id), Type::Unknown) { if let Some(rhs) = rhs { self.compile_expr(rhs, None); } @@ -153,15 +146,11 @@ impl Compiler<'_> { .map(|rhs| self.compile_expr(rhs, Some(lhs.type_id))) .unwrap_or_else(|| self.unknown()); - let mut value = self.binary_op(BinOp::Equals, lhs.hir_id, rhs.hir_id, self.builtins.bool); + let mut value = self.binary_op(BinOp::Equals, lhs.hir_id, rhs.hir_id, self.ty.std().bool); let mut is_atom = true; - if !self - .db - .compare_type(lhs.type_id, self.builtins.bytes) - .is_castable() - { + if self.ty.compare(lhs.type_id, self.ty.std().bytes) > Comparison::Castable { self.db.error( ErrorKind::NonAtomEquality(self.type_name(lhs.type_id)), text_range, @@ -169,11 +158,7 @@ impl Compiler<'_> { is_atom = false; } - if !self - .db - .compare_type(rhs.type_id, self.builtins.bytes) - .is_castable() - { + if self.ty.compare(rhs.type_id, self.ty.std().bytes) > Comparison::Castable { self.db.error( ErrorKind::NonAtomEquality(self.type_name(rhs.type_id)), text_range, @@ -181,33 +166,23 @@ impl Compiler<'_> { is_atom = false; } - if self - .db - .compare_type(lhs.type_id, self.builtins.nil) - .is_equal() - { + if self.ty.compare(lhs.type_id, self.ty.std().nil) == Comparison::Equal { if let Some(guard_path) = rhs.guard_path { - let then_type = self.builtins.nil; - let else_type = self.db.non_nullable(rhs.type_id); - value.guards.insert( - guard_path, - Guard::new(TypeOverride::new(then_type), TypeOverride::new(else_type)), - ); + let then_type = self.ty.std().nil; + let else_type = self.ty.difference(rhs.type_id, self.ty.std().nil); + value + .guards + .insert(guard_path, Guard::new(then_type, else_type)); } } - if self - .db - .compare_type(rhs.type_id, self.builtins.nil) - .is_equal() - { + if self.ty.compare(rhs.type_id, self.ty.std().nil) == Comparison::Equal { if let Some(guard_path) = lhs.guard_path.clone() { - let then_type = self.builtins.nil; - let else_type = self.db.non_nullable(lhs.type_id); - value.guards.insert( - guard_path, - Guard::new(TypeOverride::new(then_type), TypeOverride::new(else_type)), - ); + let then_type = self.ty.std().nil; + let else_type = self.ty.difference(lhs.type_id, self.ty.std().nil); + value + .guards + .insert(guard_path, Guard::new(then_type, else_type)); } } @@ -223,7 +198,7 @@ impl Compiler<'_> { let mut value = Value::new( self.db.alloc_hir(Hir::Op(Op::Not, comparison.hir_id)), - self.builtins.bool, + self.ty.std().bool, ); for (symbol_id, guard) in comparison.guards { @@ -240,11 +215,7 @@ impl Compiler<'_> { op: BinaryOp, text_range: TextRange, ) -> Value { - if self - .db - .compare_type(lhs.type_id, self.builtins.bytes) - .is_assignable() - { + if self.ty.compare(lhs.type_id, self.ty.std().bytes) <= Comparison::Assignable { let op = match op { BinaryOp::GreaterThan => BinOp::GreaterThanBytes, BinaryOp::LessThan => BinOp::LessThanBytes, @@ -254,11 +225,11 @@ impl Compiler<'_> { }; let rhs = rhs - .map(|rhs| self.compile_expr(rhs, Some(self.builtins.bytes))) + .map(|rhs| self.compile_expr(rhs, Some(self.ty.std().bytes))) .unwrap_or_else(|| self.unknown()); - self.type_check(rhs.type_id, self.builtins.bytes, text_range); - return self.binary_op(op, lhs.hir_id, rhs.hir_id, self.builtins.bool); + self.type_check(rhs.type_id, self.ty.std().bytes, text_range); + return self.binary_op(op, lhs.hir_id, rhs.hir_id, self.ty.std().bool); } let op = match op { @@ -270,31 +241,31 @@ impl Compiler<'_> { }; let rhs = rhs - .map(|rhs| self.compile_expr(rhs, Some(self.builtins.int))) + .map(|rhs| self.compile_expr(rhs, Some(self.ty.std().int))) .unwrap_or_else(|| self.unknown()); - self.type_check(lhs.type_id, self.builtins.int, text_range); - self.type_check(rhs.type_id, self.builtins.int, text_range); - self.binary_op(op, lhs.hir_id, rhs.hir_id, self.builtins.bool) + self.type_check(lhs.type_id, self.ty.std().int, text_range); + self.type_check(rhs.type_id, self.ty.std().int, text_range); + self.binary_op(op, lhs.hir_id, rhs.hir_id, self.ty.std().bool) } fn op_and(&mut self, lhs: Value, rhs: Option<&Expr>, text_range: TextRange) -> Value { self.type_guard_stack.push(lhs.then_guards()); let rhs = rhs - .map(|rhs| self.compile_expr(rhs, Some(self.builtins.bool))) + .map(|rhs| self.compile_expr(rhs, Some(self.ty.std().bool))) .unwrap_or_else(|| self.unknown()); self.type_guard_stack.pop().unwrap(); - self.type_check(lhs.type_id, self.builtins.bool, text_range); - self.type_check(rhs.type_id, self.builtins.bool, text_range); + self.type_check(lhs.type_id, self.ty.std().bool, text_range); + self.type_check(rhs.type_id, self.ty.std().bool, text_range); let mut value = self.binary_op( BinOp::LogicalAnd, lhs.hir_id, rhs.hir_id, - self.builtins.bool, + self.ty.std().bool, ); value.guards.extend(lhs.guards); value.guards.extend(rhs.guards); @@ -305,74 +276,66 @@ impl Compiler<'_> { self.type_guard_stack.push(lhs.then_guards()); let rhs = rhs - .map(|rhs| self.compile_expr(rhs, Some(self.builtins.bool))) + .map(|rhs| self.compile_expr(rhs, Some(self.ty.std().bool))) .unwrap_or_else(|| self.unknown()); self.type_guard_stack.pop().unwrap(); - self.type_check(lhs.type_id, self.builtins.bool, text_range); - self.type_check(rhs.type_id, self.builtins.bool, text_range); - self.binary_op(BinOp::LogicalOr, lhs.hir_id, rhs.hir_id, self.builtins.bool) + self.type_check(lhs.type_id, self.ty.std().bool, text_range); + self.type_check(rhs.type_id, self.ty.std().bool, text_range); + self.binary_op(BinOp::LogicalOr, lhs.hir_id, rhs.hir_id, self.ty.std().bool) } fn op_bitwise_and(&mut self, lhs: Value, rhs: Option<&Expr>, text_range: TextRange) -> Value { - if self - .db - .compare_type(lhs.type_id, self.builtins.bool) - .is_assignable() - { + if self.ty.compare(lhs.type_id, self.ty.std().bool) <= Comparison::Assignable { let rhs = rhs - .map(|rhs| self.compile_expr(rhs, Some(self.builtins.bool))) + .map(|rhs| self.compile_expr(rhs, Some(self.ty.std().bool))) .unwrap_or_else(|| self.unknown()); - self.type_check(rhs.type_id, self.builtins.bool, text_range); + self.type_check(rhs.type_id, self.ty.std().bool, text_range); - let mut value = self.binary_op(BinOp::All, lhs.hir_id, rhs.hir_id, self.builtins.bool); + let mut value = self.binary_op(BinOp::All, lhs.hir_id, rhs.hir_id, self.ty.std().bool); value.guards.extend(lhs.guards); value.guards.extend(rhs.guards); return value; } let rhs = rhs - .map(|rhs| self.compile_expr(rhs, Some(self.builtins.int))) + .map(|rhs| self.compile_expr(rhs, Some(self.ty.std().int))) .unwrap_or_else(|| self.unknown()); - self.type_check(lhs.type_id, self.builtins.int, text_range); - self.type_check(rhs.type_id, self.builtins.int, text_range); - self.binary_op(BinOp::BitwiseAnd, lhs.hir_id, rhs.hir_id, self.builtins.int) + self.type_check(lhs.type_id, self.ty.std().int, text_range); + self.type_check(rhs.type_id, self.ty.std().int, text_range); + self.binary_op(BinOp::BitwiseAnd, lhs.hir_id, rhs.hir_id, self.ty.std().int) } fn op_bitwise_or(&mut self, lhs: &Value, rhs: Option<&Expr>, text_range: TextRange) -> Value { - if self - .db - .compare_type(lhs.type_id, self.builtins.bool) - .is_assignable() - { + if self.ty.compare(lhs.type_id, self.ty.std().bool) <= Comparison::Assignable { let rhs = rhs - .map(|rhs| self.compile_expr(rhs, Some(self.builtins.bool))) + .map(|rhs| self.compile_expr(rhs, Some(self.ty.std().bool))) .unwrap_or_else(|| self.unknown()); - self.type_check(rhs.type_id, self.builtins.bool, text_range); - return self.binary_op(BinOp::Any, lhs.hir_id, rhs.hir_id, self.builtins.bool); + self.type_check(rhs.type_id, self.ty.std().bool, text_range); + return self.binary_op(BinOp::Any, lhs.hir_id, rhs.hir_id, self.ty.std().bool); } let rhs = rhs - .map(|rhs| self.compile_expr(rhs, Some(self.builtins.int))) + .map(|rhs| self.compile_expr(rhs, Some(self.ty.std().int))) .unwrap_or_else(|| self.unknown()); - self.type_check(lhs.type_id, self.builtins.int, text_range); - self.type_check(rhs.type_id, self.builtins.int, text_range); - self.binary_op(BinOp::BitwiseOr, lhs.hir_id, rhs.hir_id, self.builtins.int) + self.type_check(lhs.type_id, self.ty.std().int, text_range); + self.type_check(rhs.type_id, self.ty.std().int, text_range); + self.binary_op(BinOp::BitwiseOr, lhs.hir_id, rhs.hir_id, self.ty.std().int) } fn op_bitwise_xor(&mut self, lhs: &Value, rhs: Option<&Expr>, text_range: TextRange) -> Value { let rhs = rhs - .map(|rhs| self.compile_expr(rhs, Some(self.builtins.int))) + .map(|rhs| self.compile_expr(rhs, Some(self.ty.std().int))) .unwrap_or_else(|| self.unknown()); - self.type_check(lhs.type_id, self.builtins.int, text_range); - self.type_check(rhs.type_id, self.builtins.int, text_range); - self.binary_op(BinOp::BitwiseXor, lhs.hir_id, rhs.hir_id, self.builtins.int) + self.type_check(lhs.type_id, self.ty.std().int, text_range); + self.type_check(rhs.type_id, self.ty.std().int, text_range); + self.binary_op(BinOp::BitwiseXor, lhs.hir_id, rhs.hir_id, self.ty.std().int) } fn op_left_arith_shift( @@ -382,16 +345,16 @@ impl Compiler<'_> { text_range: TextRange, ) -> Value { let rhs = rhs - .map(|rhs| self.compile_expr(rhs, Some(self.builtins.int))) + .map(|rhs| self.compile_expr(rhs, Some(self.ty.std().int))) .unwrap_or_else(|| self.unknown()); - self.type_check(lhs.type_id, self.builtins.int, text_range); - self.type_check(rhs.type_id, self.builtins.int, text_range); + self.type_check(lhs.type_id, self.ty.std().int, text_range); + self.type_check(rhs.type_id, self.ty.std().int, text_range); self.binary_op( BinOp::LeftArithShift, lhs.hir_id, rhs.hir_id, - self.builtins.int, + self.ty.std().int, ) } @@ -402,16 +365,16 @@ impl Compiler<'_> { text_range: TextRange, ) -> Value { let rhs = rhs - .map(|rhs| self.compile_expr(rhs, Some(self.builtins.int))) + .map(|rhs| self.compile_expr(rhs, Some(self.ty.std().int))) .unwrap_or_else(|| self.unknown()); - self.type_check(lhs.type_id, self.builtins.int, text_range); - self.type_check(rhs.type_id, self.builtins.int, text_range); + self.type_check(lhs.type_id, self.ty.std().int, text_range); + self.type_check(rhs.type_id, self.ty.std().int, text_range); self.binary_op( BinOp::RightArithShift, lhs.hir_id, rhs.hir_id, - self.builtins.int, + self.ty.std().int, ) } } diff --git a/crates/rue-compiler/src/compiler/expr/block_expr.rs b/crates/rue-compiler/src/compiler/expr/block_expr.rs index 8d63073..88eb2c8 100644 --- a/crates/rue-compiler/src/compiler/expr/block_expr.rs +++ b/crates/rue-compiler/src/compiler/expr/block_expr.rs @@ -1,10 +1,11 @@ use rue_parser::{AstNode, Block}; +use rue_typing::TypeId; use crate::{ compiler::{block::BlockTerminator, Compiler}, scope::Scope, value::Value, - ErrorKind, TypeId, + ErrorKind, }; impl Compiler<'_> { diff --git a/crates/rue-compiler/src/compiler/expr/cast_expr.rs b/crates/rue-compiler/src/compiler/expr/cast_expr.rs index 225ac8a..c4a4b73 100644 --- a/crates/rue-compiler/src/compiler/expr/cast_expr.rs +++ b/crates/rue-compiler/src/compiler/expr/cast_expr.rs @@ -7,7 +7,7 @@ impl Compiler<'_> { // It's fine to default to unknown, since the cast check will succeed anyways. let type_id = cast .ty() - .map_or(self.builtins.unknown, |ty| self.compile_type(ty)); + .map_or(self.ty.std().unknown, |ty| self.compile_type(ty)); // Let's used the cast type as the expected type. let Some(expr) = cast diff --git a/crates/rue-compiler/src/compiler/expr/exists_expr.rs b/crates/rue-compiler/src/compiler/expr/exists_expr.rs deleted file mode 100644 index 153a317..0000000 --- a/crates/rue-compiler/src/compiler/expr/exists_expr.rs +++ /dev/null @@ -1,38 +0,0 @@ -use rue_parser::{AstNode, ExistsExpr}; - -use crate::{ - compiler::Compiler, - hir::{Hir, Op}, - value::{Guard, Mutation, Type, TypeOverride, Value}, - ErrorKind, -}; - -impl Compiler<'_> { - pub fn compile_exists_expr(&mut self, exists: &ExistsExpr) -> Value { - let Some(value) = exists.expr().map(|expr| self.compile_expr(&expr, None)) else { - return self.unknown(); - }; - - let Type::Optional(inner) = self.db.ty(value.type_id).clone() else { - self.db.error( - ErrorKind::InvalidExistanceCheck(self.type_name(value.type_id)), - exists.syntax().text_range(), - ); - return self.unknown(); - }; - - let exists = self.db.alloc_hir(Hir::Op(Op::Listp, value.hir_id)); - let mut new_value = Value::new(exists, self.builtins.bool); - - if let Some(guard_path) = value.guard_path { - let mut unwrap = TypeOverride::new(inner); - unwrap.mutation = Mutation::UnwrapOptional; - new_value.guards.insert( - guard_path, - Guard::new(unwrap, TypeOverride::new(value.type_id)), - ); - } - - new_value - } -} diff --git a/crates/rue-compiler/src/compiler/expr/field_access_expr.rs b/crates/rue-compiler/src/compiler/expr/field_access_expr.rs index d76d6a4..dffcc67 100644 --- a/crates/rue-compiler/src/compiler/expr/field_access_expr.rs +++ b/crates/rue-compiler/src/compiler/expr/field_access_expr.rs @@ -1,9 +1,10 @@ -use rue_parser::FieldAccessExpr; +use rue_parser::{FieldAccessExpr, SyntaxToken}; +use rue_typing::{deconstruct_items, index_to_path, Struct, Type, TypeId, TypePath, Variant}; use crate::{ compiler::Compiler, hir::{Hir, Op}, - value::{GuardPathItem, PairType, Rest, Type, Value}, + value::Value, ErrorKind, }; @@ -17,97 +18,42 @@ impl Compiler<'_> { return self.unknown(); }; - let Some(field_name) = field_access.field() else { + let Some(name) = field_access.field() else { return self.unknown(); }; - let mut new_value = match self.db.ty(old_value.type_id).clone() { - Type::Struct(struct_type) => { - if let Some((index, _, &field_type)) = - struct_type.fields.get_full(field_name.text()) - { - let mut type_id = field_type; - - if index == struct_type.fields.len() - 1 && struct_type.rest == Rest::Optional { - type_id = self.db.alloc_type(Type::Optional(type_id)); - } - - Value::new( - self.compile_index( - old_value.hir_id, - index, - index == struct_type.fields.len() - 1 && struct_type.rest != Rest::Nil, - ), - type_id, - ) - .extend_guard_path(old_value, GuardPathItem::Field(field_name.to_string())) - } else { - self.db.error( - ErrorKind::UnknownField(field_name.to_string()), - field_name.text_range(), - ); + let mut new_value = match self.ty.get(old_value.type_id).clone() { + Type::Unknown => return self.unknown(), + Type::Struct(ty) => { + let Some(value) = self.compile_struct_field_access(old_value, &ty, &name) else { return self.unknown(); - } + }; + value } - Type::EnumVariant(variant_type) => { - let fields = variant_type.fields.unwrap_or_default(); - - if let Some((index, _, &field_type)) = fields.get_full(field_name.text()) { - let mut type_id = field_type; - - if index == fields.len() - 1 && variant_type.rest == Rest::Optional { - type_id = self.db.alloc_type(Type::Optional(type_id)); - } - - let fields_hir_id = self.db.alloc_hir(Hir::Op(Op::Rest, old_value.hir_id)); - - Value::new( - self.compile_index( - fields_hir_id, - index, - index == fields.len() - 1 && variant_type.rest != Rest::Nil, - ), - type_id, - ) - .extend_guard_path(old_value, GuardPathItem::Field(field_name.to_string())) - } else { - self.db.error( - ErrorKind::UnknownField(field_name.to_string()), - field_name.text_range(), - ); + Type::Variant(ty) => { + let Some(value) = self.compile_variant_field_access(old_value, &ty, &name) else { return self.unknown(); - } + }; + value } - Type::Pair(PairType { first, rest }) => match field_name.text() { - "first" => Value::new( - self.db.alloc_hir(Hir::Op(Op::First, old_value.hir_id)), - first, - ) - .extend_guard_path(old_value, GuardPathItem::First), - "rest" => Value::new(self.db.alloc_hir(Hir::Op(Op::Rest, old_value.hir_id)), rest) - .extend_guard_path(old_value, GuardPathItem::Rest), - _ => { - self.db.error( - ErrorKind::InvalidFieldAccess( - field_name.to_string(), - self.type_name(old_value.type_id), - ), - field_name.text_range(), - ); + Type::Pair(first, rest) => { + let Some(value) = self.compile_pair_field_access(old_value, first, rest, &name) + else { return self.unknown(); - } - }, - Type::Bytes | Type::Bytes32 if field_name.text() == "length" => Value::new( + }; + value + } + Type::Bytes | Type::Bytes32 if name.text() == "length" => Value::new( self.db.alloc_hir(Hir::Op(Op::Strlen, old_value.hir_id)), - self.builtins.int, + self.ty.std().int, ), _ => { self.db.error( ErrorKind::InvalidFieldAccess( - field_name.to_string(), + name.to_string(), self.type_name(old_value.type_id), ), - field_name.text_range(), + name.text_range(), ); return self.unknown(); } @@ -115,11 +61,126 @@ impl Compiler<'_> { if let Some(guard_path) = new_value.guard_path.as_ref() { if let Some(type_override) = self.symbol_type(guard_path) { - new_value.type_id = type_override.type_id; - new_value.hir_id = self.apply_mutation(new_value.hir_id, type_override.mutation); + new_value.type_id = type_override; } } new_value } + + fn compile_pair_field_access( + &mut self, + old_value: Value, + first: TypeId, + rest: TypeId, + name: &SyntaxToken, + ) -> Option { + let path = match name.text() { + "first" => TypePath::First, + "rest" => TypePath::Rest, + _ => { + self.db.error( + ErrorKind::InvalidFieldAccess( + name.to_string(), + self.type_name(old_value.type_id), + ), + name.text_range(), + ); + return None; + } + }; + + let type_id = match path { + TypePath::First => first, + TypePath::Rest => rest, + }; + + let mut value = Value::new(self.hir_path(old_value.hir_id, &[path]), type_id); + + value.guard_path = old_value.guard_path.map(|mut guard_path| { + guard_path.items.push(path); + guard_path + }); + + Some(value) + } + + fn compile_struct_field_access( + &mut self, + old_value: Value, + ty: &Struct, + name: &SyntaxToken, + ) -> Option { + let fields = + deconstruct_items(self.ty, ty.type_id, ty.field_names.len(), ty.nil_terminated) + .expect("invalid struct type"); + + let Some(index) = ty.field_names.get_index_of(name.text()) else { + self.db + .error(ErrorKind::UnknownField(name.to_string()), name.text_range()); + return None; + }; + + let type_id = fields[index]; + + let path_items = index_to_path( + index, + index != ty.field_names.len() - 1 || ty.nil_terminated, + ); + + let mut value = Value::new(self.hir_path(old_value.hir_id, &path_items), type_id); + + value.guard_path = old_value.guard_path.map(|mut guard_path| { + guard_path.items.extend(path_items); + guard_path + }); + + Some(value) + } + + fn compile_variant_field_access( + &mut self, + old_value: Value, + ty: &Variant, + name: &SyntaxToken, + ) -> Option { + let field_names = ty.field_names.clone().unwrap_or_default(); + + let Type::Enum(enum_type) = self.ty.get(ty.original_enum_type_id) else { + unreachable!(); + }; + + let fields = if enum_type.has_fields { + let type_id = self.ty.get_pair(ty.type_id).expect("expected a pair").1; + + ty.field_names + .as_ref() + .map(|field_names| { + deconstruct_items(self.ty, type_id, field_names.len(), ty.nil_terminated) + .expect("invalid struct type") + }) + .unwrap_or_default() + } else { + Vec::new() + }; + + let Some(index) = field_names.get_index_of(name.text()) else { + self.db + .error(ErrorKind::UnknownField(name.to_string()), name.text_range()); + return None; + }; + + let type_id = fields[index]; + + let path_items = index_to_path(index + 1, index != fields.len() - 1 || ty.nil_terminated); + + let mut value = Value::new(self.hir_path(old_value.hir_id, &path_items), type_id); + + value.guard_path = old_value.guard_path.map(|mut guard_path| { + guard_path.items.extend(path_items); + guard_path + }); + + Some(value) + } } diff --git a/crates/rue-compiler/src/compiler/expr/function_call_expr.rs b/crates/rue-compiler/src/compiler/expr/function_call_expr.rs index 3650762..90a6bf9 100644 --- a/crates/rue-compiler/src/compiler/expr/function_call_expr.rs +++ b/crates/rue-compiler/src/compiler/expr/function_call_expr.rs @@ -1,14 +1,10 @@ -use std::collections::HashMap; +use rue_typing::HashMap; use rowan::TextRange; use rue_parser::{AstNode, FunctionCallExpr}; +use rue_typing::{deconstruct_items, unwrap_list, Callable, TypeId}; -use crate::{ - compiler::Compiler, - hir::Hir, - value::{FunctionType, Rest, Type, Value}, - ErrorKind, -}; +use crate::{compiler::Compiler, hir::Hir, value::Value, ErrorKind}; impl Compiler<'_> { pub fn compile_function_call_expr(&mut self, call: &FunctionCallExpr) -> Value { @@ -18,13 +14,19 @@ impl Compiler<'_> { let callee = call.callee().map(|callee| self.compile_expr(&callee, None)); // Get the function type of the callee. - let function_type = - callee - .as_ref() - .and_then(|callee| match self.db.ty(callee.type_id).clone() { - Type::Function(function_type) => Some(function_type), - _ => None, - }); + let function_type = callee + .as_ref() + .and_then(|callee| self.ty.get_callable_recursive(callee.type_id).cloned()); + + let parameter_types = function_type.as_ref().map(|ty| { + deconstruct_items( + self.ty, + ty.parameters, + ty.parameter_names.len(), + ty.nil_terminated, + ) + .expect("invalid function type") + }); // Make sure the callee is callable, if present. if let Some(callee) = callee.as_ref() { @@ -36,8 +38,48 @@ impl Compiler<'_> { } } + let generic_args = if let Some(generic_params) = function_type.as_ref().and_then(|fun| { + if fun.generic_types.is_empty() { + None + } else { + Some(fun.generic_types.clone()) + } + }) { + if let Some(generic_args) = call.generic_args() { + let types = generic_args.types(); + + if types.len() == generic_params.len() { + let mut generic_types = HashMap::new(); + + for (i, ty) in types.into_iter().enumerate() { + let type_id = self.compile_type(ty); + generic_types.insert(generic_params[i], type_id); + } + + generic_types + } else { + self.db.error( + ErrorKind::GenericArgsMismatch(types.len(), generic_params.len()), + generic_args.syntax().text_range(), + ); + + HashMap::new() + } + } else { + HashMap::new() + } + } else { + if let Some(generic_args) = call.generic_args() { + self.db.error( + ErrorKind::UnexpectedGenericArgs, + generic_args.syntax().text_range(), + ); + } + HashMap::new() + }; + // Push a generic type context for the function, and allow inference. - self.generic_type_stack.push(HashMap::new()); + self.generic_type_stack.push(generic_args); self.allow_generic_inference_stack.push(true); // Compile the arguments naively, and defer type checking until later. @@ -51,16 +93,23 @@ impl Compiler<'_> { .unwrap_or(false); if let Some(function_type) = &function_type { - self.check_argument_length(function_type, len, call.syntax().text_range()); + self.check_argument_length( + function_type, + parameter_types.as_ref().unwrap(), + len, + call.syntax().text_range(), + ); } for (i, arg) in call_args.iter().enumerate() { // Determine the expected type. let expected_type = function_type.as_ref().and_then(|ty| { - if i < ty.param_types.len() { - Some(ty.param_types[i]) - } else if ty.rest == Rest::Spread { - self.db.unwrap_list(*ty.param_types.last().unwrap()) + let parameter_types = parameter_types.as_ref().unwrap(); + + if i < parameter_types.len() { + Some(parameter_types[i]) + } else if !ty.nil_terminated { + unwrap_list(self.ty, *parameter_types.last().unwrap()) } else { None } @@ -90,29 +139,31 @@ impl Compiler<'_> { continue; }; + let parameter_types = parameter_types.as_ref().unwrap(); + if last && spread { - if function.rest != Rest::Spread { + if function.nil_terminated { self.db.error( ErrorKind::UnsupportedFunctionSpread, call_args[i].syntax().text_range(), ); - } else if i >= function.param_types.len() - 1 { - let expected_type = *function.param_types.last().unwrap(); + } else if i >= parameter_types.len() - 1 { + let expected_type = *parameter_types.last().unwrap(); self.type_check(type_id, expected_type, call_args[i].syntax().text_range()); } - } else if function.rest == Rest::Spread && i >= function.param_types.len() - 1 { + } else if !function.nil_terminated && i >= parameter_types.len() - 1 { if let Some(inner_list_type) = - self.db.unwrap_list(*function.param_types.last().unwrap()) + unwrap_list(self.ty, *parameter_types.last().unwrap()) { self.type_check(type_id, inner_list_type, call_args[i].syntax().text_range()); - } else if i == function.param_types.len() - 1 && !spread { + } else if i == parameter_types.len() - 1 && !spread { self.db.error( ErrorKind::RequiredFunctionSpread, call_args[i].syntax().text_range(), ); } - } else if i < function.param_types.len() { - let param_type = function.param_types[i]; + } else if i < parameter_types.len() { + let param_type = parameter_types[i]; self.type_check(type_id, param_type, call_args[i].syntax().text_range()); } } @@ -123,16 +174,15 @@ impl Compiler<'_> { // Calculate the return type. let mut type_id = - function_type.map_or(self.builtins.unknown, |expected| expected.return_type); + function_type.map_or(self.ty.std().unknown, |expected| expected.return_type); if !generic_types.is_empty() { - type_id = self.db.substitute_type(type_id, &generic_types); + type_id = self.ty.substitute(type_id, generic_types); } // Build the HIR for the function call. - let hir_id = self.db.alloc_hir(Hir::FunctionCall( - callee.map_or(self.builtins.unknown_hir, |callee| callee.hir_id), + callee.map_or(self.builtins.unknown, |callee| callee.hir_id), args.iter().map(|arg| arg.hir_id).collect(), spread, )); @@ -142,47 +192,30 @@ impl Compiler<'_> { fn check_argument_length( &mut self, - function: &FunctionType, + function: &Callable, + parameter_types: &[TypeId], length: usize, text_range: TextRange, ) { - match function.rest { - Rest::Nil => { - if length != function.param_types.len() { - self.db.error( - ErrorKind::ArgumentMismatch(length, function.param_types.len()), - text_range, - ); - } - } - Rest::Optional => { - if length != function.param_types.len() && length != function.param_types.len() - 1 - { - self.db.error( - ErrorKind::ArgumentMismatchOptional(length, function.param_types.len()), - text_range, - ); - } + if function.nil_terminated { + if length != parameter_types.len() { + self.db.error( + ErrorKind::ArgumentMismatch(length, parameter_types.len()), + text_range, + ); } - Rest::Spread => { - if self - .db - .unwrap_list(*function.param_types.last().unwrap()) - .is_some() - { - if length < function.param_types.len() - 1 { - self.db.error( - ErrorKind::ArgumentMismatchSpread(length, function.param_types.len()), - text_range, - ); - } - } else if length != function.param_types.len() { - self.db.error( - ErrorKind::ArgumentMismatch(length, function.param_types.len()), - text_range, - ); - } + } else if unwrap_list(self.ty, *parameter_types.last().unwrap()).is_some() { + if length < parameter_types.len() - 1 { + self.db.error( + ErrorKind::ArgumentMismatchSpread(length, parameter_types.len()), + text_range, + ); } + } else if length != parameter_types.len() { + self.db.error( + ErrorKind::ArgumentMismatch(length, parameter_types.len()), + text_range, + ); } } } diff --git a/crates/rue-compiler/src/compiler/expr/group_expr.rs b/crates/rue-compiler/src/compiler/expr/group_expr.rs index a176058..e8245de 100644 --- a/crates/rue-compiler/src/compiler/expr/group_expr.rs +++ b/crates/rue-compiler/src/compiler/expr/group_expr.rs @@ -1,6 +1,7 @@ use rue_parser::GroupExpr; +use rue_typing::TypeId; -use crate::{compiler::Compiler, value::Value, TypeId}; +use crate::{compiler::Compiler, value::Value}; impl Compiler<'_> { pub fn compile_group_expr( diff --git a/crates/rue-compiler/src/compiler/expr/guard_expr.rs b/crates/rue-compiler/src/compiler/expr/guard_expr.rs index c08a01e..8f2d51e 100644 --- a/crates/rue-compiler/src/compiler/expr/guard_expr.rs +++ b/crates/rue-compiler/src/compiler/expr/guard_expr.rs @@ -1,11 +1,11 @@ -use rowan::TextRange; use rue_parser::{AstNode, GuardExpr}; +use rue_typing::{bigint_to_bytes, Check, TypeId}; use crate::{ compiler::Compiler, hir::{BinOp, Hir, Op}, - value::{Guard, PairType, Type, TypeOverride, Value}, - Comparison, ErrorKind, HirId, TypeId, WarningKind, + value::{Guard, Value}, + ErrorKind, HirId, WarningKind, }; impl Compiler<'_> { @@ -21,203 +21,128 @@ impl Compiler<'_> { return self.unknown(); }; - let ty = guard + let rhs = guard .ty() - .map_or(self.builtins.unknown, |ty| self.compile_type(ty)); + .map_or(self.ty.std().unknown, |ty| self.compile_type(ty)); - let Some((guard, hir_id)) = - self.guard_into(expr.type_id, ty, expr.hir_id, guard.syntax().text_range()) - else { - return Value::new(self.builtins.unknown_hir, ty); + let Ok(check) = self.ty.check(expr.type_id, rhs) else { + self.db.error( + ErrorKind::RecursiveTypeCheck(self.type_name(expr.type_id), self.type_name(rhs)), + guard.syntax().text_range(), + ); + return self.unknown(); }; - let mut value = Value::new(hir_id, self.builtins.bool); + match check { + Check::True => { + self.db.warning( + WarningKind::UnnecessaryTypeCheck( + self.type_name(expr.type_id), + self.type_name(rhs), + ), + guard.syntax().text_range(), + ); + } + Check::False => { + self.db.error( + ErrorKind::ImpossibleTypeCheck( + self.type_name(expr.type_id), + self.type_name(rhs), + ), + guard.syntax().text_range(), + ); + return self.unknown(); + } + _ => {} + } + + let hir_id = self.check_hir(expr.hir_id, check); + + let mut value = Value::new(hir_id, self.ty.std().bool); if let Some(guard_path) = expr.guard_path { - value.guards.insert(guard_path, guard); + let difference = self.ty.difference(expr.type_id, rhs); + value.guards.insert(guard_path, Guard::new(rhs, difference)); } value } - fn guard_into( - &mut self, - from: TypeId, - to: TypeId, - hir_id: HirId, - text_range: TextRange, - ) -> Option<(Guard, HirId)> { - if self.db.compare_type(from, to) <= Comparison::Assignable { - self.db.warning( - WarningKind::RedundantTypeCheck(self.type_name(from)), - text_range, - ); - return Some(( - Guard::new(TypeOverride::new(to), TypeOverride::new(self.builtins.bool)), - hir_id, - )); - } - - match (self.db.ty(from).clone(), self.db.ty(to).clone()) { - (Type::Any, Type::Pair(PairType { first, rest })) => { - if !self.db.compare_type(first, self.builtins.any).is_equal() { - self.db.error(ErrorKind::NonAnyPairTypeGuard, text_range); - } - - if !self.db.compare_type(rest, self.builtins.any).is_equal() { - self.db.error(ErrorKind::NonAnyPairTypeGuard, text_range); - } - - let hir_id = self.db.alloc_hir(Hir::Op(Op::Listp, hir_id)); - Some(( - Guard::new( - TypeOverride::new(to), - TypeOverride::new(self.builtins.bytes), - ), - hir_id, - )) - } - (Type::Any, Type::Bytes) => { - let pair_type = self.db.alloc_type(Type::Pair(PairType { - first: self.builtins.any, - rest: self.builtins.any, - })); - let is_cons = self.db.alloc_hir(Hir::Op(Op::Listp, hir_id)); - let hir_id = self.db.alloc_hir(Hir::Op(Op::Not, is_cons)); - Some(( - Guard::new(TypeOverride::new(to), TypeOverride::new(pair_type)), - hir_id, - )) - } - (Type::List(inner), Type::Pair(PairType { first, rest })) => { - if !self.db.compare_type(first, inner).is_equal() { - self.db.error(ErrorKind::NonListPairTypeGuard, text_range); - } - - if !self.db.compare_type(rest, from).is_equal() { - self.db.error(ErrorKind::NonListPairTypeGuard, text_range); - } - - let hir_id = self.db.alloc_hir(Hir::Op(Op::Listp, hir_id)); - Some(( - Guard::new(TypeOverride::new(to), TypeOverride::new(self.builtins.nil)), - hir_id, - )) + fn check_hir(&mut self, hir_id: HirId, check: Check) -> HirId { + match check { + Check::True => self.db.alloc_hir(Hir::Atom(vec![1])), + Check::False => self.db.alloc_hir(Hir::Atom(Vec::new())), + Check::IsAtom => { + let listp = self.db.alloc_hir(Hir::Op(Op::Listp, hir_id)); + self.db.alloc_hir(Hir::Op(Op::Not, listp)) } - (Type::List(inner), Type::Nil) => { - let pair_type = self.db.alloc_type(Type::Pair(PairType { - first: inner, - rest: from, - })); - let is_cons = self.db.alloc_hir(Hir::Op(Op::Listp, hir_id)); - let hir_id = self.db.alloc_hir(Hir::Op(Op::Not, is_cons)); - Some(( - Guard::new(TypeOverride::new(to), TypeOverride::new(pair_type)), - hir_id, - )) + Check::IsPair => self.db.alloc_hir(Hir::Op(Op::Listp, hir_id)), + Check::Value(value) => { + let value = self.db.alloc_hir(Hir::Atom(bigint_to_bytes(value))); + self.db + .alloc_hir(Hir::BinaryOp(BinOp::Equals, hir_id, value)) } - (Type::Bytes, Type::Bytes32) => { + Check::Length(length) => { let strlen = self.db.alloc_hir(Hir::Op(Op::Strlen, hir_id)); - let length = self.db.alloc_hir(Hir::Atom(vec![32])); - let hir_id = self - .db - .alloc_hir(Hir::BinaryOp(BinOp::Equals, strlen, length)); - Some(( - Guard::new(TypeOverride::new(to), TypeOverride::new(from)), - hir_id, - )) + let length = self.db.alloc_hir(Hir::Atom(bigint_to_bytes(length.into()))); + self.db + .alloc_hir(Hir::BinaryOp(BinOp::Equals, strlen, length)) } - (Type::Bytes, Type::PublicKey) => { - let strlen = self.db.alloc_hir(Hir::Op(Op::Strlen, hir_id)); - let length = self.db.alloc_hir(Hir::Atom(vec![48])); - let hir_id = self - .db - .alloc_hir(Hir::BinaryOp(BinOp::Equals, strlen, length)); - Some(( - Guard::new(TypeOverride::new(to), TypeOverride::new(from)), - hir_id, - )) + Check::If(cond, a, b) => { + let cond = self.check_hir(hir_id, *cond); + let a = self.check_hir(hir_id, *a); + let b = self.check_hir(hir_id, *b); + self.db.alloc_hir(Hir::If(cond, a, b)) } - (Type::Enum(enum_type), Type::EnumVariant(variant_type)) => { - if variant_type.enum_type != from { - self.db.error( - ErrorKind::UnsupportedTypeGuard(self.type_name(from), self.type_name(to)), - text_range, - ); - return None; - } - - let hir_id = if enum_type.has_fields { - let first = self.db.alloc_hir(Hir::Op(Op::First, hir_id)); - self.db.alloc_hir(Hir::BinaryOp( - BinOp::Equals, - first, - variant_type.discriminant, - )) + Check::And(mut items) => { + if items.is_empty() { + self.db.alloc_hir(Hir::Atom(vec![1])) + } else if items.len() == 1 { + self.check_hir(hir_id, items.remove(0)) } else { - self.db.alloc_hir(Hir::BinaryOp( - BinOp::Equals, - hir_id, - variant_type.discriminant, - )) - }; - - Some(( - Guard::new(TypeOverride::new(to), TypeOverride::new(from)), - hir_id, - )) - } - (Type::Int, Type::EnumVariant(variant_type)) => { - let Type::Enum(enum_type) = self.db.ty(variant_type.enum_type).clone() else { - self.db.error( - ErrorKind::UnsupportedTypeGuard(self.type_name(from), self.type_name(to)), - text_range, - ); - return None; - }; - - if enum_type.has_fields { - self.db.error( - ErrorKind::UnsupportedTypeGuard(self.type_name(from), self.type_name(to)), - text_range, - ); - return None; - } + let a = self.check_hir(hir_id, items.pop().unwrap()); + let b = self.check_hir(hir_id, items.pop().unwrap()); + + let mut result = self.db.alloc_hir(Hir::BinaryOp(BinOp::LogicalAnd, b, a)); - let hir_id = self.db.alloc_hir(Hir::BinaryOp( - BinOp::Equals, - hir_id, - variant_type.discriminant, - )); + while let Some(item) = items.pop() { + let next = self.check_hir(hir_id, item); + result = self + .db + .alloc_hir(Hir::BinaryOp(BinOp::LogicalAnd, next, result)); + } - Some(( - Guard::new(TypeOverride::new(to), TypeOverride::new(from)), - hir_id, - )) + result + } } - (Type::Nullable(inner), Type::Nil) => { - let hir_id = self.db.alloc_hir(Hir::Op(Op::Not, hir_id)); + Check::Or(mut items) => { + if items.is_empty() { + unreachable!() + } else if items.len() == 1 { + self.check_hir(hir_id, items.remove(1)) + } else { + let a = self.check_hir(hir_id, items.pop().unwrap()); + let b = self.check_hir(hir_id, items.pop().unwrap()); - Some(( - Guard::new(TypeOverride::new(to), TypeOverride::new(inner)), - hir_id, - )) + let mut result = self.db.alloc_hir(Hir::BinaryOp(BinOp::LogicalOr, b, a)); + + while let Some(item) = items.pop() { + let next = self.check_hir(hir_id, item); + result = self + .db + .alloc_hir(Hir::BinaryOp(BinOp::LogicalOr, next, result)); + } + + result + } } - (Type::Nullable(inner), _) if self.db.compare_type(to, inner).is_equal() => { - let hir_id = self.db.alloc_hir(Hir::Op(Op::Not, hir_id)); - let hir_id = self.db.alloc_hir(Hir::Op(Op::Not, hir_id)); - - Some(( - Guard::new(TypeOverride::new(to), TypeOverride::new(inner)), - hir_id, - )) + Check::First(check) => { + let first = self.db.alloc_hir(Hir::Op(Op::First, hir_id)); + self.check_hir(first, *check) } - _ => { - self.db.error( - ErrorKind::UnsupportedTypeGuard(self.type_name(from), self.type_name(to)), - text_range, - ); - None + Check::Rest(check) => { + let rest = self.db.alloc_hir(Hir::Op(Op::Rest, hir_id)); + self.check_hir(rest, *check) } } } diff --git a/crates/rue-compiler/src/compiler/expr/if_expr.rs b/crates/rue-compiler/src/compiler/expr/if_expr.rs index c70e6dc..9c226b4 100644 --- a/crates/rue-compiler/src/compiler/expr/if_expr.rs +++ b/crates/rue-compiler/src/compiler/expr/if_expr.rs @@ -1,12 +1,13 @@ use rue_parser::{AstNode, IfExpr}; +use rue_typing::TypeId; -use crate::{compiler::Compiler, hir::Hir, value::Value, TypeId}; +use crate::{compiler::Compiler, hir::Hir, value::Value}; impl Compiler<'_> { pub fn compile_if_expr(&mut self, if_expr: &IfExpr, expected_type: Option) -> Value { let condition = if_expr .condition() - .map(|condition| self.compile_expr(&condition, Some(self.builtins.bool))); + .map(|condition| self.compile_expr(&condition, Some(self.ty.std().bool))); if let Some(condition) = condition.as_ref() { self.type_guard_stack.push(condition.then_guards()); @@ -38,7 +39,7 @@ impl Compiler<'_> { if let Some(condition_type) = condition.as_ref().map(|condition| condition.type_id) { self.type_check( condition_type, - self.builtins.bool, + self.ty.std().bool, if_expr.condition().unwrap().syntax().text_range(), ); } @@ -54,7 +55,7 @@ impl Compiler<'_> { let ty = then_block .as_ref() .or(else_block.as_ref()) - .map_or(self.builtins.unknown, |block| block.type_id); + .map_or(self.ty.std().unknown, |block| block.type_id); let value = condition.and_then(|condition| { then_block.and_then(|then_block| { @@ -68,6 +69,6 @@ impl Compiler<'_> { }) }); - Value::new(value.unwrap_or(self.builtins.unknown_hir), ty) + Value::new(value.unwrap_or(self.builtins.unknown), ty) } } diff --git a/crates/rue-compiler/src/compiler/expr/index_access_expr.rs b/crates/rue-compiler/src/compiler/expr/index_access_expr.rs deleted file mode 100644 index 2b04cc5..0000000 --- a/crates/rue-compiler/src/compiler/expr/index_access_expr.rs +++ /dev/null @@ -1,38 +0,0 @@ -use rue_parser::{AstNode, IndexAccessExpr}; - -use crate::{ - compiler::Compiler, - value::{Type, Value}, - ErrorKind, -}; - -impl Compiler<'_> { - pub fn compile_index_access_expr(&mut self, index_access: &IndexAccessExpr) -> Value { - let Some(value) = index_access - .expr() - .map(|expr| self.compile_expr(&expr, None)) - else { - return self.unknown(); - }; - - let Some(index_token) = index_access.index() else { - return self.unknown(); - }; - - let index = index_token - .text() - .replace('_', "") - .parse() - .expect("failed to parse integer literal"); - - let Type::List(item_type) = self.db.ty(value.type_id).clone() else { - self.db.error( - ErrorKind::InvalidIndexAccess(self.type_name(value.type_id)), - index_access.expr().unwrap().syntax().text_range(), - ); - return self.unknown(); - }; - - Value::new(self.compile_index(value.hir_id, index, false), item_type) - } -} diff --git a/crates/rue-compiler/src/compiler/expr/initializer_expr.rs b/crates/rue-compiler/src/compiler/expr/initializer_expr.rs index bbb3759..b9ddd97 100644 --- a/crates/rue-compiler/src/compiler/expr/initializer_expr.rs +++ b/crates/rue-compiler/src/compiler/expr/initializer_expr.rs @@ -1,27 +1,31 @@ -use std::collections::HashMap; +use rue_typing::HashMap; use indexmap::IndexMap; use rowan::TextRange; use rue_parser::{AstNode, InitializerExpr, InitializerField}; +use rue_typing::{bigint_to_bytes, deconstruct_items, Type, TypeId}; -use crate::{ - compiler::Compiler, - hir::Hir, - value::{Rest, Type, Value}, - ErrorKind, HirId, TypeId, -}; +use crate::{compiler::Compiler, hir::Hir, value::Value, ErrorKind, HirId}; impl Compiler<'_> { pub fn compile_initializer_expr(&mut self, initializer: &InitializerExpr) -> Value { let ty = initializer .path() - .map(|path| self.compile_path_type(&path.idents(), path.syntax().text_range())); + .map(|path| self.compile_path_type(&path.items(), path.syntax().text_range())); - match ty.map(|ty| self.db.ty(ty)).cloned() { + match ty.map(|ty| self.ty.get(ty)).cloned() { Some(Type::Struct(struct_type)) => { + let fields = deconstruct_items( + self.ty, + struct_type.type_id, + struct_type.field_names.len(), + struct_type.nil_terminated, + ) + .expect("invalid variant type"); + let hir_id = self.compile_initializer_fields( - &struct_type.fields, - struct_type.rest, + &struct_type.field_names.into_iter().zip(fields).collect(), + struct_type.nil_terminated, initializer.fields(), initializer.syntax().text_range(), ); @@ -31,18 +35,43 @@ impl Compiler<'_> { None => self.unknown(), } } - Some(Type::EnumVariant(enum_variant)) => { - if let Some(fields) = enum_variant.fields { + Some(Type::Variant(enum_variant)) => { + if let Some(field_names) = enum_variant.field_names { + let Type::Enum(enum_type) = self.ty.get(enum_variant.original_enum_type_id) + else { + unreachable!(); + }; + + let fields = if enum_type.has_fields { + let type_id = self + .ty + .get_pair(enum_variant.type_id) + .expect("expected a pair") + .1; + + deconstruct_items( + self.ty, + type_id, + field_names.len(), + enum_variant.nil_terminated, + ) + .expect("invalid struct type") + } else { + Vec::new() + }; + let fields_hir_id = self.compile_initializer_fields( - &fields, - enum_variant.rest, + &field_names.into_iter().zip(fields).collect(), + enum_variant.nil_terminated, initializer.fields(), initializer.syntax().text_range(), ); - let hir_id = self + let discriminant = self .db - .alloc_hir(Hir::Pair(enum_variant.discriminant, fields_hir_id)); + .alloc_hir(Hir::Atom(bigint_to_bytes(enum_variant.discriminant))); + + let hir_id = self.db.alloc_hir(Hir::Pair(discriminant, fields_hir_id)); match ty { Some(struct_type) => Value::new(hir_id, struct_type), @@ -70,12 +99,11 @@ impl Compiler<'_> { fn compile_initializer_fields( &mut self, struct_fields: &IndexMap, - rest: Rest, + nil_terminated: bool, initializer_fields: Vec, text_range: TextRange, ) -> HirId { let mut specified_fields = HashMap::new(); - let mut optional = false; for field in initializer_fields { let Some(name) = field.name() else { @@ -84,23 +112,15 @@ impl Compiler<'_> { let expected_type = struct_fields.get(name.text()).copied(); - let mut value = field + let value = field .expr() .map(|expr| self.compile_expr(&expr, expected_type)) .unwrap_or(self.unknown()); - // Resolve optional fields. - if rest == Rest::Optional - && struct_fields.get_index_of(name.text()) == Some(struct_fields.len() - 1) - { - optional |= matches!(self.db.ty(value.type_id), Type::Optional(..)); - value.type_id = self.db.non_undefined(value.type_id); - } - // Check the type of the field initializer. self.type_check( value.type_id, - expected_type.unwrap_or(self.builtins.unknown), + expected_type.unwrap_or(self.ty.std().unknown), field.syntax().text_range(), ); @@ -123,14 +143,8 @@ impl Compiler<'_> { // Check for any missing fields and report them. let missing_fields: Vec = struct_fields .keys() - .enumerate() - .filter(|&(i, name)| { - if rest == Rest::Optional && i == struct_fields.len() - 1 { - return false; - } - !specified_fields.contains_key(name) - }) - .map(|(_, name)| name.to_string()) + .filter(|name| !specified_fields.contains_key(*name)) + .cloned() .collect(); if !missing_fields.is_empty() { @@ -140,19 +154,15 @@ impl Compiler<'_> { ); } - let mut hir_id = self.builtins.nil_hir; + let mut hir_id = self.builtins.nil; // Construct a nil-terminated list from the arguments. for (i, field) in struct_fields.keys().rev().enumerate() { let value = specified_fields.get(field).copied(); - if i == 0 && rest == Rest::Optional && value.is_none() { - continue; - } - - let field = value.unwrap_or(self.builtins.unknown_hir); + let field = value.unwrap_or(self.builtins.unknown); - if i == 0 && (rest == Rest::Spread || (rest == Rest::Optional && optional)) { + if i == 0 && !nil_terminated { hir_id = field; } else { hir_id = self.db.alloc_hir(Hir::Pair(field, hir_id)); diff --git a/crates/rue-compiler/src/compiler/expr/lambda_expr.rs b/crates/rue-compiler/src/compiler/expr/lambda_expr.rs index d969a6a..9ef1e4a 100644 --- a/crates/rue-compiler/src/compiler/expr/lambda_expr.rs +++ b/crates/rue-compiler/src/compiler/expr/lambda_expr.rs @@ -1,14 +1,16 @@ -use std::collections::HashMap; +use rue_typing::HashMap; +use indexmap::IndexSet; use rue_parser::{AstNode, LambdaExpr}; +use rue_typing::{construct_items, deconstruct_items, Callable, Type, TypeId}; use crate::{ compiler::Compiler, hir::Hir, scope::Scope, symbol::{Function, Symbol}, - value::{FunctionType, Rest, Type, Value}, - ErrorKind, TypeId, + value::Value, + ErrorKind, }; impl Compiler<'_> { @@ -25,9 +27,14 @@ impl Compiler<'_> { } // Determine the expected type of the lambda expression. - let expected = expected_type.and_then(|ty| match self.db.ty(ty) { - Type::Function(function) => Some(function.clone()), - _ => None, + let expected = expected_type.and_then(|ty| self.ty.get_callable(ty).cloned()); + let expected_params = expected.as_ref().and_then(|callable| { + deconstruct_items( + self.ty, + callable.parameters, + callable.parameter_names.len(), + callable.nil_terminated, + ) }); // Add the scope so you can track generic types. @@ -38,12 +45,12 @@ impl Compiler<'_> { // Add the generic types to the scope. for generic_type in lambda_expr - .generic_types() - .map(|generics| generics.idents()) + .generic_params() + .map(|generics| generics.names()) .unwrap_or_default() { // Create the generic type id. - let type_id = self.db.alloc_type(Type::Generic); + let type_id = self.ty.alloc(Type::Generic); // Check for duplicate generic types. if self.scope().ty(generic_type.text()).is_some() { @@ -64,7 +71,8 @@ impl Compiler<'_> { } let mut param_types = Vec::new(); - let mut rest = Rest::Nil; + let mut param_names = IndexSet::new(); + let mut nil_terminated = true; let len = lambda_expr.params().len(); @@ -73,57 +81,40 @@ impl Compiler<'_> { let type_id = param .ty() .map(|ty| self.compile_type(ty)) - .or(expected + .or(expected_params .as_ref() - .and_then(|expected| expected.param_types.get(i).copied())) + .and_then(|expected| expected.get(i).copied())) .unwrap_or_else(|| { self.db .error(ErrorKind::CannotInferType, param.syntax().text_range()); - self.builtins.unknown + self.ty.std().unknown }); // Substitute generic types in the parameter type. - let type_id = self.db.substitute_type(type_id, &substitutions); + let type_id = self.ty.substitute(type_id, substitutions.clone()); param_types.push(type_id); if let Some(name) = param.name() { - let param_type_id = if param.optional().is_some() { - // If the parameter is optional, wrap the type in a possibly undefined type. - // This prevents referencing the parameter until it's checked for undefined. - self.db.alloc_type(Type::Optional(type_id)) - } else { - type_id - }; - - let symbol_id = self.db.alloc_symbol(Symbol::Parameter(param_type_id)); + let symbol_id = self.db.alloc_symbol(Symbol::Parameter(type_id)); self.scope_mut().define_symbol(name.to_string(), symbol_id); + param_names.insert(name.to_string()); self.db.insert_symbol_token(symbol_id, name); + } else { + param_names.insert(format!("#{i}")); }; let last = i + 1 == len; let spread = param.spread().is_some(); - let optional = param.optional().is_some(); - if spread && optional { - self.db.error( - ErrorKind::OptionalParameterSpread, - param.syntax().text_range(), - ); - } else if spread && !last { - self.db.error( - ErrorKind::InvalidSpreadParameter, - param.syntax().text_range(), - ); - } else if optional && !last { - self.db.error( - ErrorKind::InvalidOptionalParameter, - param.syntax().text_range(), - ); - } else if spread { - rest = Rest::Spread; - } else if optional { - rest = Rest::Optional; + if spread { + if !last { + self.db.error( + ErrorKind::InvalidSpreadParameter, + param.syntax().text_range(), + ); + } + nil_terminated = false; } } @@ -150,23 +141,29 @@ impl Compiler<'_> { lambda_expr.body().unwrap().syntax().text_range(), ); - let ty = FunctionType { - param_types: param_types.clone(), - rest, + let type_id = self.ty.alloc(Type::Unknown); + let parameters = construct_items(self.ty, param_types.into_iter(), nil_terminated); + + *self.ty.get_mut(type_id) = Type::Callable(Callable { + original_type_id: type_id, + parameter_names: param_names, + parameters, + nil_terminated, return_type, generic_types: Vec::new(), - }; + }); let symbol_id = self.db.alloc_symbol(Symbol::Function(Function { scope_id, hir_id: body.hir_id, - ty: ty.clone(), + type_id, + nil_terminated, })); Value::new( self.db .alloc_hir(Hir::Reference(symbol_id, lambda_expr.syntax().text_range())), - self.db.alloc_type(Type::Function(ty)), + type_id, ) } } diff --git a/crates/rue-compiler/src/compiler/expr/list_expr.rs b/crates/rue-compiler/src/compiler/expr/list_expr.rs index 6ec6063..7a17374 100644 --- a/crates/rue-compiler/src/compiler/expr/list_expr.rs +++ b/crates/rue-compiler/src/compiler/expr/list_expr.rs @@ -1,64 +1,24 @@ -use rue_parser::{AstNode, ListExpr}; +use rue_parser::ListExpr; +use rue_typing::construct_items; -use crate::{ - compiler::Compiler, - hir::Hir, - value::{Type, Value}, - ErrorKind, TypeId, -}; +use crate::{compiler::Compiler, hir::Hir, value::Value, ErrorKind}; impl Compiler<'_> { - pub fn compile_list_expr( - &mut self, - list_expr: &ListExpr, - expected_expr_type: Option, - ) -> Value { + pub fn compile_list_expr(&mut self, list_expr: &ListExpr) -> Value { let mut items = Vec::new(); + let mut types = Vec::new(); let mut nil_terminated = true; - let mut list_type = expected_expr_type; - let mut item_type = expected_expr_type.and_then(|ty| match self.db.ty(ty) { - Type::List(ty) => Some(*ty), - _ => None, - }); - - let len = list_expr.items().len(); + let length = list_expr.items().len(); for (i, item) in list_expr.items().into_iter().enumerate() { - let expected_item_type = if item.spread().is_some() { - list_type - } else { - item_type - }; - let output = item .expr() - .map(|expr| self.compile_expr(&expr, expected_item_type)) + .map(|expr| self.compile_expr(&expr, None)) .unwrap_or(self.unknown()); - if let Some(expected_item_type) = expected_item_type { - self.type_check( - output.type_id, - expected_item_type, - item.syntax().text_range(), - ); - } - - if i == 0 && item_type.is_none() { - if item.spread().is_some() { - list_type = Some(output.type_id); - item_type = match self.db.ty(output.type_id) { - Type::List(ty) => Some(*ty), - _ => None, - }; - } else { - list_type = Some(self.db.alloc_type(Type::List(output.type_id))); - item_type = Some(output.type_id); - } - } - if let Some(spread) = item.spread() { - if i + 1 == len { + if i + 1 == length { nil_terminated = false; } else { self.db @@ -67,9 +27,10 @@ impl Compiler<'_> { } items.push(output.hir_id); + types.push(output.type_id); } - let mut hir_id = self.builtins.nil_hir; + let mut hir_id = self.builtins.nil; for (i, item) in items.into_iter().rev().enumerate() { if i == 0 && !nil_terminated { @@ -79,10 +40,8 @@ impl Compiler<'_> { } } - Value::new( - hir_id, - self.db - .alloc_type(Type::List(item_type.unwrap_or(self.builtins.unknown))), - ) + let type_id = construct_items(self.ty, types.into_iter(), nil_terminated); + + Value::new(hir_id, type_id) } } diff --git a/crates/rue-compiler/src/compiler/expr/literal_expr.rs b/crates/rue-compiler/src/compiler/expr/literal_expr.rs index bacacde..05dfefa 100644 --- a/crates/rue-compiler/src/compiler/expr/literal_expr.rs +++ b/crates/rue-compiler/src/compiler/expr/literal_expr.rs @@ -1,8 +1,7 @@ -use clvmr::Allocator; -use num_bigint::BigInt; use rue_parser::{LiteralExpr, SyntaxKind, SyntaxToken}; +use rue_typing::bigint_to_bytes; -use crate::{compiler::Compiler, hir::Hir, value::Value, ErrorKind}; +use crate::{compiler::Compiler, hir::Hir, value::Value}; impl Compiler<'_> { pub fn compile_literal_expr(&mut self, literal: &LiteralExpr) -> Value { @@ -23,11 +22,11 @@ impl Compiler<'_> { fn compile_bool_literal(&mut self, value: bool) -> Value { let atom = if value { vec![1] } else { vec![] }; - Value::new(self.db.alloc_hir(Hir::Atom(atom)), self.builtins.bool) + Value::new(self.db.alloc_hir(Hir::Atom(atom)), self.ty.std().bool) } fn compile_nil_literal(&mut self) -> Value { - Value::new(self.db.alloc_hir(Hir::Atom(Vec::new())), self.builtins.nil) + Value::new(self.db.alloc_hir(Hir::Atom(Vec::new())), self.ty.std().nil) } fn compile_int_literal(&mut self, int: &SyntaxToken) -> Value { @@ -39,13 +38,10 @@ impl Compiler<'_> { .parse() .expect("failed to parse integer literal"); - let atom = Self::bigint_to_bytes(bigint).unwrap_or_else(|| { - self.db.error(ErrorKind::IntegerTooLarge, int.text_range()); - Vec::new() - }); + let atom = bigint_to_bytes(bigint); // Extract the atom representation of the number. - Value::new(self.db.alloc_hir(Hir::Atom(atom)), self.builtins.int) + Value::new(self.db.alloc_hir(Hir::Atom(atom)), self.ty.std().int) } fn compile_hex_literal(&mut self, hex: &SyntaxToken) -> Value { @@ -67,15 +63,15 @@ impl Compiler<'_> { if bytes_len == 32 { // We'll assume this is a `Bytes32` since it's the correct length. // This makes putting hashes in the code more convenient. - self.builtins.bytes32 + self.ty.std().bytes32 } else if bytes_len == 48 { // We'll assume this is a `PublicKey` since it's the correct length. // It's unlikely to intend the type being `Bytes`, but you can cast if needed. - self.builtins.public_key + self.ty.std().public_key } else { // Everything else is just `Bytes`. // Leading zeros are not removed, so `0x00` is different than `0`. - self.builtins.bytes + self.ty.std().bytes }, ) } @@ -89,18 +85,7 @@ impl Compiler<'_> { Value::new( self.db .alloc_hir(Hir::Atom(text.replace(quote, "").as_bytes().to_vec())), - self.builtins.bytes, + self.ty.std().bytes, ) } - - pub fn bigint_to_bytes(bigint: BigInt) -> Option> { - // Create a CLVM allocator. - let mut allocator = Allocator::new(); - - // Try to allocate the number. - let ptr = allocator.new_number(bigint).ok()?; - - // Extract the atom representation of the number. - Some(allocator.atom(ptr).as_ref().to_vec()) - } } diff --git a/crates/rue-compiler/src/compiler/expr/pair_expr.rs b/crates/rue-compiler/src/compiler/expr/pair_expr.rs index ad60ba3..901319a 100644 --- a/crates/rue-compiler/src/compiler/expr/pair_expr.rs +++ b/crates/rue-compiler/src/compiler/expr/pair_expr.rs @@ -1,17 +1,13 @@ use rue_parser::{AstNode, PairExpr}; +use rue_typing::{Type, TypeId}; -use crate::{ - compiler::Compiler, - hir::Hir, - value::{PairType, Type, Value}, - TypeId, -}; +use crate::{compiler::Compiler, hir::Hir, value::Value}; impl Compiler<'_> { pub fn compile_pair_expr(&mut self, pair: &PairExpr, expected_type: Option) -> Value { // Extract the first and rest type out of the expected type. - let first = expected_type.and_then(|type_id| self.db.first_type(type_id)); - let rest = expected_type.and_then(|type_id| self.db.rest_type(type_id)); + let first = expected_type.and_then(|type_id| Some(self.ty.get_pair(type_id)?.0)); + let rest = expected_type.and_then(|type_id| Some(self.ty.get_pair(type_id)?.1)); // Compile the first expression, if present. // It's a parser error if not, so it's fine to return unknown. @@ -21,7 +17,7 @@ impl Compiler<'_> { let value = self.compile_expr(&expr, first); self.type_check( value.type_id, - first.unwrap_or(self.builtins.unknown), + first.unwrap_or(self.ty.std().unknown), expr.syntax().text_range(), ); value @@ -36,7 +32,7 @@ impl Compiler<'_> { let value = self.compile_expr(&expr, rest); self.type_check( value.type_id, - rest.unwrap_or(self.builtins.unknown), + rest.unwrap_or(self.ty.std().unknown), expr.syntax().text_range(), ); value @@ -44,10 +40,7 @@ impl Compiler<'_> { .unwrap_or_else(|| self.unknown()); let hir_id = self.db.alloc_hir(Hir::Pair(first.hir_id, rest.hir_id)); - let type_id = self.db.alloc_type(Type::Pair(PairType { - first: first.type_id, - rest: rest.type_id, - })); + let type_id = self.ty.alloc(Type::Pair(first.type_id, rest.type_id)); // We throw away type guards by creating this new value. // They shouldn't be relevant since the type is not `Bool`. diff --git a/crates/rue-compiler/src/compiler/expr/path_expr.rs b/crates/rue-compiler/src/compiler/expr/path_expr.rs index 0631c36..f857f4d 100644 --- a/crates/rue-compiler/src/compiler/expr/path_expr.rs +++ b/crates/rue-compiler/src/compiler/expr/path_expr.rs @@ -1,97 +1,90 @@ use rowan::TextRange; -use rue_parser::SyntaxToken; +use rue_parser::PathItem; +use rue_typing::{bigint_to_bytes, Type}; use crate::{ compiler::{ - path::{PathItem, PathKind}, + path::{Path, PathKind}, Compiler, }, hir::Hir, symbol::{Function, Symbol}, - value::{GuardPath, Type, Value}, + value::{GuardPath, Value}, ErrorKind, }; impl Compiler<'_> { - pub fn compile_path_expr(&mut self, idents: &[SyntaxToken], text_range: TextRange) -> Value { - let Some(mut item) = - self.resolve_base_path(&idents[0], PathKind::Symbol, idents.len() == 1) + pub fn compile_path_expr(&mut self, items: &[PathItem], text_range: TextRange) -> Value { + let Some(mut path) = self.resolve_base_path(&items[0], PathKind::Symbol, items.len() == 1) else { return self.unknown(); }; - let mut last_ident = idents[0].to_string(); + let mut last_name = items[0].name().unwrap().to_string(); - for (i, name) in idents.iter().enumerate().skip(1) { - let Some(next_item) = - self.resolve_next_path(item, name, PathKind::Symbol, i == idents.len() - 1) + for (i, item) in items.iter().enumerate().skip(1) { + let Some(next_path) = + self.resolve_next_path(path, item, PathKind::Symbol, i == items.len() - 1) else { return self.unknown(); }; - last_ident = name.to_string(); - item = next_item; + last_name = item.name().unwrap().to_string(); + path = next_path; } - let symbol_id = match item { - PathItem::Symbol(symbol_id) => symbol_id, - PathItem::Type(type_id) => { - if let Type::EnumVariant(variant_type) = self.db.ty(type_id).clone() { - if variant_type.fields.is_some() { + let symbol_id = match path { + Path::Symbol(symbol_id) => symbol_id, + Path::Type(type_id) => { + if let Type::Variant(variant) = self.ty.get(type_id).clone() { + if variant.field_names.is_some() { self.db.error( ErrorKind::InvalidEnumVariantReference(self.type_name(type_id)), text_range, ); } - let Type::Enum(enum_type) = self.db.ty(variant_type.enum_type) else { + let Type::Enum(enum_type) = self.ty.get(variant.original_enum_type_id) else { unreachable!(); }; - let mut hir_id = variant_type.discriminant; + let mut hir_id = self + .db + .alloc_hir(Hir::Atom(bigint_to_bytes(variant.discriminant))); if enum_type.has_fields { - hir_id = self.db.alloc_hir(Hir::Pair(hir_id, self.builtins.nil_hir)); + hir_id = self.db.alloc_hir(Hir::Pair(hir_id, self.builtins.nil)); } return Value::new(hir_id, type_id); } self.db - .error(ErrorKind::ExpectedSymbolPath(last_ident), text_range); + .error(ErrorKind::ExpectedSymbolPath(last_name), text_range); return self.unknown(); } }; if matches!(self.db.symbol(symbol_id), Symbol::Module(..)) { self.db - .error(ErrorKind::ModuleReference(last_ident), text_range); + .error(ErrorKind::ModuleReference(last_name), text_range); return self.unknown(); } if !self.is_callee && matches!(self.db.symbol(symbol_id), Symbol::InlineFunction(..)) { self.db - .error(ErrorKind::InlineFunctionReference(last_ident), text_range); + .error(ErrorKind::InlineFunctionReference(last_name), text_range); return self.unknown(); } let type_override = self.symbol_type(&GuardPath::new(symbol_id)); - let override_type_id = type_override.map(|ty| ty.type_id); - let mut reference = self.db.alloc_hir(Hir::Reference(symbol_id, text_range)); - - if let Some(mutation) = type_override.map(|ty| ty.mutation) { - reference = self.apply_mutation(reference, mutation); - } + let reference = self.db.alloc_hir(Hir::Reference(symbol_id, text_range)); let mut value = match self.db.symbol(symbol_id).clone() { Symbol::Unknown | Symbol::Module(..) => unreachable!(), - Symbol::Function(Function { ty, .. }) | Symbol::InlineFunction(Function { ty, .. }) => { - let type_id = self.db.alloc_type(Type::Function(ty.clone())); - Value::new(reference, override_type_id.unwrap_or(type_id)) - } - Symbol::Parameter(type_id) => { - Value::new(reference, override_type_id.unwrap_or(type_id)) - } + Symbol::Function(Function { type_id, .. }) + | Symbol::InlineFunction(Function { type_id, .. }) + | Symbol::Parameter(type_id) => Value::new(reference, type_override.unwrap_or(type_id)), Symbol::Let(mut value) | Symbol::Const(mut value) | Symbol::InlineConst(mut value) => { - if let Some(type_id) = override_type_id { + if let Some(type_id) = type_override { value.type_id = type_id; } value.hir_id = reference; diff --git a/crates/rue-compiler/src/compiler/expr/prefix_expr.rs b/crates/rue-compiler/src/compiler/expr/prefix_expr.rs index 3369395..c4e8a49 100644 --- a/crates/rue-compiler/src/compiler/expr/prefix_expr.rs +++ b/crates/rue-compiler/src/compiler/expr/prefix_expr.rs @@ -11,9 +11,9 @@ impl Compiler<'_> { // Determine the expected type based on the prefix operator. let expected_type = match prefix_expr.op() { Some(PrefixOp::BitwiseNot | PrefixOp::Positive | PrefixOp::Negative) => { - Some(self.builtins.int) + Some(self.ty.std().int) } - Some(PrefixOp::Not) => Some(self.builtins.bool), + Some(PrefixOp::Not) => Some(self.ty.std().bool), None => None, }; @@ -32,7 +32,7 @@ impl Compiler<'_> { // Check the type of the expression. self.type_check( expr.type_id, - expected_type.unwrap_or(self.builtins.unknown), + expected_type.unwrap_or(self.ty.std().unknown), prefix_expr .expr() .map_or(prefix_expr.syntax().text_range(), |ast| { @@ -45,7 +45,7 @@ impl Compiler<'_> { // Negate the expression and its type guards. let mut value = Value::new( self.db.alloc_hir(Hir::Op(Op::Not, expr.hir_id)), - self.builtins.bool, + self.ty.std().bool, ); for (symbol_id, guard) in expr.guards { @@ -58,10 +58,10 @@ impl Compiler<'_> { // Subtract the expression from nil. self.db.alloc_hir(Hir::BinaryOp( BinOp::Subtract, - self.builtins.nil_hir, + self.builtins.nil, expr.hir_id, )), - self.builtins.int, + self.ty.std().int, ), PrefixOp::Positive => { // Return the expression as is. @@ -71,7 +71,7 @@ impl Compiler<'_> { // Negate the expression and its type guards. Value::new( self.db.alloc_hir(Hir::Op(Op::BitwiseNot, expr.hir_id)), - self.builtins.int, + self.ty.std().int, ) } } diff --git a/crates/rue-compiler/src/compiler/item.rs b/crates/rue-compiler/src/compiler/item.rs index 19e5345..d5fc38e 100644 --- a/crates/rue-compiler/src/compiler/item.rs +++ b/crates/rue-compiler/src/compiler/item.rs @@ -1,8 +1,9 @@ -use std::collections::HashSet; +use rue_typing::HashSet; use rue_parser::Item; +use rue_typing::{Type, TypeId}; -use crate::{symbol::Symbol, value::Type, ErrorKind, SymbolId, TypeId}; +use crate::{symbol::Symbol, ErrorKind, ScopeId, SymbolId}; use super::Compiler; @@ -17,6 +18,7 @@ mod type_alias_item; pub struct Declarations { pub type_ids: Vec, pub symbol_ids: Vec, + pub scope_ids: Vec, pub exported_types: Vec, pub exported_symbols: Vec, } @@ -29,12 +31,17 @@ impl Compiler<'_> { let mut type_ids = Vec::new(); let mut symbol_ids = Vec::new(); + let mut scope_ids = Vec::new(); let mut exported_types = Vec::new(); let mut exported_symbols = Vec::new(); for item in items { match item { - Item::TypeAliasItem(ty) => type_ids.push(self.declare_type_alias_item(ty)), + Item::TypeAliasItem(ty) => { + let (type_id, scope_id) = self.declare_type_alias_item(ty); + type_ids.push(type_id); + scope_ids.push(scope_id); + } Item::StructItem(struct_item) => { type_ids.push(self.declare_struct_item(struct_item)); } @@ -71,6 +78,7 @@ impl Compiler<'_> { Declarations { type_ids, symbol_ids, + scope_ids, exported_types, exported_symbols, } @@ -84,7 +92,7 @@ impl Compiler<'_> { Item::TypeAliasItem(ty) => { let type_id = declarations.type_ids.remove(0); self.type_definition_stack.push(type_id); - self.compile_type_alias_item(ty, type_id); + self.compile_type_alias_item(ty, type_id, declarations.scope_ids.remove(0)); self.type_definition_stack.pop().unwrap(); } Item::StructItem(struct_item) => { @@ -164,7 +172,7 @@ impl Compiler<'_> { .local_types() .into_iter() .filter_map(|type_id| { - if let Type::Enum(..) = self.db.ty(type_id).clone() { + if let Type::Enum(..) = self.ty.get(type_id).clone() { Some(self.scope().type_name(type_id).unwrap().to_string()) } else { None diff --git a/crates/rue-compiler/src/compiler/item/const_item.rs b/crates/rue-compiler/src/compiler/item/const_item.rs index 33197f4..ab2ec3e 100644 --- a/crates/rue-compiler/src/compiler/item/const_item.rs +++ b/crates/rue-compiler/src/compiler/item/const_item.rs @@ -11,7 +11,7 @@ impl Compiler<'_> { let type_id = const_item .ty() - .map_or(self.builtins.unknown, |ty| self.compile_type(ty)); + .map_or(self.ty.std().unknown, |ty| self.compile_type(ty)); let hir_id = self.db.alloc_hir(Hir::Unknown); diff --git a/crates/rue-compiler/src/compiler/item/enum_item.rs b/crates/rue-compiler/src/compiler/item/enum_item.rs index e3f6edb..50f679f 100644 --- a/crates/rue-compiler/src/compiler/item/enum_item.rs +++ b/crates/rue-compiler/src/compiler/item/enum_item.rs @@ -1,16 +1,12 @@ -use std::collections::HashSet; +use rue_typing::HashSet; use indexmap::IndexMap; use num_bigint::BigInt; use num_traits::Zero; use rue_parser::EnumItem; +use rue_typing::{construct_items, Enum, Type, TypeId, Variant}; -use crate::{ - compiler::Compiler, - hir::Hir, - value::{EnumType, EnumVariantType, Type}, - ErrorKind, TypeId, -}; +use crate::{compiler::Compiler, ErrorKind}; impl Compiler<'_> { pub fn declare_enum_item(&mut self, enum_item: &EnumItem) -> TypeId { @@ -41,30 +37,38 @@ impl Compiler<'_> { // Allocate a new type for the variant. // It has to be `Unknown` for now, since field types may not be declared yet. - let type_id = self.db.alloc_type(Type::Unknown); + let variant_type_id = self.ty.alloc(Type::Unknown); // Add the variant to the enum and define the token for the variant. - variants.insert(name.to_string(), type_id); - self.db.insert_type_token(type_id, name); + variants.insert(name.to_string(), variant_type_id); + self.db.insert_type_token(variant_type_id, name); } // Allocate a new type for the enum. - let type_id = self.db.alloc_type(Type::Enum(EnumType { + let enum_structure = self + .ty + .alloc(Type::Union(variants.values().copied().collect())); + + let enum_type_id = self.ty.alloc(Type::Unknown); + + *self.ty.get_mut(enum_type_id) = Type::Enum(Enum { + original_type_id: enum_type_id, + type_id: enum_structure, has_fields, variants, - })); + }); // Add the enum to the scope and define the token for the enum. if let Some(name) = enum_item.name() { - self.scope_mut().define_type(name.to_string(), type_id); - self.db.insert_type_token(type_id, name); + self.scope_mut().define_type(name.to_string(), enum_type_id); + self.db.insert_type_token(enum_type_id, name); } - type_id + enum_type_id } pub fn compile_enum_item(&mut self, enum_item: &EnumItem, enum_type_id: TypeId) { - let Type::Enum(enum_type) = self.db.ty(enum_type_id).clone() else { + let Type::Enum(enum_type) = self.ty.get(enum_type_id).clone() else { unreachable!(); }; @@ -95,7 +99,7 @@ impl Compiler<'_> { self.type_definition_stack.push(variant_type_id); // Compile the fields of the variant. - let (fields, rest) = variant + let (fields, nil_terminated) = variant .fields() .map(|ast| self.compile_struct_fields(ast.fields())) .unwrap_or_default(); @@ -132,29 +136,33 @@ impl Compiler<'_> { BigInt::zero() }; - let atom = Self::bigint_to_bytes(discriminant).unwrap_or_else(|| { - self.db.error( - ErrorKind::EnumDiscriminantTooLarge, - variant - .discriminant() - .map_or(name.text_range(), |token| token.text_range()), - ); - Vec::new() - }); - - let discriminant = self.db.alloc_hir(Hir::Atom(atom)); - // Update the variant to use the real `EnumVariant` type. - *self.db.ty_mut(variant_type_id) = Type::EnumVariant(EnumVariantType { - enum_type: enum_type_id, + let discriminant_type = self.ty.alloc(Type::Value(discriminant.clone())); + + let type_id = if enum_type.has_fields { + construct_items( + self.ty, + [discriminant_type] + .into_iter() + .chain(fields.values().copied()), + nil_terminated, + ) + } else { + discriminant_type + }; + + *self.ty.get_mut(variant_type_id) = Type::Variant(Variant { original_type_id: variant_type_id, - fields: if variant.fields().is_some() { - Some(fields) + original_enum_type_id: enum_type_id, + field_names: if variant.fields().is_some() { + Some(fields.keys().cloned().collect()) } else { None }, - rest, + type_id, + nil_terminated, discriminant, + generic_types: Vec::new(), }); self.type_definition_stack.pop().unwrap(); diff --git a/crates/rue-compiler/src/compiler/item/function_item.rs b/crates/rue-compiler/src/compiler/item/function_item.rs index f6cd53c..cdbbfc9 100644 --- a/crates/rue-compiler/src/compiler/item/function_item.rs +++ b/crates/rue-compiler/src/compiler/item/function_item.rs @@ -1,11 +1,11 @@ use rue_parser::{AstNode, FunctionItem}; +use rue_typing::{construct_items, Callable, Type}; use crate::{ compiler::Compiler, hir::Hir, scope::Scope, symbol::{Function, Symbol}, - value::{FunctionType, Rest, Type}, ErrorKind, SymbolId, }; @@ -22,34 +22,34 @@ impl Compiler<'_> { let mut generic_types = Vec::new(); // Add the generic types to the scope. - for generic_type in function_item - .generic_types() - .map(|generics| generics.idents()) + for name in function_item + .generic_params() + .map(|generics| generics.names()) .unwrap_or_default() { // Create the generic type id. - let type_id = self.db.alloc_type(Type::Generic); + let type_id = self.ty.alloc(Type::Generic); // Check for duplicate generic types. - if self.scope().ty(generic_type.text()).is_some() { + if self.scope().ty(name.text()).is_some() { self.db.error( - ErrorKind::DuplicateType(generic_type.text().to_string()), - generic_type.text_range(), + ErrorKind::DuplicateType(name.text().to_string()), + name.text_range(), ); } // Add the generic type to the scope and define the token for the generic type. - self.scope_mut() - .define_type(generic_type.to_string(), type_id); + self.scope_mut().define_type(name.to_string(), type_id); - self.db.insert_type_token(type_id, generic_type); + self.db.insert_type_token(type_id, name); // Add the generic type to the list so it can be added to the function type. generic_types.push(type_id); } let mut param_types = Vec::new(); - let mut rest = Rest::Nil; + let mut param_names = Vec::new(); + let mut nil_terminated = true; let params = function_item.params(); let len = params.len(); @@ -65,19 +65,12 @@ impl Compiler<'_> { // Otherwise, it's a parser error. let type_id = param .ty() - .map_or(self.builtins.unknown, |ty| self.compile_type(ty)); + .map_or(self.ty.std().unknown, |ty| self.compile_type(ty)); // Add the parameter type to the list and update the parameter symbol. param_types.push(type_id); - *self.db.symbol_mut(symbol_id) = Symbol::Parameter(if param.optional().is_some() { - // If the parameter is optional, wrap the type in a possibly undefined type. - // This prevents referencing the parameter until it's checked for undefined. - self.db.alloc_type(Type::Optional(type_id)) - } else { - // Otherwise, just use the type. - type_id - }); + *self.db.symbol_mut(symbol_id) = Symbol::Parameter(type_id); // Add the parameter to the scope and define the token for the parameter. if let Some(name) = param.name() { @@ -88,34 +81,25 @@ impl Compiler<'_> { ); } + param_names.push(name.to_string()); self.scope_mut().define_symbol(name.to_string(), symbol_id); self.db.insert_symbol_token(symbol_id, name); + } else { + param_names.push(format!("#{i}")); } // Check if it's a spread or optional parameter. let last = i + 1 == len; let spread = param.spread().is_some(); - let optional = param.optional().is_some(); - if spread && optional { - self.db.error( - ErrorKind::OptionalParameterSpread, - param.syntax().text_range(), - ); - } else if spread && !last { - self.db.error( - ErrorKind::InvalidSpreadParameter, - param.syntax().text_range(), - ); - } else if optional && !last { - self.db.error( - ErrorKind::InvalidOptionalParameter, - param.syntax().text_range(), - ); - } else if spread { - rest = Rest::Spread; - } else if optional { - rest = Rest::Optional; + if spread { + if !last { + self.db.error( + ErrorKind::InvalidSpreadParameter, + param.syntax().text_range(), + ); + } + nil_terminated = false; } self.symbol_stack.pop().unwrap(); @@ -125,7 +109,7 @@ impl Compiler<'_> { // Otherwise, it's a parser error. let return_type = function_item .return_type() - .map_or(self.builtins.unknown, |ty| self.compile_type(ty)); + .map_or(self.ty.std().unknown, |ty| self.compile_type(ty)); self.scope_stack.pop().unwrap(); @@ -133,25 +117,31 @@ impl Compiler<'_> { let hir_id = self.db.alloc_hir(Hir::Unknown); // Create the function's type. - let ty = FunctionType { - generic_types, - param_types, - rest, + let type_id = self.ty.alloc(Type::Unknown); + + *self.ty.get_mut(type_id) = Type::Callable(Callable { + original_type_id: type_id, + parameter_names: param_names.into_iter().collect(), + parameters: construct_items(self.ty, param_types.into_iter(), nil_terminated), return_type, - }; + nil_terminated, + generic_types, + }); // Update the symbol with the function. if function_item.inline().is_some() { *self.db.symbol_mut(symbol_id) = Symbol::InlineFunction(Function { scope_id, hir_id, - ty, + type_id, + nil_terminated, }); } else { *self.db.symbol_mut(symbol_id) = Symbol::Function(Function { scope_id, hir_id, - ty, + type_id, + nil_terminated, }); } @@ -173,26 +163,36 @@ impl Compiler<'_> { }; // Get the function's scope and type. - let (Symbol::Function(Function { scope_id, ty, .. }) - | Symbol::InlineFunction(Function { scope_id, ty, .. })) = - self.db.symbol(symbol_id).clone() + let (Symbol::Function(Function { + scope_id, type_id, .. + }) + | Symbol::InlineFunction(Function { + scope_id, type_id, .. + })) = self.db.symbol(symbol_id).clone() else { unreachable!(); }; + let return_type = self + .ty + .get_callable(type_id) + .map(|callable| callable.return_type); + // We don't care about explicit returns in this context. self.scope_stack.push(scope_id); self.allow_generic_inference_stack.push(false); - let value = self.compile_block(&body, Some(ty.return_type)).value; + let value = self.compile_block(&body, return_type).value; self.allow_generic_inference_stack.pop().unwrap(); self.scope_stack.pop().unwrap(); // Ensure that the body is assignable to the return type. - self.type_check( - value.type_id, - ty.return_type, - function.body().unwrap().syntax().text_range(), - ); + if let Some(return_type) = return_type { + self.type_check( + value.type_id, + return_type, + function.body().unwrap().syntax().text_range(), + ); + } // Update the function's HIR with the body's HIR, for code generation purposes. let (Symbol::Function(Function { hir_id, .. }) diff --git a/crates/rue-compiler/src/compiler/item/struct_item.rs b/crates/rue-compiler/src/compiler/item/struct_item.rs index 3c18103..3cb6ad2 100644 --- a/crates/rue-compiler/src/compiler/item/struct_item.rs +++ b/crates/rue-compiler/src/compiler/item/struct_item.rs @@ -1,16 +1,13 @@ use indexmap::IndexMap; use rue_parser::{AstNode, StructField, StructItem}; +use rue_typing::{construct_items, Struct, Type, TypeId}; -use crate::{ - compiler::Compiler, - value::{Rest, StructType, Type}, - ErrorKind, TypeId, -}; +use crate::{compiler::Compiler, ErrorKind}; impl Compiler<'_> { /// Define a type for a struct in the current scope, but leave it as unknown for now. pub fn declare_struct_item(&mut self, struct_item: &StructItem) -> TypeId { - let type_id = self.db.alloc_type(Type::Unknown); + let type_id = self.ty.alloc(Type::Unknown); if let Some(name) = struct_item.name() { self.scope_mut().define_type(name.to_string(), type_id); self.db.insert_type_token(type_id, name); @@ -19,14 +16,20 @@ impl Compiler<'_> { } /// Compile and resolve a struct type. - pub fn compile_struct_item(&mut self, struct_item: &StructItem, type_id: TypeId) { - self.type_definition_stack.push(type_id); - let (fields, rest) = self.compile_struct_fields(struct_item.fields()); - *self.db.ty_mut(type_id) = Type::Struct(StructType { - original_type_id: type_id, - fields, - rest, + pub fn compile_struct_item(&mut self, struct_item: &StructItem, struct_type_id: TypeId) { + self.type_definition_stack.push(struct_type_id); + + let (fields, nil_terminated) = self.compile_struct_fields(struct_item.fields()); + let type_id = construct_items(self.ty, fields.values().copied(), nil_terminated); + + *self.ty.get_mut(struct_type_id) = Type::Struct(Struct { + original_type_id: struct_type_id, + field_names: fields.keys().cloned().collect(), + type_id, + nil_terminated, + generic_types: Vec::new(), }); + self.type_definition_stack.pop().unwrap(); } @@ -34,35 +37,27 @@ impl Compiler<'_> { pub fn compile_struct_fields( &mut self, fields: Vec, - ) -> (IndexMap, Rest) { + ) -> (IndexMap, bool) { let mut named_fields = IndexMap::new(); - let mut rest = Rest::Nil; + let mut nil_terminated = true; let len = fields.len(); for (i, field) in fields.into_iter().enumerate() { let type_id = field .ty() - .map_or(self.builtins.unknown, |ty| self.compile_type(ty)); + .map_or(self.ty.std().unknown, |ty| self.compile_type(ty)); // Check if it's a spread or optional parameter. let last = i + 1 == len; let spread = field.spread().is_some(); - let optional = field.optional().is_some(); - if spread && optional { - self.db - .error(ErrorKind::OptionalFieldSpread, field.syntax().text_range()); - } else if spread && !last { - self.db - .error(ErrorKind::InvalidSpreadField, field.syntax().text_range()); - } else if optional && !last { - self.db - .error(ErrorKind::InvalidOptionalField, field.syntax().text_range()); - } else if spread { - rest = Rest::Spread; - } else if optional { - rest = Rest::Optional; + if spread { + if !last { + self.db + .error(ErrorKind::InvalidSpreadField, field.syntax().text_range()); + } + nil_terminated = false; } if let Some(name) = field.name() { @@ -70,6 +65,6 @@ impl Compiler<'_> { }; } - (named_fields, rest) + (named_fields, nil_terminated) } } diff --git a/crates/rue-compiler/src/compiler/item/type_alias_item.rs b/crates/rue-compiler/src/compiler/item/type_alias_item.rs index c37d70e..4b98aaf 100644 --- a/crates/rue-compiler/src/compiler/item/type_alias_item.rs +++ b/crates/rue-compiler/src/compiler/item/type_alias_item.rs @@ -1,43 +1,94 @@ use rue_parser::TypeAliasItem; +use rue_typing::{Alias, Type, TypeId}; -use crate::{compiler::Compiler, value::Type, ErrorKind, TypeId}; +use crate::{compiler::Compiler, scope::Scope, ErrorKind, ScopeId}; impl Compiler<'_> { /// Define a type for an alias in the current scope, but leave it as unknown for now. - pub fn declare_type_alias_item(&mut self, type_alias: &TypeAliasItem) -> TypeId { - let type_id = self.db.alloc_type(Type::Unknown); + pub fn declare_type_alias_item(&mut self, type_alias: &TypeAliasItem) -> (TypeId, ScopeId) { + // Add the scope so you can track generic types. + let scope_id = self.db.alloc_scope(Scope::default()); + self.scope_stack.push(scope_id); + + let mut generic_types = Vec::new(); + + // Add the generic types to the scope. + for name in type_alias + .generic_params() + .map(|generics| generics.names()) + .unwrap_or_default() + { + // Create the generic type id. + let type_id = self.ty.alloc(Type::Generic); + + // Check for duplicate generic types. + if self.scope().ty(name.text()).is_some() { + self.db.error( + ErrorKind::DuplicateType(name.text().to_string()), + name.text_range(), + ); + } + + // Add the generic type to the scope and define the token for the generic type. + self.scope_mut().define_type(name.to_string(), type_id); + + self.db.insert_type_token(type_id, name); + + // Add the generic type to the list so it can be added to the function type. + generic_types.push(type_id); + } + + self.scope_stack.pop().unwrap(); + + // Create the alias type. + let ref_type_id = self.ty.alloc(Type::Ref(self.ty.std().unknown)); + + let type_id = self.ty.alloc(Type::Unknown); + + *self.ty.get_mut(type_id) = Type::Alias(Alias { + original_type_id: type_id, + type_id: ref_type_id, + generic_types, + }); + if let Some(name) = type_alias.name() { self.scope_mut().define_type(name.to_string(), type_id); self.db.insert_type_token(type_id, name); } - type_id + + (type_id, scope_id) } /// Compile and resolve the type that the alias points to. - pub fn compile_type_alias_item(&mut self, type_alias: &TypeAliasItem, alias_type_id: TypeId) { + pub fn compile_type_alias_item( + &mut self, + type_alias: &TypeAliasItem, + alias_type_id: TypeId, + scope_id: ScopeId, + ) { self.type_definition_stack.push(alias_type_id); + // Add the scope so you can use generic types. + self.scope_stack.push(scope_id); let type_id = type_alias .ty() - .map_or(self.builtins.unknown, |ty| self.compile_type(ty)); + .map_or(self.ty.std().unknown, |ty| self.compile_type(ty)); + self.scope_stack.pop().unwrap(); // Set the alias type to the resolved type. - *self.db.ty_mut(alias_type_id) = Type::Alias(type_id); + let Type::Alias(Alias { + type_id: ref_type_id, + .. + }) = self.ty.get(alias_type_id).clone() + else { + unreachable!(); + }; - // A cycle between type aliases has been detected. - // We set it to unknown to prevent stack overflow issues later. - if self.db.is_cyclic(alias_type_id) { - let name = type_alias - .name() - .expect("the name should exist if it's in a cyclic reference"); + let Type::Ref(reference) = self.ty.get_raw_mut(ref_type_id) else { + unreachable!(); + }; - self.db.error( - ErrorKind::RecursiveTypeAlias(name.to_string()), - name.text_range(), - ); - - *self.db.ty_mut(alias_type_id) = Type::Unknown; - } + *reference = type_id; self.type_definition_stack.pop().unwrap(); } diff --git a/crates/rue-compiler/src/compiler/path.rs b/crates/rue-compiler/src/compiler/path.rs index de98a7a..cba8f87 100644 --- a/crates/rue-compiler/src/compiler/path.rs +++ b/crates/rue-compiler/src/compiler/path.rs @@ -1,11 +1,16 @@ -use rue_parser::SyntaxToken; +use rue_typing::HashMap; -use crate::{symbol::Symbol, value::Type, ErrorKind, SymbolId, TypeId}; +use indexmap::IndexMap; +use rowan::TextRange; +use rue_parser::{AstNode, GenericArgs, PathItem}; +use rue_typing::{Lazy, Type, TypeId}; + +use crate::{symbol::Symbol, ErrorKind, SymbolId}; use super::Compiler; #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum PathItem { +pub enum Path { Symbol(SymbolId), Type(TypeId), } @@ -19,10 +24,12 @@ pub enum PathKind { impl Compiler<'_> { pub fn resolve_base_path( &mut self, - name: &SyntaxToken, + item: &PathItem, path_kind: PathKind, last: bool, - ) -> Option { + ) -> Option { + let name = item.name()?; + for &scope_id in self.scope_stack.iter().rev() { let type_id = self.db.scope(scope_id).ty(name.text()); let symbol_id = self.db.scope(scope_id).symbol(name.text()); @@ -30,31 +37,39 @@ impl Compiler<'_> { if let (Some(type_id), Some(symbol_id)) = (type_id, symbol_id) { if let Symbol::Module(..) = self.db.symbol(symbol_id) { if !last { - return Some(PathItem::Symbol(symbol_id)); + return Some(Path::Symbol(symbol_id)); } } - if let Type::Enum(..) = self.db.ty(type_id) { + if let Type::Enum(..) = self.ty.get(type_id) { if !last { self.type_reference(type_id); - return Some(PathItem::Type(type_id)); + return Some(Path::Type(type_id)); } } match path_kind { PathKind::Type => { self.type_reference(type_id); - return Some(PathItem::Type(type_id)); + return Some(Path::Type(self.handle_generics( + type_id, + item.generic_args(), + item.syntax().text_range(), + )?)); } PathKind::Symbol => { - return Some(PathItem::Symbol(symbol_id)); + return Some(Path::Symbol(symbol_id)); } } } else if let Some(type_id) = type_id { self.type_reference(type_id); - return Some(PathItem::Type(type_id)); + return Some(Path::Type(self.handle_generics( + type_id, + item.generic_args(), + item.syntax().text_range(), + )?)); } else if let Some(symbol_id) = symbol_id { - return Some(PathItem::Symbol(symbol_id)); + return Some(Path::Symbol(symbol_id)); } } @@ -71,14 +86,16 @@ impl Compiler<'_> { pub fn resolve_next_path( &mut self, - item: PathItem, - name: &SyntaxToken, + path: Path, + item: &PathItem, path_kind: PathKind, last: bool, - ) -> Option { - match item { - PathItem::Type(type_id) => { - let Type::Enum(enum_type) = self.db.ty(type_id) else { + ) -> Option { + let name = item.name()?; + + match path { + Path::Type(type_id) => { + let Type::Enum(enum_type) = self.ty.get(type_id) else { self.db.error( ErrorKind::InvalidTypePath(self.type_name(type_id)), name.text_range(), @@ -95,9 +112,9 @@ impl Compiler<'_> { }; self.type_reference(variant_type); - Some(PathItem::Type(variant_type)) + Some(Path::Type(variant_type)) } - PathItem::Symbol(module_id) => { + Path::Symbol(module_id) => { let Symbol::Module(module) = self.db.symbol(module_id) else { self.db.error( ErrorKind::InvalidSymbolPath(self.symbol_name(module_id)), @@ -123,29 +140,37 @@ impl Compiler<'_> { if let (Some(type_id), Some(symbol_id)) = (type_id, symbol_id) { if let Symbol::Module(..) = self.db.symbol(symbol_id) { if !last { - return Some(PathItem::Symbol(symbol_id)); + return Some(Path::Symbol(symbol_id)); } } - if let Type::Enum(..) = self.db.ty(type_id) { + if let Type::Enum(..) = self.ty.get(type_id) { if !last { self.type_reference(type_id); - return Some(PathItem::Type(type_id)); + return Some(Path::Type(type_id)); } } match path_kind { PathKind::Type => { self.type_reference(type_id); - Some(PathItem::Type(type_id)) + Some(Path::Type(self.handle_generics( + type_id, + item.generic_args(), + item.syntax().text_range(), + )?)) } - PathKind::Symbol => Some(PathItem::Symbol(symbol_id)), + PathKind::Symbol => Some(Path::Symbol(symbol_id)), } } else if let Some(type_id) = type_id { self.type_reference(type_id); - Some(PathItem::Type(type_id)) + Some(Path::Type(self.handle_generics( + type_id, + item.generic_args(), + item.syntax().text_range(), + )?)) } else if let Some(symbol_id) = symbol_id { - Some(PathItem::Symbol(symbol_id)) + Some(Path::Symbol(symbol_id)) } else if private_type { self.db.error( ErrorKind::PrivateType(name.text().to_string()), @@ -168,4 +193,68 @@ impl Compiler<'_> { } } } + + fn handle_generics( + &mut self, + type_id: TypeId, + generic_args: Option, + text_range: TextRange, + ) -> Option { + self.handle_generics_impl(type_id, generic_args, text_range) + .map(|type_id| self.ty.substitute(type_id, HashMap::new())) + } + + fn handle_generics_impl( + &mut self, + mut type_id: TypeId, + generic_args: Option, + text_range: TextRange, + ) -> Option { + let Type::Alias(alias) = self.ty.get(type_id) else { + if generic_args.is_some() { + self.db.error(ErrorKind::UnexpectedGenericArgs, text_range); + return None; + } + return Some(type_id); + }; + + if generic_args.is_some() && alias.generic_types.is_empty() { + self.db.error(ErrorKind::UnexpectedGenericArgs, text_range); + None + } else if generic_args.is_none() && !alias.generic_types.is_empty() { + self.db.error(ErrorKind::ExpectedGenericArgs, text_range); + None + } else if let Some(generic_args) = generic_args { + let generic_args = generic_args.types(); + + if generic_args.len() != alias.generic_types.len() { + self.db.error( + ErrorKind::GenericArgsMismatch(generic_args.len(), alias.generic_types.len()), + text_range, + ); + return None; + } + + let mut substitutions = IndexMap::new(); + for (generic_type, arg) in alias.generic_types.clone().into_iter().zip(generic_args) { + let arg = self.compile_type(arg); + substitutions.insert(generic_type, arg); + } + + if self.type_definition_stack.is_empty() { + type_id = self + .ty + .substitute(type_id, substitutions.into_iter().collect()); + } else { + type_id = self.ty.alloc(Type::Lazy(Lazy { + type_id, + substitutions, + })); + } + + Some(type_id) + } else { + Some(type_id) + } + } } diff --git a/crates/rue-compiler/src/compiler/stmt/if_stmt.rs b/crates/rue-compiler/src/compiler/stmt/if_stmt.rs index 0495bc4..9927551 100644 --- a/crates/rue-compiler/src/compiler/stmt/if_stmt.rs +++ b/crates/rue-compiler/src/compiler/stmt/if_stmt.rs @@ -1,12 +1,13 @@ -use std::collections::HashMap; +use rue_typing::HashMap; use rue_parser::{AstNode, IfStmt}; +use rue_typing::TypeId; use crate::{ compiler::{block::BlockTerminator, Compiler}, scope::Scope, - value::{GuardPath, TypeOverride}, - ErrorKind, HirId, TypeId, + value::GuardPath, + ErrorKind, HirId, }; impl Compiler<'_> { @@ -15,17 +16,17 @@ impl Compiler<'_> { &mut self, if_stmt: &IfStmt, expected_type: Option, - ) -> (HirId, HirId, HashMap) { + ) -> (HirId, HirId, HashMap) { // Compile the condition expression. let condition = if_stmt .condition() - .map(|condition| self.compile_expr(&condition, Some(self.builtins.bool))) + .map(|condition| self.compile_expr(&condition, Some(self.ty.std().bool))) .unwrap_or_else(|| self.unknown()); // Check that the condition is a boolean. self.type_check( condition.type_id, - self.builtins.bool, + self.ty.std().bool, if_stmt.syntax().text_range(), ); @@ -61,7 +62,7 @@ impl Compiler<'_> { // Check that the output matches the expected type. self.type_check( then_block.type_id, - expected_type.unwrap_or(self.builtins.unknown), + expected_type.unwrap_or(self.ty.std().unknown), if_stmt.syntax().text_range(), ); diff --git a/crates/rue-compiler/src/compiler/symbol_table.rs b/crates/rue-compiler/src/compiler/symbol_table.rs index eecaddc..77dea9c 100644 --- a/crates/rue-compiler/src/compiler/symbol_table.rs +++ b/crates/rue-compiler/src/compiler/symbol_table.rs @@ -1,11 +1,9 @@ -use std::collections::HashSet; +use rue_typing::HashSet; use indexmap::{IndexMap, IndexSet}; +use rue_typing::{Type, TypeId, TypeSystem}; -use crate::{ - dependency_graph::DependencyGraph, symbol::Symbol, value::Type, Database, SymbolId, TypeId, - WarningKind, -}; +use crate::{dependency_graph::DependencyGraph, symbol::Symbol, Database, SymbolId, WarningKind}; #[derive(Debug, Default)] pub struct SymbolTable { @@ -44,6 +42,7 @@ impl SymbolTable { pub fn calculate_unused( &self, db: &mut Database, + ty: &TypeSystem, dependency_graph: &DependencyGraph, ignored_symbols: &HashSet, ignored_types: &HashSet, @@ -126,12 +125,12 @@ impl SymbolTable { continue; } let token = db.type_token(*type_id).unwrap(); - let kind = match db.ty_raw(*type_id) { + let kind = match ty.get_raw(*type_id) { Type::Generic => WarningKind::UnusedGenericType(token.to_string()), Type::Alias(..) => WarningKind::UnusedTypeAlias(token.to_string()), Type::Struct(..) => WarningKind::UnusedStruct(token.to_string()), Type::Enum(..) => WarningKind::UnusedEnum(token.to_string()), - Type::EnumVariant(..) => WarningKind::UnusedEnumVariant(token.to_string()), + Type::Variant(..) => WarningKind::UnusedEnumVariant(token.to_string()), _ => continue, }; db.warning(kind, token.text_range()); diff --git a/crates/rue-compiler/src/compiler/ty.rs b/crates/rue-compiler/src/compiler/ty.rs index adcf983..e1fdae5 100644 --- a/crates/rue-compiler/src/compiler/ty.rs +++ b/crates/rue-compiler/src/compiler/ty.rs @@ -1,25 +1,24 @@ use rue_parser::{AstNode, Type}; - -use crate::TypeId; +use rue_typing::TypeId; use super::Compiler; mod function_type; -mod list_type; -mod nullable_type; +mod literal_type; mod pair_type; mod path_type; +mod union_type; impl Compiler<'_> { pub fn compile_type(&mut self, ty: Type) -> TypeId { match ty { + Type::LiteralType(lit) => self.compile_literal_type(&lit), Type::PathType(path) => { - self.compile_path_type(&path.idents(), path.syntax().text_range()) + self.compile_path_type(&path.items(), path.syntax().text_range()) } - Type::ListType(list) => self.compile_list_type(&list), Type::FunctionType(function) => self.compile_function_type(&function), Type::PairType(tuple) => self.compile_pair_type(&tuple), - Type::NullableType(optional) => self.compile_nullable_type(&optional), + Type::UnionType(union) => self.compile_union_type(&union), } } } diff --git a/crates/rue-compiler/src/compiler/ty/function_type.rs b/crates/rue-compiler/src/compiler/ty/function_type.rs index 6ed034d..260b09a 100644 --- a/crates/rue-compiler/src/compiler/ty/function_type.rs +++ b/crates/rue-compiler/src/compiler/ty/function_type.rs @@ -1,68 +1,52 @@ -use std::collections::HashSet; - +use indexmap::IndexSet; use rue_parser::{AstNode, FunctionType as Ast}; +use rue_typing::{construct_items, Callable, Type, TypeId}; -use crate::{ - compiler::Compiler, - value::{FunctionType, Rest, Type}, - ErrorKind, TypeId, -}; +use crate::{compiler::Compiler, ErrorKind}; impl Compiler<'_> { pub fn compile_function_type(&mut self, function: &Ast) -> TypeId { - let mut param_types = Vec::new(); - let mut type_names = HashSet::new(); - let mut rest = Rest::Nil; + let mut parameters = Vec::new(); + let mut parameter_names = IndexSet::new(); + let mut nil_terminated = true; let params = function.params(); let len = params.len(); for (i, param) in params.into_iter().enumerate() { - // We don't actually use the names yet, - // but go ahead and check for duplicates. - // TODO: Use the name in the actual type? - if let Some(name) = param.name() { - if !type_names.insert(name.to_string()) { - self.db.error( - ErrorKind::DuplicateSymbol(name.to_string()), - name.text_range(), - ); - } + let name = param + .name() + .map(|token| token.to_string()) + .unwrap_or(format!("#{i}")); + + if !parameter_names.insert(name.to_string()) { + self.db.error( + ErrorKind::DuplicateSymbol(name.to_string()), + param.name().unwrap().text_range(), + ); } // Compile the type of the parameter, if present. // Otherwise, it's a parser error. let type_id = param .ty() - .map_or(self.builtins.unknown, |ty| self.compile_type(ty)); + .map_or(self.ty.std().unknown, |ty| self.compile_type(ty)); // Add the parameter type to the list. - param_types.push(type_id); + parameters.push(type_id); // Check if it's a spread or optional parameter. let last = i + 1 == len; let spread = param.spread().is_some(); - let optional = param.optional().is_some(); - if spread && optional { - self.db.error( - ErrorKind::OptionalParameterSpread, - param.syntax().text_range(), - ); - } else if spread && !last { - self.db.error( - ErrorKind::InvalidSpreadParameter, - param.syntax().text_range(), - ); - } else if optional && !last { - self.db.error( - ErrorKind::InvalidOptionalParameter, - param.syntax().text_range(), - ); - } else if spread { - rest = Rest::Spread; - } else if optional { - rest = Rest::Optional; + if spread { + if !last { + self.db.error( + ErrorKind::InvalidSpreadParameter, + param.syntax().text_range(), + ); + } + nil_terminated = false; } } @@ -70,15 +54,23 @@ impl Compiler<'_> { // Otherwise, it's a parser error. let return_type = function .return_type() - .map_or(self.builtins.unknown, |ty| self.compile_type(ty)); + .map_or(self.ty.std().unknown, |ty| self.compile_type(ty)); + + let parameters = construct_items(self.ty, parameters.into_iter(), nil_terminated); // Allocate a new type for the function. // TODO: Support generic types. - self.db.alloc_type(Type::Function(FunctionType { - param_types, - rest, + let type_id = self.ty.alloc(Type::Unknown); + + *self.ty.get_mut(type_id) = Type::Callable(Callable { + original_type_id: type_id, + parameter_names, + parameters, + nil_terminated, return_type, generic_types: Vec::new(), - })) + }); + + type_id } } diff --git a/crates/rue-compiler/src/compiler/ty/list_type.rs b/crates/rue-compiler/src/compiler/ty/list_type.rs deleted file mode 100644 index e2cd9f0..0000000 --- a/crates/rue-compiler/src/compiler/ty/list_type.rs +++ /dev/null @@ -1,14 +0,0 @@ -use rue_parser::ListType; - -use crate::{compiler::Compiler, value::Type, TypeId}; - -impl Compiler<'_> { - pub fn compile_list_type(&mut self, list: &ListType) -> TypeId { - let Some(inner) = list.ty() else { - return self.builtins.unknown; - }; - - let item_type = self.compile_type(inner); - self.db.alloc_type(Type::List(item_type)) - } -} diff --git a/crates/rue-compiler/src/compiler/ty/literal_type.rs b/crates/rue-compiler/src/compiler/ty/literal_type.rs new file mode 100644 index 0000000..c56e5c9 --- /dev/null +++ b/crates/rue-compiler/src/compiler/ty/literal_type.rs @@ -0,0 +1,32 @@ +use rue_parser::{LiteralType, SyntaxKind, SyntaxToken}; +use rue_typing::{Type, TypeId}; + +use crate::compiler::Compiler; + +impl Compiler<'_> { + pub fn compile_literal_type(&mut self, literal: &LiteralType) -> TypeId { + let Some(value) = literal.value() else { + return self.ty.std().unknown; + }; + + match value.kind() { + SyntaxKind::Int => self.compile_int_type(&value), + SyntaxKind::True => self.ty.std().true_bool, + SyntaxKind::False => self.ty.std().false_bool, + SyntaxKind::Nil => self.ty.std().nil, + _ => unreachable!(), + } + } + + fn compile_int_type(&mut self, int: &SyntaxToken) -> TypeId { + // Parse the literal into `BigInt`. + // It should not be possible to have a syntax error at this point. + let bigint = int + .text() + .replace('_', "") + .parse() + .expect("failed to parse integer literal"); + + self.ty.alloc(Type::Value(bigint)) + } +} diff --git a/crates/rue-compiler/src/compiler/ty/nullable_type.rs b/crates/rue-compiler/src/compiler/ty/nullable_type.rs deleted file mode 100644 index 16d25fb..0000000 --- a/crates/rue-compiler/src/compiler/ty/nullable_type.rs +++ /dev/null @@ -1,21 +0,0 @@ -use rue_parser::{AstNode, NullableType}; - -use crate::{compiler::Compiler, value::Type, TypeId, WarningKind}; - -impl Compiler<'_> { - pub fn compile_nullable_type(&mut self, optional: &NullableType) -> TypeId { - let ty = optional - .ty() - .map_or(self.builtins.unknown, |ty| self.compile_type(ty)); - - if let Type::Nullable(inner) = self.db.ty_raw(ty).clone() { - self.db.warning( - WarningKind::RedundantNullableType(self.type_name(ty)), - optional.syntax().text_range(), - ); - return inner; - } - - self.db.alloc_type(Type::Nullable(ty)) - } -} diff --git a/crates/rue-compiler/src/compiler/ty/pair_type.rs b/crates/rue-compiler/src/compiler/ty/pair_type.rs index 97c0f1b..84fb306 100644 --- a/crates/rue-compiler/src/compiler/ty/pair_type.rs +++ b/crates/rue-compiler/src/compiler/ty/pair_type.rs @@ -1,21 +1,18 @@ -use rue_parser::PairType as Ast; +use rue_parser::PairType; +use rue_typing::{Type, TypeId}; -use crate::{ - compiler::Compiler, - value::{PairType, Type}, - TypeId, -}; +use crate::compiler::Compiler; impl Compiler<'_> { - pub fn compile_pair_type(&mut self, pair_type: &Ast) -> TypeId { + pub fn compile_pair_type(&mut self, pair_type: &PairType) -> TypeId { let first = pair_type .first() - .map_or(self.builtins.unknown, |ty| self.compile_type(ty)); + .map_or(self.ty.std().unknown, |ty| self.compile_type(ty)); let rest = pair_type .rest() - .map_or(self.builtins.unknown, |ty| self.compile_type(ty)); + .map_or(self.ty.std().unknown, |ty| self.compile_type(ty)); - self.db.alloc_type(Type::Pair(PairType { first, rest })) + self.ty.alloc(Type::Pair(first, rest)) } } diff --git a/crates/rue-compiler/src/compiler/ty/path_type.rs b/crates/rue-compiler/src/compiler/ty/path_type.rs index 4886596..43b9f9e 100644 --- a/crates/rue-compiler/src/compiler/ty/path_type.rs +++ b/crates/rue-compiler/src/compiler/ty/path_type.rs @@ -1,39 +1,40 @@ use rowan::TextRange; -use rue_parser::SyntaxToken; +use rue_parser::PathItem; +use rue_typing::TypeId; use crate::{ compiler::{ - path::{PathItem, PathKind}, + path::{Path, PathKind}, Compiler, }, - ErrorKind, TypeId, + ErrorKind, }; impl Compiler<'_> { - pub fn compile_path_type(&mut self, idents: &[SyntaxToken], text_range: TextRange) -> TypeId { - let Some(mut item) = self.resolve_base_path(&idents[0], PathKind::Type, idents.len() == 1) + pub fn compile_path_type(&mut self, items: &[PathItem], text_range: TextRange) -> TypeId { + let Some(mut path) = self.resolve_base_path(&items[0], PathKind::Type, items.len() == 1) else { - return self.builtins.unknown; + return self.ty.std().unknown; }; - let mut last_ident = idents[0].to_string(); + let mut last_name = items[0].name().unwrap().to_string(); - for (i, name) in idents.iter().enumerate().skip(1) { - let Some(next_item) = - self.resolve_next_path(item, name, PathKind::Type, i == idents.len() - 1) + for (i, item) in items.iter().enumerate().skip(1) { + let Some(next_path) = + self.resolve_next_path(path, item, PathKind::Type, i == items.len() - 1) else { - return self.builtins.unknown; + return self.ty.std().unknown; }; - last_ident = name.to_string(); - item = next_item; + last_name = item.name().unwrap().to_string(); + path = next_path; } - match item { - PathItem::Type(type_id) => type_id, - PathItem::Symbol(..) => { + match path { + Path::Type(type_id) => type_id, + Path::Symbol(..) => { self.db - .error(ErrorKind::ExpectedTypePath(last_ident), text_range); - self.builtins.unknown + .error(ErrorKind::ExpectedTypePath(last_name), text_range); + self.ty.std().unknown } } } diff --git a/crates/rue-compiler/src/compiler/ty/union_type.rs b/crates/rue-compiler/src/compiler/ty/union_type.rs new file mode 100644 index 0000000..57bd268 --- /dev/null +++ b/crates/rue-compiler/src/compiler/ty/union_type.rs @@ -0,0 +1,14 @@ +use rue_parser::UnionType; +use rue_typing::{Type, TypeId}; + +use crate::compiler::Compiler; + +impl Compiler<'_> { + pub fn compile_union_type(&mut self, union: &UnionType) -> TypeId { + let mut types = Vec::new(); + for ty in union.types() { + types.push(self.compile_type(ty)); + } + self.ty.alloc(Type::Union(types)) + } +} diff --git a/crates/rue-compiler/src/database.rs b/crates/rue-compiler/src/database.rs index 6c2f9e4..eef5a0e 100644 --- a/crates/rue-compiler/src/database.rs +++ b/crates/rue-compiler/src/database.rs @@ -5,10 +5,10 @@ use rue_parser::SyntaxToken; mod comparison; mod ids; -mod type_system; pub use comparison::*; pub use ids::*; +use rue_typing::TypeId; use crate::{ environment::Environment, @@ -17,7 +17,6 @@ use crate::{ mir::Mir, scope::Scope, symbol::Symbol, - value::Type, Diagnostic, DiagnosticKind, ErrorKind, WarningKind, }; @@ -26,7 +25,6 @@ pub struct Database { diagnostics: Vec, scopes: Arena, symbols: Arena, - types: Arena, hir: Arena, mir: Arena, lir: Arena, @@ -49,10 +47,6 @@ impl Database { SymbolId(self.symbols.alloc(symbol)) } - pub(crate) fn alloc_type(&mut self, ty: Type) -> TypeId { - TypeId(self.types.alloc(ty)) - } - pub(crate) fn alloc_hir(&mut self, hir: Hir) -> HirId { HirId(self.hir.alloc(hir)) } @@ -77,17 +71,6 @@ impl Database { &self.symbols[id.0] } - pub fn ty_raw(&self, id: TypeId) -> &Type { - &self.types[id.0] - } - - pub fn ty(&self, mut id: TypeId) -> &Type { - while let Type::Alias(alias) = self.ty_raw(id) { - id = *alias; - } - self.ty_raw(id) - } - pub fn hir(&self, id: HirId) -> &Hir { &self.hir[id.0] } @@ -104,10 +87,6 @@ impl Database { &self.environments[id.0] } - pub(crate) fn ty_mut(&mut self, id: TypeId) -> &mut Type { - &mut self.types[id.0] - } - pub(crate) fn env_mut(&mut self, id: EnvironmentId) -> &mut Environment { &mut self.environments[id.0] } diff --git a/crates/rue-compiler/src/database/ids.rs b/crates/rue-compiler/src/database/ids.rs index 00cffdd..489eab6 100644 --- a/crates/rue-compiler/src/database/ids.rs +++ b/crates/rue-compiler/src/database/ids.rs @@ -1,9 +1,6 @@ use id_arena::Id; -use crate::{ - environment::Environment, hir::Hir, lir::Lir, mir::Mir, scope::Scope, symbol::Symbol, - value::Type, -}; +use crate::{environment::Environment, hir::Hir, lir::Lir, mir::Mir, scope::Scope, symbol::Symbol}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct SymbolId(pub(super) Id); @@ -11,9 +8,6 @@ pub struct SymbolId(pub(super) Id); #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct ScopeId(pub(super) Id); -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct TypeId(pub(super) Id); - #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct HirId(pub(super) Id); diff --git a/crates/rue-compiler/src/database/type_system.rs b/crates/rue-compiler/src/database/type_system.rs deleted file mode 100644 index 555d8aa..0000000 --- a/crates/rue-compiler/src/database/type_system.rs +++ /dev/null @@ -1,841 +0,0 @@ -use std::collections::{HashMap, HashSet}; - -use crate::{ - value::{EnumType, EnumVariantType, FunctionType, PairType, StructType, Type}, - Comparison, Database, TypeId, -}; - -#[derive(Debug)] -struct ComparisonContext<'a> { - visited: HashSet<(TypeId, TypeId)>, - generic_type_stack: &'a mut Vec>, - infer_generics: bool, -} - -impl<'a> ComparisonContext<'a> { - fn new(generic_type_stack: &'a mut Vec>, infer_generics: bool) -> Self { - Self { - visited: HashSet::new(), - generic_type_stack, - infer_generics, - } - } -} - -impl Database { - pub fn substitute_type( - &mut self, - type_id: TypeId, - substitutions: &HashMap, - ) -> TypeId { - self.substitute_type_visitor(type_id, substitutions, &mut HashSet::new()) - } - - pub fn compare_type(&self, lhs: TypeId, rhs: TypeId) -> Comparison { - self.compare_type_visitor( - lhs, - rhs, - &mut ComparisonContext::new(&mut Vec::new(), false), - ) - } - - pub fn compare_type_with_generics( - &self, - lhs: TypeId, - rhs: TypeId, - generic_type_stack: &mut Vec>, - ) -> Comparison { - self.compare_type_visitor( - lhs, - rhs, - &mut ComparisonContext::new(generic_type_stack, true), - ) - } - - pub fn is_cyclic(&self, type_id: TypeId) -> bool { - self.is_cyclic_visitor(type_id, &mut HashSet::new()) - } - - pub fn first_type(&self, type_id: TypeId) -> Option { - let Type::Pair(pair) = self.ty(type_id) else { - return None; - }; - Some(pair.first) - } - - pub fn rest_type(&self, type_id: TypeId) -> Option { - let Type::Pair(pair) = self.ty(type_id) else { - return None; - }; - Some(pair.rest) - } - - pub fn non_nullable(&mut self, ty: TypeId) -> TypeId { - match self.ty(ty) { - Type::Nullable(inner) => self.non_nullable(*inner), - _ => ty, - } - } - - pub fn non_undefined(&mut self, ty: TypeId) -> TypeId { - match self.ty(ty) { - Type::Optional(inner) => self.non_undefined(*inner), - _ => ty, - } - } - - pub fn unwrap_list(&mut self, ty: TypeId) -> Option { - match self.ty(ty) { - Type::List(inner) => Some(*inner), - _ => None, - } - } - - pub fn substitute_type_visitor( - &mut self, - type_id: TypeId, - substitutions: &HashMap, - visited: &mut HashSet, - ) -> TypeId { - if let Some(type_id) = substitutions.get(&type_id).copied() { - return type_id; - } - - if !visited.insert(type_id) { - return type_id; - } - - match self.ty(type_id).clone() { - Type::Alias(..) => unreachable!(), - Type::Pair(pair) => { - let new_pair = PairType { - first: self.substitute_type_visitor(pair.first, substitutions, visited), - rest: self.substitute_type_visitor(pair.rest, substitutions, visited), - }; - - if new_pair == pair { - type_id - } else { - self.alloc_type(Type::Pair(new_pair)) - } - } - Type::Enum(enum_type) => { - let new_enum = EnumType { - has_fields: enum_type.has_fields, - variants: enum_type - .variants - .iter() - .map(|(k, v)| { - ( - k.clone(), - self.substitute_type_visitor(*v, substitutions, visited), - ) - }) - .collect(), - }; - - if new_enum == enum_type { - type_id - } else { - self.alloc_type(Type::Enum(new_enum)) - } - } - Type::EnumVariant(enum_variant) => { - let new_variant = EnumVariantType { - enum_type: enum_variant.enum_type, - original_type_id: enum_variant.original_type_id, - fields: enum_variant.fields.as_ref().map(|fields| { - fields - .iter() - .map(|(k, v)| { - ( - k.clone(), - self.substitute_type_visitor(*v, substitutions, visited), - ) - }) - .collect() - }), - rest: enum_variant.rest, - discriminant: enum_variant.discriminant, - }; - - if new_variant == enum_variant { - type_id - } else { - self.alloc_type(Type::EnumVariant(new_variant)) - } - } - Type::List(inner) => { - let new_inner = self.substitute_type_visitor(inner, substitutions, visited); - - if new_inner == inner { - type_id - } else { - self.alloc_type(Type::List(new_inner)) - } - } - Type::Struct(struct_type) => { - let new_struct = StructType { - original_type_id: struct_type.original_type_id, - fields: struct_type - .fields - .iter() - .map(|(k, v)| { - ( - k.clone(), - self.substitute_type_visitor(*v, substitutions, visited), - ) - }) - .collect(), - rest: struct_type.rest, - }; - - if new_struct == struct_type { - type_id - } else { - self.alloc_type(Type::Struct(new_struct)) - } - } - Type::Function(function) => { - let new_function = FunctionType { - param_types: function - .param_types - .iter() - .map(|ty| self.substitute_type_visitor(*ty, substitutions, visited)) - .collect(), - rest: function.rest, - return_type: self.substitute_type_visitor( - function.return_type, - substitutions, - visited, - ), - generic_types: function - .generic_types - .iter() - .map(|ty| self.substitute_type_visitor(*ty, substitutions, visited)) - .collect(), - }; - - if new_function == function { - type_id - } else { - self.alloc_type(Type::Function(new_function)) - } - } - Type::Nullable(inner) => { - let new_inner = self.substitute_type_visitor(inner, substitutions, visited); - - if new_inner == inner { - type_id - } else { - self.alloc_type(Type::Nullable(inner)) - } - } - Type::Optional(inner) => { - let new_inner = self.substitute_type_visitor(inner, substitutions, visited); - - if new_inner == inner { - type_id - } else { - self.alloc_type(Type::Optional(inner)) - } - } - Type::Unknown - | Type::Generic - | Type::Any - | Type::Bool - | Type::Bytes - | Type::Bytes32 - | Type::Int - | Type::Nil - | Type::PublicKey => type_id, - } - } - - #[allow(clippy::match_same_arms, clippy::too_many_lines)] - fn compare_type_visitor( - &self, - lhs: TypeId, - rhs: TypeId, - ctx: &mut ComparisonContext<'_>, - ) -> Comparison { - let key = (lhs, rhs); - if lhs == rhs || !ctx.visited.insert(key) { - return Comparison::Equal; - } - - let comparison = match (self.ty(lhs), self.ty(rhs)) { - // Aliases are already resolved at this point. - (Type::Alias(..), _) | (_, Type::Alias(..)) => unreachable!(), - - // We need to infer `Generic` types, and return `Unrelated` if incompatible. - (_, Type::Generic) => { - let mut found = None; - - for generics in ctx.generic_type_stack.iter().rev() { - if let Some(&generic) = generics.get(&rhs) { - found = Some(generic); - } - } - - if let Some(found) = found { - self.compare_type_visitor(lhs, found, ctx) - } else if ctx.infer_generics { - ctx.generic_type_stack.last_mut().unwrap().insert(rhs, lhs); - Comparison::Assignable - } else { - Comparison::Unrelated - } - } - - // We should have already given a diagnostic for `Unknown`. - // So we will go ahead and pretend they are equal to everything. - (Type::Unknown, _) | (_, Type::Unknown) => Comparison::Equal, - - // Possibly undefined is a special type. - (Type::Optional(..), _) | (_, Type::Optional(..)) => Comparison::Unrelated, - - // These are of course equal atomic types. - (Type::Any, Type::Any) => Comparison::Equal, - (Type::Int, Type::Int) => Comparison::Equal, - (Type::Bool, Type::Bool) => Comparison::Equal, - (Type::Bytes, Type::Bytes) => Comparison::Equal, - (Type::Bytes32, Type::Bytes32) => Comparison::Equal, - (Type::PublicKey, Type::PublicKey) => Comparison::Equal, - (Type::Nil, Type::Nil) => Comparison::Equal, - - // We should treat `Bytes32` as a subtype of `Bytes` and therefore equal. - // `PublicKey` however, can have different meaning, so is not equal. - // `Bytes` is also not equal to `Bytes32` since it's unsized. - (Type::Bytes32, Type::Bytes) => Comparison::Equal, - - // You can cast `Any` to anything, but it's not implicit. - // So many languages make `Any` implicit and it makes it hard to debug. - (Type::Any, _) => Comparison::Castable, - - // However, anything can be assigned to `Any` implicitly. - (_, Type::Any) => Comparison::Assignable, - - // `Generic` types are unrelated to everything else except their specific instance. - // The only exception is the `Any` type above. - (Type::Generic, _) => Comparison::Unrelated, - - // You have to explicitly convert between other atom types. - // This is because the compiled output can change depending on the type. - (Type::Int, Type::Bytes) => Comparison::Castable, - (Type::Bool, Type::Bytes) => Comparison::Castable, - (Type::Bool, Type::Int) => Comparison::Castable, - (Type::Nil, Type::Int) => Comparison::Castable, - (Type::Nil, Type::Bool) => Comparison::Castable, - (Type::PublicKey, Type::Bytes) => Comparison::Castable, - (Type::PublicKey, Type::Int) => Comparison::Castable, - (Type::Bytes, Type::Int) => Comparison::Castable, - (Type::Bytes32, Type::Int) => Comparison::Castable, - - // Let's allow assigning `Nil` to `Bytes` for ease of use. - // The only alternative without needing a cast is empty strings. - (Type::Nil, Type::Bytes) => Comparison::Assignable, - - // Size changing conversions are not possible without a type guard. - (Type::Bytes, Type::Bytes32) => Comparison::Superset, - (Type::Bytes, Type::PublicKey) => Comparison::Superset, - (Type::Bytes, Type::Nil) => Comparison::Superset, - (Type::Bytes, Type::Bool) => Comparison::Superset, - (Type::Int, Type::Bytes32) => Comparison::Superset, - (Type::Int, Type::PublicKey) => Comparison::Superset, - (Type::Int, Type::Nil) => Comparison::Superset, - (Type::Int, Type::Bool) => Comparison::Superset, - (Type::Bool, Type::PublicKey) => Comparison::Unrelated, - (Type::Bool, Type::Bytes32) => Comparison::Unrelated, - (Type::Bool, Type::Nil) => Comparison::Unrelated, - (Type::Nil, Type::Bytes32) => Comparison::Unrelated, - (Type::Nil, Type::PublicKey) => Comparison::Unrelated, - (Type::PublicKey, Type::Bytes32) => Comparison::Unrelated, - (Type::PublicKey, Type::Bool) => Comparison::Unrelated, - (Type::PublicKey, Type::Nil) => Comparison::Unrelated, - (Type::Bytes32, Type::PublicKey) => Comparison::Unrelated, - (Type::Bytes32, Type::Bool) => Comparison::Unrelated, - (Type::Bytes32, Type::Nil) => Comparison::Unrelated, - - // These are the variants of the `Nullable` type. - (Type::Nil, Type::Nullable(..)) => Comparison::Assignable, - (Type::Nullable(lhs), Type::Nullable(rhs)) => { - self.compare_type_visitor(*lhs, *rhs, ctx) - } - (_, Type::Nullable(inner)) => self.compare_type_visitor(lhs, *inner, ctx), - (Type::Nullable(inner), Type::Bytes) => { - Comparison::Castable & self.compare_type_visitor(*inner, rhs, ctx) - } - - // TODO: Unions would make this more generalized and useful. - // I should add unions back. - (Type::Nullable(_inner), _) => Comparison::Unrelated, - - // Compare both sides of a `Pair`. - (Type::Pair(lhs), Type::Pair(rhs)) => { - let first = self.compare_type_visitor(lhs.first, rhs.first, ctx); - let rest = self.compare_type_visitor(lhs.rest, rhs.rest, ctx); - first & rest - } - - // A `Pair` is a valid `List` if its first type is the same as the list's inner type - // and the rest is also a valid `List` of the same type. - // However, it's not considered equal but rather assignable. - (Type::Pair(pair), Type::List(inner)) => { - let inner = self.compare_type_visitor(pair.first, *inner, ctx); - let rest = self.compare_type_visitor(pair.rest, rhs, ctx); - Comparison::Assignable & inner & rest - } - - // A `List` is not a valid pair since `Nil` is also a valid list. - // It's a `Superset` only if the opposite comparison is not `Unrelated`. - (Type::List(..), Type::Pair(..)) => { - Comparison::Superset & self.compare_type_visitor(rhs, lhs, ctx) - } - - // Nothing else can be assigned to or from a `Pair`. - (Type::Pair(..), _) | (_, Type::Pair(..)) => Comparison::Unrelated, - - // A `List` just compares with the inner type of another `List`. - (Type::List(lhs), Type::List(rhs)) => self.compare_type_visitor(*lhs, *rhs, ctx), - - // `Nil` is a valid list. - (Type::Nil, Type::List(..)) => Comparison::Assignable, - (Type::List(..), Type::Nil) => Comparison::Superset, - - // `List` is not compatible with atoms. - (Type::Bytes, Type::List(..)) => Comparison::Unrelated, - (Type::Bytes32, Type::List(..)) => Comparison::Unrelated, - (Type::PublicKey, Type::List(..)) => Comparison::Unrelated, - (Type::Int, Type::List(..)) => Comparison::Unrelated, - (Type::Bool, Type::List(..)) => Comparison::Unrelated, - (Type::List(..), Type::Bytes) => Comparison::Unrelated, - (Type::List(..), Type::Bytes32) => Comparison::Unrelated, - (Type::List(..), Type::PublicKey) => Comparison::Unrelated, - (Type::List(..), Type::Int) => Comparison::Unrelated, - (Type::List(..), Type::Bool) => Comparison::Unrelated, - - // A `Struct` is castable to others only if the fields are identical. - (Type::Struct(lhs), Type::Struct(rhs)) => { - if lhs.fields.len() == rhs.fields.len() { - let mut result = Comparison::Castable; - for i in 0..lhs.fields.len() { - result &= self.compare_type_visitor(lhs.fields[i], rhs.fields[i], ctx); - } - result - } else { - Comparison::Unrelated - } - } - (Type::Struct(..), _) | (_, Type::Struct(..)) => Comparison::Unrelated, - - // Enum variants are assignable only to their respective enum. - (Type::EnumVariant(variant_type), Type::Enum(..)) => { - if variant_type.enum_type == rhs { - Comparison::Assignable - } else { - Comparison::Unrelated - } - } - - // But not the other way around. - (Type::Enum(..), Type::EnumVariant(variant_type)) => { - if variant_type.enum_type == lhs { - Comparison::Superset - } else { - Comparison::Unrelated - } - } - - // You can cast numeric enums to `Int`. - (Type::EnumVariant(variant_type), Type::Bytes | Type::Int) => { - self.compare_type_visitor(variant_type.enum_type, rhs, ctx) - } - - (Type::Enum(enum_type), Type::Bytes | Type::Int) => { - if enum_type.has_fields { - Comparison::Unrelated - } else { - Comparison::Castable - } - } - - // Enums and their variants are not assignable to anything else. - (Type::Enum(..), _) | (_, Type::Enum(..)) => Comparison::Unrelated, - (Type::EnumVariant(..), _) | (_, Type::EnumVariant(..)) => Comparison::Unrelated, - - // The parameters and return types of functions are checked. - // This can be made more flexible later. - (Type::Function(lhs), Type::Function(rhs)) => { - if lhs.param_types.len() != rhs.param_types.len() || lhs.rest != rhs.rest { - Comparison::Unrelated - } else { - let mut result = - self.compare_type_visitor(lhs.return_type, rhs.return_type, ctx); - - for i in 0..lhs.param_types.len() { - result &= - self.compare_type_visitor(lhs.param_types[i], rhs.param_types[i], ctx); - } - - result - } - } - - // There is likely nothing that should relate to a function. - (Type::Function(..), _) | (_, Type::Function(..)) => Comparison::Unrelated, - }; - - ctx.visited.remove(&key); - comparison - } - - fn is_cyclic_visitor(&self, ty: TypeId, visited_aliases: &mut HashSet) -> bool { - match self.ty_raw(ty).clone() { - Type::Pair(pair) => { - self.is_cyclic_visitor(pair.first, visited_aliases) - || self.is_cyclic_visitor(pair.rest, visited_aliases) - } - Type::Alias(alias) => { - if !visited_aliases.insert(alias) { - return true; - } - self.is_cyclic_visitor(alias, visited_aliases) - } - Type::List(..) - | Type::Struct(..) - | Type::Enum(..) - | Type::EnumVariant(..) - | Type::Function(..) - | Type::Unknown - | Type::Generic - | Type::Nil - | Type::Any - | Type::Int - | Type::Bool - | Type::Bytes - | Type::Bytes32 - | Type::PublicKey => false, - Type::Nullable(ty) | Type::Optional(ty) => self.is_cyclic_visitor(ty, visited_aliases), - } - } -} - -#[cfg(test)] -mod tests { - use indexmap::IndexMap; - - use crate::{ - compiler::{builtins, Builtins}, - value::{EnumType, EnumVariantType, FunctionType, PairType, Rest, StructType}, - }; - - use super::*; - - fn setup() -> (Database, Builtins) { - let mut db = Database::new(); - let ty = builtins(&mut db); - (db, ty) - } - - fn fields(items: &[TypeId]) -> IndexMap { - let mut fields = IndexMap::new(); - for (i, item) in items.iter().enumerate() { - fields.insert(format!("field_{i}"), *item); - } - fields - } - - #[test] - fn test_substitution() { - let (mut db, ty) = setup(); - - let a = db.alloc_type(Type::Generic); - let b = db.alloc_type(Type::Generic); - - let a_list = db.alloc_type(Type::List(a)); - - let function = db.alloc_type(Type::Function(FunctionType { - param_types: vec![a_list], - rest: Rest::Nil, - return_type: b, - generic_types: vec![a, b], - })); - - let int_list = db.alloc_type(Type::List(ty.int)); - - let substitutions = [(a, ty.int), (b, ty.bool)].into_iter().collect(); - let substituted = db.substitute_type(function, &substitutions); - - let expected = db.alloc_type(Type::Function(FunctionType { - param_types: vec![int_list], - rest: Rest::Nil, - return_type: ty.bool, - generic_types: vec![a, b], - })); - - assert_eq!(db.compare_type(substituted, expected), Comparison::Equal); - } - - #[test] - fn test_alias_resolution() { - let (mut db, ty) = setup(); - - let int_alias = db.alloc_type(Type::Alias(ty.int)); - assert_eq!(db.compare_type(int_alias, ty.int), Comparison::Equal); - assert_eq!(db.compare_type(ty.int, int_alias), Comparison::Equal); - assert_eq!(db.compare_type(int_alias, int_alias), Comparison::Equal); - - let double_alias = db.alloc_type(Type::Alias(int_alias)); - assert_eq!(db.compare_type(double_alias, ty.int), Comparison::Equal); - assert_eq!(db.compare_type(ty.int, double_alias), Comparison::Equal); - assert_eq!( - db.compare_type(double_alias, double_alias), - Comparison::Equal - ); - assert_eq!(db.compare_type(double_alias, int_alias), Comparison::Equal); - assert_eq!(db.compare_type(int_alias, double_alias), Comparison::Equal); - } - - #[test] - fn test_generic_type() { - let (mut db, ty) = setup(); - let a = db.alloc_type(Type::Generic); - let b = db.alloc_type(Type::Generic); - assert_eq!(db.compare_type(a, b), Comparison::Unrelated); - assert_eq!(db.compare_type(a, ty.int), Comparison::Unrelated); - assert_eq!(db.compare_type(ty.int, a), Comparison::Unrelated); - assert_eq!(db.compare_type(a, a), Comparison::Equal); - } - - #[test] - fn test_any_type() { - let (db, ty) = setup(); - assert_eq!(db.compare_type(ty.any, ty.int), Comparison::Castable); - assert_eq!(db.compare_type(ty.any, ty.any), Comparison::Equal); - assert_eq!(db.compare_type(ty.int, ty.any), Comparison::Assignable); - } - - #[test] - fn test_unknown_type() { - let (db, ty) = setup(); - assert_eq!(db.compare_type(ty.any, ty.unknown), Comparison::Equal); - assert_eq!(db.compare_type(ty.int, ty.unknown), Comparison::Equal); - assert_eq!(db.compare_type(ty.unknown, ty.int), Comparison::Equal); - assert_eq!(db.compare_type(ty.unknown, ty.any), Comparison::Equal); - } - - #[test] - fn test_atom_types() { - let (db, ty) = setup(); - assert_eq!(db.compare_type(ty.bytes, ty.bytes32), Comparison::Superset); - assert_eq!(db.compare_type(ty.bytes32, ty.bytes), Comparison::Equal); - assert_eq!( - db.compare_type(ty.bytes, ty.public_key), - Comparison::Superset - ); - assert_eq!( - db.compare_type(ty.public_key, ty.bytes), - Comparison::Castable - ); - assert_eq!(db.compare_type(ty.bytes, ty.int), Comparison::Castable); - assert_eq!(db.compare_type(ty.int, ty.bytes), Comparison::Castable); - assert_eq!(db.compare_type(ty.int, ty.int), Comparison::Equal); - } - - #[test] - fn test_nil_type() { - let (mut db, ty) = setup(); - - let list = db.alloc_type(Type::List(ty.int)); - assert_eq!(db.compare_type(ty.nil, ty.int), Comparison::Castable); - assert_eq!(db.compare_type(ty.nil, ty.bool), Comparison::Castable); - assert_eq!(db.compare_type(ty.nil, ty.bytes), Comparison::Assignable); - assert_eq!(db.compare_type(ty.nil, ty.bytes32), Comparison::Unrelated); - assert_eq!( - db.compare_type(ty.nil, ty.public_key), - Comparison::Unrelated - ); - assert_eq!(db.compare_type(ty.nil, list), Comparison::Assignable); - } - - #[test] - fn test_pair_type() { - let (mut db, ty) = setup(); - - let int_pair = db.alloc_type(Type::Pair(PairType { - first: ty.int, - rest: ty.int, - })); - assert_eq!(db.compare_type(int_pair, int_pair), Comparison::Equal); - - let bytes_pair = db.alloc_type(Type::Pair(PairType { - first: ty.bytes, - rest: ty.bytes, - })); - assert_eq!(db.compare_type(int_pair, bytes_pair), Comparison::Castable); - assert_eq!(db.compare_type(bytes_pair, int_pair), Comparison::Castable); - - let bytes32_pair = db.alloc_type(Type::Pair(PairType { - first: ty.bytes32, - rest: ty.bytes32, - })); - assert_eq!( - db.compare_type(bytes_pair, bytes32_pair), - Comparison::Superset - ); - assert_eq!(db.compare_type(bytes32_pair, bytes_pair), Comparison::Equal); - - let bytes32_bytes = db.alloc_type(Type::Pair(PairType { - first: ty.bytes32, - rest: ty.bytes, - })); - assert_eq!( - db.compare_type(bytes_pair, bytes32_bytes), - Comparison::Superset - ); - assert_eq!( - db.compare_type(bytes32_bytes, bytes_pair), - Comparison::Equal - ); - } - - #[test] - fn test_list_types() { - let (mut db, ty) = setup(); - - let int_list = db.alloc_type(Type::List(ty.int)); - assert_eq!(db.compare_type(int_list, int_list), Comparison::Equal); - - let pair_list = db.alloc_type(Type::Pair(PairType { - first: ty.int, - rest: int_list, - })); - assert_eq!(db.compare_type(pair_list, int_list), Comparison::Assignable); - assert_eq!(db.compare_type(int_list, pair_list), Comparison::Superset); - - let pair_nil = db.alloc_type(Type::Pair(PairType { - first: ty.int, - rest: ty.nil, - })); - assert_eq!(db.compare_type(pair_nil, int_list), Comparison::Assignable); - - let pair_unrelated_first = db.alloc_type(Type::Pair(PairType { - first: pair_list, - rest: ty.nil, - })); - assert_eq!( - db.compare_type(pair_unrelated_first, int_list), - Comparison::Unrelated - ); - - let pair_unrelated_rest = db.alloc_type(Type::Pair(PairType { - first: ty.int, - rest: ty.int, - })); - assert_eq!( - db.compare_type(pair_unrelated_rest, int_list), - Comparison::Unrelated - ); - } - - #[test] - fn test_struct_types() { - let (mut db, ty) = setup(); - - let two_ints = db.alloc_type(Type::Unknown); - *db.ty_mut(two_ints) = Type::Struct(StructType { - original_type_id: two_ints, - fields: fields(&[ty.int, ty.int]), - rest: Rest::Nil, - }); - assert_eq!(db.compare_type(two_ints, two_ints), Comparison::Equal); - - let one_int = db.alloc_type(Type::Unknown); - *db.ty_mut(one_int) = Type::Struct(StructType { - original_type_id: one_int, - fields: fields(&[ty.int]), - rest: Rest::Nil, - }); - assert_eq!(db.compare_type(one_int, two_ints), Comparison::Unrelated); - - let empty_struct = db.alloc_type(Type::Unknown); - *db.ty_mut(empty_struct) = Type::Struct(StructType { - original_type_id: empty_struct, - fields: fields(&[]), - rest: Rest::Nil, - }); - assert_eq!( - db.compare_type(empty_struct, empty_struct), - Comparison::Equal - ); - } - - #[test] - fn test_enum_types() { - let (mut db, ty) = setup(); - - let enum_type = db.alloc_type(Type::Unknown); - - let variant = db.alloc_type(Type::Unknown); - *db.ty_mut(variant) = Type::EnumVariant(EnumVariantType { - enum_type, - original_type_id: variant, - fields: Some(fields(&[ty.int])), - discriminant: ty.unknown_hir, - rest: Rest::Nil, - }); - - *db.ty_mut(enum_type) = Type::Enum(EnumType { - has_fields: true, - variants: fields(&[variant]), - }); - - assert_eq!(db.compare_type(enum_type, variant), Comparison::Superset); - assert_eq!(db.compare_type(variant, enum_type), Comparison::Assignable); - } - - #[test] - fn test_function_types() { - let (mut db, ty) = setup(); - - let int_to_bool = db.alloc_type(Type::Function(FunctionType { - param_types: vec![ty.int], - rest: Rest::Nil, - return_type: ty.bool, - generic_types: Vec::new(), - })); - assert_eq!(db.compare_type(int_to_bool, int_to_bool), Comparison::Equal); - - let int_to_int = db.alloc_type(Type::Function(FunctionType { - param_types: vec![ty.int], - rest: Rest::Nil, - return_type: ty.int, - generic_types: Vec::new(), - })); - assert_eq!( - db.compare_type(int_to_bool, int_to_int), - Comparison::Castable - ); - assert_eq!( - db.compare_type(int_to_int, int_to_bool), - Comparison::Superset - ); - - let int_list = db.alloc_type(Type::List(ty.int)); - let int_list_to_bool = db.alloc_type(Type::Function(FunctionType { - param_types: vec![int_list], - rest: Rest::Nil, - return_type: ty.bool, - generic_types: Vec::new(), - })); - assert_eq!( - db.compare_type(int_to_bool, int_list_to_bool), - Comparison::Unrelated - ); - assert_eq!( - db.compare_type(int_list_to_bool, int_to_bool), - Comparison::Unrelated - ); - } -} diff --git a/crates/rue-compiler/src/dependency_graph.rs b/crates/rue-compiler/src/dependency_graph.rs index 5727086..c95b73e 100644 --- a/crates/rue-compiler/src/dependency_graph.rs +++ b/crates/rue-compiler/src/dependency_graph.rs @@ -5,7 +5,6 @@ use crate::{ environment::Environment, hir::Hir, symbol::{Function, Module, Symbol}, - value::Rest, Database, EnvironmentId, ErrorKind, HirId, ScopeId, SymbolId, }; @@ -134,10 +133,9 @@ impl<'a> GraphBuilder<'a> { .filter(|&symbol_id| matches!(self.db.symbol(symbol_id), Symbol::Parameter(_))) .collect(); - let environment_id = self.db.alloc_env(Environment::function( - parameters, - function.ty.rest != Rest::Nil, - )); + let environment_id = self + .db + .alloc_env(Environment::function(parameters, !function.nil_terminated)); self.graph .environments diff --git a/crates/rue-compiler/src/error.rs b/crates/rue-compiler/src/error.rs index 96fc2fa..7b6f7a6 100644 --- a/crates/rue-compiler/src/error.rs +++ b/crates/rue-compiler/src/error.rs @@ -49,33 +49,7 @@ pub enum WarningKind { UnusedEnumVariant(String), UnusedStruct(String), UnusedTypeAlias(String), - RedundantNullableType(String), - RedundantTypeCheck(String), -} - -impl fmt::Display for WarningKind { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let message = match self { - Self::UnusedFunction(name) => format!("Unused function `{name}`."), - Self::UnusedInlineFunction(name) => format!("Unused inline function `{name}`."), - Self::UnusedParameter(name) => format!("Unused parameter `{name}`."), - Self::UnusedConst(name) => format!("Unused constant `{name}`."), - Self::UnusedInlineConst(name) => format!("Unused inline constant `{name}`."), - Self::UnusedLet(name) => format!("Unused let binding `{name}`."), - Self::UnusedGenericType(name) => format!("Unused generic type `{name}`."), - Self::UnusedEnum(name) => format!("Unused enum `{name}`."), - Self::UnusedEnumVariant(name) => format!("Unused enum variant `{name}`."), - Self::UnusedStruct(name) => format!("Unused struct `{name}`."), - Self::UnusedTypeAlias(name) => format!("Unused type alias `{name}`."), - Self::RedundantNullableType(ty) => { - format!("This has no effect, since `{ty}` is already a nullable type.") - } - Self::RedundantTypeCheck(ty) => format!( - "It's redundant to guard against `{ty}`, since the value already has that type." - ), - }; - write!(f, "{}", message.trim()) - } + UnnecessaryTypeCheck(String, String), } #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -102,7 +76,6 @@ pub enum ErrorKind { UncallableType(String), ArgumentMismatch(usize, usize), ArgumentMismatchSpread(usize, usize), - ArgumentMismatchOptional(usize, usize), // Field initialization. UninitializableType(String), @@ -117,15 +90,11 @@ pub enum ErrorKind { InvalidFieldAccess(String, String), InvalidIndexAccess(String), - // Spread and optional. + // Spread syntax. InvalidSpreadItem, InvalidSpreadArgument, InvalidSpreadParameter, InvalidSpreadField, - InvalidOptionalParameter, - InvalidOptionalField, - OptionalParameterSpread, - OptionalFieldSpread, UnsupportedFunctionSpread, RequiredFunctionSpread, @@ -143,12 +112,13 @@ pub enum ErrorKind { InvalidSymbolPath(Option), ExpectedTypePath(String), ExpectedSymbolPath(String), + UnexpectedGenericArgs, + ExpectedGenericArgs, + GenericArgsMismatch(usize, usize), // Type guards. - UnsupportedTypeGuard(String, String), - NonAnyPairTypeGuard, - NonListPairTypeGuard, - InvalidExistanceCheck(String), + ImpossibleTypeCheck(String, String), + RecursiveTypeCheck(String, String), // Blocks. ImplicitReturnInIf, @@ -165,28 +135,50 @@ pub enum ErrorKind { RecursiveInlineFunctionCall, } +impl fmt::Display for WarningKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let message = match self { + Self::UnusedFunction(name) => format!("Unused function `{name}`"), + Self::UnusedInlineFunction(name) => format!("Unused inline function `{name}`"), + Self::UnusedParameter(name) => format!("Unused parameter `{name}`"), + Self::UnusedConst(name) => format!("Unused constant `{name}`"), + Self::UnusedInlineConst(name) => format!("Unused inline constant `{name}`"), + Self::UnusedLet(name) => format!("Unused let binding `{name}`"), + Self::UnusedGenericType(name) => format!("Unused generic type `{name}`"), + Self::UnusedEnum(name) => format!("Unused enum `{name}`"), + Self::UnusedEnumVariant(name) => format!("Unused enum variant `{name}`"), + Self::UnusedStruct(name) => format!("Unused struct `{name}`"), + Self::UnusedTypeAlias(name) => format!("Unused type alias `{name}`"), + Self::UnnecessaryTypeCheck(from, to) => { + format!("Checking `{from}` against `{to}` has no effect") + } + }; + write!(f, "{}", message.trim()) + } +} + impl fmt::Display for ErrorKind { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let message = match self { // Duplicate definitions. - Self::DuplicateType(name) => format!("There is already a type named `{name}` in this scope."), - Self::DuplicateSymbol(name) => format!("There is already a symbol named `{name}` in this scope."), + Self::DuplicateType(name) => format!("There is already a type named `{name}` in this scope"), + Self::DuplicateSymbol(name) => format!("There is already a symbol named `{name}` in this scope"), Self::ModuleNameTakenByEnum(name) => formatdoc!(" There is already an enum type named `{name}` in this scope. \ - This isn't allowed to prevent ambiguity when referencing items. + This isn't allowed to prevent ambiguity when referencing items "), Self::EnumNameTakenByModule(name) => formatdoc!(" There is already a module named `{name}` in this scope. \ - This isn't allowed to prevent ambiguity when referencing items. + This isn't allowed to prevent ambiguity when referencing items "), // Invalid symbol references. - Self::UnknownSymbol(name) => format!("Reference to unknown symbol `{name}`."), - Self::UnknownType(name) => format!("Reference to unknown type `{name}`."), + Self::UnknownSymbol(name) => format!("Reference to unknown symbol `{name}`"), + Self::UnknownType(name) => format!("Reference to unknown type `{name}`"), Self::InlineFunctionReference(name) => formatdoc!(" Cannot reference inline function `{name}`, since it is not a value. \ Inline functions must be resolved at compile time. \ - Try calling the function instead. + Try calling the function instead "), Self::ModuleReference(name) => formatdoc!(" Cannot reference module `{name}`, since it is not a value. \ @@ -198,126 +190,116 @@ impl fmt::Display for ErrorKind { Cycle detected when resolving type alias `{name}`. \ Type aliases cannot reference themselves. "), - Self::TypeMismatch(found, expected) => format!("Expected type `{expected}`, but found `{found}`."), - Self::CastMismatch(found, expected) => format!("Cannot cast type `{found}` to `{expected}`."), - Self::CannotInferType => "Lambda parameter type could not be inferred.".to_string(), + Self::TypeMismatch(found, expected) => format!("Expected type `{expected}`, but found `{found}`"), + Self::CastMismatch(found, expected) => format!("Cannot cast type `{found}` to `{expected}`"), + Self::CannotInferType => "Lambda parameter type could not be inferred".to_string(), // Function calls. - Self::UncallableType(ty) => format!("Expression with type `{ty}` cannot be called, since it is not a function."), + Self::UncallableType(ty) => format!("Expression with type `{ty}` cannot be called, since it is not a function"), Self::ArgumentMismatch(found, expected) => { format!( - "Expected {expected} argument{}, but found {found}.", + "Expected {expected} argument{}, but found {found}", if *expected == 1 { "" } else { "s" } ) } Self::ArgumentMismatchSpread (found, expected) => { format!( - "Expected at least {expected} argument{}, but found {found}.", + "Expected at least {expected} argument{}, but found {found}", if *expected == 1 { "" } else { "s" } ) } - Self::ArgumentMismatchOptional (found, expected)=> { - format!("Expected either {} or {expected} arguments, but found {found}.", expected - 1) - } // Field initialization. Self::UninitializableType(ty) => formatdoc!(" Cannot initializable type `{ty}`. \ - Only structs and enum variants with fields can be initialized. + Only structs and enum variants with fields can be initialized "), Self::InvalidEnumVariantReference(name) => formatdoc!(" Cannot reference enum variant `{name}`. \ Enum variants with fields cannot be referenced directly. \ - Consider initializing the enum variant instead. + Consider initializing the enum variant instead "), Self::InvalidEnumVariantInitializer(name) => formatdoc!(" Cannot initialize enum variant `{name}`. \ Enum variants without fields cannot be initialized. \ - Consider referencing the enum variant directly. + Consider referencing the enum variant directly "), - Self::DuplicateInitializerField(name) => format!("Duplicate field `{name}` specified in initializer."), - Self::UnknownInitializerField(name) => format!("Unknown field `{name}` specified in initializer."), - Self::MissingInitializerFields(fields) => format!("Missing fields in initializer: {}.", join_names(fields)), + Self::DuplicateInitializerField(name) => format!("Duplicate field `{name}` specified in initializer"), + Self::UnknownInitializerField(name) => format!("Unknown field `{name}` specified in initializer"), + Self::MissingInitializerFields(fields) => format!("Missing fields in initializer: {}", join_names(fields)), // Field access. - Self::UnknownField(name) => format!("Cannot reference unknown field `{name}`."), - Self::InvalidFieldAccess(field, ty) => format!("Cannot reference field `{field}` of type `{ty}`."), - Self::InvalidIndexAccess(ty) => format!("Cannot index into type `{ty}`."), + Self::UnknownField(name) => format!("Cannot reference unknown field `{name}`"), + Self::InvalidFieldAccess(field, ty) => format!("Cannot reference field `{field}` of type `{ty}`"), + Self::InvalidIndexAccess(ty) => format!("Cannot index into type `{ty}`"), - // Spread and optional. + // Spread syntax. Self::InvalidSpreadItem => formatdoc!(" The spread operator can only be used on the last item in a list. \ This is because it requires recursion at runtime to concatenate lists together. \ - By only allowing it on the last item by default, this additional complexity and runtime cost is avoided. + By only allowing it on the last item by default, this additional complexity and runtime cost is avoided "), Self::InvalidSpreadArgument => formatdoc!(" The spread operator can only be used on the last argument in a function call. \ This is because it requires recursion at runtime to concatenate lists together. \ - By only allowing it on the last item by default, this additional complexity and runtime cost is avoided. + By only allowing it on the last item by default, this additional complexity and runtime cost is avoided "), Self::InvalidSpreadParameter => formatdoc!(" The spread operator can only be used on the last parameter in a function. \ - Otherwise, it would be ambiguous where the parameter should start and end. + Otherwise, it would be ambiguous where the parameter should start and end "), Self::InvalidSpreadField => formatdoc!(" The spread operator can only be used on the last field. \ - Otherwise, it would be ambiguous where the field should start and end. - "), - Self::InvalidOptionalParameter => formatdoc!(" - Only the last parameter in a function can be optional. \ - Otherwise, it would be ambiguous which optional parameter was specified. + Otherwise, it would be ambiguous where the field should start and end "), - Self::InvalidOptionalField => formatdoc!(" - Only the last field can be optional. \ - Otherwise, it would be ambiguous which optional field was specified. - "), - Self::OptionalParameterSpread => "The spread operator cannot be used on optional parameters.".to_string(), - Self::OptionalFieldSpread => "The spread operator cannot be used on optional fields.".to_string(), - Self::UnsupportedFunctionSpread => "This function does not support the spread operator on its last argument.".to_string(), - Self::RequiredFunctionSpread => "This function requires the spread operator on its last argument.".to_string(), + Self::UnsupportedFunctionSpread => "This function does not support the spread operator on its last argument".to_string(), + Self::RequiredFunctionSpread => "This function requires the spread operator on its last argument".to_string(), // Enum variant definitions. - Self::DuplicateEnumVariant(name) => format!("Duplicate enum variant `{name}` specified."), - Self::DuplicateEnumDiscriminant(discriminant) => format!("Duplicate enum discriminant `{discriminant}` specified."), - Self::EnumDiscriminantTooLarge => "Enum discriminant is too large to allocate in CLVM.".to_string(), + Self::DuplicateEnumVariant(name) => format!("Duplicate enum variant `{name}` specified"), + Self::DuplicateEnumDiscriminant(discriminant) => format!("Duplicate enum discriminant `{discriminant}` specified"), + Self::EnumDiscriminantTooLarge => "Enum discriminant is too large to allocate in CLVM".to_string(), // Paths. - Self::UnknownEnumVariantPath(name) => format!("Unknown enum variant `{name}`."), - Self::UnknownModulePath(name) => format!("Could not resolve `{name}` in module."), - Self::PrivateSymbol(name) => format!("Cannot access private symbol `{name}` in module."), - Self::PrivateType(name) => format!("Cannot access private type `{name}` in module."), - Self::InvalidTypePath(ty) => format!("Cannot path into type `{ty}`."), + Self::UnknownEnumVariantPath(name) => format!("Unknown enum variant `{name}`"), + Self::UnknownModulePath(name) => format!("Could not resolve `{name}` in module"), + Self::PrivateSymbol(name) => format!("Cannot access private symbol `{name}` in module"), + Self::PrivateType(name) => format!("Cannot access private type `{name}` in module"), + Self::InvalidTypePath(ty) => format!("Cannot path into type `{ty}`"), Self::InvalidSymbolPath(name) => if let Some(name) = name { - format!("Cannot path into symbol `{name}`.") + format!("Cannot path into symbol `{name}`") } else { "Cannot path into symbol.".to_string() }, - Self::ExpectedTypePath(name) => format!("Expected type, but found symbol `{name}` instead."), - Self::ExpectedSymbolPath(name) => format!("Expected symbol, but found type `{name}` instead."), + Self::ExpectedTypePath(name) => format!("Expected type, but found symbol `{name}` instead"), + Self::ExpectedSymbolPath(name) => format!("Expected symbol, but found type `{name}` instead"), + Self::UnexpectedGenericArgs => "Unexpected generic arguments".to_string(), + Self::ExpectedGenericArgs => "Expected generic arguments".to_string(), + Self::GenericArgsMismatch(found, expected) => { + format!("Expected {expected} generic argument{}, but found {found}", if *expected == 1 { "" } else { "s" }) + } // Type guards. - Self::UnsupportedTypeGuard(from, to) => format!("Cannot check type `{from}` against `{to}`."), - Self::NonAnyPairTypeGuard => "Cannot check `Any` against pair types other than `(Any, Any)`.".to_string(), - Self::NonListPairTypeGuard => "Cannot check `T[]` against pair types other than `(T, T[])`.".to_string(), - Self::InvalidExistanceCheck(ty) => format!("Cannot check existence of value with type `{ty}`, since it can't be undefined."), + Self::ImpossibleTypeCheck(from, to) => format!("Cannot check type `{from}` against `{to}`"), + Self::RecursiveTypeCheck(from, to) => format!("Checking type `{from}` against `{to}` would result in infinite recursion at runtime"), // Blocks. Self::ImplicitReturnInIf => formatdoc!(" Implicit returns are not allowed in if statements. \ Either use an explicit return statement at the end of the block, \ - or raise an error. + or raise an error "), - Self::ExplicitReturnInExpr => "Explicit return is not allowed within expressions.".to_string(), - Self::EmptyBlock => "Blocks must either return an expression or raise an error.".to_string(), + Self::ExplicitReturnInExpr => "Explicit return is not allowed within expressions".to_string(), + Self::EmptyBlock => "Blocks must either return an expression or raise an error".to_string(), // Atoms. - Self::NonAtomEquality(ty) => format!("Cannot check equality on non-atom type `{ty}`."), - Self::IntegerTooLarge => "Integer literal is too large to allocate in CLVM.".to_string(), + Self::NonAtomEquality(ty) => format!("Cannot check equality on non-atom type `{ty}`"), + Self::IntegerTooLarge => "Integer literal is too large to allocate in CLVM".to_string(), // Recursive constants. - Self::RecursiveConstantReference => "Cannot recursively reference constant.".to_string(), - Self::RecursiveInlineConstantReference => "Cannot recursively reference inline constant.".to_string(), - Self::RecursiveInlineFunctionCall => "Cannot recursively call inline function.".to_string(), + Self::RecursiveConstantReference => "Cannot recursively reference constant".to_string(), + Self::RecursiveInlineConstantReference => "Cannot recursively reference inline constant".to_string(), + Self::RecursiveInlineFunctionCall => "Cannot recursively call inline function".to_string(), }; write!(f, "{}", message.trim()) } diff --git a/crates/rue-compiler/src/lib.rs b/crates/rue-compiler/src/lib.rs index 2d8a7fd..4f09cd9 100644 --- a/crates/rue-compiler/src/lib.rs +++ b/crates/rue-compiler/src/lib.rs @@ -22,6 +22,7 @@ use rue_parser::Root; pub use database::*; pub use error::*; +use rue_typing::TypeSystem; #[derive(Debug)] pub struct Output { @@ -29,20 +30,40 @@ pub struct Output { pub node_ptr: NodePtr, } -pub fn compile(allocator: &mut Allocator, root: &Root, mut should_codegen: bool) -> Output { - let mut db = Database::default(); - let mut ctx = setup_compiler(&mut db); +pub fn compile(allocator: &mut Allocator, root: &Root, should_codegen: bool) -> Output { + compile_raw(allocator, root, should_codegen, true) +} + +pub fn compile_raw( + allocator: &mut Allocator, + root: &Root, + mut should_codegen: bool, + should_stdlib: bool, +) -> Output { + let mut db = Database::new(); + let mut ty = TypeSystem::new(); + let mut ctx = setup_compiler(&mut db, &mut ty); + + let stdlib = if should_stdlib { + Some(load_standard_library(&mut ctx)) + } else { + None + }; - let stdlib = load_standard_library(&mut ctx); let main_module_id = load_module(&mut ctx, root); let symbol_table = compile_modules(ctx); let main = try_export_main(&mut db, main_module_id); let graph = build_graph( &mut db, + &ty, &symbol_table, main_module_id, - &[main_module_id, stdlib], + &if let Some(stdlib) = stdlib { + [main_module_id, stdlib].to_vec() + } else { + [main_module_id].to_vec() + }, ); should_codegen &= !db.diagnostics().iter().any(Diagnostic::is_error); diff --git a/crates/rue-compiler/src/lowerer.rs b/crates/rue-compiler/src/lowerer.rs index e81155d..3b956b8 100644 --- a/crates/rue-compiler/src/lowerer.rs +++ b/crates/rue-compiler/src/lowerer.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use rue_typing::HashMap; use indexmap::IndexSet; @@ -8,7 +8,6 @@ use crate::{ hir::Hir, mir::Mir, symbol::{Function, Symbol}, - value::Rest, Database, EnvironmentId, HirId, MirId, ScopeId, SymbolId, }; @@ -322,7 +321,7 @@ impl<'a> Lowerer<'a> { let mut param_map = HashMap::new(); for (i, &symbol_id) in params.iter().enumerate() { - if i + 1 != params.len() || function.ty.rest != Rest::Spread { + if i + 1 != params.len() || function.nil_terminated { let mir_id = self.lower_hir(env_id, args[i]); param_map.insert(symbol_id, mir_id); continue; diff --git a/crates/rue-compiler/src/scope.rs b/crates/rue-compiler/src/scope.rs index 328bb3a..f9d9a35 100644 --- a/crates/rue-compiler/src/scope.rs +++ b/crates/rue-compiler/src/scope.rs @@ -1,6 +1,7 @@ use indexmap::IndexMap; +use rue_typing::TypeId; -use crate::{database::TypeId, SymbolId}; +use crate::SymbolId; #[derive(Debug, Default)] pub struct Scope { diff --git a/crates/rue-compiler/src/symbol.rs b/crates/rue-compiler/src/symbol.rs index 9491aae..b61f816 100644 --- a/crates/rue-compiler/src/symbol.rs +++ b/crates/rue-compiler/src/symbol.rs @@ -1,8 +1,9 @@ use indexmap::IndexSet; +use rue_typing::TypeId; use crate::{ - database::{HirId, ScopeId, TypeId}, - value::{FunctionType, Value}, + database::{HirId, ScopeId}, + value::Value, SymbolId, }; @@ -41,11 +42,12 @@ impl Symbol { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy)] pub struct Function { pub scope_id: ScopeId, pub hir_id: HirId, - pub ty: FunctionType, + pub type_id: TypeId, + pub nil_terminated: bool, } #[derive(Debug, Clone)] diff --git a/crates/rue-compiler/src/value.rs b/crates/rue-compiler/src/value.rs index 64db5be..89b61ac 100644 --- a/crates/rue-compiler/src/value.rs +++ b/crates/rue-compiler/src/value.rs @@ -1,14 +1,12 @@ -use std::collections::HashMap; +use std::ops::Not; -mod guard; -mod guard_path; -mod ty; +use rue_typing::HashMap; -pub use guard::*; -pub use guard_path::*; -pub use ty::*; +use rue_typing::TypeId; +use rue_typing::TypePath; -use crate::{HirId, TypeId}; +use crate::HirId; +use crate::SymbolId; #[derive(Debug, Clone)] pub struct Value { @@ -28,28 +26,58 @@ impl Value { } } - pub fn then_guards(&self) -> HashMap { + pub fn then_guards(&self) -> HashMap { self.guards .iter() - .map(|(k, v)| (k.clone(), v.then_type)) + .map(|(guard_path, guard)| (guard_path.clone(), guard.then_type)) .collect() } - pub fn else_guards(&self) -> HashMap { + pub fn else_guards(&self) -> HashMap { self.guards .iter() - .map(|(k, v)| (k.clone(), v.else_type)) + .map(|(guard_path, guard)| (guard_path.clone(), guard.else_type)) .collect() } +} + +#[derive(Debug, Clone, Copy)] +pub struct Guard { + pub then_type: TypeId, + pub else_type: TypeId, +} + +impl Guard { + pub fn new(then_type: TypeId, else_type: TypeId) -> Self { + Self { + then_type, + else_type, + } + } +} + +impl Not for Guard { + type Output = Self; - pub fn extend_guard_path(mut self, old_value: Value, item: GuardPathItem) -> Self { - match old_value.guard_path { - Some(mut path) => { - path.items.push(item); - self.guard_path = Some(path); - self - } - None => self, + fn not(self) -> Self::Output { + Self { + then_type: self.else_type, + else_type: self.then_type, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct GuardPath { + pub symbol_id: SymbolId, + pub items: Vec, +} + +impl GuardPath { + pub fn new(symbol_id: SymbolId) -> Self { + Self { + symbol_id, + items: Vec::new(), } } } diff --git a/crates/rue-compiler/src/value/guard.rs b/crates/rue-compiler/src/value/guard.rs deleted file mode 100644 index 351a083..0000000 --- a/crates/rue-compiler/src/value/guard.rs +++ /dev/null @@ -1,51 +0,0 @@ -use std::ops::Not; - -use crate::TypeId; - -#[derive(Debug, Clone, Copy)] -pub struct Guard { - pub then_type: TypeOverride, - pub else_type: TypeOverride, -} - -impl Guard { - pub fn new(then_type: TypeOverride, else_type: TypeOverride) -> Self { - Self { - then_type, - else_type, - } - } -} - -impl Not for Guard { - type Output = Self; - - fn not(self) -> Self::Output { - Self { - then_type: self.else_type, - else_type: self.then_type, - } - } -} - -#[derive(Debug, Clone, Copy)] -pub struct TypeOverride { - pub type_id: TypeId, - pub mutation: Mutation, -} - -impl TypeOverride { - pub fn new(type_id: TypeId) -> Self { - Self { - type_id, - mutation: Mutation::None, - } - } -} - -#[derive(Debug, Default, Clone, Copy)] -pub enum Mutation { - #[default] - None, - UnwrapOptional, -} diff --git a/crates/rue-compiler/src/value/guard_path.rs b/crates/rue-compiler/src/value/guard_path.rs deleted file mode 100644 index 9b1e9a1..0000000 --- a/crates/rue-compiler/src/value/guard_path.rs +++ /dev/null @@ -1,23 +0,0 @@ -use crate::SymbolId; - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct GuardPath { - pub symbol_id: SymbolId, - pub items: Vec, -} - -impl GuardPath { - pub fn new(symbol_id: SymbolId) -> Self { - Self { - symbol_id, - items: Vec::new(), - } - } -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum GuardPathItem { - Field(String), - First, - Rest, -} diff --git a/crates/rue-compiler/src/value/ty.rs b/crates/rue-compiler/src/value/ty.rs deleted file mode 100644 index e724aed..0000000 --- a/crates/rue-compiler/src/value/ty.rs +++ /dev/null @@ -1,69 +0,0 @@ -use indexmap::IndexMap; - -use crate::database::{HirId, TypeId}; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum Type { - Unknown, - Generic, - Nil, - Any, - Int, - Bool, - Bytes, - Bytes32, - PublicKey, - Pair(PairType), - List(TypeId), - Struct(StructType), - Enum(EnumType), - EnumVariant(EnumVariantType), - Function(FunctionType), - Alias(TypeId), - Nullable(TypeId), - Optional(TypeId), -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct PairType { - pub first: TypeId, - pub rest: TypeId, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct StructType { - pub original_type_id: TypeId, - pub fields: IndexMap, - pub rest: Rest, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct EnumType { - pub has_fields: bool, - pub variants: IndexMap, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct EnumVariantType { - pub enum_type: TypeId, - pub original_type_id: TypeId, - pub fields: Option>, - pub rest: Rest, - pub discriminant: HirId, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct FunctionType { - pub generic_types: Vec, - pub param_types: Vec, - pub rest: Rest, - pub return_type: TypeId, -} - -#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] -pub enum Rest { - #[default] - Nil, - Spread, - Optional, -} diff --git a/crates/rue-compiler/stdlib.rue b/crates/rue-compiler/stdlib.rue index 27b8540..f265304 100644 --- a/crates/rue-compiler/stdlib.rue +++ b/crates/rue-compiler/stdlib.rue @@ -1,6 +1,6 @@ export enum Condition { Remark = 1 { - value?: Any, + ...value: (Any, nil) | nil, }, AggSigParent = 43 { public_key: PublicKey, @@ -37,7 +37,7 @@ export enum Condition { CreateCoin = 51 { puzzle_hash: Bytes32, amount: Int, - memos?: Bytes[], + ...memos: (List, nil) | nil, }, ReserveFee = 52 { amount: Int, @@ -105,12 +105,12 @@ export enum Condition { }, Softfork = 90 { cost: Int, - value?: Any, + ...value: (Any, nil) | nil, }, } -export fun concat(a: T[], b: T[]) -> T[] { - if a is (T, T[]) { +export fun concat(a: List, b: List) -> List { + if a is (T, List) { return [a.first, ...concat(a.rest, b)]; } b @@ -167,8 +167,8 @@ inline fun update_hash_with_parameter( sha256(CONS_PREIMAGE_PREFIX + two_item_list_hash(quote_hash(parameter_hash), environment_hash)) } -fun curried_params_hash(parameters: Bytes32[]) -> Bytes32 { - if parameters is Nil { +fun curried_params_hash(parameters: List) -> Bytes32 { + if parameters is nil { return ONE_TREE_HASH; } update_hash_with_parameter(parameters.first, curried_params_hash(parameters.rest)) @@ -176,7 +176,7 @@ fun curried_params_hash(parameters: Bytes32[]) -> Bytes32 { export inline fun curry_tree_hash( mod_hash: Bytes32, - ...parameters: Bytes32[] + ...parameters: List ) -> Bytes32 { apply_hash(mod_hash, curried_params_hash(parameters)) } @@ -191,15 +191,15 @@ export fun calculate_coin_id( sha256(parent_coin_id + puzzle_hash + amount as Bytes) } -export fun map(list: T[], fn: fun(item: T) -> U) -> U[] { - if list is Nil { +export fun map(list: List, fn: fun(item: T) -> U) -> List { + if list is nil { return nil; } [fn(list.first), ...map(list.rest, fn)] } -export fun filter(list: T[], fn: fun(item: T) -> Bool) -> T[] { - if list is Nil { +export fun filter(list: List, fn: fun(item: T) -> Bool) -> List { + if list is nil { return nil; } if fn(list.first) { @@ -208,8 +208,8 @@ export fun filter(list: T[], fn: fun(item: T) -> Bool) -> T[] { filter(list.rest, fn) } -export fun fold(list: T[], initial: U, fn: fun(acc: U, item: T) -> U) -> U { - if list is Nil { +export fun fold(list: List, initial: U, fn: fun(acc: U, item: T) -> U) -> U { + if list is nil { return initial; } fold(list.rest, fn(initial, list.first), fn) diff --git a/crates/rue-lsp/Cargo.toml b/crates/rue-lsp/Cargo.toml index 0440666..c45472a 100644 --- a/crates/rue-lsp/Cargo.toml +++ b/crates/rue-lsp/Cargo.toml @@ -15,8 +15,8 @@ categories = { workspace = true } workspace = true [dependencies] -rue-parser = { path = "../rue-parser", version = "0.1.0" } -rue-compiler = { path = "../rue-compiler", version = "0.1.0" } -tokio = { version = "1.37.0", features = ["full"] } -tower-lsp = "0.20.0" -clvmr = "0.6.1" +rue-parser = { workspace = true } +rue-compiler = { workspace = true } +tokio = { workspace = true, features = ["full"] } +tower-lsp = { workspace = true } +clvmr = { workspace = true } diff --git a/crates/rue-lsp/src/main.rs b/crates/rue-lsp/src/main.rs index 9390cb8..2036fab 100644 --- a/crates/rue-lsp/src/main.rs +++ b/crates/rue-lsp/src/main.rs @@ -1,5 +1,5 @@ use clvmr::Allocator; -use rue_compiler::{compile, DiagnosticKind}; +use rue_compiler::{compile_raw, DiagnosticKind}; use rue_parser::{line_col, parse, LineCol}; use tower_lsp::jsonrpc::Result; use tower_lsp::lsp_types::{ @@ -60,7 +60,7 @@ impl LanguageServer for Backend { /// This is a hack to get around a Rust compiler error. #[allow(clippy::needless_pass_by_value)] fn analyze_owned(root: rue_parser::Root) -> Vec { - compile(&mut Allocator::new(), &root, false).diagnostics + compile_raw(&mut Allocator::new(), &root, false, true).diagnostics } impl Backend { diff --git a/crates/rue-parser/Cargo.toml b/crates/rue-parser/Cargo.toml index 4fc88b0..f1ae2bf 100644 --- a/crates/rue-parser/Cargo.toml +++ b/crates/rue-parser/Cargo.toml @@ -15,8 +15,8 @@ categories = { workspace = true } workspace = true [dependencies] -indexmap = "2.2.6" -num-derive = "0.4.2" -num-traits = "0.2.18" -rowan = "0.15.15" -rue-lexer = { path = "../rue-lexer", version = "0.1.0" } +rue-lexer = { workspace = true } +indexmap = { workspace = true } +num-derive = { workspace = true } +num-traits = { workspace = true } +rowan = { workspace = true } diff --git a/crates/rue-parser/src/ast.rs b/crates/rue-parser/src/ast.rs index d5a274c..d644c6e 100644 --- a/crates/rue-parser/src/ast.rs +++ b/crates/rue-parser/src/ast.rs @@ -97,8 +97,6 @@ ast_enum!( IfExpr, FunctionCallExpr, FieldAccessExpr, - IndexAccessExpr, - ExistsExpr ); ast_node!(PathExpr); ast_node!(InitializerExpr); @@ -116,27 +114,24 @@ ast_node!(IfExpr); ast_node!(FunctionCallExpr); ast_node!(FunctionCallArg); ast_node!(FieldAccessExpr); -ast_node!(IndexAccessExpr); -ast_node!(ExistsExpr); ast_node!(LambdaExpr); ast_node!(LambdaParam); ast_enum!( Type, + LiteralType, PathType, - ListType, PairType, FunctionType, - NullableType + UnionType ); +ast_node!(LiteralType); ast_node!(PathType); -ast_node!(ListType); -ast_node!(ListTypeItem); ast_node!(PairType); ast_node!(FunctionType); ast_node!(FunctionTypeParam); -ast_node!(NullableType); +ast_node!(UnionType); ast_enum!(Stmt, LetStmt, IfStmt, ReturnStmt, RaiseStmt, AssertStmt, AssumeStmt); ast_node!(LetStmt); @@ -146,7 +141,9 @@ ast_node!(RaiseStmt); ast_node!(AssertStmt); ast_node!(AssumeStmt); -ast_node!(GenericTypes); +ast_node!(GenericArgs); +ast_node!(GenericParams); +ast_node!(PathItem); impl Root { pub fn items(&self) -> Vec { @@ -196,8 +193,8 @@ impl FunctionItem { .find(|token| token.kind() == SyntaxKind::Ident) } - pub fn generic_types(&self) -> Option { - self.syntax().children().find_map(GenericTypes::cast) + pub fn generic_params(&self) -> Option { + self.syntax().children().find_map(GenericParams::cast) } pub fn params(&self) -> Vec { @@ -238,13 +235,6 @@ impl FunctionParam { .find(|token| token.kind() == SyntaxKind::Spread) } - pub fn optional(&self) -> Option { - self.syntax() - .children_with_tokens() - .filter_map(SyntaxElement::into_token) - .find(|token| token.kind() == SyntaxKind::Question) - } - pub fn name(&self) -> Option { self.syntax() .children_with_tokens() @@ -265,6 +255,10 @@ impl TypeAliasItem { .find(|token| token.kind() == SyntaxKind::Ident) } + pub fn generic_params(&self) -> Option { + self.syntax().children().find_map(GenericParams::cast) + } + pub fn ty(&self) -> Option { self.syntax().children().find_map(Type::cast) } @@ -308,13 +302,6 @@ impl StructField { .find(|token| token.kind() == SyntaxKind::Spread) } - pub fn optional(&self) -> Option { - self.syntax() - .children_with_tokens() - .filter_map(SyntaxElement::into_token) - .find(|token| token.kind() == SyntaxKind::Question) - } - pub fn name(&self) -> Option { self.syntax() .children_with_tokens() @@ -512,11 +499,10 @@ impl AssumeStmt { } impl PathExpr { - pub fn idents(&self) -> Vec { + pub fn items(&self) -> Vec { self.syntax() - .children_with_tokens() - .filter_map(SyntaxElement::into_token) - .filter(|token| token.kind() == SyntaxKind::Ident) + .children() + .filter_map(PathItem::cast) .collect() } } @@ -722,8 +708,8 @@ impl PairExpr { } impl LambdaExpr { - pub fn generic_types(&self) -> Option { - self.syntax().children().find_map(GenericTypes::cast) + pub fn generic_params(&self) -> Option { + self.syntax().children().find_map(GenericParams::cast) } pub fn params(&self) -> Vec { @@ -757,13 +743,6 @@ impl LambdaParam { .find(|token| token.kind() == SyntaxKind::Ident) } - pub fn optional(&self) -> Option { - self.syntax() - .children_with_tokens() - .filter_map(SyntaxElement::into_token) - .find(|token| token.kind() == SyntaxKind::Question) - } - pub fn ty(&self) -> Option { self.syntax().children().find_map(Type::cast) } @@ -784,6 +763,10 @@ impl IfExpr { } impl FunctionCallExpr { + pub fn generic_args(&self) -> Option { + self.syntax().children().find_map(GenericArgs::cast) + } + pub fn callee(&self) -> Option { self.syntax().children().find_map(Expr::cast) } @@ -822,54 +805,23 @@ impl FieldAccessExpr { } } -impl IndexAccessExpr { - pub fn expr(&self) -> Option { - self.syntax().children().find_map(Expr::cast) - } - - pub fn index(&self) -> Option { +impl LiteralType { + pub fn value(&self) -> Option { self.syntax() .children_with_tokens() - .filter_map(SyntaxElement::into_token) - .find(|token| token.kind() == SyntaxKind::Int) - } -} - -impl ExistsExpr { - pub fn expr(&self) -> Option { - self.syntax().children().find_map(Expr::cast) + .find_map(SyntaxElement::into_token) } } impl PathType { - pub fn idents(&self) -> Vec { + pub fn items(&self) -> Vec { self.syntax() - .children_with_tokens() - .filter_map(SyntaxElement::into_token) - .filter(|token| token.kind() == SyntaxKind::Ident) + .children() + .filter_map(PathItem::cast) .collect() } } -impl ListType { - pub fn ty(&self) -> Option { - self.syntax().children().find_map(Type::cast) - } -} - -impl ListTypeItem { - pub fn spread(&self) -> Option { - self.syntax() - .children_with_tokens() - .filter_map(SyntaxElement::into_token) - .find(|token| token.kind() == SyntaxKind::Spread) - } - - pub fn ty(&self) -> Option { - self.syntax().children().find_map(Type::cast) - } -} - impl PairType { pub fn first(&self) -> Option { self.syntax().children().find_map(Type::cast) @@ -901,13 +853,6 @@ impl FunctionTypeParam { .find(|token| token.kind() == SyntaxKind::Ident) } - pub fn optional(&self) -> Option { - self.syntax() - .children_with_tokens() - .filter_map(SyntaxElement::into_token) - .find(|token| token.kind() == SyntaxKind::Question) - } - pub fn spread(&self) -> Option { self.syntax() .children_with_tokens() @@ -920,14 +865,14 @@ impl FunctionTypeParam { } } -impl NullableType { - pub fn ty(&self) -> Option { - self.syntax().children().find_map(Type::cast) +impl UnionType { + pub fn types(&self) -> Vec { + self.syntax().children().filter_map(Type::cast).collect() } } -impl GenericTypes { - pub fn idents(&self) -> Vec { +impl GenericParams { + pub fn names(&self) -> Vec { self.syntax() .children_with_tokens() .filter_map(SyntaxElement::into_token) @@ -935,3 +880,22 @@ impl GenericTypes { .collect() } } + +impl GenericArgs { + pub fn types(&self) -> Vec { + self.syntax().children().filter_map(Type::cast).collect() + } +} + +impl PathItem { + pub fn name(&self) -> Option { + self.syntax() + .children_with_tokens() + .filter_map(SyntaxElement::into_token) + .find(|token| token.kind() == SyntaxKind::Ident) + } + + pub fn generic_args(&self) -> Option { + self.syntax().children().find_map(GenericArgs::cast) + } +} diff --git a/crates/rue-parser/src/grammar.rs b/crates/rue-parser/src/grammar.rs index 98b16aa..233a44a 100644 --- a/crates/rue-parser/src/grammar.rs +++ b/crates/rue-parser/src/grammar.rs @@ -67,7 +67,7 @@ fn function_item(p: &mut Parser<'_>, cp: Checkpoint) { p.expect(SyntaxKind::Fun); p.expect(SyntaxKind::Ident); if p.at(SyntaxKind::LessThan) { - generic_types(p); + generic_params(p); } function_params(p); p.expect(SyntaxKind::Arrow); @@ -91,7 +91,6 @@ fn function_param(p: &mut Parser<'_>) { p.start(SyntaxKind::FunctionParam); p.try_eat(SyntaxKind::Spread); p.expect(SyntaxKind::Ident); - p.try_eat(SyntaxKind::Question); p.expect(SyntaxKind::Colon); ty(p); p.finish(); @@ -101,6 +100,9 @@ fn type_alias_item(p: &mut Parser<'_>, cp: Checkpoint) { p.start_at(cp, SyntaxKind::TypeAliasItem); p.expect(SyntaxKind::Type); p.expect(SyntaxKind::Ident); + if p.at(SyntaxKind::LessThan) { + generic_params(p); + } p.expect(SyntaxKind::Assign); ty(p); p.expect(SyntaxKind::Semicolon); @@ -126,7 +128,6 @@ fn struct_field(p: &mut Parser<'_>) { p.start(SyntaxKind::StructField); p.try_eat(SyntaxKind::Spread); p.expect(SyntaxKind::Ident); - p.try_eat(SyntaxKind::Question); p.expect(SyntaxKind::Colon); ty(p); p.finish(); @@ -339,6 +340,8 @@ fn expr(p: &mut Parser<'_>) { fn expr_binding_power(p: &mut Parser<'_>, minimum_binding_power: u8, allow_initializer: bool) { let checkpoint = p.checkpoint(); + let mut at_generic_args = false; + if p.at(SyntaxKind::Not) || p.at(SyntaxKind::Minus) || p.at(SyntaxKind::Plus) @@ -359,9 +362,9 @@ fn expr_binding_power(p: &mut Parser<'_>, minimum_binding_power: u8, allow_initi p.bump(); p.finish(); } else if p.at(SyntaxKind::Ident) { - path_expr(p); - - if p.at(SyntaxKind::OpenBrace) && allow_initializer { + if path_expr(p) { + at_generic_args = true; + } else if p.at(SyntaxKind::OpenBrace) && allow_initializer { p.start_at(checkpoint, SyntaxKind::InitializerExpr); p.bump(); while !p.at(SyntaxKind::CloseBrace) { @@ -402,9 +405,12 @@ fn expr_binding_power(p: &mut Parser<'_>, minimum_binding_power: u8, allow_initi } loop { - if p.at(SyntaxKind::OpenParen) { + if at_generic_args || p.at(SyntaxKind::OpenParen) { p.start_at(checkpoint, SyntaxKind::FunctionCallExpr); - p.bump(); + if at_generic_args { + generic_args(p); + } + p.expect(SyntaxKind::OpenParen); while !p.at(SyntaxKind::CloseParen) { function_call_arg(p); if !p.try_eat(SyntaxKind::Comma) { @@ -413,21 +419,12 @@ fn expr_binding_power(p: &mut Parser<'_>, minimum_binding_power: u8, allow_initi } p.expect(SyntaxKind::CloseParen); p.finish(); - } else if p.at(SyntaxKind::Question) { - p.start_at(checkpoint, SyntaxKind::ExistsExpr); - p.bump(); - p.finish(); + at_generic_args = false; } else if p.at(SyntaxKind::Dot) { p.start_at(checkpoint, SyntaxKind::FieldAccessExpr); p.bump(); p.expect(SyntaxKind::Ident); p.finish(); - } else if p.at(SyntaxKind::OpenBracket) { - p.start_at(checkpoint, SyntaxKind::IndexAccessExpr); - p.bump(); - p.expect(SyntaxKind::Int); - p.expect(SyntaxKind::CloseBracket); - p.finish(); } else if p.at(SyntaxKind::As) { p.start_at(checkpoint, SyntaxKind::CastExpr); p.bump(); @@ -498,13 +495,23 @@ fn expr_binding_power(p: &mut Parser<'_>, minimum_binding_power: u8, allow_initi } } -fn path_expr(p: &mut Parser<'_>) { +#[must_use] +fn path_expr(p: &mut Parser<'_>) -> bool { p.start(SyntaxKind::PathExpr); + p.start(SyntaxKind::PathItem); p.expect(SyntaxKind::Ident); + p.finish(); while p.try_eat(SyntaxKind::PathSeparator) { + if p.at(SyntaxKind::LessThan) { + p.finish(); + return true; + } + p.start(SyntaxKind::PathItem); p.expect(SyntaxKind::Ident); + p.finish(); } p.finish(); + false } fn function_call_arg(p: &mut Parser<'_>) { @@ -534,7 +541,7 @@ fn lambda_expr(p: &mut Parser<'_>) { p.start(SyntaxKind::LambdaExpr); p.expect(SyntaxKind::Fun); if p.at(SyntaxKind::LessThan) { - generic_types(p); + generic_params(p); } p.expect(SyntaxKind::OpenParen); while !p.at(SyntaxKind::CloseParen) { @@ -556,7 +563,6 @@ fn lambda_param(p: &mut Parser<'_>) { p.start(SyntaxKind::LambdaParam); p.try_eat(SyntaxKind::Spread); p.expect(SyntaxKind::Ident); - p.try_eat(SyntaxKind::Question); if p.try_eat(SyntaxKind::Colon) { ty(p); } @@ -566,10 +572,18 @@ fn lambda_param(p: &mut Parser<'_>) { const TYPE_RECOVERY_SET: &[SyntaxKind] = &[SyntaxKind::OpenBrace, SyntaxKind::CloseBrace]; fn ty(p: &mut Parser<'_>) { - let checkpoint = p.checkpoint(); + let cp = p.checkpoint(); if p.at(SyntaxKind::Ident) { path_type(p); + } else if p.at(SyntaxKind::Int) + || p.at(SyntaxKind::True) + || p.at(SyntaxKind::False) + || p.at(SyntaxKind::Nil) + { + p.start(SyntaxKind::LiteralType); + p.bump(); + p.finish(); } else if p.at(SyntaxKind::Fun) { p.start(SyntaxKind::FunctionType); p.bump(); @@ -597,27 +611,28 @@ fn ty(p: &mut Parser<'_>) { return p.error(TYPE_RECOVERY_SET); } - loop { - if p.at(SyntaxKind::OpenBracket) { - p.start_at(checkpoint, SyntaxKind::ListType); - p.bump(); - p.expect(SyntaxKind::CloseBracket); - p.finish(); - } else if p.at(SyntaxKind::Question) { - p.start_at(checkpoint, SyntaxKind::NullableType); - p.bump(); - p.finish(); - } else { - break; - } + while p.at(SyntaxKind::BitwiseOr) { + p.start_at(cp, SyntaxKind::UnionType); + p.bump(); + ty(p); + p.finish(); } } fn path_type(p: &mut Parser<'_>) { p.start(SyntaxKind::PathType); - p.expect(SyntaxKind::Ident); + path_type_item(p); while p.try_eat(SyntaxKind::PathSeparator) { - p.expect(SyntaxKind::Ident); + path_type_item(p); + } + p.finish(); +} + +fn path_type_item(p: &mut Parser<'_>) { + p.start(SyntaxKind::PathItem); + p.expect(SyntaxKind::Ident); + if p.at(SyntaxKind::LessThan) { + generic_args(p); } p.finish(); } @@ -626,14 +641,13 @@ fn function_type_param(p: &mut Parser<'_>) { p.start(SyntaxKind::FunctionTypeParam); p.try_eat(SyntaxKind::Spread); p.expect(SyntaxKind::Ident); - p.try_eat(SyntaxKind::Question); p.expect(SyntaxKind::Colon); ty(p); p.finish(); } -fn generic_types(p: &mut Parser<'_>) { - p.start(SyntaxKind::GenericTypes); +fn generic_params(p: &mut Parser<'_>) { + p.start(SyntaxKind::GenericParams); p.expect(SyntaxKind::LessThan); while !p.at(SyntaxKind::GreaterThan) { p.expect(SyntaxKind::Ident); @@ -644,3 +658,18 @@ fn generic_types(p: &mut Parser<'_>) { p.expect(SyntaxKind::GreaterThan); p.finish(); } + +fn generic_args(p: &mut Parser<'_>) { + p.start(SyntaxKind::GenericArgs); + p.expect(SyntaxKind::LessThan); + ty(p); + p.try_eat(SyntaxKind::Comma); + while !p.at(SyntaxKind::GreaterThan) { + ty(p); + if !p.try_eat(SyntaxKind::Comma) { + break; + } + } + p.expect(SyntaxKind::GreaterThan); + p.finish(); +} diff --git a/crates/rue-parser/src/syntax_kind.rs b/crates/rue-parser/src/syntax_kind.rs index e18f240..df99093 100644 --- a/crates/rue-parser/src/syntax_kind.rs +++ b/crates/rue-parser/src/syntax_kind.rs @@ -112,16 +112,15 @@ pub enum SyntaxKind { FunctionCallExpr, FunctionCallArg, FieldAccessExpr, - IndexAccessExpr, - ExistsExpr, + LiteralType, PathType, - ListType, - ListTypeItem, PairType, FunctionType, FunctionTypeParam, - NullableType, - GenericTypes, + UnionType, + GenericArgs, + GenericParams, + PathItem, } impl fmt::Display for SyntaxKind { @@ -235,17 +234,15 @@ impl fmt::Display for SyntaxKind { Self::FunctionCallExpr => "function call expression", Self::FunctionCallArg => "function call argument", Self::FieldAccessExpr => "field access expression", - Self::IndexAccessExpr => "index access expression", - Self::ExistsExpr => "exists expression", + Self::LiteralType => "literal type", Self::PathType => "path type", - Self::ListType => "list type", - Self::ListTypeItem => "list type item", Self::PairType => "pair type", Self::FunctionType => "function type", Self::FunctionTypeParam => "function type parameter", - Self::NullableType => "nullable type", - - Self::GenericTypes => "generic types", + Self::UnionType => "union type", + Self::GenericArgs => "generic args", + Self::GenericParams => "generic params", + Self::PathItem => "path item", } ) } diff --git a/crates/rue-tests/Cargo.toml b/crates/rue-tests/Cargo.toml index 114fea5..46627d5 100644 --- a/crates/rue-tests/Cargo.toml +++ b/crates/rue-tests/Cargo.toml @@ -16,14 +16,14 @@ categories = { workspace = true } workspace = true [dependencies] -rue-parser = { path = "../rue-parser", version = "0.1.0" } -rue-compiler = { path = "../rue-compiler", version = "0.1.0" } -rue-clvm = { path = "../rue-clvm", version = "0.1.0" } -clvm-utils = "0.6.0" -clvmr = "0.6.1" -hex = "0.4.3" -toml = "0.8.12" -serde = { version = "1.0.197", features = ["derive"] } -clap = { version = "4.5.4", features = ["derive"] } -walkdir = "2.5.0" -indexmap = { version = "2.2.6", features = ["serde"] } +rue-parser = { workspace = true } +rue-compiler = { workspace = true } +rue-clvm = { workspace = true } +clvm-utils = { workspace = true } +clvmr = { workspace = true } +hex = { workspace = true } +toml = { workspace = true } +walkdir = { workspace = true } +serde = { workspace = true, features = ["derive"] } +clap = { workspace = true, features = ["derive"] } +indexmap = { workspace = true, features = ["serde"] } diff --git a/crates/rue-typing/Cargo.toml b/crates/rue-typing/Cargo.toml new file mode 100644 index 0000000..c602775 --- /dev/null +++ b/crates/rue-typing/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "rue-typing" +version = "0.1.1" +edition = "2021" +license = "Apache-2.0" +description = "The type system used by the Rue compiler." +authors = ["Brandon Haggstrom "] +homepage = "https://github.com/rigidity/rue" +repository = "https://github.com/rigidity/rue" +readme = { workspace = true } +keywords = { workspace = true } +categories = { workspace = true } + +[lints] +workspace = true + +[dependencies] +id-arena = { workspace = true } +thiserror = { workspace = true } +indexmap = { workspace = true } +num-bigint = { workspace = true } +num-traits = { workspace = true } +clvmr = { workspace = true } +hashbrown = { workspace = true } +ahash = { workspace = true } + +[dev-dependencies] +anyhow = { workspace = true } diff --git a/crates/rue-typing/src/bigint.rs b/crates/rue-typing/src/bigint.rs new file mode 100644 index 0000000..55ed663 --- /dev/null +++ b/crates/rue-typing/src/bigint.rs @@ -0,0 +1,9 @@ +use clvmr::Allocator; +use num_bigint::BigInt; + +pub fn bigint_to_bytes(value: BigInt) -> Vec { + let mut allocator = Allocator::new(); + let ptr = allocator.new_number(value).unwrap(); + let atom = allocator.atom(ptr); + atom.as_ref().to_vec() +} diff --git a/crates/rue-typing/src/check.rs b/crates/rue-typing/src/check.rs new file mode 100644 index 0000000..a993f7b --- /dev/null +++ b/crates/rue-typing/src/check.rs @@ -0,0 +1,41 @@ +use std::fmt; + +mod attributes; +mod check_error; +mod check_type; +mod simplify_and; +mod simplify_check; +mod simplify_or; +mod stringify_check; + +pub use check_error::*; + +pub(crate) use attributes::*; +pub(crate) use check_type::*; +pub(crate) use simplify_and::*; +pub(crate) use simplify_check::*; +pub(crate) use simplify_or::*; +pub(crate) use stringify_check::*; + +use num_bigint::BigInt; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Check { + True, + False, + IsPair, + IsAtom, + Value(BigInt), + Length(usize), + And(Vec), + Or(Vec), + If(Box, Box, Box), + First(Box), + Rest(Box), +} + +impl fmt::Display for Check { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + stringify_check(self, f, &mut Vec::new()) + } +} diff --git a/crates/rue-typing/src/check/attributes.rs b/crates/rue-typing/src/check/attributes.rs new file mode 100644 index 0000000..5b6d31d --- /dev/null +++ b/crates/rue-typing/src/check/attributes.rs @@ -0,0 +1,138 @@ +use std::collections::VecDeque; + +use crate::{HashMap, HashSet}; + +use num_bigint::BigInt; +use num_traits::One; + +use crate::{Type, TypeId, TypeSystem}; + +use super::CheckError; + +pub(crate) struct Attributes { + pub atom_count: usize, + pub bytes32_count: usize, + pub public_key_count: usize, + pub pairs: Vec<(TypeId, TypeId)>, + pub values: HashMap, + pub length: usize, +} + +impl Attributes { + pub fn all_atoms(&self) -> bool { + self.atom_count == self.length + } + + pub fn all_bytes32(&self) -> bool { + self.bytes32_count == self.length + } + + pub fn all_public_key(&self) -> bool { + self.public_key_count == self.length + } + + pub fn all_pairs(&self) -> bool { + self.pairs.len() == self.length + } + + pub fn all_value(&self, value: &BigInt) -> bool { + self.values.get(value).copied().unwrap_or(0) == self.length + } + + pub fn atoms_are_bytes32(&self) -> bool { + self.bytes32_count == self.atom_count + } + + pub fn atoms_are_public_key(&self) -> bool { + self.public_key_count == self.atom_count + } + + pub fn atoms_are_value(&self, value: &BigInt) -> bool { + self.values.get(value).copied().unwrap_or(0) == self.atom_count + } +} + +pub(crate) fn union_attributes( + db: &TypeSystem, + items: &[TypeId], + is_lhs: bool, + other_type_id: TypeId, + visited: &mut HashSet<(TypeId, TypeId)>, +) -> Result { + let mut atom_count = 0; + let mut bytes32_count = 0; + let mut public_key_count = 0; + let mut pairs = Vec::new(); + let mut values = HashMap::new(); + + let mut items: VecDeque<_> = items.iter().copied().collect(); + let mut length = 0; + + while !items.is_empty() { + let item = items.remove(0).unwrap(); + length += 1; + + let key = if is_lhs { + (item, other_type_id) + } else { + (other_type_id, item) + }; + + if !visited.insert(key) { + return Err(CheckError::Recursive(key.0, key.1)); + } + + match db.get_recursive(item) { + Type::Ref(..) + | Type::Lazy(..) + | Type::Alias(..) + | Type::Struct(..) + | Type::Enum(..) + | Type::Variant(..) => unreachable!(), + Type::Generic | Type::Never | Type::Callable(..) => { + length -= 1; + } + Type::Any | Type::Unknown => {} + Type::Bytes | Type::Int => { + atom_count += 1; + } + Type::Bytes32 => { + atom_count += 1; + bytes32_count += 1; + } + Type::PublicKey => { + atom_count += 1; + public_key_count += 1; + } + Type::Nil | Type::False => { + atom_count += 1; + *values.entry(BigInt::ZERO).or_insert(0) += 1; + } + Type::True => { + atom_count += 1; + *values.entry(BigInt::one()).or_insert(0) += 1; + } + Type::Value(value) => { + atom_count += 1; + *values.entry(value.clone()).or_insert(0) += 1; + } + Type::Pair(first, rest) => { + pairs.push((*first, *rest)); + } + Type::Union(child_items) => { + items.extend(child_items); + } + } + + visited.remove(&key); + } + + Ok(Attributes { + atom_count, + bytes32_count, + public_key_count, + pairs, + values, + length, + }) +} diff --git a/crates/rue-typing/src/check/check_error.rs b/crates/rue-typing/src/check/check_error.rs new file mode 100644 index 0000000..c1b9963 --- /dev/null +++ b/crates/rue-typing/src/check/check_error.rs @@ -0,0 +1,9 @@ +use thiserror::Error; + +use crate::TypeId; + +#[derive(Debug, Error, Clone, Copy, PartialEq, Eq)] +pub enum CheckError { + #[error("recursive check")] + Recursive(TypeId, TypeId), +} diff --git a/crates/rue-typing/src/check/check_type.rs b/crates/rue-typing/src/check/check_type.rs new file mode 100644 index 0000000..2435f11 --- /dev/null +++ b/crates/rue-typing/src/check/check_type.rs @@ -0,0 +1,639 @@ +use crate::HashSet; + +use num_bigint::BigInt; +use num_traits::One; + +use crate::{bigint_to_bytes, Comparison, Type, TypeId, TypeSystem}; + +use super::{union_attributes, Check, CheckError}; + +/// Returns [`None`] for recursive checks. +pub(crate) fn check_type( + types: &mut TypeSystem, + lhs: TypeId, + rhs: TypeId, + visited: &mut HashSet<(TypeId, TypeId)>, +) -> Result { + if !visited.insert((lhs, rhs)) { + if types.compare(lhs, rhs) <= Comparison::Castable { + return Ok(Check::True); + } + return Err(CheckError::Recursive(lhs, rhs)); + } + + let check = match (types.get_recursive(lhs), types.get_recursive(rhs)) { + ( + Type::Ref(..) + | Type::Lazy(..) + | Type::Alias(..) + | Type::Struct(..) + | Type::Enum(..) + | Type::Variant(..), + _, + ) + | ( + _, + Type::Ref(..) + | Type::Lazy(..) + | Type::Alias(..) + | Type::Struct(..) + | Type::Enum(..) + | Type::Variant(..), + ) => { + unreachable!() + } + + (_, Type::Any | Type::Unknown) + | (Type::Never | Type::Unknown, _) + | (Type::Bytes | Type::Int | Type::Value(..), Type::Bytes | Type::Int) + | (Type::Bytes32, Type::Bytes | Type::Int | Type::Bytes32) + | (Type::PublicKey, Type::Bytes | Type::Int | Type::PublicKey) + | (Type::Nil | Type::False, Type::Bytes | Type::Int | Type::Nil | Type::False) + | (Type::True, Type::Bytes | Type::Int | Type::True) => Check::True, + + (Type::Any | Type::Generic | Type::Callable(..), Type::Bytes | Type::Int) => Check::IsAtom, + (Type::Any | Type::Generic | Type::Callable(..), Type::Bytes32) => { + Check::And(vec![Check::IsAtom, Check::Length(32)]) + } + (Type::Any | Type::Generic | Type::Callable(..), Type::PublicKey) => { + Check::And(vec![Check::IsAtom, Check::Length(48)]) + } + (Type::Any | Type::Generic | Type::Callable(..), Type::False | Type::Nil) => { + Check::And(vec![Check::IsAtom, Check::Value(BigInt::ZERO)]) + } + (Type::Any | Type::Generic | Type::Callable(..), Type::True) => { + Check::And(vec![Check::IsAtom, Check::Value(BigInt::one())]) + } + (Type::Any | Type::Generic | Type::Callable(..), Type::Value(value)) => { + Check::And(vec![Check::IsAtom, Check::Value(value.clone())]) + } + (Type::Any | Type::Generic | Type::Callable(..), Type::Pair(first, rest)) => { + let (first, rest) = (*first, *rest); + let first = check_type(types, types.std().any, first, visited)?; + let rest = check_type(types, types.std().any, rest, visited)?; + Check::And(vec![ + Check::IsPair, + Check::First(Box::new(first)), + Check::Rest(Box::new(rest)), + ]) + } + + (Type::Bytes | Type::Int, Type::Nil | Type::False) => Check::Value(BigInt::ZERO), + (Type::Bytes | Type::Int, Type::True) => Check::Value(BigInt::one()), + (Type::Bytes | Type::Int, Type::PublicKey) => Check::Length(48), + (Type::Bytes | Type::Int, Type::Bytes32) => Check::Length(32), + + (_, Type::Never | Type::Generic | Type::Callable(..)) + | (Type::PublicKey, Type::Bytes32 | Type::Nil | Type::True | Type::False) + | (Type::Bytes32, Type::PublicKey | Type::Nil | Type::True | Type::False) + | (Type::Nil | Type::False, Type::PublicKey | Type::Bytes32 | Type::True) + | (Type::True, Type::PublicKey | Type::Bytes32 | Type::Nil | Type::False) + | ( + Type::Bytes + | Type::Bytes32 + | Type::PublicKey + | Type::Int + | Type::Nil + | Type::True + | Type::False + | Type::Value(..), + Type::Pair(..), + ) + | ( + Type::Pair(..), + Type::Bytes + | Type::Bytes32 + | Type::PublicKey + | Type::Int + | Type::Nil + | Type::True + | Type::False + | Type::Value(..), + ) => Check::False, + + (Type::Value(value), Type::Bytes32) => { + if bigint_to_bytes(value.clone()).len() == 32 { + Check::True + } else { + Check::False + } + } + + (Type::Value(value), Type::PublicKey) => { + if bigint_to_bytes(value.clone()).len() == 48 { + Check::True + } else { + Check::False + } + } + + (Type::Value(value), Type::Nil | Type::False) + | (Type::Nil | Type::False, Type::Value(value)) => { + if value == &BigInt::ZERO { + Check::True + } else { + Check::False + } + } + + (Type::Value(value), Type::True) | (Type::True, Type::Value(value)) => { + if value == &BigInt::one() { + Check::True + } else { + Check::False + } + } + + (Type::Bytes | Type::Int, Type::Value(value)) => Check::Value(value.clone()), + + (Type::Bytes32, Type::Value(value)) => { + if bigint_to_bytes(value.clone()).len() == 32 { + Check::Value(value.clone()) + } else { + Check::False + } + } + (Type::PublicKey, Type::Value(value)) => { + if bigint_to_bytes(value.clone()).len() == 48 { + Check::Value(value.clone()) + } else { + Check::False + } + } + + (Type::Value(lhs_value), Type::Value(rhs_value)) => { + if lhs_value == rhs_value { + Check::True + } else { + Check::False + } + } + + (Type::Pair(lhs_first, lhs_rest), Type::Pair(rhs_first, rhs_rest)) => { + let (lhs_first, lhs_rest) = (*lhs_first, *lhs_rest); + let (rhs_first, rhs_rest) = (*rhs_first, *rhs_rest); + let first = check_type(types, lhs_first, rhs_first, visited)?; + let rest = check_type(types, lhs_rest, rhs_rest, visited)?; + Check::And(vec![ + Check::First(Box::new(first)), + Check::Rest(Box::new(rest)), + ]) + } + + (Type::Union(items), _) => { + let items = items.clone(); + check_union_against_rhs(types, lhs, &items, rhs, visited)? + } + + (_, Type::Union(items)) => { + let mut result = Vec::new(); + for item in items.clone() { + result.push(check_type(types, lhs, item, visited)?); + } + Check::Or(result) + } + }; + + visited.remove(&(lhs, rhs)); + + Ok(check) +} + +fn check_union_against_rhs( + types: &mut TypeSystem, + original_type_id: TypeId, + items: &[TypeId], + rhs: TypeId, + visited: &mut HashSet<(TypeId, TypeId)>, +) -> Result { + let union = types.alloc(Type::Union(items.to_vec())); + + if types.compare(union, rhs) <= Comparison::Castable { + return Ok(Check::True); + } + + if let Type::Union(union) = types.get_recursive(rhs) { + let rhs_items = union.clone(); + let mut result = Vec::new(); + for rhs_item in rhs_items { + if !visited.insert((original_type_id, rhs_item)) { + return Err(CheckError::Recursive(original_type_id, rhs_item)); + } + result.push(check_union_against_rhs( + types, + original_type_id, + items, + rhs_item, + visited, + )?); + } + return Ok(Check::Or(result)); + } + + let attrs = union_attributes(types, items, true, rhs, visited)?; + + Ok(match types.get_recursive(rhs) { + Type::Ref(..) + | Type::Lazy(..) + | Type::Union(..) + | Type::Alias(..) + | Type::Struct(..) + | Type::Enum(..) + | Type::Variant(..) => unreachable!(), + Type::Unknown | Type::Any => Check::True, + Type::Never | Type::Generic | Type::Callable(..) => Check::False, + Type::Bytes | Type::Int if attrs.all_atoms() => Check::True, + Type::Bytes | Type::Int => Check::IsAtom, + Type::Nil | Type::False if attrs.all_value(&BigInt::ZERO) => Check::True, + Type::Nil | Type::False if attrs.atoms_are_value(&BigInt::ZERO) => Check::IsAtom, + Type::Nil | Type::False if attrs.all_atoms() => Check::Value(BigInt::ZERO), + Type::Nil | Type::False => Check::And(vec![Check::IsAtom, Check::Value(BigInt::ZERO)]), + Type::True if attrs.all_value(&BigInt::one()) => Check::True, + Type::True if attrs.atoms_are_value(&BigInt::ZERO) => Check::IsAtom, + Type::True if attrs.all_atoms() => Check::Value(BigInt::one()), + Type::True => Check::And(vec![Check::IsAtom, Check::Value(BigInt::one())]), + Type::Value(value) if attrs.all_value(value) => Check::True, + Type::Value(value) if attrs.atoms_are_value(value) => Check::IsAtom, + Type::Value(value) if attrs.all_atoms() => Check::Value(value.clone()), + Type::Value(value) => Check::And(vec![Check::IsAtom, Check::Value(value.clone())]), + Type::Bytes32 if attrs.all_bytes32() => Check::True, + Type::Bytes32 if attrs.all_atoms() => Check::Length(32), + Type::Bytes32 if attrs.atoms_are_bytes32() => Check::IsAtom, + Type::Bytes32 => Check::And(vec![Check::IsAtom, Check::Length(32)]), + Type::PublicKey if attrs.all_public_key() => Check::True, + Type::PublicKey if attrs.all_atoms() => Check::Length(48), + Type::PublicKey if attrs.atoms_are_public_key() => Check::IsAtom, + Type::PublicKey => Check::And(vec![Check::IsAtom, Check::Length(48)]), + Type::Pair(..) if attrs.all_atoms() => Check::False, + Type::Pair(..) if attrs.pairs.len() == 1 && attrs.atom_count == attrs.length - 1 => { + Check::IsPair + } + Type::Pair(first, rest) => { + let (first, rest) = (*first, *rest); + + let first_items: Vec<_> = attrs.pairs.iter().map(|(first, _)| *first).collect(); + let rest_items: Vec<_> = attrs.pairs.iter().map(|(_, rest)| *rest).collect(); + + let first = + check_union_against_rhs(types, original_type_id, &first_items, first, visited)?; + let rest = + check_union_against_rhs(types, original_type_id, &rest_items, rest, visited)?; + + if attrs.all_pairs() { + Check::And(vec![ + Check::First(Box::new(first)), + Check::Rest(Box::new(rest)), + ]) + } else { + Check::And(vec![ + Check::IsPair, + Check::First(Box::new(first)), + Check::Rest(Box::new(rest)), + ]) + } + } + }) +} + +#[cfg(test)] +mod tests { + use indexmap::indexmap; + + use crate::{alloc_list, alloc_struct}; + + use super::*; + + fn check_str(db: &mut TypeSystem, lhs: TypeId, rhs: TypeId, expected: &str) { + assert_eq!(format!("{}", db.check(lhs, rhs).unwrap()), expected); + } + + fn check_recursive(db: &mut TypeSystem, lhs: TypeId, rhs: TypeId) { + assert!(matches!(db.check(lhs, rhs), Err(CheckError::Recursive(..)))); + } + + #[test] + fn test_check_any_bytes() { + let mut db = TypeSystem::new(); + let types = db.std(); + check_str(&mut db, types.any, types.bytes, "(not (l val))"); + } + + #[test] + fn test_check_any_bytes32() { + let mut db = TypeSystem::new(); + let types = db.std(); + check_str( + &mut db, + types.any, + types.bytes32, + "(and (not (l val)) (= (strlen val) 32))", + ); + } + + #[test] + fn test_check_any_public_key() { + let mut db = TypeSystem::new(); + let types = db.std(); + check_str( + &mut db, + types.any, + types.public_key, + "(and (not (l val)) (= (strlen val) 48))", + ); + } + + #[test] + fn test_check_any_int() { + let mut db = TypeSystem::new(); + let types = db.std(); + check_str(&mut db, types.any, types.int, "(not (l val))"); + } + + #[test] + fn test_check_any_bool() { + let mut db = TypeSystem::new(); + let types = db.std(); + check_str( + &mut db, + types.any, + types.bool, + "(and (not (l val)) (or (= val 0) (= val 1)))", + ); + } + + #[test] + fn test_check_any_nil() { + let mut db = TypeSystem::new(); + let types = db.std(); + check_str( + &mut db, + types.any, + types.nil, + "(and (not (l val)) (= val 0))", + ); + } + + #[test] + fn test_check_bytes_bytes() { + let mut db = TypeSystem::new(); + let types = db.std(); + check_str(&mut db, types.bytes, types.bytes, "1"); + } + + #[test] + fn test_check_bytes32_bytes32() { + let mut db = TypeSystem::new(); + let types = db.std(); + check_str(&mut db, types.bytes32, types.bytes32, "1"); + } + + #[test] + fn test_check_public_key_public_key() { + let mut db = TypeSystem::new(); + let types = db.std(); + check_str(&mut db, types.public_key, types.public_key, "1"); + } + + #[test] + fn test_check_int_int() { + let mut db = TypeSystem::new(); + let types = db.std(); + check_str(&mut db, types.int, types.int, "1"); + } + + #[test] + fn test_check_bool_bool() { + let mut db = TypeSystem::new(); + let types = db.std(); + check_str(&mut db, types.bool, types.bool, "1"); + } + + #[test] + fn test_check_nil_nil() { + let mut db = TypeSystem::new(); + let types = db.std(); + check_str(&mut db, types.nil, types.nil, "1"); + } + + #[test] + fn test_check_bytes_bytes32() { + let mut db = TypeSystem::new(); + let types = db.std(); + check_str(&mut db, types.bytes, types.bytes32, "(= (strlen val) 32)"); + } + + #[test] + fn test_check_bytes32_bytes() { + let mut db = TypeSystem::new(); + let types = db.std(); + check_str(&mut db, types.bytes32, types.bytes, "1"); + } + + #[test] + fn test_check_bytes_public_key() { + let mut db = TypeSystem::new(); + let types = db.std(); + check_str( + &mut db, + types.bytes, + types.public_key, + "(= (strlen val) 48)", + ); + } + + #[test] + fn test_check_public_key_bytes() { + let mut db = TypeSystem::new(); + let types = db.std(); + check_str(&mut db, types.public_key, types.bytes, "1"); + } + + #[test] + fn test_check_bytes_nil() { + let mut db = TypeSystem::new(); + let types = db.std(); + check_str(&mut db, types.bytes, types.nil, "(= val 0)"); + } + + #[test] + fn test_check_nil_bytes() { + let mut db = TypeSystem::new(); + let types = db.std(); + check_str(&mut db, types.nil, types.bytes, "1"); + } + + #[test] + fn test_check_bytes_bool() { + let mut db = TypeSystem::new(); + let types = db.std(); + check_str(&mut db, types.bytes, types.bool, "(or (= val 0) (= val 1))"); + } + + #[test] + fn test_check_bool_bytes() { + let mut db = TypeSystem::new(); + let types = db.std(); + check_str(&mut db, types.bool, types.bytes, "1"); + } + + #[test] + fn test_check_bytes32_public_key() { + let mut db = TypeSystem::new(); + let types = db.std(); + check_str(&mut db, types.bytes32, types.public_key, "0"); + } + + #[test] + fn test_check_bytes_int() { + let mut db = TypeSystem::new(); + let types = db.std(); + check_str(&mut db, types.bytes, types.int, "1"); + } + + #[test] + fn test_check_int_bytes() { + let mut db = TypeSystem::new(); + let types = db.std(); + check_str(&mut db, types.int, types.bytes, "1"); + } + + #[test] + fn test_check_bool_nil() { + let mut db = TypeSystem::new(); + let types = db.std(); + check_str(&mut db, types.bool, types.nil, "(= val 0)"); + } + + #[test] + fn test_check_nil_bool() { + let mut db = TypeSystem::new(); + let types = db.std(); + check_str(&mut db, types.nil, types.bool, "1"); + } + + #[test] + fn test_check_any_any() { + let mut db = TypeSystem::new(); + let types = db.std(); + check_str(&mut db, types.any, types.any, "1"); + } + + #[test] + fn test_check_bytes_any() { + let mut db = TypeSystem::new(); + let types = db.std(); + check_str(&mut db, types.any, types.any, "1"); + } + + #[test] + fn test_check_list_nil() { + let mut db = TypeSystem::new(); + let types = db.std(); + let list = alloc_list(&mut db, types.bytes); + check_str(&mut db, list, types.nil, "(not (l val))"); + } + + #[test] + fn test_check_list_pair() { + let mut db = TypeSystem::new(); + let types = db.std(); + let list = alloc_list(&mut db, types.bytes); + let pair = db.alloc(Type::Pair(types.bytes, list)); + check_str(&mut db, list, pair, "(l val)"); + } + + #[test] + fn test_check_list_pair_generic() { + let mut db = TypeSystem::new(); + let generic = db.alloc(Type::Generic); + let list = alloc_list(&mut db, generic); + let pair = db.alloc(Type::Pair(generic, list)); + check_str(&mut db, list, pair, "(l val)"); + } + + #[test] + fn test_check_any_list() { + let mut db = TypeSystem::new(); + let types = db.std(); + let list = alloc_list(&mut db, types.bytes); + check_recursive(&mut db, types.any, list); + } + + #[test] + fn test_check_any_point() { + let mut db = TypeSystem::new(); + let types = db.std(); + let point_end = db.alloc(Type::Pair(types.int, types.nil)); + let point = db.alloc(Type::Pair(types.int, point_end)); + check_str(&mut db, types.any, point, "(and (l val) (not (l (f val))) (l (r val)) (not (l (f (r val)))) (not (l (r (r val)))) (= (r (r val)) 0))"); + } + + #[test] + fn test_check_any_point_struct() { + let mut db = TypeSystem::new(); + let types = db.std(); + let point = alloc_struct( + &mut db, + &indexmap! { + "x".to_string() => types.int, + "y".to_string() => types.int, + }, + true, + ); + check_str(&mut db, types.any, point, "(and (l val) (not (l (f val))) (l (r val)) (not (l (f (r val)))) (not (l (r (r val)))) (= (r (r val)) 0))"); + } + + #[test] + fn test_check_condition_agg_sig_me() { + let mut db = TypeSystem::new(); + let types = db.std(); + + let opcode = db.alloc(Type::Value(BigInt::from(49))); + let agg_sig_unsafe = alloc_struct( + &mut db, + &indexmap! { + "opcode".to_string() => opcode, + "public_key".to_string() => types.public_key, + "message".to_string() => types.bytes, + }, + true, + ); + + let opcode = db.alloc(Type::Value(BigInt::from(50))); + let agg_sig_me = alloc_struct( + &mut db, + &indexmap! { + "opcode".to_string() => opcode, + "public_key".to_string() => types.public_key, + "message".to_string() => types.bytes, + }, + true, + ); + + let condition = db.alloc(Type::Union(vec![agg_sig_unsafe, agg_sig_me])); + + check_str(&mut db, condition, agg_sig_me, "(= (f val) 50)"); + check_str(&mut db, types.any, agg_sig_me, "(and (l val) (not (l (f val))) (= (f val) 50) (l (r val)) (not (l (f (r val)))) (= (strlen (f (r val))) 48) (l (r (r val))) (not (l (f (r (r val))))) (not (l (r (r (r val))))) (= (r (r (r val))) 0))"); + } + + #[test] + fn test_check_three_int_int_list() { + let mut db = TypeSystem::new(); + let types = db.std(); + let inner = db.alloc(Type::Pair(types.int, types.int)); + let pair = db.alloc(Type::Pair(types.int, inner)); + let list = alloc_list(&mut db, types.int); + let union = db.alloc(Type::Union(vec![pair, list])); + + check_str(&mut db, union, union, "1"); + check_str(&mut db, union, types.int, "(not (l val))"); + check_str(&mut db, union, types.nil, "(not (l val))"); + check_recursive(&mut db, union, list); + check_str( + &mut db, + union, + pair, + "(and (l val) (l (r val)) (not (l (r (r val)))))", + ); + } +} diff --git a/crates/rue-typing/src/check/simplify_and.rs b/crates/rue-typing/src/check/simplify_and.rs new file mode 100644 index 0000000..ae29557 --- /dev/null +++ b/crates/rue-typing/src/check/simplify_and.rs @@ -0,0 +1,142 @@ +use std::collections::VecDeque; + +use super::{simplify_check, Check}; + +pub(crate) fn simplify_and_deep(items: Vec) -> Check { + let mut items = VecDeque::from(items); + + let iter = std::iter::from_fn(|| { + while let Some(item) = items.pop_front() { + match simplify_check(item) { + Check::And(children) => items.extend(children), + item => return Some(item), + } + } + None + }); + + simplify_and_shallow(iter) +} + +pub(crate) fn simplify_and_shallow(items: impl IntoIterator) -> Check { + let mut result = Vec::new(); + let mut is_atom = false; + let mut is_pair = false; + let mut value = false; + let mut length = false; + + for item in items { + match item { + Check::True => continue, + Check::IsAtom if is_atom => continue, + Check::IsAtom => is_atom = true, + Check::IsPair if is_pair => continue, + Check::IsPair => is_pair = true, + Check::Value(..) if value => continue, + Check::Value(..) => value = true, + Check::Length(..) if length => continue, + Check::Length(..) => length = true, + _ => {} + } + result.push(item); + } + + construct_and(result) +} + +pub(crate) fn construct_and(mut items: Vec) -> Check { + if items.is_empty() { + Check::True + } else if items.len() == 1 { + items.remove(0) + } else { + Check::And(items) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_simplify_and_none() { + assert_eq!(simplify_and_shallow([Check::True]), Check::True); + } + + #[test] + fn test_simplify_none_and_none() { + assert_eq!( + simplify_and_shallow([Check::True, Check::True]), + Check::True + ); + } + + #[test] + fn test_simplify_check_and_none() { + assert_eq!( + simplify_and_shallow([Check::IsAtom, Check::True]), + Check::IsAtom + ); + } + + #[test] + fn test_simplify_none_and_check() { + assert_eq!( + simplify_and_shallow([Check::True, Check::IsAtom]), + Check::IsAtom + ); + } + + #[test] + fn test_simplify_and_one_check() { + assert_eq!(simplify_and_shallow([Check::IsAtom]), Check::IsAtom); + } + + #[test] + fn test_simplify_atom_and_atom() { + assert_eq!( + simplify_and_shallow([Check::IsAtom, Check::IsAtom]), + Check::IsAtom + ); + } + + #[test] + fn test_simplify_atom_and_pair() { + assert_eq!( + simplify_and_shallow([Check::IsAtom, Check::IsPair]), + Check::And(vec![Check::IsAtom, Check::IsPair]) + ); + } + + #[test] + fn test_simplify_pair_and_pair() { + assert_eq!( + simplify_and_shallow([Check::IsPair, Check::IsPair]), + Check::IsPair + ); + } + + #[test] + fn test_simplify_pair_and_atom() { + assert_eq!( + simplify_and_shallow([Check::IsPair, Check::IsAtom]), + Check::And(vec![Check::IsPair, Check::IsAtom]) + ); + } + + #[test] + fn test_simplify_and_shallow() { + assert_eq!( + simplify_and_shallow([Check::And(vec![Check::True, Check::True])]), + Check::And(vec![Check::True, Check::True]) + ); + } + + #[test] + fn test_simplify_and_deep() { + assert_eq!( + simplify_and_deep(vec![Check::And(vec![Check::True, Check::True])]), + Check::True + ); + } +} diff --git a/crates/rue-typing/src/check/simplify_check.rs b/crates/rue-typing/src/check/simplify_check.rs new file mode 100644 index 0000000..9ad9cdc --- /dev/null +++ b/crates/rue-typing/src/check/simplify_check.rs @@ -0,0 +1,60 @@ +use num_bigint::BigInt; + +use super::{simplify_and_deep, simplify_or_deep, Check}; + +pub(crate) fn simplify_check(check: Check) -> Check { + match check { + Check::True => Check::True, + Check::False => Check::False, + Check::IsAtom => Check::IsAtom, + Check::IsPair => Check::IsPair, + Check::Value(value) => Check::Value(value), + Check::Length(len) => { + if len == 0 { + Check::Value(BigInt::ZERO) + } else { + Check::Length(len) + } + } + Check::And(items) => simplify_and_deep(items), + Check::Or(items) => simplify_or_deep(items), + Check::If(cond, then, else_) => { + let cond = simplify_check(*cond); + let then = simplify_check(*then); + let else_ = simplify_check(*else_); + Check::If(Box::new(cond), Box::new(then), Box::new(else_)) + } + Check::First(first) => match simplify_check(*first) { + Check::True => Check::True, + Check::And(items) => Check::And( + items + .into_iter() + .map(|item| Check::First(Box::new(item))) + .collect(), + ), + Check::Or(items) => Check::Or( + items + .into_iter() + .map(|item| Check::First(Box::new(item))) + .collect(), + ), + first => Check::First(Box::new(first)), + }, + Check::Rest(rest) => match simplify_check(*rest) { + Check::True => Check::True, + Check::And(items) => Check::And( + items + .into_iter() + .map(|item| Check::Rest(Box::new(item))) + .collect(), + ), + Check::Or(items) => Check::Or( + items + .into_iter() + .map(|item| Check::Rest(Box::new(item))) + .collect(), + ), + rest => Check::Rest(Box::new(rest)), + }, + } +} diff --git a/crates/rue-typing/src/check/simplify_or.rs b/crates/rue-typing/src/check/simplify_or.rs new file mode 100644 index 0000000..2b0b054 --- /dev/null +++ b/crates/rue-typing/src/check/simplify_or.rs @@ -0,0 +1,237 @@ +use std::collections::VecDeque; + +use super::{construct_and, simplify_check, Check}; + +enum ShapeCheck { + None, + Any, + Or(Check), +} + +pub(crate) fn simplify_or_deep(items: Vec) -> Check { + let mut items = VecDeque::from(items); + + let iter = std::iter::from_fn(|| { + while let Some(item) = items.pop_front() { + match simplify_check(item) { + Check::Or(children) => items.extend(children), + item => return Some(item), + } + } + None + }); + + simplify_or_shallow(iter) +} + +pub(crate) fn simplify_or_shallow(items: impl IntoIterator) -> Check { + let mut result = Vec::new(); + + let mut atom_checks: Vec = Vec::new(); + let mut pair_checks: Vec = Vec::new(); + + let mut any_atom = false; + let mut any_pair = false; + + for item in items { + match item { + Check::True => return Check::True, + Check::IsAtom if any_atom => continue, + Check::IsPair if any_pair => continue, + Check::IsAtom => { + any_atom = true; + continue; + } + Check::IsPair => { + any_pair = true; + continue; + } + Check::And(children) => { + let (shape, checks) = extract_shape_check(children); + match shape { + Some(Check::IsAtom) => atom_checks.push(checks), + Some(Check::IsPair) => pair_checks.push(checks), + _ => result.push(checks), + } + continue; + } + item => result.push(item), + } + } + + let prefer_atom = atom_checks.len() > pair_checks.len(); + + let atom_check = if any_atom { + ShapeCheck::Any + } else if atom_checks.is_empty() { + ShapeCheck::None + } else { + ShapeCheck::Or(construct_or(atom_checks)) + }; + + let pair_check = if any_pair { + ShapeCheck::Any + } else if pair_checks.is_empty() { + ShapeCheck::None + } else { + ShapeCheck::Or(construct_or(pair_checks)) + }; + + match (atom_check, pair_check) { + (ShapeCheck::None, ShapeCheck::None) => {} + (ShapeCheck::Any, ShapeCheck::Any) => { + return Check::True; + } + (ShapeCheck::Any, ShapeCheck::None) => { + result.push(Check::IsAtom); + } + (ShapeCheck::None, ShapeCheck::Any) => { + result.push(Check::IsPair); + } + (ShapeCheck::Or(atom_check), ShapeCheck::None) => { + result.push(construct_and(vec![Check::IsAtom, atom_check])); + } + (ShapeCheck::None, ShapeCheck::Or(pair_check)) => { + result.push(construct_and(vec![Check::IsPair, pair_check])); + } + (ShapeCheck::Or(atom_check), ShapeCheck::Any) => { + result.push(Check::IsPair); + result.push(atom_check); + } + (ShapeCheck::Any, ShapeCheck::Or(pair_check)) => { + result.push(Check::IsAtom); + result.push(pair_check); + } + (ShapeCheck::Or(atom_check), ShapeCheck::Or(pair_check)) => { + if prefer_atom { + result.push(Check::If( + Box::new(Check::IsAtom), + Box::new(atom_check), + Box::new(pair_check), + )); + } else { + result.push(Check::If( + Box::new(Check::IsPair), + Box::new(pair_check), + Box::new(atom_check), + )); + } + } + } + + construct_or(result) +} + +fn extract_shape_check(items: Vec) -> (Option, Check) { + let mut result = Vec::new(); + let mut shape = None; + + for item in items { + match item { + Check::IsAtom => shape = Some(Check::IsAtom), + Check::IsPair => shape = Some(Check::IsPair), + _ => result.push(item), + } + } + + (shape, construct_and(result)) +} + +fn construct_or(mut items: Vec) -> Check { + if items.is_empty() { + unreachable!() + } else if items.len() == 1 { + items.remove(0) + } else { + Check::Or(items) + } +} + +#[cfg(test)] +mod tests { + use num_bigint::BigInt; + + use super::*; + + #[test] + fn test_simplify_or_none() { + assert_eq!(simplify_or_shallow([Check::True]), Check::True); + } + + #[test] + fn test_simplify_none_or_none() { + assert_eq!(simplify_or_shallow([Check::True, Check::True]), Check::True); + } + + #[test] + fn test_simplify_check_or_none() { + assert_eq!( + simplify_or_shallow([Check::IsAtom, Check::True]), + Check::True + ); + } + + #[test] + fn test_simplify_none_or_check() { + assert_eq!( + simplify_or_shallow([Check::True, Check::IsAtom]), + Check::True + ); + } + + #[test] + fn test_simplify_or_one_check() { + assert_eq!(simplify_or_shallow([Check::IsAtom]), Check::IsAtom); + } + + #[test] + fn test_simplify_or_two_checks() { + assert_eq!( + simplify_or_shallow([Check::IsPair, Check::Value(BigInt::ZERO)]), + Check::Or(vec![Check::Value(BigInt::ZERO), Check::IsPair]) + ); + } + + #[test] + fn test_simplify_atom_or_pair() { + assert_eq!( + simplify_or_shallow([Check::IsAtom, Check::IsPair]), + Check::True + ); + } + + #[test] + fn test_simplify_pair_or_atom() { + assert_eq!( + simplify_or_shallow([Check::IsPair, Check::IsAtom]), + Check::True + ); + } + + #[test] + fn test_simplify_atom_or_atom() { + assert_eq!( + simplify_or_shallow([Check::IsAtom, Check::IsAtom]), + Check::IsAtom + ); + } + + #[test] + fn test_simplify_pair_or_pair() { + assert_eq!( + simplify_or_shallow([Check::IsPair, Check::IsPair]), + Check::IsPair + ); + } + + #[test] + fn test_simplify_compound_atom_or_pair() { + assert_eq!( + simplify_or_shallow([ + Check::And(vec![Check::IsAtom, Check::Length(32)]), + Check::IsPair + ]), + Check::Or(vec![Check::IsPair, Check::Length(32)]) + ); + } +} diff --git a/crates/rue-typing/src/check/stringify_check.rs b/crates/rue-typing/src/check/stringify_check.rs new file mode 100644 index 0000000..e627ad9 --- /dev/null +++ b/crates/rue-typing/src/check/stringify_check.rs @@ -0,0 +1,89 @@ +use std::fmt::{self, Display}; + +use crate::TypePath; + +use super::Check; + +pub(crate) fn stringify_check( + check: &Check, + f: &mut fmt::Formatter<'_>, + path: &mut Vec, +) -> fmt::Result { + match check { + Check::True => write!(f, "1"), + Check::False => write!(f, "0"), + Check::IsPair => { + write!(f, "(l ")?; + stringify_value(f, path)?; + write!(f, ")") + } + Check::IsAtom => { + write!(f, "(not (l ")?; + stringify_value(f, path)?; + write!(f, "))") + } + Check::Value(value) => { + write!(f, "(= ")?; + stringify_value(f, path)?; + write!(f, " ")?; + value.fmt(f)?; + write!(f, ")") + } + Check::Length(len) => { + write!(f, "(= (strlen ")?; + stringify_value(f, path)?; + write!(f, ") {len})") + } + Check::And(checks) => { + write!(f, "(and")?; + for check in checks { + write!(f, " ")?; + stringify_check(check, f, path)?; + } + write!(f, ")") + } + Check::Or(checks) => { + write!(f, "(or")?; + for check in checks { + write!(f, " ")?; + stringify_check(check, f, path)?; + } + write!(f, ")") + } + Check::If(cond, then, else_) => { + write!(f, "(if ")?; + stringify_check(cond, f, path)?; + write!(f, " ")?; + stringify_check(then, f, path)?; + write!(f, " ")?; + stringify_check(else_, f, path)?; + write!(f, ")") + } + Check::First(first) => { + path.push(TypePath::First); + stringify_check(first, f, path)?; + path.pop().unwrap(); + Ok(()) + } + Check::Rest(rest) => { + path.push(TypePath::Rest); + stringify_check(rest, f, path)?; + path.pop().unwrap(); + Ok(()) + } + } +} + +fn stringify_value(f: &mut fmt::Formatter<'_>, path: &[TypePath]) -> fmt::Result { + for path in path.iter().rev() { + match path { + TypePath::First => write!(f, "(f ")?, + TypePath::Rest => write!(f, "(r ")?, + } + } + write!(f, "val")?; + for _ in 0..path.len() { + write!(f, ")")?; + } + Ok(()) +} diff --git a/crates/rue-typing/src/compare_type.rs b/crates/rue-typing/src/compare_type.rs new file mode 100644 index 0000000..c470f66 --- /dev/null +++ b/crates/rue-typing/src/compare_type.rs @@ -0,0 +1,866 @@ +use std::cmp::{max, min}; + +use num_bigint::BigInt; +use num_traits::One; + +use crate::{bigint_to_bytes, HashMap, HashSet, Type, TypeId, TypeSystem}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Comparison { + Equal, + Assignable, + Castable, + NotEqual, +} + +pub(crate) struct ComparisonContext<'a> { + pub visited: HashSet<(TypeId, TypeId)>, + pub inferred: &'a mut Vec>, + pub infer_generics: bool, + pub lhs_substitutions: Vec>, + pub rhs_substitutions: Vec>, +} + +pub(crate) fn compare_type( + db: &TypeSystem, + lhs: TypeId, + rhs: TypeId, + ctx: &mut ComparisonContext<'_>, +) -> Comparison { + if !ctx.visited.insert((lhs, rhs)) { + return Comparison::Assignable; + } + + let found_lhs = ctx + .lhs_substitutions + .iter() + .rev() + .find_map(|substitutions| substitutions.get(&lhs).copied()); + + let found_rhs = ctx + .rhs_substitutions + .iter() + .rev() + .find_map(|substitutions| substitutions.get(&rhs).copied()); + + let comparison = match (db.get(lhs), db.get(rhs)) { + (Type::Ref(..) | Type::Lazy(..), _) | (_, Type::Ref(..) | Type::Lazy(..)) => unreachable!(), + + // These types are identical. + (Type::Unknown, Type::Unknown) + | (Type::Never, Type::Never) + | (Type::Any, Type::Any) + | (Type::Bytes, Type::Bytes) + | (Type::Bytes32, Type::Bytes32) + | (Type::PublicKey, Type::PublicKey) + | (Type::Int, Type::Int) + | (Type::Nil, Type::Nil) + | (Type::True, Type::True) + | (Type::False, Type::False) => Comparison::Equal, + + // These should always be the case, regardless of the other type. + (_, Type::Any | Type::Unknown) | (Type::Unknown | Type::Never, _) => Comparison::Assignable, + + // Handle generics and substitutions. + (Type::Generic, _) if found_lhs.is_some() => compare_type(db, found_lhs.unwrap(), rhs, ctx), + (_, Type::Generic) if found_rhs.is_some() => compare_type(db, lhs, found_rhs.unwrap(), ctx), + + // Infer generics. + (_, Type::Generic) => { + if let Some(inferred) = ctx + .inferred + .iter() + .rev() + .find_map(|map| map.get(&rhs).copied()) + { + compare_type(db, lhs, inferred, ctx) + } else if lhs == rhs { + Comparison::Equal + } else if ctx.infer_generics { + ctx.inferred.last_mut().unwrap().insert(rhs, lhs); + Comparison::Assignable + } else { + Comparison::NotEqual + } + } + + (Type::Generic, _) => { + if let Some(inferred) = ctx + .inferred + .iter() + .rev() + .find_map(|map| map.get(&lhs).copied()) + { + compare_type(db, inferred, rhs, ctx) + } else if lhs == rhs { + Comparison::Equal + } else { + Comparison::NotEqual + } + } + + // These are assignable since the structure and semantics match. + (Type::Value(..), Type::Int) | (Type::Bytes32 | Type::Nil, Type::Bytes) => { + Comparison::Assignable + } + + // These are castable since the structure matches but the semantics differ. + ( + Type::Bytes | Type::Bytes32 | Type::PublicKey | Type::Nil | Type::True | Type::False, + Type::Int, + ) + | (Type::PublicKey | Type::Int | Type::True | Type::False | Type::Value(..), Type::Bytes) + | (Type::False, Type::Nil) + | (Type::Nil, Type::False) => Comparison::Castable, + + // These are incompatible since the structure differs. + ( + Type::Any, + Type::Bytes + | Type::Bytes32 + | Type::PublicKey + | Type::Int + | Type::Nil + | Type::True + | Type::False + | Type::Value(..) + | Type::Pair(..), + ) + | ( + Type::Bytes | Type::Int, + Type::Bytes32 + | Type::PublicKey + | Type::Nil + | Type::True + | Type::False + | Type::Value(..), + ) + | ( + Type::Any + | Type::Bytes + | Type::Bytes32 + | Type::PublicKey + | Type::Int + | Type::Nil + | Type::True + | Type::False + | Type::Value(..) + | Type::Pair(..) + | Type::Callable(..), + Type::Never, + ) + | (Type::Bytes32 | Type::PublicKey, Type::Value(..)) + | ( + Type::Pair(..), + Type::Bytes + | Type::Bytes32 + | Type::PublicKey + | Type::Int + | Type::Nil + | Type::True + | Type::False + | Type::Value(..), + ) + | ( + Type::Bytes + | Type::Bytes32 + | Type::PublicKey + | Type::Int + | Type::Nil + | Type::True + | Type::False + | Type::Value(..), + Type::Pair(..), + ) + | (Type::Bytes32, Type::PublicKey | Type::Nil | Type::True | Type::False) + | (Type::PublicKey, Type::Bytes32 | Type::Nil | Type::True | Type::False) + | (Type::Nil, Type::Bytes32 | Type::PublicKey | Type::True) + | (Type::True, Type::False | Type::Nil) + | (Type::False, Type::True) + | (Type::True | Type::False, Type::Bytes32 | Type::PublicKey) + | ( + Type::Any + | Type::Bytes + | Type::Bytes32 + | Type::PublicKey + | Type::Int + | Type::Nil + | Type::True + | Type::False + | Type::Value(..) + | Type::Pair(..), + Type::Callable(..), + ) => Comparison::NotEqual, + + // Value is a subtype of Int, so it's castable to Bytes32 if it's 32 bytes long. + (Type::Value(value), Type::Bytes32) => { + if bigint_to_bytes(value.clone()).len() == 32 { + Comparison::Castable + } else { + Comparison::NotEqual + } + } + + // Value is a subtype of Int, so it's castable to PublicKey if it's 48 bytes long. + (Type::Value(value), Type::PublicKey) => { + if bigint_to_bytes(value.clone()).len() == 48 { + Comparison::Castable + } else { + Comparison::NotEqual + } + } + + // Nil and False are castable to Value only if the value is zero. + (Type::Nil | Type::False, Type::Value(value)) + | (Type::Value(value), Type::Nil | Type::False) => { + if value == &BigInt::ZERO { + Comparison::Castable + } else { + Comparison::NotEqual + } + } + + // True is castable to Value only if the value is one. + (Type::True, Type::Value(value)) | (Type::Value(value), Type::True) => { + if value == &BigInt::one() { + Comparison::Castable + } else { + Comparison::NotEqual + } + } + + // Value is equal to other instances of Value only if the values are equal. + (Type::Value(lhs), Type::Value(rhs)) => { + if lhs == rhs { + Comparison::Equal + } else { + Comparison::NotEqual + } + } + + // A comparison of pairs is done by using whichever comparison is the most restrictive. + (Type::Pair(lhs_first, lhs_rest), Type::Pair(rhs_first, rhs_rest)) => { + let first = compare_type(db, *lhs_first, *rhs_first, ctx); + let rest = compare_type(db, *lhs_rest, *rhs_rest, ctx); + max(first, rest) + } + + // Unions can be assigned to anything so long as each of the items in the union are also. + (Type::Union(items), _) => { + let items = items.clone(); + let mut result = Comparison::Assignable; + + for item in items { + let cmp = compare_type(db, item, rhs, ctx); + result = max(result, cmp); + } + + result + } + + // Anything can be assigned to a union so long as it's assignable to at least one of the items. + (_, Type::Union(items)) => { + let items = items.clone(); + let mut result = Comparison::NotEqual; + + for item in &items { + if matches!(db.get_recursive(*item), Type::Never) { + continue; + } + + let cmp = compare_type(db, lhs, *item, ctx); + result = min(result, cmp); + } + + max(result, Comparison::Assignable) + } + + // Resolve the alias to the type that it's pointing to. + (Type::Alias(alias), _) => compare_type(db, alias.type_id, rhs, ctx), + (_, Type::Alias(alias)) => compare_type(db, lhs, alias.type_id, ctx), + + // Structs are at best castable to other types, since they have different semantics. + (Type::Struct(lhs), Type::Struct(rhs)) if lhs.original_type_id == rhs.original_type_id => { + compare_type(db, lhs.type_id, rhs.type_id, ctx) + } + (Type::Struct(lhs), _) => max( + compare_type(db, lhs.type_id, rhs, ctx), + Comparison::Castable, + ), + (_, Type::Struct(rhs)) => max( + compare_type(db, lhs, rhs.type_id, ctx), + Comparison::Castable, + ), + + // Variants can be assigned to enums if the structure is assignable and it's the same enum. + (Type::Variant(variant), Type::Enum(ty)) => { + let comparison = compare_type(db, lhs, ty.type_id, ctx); + + if variant.original_enum_type_id == ty.original_type_id { + max(comparison, Comparison::Assignable) + } else { + max(comparison, Comparison::Castable) + } + } + + (Type::Enum(ty), Type::Variant(variant)) => { + let comparison = compare_type(db, ty.type_id, rhs, ctx); + + if variant.original_enum_type_id == ty.original_type_id { + max(comparison, Comparison::Assignable) + } else { + max(comparison, Comparison::Castable) + } + } + + // Enums can be assigned if the structure is assignable and it's the same enum. + (Type::Enum(lhs), Type::Enum(rhs)) if lhs.original_type_id == rhs.original_type_id => { + compare_type(db, lhs.type_id, rhs.type_id, ctx) + } + (Type::Enum(ty), _) => max(compare_type(db, ty.type_id, rhs, ctx), Comparison::Castable), + (_, Type::Enum(ty)) => max(compare_type(db, lhs, ty.type_id, ctx), Comparison::Castable), + + // Variants can be assigned if the structure is assignable and it's the same variant. + (Type::Variant(lhs), Type::Variant(rhs)) + if lhs.original_type_id == rhs.original_type_id => + { + compare_type(db, lhs.type_id, rhs.type_id, ctx) + } + (Type::Variant(lhs), _) => max( + compare_type(db, lhs.type_id, rhs, ctx), + Comparison::Castable, + ), + (_, Type::Variant(rhs)) => max( + compare_type(db, lhs, rhs.type_id, ctx), + Comparison::Castable, + ), + + // Functions can be assigned to other functions if the parameters and return type are assignable. + // They're treated like Never on the right hand side and Any on the left hand side. + (Type::Callable(lhs), Type::Callable(rhs)) => max( + compare_type(db, lhs.parameters, rhs.parameters, ctx), + compare_type(db, lhs.return_type, rhs.return_type, ctx), + ), + (Type::Callable(..), _) => compare_type(db, lhs, db.std().any, ctx), + }; + + ctx.visited.remove(&(lhs, rhs)); + + comparison +} + +#[cfg(test)] +mod tests { + use indexmap::indexmap; + + use crate::{alloc_list, alloc_struct, alloc_tuple_of, Enum, Struct, Variant}; + + use super::*; + + #[test] + fn test_compare_int_int() { + let db = TypeSystem::new(); + let types = db.std(); + assert_eq!(db.compare(types.int, types.int), Comparison::Equal); + } + + #[test] + fn test_compare_int_bytes() { + let db = TypeSystem::new(); + let types = db.std(); + assert_eq!(db.compare(types.int, types.bytes), Comparison::Castable); + } + + #[test] + fn test_compare_bytes_int() { + let db = TypeSystem::new(); + let types = db.std(); + assert_eq!(db.compare(types.bytes, types.int), Comparison::Castable); + } + + #[test] + fn test_compare_bytes_bytes32() { + let db = TypeSystem::new(); + let types = db.std(); + assert_eq!(db.compare(types.bytes, types.bytes32), Comparison::NotEqual); + } + + #[test] + fn test_compare_bytes32_bytes() { + let db = TypeSystem::new(); + let types = db.std(); + assert_eq!( + db.compare(types.bytes32, types.bytes), + Comparison::Assignable + ); + } + + #[test] + fn test_compare_bytes_public_key() { + let db = TypeSystem::new(); + let types = db.std(); + assert_eq!( + db.compare(types.bytes, types.public_key), + Comparison::NotEqual + ); + } + + #[test] + fn test_compare_public_key_bytes() { + let db = TypeSystem::new(); + let types = db.std(); + assert_eq!( + db.compare(types.public_key, types.bytes), + Comparison::Castable + ); + } + + #[test] + fn test_compare_bytes32_public_key() { + let db = TypeSystem::new(); + let types = db.std(); + assert_eq!( + db.compare(types.bytes32, types.public_key), + Comparison::NotEqual + ); + } + + #[test] + fn test_compare_public_key_bytes32() { + let db = TypeSystem::new(); + let types = db.std(); + assert_eq!( + db.compare(types.public_key, types.bytes32), + Comparison::NotEqual + ); + } + + #[test] + fn test_compare_bytes_any() { + let db = TypeSystem::new(); + let types = db.std(); + assert_eq!(db.compare(types.bytes, types.any), Comparison::Assignable); + } + + #[test] + fn test_compare_any_bytes() { + let db = TypeSystem::new(); + let types = db.std(); + assert_eq!(db.compare(types.any, types.bytes), Comparison::NotEqual); + } + + #[test] + fn test_compare_bytes32_any() { + let db = TypeSystem::new(); + let types = db.std(); + assert_eq!(db.compare(types.bytes32, types.any), Comparison::Assignable); + } + + #[test] + fn test_compare_any_bytes32() { + let db = TypeSystem::new(); + let types = db.std(); + assert_eq!(db.compare(types.any, types.bytes32), Comparison::NotEqual); + } + + #[test] + fn test_compare_list_any() { + let mut db = TypeSystem::new(); + let types = db.std(); + let list = alloc_list(&mut db, types.int); + assert_eq!(db.compare(list, types.any), Comparison::Assignable); + } + + #[test] + fn test_compare_pair_any() { + let mut db = TypeSystem::new(); + let types = db.std(); + let pair = db.alloc(Type::Pair(types.int, types.public_key)); + assert_eq!(db.compare(pair, types.any), Comparison::Assignable); + } + + #[test] + fn test_compare_int_any() { + let db = TypeSystem::new(); + let types = db.std(); + assert_eq!(db.compare(types.int, types.any), Comparison::Assignable); + } + + #[test] + fn test_compare_public_key_any() { + let db = TypeSystem::new(); + let types = db.std(); + assert_eq!( + db.compare(types.public_key, types.any), + Comparison::Assignable + ); + } + + #[test] + fn test_compare_complex_any() { + let mut db = TypeSystem::new(); + let types = db.std(); + let pair_inner_inner = db.alloc(Type::Pair(types.any, types.nil)); + let pair_inner = db.alloc(Type::Pair(pair_inner_inner, pair_inner_inner)); + let pair = db.alloc(Type::Pair(types.int, pair_inner)); + let list = alloc_list(&mut db, pair); + assert_eq!(db.compare(list, types.any), Comparison::Assignable); + } + + #[test] + fn test_point_struct_any() { + let mut db = TypeSystem::new(); + let types = db.std(); + let point = alloc_struct( + &mut db, + &indexmap! { + "x".to_string() => types.int, + "y".to_string() => types.int, + }, + true, + ); + assert_eq!(db.compare(point, types.any), Comparison::Assignable); + } + + #[test] + fn test_compare_any_any() { + let db = TypeSystem::new(); + let types = db.std(); + assert_eq!(db.compare(types.any, types.any), Comparison::Equal); + } + + #[test] + fn test_compare_incompatible_pair() { + let mut db = TypeSystem::new(); + let types = db.std(); + let lhs = db.alloc(Type::Pair(types.int, types.public_key)); + let rhs = db.alloc(Type::Pair(types.bytes, types.nil)); + assert_eq!(db.compare(lhs, rhs), Comparison::NotEqual); + } + + #[test] + fn test_compare_castable_pair() { + let mut db = TypeSystem::new(); + let types = db.std(); + let lhs = db.alloc(Type::Pair(types.int, types.public_key)); + let rhs = db.alloc(Type::Pair(types.bytes, types.bytes)); + assert_eq!(db.compare(lhs, rhs), Comparison::Castable); + } + + #[test] + fn test_compare_assignable_pair() { + let mut db = TypeSystem::new(); + let types = db.std(); + let lhs = db.alloc(Type::Pair(types.int, types.public_key)); + let rhs = db.alloc(Type::Pair(types.any, types.any)); + assert_eq!(db.compare(lhs, rhs), Comparison::Assignable); + } + + #[test] + fn test_compare_nil_list() { + let mut db = TypeSystem::new(); + let types = db.std(); + let list = alloc_list(&mut db, types.int); + assert_eq!(db.compare(types.nil, list), Comparison::Assignable); + } + + #[test] + fn test_compare_pair_list() { + let mut db = TypeSystem::new(); + let types = db.std(); + let pair = db.alloc(Type::Pair(types.int, types.nil)); + let list = alloc_list(&mut db, types.int); + assert_eq!(db.compare(pair, list), Comparison::Assignable); + } + + #[test] + fn test_compare_generic_inference() { + let mut db = TypeSystem::new(); + let types = db.std(); + + let generic = db.alloc(Type::Generic); + + let mut stack = vec![HashMap::new()]; + + assert_eq!( + db.compare_with_generics(types.int, generic, &mut stack, true), + Comparison::Assignable + ); + + assert_eq!(stack.len(), 1); + assert_eq!(stack[0].get(&generic), Some(&types.int)); + + for infer in [true, false] { + assert_eq!( + db.compare_with_generics(types.bytes, generic, &mut stack, infer), + Comparison::Castable + ); + assert_eq!( + db.compare_with_generics(types.any, generic, &mut stack, infer), + Comparison::NotEqual + ); + } + } + + #[test] + fn test_compare_enum_generic_inference() { + let mut db = TypeSystem::new(); + + let generic = db.alloc(Type::Generic); + + let mut stack = vec![HashMap::new()]; + + let enum_type = db.alloc(Type::Unknown); + let variant = db.alloc(Type::Unknown); + + let variant_inner = db.alloc(Type::Value(BigInt::ZERO)); + + *db.get_mut(variant) = Type::Variant(Variant { + original_enum_type_id: enum_type, + original_type_id: variant, + type_id: variant_inner, + field_names: None, + nil_terminated: false, + generic_types: vec![], + discriminant: BigInt::ZERO, + }); + + *db.get_mut(enum_type) = Type::Enum(Enum { + original_type_id: enum_type, + type_id: variant, + has_fields: false, + variants: indexmap! { + "A".to_string() => variant + }, + }); + + assert_eq!( + db.compare_with_generics(enum_type, generic, &mut stack, true), + Comparison::Assignable + ); + + assert_eq!(stack.len(), 1); + assert_eq!(stack[0].get(&generic), Some(&enum_type)); + } + + #[test] + fn test_compare_union_to_rhs_incompatible() { + let mut db = TypeSystem::new(); + let types = db.std(); + + let pair = db.alloc(Type::Pair(types.int, types.public_key)); + let union = db.alloc(Type::Union(vec![types.bytes32, pair, types.nil])); + assert_eq!(db.compare(union, types.bytes), Comparison::NotEqual); + } + + #[test] + fn test_compare_union_to_rhs_superset() { + let mut db = TypeSystem::new(); + let types = db.std(); + + let pair = db.alloc(Type::Pair(types.int, types.public_key)); + let union = db.alloc(Type::Union(vec![types.bytes, pair])); + assert_eq!(db.compare(union, types.bytes), Comparison::NotEqual); + } + + #[test] + fn test_compare_union_to_rhs_assignable() { + let mut db = TypeSystem::new(); + let types = db.std(); + + let union = db.alloc(Type::Union(vec![types.bytes32, types.nil])); + assert_eq!(db.compare(union, types.bytes), Comparison::Assignable); + } + + #[test] + fn test_compare_lhs_to_union_incompatible() { + let mut db = TypeSystem::new(); + let types = db.std(); + + let pair = db.alloc(Type::Pair(types.int, types.public_key)); + let union = db.alloc(Type::Union(vec![types.bytes32, pair, types.nil])); + assert_eq!(db.compare(types.bytes, union), Comparison::NotEqual); + } + + #[test] + fn test_compare_same_derivative_struct() { + let mut db = TypeSystem::new(); + let types = db.std(); + + let struct_type = alloc_struct( + &mut db, + &indexmap! { + "x".to_string() => types.int, + "y".to_string() => types.int, + }, + true, + ); + + let Type::Struct(original) = db.get(struct_type) else { + unreachable!(); + }; + + let derivative_struct_type = db.alloc(Type::Struct(Struct { + original_type_id: struct_type, + type_id: original.type_id, + field_names: original.field_names.clone(), + nil_terminated: original.nil_terminated, + generic_types: original.generic_types.clone(), + })); + + assert_eq!( + db.compare(derivative_struct_type, struct_type), + Comparison::Equal + ); + + assert_eq!( + db.compare(struct_type, derivative_struct_type), + Comparison::Equal + ); + } + + #[test] + fn test_compare_different_derivative_struct() { + let mut db = TypeSystem::new(); + let types = db.std(); + + let struct_type = alloc_struct( + &mut db, + &indexmap! { + "x".to_string() => types.int, + "y".to_string() => types.int, + }, + true, + ); + + let Type::Struct(original) = db.get(struct_type).clone() else { + unreachable!(); + }; + + let new_inner = alloc_tuple_of(&mut db, [types.int, types.bytes, types.nil].into_iter()); + + let derivative_struct_type = db.alloc(Type::Struct(Struct { + original_type_id: struct_type, + type_id: new_inner, + field_names: original.field_names, + nil_terminated: original.nil_terminated, + generic_types: original.generic_types, + })); + + assert_eq!( + db.compare(derivative_struct_type, struct_type), + Comparison::Castable + ); + + assert_eq!( + db.compare(struct_type, derivative_struct_type), + Comparison::Castable + ); + } + + #[test] + fn test_compare_different_struct() { + let mut db = TypeSystem::new(); + let types = db.std(); + + let struct_type = alloc_struct( + &mut db, + &indexmap! { + "x".to_string() => types.int, + "y".to_string() => types.int, + }, + true, + ); + + let other_struct_type = alloc_struct( + &mut db, + &indexmap! { + "x".to_string() => types.int, + "y".to_string() => types.int, + }, + true, + ); + + assert_eq!( + db.compare(struct_type, other_struct_type), + Comparison::Castable + ); + } + + #[test] + fn test_compare_generic_equal() { + let mut db = TypeSystem::new(); + let types = db.std(); + let generic = db.alloc(Type::Generic); + assert_eq!(db.compare(types.int, generic), Comparison::NotEqual); + assert_eq!(db.compare(generic, generic), Comparison::Equal); + } + + #[test] + fn test_compare_generic_list_assignable() { + let mut db = TypeSystem::new(); + let types = db.std(); + let generic = db.alloc(Type::Generic); + let list = alloc_list(&mut db, generic); + let pair = db.alloc(Type::Pair(generic, list)); + assert_eq!(db.compare(types.nil, list), Comparison::Assignable); + assert_eq!(db.compare(list, list), Comparison::Assignable); + assert_eq!(db.compare(pair, list), Comparison::Assignable); + } + + #[test] + fn test_compare_pair_union() { + let mut db = TypeSystem::new(); + let types = db.std(); + + let pair_enum = db.alloc(Type::Pair(types.int, types.nil)); + let pair_enum = db.alloc(Type::Pair(types.int, pair_enum)); + let zero = db.alloc(Type::Value(BigInt::ZERO)); + let pair_enum = db.alloc(Type::Pair(zero, pair_enum)); + + let int_enum = db.alloc(Type::Pair(types.int, types.nil)); + let one = db.alloc(Type::Value(BigInt::one())); + let int_enum = db.alloc(Type::Pair(one, int_enum)); + + let union = db.alloc(Type::Union(vec![pair_enum, int_enum])); + + assert_eq!(db.compare(pair_enum, union), Comparison::Assignable); + } + + #[test] + fn test_compare_list_unmapped_list_generics() { + let mut db = TypeSystem::new(); + let types = db.std(); + + let mut stack = vec![HashMap::new()]; + + let list = alloc_list(&mut db, types.int); + assert_eq!( + db.compare_with_generics(list, types.unmapped_list, &mut stack, true), + Comparison::Assignable + ); + assert_eq!(stack, vec![[(types.generic_list_item, types.int)].into()]); + } + + #[test] + fn test_compare_tuple_list_generics() { + let mut db = TypeSystem::new(); + let types = db.std(); + + let mut stack = vec![HashMap::new()]; + + let tuple = alloc_tuple_of( + &mut db, + [types.int, types.int, types.int, types.nil].into_iter(), + ); + + let generic = db.alloc(Type::Generic); + let list = alloc_list(&mut db, generic); + + assert_eq!( + db.compare_with_generics(tuple, list, &mut stack, true), + Comparison::Assignable + ); + assert_eq!(stack, vec![[(generic, types.int)].into()]); + } +} diff --git a/crates/rue-typing/src/debug_type.rs b/crates/rue-typing/src/debug_type.rs new file mode 100644 index 0000000..d5c3870 --- /dev/null +++ b/crates/rue-typing/src/debug_type.rs @@ -0,0 +1,123 @@ +use crate::HashSet; + +use crate::{Type, TypeId, TypeSystem}; + +pub(crate) fn debug_type( + ty: &TypeSystem, + prefix: &str, + type_id: TypeId, + indent: usize, + visited: &mut HashSet, +) -> String { + let mut result = String::new(); + + for _ in 0..indent { + result.push_str(" "); + } + + result.push_str(prefix); + result.push_str(&format!("({}) ", type_id.index())); + + if !visited.insert(type_id) { + result.push_str("..."); + return result; + } + + match ty.get_raw(type_id) { + Type::Unknown => result.push_str("Unknown"), + Type::Any => result.push_str("Any"), + Type::Never => result.push_str("Never"), + Type::Bytes => result.push_str("Bytes"), + Type::Bytes32 => result.push_str("Bytes32"), + Type::PublicKey => result.push_str("PublicKey"), + Type::Int => result.push_str("Int"), + Type::Nil => result.push_str("Nil"), + Type::True => result.push_str("True"), + Type::False => result.push_str("False"), + Type::Generic => result.push_str("Generic"), + Type::Value(value) => result.push_str(&format!("Literal {value}")), + Type::Pair(first, rest) => { + let first = debug_type(ty, "First", *first, indent + 1, visited); + let rest = debug_type(ty, "Rest", *rest, indent + 1, visited); + result.push_str(&format!("Pair\n{first}\n{rest}")); + } + Type::Union(types) => { + result.push_str("Union"); + for type_id in types { + let type_str = debug_type(ty, "", *type_id, indent + 1, visited); + result.push_str(&format!("\n{type_str}")); + } + } + Type::Ref(inner) => { + let inner = debug_type(ty, "", *inner, indent + 1, visited); + result.push_str(&format!("Ref\n{inner}")); + } + Type::Alias(alias) => { + result.push_str(&format!("Alias {}", alias.original_type_id.index())); + if !alias.generic_types.is_empty() { + generics(&mut result, &alias.generic_types); + } + let inner = debug_type(ty, "", alias.type_id, indent + 1, visited); + result.push_str(&format!("\n{inner}")); + } + Type::Lazy(lazy) => { + result.push_str("Lazy"); + if !lazy.substitutions.is_empty() { + result.push_str(" <"); + for (i, (from, to)) in lazy.substitutions.iter().enumerate() { + if i != 0 { + result.push_str(", "); + } + result.push_str(&format!("{} = {}", from.index(), to.index())); + } + result.push('>'); + } + let inner = debug_type(ty, "", lazy.type_id, indent + 1, visited); + result.push_str(&format!("\n{inner}")); + } + Type::Struct(struct_type) => { + result.push_str("Struct"); + if !struct_type.generic_types.is_empty() { + generics(&mut result, &struct_type.generic_types); + } + let inner = debug_type(ty, "", struct_type.type_id, indent + 1, visited); + result.push_str(&format!("\n{inner}")); + } + Type::Enum(enum_type) => { + result.push_str("Enum"); + let inner = debug_type(ty, "", enum_type.type_id, indent + 1, visited); + result.push_str(&format!("\n{inner}")); + } + Type::Variant(variant) => { + result.push_str("Variant"); + let inner = debug_type(ty, "", variant.type_id, indent + 1, visited); + result.push_str(&format!("\n{inner}")); + } + Type::Callable(callable) => { + result.push_str("Callable"); + if !callable.generic_types.is_empty() { + generics(&mut result, &callable.generic_types); + } + let inner = debug_type(ty, "Parameters", callable.parameters, indent + 1, visited); + result.push_str(&format!("\n{inner}")); + let inner = debug_type(ty, "Return", callable.return_type, indent + 1, visited); + result.push_str(&format!("\n{inner}")); + } + } + + visited.remove(&type_id); + result +} + +fn generics(result: &mut String, generics: &[TypeId]) { + if !generics.is_empty() { + result.push_str(" <"); + for (i, type_id) in generics.iter().enumerate() { + if i != 0 { + result.push_str(", "); + } + result.push_str(&format!("{}", type_id.index())); + } + result.push('>'); + } +} diff --git a/crates/rue-typing/src/difference_type.rs b/crates/rue-typing/src/difference_type.rs new file mode 100644 index 0000000..e98c727 --- /dev/null +++ b/crates/rue-typing/src/difference_type.rs @@ -0,0 +1,248 @@ +use crate::HashSet; + +use num_bigint::BigInt; +use num_traits::One; + +use crate::{bigint_to_bytes, Enum, Struct, Type, TypeId, TypeSystem, Variant}; + +pub(crate) fn difference_type( + types: &mut TypeSystem, + lhs: TypeId, + rhs: TypeId, + visited: &mut HashSet<(TypeId, TypeId)>, +) -> TypeId { + let std = types.std(); + + if !visited.insert((lhs, rhs)) { + return lhs; + } + + let result = match (types.get(lhs), types.get(rhs)) { + (Type::Ref(..) | Type::Lazy(..), _) | (_, Type::Ref(..) | Type::Lazy(..)) => unreachable!(), + + // If you subtract a supertype or equal type, there are no other possible types. + (Type::Never, _) + | (_, Type::Any | Type::Unknown | Type::Generic | Type::Callable(..)) + | (Type::Bytes | Type::Int | Type::Value(..), Type::Bytes | Type::Int) + | (Type::Bytes32, Type::Bytes32 | Type::Bytes | Type::Int) + | (Type::PublicKey, Type::PublicKey | Type::Bytes | Type::Int) + | (Type::Nil | Type::False, Type::Nil | Type::Bytes | Type::Int | Type::False) + | (Type::True, Type::True | Type::Bytes | Type::Int) => std.never, + + // If you subtract something which is a subtype, the result is the same as the original type. + (_, Type::Never) + | ( + Type::Any + | Type::Int + | Type::Bytes + | Type::Unknown + | Type::Generic + | Type::Callable(..), + Type::Bytes32 + | Type::PublicKey + | Type::Nil + | Type::True + | Type::False + | Type::Value(..), + ) + | ( + Type::Unknown | Type::Generic | Type::Callable(..), + Type::Int | Type::Bytes | Type::Pair(..), + ) + | ( + Type::Bytes32, + Type::PublicKey | Type::Nil | Type::True | Type::False | Type::Value(..), + ) + | ( + Type::PublicKey, + Type::Bytes32 | Type::Nil | Type::True | Type::False | Type::Value(..), + ) + | (Type::Nil | Type::False, Type::Bytes32 | Type::PublicKey | Type::True) + | (Type::True, Type::Bytes32 | Type::PublicKey | Type::Nil | Type::False) + | ( + Type::Pair(..), + Type::Bytes + | Type::Bytes32 + | Type::PublicKey + | Type::Int + | Type::Nil + | Type::True + | Type::False + | Type::Value(..), + ) + | ( + Type::Bytes + | Type::Bytes32 + | Type::PublicKey + | Type::Int + | Type::Nil + | Type::True + | Type::False + | Type::Value(..), + Type::Pair(..), + ) => lhs, + + // Any is defined as either Bytes or (Any, Any). + (Type::Any, Type::Pair(..)) => std.bytes, + (Type::Any, Type::Bytes | Type::Int) => types.alloc(Type::Pair(std.any, std.any)), + + (Type::Nil | Type::False, Type::Value(value)) + | (Type::Value(value), Type::Nil | Type::False) => { + if value == &BigInt::ZERO { + std.never + } else { + lhs + } + } + + (Type::True, Type::Value(value)) | (Type::Value(value), Type::True) => { + if value == &BigInt::one() { + std.never + } else { + lhs + } + } + + (Type::Value(value), Type::Bytes32) => { + if bigint_to_bytes(value.clone()).len() == 32 { + std.never + } else { + lhs + } + } + (Type::Value(value), Type::PublicKey) => { + if bigint_to_bytes(value.clone()).len() == 48 { + std.never + } else { + lhs + } + } + + (Type::Value(lhs_value), Type::Value(rhs_value)) => { + if lhs_value == rhs_value { + std.never + } else { + lhs + } + } + + (Type::Pair(lhs_first, lhs_rest), Type::Pair(rhs_first, rhs_rest)) => { + let (lhs_first, lhs_rest) = (*lhs_first, *lhs_rest); + let (rhs_first, rhs_rest) = (*rhs_first, *rhs_rest); + + let first = difference_type(types, lhs_first, rhs_first, visited); + let rest = difference_type(types, lhs_rest, rhs_rest, visited); + + if matches!(types.get(first), Type::Never) || matches!(types.get(first), Type::Never) { + std.never + } else if first == lhs_first && rest == lhs_rest { + lhs + } else { + types.alloc(Type::Pair(first, rest)) + } + } + + (Type::Alias(alias), _) => difference_type(types, alias.type_id, rhs, visited), + (_, Type::Alias(alias)) => difference_type(types, lhs, alias.type_id, visited), + + (Type::Struct(ty), _) => { + let ty = ty.clone(); + let type_id = difference_type(types, ty.type_id, rhs, visited); + + types.alloc(Type::Struct(Struct { + original_type_id: ty.original_type_id, + type_id, + field_names: ty.field_names, + nil_terminated: ty.nil_terminated, + generic_types: ty.generic_types, + })) + } + (_, Type::Struct(ty)) => difference_type(types, lhs, ty.type_id, visited), + + (Type::Enum(ty), _) => { + let ty = ty.clone(); + let type_id = difference_type(types, ty.type_id, rhs, visited); + + types.alloc(Type::Enum(Enum { + original_type_id: ty.original_type_id, + type_id, + has_fields: ty.has_fields, + variants: ty.variants, + })) + } + (_, Type::Enum(ty)) => difference_type(types, lhs, ty.type_id, visited), + + (Type::Variant(variant), _) => { + let variant = variant.clone(); + let type_id = difference_type(types, variant.type_id, rhs, visited); + + types.alloc(Type::Variant(Variant { + original_type_id: variant.original_type_id, + original_enum_type_id: variant.original_enum_type_id, + field_names: variant.field_names, + type_id, + nil_terminated: variant.nil_terminated, + generic_types: variant.generic_types, + discriminant: variant.discriminant, + })) + } + (_, Type::Variant(variant)) => difference_type(types, lhs, variant.type_id, visited), + + (Type::Union(items), _) => { + let items = items.clone(); + + let mut result = Vec::new(); + + for item in &items { + let item = difference_type(types, *item, rhs, visited); + if matches!(types.get_recursive(item), Type::Never) { + continue; + } + result.push(item); + } + + if result.is_empty() { + std.never + } else if result.len() == 1 { + result[0] + } else if result == items { + lhs + } else { + types.alloc(Type::Union(result)) + } + } + + (_, Type::Union(items)) => { + let items = items.clone(); + let mut lhs = lhs; + for item in items { + lhs = difference_type(types, lhs, item, visited); + } + lhs + } + }; + + visited.remove(&(lhs, rhs)); + + result +} + +#[cfg(test)] +mod tests { + use crate::{alloc_list, Comparison}; + + use super::*; + + #[test] + fn test_difference_list_nil() { + let mut db = TypeSystem::new(); + let types = db.std(); + + let generic = db.alloc(Type::Generic); + let list = alloc_list(&mut db, generic); + let non_nil = db.difference(list, types.nil); + + assert_eq!(db.compare(non_nil, list), Comparison::Assignable); + assert_eq!(db.compare(types.nil, non_nil), Comparison::NotEqual); + } +} diff --git a/crates/rue-typing/src/lib.rs b/crates/rue-typing/src/lib.rs new file mode 100644 index 0000000..55f8e8e --- /dev/null +++ b/crates/rue-typing/src/lib.rs @@ -0,0 +1,36 @@ +mod bigint; +mod check; +mod compare_type; +mod debug_type; +mod difference_type; +mod map; +mod replace_type; +mod semantic_types; +mod standard_types; +mod stringify_type; +mod substitute_type; +mod ty; +mod type_path; +mod type_system; + +pub use bigint::*; +pub use check::*; +pub use compare_type::*; +pub use map::*; +pub use semantic_types::*; +pub use standard_types::*; +pub use ty::*; +pub use type_path::*; +pub use type_system::*; + +pub(crate) use debug_type::debug_type; +pub(crate) use difference_type::difference_type; +pub(crate) use replace_type::replace_type; +pub(crate) use stringify_type::stringify_type; +pub(crate) use substitute_type::substitute_type; + +#[cfg(test)] +mod test_tools; + +#[cfg(test)] +pub(crate) use test_tools::*; diff --git a/crates/rue-typing/src/map.rs b/crates/rue-typing/src/map.rs new file mode 100644 index 0000000..c6e996e --- /dev/null +++ b/crates/rue-typing/src/map.rs @@ -0,0 +1,6 @@ +use std::hash::BuildHasherDefault; + +use ahash::AHasher; + +pub type HashMap = hashbrown::HashMap>; +pub type HashSet = hashbrown::HashSet>; diff --git a/crates/rue-typing/src/replace_type.rs b/crates/rue-typing/src/replace_type.rs new file mode 100644 index 0000000..54565d4 --- /dev/null +++ b/crates/rue-typing/src/replace_type.rs @@ -0,0 +1,63 @@ +use crate::{Alias, Enum, Struct, Type, TypeId, TypePath, TypeSystem, Variant}; + +pub(crate) fn replace_type( + types: &mut TypeSystem, + type_id: TypeId, + replace_type_id: TypeId, + path: &[TypePath], +) -> TypeId { + if path.is_empty() { + return replace_type_id; + } + + match types.get(type_id) { + Type::Pair(first, rest) => match path[0] { + TypePath::First => replace_type(types, *first, replace_type_id, &path[1..]), + TypePath::Rest => replace_type(types, *rest, replace_type_id, &path[1..]), + }, + Type::Alias(alias) => { + let alias = alias.clone(); + let new_type_id = replace_type(types, alias.type_id, replace_type_id, path); + types.alloc(Type::Alias(Alias { + original_type_id: alias.original_type_id, + type_id: new_type_id, + generic_types: alias.generic_types, + })) + } + Type::Struct(ty) => { + let ty = ty.clone(); + let new_type_id = replace_type(types, ty.type_id, replace_type_id, path); + types.alloc(Type::Struct(Struct { + original_type_id: ty.original_type_id, + field_names: ty.field_names, + type_id: new_type_id, + nil_terminated: ty.nil_terminated, + generic_types: ty.generic_types, + })) + } + Type::Variant(ty) => { + let ty = ty.clone(); + let new_type_id = replace_type(types, ty.type_id, replace_type_id, path); + types.alloc(Type::Variant(Variant { + original_type_id: ty.original_type_id, + original_enum_type_id: ty.original_enum_type_id, + field_names: ty.field_names, + type_id: new_type_id, + nil_terminated: ty.nil_terminated, + generic_types: ty.generic_types, + discriminant: ty.discriminant, + })) + } + Type::Enum(ty) => { + let ty = ty.clone(); + let new_type_id = replace_type(types, ty.type_id, replace_type_id, path); + types.alloc(Type::Enum(Enum { + original_type_id: ty.original_type_id, + type_id: new_type_id, + has_fields: ty.has_fields, + variants: ty.variants, + })) + } + _ => type_id, + } +} diff --git a/crates/rue-typing/src/semantic_types.rs b/crates/rue-typing/src/semantic_types.rs new file mode 100644 index 0000000..784df15 --- /dev/null +++ b/crates/rue-typing/src/semantic_types.rs @@ -0,0 +1,230 @@ +use indexmap::{IndexMap, IndexSet}; +use num_bigint::BigInt; + +use crate::{Comparison, Type, TypeId, TypeSystem}; + +/// Allows you to map generic types lazily without having to resolve them immediately. +/// This prevents stack overflows when resolving generic type definitions that reference themselves. +/// When reducing a type to its structural form, lazy types should be removed. +#[derive(Debug, Clone)] +pub struct Lazy { + pub type_id: TypeId, + pub substitutions: IndexMap, +} + +/// Represents an alias to a type with a set of generic parameters that must be mapped prior to use. +#[derive(Debug, Clone)] +pub struct Alias { + /// A pointer to the alias from which this was derived. + pub original_type_id: TypeId, + pub type_id: TypeId, + pub generic_types: Vec, +} + +/// Struct types are just wrappers around a structural type that provide field information. +#[derive(Debug, Clone)] +pub struct Struct { + /// A pointer to the struct from which this was derived. + pub original_type_id: TypeId, + pub field_names: IndexSet, + pub type_id: TypeId, + pub nil_terminated: bool, + pub generic_types: Vec, +} + +/// Represents something which can be called with arguments and returns a given type. +#[derive(Debug, Clone)] +pub struct Callable { + /// A pointer to the callable from which this was derived. + pub original_type_id: TypeId, + pub parameter_names: IndexSet, + pub parameters: TypeId, + pub return_type: TypeId, + pub nil_terminated: bool, + pub generic_types: Vec, +} + +/// Represents an enum type which can have multiple variants. +#[derive(Debug, Clone)] +pub struct Enum { + /// A pointer to the enum from which this was derived. + pub original_type_id: TypeId, + /// The structural type of the enum. + pub type_id: TypeId, + /// Whether the enum semantically has fields. + pub has_fields: bool, + /// This is a map of the original variant names to their type ids. + pub variants: IndexMap, +} + +/// Represents a variant type which can optionally have fields. +#[derive(Debug, Clone)] +pub struct Variant { + /// A pointer to the variant from which this was derived. + pub original_type_id: TypeId, + /// The original enum type to which this variant belongs. + pub original_enum_type_id: TypeId, + /// The field names of the variant. + pub field_names: Option>, + /// The structural type of the enum variant. + pub type_id: TypeId, + /// Whether the chain of cons pairs is nil terminated. + pub nil_terminated: bool, + /// The generic types of the variant. + pub generic_types: Vec, + /// The discriminant value. + pub discriminant: BigInt, +} + +/// Constructs a structural type consisting of the items in a list. +pub fn construct_items( + db: &mut TypeSystem, + items: impl DoubleEndedIterator, + nil_terminated: bool, +) -> TypeId { + let mut result = db.std().nil; + for (i, item) in items.rev().enumerate() { + if i == 0 && !nil_terminated { + result = item; + continue; + } + result = db.alloc(Type::Pair(item, result)); + } + result +} + +/// Deconstructs a structural type into a list of items and a rest value. +pub fn deconstruct_items( + db: &mut TypeSystem, + type_id: TypeId, + length: usize, + nil_terminated: bool, +) -> Option> { + let mut items = Vec::with_capacity(length); + let mut current = type_id; + + for i in (0..length).rev() { + if i == 0 { + if !nil_terminated { + items.push(current); + break; + } + + let (first, rest) = db.get_pair(current)?; + items.push(first); + if db.compare(rest, db.std().nil) > Comparison::Equal { + return None; + } + break; + } + + let (first, rest) = db.get_pair(current)?; + items.push(first); + current = rest; + } + + Some(items) +} + +/// Unwraps a list type into its inner type. +pub fn unwrap_list(db: &mut TypeSystem, type_id: TypeId) -> Option { + if db.compare(db.std().nil, type_id) > Comparison::Assignable { + return None; + } + + let non_nil = db.difference(type_id, db.std().nil); + let (first, rest) = db.get_pair(non_nil)?; + + if db.compare(rest, type_id) > Comparison::Assignable { + return None; + } + + Some(first) +} + +#[cfg(test)] +mod tests { + use crate::alloc_list; + + use super::*; + + #[test] + fn test_construct_int_nil() { + let mut db = TypeSystem::new(); + let std = db.std(); + let type_id = construct_items(&mut db, [std.int].into_iter(), true); + let items = deconstruct_items(&mut db, type_id, 1, true); + assert_eq!(items, Some(vec![std.int])); + } + + #[test] + fn test_construct_int() { + let mut db = TypeSystem::new(); + let std = db.std(); + let type_id = construct_items(&mut db, [std.int].into_iter(), false); + let items = deconstruct_items(&mut db, type_id, 1, false); + assert_eq!(items, Some(vec![std.int])); + } + + #[test] + fn test_construct_empty_nil() { + let mut db = TypeSystem::new(); + let type_id = construct_items(&mut db, [].into_iter(), true); + let items = deconstruct_items(&mut db, type_id, 0, true); + assert_eq!(items, Some(vec![])); + } + + #[test] + fn test_construct_empty() { + let mut db = TypeSystem::new(); + let type_id = construct_items(&mut db, [].into_iter(), false); + let items = deconstruct_items(&mut db, type_id, 0, false); + assert_eq!(items, Some(vec![])); + } + + #[test] + fn test_construct_int_int_nil() { + let mut db = TypeSystem::new(); + let std = db.std(); + let type_id = construct_items(&mut db, [std.int, std.int].into_iter(), true); + let items = deconstruct_items(&mut db, type_id, 2, true); + assert_eq!(items, Some(vec![std.int, std.int])); + } + + #[test] + fn test_construct_int_int() { + let mut db = TypeSystem::new(); + let std = db.std(); + let type_id = construct_items(&mut db, [std.int, std.int].into_iter(), false); + let items = deconstruct_items(&mut db, type_id, 2, false); + assert_eq!(items, Some(vec![std.int, std.int])); + } + + #[test] + fn test_construct_bytes32_pair_nil() { + let mut db = TypeSystem::new(); + let std = db.std(); + let pair = db.alloc(Type::Pair(std.bytes32, std.nil)); + let type_id = construct_items(&mut db, [std.bytes32, pair].into_iter(), true); + let items = deconstruct_items(&mut db, type_id, 2, true); + assert_eq!(items, Some(vec![std.bytes32, pair])); + } + + #[test] + fn test_construct_bytes32_pair() { + let mut db = TypeSystem::new(); + let std = db.std(); + let pair = db.alloc(Type::Pair(std.bytes32, std.nil)); + let type_id = construct_items(&mut db, [std.bytes32, pair].into_iter(), false); + let items = deconstruct_items(&mut db, type_id, 2, false); + assert_eq!(items, Some(vec![std.bytes32, pair])); + } + + #[test] + fn test_unwrap_list() { + let mut db = TypeSystem::new(); + let std = db.std(); + let list = alloc_list(&mut db, std.public_key); + assert_eq!(unwrap_list(&mut db, list), Some(std.public_key)); + } +} diff --git a/crates/rue-typing/src/standard_types.rs b/crates/rue-typing/src/standard_types.rs new file mode 100644 index 0000000..41c069f --- /dev/null +++ b/crates/rue-typing/src/standard_types.rs @@ -0,0 +1,18 @@ +use crate::TypeId; + +#[derive(Debug, Clone, Copy)] +pub struct StandardTypes { + pub unknown: TypeId, + pub never: TypeId, + pub any: TypeId, + pub unmapped_list: TypeId, + pub generic_list_item: TypeId, + pub bytes: TypeId, + pub bytes32: TypeId, + pub public_key: TypeId, + pub int: TypeId, + pub bool: TypeId, + pub true_bool: TypeId, + pub false_bool: TypeId, + pub nil: TypeId, +} diff --git a/crates/rue-typing/src/stringify_type.rs b/crates/rue-typing/src/stringify_type.rs new file mode 100644 index 0000000..4cf18a7 --- /dev/null +++ b/crates/rue-typing/src/stringify_type.rs @@ -0,0 +1,142 @@ +use crate::{HashMap, HashSet}; + +use crate::{Callable, Enum, Struct, Type, TypeId, TypeSystem, Variant}; + +pub(crate) fn stringify_type( + types: &TypeSystem, + type_id: TypeId, + names: &HashMap, + visited: &mut HashSet, +) -> String { + if let Some(name) = names.get(&type_id) { + return name.clone(); + } + + if !visited.insert(type_id) { + return "{recursive}".to_string(); + } + + let result = match types.get(type_id) { + Type::Ref(..) => unreachable!(), + Type::Unknown => "{unknown}".to_string(), + Type::Generic => format!("{{{}}}", type_id.index()), + Type::Never => "Never".to_string(), + Type::Any => "Any".to_string(), + Type::Bytes => "Bytes".to_string(), + Type::Bytes32 => "Bytes32".to_string(), + Type::PublicKey => "PublicKey".to_string(), + Type::Int => "Int".to_string(), + Type::True => "True".to_string(), + Type::False => "False".to_string(), + Type::Nil => "Nil".to_string(), + Type::Value(value) => format!("{value}"), + Type::Pair(first, rest) => { + let first = stringify_type(types, *first, names, visited); + let rest = stringify_type(types, *rest, names, visited); + format!("({first}, {rest})") + } + Type::Union(items) => { + let mut result = String::new(); + + for (index, item) in items.iter().enumerate() { + if index > 0 { + result.push_str(" | "); + } + result.push_str(&stringify_type(types, *item, names, visited)); + } + + result + } + Type::Lazy(lazy) => { + let name = stringify_type(types, lazy.type_id, names, visited); + let mut generics = "<".to_string(); + + for (index, (_, generic)) in lazy.substitutions.iter().enumerate() { + if index > 0 { + generics.push_str(", "); + } + generics.push_str(&stringify_type(types, *generic, names, visited)); + } + + generics.push('>'); + name + &generics + } + Type::Alias(alias) => stringify_type(types, alias.type_id, names, visited), + Type::Struct(Struct { type_id, .. }) | Type::Variant(Variant { type_id, .. }) => { + stringify_type(types, *type_id, names, visited) + } + Type::Enum(Enum { type_id, .. }) => stringify_type(types, *type_id, names, visited), + Type::Callable(Callable { + parameters, + return_type, + .. + }) => { + let mut result = "fun(".to_string(); + result.push_str(&stringify_type(types, *parameters, names, visited)); + result.push_str(") -> "); + result.push_str(&stringify_type(types, *return_type, names, visited)); + result + } + }; + + visited.remove(&type_id); + + result +} + +#[cfg(test)] +mod tests { + use indexmap::indexmap; + + use crate::alloc_callable; + + use super::*; + + #[test] + fn stringify_atoms() { + let db = TypeSystem::new(); + let types = db.std(); + + assert_eq!(db.stringify(types.unknown), "{unknown}"); + assert_eq!(db.stringify(types.never), "Never"); + assert_eq!(db.stringify(types.bytes), "Bytes"); + assert_eq!(db.stringify(types.bytes32), "Bytes32"); + assert_eq!(db.stringify(types.public_key), "PublicKey"); + assert_eq!(db.stringify(types.int), "Int"); + assert_eq!(db.stringify(types.bool), "Bool"); + assert_eq!(db.stringify(types.nil), "Nil"); + assert_eq!(db.stringify(types.any), "Any"); + } + + #[test] + fn stringify_named() { + let db = TypeSystem::new(); + let types = db.std(); + + let mut names = HashMap::new(); + names.insert(types.any, "CustomAny".to_string()); + + assert_eq!(db.stringify_named(types.any, names), "CustomAny"); + } + + #[test] + fn test_stringify_callable() { + let mut db = TypeSystem::new(); + let types = db.std(); + + let callable = alloc_callable( + &mut db, + &indexmap! { + "a".to_string() => types.int, + "b".to_string() => types.bytes, + }, + types.bool, + true, + ); + + assert_eq!( + db.stringify_named(callable, HashMap::new()), + "fun((Int, (Bytes, Nil))) -> Bool" + ); + } +} diff --git a/crates/rue-typing/src/substitute_type.rs b/crates/rue-typing/src/substitute_type.rs new file mode 100644 index 0000000..2f979cf --- /dev/null +++ b/crates/rue-typing/src/substitute_type.rs @@ -0,0 +1,194 @@ +use crate::HashMap; + +use crate::{Alias, Callable, Enum, Struct, Type, TypeId, TypeSystem, Variant}; + +pub(crate) fn substitute_type( + types: &mut TypeSystem, + type_id: TypeId, + substitutions: &mut Vec>, +) -> TypeId { + for frame in substitutions.iter().rev() { + if let Some(new_type_id) = frame.get(&type_id) { + return *new_type_id; + } + } + + match types.get(type_id) { + Type::Unknown + | Type::Generic + | Type::Never + | Type::Any + | Type::Bytes + | Type::Bytes32 + | Type::PublicKey + | Type::Int + | Type::Nil + | Type::True + | Type::False + | Type::Value(..) => return type_id, + _ => {} + } + + let placeholder = types.alloc(Type::Unknown); + substitutions + .last_mut() + .unwrap() + .insert(type_id, placeholder); + + let result = match types.get(type_id) { + Type::Ref(..) + | Type::Unknown + | Type::Generic + | Type::Never + | Type::Any + | Type::Bytes + | Type::Bytes32 + | Type::PublicKey + | Type::Int + | Type::Nil + | Type::True + | Type::False + | Type::Value(..) => unreachable!(), + Type::Pair(first, rest) => { + let (first, rest) = (*first, *rest); + + let new_first = substitute_type(types, first, substitutions); + let new_rest = substitute_type(types, rest, substitutions); + + if new_first == first && new_rest == rest { + type_id + } else { + types.alloc(Type::Pair(new_first, new_rest)) + } + } + Type::Union(items) => { + let items = items.clone(); + let mut result = Vec::new(); + + for item in &items { + result.push(substitute_type(types, *item, substitutions)); + } + + if result == items { + type_id + } else { + types.alloc(Type::Union(result)) + } + } + Type::Lazy(lazy) => { + substitutions.push(lazy.substitutions.clone().into_iter().collect()); + let result = substitute_type(types, lazy.type_id, substitutions); + substitutions.pop().unwrap(); + result + } + Type::Alias(alias) => { + let alias = alias.clone(); + let new_type_id = substitute_type(types, alias.type_id, substitutions); + + if new_type_id == alias.type_id { + type_id + } else { + types.alloc(Type::Alias(Alias { + original_type_id: alias.original_type_id, + type_id: new_type_id, + generic_types: alias.generic_types, + })) + } + } + Type::Struct(ty) => { + let ty = ty.clone(); + let new_type_id = substitute_type(types, ty.type_id, substitutions); + + if new_type_id == ty.type_id { + type_id + } else { + types.alloc(Type::Struct(Struct { + original_type_id: ty.original_type_id, + type_id: new_type_id, + field_names: ty.field_names, + nil_terminated: ty.nil_terminated, + generic_types: ty.generic_types, + })) + } + } + Type::Variant(ty) => { + let ty = ty.clone(); + let new_type_id = substitute_type(types, ty.type_id, substitutions); + + if new_type_id == ty.type_id { + type_id + } else { + types.alloc(Type::Variant(Variant { + original_type_id: ty.original_type_id, + original_enum_type_id: ty.original_enum_type_id, + type_id: new_type_id, + field_names: ty.field_names, + nil_terminated: ty.nil_terminated, + generic_types: ty.generic_types, + discriminant: ty.discriminant, + })) + } + } + Type::Enum(ty) => { + let ty = ty.clone(); + let new_type_id = substitute_type(types, ty.type_id, substitutions); + + if new_type_id == ty.type_id { + type_id + } else { + types.alloc(Type::Enum(Enum { + original_type_id: ty.original_type_id, + type_id: new_type_id, + has_fields: ty.has_fields, + variants: ty.variants, + })) + } + } + Type::Callable(callable) => { + let callable = callable.clone(); + let new_return_type = substitute_type(types, callable.return_type, substitutions); + let new_parameters = substitute_type(types, callable.parameters, substitutions); + + if new_return_type == callable.return_type && new_parameters == callable.parameters { + type_id + } else { + types.alloc(Type::Callable(Callable { + original_type_id: callable.original_type_id, + parameter_names: callable.parameter_names, + parameters: new_parameters, + return_type: new_return_type, + nil_terminated: callable.nil_terminated, + generic_types: callable.generic_types, + })) + } + } + }; + + *types.get_mut(placeholder) = Type::Ref(result); + + result +} + +#[cfg(test)] +mod tests { + use crate::{alloc_list, Comparison}; + + use super::*; + + #[test] + fn test_substitute_generic_list() { + let mut db = TypeSystem::new(); + let types = db.std(); + + let generic = db.alloc(Type::Generic); + let list = alloc_list(&mut db, generic); + + let mut substitutions = HashMap::new(); + substitutions.insert(generic, types.bool); + + let result = db.substitute(list, substitutions); + let expected = alloc_list(&mut db, types.bool); + + assert_eq!(db.compare(expected, result), Comparison::Assignable); + } +} diff --git a/crates/rue-typing/src/test_tools.rs b/crates/rue-typing/src/test_tools.rs new file mode 100644 index 0000000..5b68986 --- /dev/null +++ b/crates/rue-typing/src/test_tools.rs @@ -0,0 +1,90 @@ +use indexmap::{indexmap, IndexMap}; + +use crate::{Callable, Struct, Type, TypeId, TypeSystem}; + +pub fn alloc_list(db: &mut TypeSystem, item_type_id: TypeId) -> TypeId { + db.substitute( + db.std().unmapped_list, + indexmap! { + db.std().generic_list_item => item_type_id, + } + .into_iter() + .collect(), + ) +} + +pub fn alloc_callable( + db: &mut TypeSystem, + parameters: &IndexMap, + return_type: TypeId, + nil_terminated: bool, +) -> TypeId { + let structure = if nil_terminated { + alloc_list_of(db, parameters.values().copied()) + } else { + alloc_tuple_of(db, parameters.values().copied()) + }; + + let type_id = db.alloc(Type::Unknown); + + *db.get_mut(type_id) = Type::Callable(Callable { + original_type_id: type_id, + parameter_names: parameters.keys().cloned().collect(), + parameters: structure, + return_type, + nil_terminated, + generic_types: Vec::new(), + }); + + type_id +} + +pub fn alloc_struct( + db: &mut TypeSystem, + fields: &IndexMap, + nil_terminated: bool, +) -> TypeId { + let structure = if nil_terminated { + alloc_list_of(db, fields.values().copied()) + } else { + alloc_tuple_of(db, fields.values().copied()) + }; + + let type_id = db.alloc(Type::Unknown); + + *db.get_mut(type_id) = Type::Struct(Struct { + original_type_id: type_id, + type_id: structure, + field_names: fields.keys().cloned().collect(), + nil_terminated, + generic_types: Vec::new(), + }); + + type_id +} + +pub fn alloc_list_of( + db: &mut TypeSystem, + items: impl DoubleEndedIterator, +) -> TypeId { + let mut tuple = db.std().nil; + for item in items.rev() { + tuple = db.alloc(Type::Pair(item, tuple)); + } + tuple +} + +pub fn alloc_tuple_of( + db: &mut TypeSystem, + items: impl DoubleEndedIterator, +) -> TypeId { + let mut tuple = db.std().nil; + for (i, item) in items.rev().enumerate() { + if i == 0 { + tuple = item; + continue; + } + tuple = db.alloc(Type::Pair(item, tuple)); + } + tuple +} diff --git a/crates/rue-typing/src/ty.rs b/crates/rue-typing/src/ty.rs new file mode 100644 index 0000000..bbcd043 --- /dev/null +++ b/crates/rue-typing/src/ty.rs @@ -0,0 +1,28 @@ +use num_bigint::BigInt; + +use crate::{Alias, Callable, Enum, Lazy, Struct, TypeId, Variant}; + +#[derive(Debug, Clone)] +pub enum Type { + Unknown, + Generic, + Never, + Any, + Bytes, + Bytes32, + PublicKey, + Int, + True, + False, + Nil, + Value(BigInt), + Pair(TypeId, TypeId), + Union(Vec), + Ref(TypeId), + Lazy(Lazy), + Alias(Alias), + Struct(Struct), + Callable(Callable), + Enum(Enum), + Variant(Variant), +} diff --git a/crates/rue-typing/src/type_path.rs b/crates/rue-typing/src/type_path.rs new file mode 100644 index 0000000..2cf8f2e --- /dev/null +++ b/crates/rue-typing/src/type_path.rs @@ -0,0 +1,16 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum TypePath { + First, + Rest, +} + +pub fn index_to_path(index: usize, nil_terminated: bool) -> Vec { + let mut path = Vec::with_capacity(index); + for _ in 0..index { + path.push(TypePath::Rest); + } + if nil_terminated { + path.push(TypePath::First); + } + path +} diff --git a/crates/rue-typing/src/type_system.rs b/crates/rue-typing/src/type_system.rs new file mode 100644 index 0000000..90c9e18 --- /dev/null +++ b/crates/rue-typing/src/type_system.rs @@ -0,0 +1,218 @@ +use id_arena::{Arena, Id}; + +use crate::{ + check_type, compare_type, debug_type, difference_type, replace_type, simplify_check, + stringify_type, substitute_type, Alias, Callable, Check, CheckError, Comparison, + ComparisonContext, HashMap, HashSet, StandardTypes, Type, TypePath, +}; + +pub type TypeId = Id; + +#[derive(Debug, Clone)] +pub struct TypeSystem { + arena: Arena, + types: StandardTypes, + names: HashMap, +} + +impl Default for TypeSystem { + fn default() -> Self { + let mut arena = Arena::new(); + + let unknown = arena.alloc(Type::Unknown); + let never = arena.alloc(Type::Never); + let any = arena.alloc(Type::Any); + let bytes = arena.alloc(Type::Bytes); + let bytes32 = arena.alloc(Type::Bytes32); + let public_key = arena.alloc(Type::PublicKey); + let int = arena.alloc(Type::Int); + let true_bool = arena.alloc(Type::True); + let false_bool = arena.alloc(Type::False); + let nil = arena.alloc(Type::Nil); + let bool = arena.alloc(Type::Union(vec![false_bool, true_bool])); + + let generic_list_item = arena.alloc(Type::Generic); + let inner = arena.alloc(Type::Unknown); + let unmapped_list = arena.alloc(Type::Unknown); + arena[unmapped_list] = Type::Alias(Alias { + original_type_id: unmapped_list, + type_id: inner, + generic_types: vec![generic_list_item], + }); + let pair = arena.alloc(Type::Pair(generic_list_item, unmapped_list)); + arena[inner] = Type::Union(vec![pair, nil]); + + let mut names = HashMap::new(); + names.insert(never, "Never".to_string()); + names.insert(any, "Any".to_string()); + names.insert(bytes, "Bytes".to_string()); + names.insert(bytes32, "Bytes32".to_string()); + names.insert(public_key, "PublicKey".to_string()); + names.insert(int, "Int".to_string()); + names.insert(bool, "Bool".to_string()); + names.insert(true_bool, "True".to_string()); + names.insert(false_bool, "False".to_string()); + names.insert(nil, "Nil".to_string()); + names.insert(unmapped_list, "List".to_string()); + names.insert(generic_list_item, "{item}".to_string()); + + Self { + arena, + types: StandardTypes { + unknown, + never, + any, + unmapped_list, + generic_list_item, + bytes, + bytes32, + public_key, + int, + bool, + true_bool, + false_bool, + nil, + }, + names, + } + } +} + +impl TypeSystem { + pub fn new() -> Self { + Self::default() + } + + pub fn std(&self) -> StandardTypes { + self.types + } + + pub fn alloc(&mut self, ty: Type) -> TypeId { + self.arena.alloc(ty) + } + + pub fn get_raw(&self, type_id: TypeId) -> &Type { + &self.arena[type_id] + } + + pub fn get_raw_mut(&mut self, type_id: TypeId) -> &mut Type { + &mut self.arena[type_id] + } + + pub fn get_recursive(&self, type_id: TypeId) -> &Type { + match self.get(type_id) { + Type::Alias(ty) => self.get_recursive(ty.type_id), + Type::Struct(ty) => self.get_recursive(ty.type_id), + Type::Enum(ty) => self.get_recursive(ty.type_id), + Type::Variant(ty) => self.get_recursive(ty.type_id), + ty => ty, + } + } + + pub fn get(&self, type_id: TypeId) -> &Type { + match &self.arena[type_id] { + Type::Ref(type_id) => self.get(*type_id), + ty => ty, + } + } + + pub fn get_mut(&mut self, type_id: TypeId) -> &mut Type { + match &self.arena[type_id] { + Type::Ref(type_id) => self.get_mut(*type_id), + _ => &mut self.arena[type_id], + } + } + + pub fn get_pair(&self, type_id: TypeId) -> Option<(TypeId, TypeId)> { + match self.get(type_id) { + Type::Pair(first, rest) => Some((*first, *rest)), + _ => None, + } + } + + pub fn get_union(&self, type_id: TypeId) -> Option<&[TypeId]> { + match self.get(type_id) { + Type::Union(types) => Some(types), + _ => None, + } + } + + pub fn get_callable(&self, type_id: TypeId) -> Option<&Callable> { + match self.get(type_id) { + Type::Callable(callable) => Some(callable), + _ => None, + } + } + + pub fn get_callable_recursive(&mut self, type_id: TypeId) -> Option<&Callable> { + match self.get_recursive(type_id) { + Type::Callable(callable) => Some(callable), + _ => None, + } + } + + pub fn stringify_named(&self, type_id: TypeId, mut names: HashMap) -> String { + for (id, name) in &self.names { + names.entry(*id).or_insert_with(|| name.clone()); + } + stringify_type(self, type_id, &names, &mut HashSet::new()) + } + + pub fn stringify(&self, type_id: TypeId) -> String { + self.stringify_named(type_id, HashMap::new()) + } + + pub fn debug(&self, type_id: TypeId) -> String { + debug_type(self, "", type_id, 0, &mut HashSet::new()) + } + + pub fn compare(&self, lhs: TypeId, rhs: TypeId) -> Comparison { + self.compare_with_generics(lhs, rhs, &mut Vec::new(), false) + } + + pub fn compare_with_generics( + &self, + lhs: TypeId, + rhs: TypeId, + substitution_stack: &mut Vec>, + infer_generics: bool, + ) -> Comparison { + compare_type( + self, + lhs, + rhs, + &mut ComparisonContext { + visited: HashSet::new(), + lhs_substitutions: Vec::new(), + rhs_substitutions: Vec::new(), + inferred: substitution_stack, + infer_generics, + }, + ) + } + + pub fn substitute( + &mut self, + type_id: TypeId, + substitutions: HashMap, + ) -> TypeId { + substitute_type(self, type_id, &mut vec![substitutions]) + } + + pub fn check(&mut self, lhs: TypeId, rhs: TypeId) -> Result { + check_type(self, lhs, rhs, &mut HashSet::new()).map(simplify_check) + } + + pub fn difference(&mut self, lhs: TypeId, rhs: TypeId) -> TypeId { + difference_type(self, lhs, rhs, &mut HashSet::new()) + } + + pub fn replace( + &mut self, + type_id: TypeId, + replace_type_id: TypeId, + path: &[TypePath], + ) -> TypeId { + replace_type(self, type_id, replace_type_id, path) + } +} diff --git a/examples/cat.rue b/examples/cat.rue index b414691..9b8cb59 100644 --- a/examples/cat.rue +++ b/examples/cat.rue @@ -29,27 +29,27 @@ struct Truths { type Tail = fun( truths: Truths, parent_is_cat: Bool, - lineage_proof: LineageProof?, + lineage_proof: LineageProof | nil, extra_delta: Int, - conditions: Condition[], - tail_solution: Any[], -) -> Condition[]; + conditions: List, + tail_solution: List, +) -> List; // Information about the TAIL reveal. // This is revealed with `RunTailCondition`. struct TailInfo { tail_puzzle: Tail, - tail_solution: Any[], + tail_solution: List, } // A custom condition `(51 () -113 tail_puzzle tail_solution)`. // We can check `Condition::CreateCoin` then cast to this type instead. struct RunTailCondition { opcode: Int, - puzzle_hash: Nil, + puzzle_hash: nil, amount: Int, tail_puzzle: Tail, - tail_solution: Any[], + tail_solution: List, } // Information about the current coin. @@ -90,15 +90,15 @@ inline fun cat_puzzle_hash(cat_info: CatInfo, inner_puzzle_hash: Bytes32) -> Byt fun main( mod_hash: Bytes32, asset_id: Bytes32, - inner_puzzle: fun(...solution: Any) -> Condition[], + inner_puzzle: fun(...solution: Any) -> List, inner_solution: Any, - lineage_proof: LineageProof?, + lineage_proof: LineageProof | nil, prev_coin_id: Bytes32, my_coin: Coin, next_coin_proof: CoinProof, prev_subtotal: Int, extra_delta: Int, -) -> Condition[] { +) -> List { // For simplicity, we'll pack these values into a struct. let cat_info = CatInfo { mod_hash: mod_hash, @@ -144,12 +144,12 @@ fun main( // Prepend the ring conditions to the morphed conditions. // This ensures that the previous and next CATs are linked. // When they form a ring like this, you can be sure the supply isn't changed. - let conditions: Condition[] = [ + let conditions: List = [ Condition::CreateCoinAnnouncement { - message: RING_MORPH_BYTE + tree_hash([prev_coin_id, prev_subtotal] as Any[]), + message: RING_MORPH_BYTE + tree_hash([prev_coin_id, prev_subtotal] as List), }, Condition::AssertCoinAnnouncement { - announcement_id: sha256(next_coin_id + RING_MORPH_BYTE + tree_hash([my_coin_id, subtotal] as Any[])), + announcement_id: sha256(next_coin_id + RING_MORPH_BYTE + tree_hash([my_coin_id, subtotal] as List)), }, ...morph.conditions, ]; @@ -191,23 +191,23 @@ fun main( struct Morph { // The morphed conditions. - conditions: Condition[], + conditions: List, // The total amount of coins created. sum: Int, // Information about the TAIL, revealed in the conditions. - tail_info: TailInfo?, + tail_info: TailInfo | nil, } // Morph all of the conditions and extract the TAIL info. fun morph_conditions( - conditions: Condition[], + conditions: List, cat_info: CatInfo, - tail_info: TailInfo?, + tail_info: TailInfo | nil, ) -> Morph { // If there are no conditions, return an empty morph. - if conditions is Nil { + if conditions is nil { return Morph { conditions: nil, sum: 0, @@ -220,7 +220,7 @@ fun morph_conditions( if condition is Condition::CreateCoin { // If the amount is -113, it's a TAIL reveal. if condition.amount == -113 { - let run_tail = condition as Any as RunTailCondition; + let run_tail = cast::(condition); let rest = morph_conditions(conditions.rest, cat_info, TailInfo { tail_puzzle: run_tail.tail_puzzle, diff --git a/examples/multisig.rue b/examples/multisig.rue index 9bba1dc..a92e417 100644 --- a/examples/multisig.rue +++ b/examples/multisig.rue @@ -1,28 +1,28 @@ // This puzzle has not been audited or tested, and is for example purposes only. fun main( - public_keys: PublicKey[], + public_keys: List, required: Int, - indices: Int[], - conditions: Condition[], -) -> Condition[] { + indices: List, + conditions: List, +) -> List { let message = tree_hash(conditions); let agg_sigs = check_signatures(public_keys, required, indices, 0, message); concat(agg_sigs, conditions) } fun check_signatures( - public_keys: PublicKey[], + public_keys: List, required: Int, - indices: Int[], + indices: List, pos: Int, message: Bytes, -) -> Condition[] { +) -> List { if required == 0 { return nil; } - assume !(public_keys is Nil) && !(indices is Nil); + assume !(public_keys is nil) && !(indices is nil); if indices.first != pos { return check_signatures(public_keys.rest, required, indices, pos + 1, message); diff --git a/examples/p2_conditions.rue b/examples/p2_conditions.rue index f017e95..1f5dd95 100644 --- a/examples/p2_conditions.rue +++ b/examples/p2_conditions.rue @@ -1,6 +1,6 @@ // This puzzle has not been audited or tested, and is for example purposes only. -fun main(public_key: PublicKey, conditions: Condition[]) -> Condition[] { +fun main(public_key: PublicKey, conditions: List) -> List { let agg_sig = Condition::AggSigMe { public_key: public_key, message: tree_hash(conditions), diff --git a/examples/p2_delegated_or_hidden.rue b/examples/p2_delegated_or_hidden.rue index a4f66e4..2d25e74 100644 --- a/examples/p2_delegated_or_hidden.rue +++ b/examples/p2_delegated_or_hidden.rue @@ -2,10 +2,10 @@ fun main( synthetic_pk: PublicKey, - original_pk: PublicKey?, - delegated_puzzle: fun(...solution: Any) -> Condition[], + original_pk: PublicKey | nil, + delegated_puzzle: fun(...solution: Any) -> List, delegated_solution: Any -) -> Condition[] { +) -> List { let conditions = delegated_puzzle(...delegated_solution); let delegated_puzzle_hash = tree_hash(delegated_puzzle); diff --git a/examples/p2_fusion.rue b/examples/p2_fusion.rue index dbf19d0..3a896bf 100644 --- a/examples/p2_fusion.rue +++ b/examples/p2_fusion.rue @@ -20,7 +20,7 @@ fun main( my_inner_puzzle_hash: Bytes32, my_amount: Int, p2_puzzle_hash: Bytes32, -) -> Condition[] { +) -> List { // The NFT singleton has the same mod hash and launcher puzzle hash as the fusion singleton. let nft_singleton = SingletonInfo { mod_hash: fusion_singleton.mod_hash, @@ -52,7 +52,7 @@ fun main( Condition::CreateCoin { puzzle_hash: p2_puzzle_hash, amount: my_amount, - memos: [p2_puzzle_hash], + memos: ([p2_puzzle_hash], nil), }, // Announce that a specific singleton is being spent to diff --git a/examples/royalty_split.rue b/examples/royalty_split.rue index 6078916..df01431 100644 --- a/examples/royalty_split.rue +++ b/examples/royalty_split.rue @@ -5,7 +5,7 @@ struct Payout { share: Int, } -fun main(payouts: Payout[], total_shares: Int, my_amount: Int) -> Condition[] { +fun main(payouts: List, total_shares: Int, my_amount: Int) -> List { let announcement = Condition::CreateCoinAnnouncement { message: '$' }; let assert_amount = Condition::AssertMyAmount { amount: my_amount }; @@ -14,13 +14,13 @@ fun main(payouts: Payout[], total_shares: Int, my_amount: Int) -> Condition[] { } fun calculate_amount_and_split( - payouts: Payout[], + payouts: List, total_amount: Int, total_shares: Int, shares_sum: Int, remaining_amount: Int, -) -> Condition[] { - if payouts is (Payout, Payout[]) { +) -> List { + if payouts is (Payout, List) { let amount = get_amount(payouts.first, total_amount, total_shares); return split_amount_and_create_coins(payouts, amount, total_amount, total_shares, shares_sum, remaining_amount); } @@ -29,18 +29,18 @@ fun calculate_amount_and_split( } fun split_amount_and_create_coins( - payouts: (Payout, Payout[]), + payouts: (Payout, List), this_amount: Int, total_amount: Int, total_shares: Int, shares_sum: Int, remaining_amount: Int, -) -> Condition[] { +) -> List { let payout = payouts.first; let create_coin = Condition::CreateCoin { puzzle_hash: payout.puzzle_hash, amount: if payout.share > 0 { this_amount } else { remaining_amount }, - memos: [payout.puzzle_hash], + memos: ([payout.puzzle_hash], nil), }; let rest = calculate_amount_and_split(payouts.rest, total_amount, total_shares, shares_sum + payout.share, remaining_amount - this_amount); [create_coin, ...rest] diff --git a/examples/singleton.rue b/examples/singleton.rue index 8112a28..819d164 100644 --- a/examples/singleton.rue +++ b/examples/singleton.rue @@ -8,7 +8,7 @@ struct Singleton { struct LineageProof { parent_parent_coin_info: Bytes32, - parent_inner_puzzle_hash: Bytes32?, + parent_inner_puzzle_hash: Bytes32 | nil, parent_amount: Int, } @@ -18,11 +18,11 @@ fun singleton_puzzle_hash(singleton: Singleton, inner_puzzle_hash: Bytes32) -> B fun main( singleton: Singleton, - inner_puzzle: fun(...solution: Any) -> Condition[], + inner_puzzle: fun(...solution: Any) -> List, lineage_proof: LineageProof, my_amount: Int, inner_solution: Any, -) -> Condition[] { +) -> List { // Ensure that the amount is odd. assert my_amount & 1 == 1; @@ -55,10 +55,10 @@ fun main( fun morph_conditions( singleton: Singleton, - conditions: Condition[], + conditions: List, found_singleton_output: Bool, -) -> Condition[] { - if conditions is Nil { +) -> List { + if conditions is nil { // We must have a singleton output. assert found_singleton_output; return nil; diff --git a/tests.toml b/tests.toml index 7a5887f..536f961 100644 --- a/tests.toml +++ b/tests.toml @@ -1,6 +1,6 @@ [const_function_cycle] parser_errors = [] -compiler_errors = ["Error: Cannot recursively reference constant. (8:5)"] +compiler_errors = ["Error: Cannot recursively reference constant (8:5)"] [recursive_function] bytes = 95 @@ -18,7 +18,7 @@ hash = "3e851b5ad6b5c395bd46219991d2c7bb113e91918a065c705cc9068445bcc2a7" [inline_function_cycle] parser_errors = [] -compiler_errors = ["Error: Cannot recursively call inline function. (10:5)"] +compiler_errors = ["Error: Cannot recursively call inline function (10:5)"] [const_reference] bytes = 31 @@ -29,7 +29,7 @@ hash = "e3429ee993cedb79fa5b67ebd52635a294daf9c69b1c1d1a967f884fb4c3f1d1" [inline_const_inline_function_cycle] parser_errors = [] -compiler_errors = ["Error: Cannot recursively call inline function. (5:27)"] +compiler_errors = ["Error: Cannot recursively call inline function (5:27)"] [mixed_consts] bytes = 39 @@ -40,27 +40,27 @@ hash = "14870051d753aa4295466e166782fb5da7f18ad450a0ed0af9d7ccb83a82bed9" [inline_const_function_cycle] parser_errors = [] -compiler_errors = ["Error: Cannot recursively reference inline constant. (8:5)"] +compiler_errors = ["Error: Cannot recursively reference inline constant (8:5)"] [inline_const_self] parser_errors = [] -compiler_errors = ["Error: Cannot recursively reference inline constant. (1:27)"] +compiler_errors = ["Error: Cannot recursively reference inline constant (1:27)"] [recursive_inline_function] parser_errors = [] -compiler_errors = ["Error: Cannot recursively call inline function. (6:5)"] +compiler_errors = ["Error: Cannot recursively call inline function (6:5)"] [const_inline_function_cycle] parser_errors = [] -compiler_errors = ["Error: Cannot recursively call inline function. (5:20)"] +compiler_errors = ["Error: Cannot recursively call inline function (5:20)"] [const_self] parser_errors = [] -compiler_errors = ["Error: Cannot recursively reference constant. (1:20)"] +compiler_errors = ["Error: Cannot recursively reference constant (1:20)"] [const_cycle] parser_errors = [] -compiler_errors = ["Error: Cannot recursively reference constant. (2:16)"] +compiler_errors = ["Error: Cannot recursively reference constant (2:16)"] [inline_const_reference] bytes = 3 @@ -185,7 +185,7 @@ hash = "8172dac13078e7c4ce8a6062bb5dd117bb39d45b7e9ca3b9970dfef4be188ac8" [block_return] parser_errors = [] -compiler_errors = ["Error: Explicit return is not allowed within expressions. (2:5)"] +compiler_errors = ["Error: Explicit return is not allowed within expressions (2:5)"] [block_function] bytes = 57 @@ -203,7 +203,7 @@ hash = "6636275e009c26ccaba0bde1f76d4e1451f57462b1d2528a3a5df0a3202c1f78" [block_nested_return] parser_errors = [] -compiler_errors = ["Error: Explicit return is not allowed within expressions. (2:10)"] +compiler_errors = ["Error: Explicit return is not allowed within expressions (2:10)"] [enum_discriminant] bytes = 199 @@ -213,11 +213,11 @@ output = "()" hash = "1a9674474efa85b2616b28fc2ae4f1b6a199273973e482ea74498fcc507141cf" [enum_fields] -bytes = 35 -cost = 636 +bytes = 65 +cost = 1004 input = "()" output = "1000" -hash = "9cccd0bde90e1e21f335c88acd5af8788b190c7bcfe652995564064ff2f34e28" +hash = "6f486106a324ee36d5b3e63bd65d7e69c8a95278d4f405307492e0c4960b08f2" [enum_numeric] bytes = 103 @@ -243,39 +243,25 @@ hash = "600c2c11a7ceb22fd3f5d2559d22120f9ad7e8d47e0966ddcfbcce20e02f373f" [function_call] parser_errors = [] compiler_errors = [ - "Error: Expected 0 arguments, but found 1. (3:11)", - "Error: Expected either 0 or 1 arguments, but found 2. (7:11)", - "Error: This function does not support the spread operator on its last argument. (8:25)", - "Error: Expected either 0 or 1 arguments, but found 2. (9:11)", - "Error: This function does not support the spread operator on its last argument. (9:28)", - "Error: Expected either 1 or 2 arguments, but found 0. (11:11)", - "Error: Expected either 1 or 2 arguments, but found 3. (14:11)", - "Error: Expected either 1 or 2 arguments, but found 4. (15:11)", - "Error: This function does not support the spread operator on its last argument. (16:25)", - "Error: This function does not support the spread operator on its last argument. (17:28)", - "Error: Expected either 1 or 2 arguments, but found 3. (18:11)", - "Error: This function does not support the spread operator on its last argument. (18:31)", - "Error: Expected either 1 or 2 arguments, but found 4. (19:11)", - "Error: This function does not support the spread operator on its last argument. (19:34)", - "Error: This function does not support the spread operator on its last argument. (20:28)", - "Error: Expected type `Int[]`, but found `Int`. (26:28)", - "Error: Expected type `Int[]`, but found `Int`. (27:31)", - "Error: Expected at least 2 arguments, but found 0. (31:11)", - "Error: Expected type `Int[]`, but found `Int`. (36:31)", - "Error: Expected 1 argument, but found 0. (40:11)", - "Error: This function requires the spread operator on its last argument. (41:27)", - "Error: Expected 1 argument, but found 2. (42:11)", - "Error: This function requires the spread operator on its last argument. (42:27)", - "Error: Expected 1 argument, but found 2. (44:11)", - "Error: Expected type `Int`, but found `Int[]`. (45:27)", - "Error: Expected 1 argument, but found 2. (46:11)", - "Error: Expected type `Int`, but found `Int[]`. (46:30)", - "Error: Expected 2 arguments, but found 0. (48:11)", - "Error: Expected 2 arguments, but found 1. (49:11)", - "Error: This function requires the spread operator on its last argument. (50:30)", - "Error: Expected 2 arguments, but found 1. (51:11)", - "Error: Expected 2 arguments, but found 1. (53:11)", - "Error: Expected type `Int`, but found `Int[]`. (54:30)", + "Error: Expected 0 arguments, but found 1 (3:11)", + "Error: Expected type `(Int, {recursive} | Nil) | Nil`, but found `Int` (9:28)", + "Error: Expected type `(Int, {recursive} | Nil) | Nil`, but found `Int` (10:31)", + "Error: Expected at least 2 arguments, but found 0 (14:11)", + "Error: Expected type `(Int, {recursive} | Nil) | Nil`, but found `Int` (19:31)", + "Error: Expected 1 argument, but found 0 (23:11)", + "Error: This function requires the spread operator on its last argument (24:27)", + "Error: Expected 1 argument, but found 2 (25:11)", + "Error: This function requires the spread operator on its last argument (25:27)", + "Error: Expected 1 argument, but found 2 (27:11)", + "Error: Expected type `Int`, but found `(Int, (Int, Nil))` (28:27)", + "Error: Expected 1 argument, but found 2 (29:11)", + "Error: Expected type `Int`, but found `(Int, Nil)` (29:30)", + "Error: Expected 2 arguments, but found 0 (31:11)", + "Error: Expected 2 arguments, but found 1 (32:11)", + "Error: This function requires the spread operator on its last argument (33:30)", + "Error: Expected 2 arguments, but found 1 (34:11)", + "Error: Expected 2 arguments, but found 1 (36:11)", + "Error: Expected type `Int`, but found `(Int, Nil)` (37:30)", ] [struct_empty] @@ -309,10 +295,10 @@ hash = "e1bb38bc03b979e06bcfc2341001b12647928e29d682637eee425adf5b1bf212" [infer_lambda_params] parser_errors = [] compiler_errors = [ - "Error: Lambda parameter type could not be inferred. (2:24)", - "Error: Unused let binding `no_infer`. (2:9)", - "Error: Unused let binding `infer`. (3:9)", - "Error: Unused let binding `explicit`. (4:9)", + "Error: Lambda parameter type could not be inferred (2:24)", + "Error: Unused let binding `no_infer` (2:9)", + "Error: Unused let binding `infer` (3:9)", + "Error: Unused let binding `explicit` (4:9)", ] [infer_list] @@ -486,10 +472,10 @@ hash = "e977dca31ec8aba231dd0e858a3235d2dea5ed36aefed5f12c48ec02bf450b77" [enum_duplicate_variants] parser_errors = [] compiler_errors = [ - "Error: Duplicate enum variant `SameName` specified. (3:5)", - "Error: Duplicate enum discriminant `1` specified. (4:29)", - "Error: Unused enum variant `SameName`. (2:5)", - "Error: Unused enum variant `DuplicateDiscriminant`. (4:5)", + "Error: Duplicate enum variant `SameName` specified (3:5)", + "Error: Duplicate enum discriminant `1` specified (4:29)", + "Error: Unused enum variant `SameName` (2:5)", + "Error: Unused enum variant `DuplicateDiscriminant` (4:5)", ] [p2_fusion] @@ -549,25 +535,26 @@ output = "((g1_multiply 0xc00000000000000000000000000000000000000000000000000000 hash = "21f96d7bb1b15b83ce81dff3525d4c98793f906f6cc7ebba52a76524a7db6943" [singleton] -bytes = 1395 -cost = 52110 +bytes = 1531 +cost = 54718 input = "((0x42840c6aebec47ce2e01629ce381b461c19695264281a7b1aab5d4ff54506775 0x4696e7a2b7682e2df01ab47e6e002d0dca895f99c6172e4a55a3e033499532b7 0x291e4594b43d58e833cab95e4b165c5fac6b4d8391c81ebfd20efdd8d58b92d8) 1 (0x9b1c580707ca8282534c02c1a055427e0954818b6195a29f4442ac3e7ea8e8ee () 1) 1 ((51 0x173385b87af5d8940767c328026fe5f8e76bc238d2a3aaddf4f55e844f400fca 1)))" output = "((73 1) (71 0xf92f0ebbd0e5ecb1334331d98c1f3b3e41cfce2c15f1053ffd1e2151b361e909) (g1_negate 0x07d534114dd68436cb7a4026abade359cd9c9f28b253c60e305535c781bbc7ed 1))" -hash = "28a3c50a049c12c4ce49c6098618957b4c6791c563c70c895fa59bf47d93366c" +hash = "3d8892dcdac1a32aab1034fa9489b7b7dc3b6aabf512f38e73d54574b1f6b62c" [enum_type_guard] -bytes = 131 -cost = 1710 +bytes = 101 +cost = 1090 input = "()" output = "()" -hash = "e3153c4596c3b27d1f1cea81cf4c475e2d8f617330a82684bd97f6fae2bacf50" +hash = "146182f765c52c144e4fa6d44fd3073bb3cbed6fac2bf1ca15393fc244c7d2b8" [cat] -bytes = 2107 -cost = 359978 +bytes = 2323 +cost = 0 input = "(0x00f43ce9fcc63d5019e209c103e6b0aaf56bbe7fc7fafae5af7f5ee6887a8719 0xd622c62a7292ffee5cf2537a90360ca0b7337b76d7014ec042930c0a87592213 (q (g1_negate () -113 (a (q 2 (i 47 (q 8) (q 2 (i (= 45 2) () (q 8)) 1)) 1) (c (q . 0x895eb35a355941ba7f6a8679a73bb9b8b62cae2b04ef5351eda42583c0f2d861) 1)) ()) (g1_negate 0xb8705f94744e7fc30300ac9b12d306b283f5a702937ee99beabf665be6023001 1 (0xb8705f94744e7fc30300ac9b12d306b283f5a702937ee99beabf665be6023001))) () () 0x615236766bed52d7abaa41d270407f3ec852981852334b213bd8515924459a5d (0x895eb35a355941ba7f6a8679a73bb9b8b62cae2b04ef5351eda42583c0f2d861 0x1ecb863db5d2ae6c71e9a8b0741acb3e034e8164b8ca0e564d5fad8b9dc875d5 1) (0x895eb35a355941ba7f6a8679a73bb9b8b62cae2b04ef5351eda42583c0f2d861 0x130deb20b44082a68293974f8cab9c51e21f9a9f3005000168eb77e49e0fc378 1) () ())" -output = "((70 0x615236766bed52d7abaa41d270407f3ec852981852334b213bd8515924459a5d) (60 0xcb7f53b5de05b4afe58ab663b952aa785fd9ad911564709bd8eacbf4c58ba1e589) (61 0x389f1ea7b9fab7a9294104eb3a181d662f2dbe9c43fbe52802ba9b7eef357e77) (g1_negate 0xc9644528436f44cd9b33282684b4964f55d5551cb3a970ab9cf8f536e0d72ad1 1 (0xb8705f94744e7fc30300ac9b12d306b283f5a702937ee99beabf665be6023001)))" -hash = "00f43ce9fcc63d5019e209c103e6b0aaf56bbe7fc7fafae5af7f5ee6887a8719" +output = "()" +hash = "5db0d2b8b2ea766c75afdfee8127d704f41e849a781fda6655e305199d4c5e23" +error = "()" [external_function] bytes = 7 @@ -584,8 +571,8 @@ output = "1500" hash = "c9370ff4457a61a860e99fbbcf92aec2630e45a9447428174a1c5036a0768aad" [block_let_function] -bytes = 281 -cost = 9394 +bytes = 283 +cost = 9356 input = "()" output = "1" -hash = "a1d71c94c9e64a24f50f4aee3496e833aa8bb263457b4cb4459f851f0f3d575b" +hash = "488ceea9106d9d45a4380644bdb88d7f440a901d08c1281a5b94ba92a2ae418d" diff --git a/tests/block/block_let_function.rue b/tests/block/block_let_function.rue index e62877a..444ca69 100644 --- a/tests/block/block_let_function.rue +++ b/tests/block/block_let_function.rue @@ -1,8 +1,8 @@ fun main() -> Bool { - let outer: Int? = 42; + let outer: Int | nil = 42; let another_outer = 19; let not_used_in_block = false; - let result = !not_used_in_block && outer is Int && { + let result = !not_used_in_block && !(outer is nil) && { let double = double(outer); outer + double + double - double + another_outer == 126 + 19 }; diff --git a/tests/enum/enum_discriminant.rue b/tests/enum/enum_discriminant.rue index 2f9be1a..d401f31 100644 --- a/tests/enum/enum_discriminant.rue +++ b/tests/enum/enum_discriminant.rue @@ -7,7 +7,7 @@ enum Num { Six = 6, } -fun main() -> Nil { +fun main() -> nil { assert Num::Zero as Int == 0; assert Num::One as Int == 1; assert Num::Two as Int == 2; diff --git a/tests/enum/enum_empty.rue b/tests/enum/enum_empty.rue index ca521ee..880200f 100644 --- a/tests/enum/enum_empty.rue +++ b/tests/enum/enum_empty.rue @@ -2,7 +2,7 @@ enum Enum { Empty {}, } -fun main() -> Nil { +fun main() -> nil { assert tree_hash(Enum::Empty {}) == tree_hash([0]); nil } diff --git a/tests/enum/enum_mixed.rue b/tests/enum/enum_mixed.rue index 69724ac..ed60141 100644 --- a/tests/enum/enum_mixed.rue +++ b/tests/enum/enum_mixed.rue @@ -3,7 +3,7 @@ enum Test { Value = 1 { num: Int }, } -fun main() -> Nil { +fun main() -> nil { assert tree_hash(Test::Unit) == tree_hash([0]); assert tree_hash(Test::Value { num: 1000 }) == tree_hash([1, 1000]); nil diff --git a/tests/enum/enum_numeric.rue b/tests/enum/enum_numeric.rue index e2b0f11..d6176c4 100644 --- a/tests/enum/enum_numeric.rue +++ b/tests/enum/enum_numeric.rue @@ -4,7 +4,7 @@ enum Mode { Open, } -fun main() -> Nil { +fun main() -> nil { let open: Mode = Mode::Open; assert open is Mode::Open; diff --git a/tests/enum/enum_type_guard.rue b/tests/enum/enum_type_guard.rue index ff3c3cd..44f4461 100644 --- a/tests/enum/enum_type_guard.rue +++ b/tests/enum/enum_type_guard.rue @@ -15,7 +15,7 @@ fun main() -> Int { raise "Unreachable"; } - assert color is Color::Red; let red: Color::Red = color; + red as Int } diff --git a/tests/function/external_function.rue b/tests/function/external_function.rue index 97e6602..665c977 100644 --- a/tests/function/external_function.rue +++ b/tests/function/external_function.rue @@ -1,3 +1,3 @@ -fun main(conditions: fun(...solution: Any) -> Condition[]) -> Condition[] { +fun main(conditions: fun(...solution: Any) -> List) -> List { conditions(...nil) } diff --git a/tests/function/function_call.rue b/tests/function/function_call.rue index 0655ef1..dbbab41 100644 --- a/tests/function/function_call.rue +++ b/tests/function/function_call.rue @@ -1,24 +1,7 @@ -fun main() -> Nil { +fun main() -> nil { assert empty(); assert empty(1); - assert one_optional(); - assert one_optional(1); - assert one_optional(1, 2); - assert one_optional(...1); - assert one_optional(1, ...2); - - assert two_optional(); - assert two_optional(1); - assert two_optional(1, 2); - assert two_optional(1, 2, 3); - assert two_optional(1, 2, 3, 4); - assert two_optional(...[1, 2, 3, 4]); - assert two_optional(1, ...[2, 3, 4]); - assert two_optional(1, 2, ...[3, 4]); - assert two_optional(1, 2, 3, ...[4]); - assert two_optional(1, ...2); - assert one_spread_list(); assert one_spread_list(1); assert one_spread_list(1, 2); @@ -60,15 +43,7 @@ fun empty() -> Bool { true } -fun one_optional(_a?: Int) -> Bool { - true -} - -fun two_optional(_a: Int, _b?: Int) -> Bool { - true -} - -fun one_spread_list(..._a: Int[]) -> Bool { +fun one_spread_list(..._a: List) -> Bool { true } @@ -76,7 +51,7 @@ fun one_spread_raw(..._a: Int) -> Bool { true } -fun two_spread_list(_a: Int, ..._b: Int[]) -> Bool { +fun two_spread_list(_a: Int, ..._b: List) -> Bool { true } diff --git a/tests/function/function_optional_param.rue b/tests/function/function_optional_param.rue deleted file mode 100644 index b684a94..0000000 --- a/tests/function/function_optional_param.rue +++ /dev/null @@ -1,11 +0,0 @@ -fun main() -> Int { - multiply(10) * multiply(34, 42) -} - -fun multiply(num: Int, factor?: Int) -> Int { - if factor? { - num * factor - } else { - num * num - } -} diff --git a/tests/function/function_rest_param.rue b/tests/function/function_rest_param.rue index 667d63c..caadd06 100644 --- a/tests/function/function_rest_param.rue +++ b/tests/function/function_rest_param.rue @@ -2,15 +2,15 @@ fun main() -> Int { sum(...range_inclusive(1, 10)) } -fun range_inclusive(start: Int, end: Int) -> Int[] { +fun range_inclusive(start: Int, end: Int) -> List { if start > end { return nil; } [start, ...range_inclusive(start + 1, end)] } -fun sum(...nums: Int[]) -> Int { - if nums is (Int, Int[]) { +fun sum(...nums: List) -> Int { + if nums is (Int, List) { nums.first + sum(...nums.rest) } else { 0 diff --git a/tests/function/lambda_optional_param.rue b/tests/function/lambda_optional_param.rue deleted file mode 100644 index 8e642cf..0000000 --- a/tests/function/lambda_optional_param.rue +++ /dev/null @@ -1,10 +0,0 @@ -fun main() -> Int { - let max = fun(a: Int, b?: Int) => { - if b? && b > a { - b - } else { - a - } - }; - max(42) + max(12124, 81) -} diff --git a/tests/std.rue b/tests/std.rue index 6950624..715c9e9 100644 --- a/tests/std.rue +++ b/tests/std.rue @@ -1,4 +1,4 @@ -fun main() -> Nil { +fun main() -> nil { assert tree_hash(map([1, 2, 3], fun(num) => num * 2)) == tree_hash([2, 4, 6]); assert tree_hash(filter([1, 2, 3, 4, 5], fun(num) => num < 4)) == tree_hash([1, 2, 3]); assert tree_hash(fold([1, 2, 3], 0, fun(acc, num) => acc + num)) == tree_hash(6); diff --git a/tests/struct/struct_empty.rue b/tests/struct/struct_empty.rue index 2f3f62c..162d234 100644 --- a/tests/struct/struct_empty.rue +++ b/tests/struct/struct_empty.rue @@ -1,6 +1,6 @@ struct Empty {} -fun main() -> Nil { +fun main() -> nil { assert tree_hash(Empty {}) == tree_hash(nil); nil } diff --git a/tests/struct/struct_inner_optional.rue b/tests/struct/struct_inner_optional.rue deleted file mode 100644 index b81e625..0000000 --- a/tests/struct/struct_inner_optional.rue +++ /dev/null @@ -1,23 +0,0 @@ -struct Value { - inner?: Value, -} - -fun main() -> Nil { - let value = Value { - inner: Value { - inner: Value { - inner: Value { - inner: Value {} - } - } - } - }; - - assert value.inner?; - assert value.inner.inner?; - assert value.inner.inner.inner?; - assert value.inner.inner.inner.inner?; - assert !value.inner.inner.inner.inner.inner?; - - nil -} diff --git a/tests/struct/struct_optional.rue b/tests/struct/struct_optional.rue deleted file mode 100644 index 427cb29..0000000 --- a/tests/struct/struct_optional.rue +++ /dev/null @@ -1,19 +0,0 @@ -struct Point { - x: Int, - y: Int, - z?: Int, -} - -fun main() -> Int { - let point_xy = Point { x: 42, y: 34 }; - let point_xyz = Point { x: 42, y: 34, z: 69 }; - sum(point_xy) + sum(point_xyz) -} - -fun sum(point: Point) -> Int { - if point.z? { - point.x + point.y + point.z - } else { - point.x + point.y - } -} diff --git a/tests/struct/struct_optional_initializer.rue b/tests/struct/struct_optional_initializer.rue deleted file mode 100644 index 41d8b5d..0000000 --- a/tests/struct/struct_optional_initializer.rue +++ /dev/null @@ -1,18 +0,0 @@ -struct Point { - x: Int, - y: Int, - z?: Int, -} - -fun main() -> Nil { - let point_xy = Point { x: 42, y: 34 }; - let new_point_xy = Point { x: point_xy.x, y: point_xy.y, z: point_xy.z }; - - let point_xyz = Point { x: 42, y: 34, z: 69 }; - let new_point_xyz = Point { x: point_xyz.x, y: point_xyz.y, z: point_xyz.z }; - - assert tree_hash(point_xy) == tree_hash(new_point_xy); - assert tree_hash(point_xyz) == tree_hash(new_point_xyz); - - nil -} diff --git a/tests/struct/struct_single_optional.rue b/tests/struct/struct_single_optional.rue deleted file mode 100644 index 6096dc1..0000000 --- a/tests/struct/struct_single_optional.rue +++ /dev/null @@ -1,11 +0,0 @@ -struct Value { - num?: Int, -} - -fun main() -> Int { - let empty = Value {}; - let value = Value { num: 42 }; - assert !empty.num?; - assert value.num?; - value.num -}