diff --git a/compiler/lcalc/compile_without_exceptions.ml b/compiler/lcalc/compile_without_exceptions.ml index 90d200ec1..986a33354 100644 --- a/compiler/lcalc/compile_without_exceptions.ml +++ b/compiler/lcalc/compile_without_exceptions.ml @@ -101,9 +101,9 @@ let rec translate_typ (tau : D.typ Pos.marked) : D.typ Pos.marked = | D.TEnum (ts, en) -> D.TEnum (List.map translate_typ ts, en) | D.TAny -> D.TAny | D.TArray ts -> D.TArray (translate_typ ts) - (* catala is not polymorphic*) - | D.TArrow ((D.TLit D.TUnit, _), t2) -> - D.TEnum ([ translate_typ t2 ], A.option_enum) (* D.TAny *) + (* catala is not polymorphic *) + | D.TArrow ((D.TLit D.TUnit, pos_unit), t2) -> + D.TEnum ([ (D.TLit D.TUnit, pos_unit); translate_typ t2 ], A.option_enum) (* D.TAny *) | D.TArrow (t1, t2) -> D.TArrow (translate_typ t1, translate_typ t2) end diff --git a/compiler/scalc/compile_from_lambda.ml b/compiler/scalc/compile_from_lambda.ml index afc74f8b5..32bb53a2a 100644 --- a/compiler/scalc/compile_from_lambda.ml +++ b/compiler/scalc/compile_from_lambda.ml @@ -67,7 +67,6 @@ let rec translate_expr (ctxt : ctxt) (expr : L.expr Pos.marked) : A.block * A.ex ([], []) args in let new_args = List.rev new_args in - let args_stmts = List.rev args_stmts in (f_stmts @ args_stmts, (A.EApp (new_f, new_args), Pos.get_position expr)) | L.EArray args -> let args_stmts, new_args = @@ -78,7 +77,6 @@ let rec translate_expr (ctxt : ctxt) (expr : L.expr Pos.marked) : A.block * A.ex ([], []) args in let new_args = List.rev new_args in - let args_stmts = List.rev args_stmts in (args_stmts, (A.EArray new_args, Pos.get_position expr)) | L.EOp op -> ([], (A.EOp op, Pos.get_position expr)) | L.ELit l -> ([], (A.ELit l, Pos.get_position expr)) @@ -260,8 +258,12 @@ let translate_program (p : L.program) : A.program = { A.func_params = new_scope_params; A.func_body = new_scope_body }; } :: new_scopes )) - ( L.VarMap.singleton L.handle_default - (A.TopLevelName.fresh ("handle_default", Pos.no_pos)), + ( (if !Cli.avoid_exceptions_flag then + L.VarMap.singleton L.handle_default_opt + (A.TopLevelName.fresh ("handle_default_opt", Pos.no_pos)) + else + L.VarMap.singleton L.handle_default + (A.TopLevelName.fresh ("handle_default", Pos.no_pos))), [] ) p.L.scopes in diff --git a/compiler/scalc/to_python.ml b/compiler/scalc/to_python.ml index aad84435a..0b9a00927 100644 --- a/compiler/scalc/to_python.ml +++ b/compiler/scalc/to_python.ml @@ -147,6 +147,9 @@ let rec format_typ (fmt : Format.formatter) (typ : Dcalc.Ast.typ Pos.marked) : u (fun fmt t -> Format.fprintf fmt "%a" format_typ_with_parens t)) ts | TTuple (_, Some s) -> Format.fprintf fmt "%a" format_struct_name s + | TEnum ([ _; some_typ ], e) when D.EnumName.compare e L.option_enum = 0 -> + (* We translate the option type with an overloading by Python's [None] *) + Format.fprintf fmt "Optional[%a]" format_typ some_typ | TEnum (_, e) -> Format.fprintf fmt "%a" format_enum_name e | TArrow (t1, t2) -> Format.fprintf fmt "Callable[[%a], %a]" format_typ_with_parens t1 format_typ_with_parens t2 @@ -181,8 +184,8 @@ let format_exception (fmt : Format.formatter) (exc : L.except Pos.marked) : unit | NoValueProvided -> let pos = Pos.get_position exc in Format.fprintf fmt - "NoValueProvided(SourcePosition(filename=\"%s\",@ start_line=%d,@ start_column=%d,@ \ - end_line=%d,@ end_column=%d,@ law_headings=%a))" + "NoValueProvided(@[SourcePosition(@[filename=\"%s\",@ start_line=%d,@ \ + start_column=%d,@ end_line=%d,@ end_column=%d,@ law_headings=%a)@])@]" (Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos) (Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list (Pos.get_law_info pos) @@ -201,6 +204,16 @@ let rec format_expression (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (e (List.combine es (List.map fst (Dcalc.Ast.StructMap.find s ctx.ctx_structs))) | EStructFieldAccess (e1, field, _) -> Format.fprintf fmt "%a.%a" (format_expression ctx) e1 format_struct_field_name field + | EInj (_, cons, e_name) + when D.EnumName.compare e_name L.option_enum = 0 + && D.EnumConstructor.compare cons L.none_constr = 0 -> + (* We translate the option type with an overloading by Python's [None] *) + Format.fprintf fmt "None" + | EInj (e, cons, e_name) + when D.EnumName.compare e_name L.option_enum = 0 + && D.EnumConstructor.compare cons L.some_constr = 0 -> + (* We translate the option type with an overloading by Python's [None] *) + format_expression ctx fmt e | EInj (e, cons, enum_name) -> Format.fprintf fmt "%a(%a_Code.%a,@ %a)" format_enum_name enum_name format_enum_name enum_name format_enum_cons_name cons (format_expression ctx) e @@ -240,7 +253,7 @@ let rec format_expression (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (e | EApp ((EOp (Unop op), _), [ arg1 ]) -> Format.fprintf fmt "%a(%a)" format_unop (op, Pos.no_pos) (format_expression ctx) arg1 | EApp (f, args) -> - Format.fprintf fmt "%a(%a)" (format_expression ctx) f + Format.fprintf fmt "%a(@[%a)@]" (format_expression ctx) f (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") (format_expression ctx)) @@ -270,6 +283,14 @@ let rec format_statement (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (s | SIfThenElse (cond, b1, b2) -> Format.fprintf fmt "@[if %a:@\n%a@]@\n@[else:@\n%a@]" (format_expression ctx) cond (format_block ctx) b1 (format_block ctx) b2 + | SSwitch (e1, e_name, [ (case_none, _); (case_some, case_some_var) ]) + when D.EnumName.compare e_name L.option_enum = 0 -> + (* We translate the option type with an overloading by Python's [None] *) + let tmp_var = LocalName.fresh ("perhaps_none_arg", Pos.no_pos) in + Format.fprintf fmt + "%a = %a@\n@[if %a is None:@\n%a@]@\n@[else:@\n%a = %a@\n%a@]" format_var + tmp_var (format_expression ctx) e1 format_var tmp_var (format_block ctx) case_none + format_var case_some_var format_var tmp_var (format_block ctx) case_some | SSwitch (e1, e_name, cases) -> let cases = List.map2 (fun (x, y) (cons, _) -> (x, y, cons)) cases (D.EnumMap.find e_name ctx.ctx_enums) @@ -406,6 +427,8 @@ let format_ctx (type_ordering : Scopelang.Dependency.TVertex.t list) (fmt : Form let format_program (fmt : Format.formatter) (p : Ast.program) (type_ordering : Scopelang.Dependency.TVertex.t list) : unit = + (* We disable the style flag in order to enjoy formatting from the pretty-printers of Dcalc and + Lcalc but without the color terminal markers. *) Cli.style_flag := false; Format.fprintf fmt "# This file has been generated by the Catala compiler, do not edit!\n\ diff --git a/french_law/python/src/catala.py b/french_law/python/src/catala.py index 02daeaabe..405e22635 100644 --- a/french_law/python/src/catala.py +++ b/french_law/python/src/catala.py @@ -12,7 +12,7 @@ from gmpy2 import log2, mpz, mpq, mpfr, t_divmod # type: ignore import datetime import dateutil.relativedelta -from typing import NewType, List, Callable, Tuple, Optional, TypeVar, Iterable, Union +from typing import NewType, List, Callable, Tuple, Optional, TypeVar, Iterable, Union, Any from functools import reduce from enum import Enum import copy @@ -515,6 +515,31 @@ def handle_default( return acc +def handle_default_opt( + exceptions: List[Optional[Any]], + just: Optional[bool], + cons: Optional[Alpha] +) -> Optional[Alpha]: + acc: Optional[Alpha] = None + for exception in exceptions: + if acc is None: + acc = exception + elif not (acc is None) and exception is None: + pass # acc stays the same + elif not (acc is None) and not (exception is None): + raise ConflictError + if acc is None: + if just is None: + return None + else: + if just: + return cons + else: + return None + else: + return acc + + def no_input() -> Callable[[Unit], Alpha]: def closure(_: Unit): raise EmptyError