Skip to content

Commit

Permalink
Merge pull request #10 from Rigidity/overhaul-function-types
Browse files Browse the repository at this point in the history
Overhaul function types
  • Loading branch information
Rigidity authored Jun 25, 2024
2 parents f8630e3 + 837ad74 commit bf598ad
Show file tree
Hide file tree
Showing 23 changed files with 568 additions and 197 deletions.
36 changes: 5 additions & 31 deletions crates/rue-compiler/src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{
database::{Database, HirId, ScopeId, SymbolId, TypeId},
hir::Hir,
scope::Scope,
ty::{FunctionType, PairType, Rest, Type, Value},
ty::{PairType, Type, Value},
ErrorKind,
};

Expand Down Expand Up @@ -97,36 +97,6 @@ impl<'a> Compiler<'a> {
result
}

fn expected_param_type(
&self,
function_type: FunctionType,
index: usize,
spread: bool,
) -> Option<TypeId> {
let param_types = function_type.param_types;
let len = param_types.len();

if index + 1 < len {
return Some(param_types[index]);
}

if function_type.rest == Rest::Nil {
if index + 1 == len {
return Some(param_types[index]);
}
return None;
}

if spread {
return Some(param_types[len - 1]);
}

match self.db.ty(param_types[len - 1]) {
Type::List(list_type) => Some(*list_type),
_ => None,
}
}

fn type_reference(&mut self, referenced_type_id: TypeId) {
if let Some(symbol_id) = self.symbol_stack.last() {
self.sym
Expand Down Expand Up @@ -237,6 +207,10 @@ impl<'a> Compiler<'a> {
let inner = self.type_name_visitor(*ty, stack);
format!("{inner}?")
}
Type::PossiblyUndefined(ty) => {
let inner = self.type_name_visitor(*ty, stack);
format!("possibly undefined {inner}")
}
};

stack.pop().unwrap();
Expand Down
15 changes: 14 additions & 1 deletion crates/rue-compiler/src/compiler/expr/field_access_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use rue_parser::FieldAccessExpr;
use crate::{
compiler::Compiler,
hir::Hir,
ty::{PairType, Type, Value},
ty::{Guard, PairType, Type, Value},
ErrorKind,
};

Expand Down Expand Up @@ -51,6 +51,19 @@ impl Compiler<'_> {
self.builtins.int,
);
}
Type::PossiblyUndefined(inner) if field_name.text() == "exists" => {
let maybe_nil_reference = self.db.alloc_hir(Hir::CheckExists(value.hir_id));
let exists = self.db.alloc_hir(Hir::IsCons(maybe_nil_reference));
let mut new_value = Value::new(exists, self.builtins.bool);

if let Hir::Reference(symbol_id) = self.db.hir(value.hir_id).clone() {
new_value
.guards
.insert(symbol_id, Guard::new(inner, value.type_id));
}

return new_value;
}
_ => {}
}

Expand Down
253 changes: 145 additions & 108 deletions crates/rue-compiler/src/compiler/expr/function_call_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,159 +5,196 @@ use rue_parser::{AstNode, FunctionCallExpr};
use crate::{
compiler::Compiler,
hir::Hir,
ty::{Rest, Type, Value},
ErrorKind,
ty::{FunctionType, Rest, Type, Value},
ErrorKind, TypeId,
};

