Skip to content

Commit

Permalink
Support for list recombinations
Browse files Browse the repository at this point in the history
The primary use-case for this was to be able to run computations on a list of
structures, then return an updated list with some fields in the structures
modified : that is what we need for distribution of tax amounts among household
members, for example.

This patch has a few components:

- Addition of a test as an example for tax distributions

- Added a transformation, performed during desugaring, that -- where lists are
  syntactically expected, i.e. after the `among` keyword -- turns a (syntactic)
  tuple of lists into a list of tuples ("zipping" the lists)

- Arg-extremum transformation was also fixed to use an intermediate list instead
  of computing the predicate twice

- For convenience, allow to bind multiple variables in most* list
  operations (previously only `let in` and functions allowed it)

- Fixed the printer for tuples to differentiate them from lists

*Note: tuples are not yet allowed on the left-hand side of filters and
arg-extremums for annoying syntax conflict reasons.
  • Loading branch information
AltGr committed Jan 26, 2024
1 parent 13bc62a commit 371f955
Show file tree
Hide file tree
Showing 13 changed files with 359 additions and 90 deletions.
146 changes: 115 additions & 31 deletions compiler/desugared/from_surface.ml
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,40 @@ let rec translate_expr
let rec_helper ?(local_vars = local_vars) e =
translate_expr scope inside_definition_of ctxt local_vars e
in
let rec detuplify_list = function
(* Where a list is expected (e.g. after [among]), as syntactic sugar, if a
tuple is found instead we transpose it into a list of tuples *)
| S.Tuple ls, pos ->
let m = Untyped { pos } in
let ls = List.map detuplify_list ls in
let rec zip = function
| [] -> assert false
| [l] -> l
| l1 :: r ->
let rhs = zip r in
let rtys, explode =
match List.length r with
| 1 -> (TAny, pos), fun e -> [e]
| size ->
( (TTuple (List.map (fun _ -> TAny, pos) r), pos),
fun e ->
List.init size (fun index ->
Expr.etupleaccess ~e ~size ~index m) )
in
let tys = [TAny, pos; rtys] in
let f_join =
let x1 = Var.make "x1" and x2 = Var.make "x2" in
Expr.make_abs [| x1; x2 |]
(Expr.make_tuple (Expr.evar x1 m :: explode (Expr.evar x2 m)) m)
tys pos
in
Expr.eappop ~op:Map2 ~args:[f_join; l1; rhs]
~tys:((TAny, pos) :: List.map (fun ty -> TArray ty, pos) tys)
m
in
zip ls
| e -> rec_helper e
in
let pos = Mark.get expr in
let emark = Untyped { pos } in
match Mark.remove expr with
Expand Down Expand Up @@ -629,16 +663,39 @@ let rec translate_expr
| ArrayLit es -> Expr.earray (List.map rec_helper es) emark
| Tuple es -> Expr.etuple (List.map rec_helper es) emark
| CollectionOp (((S.Filter { f } | S.Map { f }) as op), collection) ->
let collection = rec_helper collection in
let param_name, predicate = f in
let param = Var.make (Mark.remove param_name) in
let local_vars = Ident.Map.add (Mark.remove param_name) param local_vars in
let collection = detuplify_list collection in
let param_names, predicate = f in
let params = List.map (fun n -> Var.make (Mark.remove n)) param_names in
let local_vars =
List.fold_left2
(fun vars n p -> Ident.Map.add (Mark.remove n) p vars)
local_vars param_names params
in
let f_pred =
Expr.make_abs [| param |]
Expr.make_abs (Array.of_list params)
(rec_helper ~local_vars predicate)
[TAny, pos]
(List.map (fun _ -> TAny, pos) params)
pos
in
let f_pred =
(* Detuplification (TODO: check if we couldn't fit this in the general
detuplification later) *)
match List.length param_names with
| 1 -> f_pred
| nb_args ->
let v =
Var.make (String.concat "_" (List.map Mark.remove param_names))
in
let x = Expr.evar v emark in
let tys = List.map (fun _ -> TAny, pos) param_names in
Expr.make_abs [| v |]
(Expr.make_app f_pred
(List.init nb_args (fun i ->
Expr.etupleaccess ~e:x ~index:i ~size:nb_args emark))
tys pos)
[TAny, pos]
pos
in
Expr.eappop
~op:
(match op with
Expand All @@ -648,64 +705,91 @@ let rec translate_expr
~tys:[TAny, pos; TAny, pos]
~args:[f_pred; collection] emark
| CollectionOp
( S.AggregateArgExtremum { max; default; f = param_name, predicate },
( S.AggregateArgExtremum { max; default; f = param_names, predicate },
collection ) ->
let default = rec_helper default in
let pos_dft = Expr.pos default in
let collection = rec_helper collection in
let param = Var.make (Mark.remove param_name) in
let local_vars = Ident.Map.add (Mark.remove param_name) param local_vars in
let collection = detuplify_list collection in
let params = List.map (fun n -> Var.make (Mark.remove n)) param_names in
let local_vars =
List.fold_left2
(fun vars n p -> Ident.Map.add (Mark.remove n) p vars)
local_vars param_names params
in
let cmp_op = if max then Op.Gt else Op.Lt in
let f_pred =
Expr.make_abs [| param |]
Expr.make_abs (Array.of_list params)
(rec_helper ~local_vars predicate)
[TAny, pos]
pos
in
let param_name = Bindlib.name_of param in
let v1, v2 = Var.make (param_name ^ "_1"), Var.make (param_name ^ "_2") in
let x1 = Expr.make_var v1 emark in
let x2 = Expr.make_var v2 emark in
let add_weight_f =
let vs = List.map (fun p -> Var.make (Bindlib.name_of p)) params in
let xs = List.map (fun v -> Expr.evar v emark) vs in
let x = match xs with [x] -> x | xs -> Expr.etuple xs emark in
Expr.make_abs (Array.of_list vs)
(Expr.make_tuple [x; Expr.eapp ~f:f_pred ~args:xs ~tys:[] emark] emark)
[TAny, pos]
pos
in
let reduce_f =
(* fun x1 x2 -> cmp_op (pred x1) (pred x2) *)
(* Note: this computes f_pred twice on every element, but we'd rather not
rely on returning tuples here *)
(* fun x1 x2 -> if cmp_op (x1.2) (x2.2) cmp *)
let v1, v2 = Var.make "x1", Var.make "x2" in
let x1, x2 = Expr.make_var v1 emark, Expr.make_var v2 emark in
Expr.make_abs [| v1; v2 |]
(Expr.eifthenelse
(Expr.eappop ~op:cmp_op
~tys:[TAny, pos_dft; TAny, pos_dft]
~args:
[
Expr.eapp ~f:f_pred ~args:[x1] ~tys:[] emark;
Expr.eapp ~f:f_pred ~args:[x2] ~tys:[] emark;
Expr.etupleaccess ~e:x1 ~index:1 ~size:2 emark;
Expr.etupleaccess ~e:x2 ~index:1 ~size:2 emark;
]
emark)
x1 x2 emark)
[TAny, pos; TAny, pos]
pos
in
Expr.eappop ~op:Reduce
~tys:[TAny, pos; TAny, pos; TAny, pos]
~args:[reduce_f; default; collection]
emark
let weights_var = Var.make "weights" in
let default = Expr.make_app add_weight_f [default] [TAny, pos] pos_dft in
let weighted_result =
Expr.make_let_in weights_var
(TArray (TTuple [TAny, pos; TAny, pos], pos), pos)
(Expr.eappop ~op:Map
~tys:[TAny, pos; TArray (TAny, pos), pos]
~args:[add_weight_f; collection] emark)
(Expr.eappop ~op:Reduce
~tys:[TAny, pos; TAny, pos; TAny, pos]
~args:[reduce_f; default; Expr.evar weights_var emark]
emark)
pos
in
Expr.etupleaccess ~e:weighted_result ~index:0 ~size:2 emark
| CollectionOp
(((Exists { predicate } | Forall { predicate }) as op), collection) ->
let collection = rec_helper collection in
let collection = detuplify_list collection in
let init, op =
match op with
| Exists _ -> false, S.Or
| Forall _ -> true, S.And
| _ -> assert false
in
let init = Expr.elit (LBool init) emark in
let param0, predicate = predicate in
let param = Var.make (Mark.remove param0) in
let local_vars = Ident.Map.add (Mark.remove param0) param local_vars in
let params0, predicate = predicate in
let params = List.map (fun n -> Var.make (Mark.remove n)) params0 in
let local_vars =
List.fold_left2
(fun vars n p -> Ident.Map.add (Mark.remove n) p vars)
local_vars params0 params
in
let f =
let acc_var = Var.make "acc" in
let acc = Expr.make_var acc_var (Untyped { pos = Mark.get param0 }) in
let acc =
Expr.make_var acc_var (Untyped { pos = Mark.get (List.hd params0) })
in
Expr.eabs
(Expr.bind [| acc_var; param |]
(Expr.bind
(Array.of_list (acc_var :: params))
(translate_binop (op, pos) pos acc
(rec_helper ~local_vars predicate)))
[TAny, pos; TAny, pos]
Expand Down Expand Up @@ -766,7 +850,7 @@ let rec translate_expr
| MemCollection (member, collection) ->
let param_var = Var.make "collection_member" in
let param = Expr.make_var param_var emark in
let collection = rec_helper collection in
let collection = detuplify_list collection in
let init = Expr.elit (LBool false) emark in
let acc_var = Var.make "acc" in
let acc = Expr.make_var acc_var emark in
Expand Down
13 changes: 11 additions & 2 deletions compiler/shared_ast/interpreter.ml
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,15 @@ let rec evaluate_operator
(fun e1 e2 ->
evaluate_expr
(Mark.add m
(EApp { f; args = [e1; e2]; tys = [Expr.maybe_ty (Mark.get e1); Expr.maybe_ty (Mark.get e2)] })))
(EApp
{
f;
args = [e1; e2];
tys =
[
Expr.maybe_ty (Mark.get e1); Expr.maybe_ty (Mark.get e2);
];
})))
es1 es2)
| Reduce, [_; default; (EArray [], _)] -> Mark.remove default
| Reduce, [f; _; (EArray (x0 :: xn), _)] ->
Expand Down Expand Up @@ -257,7 +265,8 @@ let rec evaluate_operator
];
})))
init es)
| (Length | Log _ | Eq | Map | Map2 | Concat | Filter | Fold | Reduce), _ -> err ()
| (Length | Log _ | Eq | Map | Map2 | Concat | Filter | Fold | Reduce), _ ->
err ()
| Not, [(ELit (LBool b), _)] -> ELit (LBool (o_not b))
| GetDay, [(ELit (LDate d), _)] -> ELit (LInt (o_getDay d))
| GetMonth, [(ELit (LDate d), _)] -> ELit (LInt (o_getMonth d))
Expand Down
4 changes: 2 additions & 2 deletions compiler/shared_ast/operator.ml
Original file line number Diff line number Diff line change
Expand Up @@ -374,8 +374,8 @@ type 'a no_overloads =
let translate (t : 'a no_overloads t) : 'b no_overloads t =
match t with
| ( Not | GetDay | GetMonth | GetYear | FirstDayOfMonth | LastDayOfMonth | And
| Or | Xor | HandleDefault | HandleDefaultOpt | Log _ | Length | Eq | Map | Map2
| Concat | Filter | Reduce | Fold | Minus_int | Minus_rat | Minus_mon
| Or | Xor | HandleDefault | HandleDefaultOpt | Log _ | Length | Eq | Map
| Map2 | Concat | Filter | Reduce | Fold | Minus_int | Minus_rat | Minus_mon
| Minus_dur | ToRat_int | ToRat_mon | ToMoney_rat | Round_rat | Round_mon
| Add_int_int | Add_rat_rat | Add_mon_mon | Add_dat_dur _ | Add_dur_dur
| Sub_int_int | Sub_rat_rat | Sub_mon_mon | Sub_dat_dat | Sub_dat_dur
Expand Down
12 changes: 9 additions & 3 deletions compiler/shared_ast/print.ml
Original file line number Diff line number Diff line change
Expand Up @@ -409,8 +409,8 @@ module Precedence = struct
| Div | Div_int_int | Div_rat_rat | Div_mon_rat | Div_mon_mon
| Div_dur_dur ->
Op Div
| HandleDefault | HandleDefaultOpt | Map | Map2 | Concat | Filter | Reduce | Fold
| ToClosureEnv | FromClosureEnv ->
| HandleDefault | HandleDefaultOpt | Map | Map2 | Concat | Filter | Reduce
| Fold | ToClosureEnv | FromClosureEnv ->
App)
| EApp _ -> App
| EArray _ -> Contained
Expand Down Expand Up @@ -1090,12 +1090,18 @@ module UserFacing = struct
ppf e ->
match Mark.remove e with
| ELit l -> lit lang ppf l
| EArray l | ETuple l ->
| EArray l ->
Format.fprintf ppf "@[<hv 2>[@,@[<hov>%a@]@;<0 -2>]@]"
(Format.pp_print_list
~pp_sep:(fun ppf () -> Format.fprintf ppf ";@ ")
(value ~fallback lang))
l
| ETuple l ->
Format.fprintf ppf "@[<hv 2>(@,@[<hov>%a@]@;<0 -2>)@]"
(Format.pp_print_list
~pp_sep:(fun ppf () -> Format.fprintf ppf ",@ ")
(value ~fallback lang))
l
| EStruct { name; fields } ->
Format.fprintf ppf "@[<hv 2>%a {@ %a@;<1 -2>}@]" StructName.format name
(StructField.Map.format_bindings ~pp_sep:Format.pp_print_space
Expand Down
4 changes: 2 additions & 2 deletions compiler/shared_ast/typing.ml
Original file line number Diff line number Diff line change
Expand Up @@ -755,9 +755,9 @@ and typecheck_expr_top_down :
| A.EAbs { binder; tys = t_args } ->
if Bindlib.mbinder_arity binder <> List.length t_args then
Message.raise_spanned_error (Expr.pos e)
"function has %d variables but was supplied %d types"
"function has %d variables but was supplied %d types\n%a"
(Bindlib.mbinder_arity binder)
(List.length t_args)
(List.length t_args) Expr.format e
else
let tau_args = List.map ast_to_typ t_args in
let t_ret = unionfind (TAny (Any.fresh ())) in
Expand Down
10 changes: 5 additions & 5 deletions compiler/surface/ast.ml
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,10 @@ and literal =
| LDate of literal_date

and collection_op =
| Exists of { predicate : lident Mark.pos * expression }
| Forall of { predicate : lident Mark.pos * expression }
| Map of { f : lident Mark.pos * expression }
| Filter of { f : lident Mark.pos * expression }
| Exists of { predicate : lident Mark.pos list * expression }
| Forall of { predicate : lident Mark.pos list * expression }
| Map of { f : lident Mark.pos list * expression }
| Filter of { f : lident Mark.pos list * expression }
| AggregateSum of { typ : primitive_typ }
(* it would be nice to remove the need for specifying the and here like for
extremums, but we need an additionl overload for "neutral element for
Expand All @@ -157,7 +157,7 @@ and collection_op =
| AggregateArgExtremum of {
max : bool;
default : expression;
f : lident Mark.pos * expression;
f : lident Mark.pos list * expression;
}

and explicit_match_case = {
Expand Down
22 changes: 13 additions & 9 deletions compiler/surface/parser.mly
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ let qlident :=
}
| id = lident ; { [], id }

let mbinder ==
| id = lident ; { [id] }
| LPAREN ; ids = separated_nonempty_list(COMMA,lident) ; RPAREN ; <>

let expression :=
| e = addpos(naked_expression) ; <>

Expand Down Expand Up @@ -216,7 +220,7 @@ let naked_expression ==
CollectionOp (AggregateSum { typ = Mark.remove typ }, coll)
} %prec apply
| f = expression ;
FOR ; i = lident ;
FOR ; i = mbinder ;
AMONG ; coll = expression ; {
CollectionOp (Map {f = i, f}, coll)
} %prec apply
Expand All @@ -234,12 +238,12 @@ let naked_expression ==
e2 = expression ; {
Binop (binop, e1, e2)
}
| EXISTS ; i = lident ;
| EXISTS ; i = mbinder ;
AMONG ; coll = expression ;
SUCH ; THAT ; predicate = expression ; {
CollectionOp (Exists {predicate = i, predicate}, coll)
} %prec let_expr
| FOR ; ALL ; i = lident ;
| FOR ; ALL ; i = mbinder ;
AMONG ; coll = expression ;
WE_HAVE ; predicate = expression ; {
CollectionOp (Forall {predicate = i, predicate}, coll)
Expand All @@ -254,28 +258,28 @@ let naked_expression ==
ELSE ; e3 = expression ; {
IfThenElse (e1, e2, e3)
} %prec let_expr
| LET ; ids = separated_nonempty_list(COMMA,lident) ;
| LET ; ids = mbinder ;
DEFINED_AS ; e1 = expression ;
IN ; e2 = expression ; {
LetIn (ids, e1, e2)
} %prec let_expr
| i = lident ;
| i = lident ; (* FIXME: should be mbinder *)
AMONG ; coll = expression ;
SUCH ; THAT ; f = expression ; {
CollectionOp (Filter {f = i, f}, coll)
CollectionOp (Filter {f = [i], f}, coll)
} %prec top_expr
| fmap = expression ;
FOR ; i = lident ;
FOR ; i = mbinder ;
AMONG ; coll = expression ;
SUCH ; THAT ; ffilt = expression ; {
CollectionOp (Map {f = i, fmap}, (CollectionOp (Filter {f = i, ffilt}, coll), Pos.from_lpos $loc))
} %prec top_expr
| i = lident ;
| i = lident ; (* FIXME: should be mbinder *)
AMONG ; coll = expression ;
SUCH ; THAT ; f = expression ;
IS ; max = minmax ;
OR ; IF ; LIST_EMPTY ; THEN ; default = expression ; {
CollectionOp (AggregateArgExtremum { max; default; f = i, f }, coll)
CollectionOp (AggregateArgExtremum { max; default; f = [i], f }, coll)
} %prec top_expr


Expand Down
4 changes: 3 additions & 1 deletion runtimes/ocaml/runtime.ml
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,9 @@ module Oper = struct
let o_xor : bool -> bool -> bool = ( <> )
let o_eq = ( = )
let o_map = Array.map
let o_map2 f a b = try Array.map2 f a b with Invalid_argument _ -> raise NotSameLength

let o_map2 f a b =
try Array.map2 f a b with Invalid_argument _ -> raise NotSameLength

let o_reduce f dft a =
let len = Array.length a in
Expand Down
Loading

0 comments on commit 371f955

Please sign in to comment.