Skip to content

Commit

Permalink
Merge pull request #6 from Rigidity/generic-types
Browse files Browse the repository at this point in the history
Generic types
  • Loading branch information
Rigidity authored Jun 24, 2024
2 parents 08816b4 + 770c3bd commit ecc58e5
Show file tree
Hide file tree
Showing 21 changed files with 578 additions and 153 deletions.
6 changes: 3 additions & 3 deletions crates/rue-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ fn main() {
"Serialized output: {}",
hex::encode(node_to_bytes(&allocator, output.1).unwrap())
),
Err(error) => eprintln!("Error: {error:?}"),
Err(error) => eprintln!("error: {error:?}"),
}
}
}
Expand All @@ -71,10 +71,10 @@ fn print_diagnostics(source: &str, diagnostics: &[Diagnostic]) -> bool {
match error.kind() {
DiagnosticKind::Error(kind) => {
has_error = true;
eprintln!("Error: {kind} at {line}:{col}");
eprintln!("error: {kind} at {line}:{col}");
}
DiagnosticKind::Warning(kind) => {
eprintln!("Warning: {kind} at {line}:{col}");
eprintln!("warning: {kind} at {line}:{col}");
}
}
}
Expand Down
30 changes: 27 additions & 3 deletions crates/rue-compiler/src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@ use crate::{
hir::Hir,
scope::Scope,
ty::{FunctionType, PairType, Rest, Type, Value},
Comparison, ErrorKind, TypeSystem,
Comparison, ErrorKind,
};

mod block;
mod builtins;
mod context;
mod expr;
mod generic_types;
mod item;
mod path;
mod stmt;
Expand Down Expand Up @@ -47,6 +48,12 @@ pub struct Compiler<'a> {
// The type guard stack is used for overriding types in certain contexts.
type_guard_stack: Vec<HashMap<SymbolId, TypeId>>,

// The generic type stack is used for overriding generic types that are being checked against.
generic_type_stack: Vec<HashMap<TypeId, TypeId>>,

// Whether or not generic type inference is allowed.
allow_generic_inference_stack: Vec<bool>,

// Whether the current expression is directly the callee of a function call.
is_callee: bool,

