From b569aa08feb160dd960243ae3e2b0f04f53c171b Mon Sep 17 00:00:00 2001 From: Enrico Tassi Date: Thu, 10 Oct 2024 17:14:15 +0200 Subject: [PATCH] fix --- src/compiler.ml | 46 +++++++++++++++++++++++++++++++++++------- src/parser/ast.ml | 20 +++++++++--------- src/parser/ast.mli | 11 +++++----- src/parser/grammar.mly | 24 ++++++++++++---------- src/utils/util.ml | 4 +++- src/utils/util.mli | 4 +++- 6 files changed, 75 insertions(+), 34 deletions(-) diff --git a/src/compiler.ml b/src/compiler.ml index f5ea44da8..d5825312f 100644 --- a/src/compiler.ml +++ b/src/compiler.ml @@ -665,19 +665,24 @@ module ScopedTypeExpression = struct type t_ = | Const of scope * F.t | App of F.t * e * e list + | Arrow of e * e + | Pred of Ast.Structured.functionality * (Ast.Mode.t * e) list | CData of CData.t and e = { it : t_; loc : Loc.t } [@@ deriving show] - type t = - | Lam of F.t * t + type v_ = + | Lam of F.t * v_ | Ty of e + type t = { name : F.t; value : v_; nparams : int; loc : Loc.t } let rec eqt ctx t1 t2 = match t1.it, t2.it with | Const(Global,c1), Const(Global,c2) -> F.equal c1 c2 | Const(Local,c1), Const(Local,c2) -> ScopedTerm.eq_var ctx c1 c2 | App(c1,x,xs), App(c2,y,ys) -> F.equal c1 c2 && eqt ctx x y && Util.for_all2 (eqt ctx) xs ys + | Arrow(s1,t1), Arrow(s2,t2) -> eqt ctx s1 s2 && eqt ctx t1 t2 + | Pred(f1,l1), Pred(f2,l2) -> f1 == f2 && Util.for_all2 (fun (m1,t1) (m2,t2) -> Ast.Mode.compare m1 m2 == 0 && eqt ctx t1 t2) l1 l2 | CData c1, CData c2 -> CData.equal c1 c2 | _ -> false @@ -881,7 +886,8 @@ end = struct (* {{{ *) { c with Chr.attributes = aux_chr { cid; cifexpr = None } attributes } - let rec structure_type_expression_aux ~loc = function + let rec structure_type_expression_aux ~loc t = { t with TypeExpression.tit = + match t.TypeExpression.tit with | TypeExpression.TPred([Functional],p) -> TypeExpression.TPred(Function,List.map (fun (m,p) -> m, structure_type_expression_aux ~loc p) p) | TypeExpression.TPred([],p) -> TypeExpression.TPred(Relation,List.map (fun (m,p) -> m, structure_type_expression_aux ~loc p) p) | TypeExpression.TPred(a :: _, _) -> error ~loc ("illegal attribute " ^ show_raw_attribute a) @@ -889,11 +895,13 @@ end = struct (* {{{ *) | TypeExpression.TApp(c,x,xs) -> TypeExpression.TApp(c,structure_type_expression_aux ~loc x,List.map (structure_type_expression_aux ~loc) xs) | TypeExpression.TCData c -> TypeExpression.TCData c | TypeExpression.TConst c -> TypeExpression.TConst c + } - let structure_type_expression loc toplevel_func = function + let structure_type_expression loc toplevel_func t = + match t.TypeExpression.tit with | TypeExpression.TPred([],p) -> - TypeExpression.TPred(toplevel_func,List.map (fun (m,p) -> m, structure_type_expression_aux ~loc p) p) - | x -> structure_type_expression_aux ~loc x + { t with TypeExpression.tit = TypeExpression.TPred(toplevel_func,List.map (fun (m,p) -> m, structure_type_expression_aux ~loc p) p) } + | x -> structure_type_expression_aux ~loc t let structure_type_attributes { Type.attributes; loc; name; ty } = let duplicate_err s = @@ -934,7 +942,7 @@ end = struct (* {{{ *) let structure_type_abbreviation { TypeAbbreviation.name; value; nparams; loc } = let rec aux = function - | TypeAbbreviation.Lam(c,t) -> TypeAbbreviation.Lam(c,aux t) + | TypeAbbreviation.Lam(c,loc,t) -> TypeAbbreviation.Lam(c,loc,aux t) | TypeAbbreviation.Ty t -> TypeAbbreviation.Ty (structure_type_expression loc Relation t) in { TypeAbbreviation.name; value = aux value; nparams; loc } @@ -1183,6 +1191,30 @@ end = struct let it = scope_term ctx ~loc it in { ScopedTerm.it; loc } + let rec scope_tye ctx ~loc t = + match t with + | Ast.TypeExpression.TConst c when F.Set.mem c ctx -> ScopedTypeExpression.(Const(ScopedTerm.Local,c)) + | Ast.TypeExpression.TConst c -> ScopedTypeExpression.(Const(ScopedTerm.Global,c)) + | Ast.TypeExpression.TApp(c,x,xs) -> + if F.Set.mem c ctx then error ~loc "type schema parameters cannot be type formers"; + ScopedTypeExpression.App(c,scope_loc_tye ctx x, List.map (scope_loc_tye ctx) xs) + | Ast.TypeExpression.TPred(m,xs) -> + ScopedTypeExpression.Pred(m,List.map (fun (m,t) -> m, scope_loc_tye ctx t) xs) + | Ast.TypeExpression.TArr(s,t) -> + ScopedTypeExpression.Arrow(scope_loc_tye ctx s, scope_loc_tye ctx t) + | Ast.TypeExpression.TCData c -> ScopedTypeExpression.CData c + and scope_loc_tye ctx { tloc; tit } = { loc = tloc; it = scope_tye ctx ~loc:tloc tit } + + let scope_type_abbrev { Ast.TypeAbbreviation.name; value; nparams; loc } = + let rec aux ctx = function + | Ast.TypeAbbreviation.Lam(c,loc,t) when is_uvar_name c -> + if F.Set.mem c ctx then error ~loc "duplicate type schema variable"; + ScopedTypeExpression.Lam(c,aux (F.Set.add c ctx) t) + | Ast.TypeAbbreviation.Lam(c,loc,_) -> error ~loc "only variables can be abstracted in type schema" + | Ast.TypeAbbreviation.Ty t -> ScopedTypeExpression.Ty (scope_loc_tye ctx t) + in + { ScopedTypeExpression.name; value = aux F.Set.empty value; nparams; loc } + let scope_loc_term = scope_loc_term F.Set.empty let compile_type_abbrev abbrvs { Ast.TypeAbbreviation.name; value; nparams; loc } = diff --git a/src/parser/ast.ml b/src/parser/ast.ml index c765b89d7..85f5c5727 100644 --- a/src/parser/ast.ml +++ b/src/parser/ast.ml @@ -62,7 +62,7 @@ end module Mode = struct - type mode = Util.arg_mode = Input | Output + type t = Util.arg_mode = Input | Output [@@deriving show, ord] end @@ -79,18 +79,20 @@ type raw_attribute = | Functional [@@deriving show, ord] + module TypeExpression = struct - type 'attribute t = - | TConst of Func.t - | TApp of Func.t * 'attribute t * 'attribute t list - | TPred of 'attribute * (Mode.mode * 'attribute t) list - | TArr of 'attribute t * 'attribute t - | TCData of CData.t + type 'attribute t_ = + | TConst of Func.t + | TApp of Func.t * 'attribute t * 'attribute t list + | TPred of 'attribute * (Mode.t * 'attribute t) list + | TArr of 'attribute t * 'attribute t + | TCData of CData.t + and 'a t = { tit : 'a t_; tloc : Loc.t } [@@ deriving show, ord] end - + module Term = struct type t_ = @@ -225,7 +227,7 @@ end module TypeAbbreviation = struct type 'ty closedTypeexpression = - | Lam of Func.t * 'ty closedTypeexpression + | Lam of Func.t * Loc.t * 'ty closedTypeexpression | Ty of 'ty [@@ deriving show, ord] diff --git a/src/parser/ast.mli b/src/parser/ast.mli index 81a22db92..7bd2cbaf3 100644 --- a/src/parser/ast.mli +++ b/src/parser/ast.mli @@ -42,7 +42,7 @@ end module Mode : sig - type mode = Input | Output + type t = Input | Output [@@deriving show, ord] end @@ -61,13 +61,14 @@ type raw_attribute = module TypeExpression : sig - type 'attribute t = + type 'attribute t_ = | TConst of Func.t | TApp of Func.t * 'attribute t * 'attribute t list - | TPred of 'attribute * ((Mode.mode * 'attribute t) list) + | TPred of 'attribute * (Mode.t * 'attribute t) list | TArr of 'attribute t * 'attribute t | TCData of CData.t - [@@ deriving show, ord] + and 'a t = { tit : 'a t_; tloc : Loc.t } + [@@ deriving show, ord] end @@ -153,7 +154,7 @@ end module TypeAbbreviation : sig type 'ty closedTypeexpression = - | Lam of Func.t * 'ty closedTypeexpression + | Lam of Func.t * Loc.t * 'ty closedTypeexpression | Ty of 'ty [@@ deriving show, ord] diff --git a/src/parser/grammar.mly b/src/parser/grammar.mly index 46f261d3e..58d3e3334 100644 --- a/src/parser/grammar.mly +++ b/src/parser/grammar.mly @@ -24,7 +24,7 @@ let loc (startpos, endpos) = { line_starts_at = startpos.Lexing.pos_bol; } -let desugar_multi_binder loc t = +let desugar_multi_binder loc (t : Ast.Term.t) = match t.it with | App( { it = Const hd } as binder,args) when Func.equal hd Func.pif || Func.equal hd Func.sigmaf && List.length args > 1 -> @@ -112,6 +112,8 @@ let mode_of_IO io = %type < Func.t > prefix_SYMB %type < Func.t > postfix_SYMB %type < Func.t > constant +%type < 'a TypeExpression.t > type_term +%type < 'a TypeExpression.t > atype_term (* entry points *) %start program @@ -175,14 +177,14 @@ chr_rule: pred: | attributes = attributes; PRED; name = constant; args = separated_list(option(CONJ),pred_item) { - { Type.loc=loc $sloc; name; attributes; ty = TPred ([], args) } + { Type.loc=loc $sloc; name; attributes; ty = { tloc = loc $loc; tit = TPred ([], args) } } } pred_item: | io = IO_COLON; ty = type_term { (mode_of_IO io,ty) } anonymous_pred: | attributes = attributes; PRED; - args = separated_list(option(CONJ),pred_item) { TPred (attributes, args) } + args = separated_list(option(CONJ),pred_item) { { tloc = loc $loc; tit = TPred (attributes, args) } } kind: | KIND; names = separated_nonempty_list(CONJ,constant); k = kind_term { @@ -197,20 +199,20 @@ type_: } atype_term: -| c = STRING { TCData (cstring.Util.CData.cin c) } -| c = constant { TConst (fix_church c) } +| c = STRING { { tloc = loc $loc; tit = TCData (cstring.Util.CData.cin c) } } +| c = constant { { tloc = loc $loc; tit = TConst (fix_church c) } } | LPAREN; t = type_term; RPAREN { t } | LPAREN; t = anonymous_pred; RPAREN { t } type_term: -| c = constant { TConst (fix_church c) } -| hd = constant; args = nonempty_list(atype_term) { TApp (hd, List.hd args, List.tl args) } -| hd = type_term; ARROW; t = type_term { TArr (hd, t) } +| c = constant { { tloc = loc $loc; tit = TConst (fix_church c) } } +| hd = constant; args = nonempty_list(atype_term) { { tloc = loc $loc; tit = TApp (hd, List.hd args, List.tl args) } } +| hd = type_term; ARROW; t = type_term { { tloc = loc $loc; tit = TArr (hd, t) } } | LPAREN; t = anonymous_pred; RPAREN { t } | LPAREN; t = type_term; RPAREN { t } kind_term: -| TYPE { TConst (Func.from_string "type") } -| TYPE; ARROW; t = kind_term { TArr (TConst (Func.from_string "type"), t) } +| TYPE { { tloc = loc $loc; tit = TConst (Func.from_string "type") } } +| x = TYPE; ARROW; t = kind_term { { tloc = loc $loc; tit = TArr ({ tloc = loc $loc(x); tit = TConst (Func.from_string "type") }, t) } } macro: | MACRO; m = term; VDASH; b = term { @@ -222,7 +224,7 @@ typeabbrev: | TYPEABBREV; a = abbrevform; t = type_term { let name, args = a in let nparams = List.length args in - let mkLam (n,_) body = TypeAbbreviation.Lam (n, body) in + let mkLam (n,l) body = TypeAbbreviation.Lam (n, l, body) in let value = List.fold_right mkLam args (Ty t) in { TypeAbbreviation.name = name; nparams = nparams; diff --git a/src/utils/util.ml b/src/utils/util.ml index 62e184b02..7865c76ee 100644 --- a/src/utils/util.ml +++ b/src/utils/util.ml @@ -232,7 +232,9 @@ let rec for_all3b p l1 l2 bl b = ;; type arg_mode = Input | Output -and mode_aux = +[@@deriving show, ord] + +type mode_aux = | Fo of arg_mode | Ho of arg_mode * mode and mode = mode_aux list diff --git a/src/utils/util.mli b/src/utils/util.mli index 20245d687..7aa901f81 100644 --- a/src/utils/util.mli +++ b/src/utils/util.mli @@ -118,7 +118,9 @@ val for_all2 : ('a -> 'a -> bool) -> 'a list -> 'a list -> bool val for_all23 : argsdepth:int -> (argsdepth:int -> matching:bool -> 'x -> 'y -> 'z -> 'a -> 'a -> bool) -> 'x -> 'y -> 'z -> 'a list -> 'a list -> bool val for_all3b : ('a -> 'a -> bool -> bool) -> 'a list -> 'a list -> bool list -> bool -> bool type arg_mode = Input | Output -and mode_aux = +[@@deriving show, ord] + +type mode_aux = | Fo of arg_mode | Ho of arg_mode * mode and mode = mode_aux list