From 9a2b99c148f07e12001b62f1b399a280193a7114 Mon Sep 17 00:00:00 2001 From: Sam Mohr Date: Fri, 8 Nov 2024 00:05:56 -0800 Subject: [PATCH 1/4] Constrain early returns in functions in addition to closures --- crates/compiler/constrain/src/expr.rs | 16 +++++++++ crates/compiler/test_gen/src/gen_return.rs | 42 ++++++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/crates/compiler/constrain/src/expr.rs b/crates/compiler/constrain/src/expr.rs index ed8847043e..aedd7ab7f3 100644 --- a/crates/compiler/constrain/src/expr.rs +++ b/crates/compiler/constrain/src/expr.rs @@ -2114,6 +2114,21 @@ fn constrain_function_def( ret_type_index, )); + let mut early_return_constraints = Vec::with_capacity(function_def.early_returns.len()); + for (early_return_variable, early_return_region) in &function_def.early_returns { + let early_return_var = constraints.push_variable(*early_return_variable); + let early_return_con = constraints.equal_types( + early_return_var, + return_type_annotation_expected, + Category::Return, + *early_return_region, + ); + + early_return_constraints.push(early_return_con); + } + + let early_returns_constraint = constraints.and_constraint(early_return_constraints); + let solved_fn_type = { // TODO(types-soa) optimize for Variable let pattern_types = types.from_old_type_slice( @@ -2151,6 +2166,7 @@ fn constrain_function_def( std::file!(), std::line!(), ), + early_returns_constraint, constraints.let_constraint( [], argument_pattern_state.vars, diff --git a/crates/compiler/test_gen/src/gen_return.rs b/crates/compiler/test_gen/src/gen_return.rs index c2cb6d4c1c..09a2aa8e76 100644 --- a/crates/compiler/test_gen/src/gen_return.rs +++ b/crates/compiler/test_gen/src/gen_return.rs @@ -107,3 +107,45 @@ fn early_return_solo() { true ); } + +#[test] +#[cfg(any(feature = "gen-llvm", feature = "gen-wasm", feature = "gen-dev"))] +fn early_return_annotated_function() { + assert_evals_to!( + indoc!( + r#" + app "test" provides [main] to "./platform" + + failIfLessThanFive : U64 -> Result {} [LessThanFive] + failIfLessThanFive = \n -> + if n < 5 then + Err LessThanFive + else + Ok {} + + validateInput : Str -> Result U64 [InvalidNumStr, LessThanFive] + validateInput = \str -> + num = try Str.toU64 str + + when failIfLessThanFive num is + Err err -> + return Err err + + Ok {} -> + Ok num + + main : List Str + main = + ["abc", "3", "7"] + |> List.map validateInput + |> List.map Inspect.toStr + "# + ), + RocList::from_slice(&[ + RocStr::from("(Err InvalidNumStr)"), + RocStr::from("(Err LessThanFive)"), + RocStr::from("(Ok 7)") + ]), + RocList + ); +} From f8578729035b99e139d1add7f8c70f07bf895745 Mon Sep 17 00:00:00 2001 From: Sam Mohr Date: Thu, 21 Nov 2024 04:09:11 -0800 Subject: [PATCH 2/4] Properly type constrain all function types --- crates/compiler/constrain/src/expr.rs | 222 +++++++++++------- crates/compiler/test_gen/src/gen_return.rs | 114 +++++++++ .../test_mono/generated/return_annotated.txt | 48 ++++ crates/compiler/test_mono/src/tests.rs | 15 ++ 4 files changed, 320 insertions(+), 79 deletions(-) create mode 100644 crates/compiler/test_mono/generated/return_annotated.txt diff --git a/crates/compiler/constrain/src/expr.rs b/crates/compiler/constrain/src/expr.rs index 99ff1c9062..3921cf218e 100644 --- a/crates/compiler/constrain/src/expr.rs +++ b/crates/compiler/constrain/src/expr.rs @@ -168,37 +168,38 @@ fn constrain_untyped_closure( vars.push(closure_var); vars.push(fn_var); - let body_type = constraints.push_expected_type(ForReason( + let return_type_index = constraints.push_expected_type(ForReason( Reason::FunctionOutput, return_type_index, loc_body_expr.region, )); - let ret_constraint = env.with_fx_expectation(fx_var, None, |env| { - constrain_expr( + let returns_constraint = env.with_fx_expectation(fx_var, None, |env| { + let return_con = constrain_expr( types, constraints, env, loc_body_expr.region, &loc_body_expr.value, - body_type, - ) - }); - - let mut early_return_constraints = Vec::with_capacity(early_returns.len()); - for (early_return_variable, early_return_region) in early_returns { - let early_return_var = constraints.push_variable(*early_return_variable); - let early_return_con = constraints.equal_types( - early_return_var, - body_type, - Category::Return, - *early_return_region, + return_type_index, ); - early_return_constraints.push(early_return_con); - } + let mut return_constraints = Vec::with_capacity(early_returns.len() + 1); + return_constraints.push(return_con); + + for (early_return_variable, early_return_region) in early_returns { + let early_return_con = constraints.equal_types_var( + *early_return_variable, + return_type_index, + Category::Return, + *early_return_region, + ); + + return_constraints.push(early_return_con); + } - let early_returns_constraint = constraints.and_constraint(early_return_constraints); + constraints.and_constraint(return_constraints) + }); // make sure the captured symbols are sorted! debug_assert_eq!(captured_symbols.to_vec(), { @@ -231,7 +232,7 @@ fn constrain_untyped_closure( pattern_state.vars, pattern_state.headers, pattern_state_constraints, - ret_constraint, + returns_constraint, Generalizable(true), ), constraints.and_constraint(pattern_state.delayed_fx_suffix_constraints), @@ -242,7 +243,6 @@ fn constrain_untyped_closure( region, fn_var, ), - early_returns_constraint, closure_constraint, constraints.flex_to_pure(fx_var), ]; @@ -1423,23 +1423,20 @@ pub fn constrain_expr( return_var, } => { let return_type_index = constraints.push_variable(*return_var); - let expected_return_value = constraints.push_expected_type(ForReason( Reason::FunctionOutput, return_type_index, return_value.region, )); - let return_con = constrain_expr( + constrain_expr( types, constraints, env, return_value.region, &return_value.value, expected_return_value, - ); - - constraints.exists([*return_var], return_con) + ) } Tag { tag_union_var: variant_var, @@ -2062,21 +2059,6 @@ fn constrain_function_def( ret_type_index, )); - let mut early_return_constraints = Vec::with_capacity(function_def.early_returns.len()); - for (early_return_variable, early_return_region) in &function_def.early_returns { - let early_return_var = constraints.push_variable(*early_return_variable); - let early_return_con = constraints.equal_types( - early_return_var, - return_type_annotation_expected, - Category::Return, - *early_return_region, - ); - - early_return_constraints.push(early_return_con); - } - - let early_returns_constraint = constraints.and_constraint(early_return_constraints); - let solved_fn_type = { // TODO(types-soa) optimize for Variable let pattern_types = types.from_old_type_slice( @@ -2090,8 +2072,8 @@ fn constrain_function_def( constraints.push_type(types, fn_type) }; - let ret_constraint = { - let con = constrain_expr( + let returns_constraint = { + let return_con = constrain_expr( types, constraints, env, @@ -2099,7 +2081,33 @@ fn constrain_function_def( &loc_body_expr.value, return_type_annotation_expected, ); - attach_resolution_constraints(constraints, env, con) + + let mut return_constraints = + Vec::with_capacity(function_def.early_returns.len() + 1); + return_constraints.push(return_con); + + for (early_return_variable, early_return_region) in &function_def.early_returns { + let early_return_type_expected = + constraints.push_expected_type(Expected::ForReason( + Reason::FunctionOutput, + ret_type_index, + *early_return_region, + )); + + vars.push(*early_return_variable); + let early_return_con = constraints.equal_types_var( + *early_return_variable, + early_return_type_expected, + Category::Return, + *early_return_region, + ); + + return_constraints.push(early_return_con); + } + + let returns_constraint = constraints.and_constraint(return_constraints); + + attach_resolution_constraints(constraints, env, returns_constraint) }; vars.push(expr_var); @@ -2114,13 +2122,12 @@ fn constrain_function_def( std::file!(), std::line!(), ), - early_returns_constraint, constraints.let_constraint( [], argument_pattern_state.vars, argument_pattern_state.headers, defs_constraint, - ret_constraint, + returns_constraint, // This is a syntactic function, it can be generalized Generalizable(true), ), @@ -2876,6 +2883,7 @@ fn constrain_typed_def( function_type: fn_var, closure_type: closure_var, return_type: ret_var, + early_returns, fx_type: fx_var, captured_symbols, arguments, @@ -2945,7 +2953,7 @@ fn constrain_typed_def( constraints.push_type(types, fn_type) }; - let body_type = constraints.push_expected_type(FromAnnotation( + let return_type = constraints.push_expected_type(FromAnnotation( def.loc_pattern.clone(), arguments.len(), AnnotationSource::TypedBody { @@ -2954,18 +2962,35 @@ fn constrain_typed_def( ret_type_index, )); - let ret_constraint = env.with_fx_expectation(fx_var, Some(annotation.region), |env| { - constrain_expr( - types, - constraints, - env, - loc_body_expr.region, - &loc_body_expr.value, - body_type, - ) - }); + let returns_constraint = + env.with_fx_expectation(fx_var, Some(annotation.region), |env| { + let return_con = constrain_expr( + types, + constraints, + env, + loc_body_expr.region, + &loc_body_expr.value, + return_type, + ); - let ret_constraint = attach_resolution_constraints(constraints, env, ret_constraint); + let mut return_constraints = Vec::with_capacity(early_returns.len() + 1); + return_constraints.push(return_con); + + for (early_return_variable, early_return_region) in early_returns { + let early_return_con = constraints.equal_types_var( + *early_return_variable, + return_type, + Category::Return, + *early_return_region, + ); + + return_constraints.push(early_return_con); + } + + let returns_constraint = constraints.and_constraint(return_constraints); + + attach_resolution_constraints(constraints, env, returns_constraint) + }); vars.push(*fn_var); let defs_constraint = constraints.and_constraint(argument_pattern_state.constraints); @@ -2978,7 +3003,7 @@ fn constrain_typed_def( argument_pattern_state.vars, argument_pattern_state.headers, defs_constraint, - ret_constraint, + returns_constraint, // This is a syntactic function, it can be generalized Generalizable(true), ), @@ -3985,18 +4010,38 @@ fn constraint_recursive_function( constraints.push_type(types, typ) }; - let expr_con = env.with_fx_expectation(fx_var, Some(annotation.region), |env| { - let expected = constraints.push_expected_type(NoExpectation(ret_type_index)); - constrain_expr( - types, - constraints, - env, - loc_body_expr.region, - &loc_body_expr.value, - expected, - ) - }); - let expr_con = attach_resolution_constraints(constraints, env, expr_con); + let returns_constraint = + env.with_fx_expectation(fx_var, Some(annotation.region), |env| { + let expected = constraints.push_expected_type(NoExpectation(ret_type_index)); + let return_con = constrain_expr( + types, + constraints, + env, + loc_body_expr.region, + &loc_body_expr.value, + expected, + ); + + let mut return_constraints = + Vec::with_capacity(function_def.early_returns.len() + 1); + return_constraints.push(return_con); + + for (early_return_variable, early_return_region) in &function_def.early_returns + { + let early_return_con = constraints.equal_types_var( + *early_return_variable, + expected, + Category::Return, + *early_return_region, + ); + + return_constraints.push(early_return_con); + } + + let returns_constraint = constraints.and_constraint(return_constraints); + + attach_resolution_constraints(constraints, env, returns_constraint) + }); vars.push(expr_var); @@ -4008,7 +4053,7 @@ fn constraint_recursive_function( argument_pattern_state.vars, argument_pattern_state.headers, state_constraints, - expr_con, + returns_constraint, // Syntactic function can be generalized Generalizable(true), ), @@ -4470,6 +4515,7 @@ fn rec_defs_help( function_type: fn_var, closure_type: closure_var, return_type: ret_var, + early_returns, fx_type: fx_var, captured_symbols, arguments, @@ -4537,22 +4583,40 @@ fn rec_defs_help( let typ = types.function(pattern_types, lambda_set, ret_type, fx_type); constraints.push_type(types, typ) }; - let expr_con = + let returns_constraint = env.with_fx_expectation(fx_var, Some(annotation.region), |env| { - let body_type = + let return_type_expected = constraints.push_expected_type(NoExpectation(ret_type_index)); - constrain_expr( + let return_con = constrain_expr( types, constraints, env, loc_body_expr.region, &loc_body_expr.value, - body_type, - ) - }); + return_type_expected, + ); - let expr_con = attach_resolution_constraints(constraints, env, expr_con); + let mut return_constraints = + Vec::with_capacity(early_returns.len() + 1); + return_constraints.push(return_con); + + for (early_return_variable, early_return_region) in early_returns { + let early_return_con = constraints.equal_types_var( + *early_return_variable, + return_type_expected, + Category::Return, + *early_return_region, + ); + + return_constraints.push(early_return_con); + } + + let returns_constraint = + constraints.and_constraint(return_constraints); + + attach_resolution_constraints(constraints, env, returns_constraint) + }); vars.push(*fn_var); @@ -4567,7 +4631,7 @@ fn rec_defs_help( argument_pattern_state.vars, argument_pattern_state.headers, state_constraints, - expr_con, + returns_constraint, generalizable, ), // Check argument suffixes against usage diff --git a/crates/compiler/test_gen/src/gen_return.rs b/crates/compiler/test_gen/src/gen_return.rs index 09a2aa8e76..b35ac895a4 100644 --- a/crates/compiler/test_gen/src/gen_return.rs +++ b/crates/compiler/test_gen/src/gen_return.rs @@ -108,6 +108,24 @@ fn early_return_solo() { ); } +#[test] +#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))] +fn early_return_solo_annotated() { + assert_evals_to!( + r#" + identity : Str -> Str + identity = \x -> + return x + + identity "abc" + "#, + RocStr::from("abc"), + RocStr, + identity, + true + ); +} + #[test] #[cfg(any(feature = "gen-llvm", feature = "gen-wasm", feature = "gen-dev"))] fn early_return_annotated_function() { @@ -149,3 +167,99 @@ fn early_return_annotated_function() { RocList ); } + +#[test] +#[cfg(any(feature = "gen-llvm", feature = "gen-wasm", feature = "gen-dev"))] +fn early_return_nested_annotated_function() { + assert_evals_to!( + indoc!( + r#" + app "test" provides [main] to "./platform" + + validateInput : Str -> Result U64 [InvalidNumStr, LessThanFive] + validateInput = \str -> + failIfLessThanFive : U64 -> Result {} [LessThanFive] + failIfLessThanFive = \n -> + if n < 5 then + Err LessThanFive + else + Ok {} + + num = try Str.toU64 str + + when failIfLessThanFive num is + Err err -> + return Err err + + Ok {} -> + Ok num + + main : List Str + main = + ["abc", "3", "7"] + |> List.map validateInput + |> List.map Inspect.toStr + "# + ), + RocList::from_slice(&[ + RocStr::from("(Err InvalidNumStr)"), + RocStr::from("(Err LessThanFive)"), + RocStr::from("(Ok 7)") + ]), + RocList + ); +} + +#[test] +#[cfg(any(feature = "gen-llvm", feature = "gen-wasm", feature = "gen-dev"))] +fn early_return_annotated_recursive_function() { + assert_evals_to!( + indoc!( + r#" + app "test" provides [main] to "./platform" + + mightCallSecond : U64 -> Result U64 _ + mightCallSecond = \num -> + nextNum = + if num < 5 then + return Err LessThanFive + else + num - 1 + + mightCallFirst nextNum + + mightCallFirst : U64 -> Result U64 _ + mightCallFirst = \num -> + nextNum = + if num < 10 then + return Err LessThanTen + else + num * 2 + + if nextNum > 25 then + Ok nextNum + else + mightCallSecond nextNum + + main : List Str + main = + [ + mightCallSecond 3, + mightCallSecond 7, + mightCallSecond 20, + mightCallFirst 7, + mightCallFirst 15, + ] + |> List.map Inspect.toStr + "# + ), + RocList::from_slice(&[ + RocStr::from("(Err LessThanFive)"), + RocStr::from("(Err LessThanTen)"), + RocStr::from("(Ok 38)"), + RocStr::from("(Err LessThanTen)"), + RocStr::from("(Ok 30)") + ]), + RocList + ); +} diff --git a/crates/compiler/test_mono/generated/return_annotated.txt b/crates/compiler/test_mono/generated/return_annotated.txt new file mode 100644 index 0000000000..c87a64563d --- /dev/null +++ b/crates/compiler/test_mono/generated/return_annotated.txt @@ -0,0 +1,48 @@ +procedure Bool.11 (#Attr.2, #Attr.3): + let Bool.23 : Int1 = lowlevel Eq #Attr.2 #Attr.3; + ret Bool.23; + +procedure Str.26 (Str.83): + let Str.246 : [C {}, C U64] = CallByName Str.66 Str.83; + ret Str.246; + +procedure Str.42 (#Attr.2): + let Str.254 : {U64, U8} = lowlevel StrToNum #Attr.2; + ret Str.254; + +procedure Str.66 (Str.191): + let Str.192 : {U64, U8} = CallByName Str.42 Str.191; + let Str.252 : U8 = StructAtIndex 1 Str.192; + let Str.253 : U8 = 0i64; + let Str.249 : Int1 = CallByName Bool.11 Str.252 Str.253; + if Str.249 then + let Str.251 : U64 = StructAtIndex 0 Str.192; + let Str.250 : [C {}, C U64] = TagId(1) Str.251; + ret Str.250; + else + let Str.248 : {} = Struct {}; + let Str.247 : [C {}, C U64] = TagId(0) Str.248; + ret Str.247; + +procedure Test.3 (Test.4): + joinpoint Test.14 Test.5: + let Test.12 : [C {}, C U64] = TagId(1) Test.5; + ret Test.12; + in + let Test.13 : [C {}, C U64] = CallByName Str.26 Test.4; + let Test.18 : U8 = 1i64; + let Test.19 : U8 = GetTagId Test.13; + let Test.20 : Int1 = lowlevel Eq Test.18 Test.19; + if Test.20 then + let Test.6 : U64 = UnionAtIndex (Id 1) (Index 0) Test.13; + jump Test.14 Test.6; + else + let Test.7 : {} = UnionAtIndex (Id 0) (Index 0) Test.13; + let Test.17 : [C {}, C U64] = TagId(0) Test.7; + ret Test.17; + +procedure Test.0 (): + let Test.11 : Str = "123"; + let Test.10 : [C {}, C U64] = CallByName Test.3 Test.11; + dec Test.11; + ret Test.10; diff --git a/crates/compiler/test_mono/src/tests.rs b/crates/compiler/test_mono/src/tests.rs index 87d5ed8018..4a4aa2fb6d 100644 --- a/crates/compiler/test_mono/src/tests.rs +++ b/crates/compiler/test_mono/src/tests.rs @@ -3695,3 +3695,18 @@ fn dec_refcount_for_usage_after_early_return_in_if() { "# ) } + +#[mono_test] +fn return_annotated() { + indoc!( + r#" + validateInput : Str -> Result U64 _ + validateInput = \str -> + num = try Str.toU64 str + + Ok num + + validateInput "123" + "# + ) +} From 899a7d3308306a60c55701f57e17c8b09f4eefc3 Mon Sep 17 00:00:00 2001 From: Sam Mohr Date: Thu, 21 Nov 2024 04:38:58 -0800 Subject: [PATCH 3/4] Add reporting tests for annotated functions with early returns --- crates/compiler/load/tests/test_reporting.rs | 71 ++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/crates/compiler/load/tests/test_reporting.rs b/crates/compiler/load/tests/test_reporting.rs index c209249887..590abbfa25 100644 --- a/crates/compiler/load/tests/test_reporting.rs +++ b/crates/compiler/load/tests/test_reporting.rs @@ -14655,6 +14655,77 @@ All branches in an `if` must have the same type! "### ); + test_report!( + mismatch_only_early_returns, + indoc!( + r#" + myFunction = \x -> + if x == 5 then + return "abc" + else + return 123 + + myFunction 3 + "# + ), + @r###" + ── TYPE MISMATCH in /code/proj/Main.roc ──────────────────────────────────────── + + This `return` statement doesn't match the return type of its enclosing + function: + + 5│ if x == 5 then + 6│ return "abc" + 7│ else + 8│ return 123 + ^^^^^^^^^^ + + This returns a value of type: + + Num * + + But I expected the function to have return type: + + Str + "### + ); + + test_report!( + mismatch_early_return_annotated_function, + indoc!( + r#" + myFunction : U64 -> Str + myFunction = \x -> + if x == 5 then + return 123 + else + "abc" + + myFunction 3 + "# + ), + @r###" + ── TYPE MISMATCH in /code/proj/Main.roc ──────────────────────────────────────── + + Something is off with the body of the `myFunction` definition: + + 4│ myFunction : U64 -> Str + 5│ myFunction = \x -> + 6│ if x == 5 then + 7│ return 123 + ^^^^^^^^^^ + + This returns a value of type: + + Num * + + But the type annotation on `myFunction` says it should be: + + Str + + "### + ); + test_report!( leftover_statement, indoc!( From c5b2e160cabcc256d62ef201592847dc388b5d08 Mon Sep 17 00:00:00 2001 From: Sam Mohr Date: Thu, 21 Nov 2024 09:52:38 -0800 Subject: [PATCH 4/4] Ignore switch stmt ret_layout --- crates/compiler/gen_dev/src/lib.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/crates/compiler/gen_dev/src/lib.rs b/crates/compiler/gen_dev/src/lib.rs index 2cc19f6f23..72391fee81 100644 --- a/crates/compiler/gen_dev/src/lib.rs +++ b/crates/compiler/gen_dev/src/lib.rs @@ -663,7 +663,9 @@ trait Backend<'a> { cond_layout, branches, default_branch, - ret_layout, + // always use the proc's ret_layout, as early returns can make + // this ret_layout inaccurate + ret_layout: _, } => { self.load_literal_symbols(&[*cond_symbol]); self.build_switch(