Skip to content

Commit

Permalink
flambda-backend: Better region handling for functions (#1871)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ekdohibs authored Nov 20, 2023
1 parent 948507a commit 2170ee5
Show file tree
Hide file tree
Showing 21 changed files with 115 additions and 76 deletions.
15 changes: 11 additions & 4 deletions lambda/lambda.ml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ include (struct
if Config.stack_allocation then Modify_maybe_stack
else Modify_heap

let equal_alloc_mode mode1 mode2 =
match mode1, mode2 with
| Alloc_local, Alloc_local | Alloc_heap, Alloc_heap -> true
| (Alloc_local | Alloc_heap), _ -> false

end : sig

type locality_mode = private
Expand All @@ -92,6 +97,7 @@ end : sig

val join_mode : alloc_mode -> alloc_mode -> alloc_mode

val equal_alloc_mode : alloc_mode -> alloc_mode -> bool
end)

let is_local_mode = function
Expand Down Expand Up @@ -612,6 +618,7 @@ and lfunction =
attr: function_attribute; (* specified with [@inline] attribute *)
loc: scoped_location;
mode: alloc_mode;
ret_mode: alloc_mode;
region: bool; }

and lambda_while =
Expand Down Expand Up @@ -675,7 +682,7 @@ let max_arity () =
(* 126 = 127 (the maximal number of parameters supported in C--)
- 1 (the hidden parameter containing the environment) *)

let lfunction ~kind ~params ~return ~body ~attr ~loc ~mode ~region =
let lfunction ~kind ~params ~return ~body ~attr ~loc ~mode ~ret_mode ~region =
assert (List.length params <= max_arity ());
(* A curried function type with n parameters has n arrows. Of these,
the first [n-nlocal] have return mode Heap, while the remainder
Expand All @@ -698,7 +705,7 @@ let lfunction ~kind ~params ~return ~body ~attr ~loc ~mode ~region =
if not region then assert (nlocal >= 1);
if is_local_mode mode then assert (nlocal = nparams)
end;
Lfunction { kind; params; return; body; attr; loc; mode; region }
Lfunction { kind; params; return; body; attr; loc; mode; ret_mode; region }

let lambda_unit = Lconst const_unit

Expand Down Expand Up @@ -1272,9 +1279,9 @@ let shallow_map ~tail ~non_tail:f = function
ap_specialised;
ap_probe;
}
| Lfunction { kind; params; return; body; attr; loc; mode; region } ->
| Lfunction { kind; params; return; body; attr; loc; mode; ret_mode; region } ->
Lfunction { kind; params; return; body = f body; attr; loc;
mode; region }
mode; ret_mode; region }
| Llet (str, layout, v, e1, e2) ->
Llet (str, layout, v, f e1, tail e2)
| Lmutlet (layout, v, e1, e2) ->
Expand Down
4 changes: 4 additions & 0 deletions lambda/lambda.mli
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ val modify_heap : modify_mode

val modify_maybe_stack : modify_mode

val equal_alloc_mode : alloc_mode -> alloc_mode -> bool

