From e65faaeb77e31d0abfb60c56d23e422bf0ce17da Mon Sep 17 00:00:00 2001 From: Richard Eisenberg Date: Thu, 19 Oct 2023 21:00:24 +0100 Subject: [PATCH] flambda-backend: Make `assert false` behave as local_ or not, depending on what's better (+ 2 bugfixes) (#1899) * Add test cases * Refactor is_local_returning_function; improve error * Actually allow [assert false] to imply [local_] * Add test case around local newtype annotations * Look through newtypes with layout annotations --- testsuite/tests/typing-local/local-layouts.ml | 12 + testsuite/tests/typing-local/local.ml | 25 ++ typing/typecore.ml | 234 ++++++++++-------- 3 files changed, 171 insertions(+), 100 deletions(-) create mode 100644 testsuite/tests/typing-local/local-layouts.ml diff --git a/testsuite/tests/typing-local/local-layouts.ml b/testsuite/tests/typing-local/local-layouts.ml new file mode 100644 index 00000000000..8e84ee814d6 --- /dev/null +++ b/testsuite/tests/typing-local/local-layouts.ml @@ -0,0 +1,12 @@ +(* TEST + * expect + flags = "-extension layouts_beta" +*) + +let foo _t (type a) = local_ 1 +let bar _t (type a : value) = local_ 2 + +[%%expect{| +val foo : 'a -> local_ int = +val bar : 'a -> local_ int = +|}] diff --git a/testsuite/tests/typing-local/local.ml b/testsuite/tests/typing-local/local.ml index 7e79efb01ff..f1016664111 100644 --- a/testsuite/tests/typing-local/local.ml +++ b/testsuite/tests/typing-local/local.ml @@ -2769,3 +2769,28 @@ Line 2, characters 33-58: Error: This expression has type int -> local_ (int -> int) but an expression was expected of type int -> (int -> int) |}];; + +(* test that [function] checks all its branches either for local_ or the + absence thereof *) +let foo = function + | false -> local_ 5 + | true -> 6 + +[%%expect{| +Line 3, characters 12-13: +3 | | true -> 6 + ^ +Error: This function return is not annotated with "local_" + whilst other returns were. +|}] + +(* test that [assert false] can mix with other returns being [local_] *) +let foo b = + if b + then assert false + else local_ Some 6 + +[%%expect{| +val foo : bool -> local_ int option = +|}] + diff --git a/typing/typecore.ml b/typing/typecore.ml index 69505390e44..6f925f4eeee 100644 --- a/typing/typecore.ml +++ b/typing/typecore.ml @@ -3647,72 +3647,140 @@ let check_recursive_class_bindings env ids exprs = raise(Error(expr.cl_loc, env, Illegal_class_expr))) exprs -(* Is the return value annotated with "local_" *) -let is_local_returning_expr e = - let combine (local1, loc1) (local2, loc2) = - match local1, local2 with - | true, true -> true, loc1 - | false, false -> false, loc1 - | false, true -> - raise(Error(loc1, Env.empty, Local_return_annotation_mismatch loc2)) - | true, false -> - raise(Error(loc2, Env.empty, Local_return_annotation_mismatch loc1)) - in - let rec loop e = - match Jane_syntax.Expression.of_ast e with - | Some (jexp, _attrs) -> begin - match jexp with - | Jexp_comprehension _ -> false, e.pexp_loc - | Jexp_immutable_array _ -> false, e.pexp_loc - | Jexp_layout (Lexp_constant _) -> false, e.pexp_loc - | Jexp_layout (Lexp_newtype (_, _, e)) -> loop e - | Jexp_n_ary_function _ -> false, e.pexp_loc - end - | None -> - match e.pexp_desc with - | Pexp_apply - ({ pexp_desc = Pexp_extension( - {txt = "extension.local"|"ocaml.local"|"local"}, PStr []) }, - [Nolabel, _]) -> - true, e.pexp_loc - | Pexp_apply - ({ pexp_desc = Pexp_extension( - {txt = "extension.unique"|"ocaml.unique"|"unique"}, PStr []) }, - [Nolabel, exp]) -> - loop exp - | Pexp_apply - ({ pexp_desc = Pexp_extension( - {txt = "extension.once" | "ocaml.once" | "once"}, PStr []) }, - [Nolabel, exp]) -> - loop exp - | Pexp_ident _ | Pexp_constant _ | Pexp_apply _ | Pexp_tuple _ - | Pexp_construct _ | Pexp_variant _ | Pexp_record _ | Pexp_field _ - | Pexp_setfield _ | Pexp_array _ | Pexp_while _ | Pexp_for _ | Pexp_send _ - | Pexp_new _ | Pexp_setinstvar _ | Pexp_override _ | Pexp_assert _ - | Pexp_lazy _ | Pexp_object _ | Pexp_pack _ | Pexp_function _ | Pexp_fun _ - | Pexp_letop _ | Pexp_extension _ | Pexp_unreachable -> - false, e.pexp_loc - | Pexp_let(_, _, e) | Pexp_sequence(_, e) | Pexp_constraint(e, _) - | Pexp_coerce(e, _, _) | Pexp_letmodule(_, _, e) | Pexp_letexception(_, e) - | Pexp_poly(e, _) | Pexp_newtype(_, e) | Pexp_open(_, e) - | Pexp_ifthenelse(_, e, None)-> - loop e - | Pexp_ifthenelse(_, e1, Some e2)-> combine (loop e1) (loop e2) - | Pexp_match(_, cases) -> begin - match cases with - | [] -> false, e.pexp_loc - | first :: rest -> - List.fold_left - (fun acc pc -> combine acc (loop pc.pc_rhs)) - (loop first.pc_rhs) rest +module Is_local_returning : sig + val function_ : Parsetree.case list -> bool +end = struct + + (* Is the return value annotated with "local_"? + [assert false] can work either way *) + + type local_returning_flag = + | Local of Location.t (* location of a local return *) + | Not of Location.t (* location of a non-local return *) + | Either + [@@warning "-unused-constructor"] + + let combine flag1 flag2 = + match flag1, flag2 with + | (Local _ as flag), Local _ + | (Local _ as flag), Either + | (Not _ as flag), Not _ + | (Not _ as flag), Either + | Either, (Local _ as flag) + | Either, (Not _ as flag) + | (Either as flag), Either -> + flag + + | Local local_loc, Not not_local_loc + | Not not_local_loc, Local local_loc -> + raise(Error(not_local_loc, Env.empty, + Local_return_annotation_mismatch local_loc)) + + let expr e = + let rec loop e = + match Jane_syntax.Expression.of_ast e with + | Some (jexp, _attrs) -> begin + match jexp with + | Jexp_comprehension _ -> Not e.pexp_loc + | Jexp_immutable_array _ -> Not e.pexp_loc + | Jexp_layout (Lexp_constant _) -> Not e.pexp_loc + | Jexp_layout (Lexp_newtype (_, _, e)) -> loop e + | Jexp_n_ary_function _ -> Not e.pexp_loc + end + | None -> + match e.pexp_desc with + | Pexp_apply + ({ pexp_desc = Pexp_extension( + {txt = "extension.local"|"ocaml.local"|"local"}, PStr []) }, + [Nolabel, _]) -> + Local e.pexp_loc + | Pexp_apply + ({ pexp_desc = Pexp_extension( + {txt = "extension.unique"|"ocaml.unique"|"unique"}, PStr []) }, + [Nolabel, exp]) -> + loop exp + | Pexp_apply + ({ pexp_desc = Pexp_extension( + {txt = "extension.once" | "ocaml.once" | "once"}, PStr []) }, + [Nolabel, exp]) -> + loop exp + | Pexp_assert { pexp_desc = Pexp_construct ({ txt = Lident "false" }, + None) } -> + Either + | Pexp_ident _ | Pexp_constant _ | Pexp_apply _ | Pexp_tuple _ + | Pexp_construct _ | Pexp_variant _ | Pexp_record _ | Pexp_field _ + | Pexp_setfield _ | Pexp_array _ | Pexp_while _ | Pexp_for _ | Pexp_send _ + | Pexp_new _ | Pexp_setinstvar _ | Pexp_override _ | Pexp_assert _ + | Pexp_lazy _ | Pexp_object _ | Pexp_pack _ | Pexp_function _ | Pexp_fun _ + | Pexp_letop _ | Pexp_extension _ | Pexp_unreachable -> + Not e.pexp_loc + | Pexp_let(_, _, e) | Pexp_sequence(_, e) | Pexp_constraint(e, _) + | Pexp_coerce(e, _, _) | Pexp_letmodule(_, _, e) | Pexp_letexception(_, e) + | Pexp_poly(e, _) | Pexp_newtype(_, e) | Pexp_open(_, e) + | Pexp_ifthenelse(_, e, None)-> + loop e + | Pexp_ifthenelse(_, e1, Some e2)-> combine (loop e1) (loop e2) + | Pexp_match(_, cases) -> begin + match cases with + | [] -> Not e.pexp_loc + | first :: rest -> + List.fold_left + (fun acc pc -> combine acc (loop pc.pc_rhs)) + (loop first.pc_rhs) rest + end + | Pexp_try(e, cases) -> + List.fold_left + (fun acc pc -> combine acc (loop pc.pc_rhs)) + (loop e) cases + in + loop e + + let function_ cases = + let rec loop_cases cases = + match cases with + | [] -> Misc.fatal_error "empty cases in function_" + | [{pc_lhs = _; pc_guard = None; pc_rhs = e}] -> + loop_body e + | case :: cases -> + let is_local_returning_case case = + expr case.pc_rhs + in + List.fold_left + (fun acc case -> combine acc (is_local_returning_case case)) + (is_local_returning_case case) cases + and loop_body e = + if Builtin_attributes.has_curry e.pexp_attributes then + expr e + else begin + match Jane_syntax.Expression.of_ast e with + | Some (jexp, _attrs) -> begin + match jexp with + | Jexp_n_ary_function (_, _, Pfunction_cases (cases, _, _)) -> + loop_cases cases + | Jexp_n_ary_function (_, _, Pfunction_body body) -> + loop_body body + | Jexp_comprehension _ | Jexp_immutable_array _ -> + expr e + | Jexp_layout (Lexp_constant _) -> + Not e.pexp_loc + | Jexp_layout (Lexp_newtype (_, _, body)) -> + loop_body body + end + | None -> match e.pexp_desc, e.pexp_attributes with + | Pexp_fun(_, _, _, e), _ -> loop_body e + | Pexp_function cases, _ -> loop_cases cases + | Pexp_constraint (e, _), _ -> loop_body e + | Pexp_let (Nonrecursive, _, e), + [{Parsetree.attr_name = {txt="#default"};_}] -> loop_body e + | _ -> expr e end - | Pexp_try(e, cases) -> - List.fold_left - (fun acc pc -> combine acc (loop pc.pc_rhs)) - (loop e) cases - in - let local, _ = loop e in - local + in + match loop_cases cases with + | Local _ -> true + | Either | Not _ -> false + (* [fun _ -> assert false] must not be local-returning for + backward compatibility *) +end let rec is_an_uncurried_function e = if Builtin_attributes.has_curry e.pexp_attributes then false @@ -3729,40 +3797,6 @@ let rec is_an_uncurried_function e = | _ -> false end -let is_local_returning_function cases = - let rec loop_cases cases = - match cases with - | [] -> false - | [{pc_lhs = _; pc_guard = None; pc_rhs = e}] -> - loop_body e - | cases -> - List.for_all (fun case -> is_local_returning_expr case.pc_rhs) cases - and loop_body e = - if Builtin_attributes.has_curry e.pexp_attributes then - is_local_returning_expr e - else begin - match Jane_syntax.Expression.of_ast e with - | Some (jexp, _attrs) -> begin - match jexp with - | Jexp_n_ary_function (_, _, Pfunction_cases (cases, _, _)) -> - loop_cases cases - | Jexp_n_ary_function (_, _, Pfunction_body body) -> - loop_body body - | Jexp_comprehension _ | Jexp_immutable_array _ -> - is_local_returning_expr e - | Jexp_layout (Lexp_constant _ | Lexp_newtype _) -> false - end - | None -> match e.pexp_desc, e.pexp_attributes with - | Pexp_fun(_, _, _, e), _ -> loop_body e - | Pexp_function cases, _ -> loop_cases cases - | Pexp_constraint (e, _), _ -> loop_body e - | Pexp_let (Nonrecursive, _, e), - [{Parsetree.attr_name = {txt="#default"};_}] -> loop_body e - | _ -> is_local_returning_expr e - end - in - loop_cases cases - (* The "rest of the function" extends from the start of the first parameter to the end of the overall function. The parser does not construct such a location so we forge one for type errors. @@ -6154,7 +6188,7 @@ and type_function match in_function with | Some (_, _, region_locked) -> env, region_locked | None -> - let region_locked = not (is_local_returning_function caselist) in + let region_locked = not (Is_local_returning.function_ caselist) in let env = Env.add_closure_lock ?closure_context:expected_mode.closure_context