Skip to content

Commit

Permalink
flambda-backend: Generalize unique_to_linear to `monadic_to_comonad…
Browse files Browse the repository at this point in the history
…ic` (#2351)

* generalize unique_to_linear for variables

* generalize unique_to_linear for constants

* fix typo

Co-authored-by: Richard Eisenberg <[email protected]>

---------

Co-authored-by: Richard Eisenberg <[email protected]>
  • Loading branch information
riaqn and goldfirere authored Mar 13, 2024
1 parent e132455 commit c723951
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 102 deletions.
15 changes: 8 additions & 7 deletions typing/env.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2977,22 +2977,23 @@ let share_mode ~errors ~env ~loc id vmode shared_context =
(Once_value_used_in (id, shared_context))
| Ok () -> Mode.Value.join [Mode.Value.min_with_uniqueness Mode.Uniqueness.shared; vmode]

let closure_mode ~errors ~env ~loc id vmode closure_context comonadic =
let closure_mode ~errors ~env ~loc id {Mode.monadic; comonadic}
closure_context comonadic0 : Mode.Value.l =
begin
match
Mode.Value.Comonadic.submode vmode.Mode.comonadic comonadic
Mode.Value.Comonadic.submode comonadic comonadic0
with
| Error e ->
may_lookup_error errors loc env
(Value_used_in_closure (id, e, closure_context))
| Ok () -> ()
end;
let uniqueness =
Mode.Uniqueness.join
[ Mode.Value.uniqueness vmode;
Mode.linear_to_unique (Mode.Value.Comonadic.linearity comonadic) ]
let monadic =
Mode.Value.Monadic.join
[ monadic;
Mode.Value.comonadic_to_monadic comonadic0 ]
in
Mode.Value.join [Mode.Value.min_with_uniqueness uniqueness; vmode]
{monadic; comonadic}

let exclave_mode ~errors ~env ~loc id vmode =
match
Expand Down
197 changes: 113 additions & 84 deletions typing/mode.ml
Original file line number Diff line number Diff line change
Expand Up @@ -551,10 +551,16 @@ module Lattices_mono = struct
('a0, 'a1, 'd) morph
-> ('a0 comonadic_with, 'a1 comonadic_with, 'd) morph
(** Lift an morphism on areality to a morphism on the comonadic fragment *)
| Unique_to_linear : (Uniqueness_op.t, Linearity.t, 'l * 'r) morph
(** Returns the linearity dual to the given uniqueness *)
| Linear_to_unique : (Linearity.t, Uniqueness_op.t, 'l * 'r) morph
(** Returns the uniqueness dual to the given linearity *)
| Monadic_to_comonadic_min
: (Monadic_op.t, 'a comonadic_with, 'l * disallowed) morph
(** Dualize the monadic fragment to the comonadic fragment. The areality is set to min. *)
| Comonadic_to_monadic :
'a comonadic_with obj
-> ('a comonadic_with, Monadic_op.t, 'l * 'r) morph
(** Dualize the comonadic fragment to the monadic fragment. The areality axis is ignored. *)
| Monadic_to_comonadic_max
: (Monadic_op.t, 'a comonadic_with, disallowed * 'r) morph
(** Dualize the monadic fragment to the comonadic fragment. The areality is set to max. *)
(* Following is a chain of adjunction (complete and cannot extend in
either direction) *)
| Local_to_regional : (Locality.t, Regionality.t, 'l * disallowed) morph
Expand Down Expand Up @@ -585,8 +591,8 @@ module Lattices_mono = struct
let f = allow_left f in
let g = allow_left g in
Compose (f, g)
| Unique_to_linear -> Unique_to_linear
| Linear_to_unique -> Linear_to_unique
| Monadic_to_comonadic_min -> Monadic_to_comonadic_min
| Comonadic_to_monadic a -> Comonadic_to_monadic a
| Local_to_regional -> Local_to_regional
| Locality_as_regionality -> Locality_as_regionality
| Regional_to_local -> Regional_to_local
Expand All @@ -607,8 +613,8 @@ module Lattices_mono = struct
let f = allow_right f in
let g = allow_right g in
Compose (f, g)
| Unique_to_linear -> Unique_to_linear
| Linear_to_unique -> Linear_to_unique
| Comonadic_to_monadic a -> Comonadic_to_monadic a
| Monadic_to_comonadic_max -> Monadic_to_comonadic_max
| Global_to_regional -> Global_to_regional
| Locality_as_regionality -> Locality_as_regionality
| Regional_to_local -> Regional_to_local
Expand All @@ -632,8 +638,9 @@ module Lattices_mono = struct
let f = disallow_left f in
let g = disallow_left g in
Compose (f, g)
| Unique_to_linear -> Unique_to_linear
| Linear_to_unique -> Linear_to_unique
| Monadic_to_comonadic_min -> Monadic_to_comonadic_min
| Comonadic_to_monadic a -> Comonadic_to_monadic a
| Monadic_to_comonadic_max -> Monadic_to_comonadic_max
| Local_to_regional -> Local_to_regional
| Global_to_regional -> Global_to_regional
| Locality_as_regionality -> Locality_as_regionality
Expand All @@ -658,8 +665,9 @@ module Lattices_mono = struct
let f = disallow_right f in
let g = disallow_right g in
Compose (f, g)
| Unique_to_linear -> Unique_to_linear
| Linear_to_unique -> Linear_to_unique
| Monadic_to_comonadic_min -> Monadic_to_comonadic_min
| Comonadic_to_monadic a -> Comonadic_to_monadic a
| Monadic_to_comonadic_max -> Monadic_to_comonadic_max
| Local_to_regional -> Local_to_regional
| Global_to_regional -> Global_to_regional
| Locality_as_regionality -> Locality_as_regionality
Expand All @@ -684,8 +692,9 @@ module Lattices_mono = struct
| Compose (f, g) ->
let mid = src dst f in
src mid g
| Unique_to_linear -> Uniqueness_op
| Linear_to_unique -> Linearity
| Monadic_to_comonadic_min -> Monadic_op
| Comonadic_to_monadic src -> src
| Monadic_to_comonadic_max -> Monadic_op
| Local_to_regional -> Locality
| Locality_as_regionality -> Locality
| Global_to_regional -> Locality
Expand Down Expand Up @@ -728,8 +737,10 @@ module Lattices_mono = struct
| Join_with c0, Join_with c1 -> if c0 = c1 then Some Refl else None
| Imply c0, Imply c1 -> if c0 = c1 then Some Refl else None
| Subtract c0, Subtract c1 -> if c0 = c1 then Some Refl else None
| Unique_to_linear, Unique_to_linear -> Some Refl
| Linear_to_unique, Linear_to_unique -> Some Refl
| Monadic_to_comonadic_min, Monadic_to_comonadic_min -> Some Refl
| Comonadic_to_monadic a0, Comonadic_to_monadic a1 -> (
match eq_obj a0 a1 with None -> None | Some Refl -> Some Refl)
| Monadic_to_comonadic_max, Monadic_to_comonadic_max -> Some Refl
| Local_to_regional, Local_to_regional -> Some Refl
| Locality_as_regionality, Locality_as_regionality -> Some Refl
| Global_to_regional, Global_to_regional -> Some Refl
Expand All @@ -743,7 +754,8 @@ module Lattices_mono = struct
| Map_comonadic f, Map_comonadic g -> (
match equal f g with Some Refl -> Some Refl | None -> None)
| ( ( Id | Proj _ | Max_with _ | Min_with _ | Meet_with _ | Join_with _
| Unique_to_linear | Linear_to_unique | Local_to_regional
| Monadic_to_comonadic_min | Comonadic_to_monadic _
| Monadic_to_comonadic_max | Local_to_regional
| Locality_as_regionality | Global_to_regional | Regional_to_local
| Regional_to_global | Compose _ | Map_comonadic _ | Imply _
| Subtract _ ),
Expand All @@ -767,8 +779,9 @@ module Lattices_mono = struct
| Map_comonadic f ->
let dst0 = proj_obj Areality dst in
Format.fprintf ppf "map_comonadic(%a)" (print_morph dst0) f
| Unique_to_linear -> Format.fprintf ppf "unique_to_linear"
| Linear_to_unique -> Format.fprintf ppf "linear_to_unique"
| Monadic_to_comonadic_min -> Format.fprintf ppf "monadic_to_comonadic_min"
| Comonadic_to_monadic _ -> Format.fprintf ppf "comonadic_to_monadic"
| Monadic_to_comonadic_max -> Format.fprintf ppf "monadic_to_comonadic_max"
| Local_to_regional -> Format.fprintf ppf "local_to_regional"
| Regional_to_local -> Format.fprintf ppf "regional_to_local"
| Locality_as_regionality -> Format.fprintf ppf "locality_as_regionality"
Expand Down Expand Up @@ -814,6 +827,27 @@ module Lattices_mono = struct

let max_with dst ax a = update ax a (max dst)

let monadic_to_comonadic_min :
type a. a comonadic_with obj -> Monadic_op.t -> a comonadic_with =
fun obj (uniqueness, ()) ->
match obj with
| Comonadic_with_locality -> Locality.min, unique_to_linear uniqueness
| Comonadic_with_regionality -> Regionality.min, unique_to_linear uniqueness

let comonadic_to_monadic :
type a. a comonadic_with obj -> a comonadic_with -> Monadic_op.t =
fun obj (_, linearity) ->
match obj with
| Comonadic_with_locality -> linear_to_unique linearity, ()
| Comonadic_with_regionality -> linear_to_unique linearity, ()

let monadic_to_comonadic_max :
type a. a comonadic_with obj -> Monadic_op.t -> a comonadic_with =
fun obj (uniqueness, ()) ->
match obj with
| Comonadic_with_locality -> Locality.max, unique_to_linear uniqueness
| Comonadic_with_regionality -> Regionality.max, unique_to_linear uniqueness

let rec apply : type a b d. b obj -> (a, b, d) morph -> a -> b =
fun dst f a ->
match f with
Expand All @@ -830,8 +864,9 @@ module Lattices_mono = struct
| Join_with c -> join dst c a
| Imply c -> imply dst c a
| Subtract c -> subtract dst c a
| Unique_to_linear -> unique_to_linear a
| Linear_to_unique -> linear_to_unique a
| Monadic_to_comonadic_min -> monadic_to_comonadic_min dst a
| Comonadic_to_monadic src -> comonadic_to_monadic src a
| Monadic_to_comonadic_max -> monadic_to_comonadic_max dst a
| Local_to_regional -> local_to_regional a
| Regional_to_local -> regional_to_local a
| Locality_as_regionality -> locality_as_regionality a
Expand Down Expand Up @@ -878,8 +913,9 @@ module Lattices_mono = struct
match ax with
| Areality -> Some (compose dst f (Proj (src', Areality)))
| Linearity -> Some (Proj (src', Linearity)))
| Unique_to_linear, Linear_to_unique -> Some Id
| Linear_to_unique, Unique_to_linear -> Some Id
| Proj _, Monadic_to_comonadic_min -> None
| Proj _, Monadic_to_comonadic_max -> None
| Proj _, Comonadic_to_monadic _ -> None
| Map_comonadic f, Map_comonadic g ->
let dst0 = proj_obj Areality dst in
Some (Map_comonadic (compose dst0 f g))
Expand Down Expand Up @@ -912,14 +948,6 @@ module Lattices_mono = struct
(compose dst
(Join_with (locality_as_regionality c))
Locality_as_regionality)
| Unique_to_linear, Meet_with c ->
Some (compose dst (Meet_with (unique_to_linear c)) Unique_to_linear)
| Unique_to_linear, Join_with c ->
Some (compose dst (Join_with (unique_to_linear c)) Unique_to_linear)
| Linear_to_unique, Meet_with c ->
Some (compose dst (Meet_with (linear_to_unique c)) Linear_to_unique)
| Linear_to_unique, Join_with c ->
Some (compose dst (Join_with (linear_to_unique c)) Linear_to_unique)
| Map_comonadic f, Join_with c ->
let dst0 = proj_obj Areality dst in
let areality, linearity = c in
Expand Down Expand Up @@ -955,19 +983,19 @@ module Lattices_mono = struct
| Subtract _, _ -> None
| _, Proj _ -> None
| Map_comonadic _, _ -> None
| Monadic_to_comonadic_min, _ -> None
| Monadic_to_comonadic_max, _ -> None
| Comonadic_to_monadic _, _ -> None
| ( Proj _,
( Unique_to_linear | Linear_to_unique | Local_to_regional
| Regional_to_local | Locality_as_regionality | Regional_to_global
| Global_to_regional ) ) ->
( Local_to_regional | Regional_to_local | Locality_as_regionality
| Regional_to_global | Global_to_regional ) ) ->
.
| ( ( Unique_to_linear | Linear_to_unique | Local_to_regional
| Regional_to_local | Locality_as_regionality | Regional_to_global
| Global_to_regional ),
| ( ( Local_to_regional | Regional_to_local | Locality_as_regionality
| Regional_to_global | Global_to_regional ),
Min_with _ ) ->
.
| ( ( Unique_to_linear | Linear_to_unique | Local_to_regional
| Regional_to_local | Locality_as_regionality | Regional_to_global
| Global_to_regional ),
| ( ( Local_to_regional | Regional_to_local | Locality_as_regionality
| Regional_to_global | Global_to_regional ),
Max_with _ ) ->
.

Expand All @@ -992,8 +1020,8 @@ module Lattices_mono = struct
Compose (g', f')
| Join_with c -> Subtract c
| Imply c -> Meet_with c
| Unique_to_linear -> Linear_to_unique
| Linear_to_unique -> Unique_to_linear
| Comonadic_to_monadic _ -> Monadic_to_comonadic_min
| Monadic_to_comonadic_max -> Comonadic_to_monadic dst
| Global_to_regional -> Regional_to_global
| Regional_to_global -> Locality_as_regionality
| Locality_as_regionality -> Regional_to_local
Expand All @@ -1018,8 +1046,8 @@ module Lattices_mono = struct
Compose (g', f')
| Meet_with c -> Imply c
| Subtract c -> Join_with c
| Unique_to_linear -> Linear_to_unique
| Linear_to_unique -> Unique_to_linear
| Comonadic_to_monadic _ -> Monadic_to_comonadic_max
| Monadic_to_comonadic_min -> Comonadic_to_monadic dst
| Local_to_regional -> Regional_to_local
| Regional_to_local -> Locality_as_regionality
| Locality_as_regionality -> Regional_to_global
Expand Down Expand Up @@ -1210,12 +1238,6 @@ module Uniqueness = struct
let zap_to_legacy = zap_to_ceil
end

let unique_to_linear m =
S.Positive.via_antitone Linearity.Obj.obj C.Unique_to_linear m

let linear_to_unique m =
S.Negative.via_antitone Uniqueness.Obj.obj C.Linear_to_unique m

let regional_to_local m =
S.Positive.via_monotone Locality.Obj.obj C.Regional_to_local m

Expand All @@ -1225,10 +1247,6 @@ let locality_as_regionality m =
let regional_to_global m =
S.Positive.via_monotone Locality.Obj.obj C.Regional_to_global m

module Const = struct
let unique_to_linear a = C.unique_to_linear a
end

module Comonadic_with_regionality = struct
module Const = C.Comonadic_with_regionality

Expand Down Expand Up @@ -1697,6 +1715,10 @@ module Value = struct
let monadic = Monadic.meet mo in
{ comonadic; monadic }

let comonadic_to_monadic m =
S.Negative.via_antitone Monadic.Obj.obj
(Comonadic_to_monadic Comonadic.Obj.obj) m

module Const = struct
type t = Regionality.Const.t * Linearity.Const.t * Uniqueness.Const.t

Expand Down Expand Up @@ -1935,6 +1957,10 @@ module Alloc = struct
let monadic = Monadic.meet mo in
{ comonadic; monadic }

let monadic_to_comonadic_min m =
S.Positive.via_antitone Comonadic.Obj.obj Monadic_to_comonadic_min
(Monadic.disallow_left m)

module Const = struct
type ('loc, 'lin, 'uni) modes =
{ locality : 'loc;
Expand Down Expand Up @@ -2006,26 +2032,31 @@ module Alloc = struct
{ locality; uniqueness; linearity }
end

let split { locality; linearity; uniqueness } =
let monadic = uniqueness, () in
let comonadic = locality, linearity in
{ comonadic; monadic }

let merge { comonadic; monadic } =
let locality, linearity = comonadic in
let uniqueness, () = monadic in
{ locality; linearity; uniqueness }

(** See [Alloc.close_over] for explanation. *)
let close_over m =
let locality = m.locality in
(* uniqueness of the returned function is not constrained *)
let uniqueness = Uniqueness.Const.min in
let linearity =
Linearity.Const.join m.linearity
(* In addition, unique argument make the returning function once.
In other words, if argument <= unique, returning function >= once.
That is, returning function >= (dual of argument) *)
(Const.unique_to_linear m.uniqueness)
let { monadic; comonadic } = split m in
let comonadic =
Comonadic.Const.join comonadic
(C.monadic_to_comonadic_min Comonadic.Obj.obj monadic)
in
{ locality; linearity; uniqueness }
let monadic = Monadic.Const.min in
merge { comonadic; monadic }

(** See [Alloc.partial_apply] for explanation. *)
let partial_apply m =
let locality = m.locality in
let uniqueness = Uniqueness.Const.min in
let linearity = m.linearity in
{ locality; linearity; uniqueness }
let { comonadic; _ } = split m in
let monadic = Monadic.Const.min in
merge { comonadic; monadic }
end

let of_const = Const.of_const
Expand Down Expand Up @@ -2054,25 +2085,23 @@ module Alloc = struct
C]. [comonadic] and [monadic] constutute the mode of [A], and we need to
give the lower bound mode of [B -> C]. *)
let close_over { comonadic; monadic } =
(* If [A] is [local], [B -> C] containining a pointer to [A] must
be [local] too. *)
let locality = min_with_locality (Comonadic.locality comonadic) in
(* [B -> C] is arrow type and thus crosses uniqueness *)
(* If [A] is [once], [B -> C] containing a pointer to [A] must be [once] too
*)
let linearity0 = min_with_linearity (Comonadic.linearity comonadic) in
(* Moreover, if [A] is [unique], [B -> C] must be [once]. *)
let linearity1 =
min_with_linearity (unique_to_linear (Monadic.uniqueness monadic))
in
join [locality; linearity0; linearity1]
let comonadic = Comonadic.disallow_right comonadic in
(* The comonadic of the returned function is constrained by the monadic of the closed argument via the dualizing morphism. *)
let comonadic1 = monadic_to_comonadic_min monadic in
(* It's also constrained by the comonadic of the closed argument. *)
let comonadic = Comonadic.join [comonadic; comonadic1] in
(* The returned function crosses all monadic axes that we know of
(uniqueness/contention). *)
let monadic = Monadic.disallow_right Monadic.min in
{ comonadic; monadic }

(** Similar to above, but we are given the mode of [A -> B -> C], and need to
give the lower bound mode of [B -> C]. *)
let partial_apply alloc_mode =
(* [B -> C] should be always higher than [A -> B -> C] except the uniqueness
axis where it's not constrained *)
meet_with_uniqueness Unique alloc_mode
let partial_apply { comonadic; _ } =
(* The returned function crosses all monadic axes that we know of. *)
let monadic = Monadic.disallow_right Monadic.min in
let comonadic = Comonadic.disallow_right comonadic in
{ comonadic; monadic }
end

let alloc_as_value m =
Expand Down
Loading

0 comments on commit c723951

Please sign in to comment.