Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Python backend with exceptions avoided #210

Merged
merged 1 commit into from
Feb 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions compiler/lcalc/compile_without_exceptions.ml
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ let rec translate_typ (tau : D.typ Pos.marked) : D.typ Pos.marked =
| D.TEnum (ts, en) -> D.TEnum (List.map translate_typ ts, en)
| D.TAny -> D.TAny
| D.TArray ts -> D.TArray (translate_typ ts)
(* catala is not polymorphic*)
| D.TArrow ((D.TLit D.TUnit, _), t2) ->
D.TEnum ([ translate_typ t2 ], A.option_enum) (* D.TAny *)
(* catala is not polymorphic *)
| D.TArrow ((D.TLit D.TUnit, pos_unit), t2) ->
D.TEnum ([ (D.TLit D.TUnit, pos_unit); translate_typ t2 ], A.option_enum) (* D.TAny *)
| D.TArrow (t1, t2) -> D.TArrow (translate_typ t1, translate_typ t2)
end

Expand Down
10 changes: 6 additions & 4 deletions compiler/scalc/compile_from_lambda.ml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ let rec translate_expr (ctxt : ctxt) (expr : L.expr Pos.marked) : A.block * A.ex
([], []) args
in
let new_args = List.rev new_args in
let args_stmts = List.rev args_stmts in
(f_stmts @ args_stmts, (A.EApp (new_f, new_args), Pos.get_position expr))
| L.EArray args ->
let args_stmts, new_args =
Expand All @@ -78,7 +77,6 @@ let rec translate_expr (ctxt : ctxt) (expr : L.expr Pos.marked) : A.block * A.ex
([], []) args
in
let new_args = List.rev new_args in
let args_stmts = List.rev args_stmts in
(args_stmts, (A.EArray new_args, Pos.get_position expr))
| L.EOp op -> ([], (A.EOp op, Pos.get_position expr))
| L.ELit l -> ([], (A.ELit l, Pos.get_position expr))
Expand Down Expand Up @@ -260,8 +258,12 @@ let translate_program (p : L.program) : A.program =
{ A.func_params = new_scope_params; A.func_body = new_scope_body };
}
:: new_scopes ))
( L.VarMap.singleton L.handle_default
(A.TopLevelName.fresh ("handle_default", Pos.no_pos)),
( (if !Cli.avoid_exceptions_flag then
L.VarMap.singleton L.handle_default_opt
(A.TopLevelName.fresh ("handle_default_opt", Pos.no_pos))
else
L.VarMap.singleton L.handle_default
(A.TopLevelName.fresh ("handle_default", Pos.no_pos))),
[] )
p.L.scopes
in
Expand Down
29 changes: 26 additions & 3 deletions compiler/scalc/to_python.ml
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ let rec format_typ (fmt : Format.formatter) (typ : Dcalc.Ast.typ Pos.marked) : u
(fun fmt t -> Format.fprintf fmt "%a" format_typ_with_parens t))
ts
| TTuple (_, Some s) -> Format.fprintf fmt "%a" format_struct_name s
| TEnum ([ _; some_typ ], e) when D.EnumName.compare e L.option_enum = 0 ->
(* We translate the option type with an overloading by Python's [None] *)
Format.fprintf fmt "Optional[%a]" format_typ some_typ
| TEnum (_, e) -> Format.fprintf fmt "%a" format_enum_name e
| TArrow (t1, t2) ->
Format.fprintf fmt "Callable[[%a], %a]" format_typ_with_parens t1 format_typ_with_parens t2
Expand Down Expand Up @@ -181,8 +184,8 @@ let format_exception (fmt : Format.formatter) (exc : L.except Pos.marked) : unit
| NoValueProvided ->
let pos = Pos.get_position exc in
Format.fprintf fmt
"NoValueProvided(SourcePosition(filename=\"%s\",@ start_line=%d,@ start_column=%d,@ \
end_line=%d,@ end_column=%d,@ law_headings=%a))"
"NoValueProvided(@[<hov 0>SourcePosition(@[<hov 0>filename=\"%s\",@ start_line=%d,@ \
start_column=%d,@ end_line=%d,@ end_column=%d,@ law_headings=%a)@])@]"
(Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos)
(Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list (Pos.get_law_info pos)

Expand All @@ -201,6 +204,16 @@ let rec format_expression (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (e
(List.combine es (List.map fst (Dcalc.Ast.StructMap.find s ctx.ctx_structs)))
| EStructFieldAccess (e1, field, _) ->
Format.fprintf fmt "%a.%a" (format_expression ctx) e1 format_struct_field_name field
| EInj (_, cons, e_name)
when D.EnumName.compare e_name L.option_enum = 0
&& D.EnumConstructor.compare cons L.none_constr = 0 ->
(* We translate the option type with an overloading by Python's [None] *)
Format.fprintf fmt "None"
| EInj (e, cons, e_name)
when D.EnumName.compare e_name L.option_enum = 0
&& D.EnumConstructor.compare cons L.some_constr = 0 ->
(* We translate the option type with an overloading by Python's [None] *)
format_expression ctx fmt e
| EInj (e, cons, enum_name) ->
Format.fprintf fmt "%a(%a_Code.%a,@ %a)" format_enum_name enum_name format_enum_name enum_name
format_enum_cons_name cons (format_expression ctx) e
Expand Down Expand Up @@ -240,7 +253,7 @@ let rec format_expression (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (e
| EApp ((EOp (Unop op), _), [ arg1 ]) ->
Format.fprintf fmt "%a(%a)" format_unop (op, Pos.no_pos) (format_expression ctx) arg1
| EApp (f, args) ->
Format.fprintf fmt "%a(%a)" (format_expression ctx) f
Format.fprintf fmt "%a(@[<hov 0>%a)@]" (format_expression ctx) f
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(format_expression ctx))
Expand Down Expand Up @@ -270,6 +283,14 @@ let rec format_statement (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (s
| SIfThenElse (cond, b1, b2) ->
Format.fprintf fmt "@[<hov 4>if %a:@\n%a@]@\n@[<hov 4>else:@\n%a@]" (format_expression ctx)
cond (format_block ctx) b1 (format_block ctx) b2
| SSwitch (e1, e_name, [ (case_none, _); (case_some, case_some_var) ])
when D.EnumName.compare e_name L.option_enum = 0 ->
(* We translate the option type with an overloading by Python's [None] *)
let tmp_var = LocalName.fresh ("perhaps_none_arg", Pos.no_pos) in
Format.fprintf fmt
"%a = %a@\n@[<hov 4>if %a is None:@\n%a@]@\n@[<hov 4>else:@\n%a = %a@\n%a@]" format_var
tmp_var (format_expression ctx) e1 format_var tmp_var (format_block ctx) case_none
format_var case_some_var format_var tmp_var (format_block ctx) case_some
| SSwitch (e1, e_name, cases) ->
let cases =
List.map2 (fun (x, y) (cons, _) -> (x, y, cons)) cases (D.EnumMap.find e_name ctx.ctx_enums)
Expand Down Expand Up @@ -406,6 +427,8 @@ let format_ctx (type_ordering : Scopelang.Dependency.TVertex.t list) (fmt : Form

let format_program (fmt : Format.formatter) (p : Ast.program)
(type_ordering : Scopelang.Dependency.TVertex.t list) : unit =
(* We disable the style flag in order to enjoy formatting from the pretty-printers of Dcalc and
Lcalc but without the color terminal markers. *)
Cli.style_flag := false;
Format.fprintf fmt
"# This file has been generated by the Catala compiler, do not edit!\n\
Expand Down
27 changes: 26 additions & 1 deletion french_law/python/src/catala.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from gmpy2 import log2, mpz, mpq, mpfr, t_divmod # type: ignore
import datetime
import dateutil.relativedelta
from typing import NewType, List, Callable, Tuple, Optional, TypeVar, Iterable, Union
from typing import NewType, List, Callable, Tuple, Optional, TypeVar, Iterable, Union, Any
from functools import reduce
from enum import Enum
import copy
Expand Down Expand Up @@ -515,6 +515,31 @@ def handle_default(
return acc


def handle_default_opt(
exceptions: List[Optional[Any]],
just: Optional[bool],
cons: Optional[Alpha]
) -> Optional[Alpha]:
acc: Optional[Alpha] = None
for exception in exceptions:
if acc is None:
acc = exception
elif not (acc is None) and exception is None:
pass # acc stays the same
elif not (acc is None) and not (exception is None):
raise ConflictError
if acc is None:
if just is None:
return None
else:
if just:
return cons
else:
return None
else:
return acc


def no_input() -> Callable[[Unit], Alpha]:
def closure(_: Unit):
raise EmptyError
Expand Down