Skip to content

Commit

Permalink
flambda-backend: Mode system exposes bounds (#2449)
Browse files Browse the repository at this point in the history
  • Loading branch information
riaqn authored Apr 16, 2024
1 parent a0777ca commit a1fe776
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,14 @@ let rec fib = function | 0 | 1 -> 1 | n -> (fib (n - 1)) + (fib (n - 2))
<def>
pattern (test_locations.ml[17,534+8]..test_locations.ml[17,534+11])
Tpat_var "fib"
value_mode meet_local,once(0[global,many,global,many]),join_shared(1[shared,shared])
value_mode global,many,shared
expression (test_locations.ml[17,534+14]..test_locations.ml[19,572+34])
Texp_function
region true
alloc_mode map_comonadic(regional_to_global)(6[global,many,global,many]),id(7[shared,shared])
alloc_mode global,many,shared
[]
Tfunction_cases (test_locations.ml[17,534+14]..test_locations.ml[19,572+34])
alloc_mode id(2[global,many,global,many]),id(3[shared,shared])
alloc_mode global,many,shared
value
[
<case>
Expand All @@ -110,11 +110,11 @@ let rec fib = function | 0 | 1 -> 1 | n -> (fib (n - 1)) + (fib (n - 2))
<case>
pattern (test_locations.ml[19,572+4]..test_locations.ml[19,572+5])
Tpat_var "n"
value_mode meet_global,many ∘ map_comonadic(local_to_regional)(2[global,many,global,many]),join_unique(3[shared,shared])
value_mode global,many,unique
expression (test_locations.ml[19,572+9]..test_locations.ml[19,572+34])
Texp_apply
apply_mode Tail
locality_mode proj_areality(10[global,many,global,many])
locality_mode global
expression (test_locations.ml[19,572+21]..test_locations.ml[19,572+22])
Texp_ident "Stdlib!.+"
[
Expand All @@ -123,7 +123,7 @@ let rec fib = function | 0 | 1 -> 1 | n -> (fib (n - 1)) + (fib (n - 2))
expression (test_locations.ml[19,572+9]..test_locations.ml[19,572+20])
Texp_apply
apply_mode Default
locality_mode proj_areality(4[global,many,global,many])
locality_mode global
expression (test_locations.ml[19,572+9]..test_locations.ml[19,572+12])
Texp_ident "fib"
[
Expand All @@ -132,7 +132,7 @@ let rec fib = function | 0 | 1 -> 1 | n -> (fib (n - 1)) + (fib (n - 2))
expression (test_locations.ml[19,572+13]..test_locations.ml[19,572+20])
Texp_apply
apply_mode Default
locality_mode proj_areality(20[global,many,global,many])
locality_mode global
expression (test_locations.ml[19,572+16]..test_locations.ml[19,572+17])
Texp_ident "Stdlib!.-"
[
Expand All @@ -151,7 +151,7 @@ let rec fib = function | 0 | 1 -> 1 | n -> (fib (n - 1)) + (fib (n - 2))
expression (test_locations.ml[19,572+23]..test_locations.ml[19,572+34])
Texp_apply
apply_mode Default
locality_mode proj_areality(4[global,many,global,many])
locality_mode global
expression (test_locations.ml[19,572+23]..test_locations.ml[19,572+26])
Texp_ident "fib"
[
Expand All @@ -160,7 +160,7 @@ let rec fib = function | 0 | 1 -> 1 | n -> (fib (n - 1)) + (fib (n - 2))
expression (test_locations.ml[19,572+27]..test_locations.ml[19,572+34])
Texp_apply
apply_mode Default
locality_mode proj_areality(34[global,many,global,many])
locality_mode global
expression (test_locations.ml[19,572+30]..test_locations.ml[19,572+31])
Texp_ident "Stdlib!.-"
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,14 @@ let rec fib = function | 0 | 1 -> 1 | n -> (fib (n - 1)) + (fib (n - 2))
<def>
pattern
Tpat_var "fib"
value_mode meet_local,once(0[global,many,global,many]),join_shared(1[shared,shared])
value_mode global,many,shared
expression
Texp_function
region true
alloc_mode map_comonadic(regional_to_global)(6[global,many,global,many]),id(7[shared,shared])
alloc_mode global,many,shared
[]
Tfunction_cases
alloc_mode id(2[global,many,global,many]),id(3[shared,shared])
alloc_mode global,many,shared
value
[
<case>
Expand All @@ -110,11 +110,11 @@ let rec fib = function | 0 | 1 -> 1 | n -> (fib (n - 1)) + (fib (n - 2))
<case>
pattern
Tpat_var "n"
value_mode meet_global,many ∘ map_comonadic(local_to_regional)(2[global,many,global,many]),join_unique(3[shared,shared])
value_mode global,many,unique
expression
Texp_apply
apply_mode Tail
locality_mode proj_areality(10[global,many,global,many])
locality_mode global
expression
Texp_ident "Stdlib!.+"
[
Expand All @@ -123,7 +123,7 @@ let rec fib = function | 0 | 1 -> 1 | n -> (fib (n - 1)) + (fib (n - 2))
expression
Texp_apply
apply_mode Default
locality_mode proj_areality(4[global,many,global,many])
locality_mode global
expression
Texp_ident "fib"
[
Expand All @@ -132,7 +132,7 @@ let rec fib = function | 0 | 1 -> 1 | n -> (fib (n - 1)) + (fib (n - 2))
expression
Texp_apply
apply_mode Default
locality_mode proj_areality(20[global,many,global,many])
locality_mode global
expression
Texp_ident "Stdlib!.-"
[
Expand All @@ -151,7 +151,7 @@ let rec fib = function | 0 | 1 -> 1 | n -> (fib (n - 1)) + (fib (n - 2))
expression
Texp_apply
apply_mode Default
locality_mode proj_areality(4[global,many,global,many])
locality_mode global
expression
Texp_ident "fib"
[
Expand All @@ -160,7 +160,7 @@ let rec fib = function | 0 | 1 -> 1 | n -> (fib (n - 1)) + (fib (n - 2))
expression
Texp_apply
apply_mode Default
locality_mode proj_areality(34[global,many,global,many])
locality_mode global
expression
Texp_ident "Stdlib!.-"
[
Expand Down
62 changes: 27 additions & 35 deletions typing/mode.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1195,15 +1195,21 @@ module Common (Obj : Obj) = struct

let print ?verbose () ppf m = Solver.print ?verbose obj ppf m

let print_raw ?verbose () ppf m = Solver.print_raw ?verbose obj ppf m

let zap_to_ceil m = with_log (Solver.zap_to_ceil obj m)

let zap_to_floor m = with_log (Solver.zap_to_floor obj m)

let of_const : type l r. const -> (l * r) t = fun a -> Solver.of_const obj a

let check_const m = Solver.check_const obj m
module Guts = struct
let get_floor m = Solver.get_floor obj m

let get_ceil m = Solver.get_ceil obj m

let get_conservative_floor m = Solver.get_conservative_floor obj m

let get_conservative_ceil m = Solver.get_conservative_ceil obj m
end
end
[@@inline]

Expand All @@ -1227,6 +1233,18 @@ module Locality = struct
let legacy = of_const Const.legacy

let zap_to_legacy = zap_to_floor

module Guts = struct
let check_const m =
let floor = Guts.get_floor m in
let ceil = Guts.get_ceil m in
if Const.le ceil floor then Some ceil else None

let check_const_conservative m =
let floor = Guts.get_conservative_floor m in
let ceil = Guts.get_conservative_ceil m in
if Const.le ceil floor then Some ceil else None
end
end

module Regionality = struct
Expand Down Expand Up @@ -1430,12 +1448,6 @@ module Comonadic_with_locality = struct

(* override to report the offending axis *)
let equate a b = try_with_log (equate_from_submode submode_log a b)

(** overriding to check per-axis *)
let check_const m =
let locality = Locality.check_const (proj Areality m) in
let linearity = Linearity.check_const (proj Linearity m) in
locality, linearity
end

module Monadic = struct
Expand Down Expand Up @@ -1502,11 +1514,6 @@ module Monadic = struct

(* override to report the offending axis *)
let equate a b = try_with_log (equate_from_submode submode_log a b)

(** overriding to check per-axis *)
let check_const m =
let uniqueness = Uniqueness.check_const (proj Uniqueness m) in
uniqueness, ()
end

type ('mo, 'como) monadic_comonadic =
Expand Down Expand Up @@ -1550,13 +1557,6 @@ module Value = struct
let uniqueness, () = monadic in
{ regionality; linearity; uniqueness }

let print_raw ?verbose () ppf { monadic; comonadic } =
Format.fprintf ppf "%a,%a"
(Comonadic.print_raw ?verbose ())
comonadic
(Monadic.print_raw ?verbose ())
monadic

let print ?verbose () ppf { monadic; comonadic } =
Format.fprintf ppf "%a,%a"
(Comonadic.print ?verbose ())
Expand Down Expand Up @@ -1585,7 +1585,9 @@ module Value = struct
let m1 = split m1 in
Comonadic.le m0.comonadic m1.comonadic && Monadic.le m0.monadic m1.monadic

let print ppf m = print_raw () ppf (of_const m)
let print ppf m =
let { monadic; comonadic } = split m in
Format.fprintf ppf "%a,%a" Comonadic.print comonadic Monadic.print monadic

let legacy =
merge { comonadic = Comonadic.legacy; monadic = Monadic.legacy }
Expand Down Expand Up @@ -1919,13 +1921,6 @@ module Alloc = struct
let equate_exn m0 m1 =
match equate m0 m1 with Ok () -> () | Error _ -> invalid_arg "equate_exn"

let print_raw ?verbose () ppf { monadic; comonadic } =
Format.fprintf ppf "%a,%a"
(Comonadic.print_raw ?verbose ())
comonadic
(Monadic.print_raw ?verbose ())
monadic

let print ?verbose () ppf { monadic; comonadic } =
Format.fprintf ppf "%a,%a"
(Comonadic.print ?verbose ())
Expand Down Expand Up @@ -2069,7 +2064,9 @@ module Alloc = struct
let m1 = split m1 in
Comonadic.le m0.comonadic m1.comonadic && Monadic.le m0.monadic m1.monadic

let print ppf m = print_raw () ppf (of_const m)
let print ppf m =
let { monadic; comonadic } = split m in
Format.fprintf ppf "%a,%a" Comonadic.print comonadic Monadic.print monadic

let legacy =
merge { comonadic = Comonadic.legacy; monadic = Monadic.legacy }
Expand Down Expand Up @@ -2158,11 +2155,6 @@ module Alloc = struct
let comonadic = Comonadic.zap_to_legacy comonadic in
merge { monadic; comonadic }

let check_const { comonadic; monadic } =
let comonadic = Comonadic.check_const comonadic in
let monadic = Monadic.check_const monadic in
merge { monadic; comonadic }

(** This is about partially applying [A -> B -> C] to [A] and getting [B ->
C]. [comonadic] and [monadic] constutute the mode of [A], and we need to
give the lower bound mode of [B -> C]. *)
Expand Down
24 changes: 16 additions & 8 deletions typing/mode_intf.mli
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,7 @@ module type Common = sig

val newvar_below : ('l * allowed) t -> ('l_ * 'r) t * bool

val print_raw :
?verbose:bool -> unit -> Format.formatter -> ('l * 'r) t -> unit

val print :
?verbose:bool -> unit -> Format.formatter -> (allowed * allowed) t -> unit
val print : ?verbose:bool -> unit -> Format.formatter -> ('l * 'r) t -> unit

val of_const : Const.t -> ('l * 'r) t
end
Expand Down Expand Up @@ -147,7 +143,21 @@ module type S = sig

val zap_to_ceil : ('l * allowed) t -> Const.t

val check_const : (allowed * allowed) t -> Const.t option
module Guts : sig
(** This module exposes some functions that allow callers to inspect modes
directly, which could be useful for error printing and dev tools (such as
merlin). Any usage of this in type checking should be pondered. *)

(** Returns [Some c] if the given mode has been constrained to constant
[c]. see notes on [get_floor] in [solver_intf.mli] for cautions. *)
val check_const : (allowed * allowed) t -> Const.t option

(** Similar to [check_const] but doesn't run the further constraining
needed for precise bounds. As a result, it is inexpensive and returns
a conservative result. I.e., it might return [None] for
fully-constrained modes. *)
val check_const_conservative : (l * 'r) t -> Const.t option
end
end

module Regionality : sig
Expand Down Expand Up @@ -389,8 +399,6 @@ module type S = sig
and type error := error
and type 'd t := 'd t

val check_const : (allowed * allowed) t -> Const.Option.t

val proj : ('m, 'a, 'l * 'r) axis -> ('l * 'r) t -> 'm

val max_with : ('m, 'a, 'l * 'r) axis -> 'm -> (disallowed * 'r) t
Expand Down
6 changes: 3 additions & 3 deletions typing/printtyped.ml
Original file line number Diff line number Diff line change
Expand Up @@ -404,16 +404,16 @@ and expression_extra i ppf x attrs =
attributes i ppf attrs;

and alloc_mode: type l r. _ -> _ -> (l * r) Mode.Alloc.t -> _
= fun i ppf m -> line i ppf "alloc_mode %a\n" (Mode.Alloc.print_raw ()) m
= fun i ppf m -> line i ppf "alloc_mode %a\n" (Mode.Alloc.print ()) m

and alloc_mode_option i ppf m = Option.iter (alloc_mode i ppf) m

and locality_mode i ppf m =
line i ppf "locality_mode %a\n"
(Mode.Locality.print_raw ()) m
(Mode.Locality.print ()) m

and value_mode i ppf m =
line i ppf "value_mode %a\n" (Mode.Value.print_raw ()) m
line i ppf "value_mode %a\n" (Mode.Value.print ()) m

and expression_alloc_mode i ppf (expr, am) =
alloc_mode i ppf am;
Expand Down
Loading

0 comments on commit a1fe776

Please sign in to comment.