type initialization_or_assignment =
(* [Assignment Alloc_local] is a mutation of a block that may be heap or local.
[Assignment Alloc_heap] is a mutation of a block that's definitely heap. *)
Expand Down Expand Up @@ -515,6 +517,7 @@ and lfunction = private
attr: function_attribute; (* specified with [@inline] attribute *)
loc : scoped_location;
mode : alloc_mode; (* alloc mode of the closure itself *)
ret_mode: alloc_mode;
region : bool; (* false if this function may locally
allocate in the caller's region *)
}
Expand Down Expand Up @@ -635,6 +638,7 @@ val lfunction :
attr:function_attribute -> (* specified with [@inline] attribute *)
loc:scoped_location ->
mode:alloc_mode ->
ret_mode:alloc_mode ->
region:bool ->
lambda

Expand Down
5 changes: 2 additions & 3 deletions lambda/printlambda.ml
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,7 @@ let rec lam ppf = function
apply_inlined_attribute ap.ap_inlined
apply_specialised_attribute ap.ap_specialised
apply_probe ap.ap_probe
| Lfunction{kind; params; return; body; attr; mode; region} ->
| Lfunction{kind; params; return; body; attr; ret_mode; mode} ->
let pr_params ppf params =
match kind with
| Curried {nlocal} ->
Expand All @@ -830,10 +830,9 @@ let rec lam ppf = function
layout ppf p.layout)
params;
fprintf ppf ")" in
let rmode = if region then alloc_heap else alloc_local in
fprintf ppf "@[<2>(function%s%a@ %a%a%a)@]"
(alloc_kind mode) pr_params params
function_attribute attr return_kind (rmode, return) lam body
function_attribute attr return_kind (ret_mode, return) lam body
| Llet _ | Lmutlet _ as expr ->
let let_kind = begin function
| Llet(str,_,_,_,_) ->
Expand Down
20 changes: 10 additions & 10 deletions lambda/simplif.ml
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,8 @@ let simplify_exits lam =
| Lapply ap ->
Lapply{ap with ap_func = simplif ~layout:None ~try_depth ap.ap_func;
ap_args = List.map (simplif ~layout:None ~try_depth) ap.ap_args}
| Lfunction{kind; params; return; mode; region; body = l; attr; loc} ->
lfunction ~kind ~params ~return ~mode ~region
| Lfunction{kind; params; return; mode; ret_mode; region; body = l; attr; loc} ->
lfunction ~kind ~params ~return ~mode ~region ~ret_mode
~body:(simplif ~layout:None ~try_depth l) ~attr ~loc
| Llet(str, kind, v, l1, l2) ->
Llet(str, kind, v, simplif ~layout:None ~try_depth l1, simplif ~layout ~try_depth l2)
Expand Down Expand Up @@ -556,12 +556,12 @@ let simplify_lets lam =
| _ -> no_opt ()
end
| Lfunction{kind=outer_kind; params; return=outer_return; body = l;
attr; loc; mode; region=outer_region} ->
attr; loc; ret_mode; mode; region=outer_region} ->
begin match outer_kind, outer_region, simplif l with
Curried {nlocal=0},
true,
Lfunction{kind=Curried _ as kind; params=params'; return=return2;
body; attr; loc; mode=inner_mode; region}
body; attr; loc; mode=inner_mode; ret_mode; region}
when optimize &&
List.length params + List.length params' <= Lambda.max_arity() ->
(* The returned function's mode should match the outer return mode *)
Expand All @@ -571,9 +571,9 @@ let simplify_lets lam =
type of the merged function taking [params @ params'] as
parameters is the type returned after applying [params']. *)
let return = return2 in
lfunction ~kind ~params:(params @ params') ~return ~body ~attr ~loc ~mode ~region
lfunction ~kind ~params:(params @ params') ~return ~body ~attr ~loc ~mode ~ret_mode ~region
| kind, region, body ->
lfunction ~kind ~params ~return:outer_return ~body ~attr ~loc ~mode ~region
lfunction ~kind ~params ~return:outer_return ~body ~attr ~loc ~mode ~ret_mode ~region
end
| Llet(_str, _k, v, Lvar w, l2) when optimize ->
Hashtbl.add subst v (simplif (Lvar w));
Expand Down Expand Up @@ -759,7 +759,7 @@ and list_emit_tail_infos is_tail =
function's body. *)

let split_default_wrapper ~id:fun_id ~kind ~params ~return ~body
~attr ~loc ~mode ~region:orig_region =
~attr ~loc ~mode ~ret_mode ~region:orig_region =
let rec aux map add_region = function
(* When compiling [fun ?(x=expr) -> body], this is first translated
to:
Expand Down Expand Up @@ -836,7 +836,7 @@ let split_default_wrapper ~id:fun_id ~kind ~params ~return ~body
let inner_fun =
lfunction ~kind:(Curried {nlocal=0})
~params:new_ids
~return ~body ~attr ~loc ~mode ~region:true
~return ~body ~attr ~loc ~mode ~ret_mode ~region:true
in
(wrapper_body, (inner_id, inner_fun))
in
Expand All @@ -849,9 +849,9 @@ let split_default_wrapper ~id:fun_id ~kind ~params ~return ~body
end;
let body, inner = aux [] false body in
let attr = { default_stub_attribute with check = attr.check } in
[(fun_id, lfunction ~kind ~params ~return ~body ~attr ~loc ~mode ~region:true); inner]
[(fun_id, lfunction ~kind ~params ~return ~body ~attr ~loc ~mode ~ret_mode ~region:true); inner]
with Exit ->
[(fun_id, lfunction ~kind ~params ~return ~body ~attr ~loc ~mode ~region:orig_region)]
[(fun_id, lfunction ~kind ~params ~return ~body ~attr ~loc ~mode ~ret_mode ~region:orig_region)]

(* Simplify local let-bound functions: if all occurrences are
fully-applied function calls in the same "tail scope", replace the
Expand Down
1 change: 1 addition & 0 deletions lambda/simplif.mli
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,6 @@ val split_default_wrapper
-> attr:function_attribute
-> loc:Lambda.scoped_location
-> mode:Lambda.alloc_mode
-> ret_mode:Lambda.alloc_mode
-> region:bool
-> (Ident.t * lambda) list
5 changes: 3 additions & 2 deletions lambda/tmc.ml
Original file line number Diff line number Diff line change
Expand Up @@ -991,9 +991,9 @@ and traverse_binding outer_ctx inner_ctx (var, def) =
(Debuginfo.Scoped_location.to_location lfun.loc)
Warnings.Unused_tmc_attribute;
let direct =
let { kind; params; return; body = _; attr; loc; mode; region } = lfun in
let { kind; params; return; body = _; attr; loc; mode; ret_mode; region } = lfun in
let body = Choice.direct fun_choice in
lfunction ~kind ~params ~return ~body ~attr ~loc ~mode ~region in
lfunction ~kind ~params ~return ~body ~attr ~loc ~mode ~ret_mode ~region in
let dps =
let dst_param = {
var = Ident.create_local "dst";
Expand Down Expand Up @@ -1021,6 +1021,7 @@ and traverse_binding outer_ctx inner_ctx (var, def) =
~attr:lfun.attr
~loc:lfun.loc
~mode:lfun.mode
~ret_mode:lfun.ret_mode
~region:true
in
let dps_var = special.dps_id in
Expand Down
1 change: 1 addition & 0 deletions lambda/transl_list_comprehension.ml
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ let rec translate_bindings
~attr:default_function_attribute
~loc
~mode:alloc_local
~ret_mode:alloc_local
~region:false
~body:(add_bindings body)
in
Expand Down
4 changes: 2 additions & 2 deletions lambda/translattribute.ml
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,8 @@ let check_poll_local loc attr =
()

let lfunction_with_attr ~attr
{ kind; params; return; body; attr=_; loc; mode; region } =
lfunction ~kind ~params ~return ~body ~attr ~loc ~mode ~region
{ kind; params; return; body; attr=_; loc; mode; ret_mode; region } =
lfunction ~kind ~params ~return ~body ~attr ~loc ~mode ~ret_mode ~region

let add_inline_attribute expr loc attributes =
match expr with
Expand Down
16 changes: 12 additions & 4 deletions lambda/translclass.ml
Original file line number Diff line number Diff line change
Expand Up @@ -37,26 +37,28 @@ let layout_meth = layout_any_value
let layout_tables = Lambda.Pvalue Pgenval


let lfunction ?(kind=Curried {nlocal=0}) ?(region=true) return_layout params body =
let lfunction ?(kind=Curried {nlocal=0}) ?(region=true) ?(ret_mode=alloc_heap) return_layout params body =
if params = [] then body else
match kind, body with
| Curried {nlocal=0},
Lfunction {kind = Curried _ as kind; params = params';
body = body'; attr; loc}
body = body'; attr; loc; mode = Alloc_heap; ret_mode; region}
when List.length params + List.length params' <= Lambda.max_arity() ->
lfunction ~kind ~params:(params @ params')
~return:return_layout
~body:body'
~attr
~loc
~mode:alloc_heap
~ret_mode
~region
| _ ->
lfunction ~kind ~params ~return:return_layout
~body
~attr:default_function_attribute
~loc:Loc_unknown
~mode:alloc_heap
~ret_mode
~region

let lapply ap =
Expand Down Expand Up @@ -226,6 +228,7 @@ let rec build_object_init ~scopes cl_table obj params inh_init obj_init cl =
~loc:(of_location ~scopes pat.pat_loc)
~body
~mode:alloc_heap
~ret_mode:alloc_heap
~region:true
in
begin match obj_init with
Expand Down Expand Up @@ -514,6 +517,7 @@ let rec transl_class_rebind ~scopes obj_init cl vf =
~loc:(of_location ~scopes pat.pat_loc)
~body
~mode:alloc_heap
~ret_mode:alloc_heap
~region:true
in
(path, path_lam,
Expand Down Expand Up @@ -792,7 +796,7 @@ let transl_class ~scopes ids cl_id pub_meths cl vflag =
let new_ids_meths = ref [] in
let no_env_update _ _ env = env in
let msubst arr = function
Lfunction {kind = Curried _ as kind; region;
Lfunction {kind = Curried _ as kind; region; ret_mode;
params = self :: args; return; body} ->
let env = Ident.create_local "env" in
let body' =
Expand All @@ -804,7 +808,7 @@ let transl_class ~scopes ids cl_id pub_meths cl vflag =
if not arr || !Clflags.debug then raise Not_found;
builtin_meths [self.name] env env2 (lfunction return args body')
with Not_found ->
[lfunction ~kind ~region return (self :: args)
[lfunction ~kind ~region ~ret_mode return (self :: args)
(if not (Ident.Set.mem env (free_variables body')) then body' else
Llet(Alias, layout_block, env,
Lprim(Pfield_computed Reads_vary,
Expand Down Expand Up @@ -875,6 +879,7 @@ let transl_class ~scopes ids cl_id pub_meths cl vflag =
~loc:Loc_unknown
~return:layout_function
~mode:alloc_heap
~ret_mode:alloc_heap
~region:true
~params:[lparam cla layout_table] ~body:cl_init) in
Llet(Strict, layout_function, class_init, cl_init, lam (free_variables cl_init))
Expand All @@ -900,6 +905,7 @@ let transl_class ~scopes ids cl_id pub_meths cl vflag =
~loc:Loc_unknown
~return:layout_function
~mode:alloc_heap
~ret_mode:alloc_heap
~region:true
~params:[lparam cla layout_table] ~body:cl_init;
lambda_unit; lenvs],
Expand Down Expand Up @@ -960,6 +966,7 @@ let transl_class ~scopes ids cl_id pub_meths cl vflag =
~attr:default_function_attribute
~loc:Loc_unknown
~mode:alloc_heap
~ret_mode:alloc_heap
~region:true
~body:(def_ids cla cl_init), lam)
and lcache lam =
Expand All @@ -985,6 +992,7 @@ let transl_class ~scopes ids cl_id pub_meths cl vflag =
~attr:default_function_attribute
~loc:Loc_unknown
~mode:alloc_heap
~ret_mode:alloc_heap
~region:true
~return:layout_function
~params:[lparam cla layout_table]
Expand Down
Loading

0 comments on commit 2170ee5

Please sign in to comment.