Expand All @@ -66,6 +73,8 @@ impl<'a> Compiler<'a> {
symbol_stack: Vec::new(),
type_definition_stack: Vec::new(),
type_guard_stack: Vec::new(),
generic_type_stack: Vec::new(),
allow_generic_inference_stack: vec![false],
is_callee: false,
sym: SymbolTable::default(),
builtins,
Expand Down Expand Up @@ -165,6 +174,7 @@ impl<'a> Compiler<'a> {

let name = match self.db.ty(ty) {
Type::Unknown => "{unknown}".to_string(),
Type::Generic => "{generic}".to_string(),
Type::Nil => "Nil".to_string(),
Type::Any => "Any".to_string(),
Type::Int => "Int".to_string(),
Expand Down Expand Up @@ -231,7 +241,14 @@ impl<'a> Compiler<'a> {
}

fn type_check(&mut self, from: TypeId, to: TypeId, range: TextRange) {
if self.db.compare_type(from, to) > Comparison::Assignable {
let comparison = if self.allow_generic_inference_stack.last().copied().unwrap() {
self.db
.compare_type_with_generics(from, to, &mut self.generic_type_stack)
} else {
self.db.compare_type_raw(from, to)
};

if comparison > Comparison::Assignable {
self.db.error(
ErrorKind::TypeMismatch {
expected: self.type_name(to),
Expand All @@ -243,7 +260,14 @@ impl<'a> Compiler<'a> {
}

fn cast_check(&mut self, from: TypeId, to: TypeId, range: TextRange) {
if self.db.compare_type(from, to) > Comparison::Castable {
let comparison = if self.allow_generic_inference_stack.last().copied().unwrap() {
self.db
.compare_type_with_generics(from, to, &mut self.generic_type_stack)
} else {
self.db.compare_type_raw(from, to)
};

if comparison > Comparison::Castable {
self.db.error(
ErrorKind::CastMismatch {
expected: self.type_name(to),
Expand Down
2 changes: 2 additions & 0 deletions crates/rue-compiler/src/compiler/builtins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ fn sha256(db: &mut Database, builtins: &Builtins) -> SymbolId {
param_types: vec![builtins.bytes],
rest: Rest::Nil,
return_type: builtins.bytes32,
generic_types: Vec::new(),
},
}))
}
Expand All @@ -105,6 +106,7 @@ fn pubkey_for_exp(db: &mut Database, builtins: &Builtins) -> SymbolId {
param_types: vec![builtins.bytes32],
rest: Rest::Nil,
return_type: builtins.public_key,
generic_types: Vec::new(),
},
}))
}
7 changes: 7 additions & 0 deletions crates/rue-compiler/src/compiler/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,13 @@ pub fn build_graph(
ignored_types.extend(enum_type.variants.values());
}

for symbol_id in ignored_symbols.clone() {
let Symbol::Function(function) = db.symbol_mut(symbol_id).clone() else {
continue;
};
ignored_types.extend(function.ty.generic_types.iter().copied());
}

let Symbol::Module(module) = db.symbol_mut(main_module_id).clone() else {
unreachable!();
};
Expand Down
31 changes: 20 additions & 11 deletions crates/rue-compiler/src/compiler/expr/binary_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
compiler::Compiler,
hir::{BinOp, Hir},
ty::{Guard, Value},
Comparison, ErrorKind, TypeSystem,
Comparison, ErrorKind,
};

impl Compiler<'_> {
Expand Down Expand Up @@ -39,7 +39,10 @@ impl Compiler<'_> {
let lhs = lhs!();
let rhs = rhs!();

if self.db.compare_type(lhs.type_id, self.builtins.public_key) == Comparison::Equal
if self
.db
.compare_type_raw(lhs.type_id, self.builtins.public_key)
== Comparison::Equal
{
self.type_check(rhs.type_id, self.builtins.public_key, text_range);
(
Expand All @@ -48,7 +51,7 @@ impl Compiler<'_> {
rhs.hir_id,
self.builtins.public_key,
)
} else if self.db.compare_type(lhs.type_id, self.builtins.bytes)
} else if self.db.compare_type_raw(lhs.type_id, self.builtins.bytes)
== Comparison::Equal
{
self.type_check(rhs.type_id, self.builtins.bytes, text_range);
Expand Down Expand Up @@ -91,22 +94,25 @@ impl Compiler<'_> {
let lhs = lhs!();
let rhs = rhs!();

if self.db.compare_type(lhs.type_id, self.builtins.bytes) > Comparison::Castable
|| self.db.compare_type(rhs.type_id, self.builtins.bytes) > Comparison::Castable
if self.db.compare_type_raw(lhs.type_id, self.builtins.bytes) > Comparison::Castable
|| self.db.compare_type_raw(rhs.type_id, self.builtins.bytes)
> Comparison::Castable
{
self.db.error(
ErrorKind::NonAtomEquality(self.type_name(lhs.type_id)),
text_range,
);
} else if self.db.compare_type(lhs.type_id, self.builtins.nil) == Comparison::Equal
} else if self.db.compare_type_raw(lhs.type_id, self.builtins.nil)
== Comparison::Equal
{
if let Hir::Reference(symbol_id) = self.db.hir(rhs.hir_id) {
value.guards.insert(
*symbol_id,
Guard::new(self.builtins.nil, self.try_unwrap_optional(rhs.type_id)),
);
}
} else if self.db.compare_type(rhs.type_id, self.builtins.nil) == Comparison::Equal
} else if self.db.compare_type_raw(rhs.type_id, self.builtins.nil)
== Comparison::Equal
{
if let Hir::Reference(symbol_id) = self.db.hir(lhs.hir_id) {
value.guards.insert(
Expand All @@ -124,22 +130,25 @@ impl Compiler<'_> {
let lhs = lhs!();
let rhs = rhs!();

if self.db.compare_type(lhs.type_id, self.builtins.bytes) > Comparison::Castable
|| self.db.compare_type(rhs.type_id, self.builtins.bytes) > Comparison::Castable
if self.db.compare_type_raw(lhs.type_id, self.builtins.bytes) > Comparison::Castable
|| self.db.compare_type_raw(rhs.type_id, self.builtins.bytes)
> Comparison::Castable
{
self.db.error(
ErrorKind::NonAtomEquality(self.type_name(lhs.type_id)),
text_range,
);
} else if self.db.compare_type(lhs.type_id, self.builtins.nil) == Comparison::Equal
} else if self.db.compare_type_raw(lhs.type_id, self.builtins.nil)
== Comparison::Equal
{
if let Hir::Reference(symbol_id) = self.db.hir(rhs.hir_id) {
value.guards.insert(
*symbol_id,
Guard::new(self.try_unwrap_optional(rhs.type_id), self.builtins.nil),
);
}
} else if self.db.compare_type(rhs.type_id, self.builtins.nil) == Comparison::Equal
} else if self.db.compare_type_raw(rhs.type_id, self.builtins.nil)
== Comparison::Equal
{
if let Hir::Reference(symbol_id) = self.db.hir(lhs.hir_id) {
value.guards.insert(
Expand Down
14 changes: 13 additions & 1 deletion crates/rue-compiler/src/compiler/expr/function_call_expr.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::collections::HashMap;

use rue_parser::{AstNode, FunctionCallExpr};

use crate::{
Expand Down Expand Up @@ -28,6 +30,9 @@ impl Compiler<'_> {
None
};

self.generic_type_stack.push(HashMap::new());
self.allow_generic_inference_stack.push(true);

let mut args = Vec::new();
let mut arg_types = Vec::new();
let mut spread = false;
Expand Down Expand Up @@ -144,7 +149,14 @@ impl Compiler<'_> {
varargs: spread,
});

let type_id = expected.map_or(self.builtins.unknown, |expected| expected.return_type);
let mut type_id = expected.map_or(self.builtins.unknown, |expected| expected.return_type);

self.allow_generic_inference_stack.pop().unwrap();
let generic_types = self.generic_type_stack.pop().unwrap();

if !generic_types.is_empty() {
type_id = self.db.substitute_type(type_id, &generic_types);
}

Value::new(hir_id, type_id)
}
Expand Down
12 changes: 6 additions & 6 deletions crates/rue-compiler/src/compiler/expr/guard_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{
compiler::Compiler,
hir::{BinOp, Hir},
ty::{Guard, PairType, Type, Value},
Comparison, ErrorKind, HirId, TypeId, TypeSystem, WarningKind,
Comparison, ErrorKind, HirId, TypeId, WarningKind,
};

impl Compiler<'_> {
Expand Down Expand Up @@ -47,7 +47,7 @@ impl Compiler<'_> {
hir_id: HirId,
text_range: TextRange,
) -> Option<(Guard, HirId)> {
if self.db.compare_type(from, to) <= Comparison::Assignable {
if self.db.compare_type_raw(from, to) <= Comparison::Assignable {
self.db.warning(
WarningKind::RedundantTypeCheck(self.type_name(from)),
text_range,
Expand All @@ -57,11 +57,11 @@ impl Compiler<'_> {

match (self.db.ty(from).clone(), self.db.ty(to).clone()) {
(Type::Any, Type::Pair(PairType { first, rest })) => {
if self.db.compare_type(first, self.builtins.any) != Comparison::Equal {
if self.db.compare_type_raw(first, self.builtins.any) > Comparison::Equal {
self.db.error(ErrorKind::NonAnyPairTypeGuard, text_range);
}

if self.db.compare_type(rest, self.builtins.any) != Comparison::Equal {
if self.db.compare_type_raw(rest, self.builtins.any) > Comparison::Equal {
self.db.error(ErrorKind::NonAnyPairTypeGuard, text_range);
}

Expand All @@ -78,11 +78,11 @@ impl Compiler<'_> {
Some((Guard::new(to, pair_type), hir_id))
}
(Type::List(inner), Type::Pair(PairType { first, rest })) => {
if self.db.compare_type(first, inner) != Comparison::Equal {
if self.db.compare_type_raw(first, inner) > Comparison::Equal {
self.db.error(ErrorKind::NonListPairTypeGuard, text_range);
}

if self.db.compare_type(rest, from) != Comparison::Equal {
if self.db.compare_type_raw(rest, from) > Comparison::Equal {
self.db.error(ErrorKind::NonListPairTypeGuard, text_range);
}

Expand Down
3 changes: 3 additions & 0 deletions crates/rue-compiler/src/compiler/expr/lambda_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ impl Compiler<'_> {
.or(expected.map(|expected| expected.return_type));

self.scope_stack.push(scope_id);
self.allow_generic_inference_stack.push(false);
let body = self.compile_expr(&body, expected_return_type);
self.allow_generic_inference_stack.pop().unwrap();
self.scope_stack.pop().expect("lambda not in scope stack");

let return_type = expected_return_type.unwrap_or(body.type_id);
Expand All @@ -80,6 +82,7 @@ impl Compiler<'_> {
param_types: param_types.clone(),
rest,
return_type,
generic_types: Vec::new(),
};

let symbol_id = self.db.alloc_symbol(Symbol::Function(Function {
Expand Down
3 changes: 3 additions & 0 deletions crates/rue-compiler/src/compiler/generic_types.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
use super::Compiler;

impl Compiler<'_> {}
Loading

0 comments on commit ecc58e5

Please sign in to comment.