impl Compiler<'_> {
pub fn compile_function_call_expr(&mut self, call: &FunctionCallExpr) -> Value {
let Some(callee) = call.callee() else {
return self.unknown();
};

// Compile the callee expression.
// We mark this expression as a callee to allow inline function references.
self.is_callee = true;
let callee = self.compile_expr(&callee, None);

let expected = if let Type::Function(fun) = self.db.ty(callee.type_id) {
Some(fun.clone())
} else if let Type::Unknown = self.db.ty(callee.type_id) {
None
} else {
self.db.error(
ErrorKind::UncallableType(self.type_name(callee.type_id)),
call.callee().unwrap().syntax().text_range(),
);
None
};
let callee = call.callee().map(|callee| self.compile_expr(&callee, None));

// Get the function type of the callee.
let function_type =
callee
.as_ref()
.and_then(|callee| match self.db.ty(callee.type_id).clone() {
Type::Function(function_type) => Some(function_type),
_ => None,
});

// Make sure the callee is callable, if present.
if let Some(callee) = callee.as_ref() {
if function_type.is_none() {
self.db.error(
ErrorKind::UncallableType(self.type_name(callee.type_id)),
call.callee().unwrap().syntax().text_range(),
);
}
}

// Push a generic type context for the function, and allow inference.
self.generic_type_stack.push(HashMap::new());
self.allow_generic_inference_stack.push(true);

// Compile the arguments naively, and defer type checking until later.
let mut args = Vec::new();
let mut arg_types = Vec::new();
let mut spread = false;

let arg_len = call.args().len();
let call_args = call.args();
let len = call_args.len();

for (i, arg) in call.args().into_iter().enumerate().rev() {
let expected_type = expected.as_ref().and_then(|expected| {
self.expected_param_type(
expected.clone(),
i,
i + 1 == arg_len && arg.spread().is_some(),
)
for (i, arg) in call_args.into_iter().enumerate() {
// Determine the expected type.
let expected_type = function_type.as_ref().and_then(|ty| {
if i < ty.param_types.len() {
Some(ty.param_types[i])
} else if ty.rest == Rest::Spread {
self.db.unwrap_list(*ty.param_types.last().unwrap())
} else {
None
}
});

let value = arg
// Compile the argument expression, if present.
// Otherwise, it's a parser error
let expr = arg
.expr()
.map(|expr| self.compile_expr(&expr, expected_type))
.unwrap_or_else(|| self.unknown());

arg_types.push(value.type_id);
// Add the argument to the list.
args.push(expr);

// Check if it's a spread argument.
if arg.spread().is_some() {
if i + 1 == arg_len {
if i == len - 1 {
spread = true;
} else {
self.db
.error(ErrorKind::NonFinalSpread, arg.syntax().text_range());
.error(ErrorKind::InvalidSpreadArgument, arg.syntax().text_range());
}
}
}

args.push(value.hir_id);
// Check that the arguments match the parameters.
if let Some(function_type) = function_type.as_ref() {
let arg_types = args.iter().map(|arg| arg.type_id).collect::<Vec<_>>();
self.check_arguments(call, function_type, &arg_types, spread);
}

args.reverse();
arg_types.reverse();
// The generic type context is no longer needed.
let generic_types = self.generic_type_stack.pop().unwrap();
self.allow_generic_inference_stack.pop().unwrap();

if let Some(expected) = expected.as_ref() {
let param_len = expected.param_types.len();
// Calculate the return type.
let mut type_id =
function_type.map_or(self.builtins.unknown, |expected| expected.return_type);

let too_few_args = arg_types.len() < param_len
&& !(expected.rest == Rest::Parameter && arg_types.len() == param_len - 1);
let too_many_args = arg_types.len() > param_len && expected.rest == Rest::Nil;
if !generic_types.is_empty() {
type_id = self.db.substitute_type(type_id, &generic_types);
}

if too_few_args && expected.rest == Rest::Parameter {
self.db.error(
ErrorKind::TooFewArgumentsWithVarargs {
expected: param_len - 1,
found: arg_types.len(),
},
call.syntax().text_range(),
);
} else if too_few_args || too_many_args {
self.db.error(
ErrorKind::ArgumentMismatch {
expected: param_len,
found: arg_types.len(),
},
call.syntax().text_range(),
);
}
// Build the HIR for the function call.

for (i, arg) in arg_types.into_iter().enumerate() {
if i + 1 == arg_len && spread && expected.rest == Rest::Nil {
let hir_id = self.db.alloc_hir(Hir::FunctionCall {
callee: callee.map_or(self.builtins.unknown_hir, |callee| callee.hir_id),
args: args.iter().map(|arg| arg.hir_id).collect(),
varargs: spread,
});

Value::new(hir_id, type_id)
}

fn check_arguments(
&mut self,
ast: &FunctionCallExpr,
function: &FunctionType,
args: &[TypeId],
spread: bool,
) {
match function.rest {
Rest::Nil => {
if args.len() != function.param_types.len() {
self.db.error(
ErrorKind::NonVarargSpread,
call.args()[i].syntax().text_range(),
ErrorKind::ArgumentMismatch {
expected: function.param_types.len(),
found: args.len(),
},
ast.syntax().text_range(),
);
continue;
}

if i + 1 >= param_len
&& (i + 1 < arg_len || !spread)
&& expected.rest == Rest::Parameter
}
Rest::Optional => {
if args.len() != function.param_types.len()
&& args.len() != function.param_types.len() - 1
{
match self.db.ty(expected.param_types.last().copied().unwrap()) {
Type::List(list_type) => {
self.type_check(arg, *list_type, call.args()[i].syntax().text_range());
}
_ => {
self.db.error(
ErrorKind::NonListVararg,
call.args()[i].syntax().text_range(),
);
}
}
continue;
self.db.error(
ErrorKind::ArgumentMismatchOptional {
expected: function.param_types.len(),
found: args.len(),
},
ast.syntax().text_range(),
);
}

if i + 1 == arg_len && spread && expected.rest == Rest::Parameter {
self.type_check(
arg,
expected.param_types[param_len - 1],
call.args()[i].syntax().text_range(),
}
Rest::Spread => {
if self
.db
.unwrap_list(*function.param_types.last().unwrap())
.is_some()
{
if args.len() < function.param_types.len() - 1 {
self.db.error(
ErrorKind::ArgumentMismatchSpread {
expected: function.param_types.len(),
found: args.len(),
},
ast.syntax().text_range(),
);
}
} else if args.len() != function.param_types.len() {
self.db.error(
ErrorKind::ArgumentMismatch {
expected: function.param_types.len(),
found: args.len(),
},
ast.syntax().text_range(),
);
continue;
}

self.type_check(
arg,
expected
.param_types
.get(i)
.copied()
.unwrap_or(self.builtins.unknown),
call.args()[i].syntax().text_range(),
);
}
}

let hir_id = self.db.alloc_hir(Hir::FunctionCall {
callee: callee.hir_id,
args,
varargs: spread,
});
let ast_args = ast.args();

let mut type_id = expected.map_or(self.builtins.unknown, |expected| expected.return_type);
for (i, &arg) in args.iter().enumerate() {
let last = i == args.len() - 1;

self.allow_generic_inference_stack.pop().unwrap();
let generic_types = self.generic_type_stack.pop().unwrap();

if !generic_types.is_empty() {
type_id = self.db.substitute_type(type_id, &generic_types);
if last && spread {
if function.rest != Rest::Spread {
self.db.error(
ErrorKind::DisallowedSpread,
ast_args[i].syntax().text_range(),
);
} else if i >= function.param_types.len() - 1 {
let expected_type = *function.param_types.last().unwrap();
self.type_check(arg, expected_type, ast_args[i].syntax().text_range());
}
} else if function.rest == Rest::Spread && i >= function.param_types.len() - 1 {
if let Some(inner_list_type) =
self.db.unwrap_list(*function.param_types.last().unwrap())
{
self.type_check(arg, inner_list_type, ast_args[i].syntax().text_range());
} else if i == function.param_types.len() - 1 && !spread {
self.db
.error(ErrorKind::RequiredSpread, ast_args[i].syntax().text_range());
}
} else if i < function.param_types.len() {
let param_type = function.param_types[i];
self.type_check(arg, param_type, ast_args[i].syntax().text_range());
}
}

Value::new(hir_id, type_id)
}
}
Loading

0 comments on commit bf598ad

Please sign in to comment.