Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

float32 flambda2 operations #2384

Merged
merged 12 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions backend/cmm_helpers.ml
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,8 @@ let mk_compare_floats_untagged dbg a1 a2 =
runtime/floats.c *)
add_int (sub_int op1 op2 dbg) (sub_int op3 op4 dbg) dbg))

let mk_compare_float32s_untagged _dbg _a1 _a2 = assert false

let mk_compare_floats dbg a1 a2 =
bind "float_cmp" a2 (fun a2 ->
bind "float_cmp" a1 (fun a1 ->
Expand Down Expand Up @@ -3702,6 +3704,30 @@ let float_gt = binary (Ccmpf CFgt)

let float_ge = binary (Ccmpf CFge)

let float32_abs ~dbg:_ _ = assert false

let float32_neg ~dbg:_ _ = assert false

let float32_add ~dbg:_ _ _ = assert false

let float32_sub ~dbg:_ _ _ = assert false

let float32_mul ~dbg:_ _ _ = assert false

let float32_div ~dbg:_ _ _ = assert false

let float32_eq ~dbg:_ _ _ = assert false

let float32_neq ~dbg:_ _ _ = assert false

let float32_lt ~dbg:_ _ _ = assert false

let float32_le ~dbg:_ _ _ = assert false

let float32_gt ~dbg:_ _ _ = assert false

let float32_ge ~dbg:_ _ _ = assert false

let beginregion ~dbg = Cop (Cbeginregion, [], dbg)

let endregion ~dbg region = Cop (Cendregion, [region], dbg)
Expand Down
27 changes: 27 additions & 0 deletions backend/cmm_helpers.mli
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ val mk_compare_ints_untagged :
val mk_compare_floats_untagged :
Debuginfo.t -> expression -> expression -> expression

val mk_compare_float32s_untagged :
Debuginfo.t -> expression -> expression -> expression

(** Convert a tagged integer into a raw integer with boolean meaning *)
val test_bool : Debuginfo.t -> expression -> expression

Expand Down Expand Up @@ -723,6 +726,8 @@ val uge : dbg:Debuginfo.t -> expression -> expression -> expression
(** Asbolute value on floats. *)
val float_abs : dbg:Debuginfo.t -> expression -> expression

val float32_abs : dbg:Debuginfo.t -> expression -> expression

(** Arithmetic negation on floats. *)
val float_neg : dbg:Debuginfo.t -> expression -> expression

Expand All @@ -732,11 +737,23 @@ val float_sub : dbg:Debuginfo.t -> expression -> expression -> expression

val float_mul : dbg:Debuginfo.t -> expression -> expression -> expression

val float32_neg : dbg:Debuginfo.t -> expression -> expression

val float32_add : dbg:Debuginfo.t -> expression -> expression -> expression

val float32_sub : dbg:Debuginfo.t -> expression -> expression -> expression

val float32_mul : dbg:Debuginfo.t -> expression -> expression -> expression

(** Float arithmetic operations. *)
val float_div : dbg:Debuginfo.t -> expression -> expression -> expression

val float_eq : dbg:Debuginfo.t -> expression -> expression -> expression

val float32_div : dbg:Debuginfo.t -> expression -> expression -> expression

val float32_eq : dbg:Debuginfo.t -> expression -> expression -> expression

(** Float arithmetic (dis)equality of cmm expressions. Returns an untagged
integer (either 0 or 1) to represent the result of the comparison. *)
val float_neq : dbg:Debuginfo.t -> expression -> expression -> expression
Expand All @@ -747,10 +764,20 @@ val float_le : dbg:Debuginfo.t -> expression -> expression -> expression

val float_gt : dbg:Debuginfo.t -> expression -> expression -> expression

val float32_neq : dbg:Debuginfo.t -> expression -> expression -> expression

val float32_lt : dbg:Debuginfo.t -> expression -> expression -> expression

val float32_le : dbg:Debuginfo.t -> expression -> expression -> expression

val float32_gt : dbg:Debuginfo.t -> expression -> expression -> expression

(** Float arithmetic comparisons on cmm expressions. Returns an untagged integer
(either 0 or 1) to represent the result of the comparison. *)
val float_ge : dbg:Debuginfo.t -> expression -> expression -> expression

val float32_ge : dbg:Debuginfo.t -> expression -> expression -> expression

val beginregion : dbg:Debuginfo.t -> expression

val endregion : dbg:Debuginfo.t -> expression -> expression
Expand Down
113 changes: 78 additions & 35 deletions middle_end/flambda2/from_lambda/lambda_to_flambda_primitives.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1120,35 +1120,39 @@ let convert_lprim ~big_endian (prim : L.primitive) (args : Simple.t list list)
let dst = K.Standard_int_or_float.Naked_float in
[box_float mode (Unary (Num_conv { src; dst }, arg)) ~current_region]
| Pnegfloat (Pfloat64, mode), [[arg]] ->
[box_float mode (Unary (Float_arith Neg, unbox_float arg)) ~current_region]
[ box_float mode
(Unary (Float_arith (Float64, Neg), unbox_float arg))
~current_region ]
| Pabsfloat (Pfloat64, mode), [[arg]] ->
[box_float mode (Unary (Float_arith Abs, unbox_float arg)) ~current_region]
[ box_float mode
(Unary (Float_arith (Float64, Abs), unbox_float arg))
~current_region ]
| Paddfloat (Pfloat64, mode), [[arg1]; [arg2]] ->
[ box_float mode
(Binary (Float_arith Add, unbox_float arg1, unbox_float arg2))
(Binary (Float_arith (Float64, Add), unbox_float arg1, unbox_float arg2))
~current_region ]
| Psubfloat (Pfloat64, mode), [[arg1]; [arg2]] ->
[ box_float mode
(Binary (Float_arith Sub, unbox_float arg1, unbox_float arg2))
(Binary (Float_arith (Float64, Sub), unbox_float arg1, unbox_float arg2))
~current_region ]
| Pmulfloat (Pfloat64, mode), [[arg1]; [arg2]] ->
[ box_float mode
(Binary (Float_arith Mul, unbox_float arg1, unbox_float arg2))
(Binary (Float_arith (Float64, Mul), unbox_float arg1, unbox_float arg2))
~current_region ]
| Pdivfloat (Pfloat64, mode), [[arg1]; [arg2]] ->
[ box_float mode
(Binary (Float_arith Div, unbox_float arg1, unbox_float arg2))
(Binary (Float_arith (Float64, Div), unbox_float arg1, unbox_float arg2))
~current_region ]
| Pfloatcomp (Pfloat64, comp), [[arg1]; [arg2]] ->
[ tag_int
(Binary
( Float_comp (Yielding_bool (convert_float_comparison comp)),
( Float_comp (Float64, Yielding_bool (convert_float_comparison comp)),
unbox_float arg1,
unbox_float arg2 )) ]
| Punboxed_float_comp (Pfloat64, comp), [[arg1]; [arg2]] ->
[ tag_int
(Binary
( Float_comp (Yielding_bool (convert_float_comparison comp)),
( Float_comp (Float64, Yielding_bool (convert_float_comparison comp)),
arg1,
arg2 )) ]
| Punbox_float Pfloat64, [[arg]] -> [Unary (Unbox_number Naked_float, arg)]
Expand All @@ -1166,19 +1170,53 @@ let convert_lprim ~big_endian (prim : L.primitive) (args : Simple.t list list)
let src = K.Standard_int_or_float.Tagged_immediate in
let dst = K.Standard_int_or_float.Naked_float32 in
[box_float32 mode (Unary (Num_conv { src; dst }, arg)) ~current_region]
| Pnegfloat (Pfloat32, _), _
| Pabsfloat (Pfloat32, _), _
| Paddfloat (Pfloat32, _), _
| Psubfloat (Pfloat32, _), _
| Pmulfloat (Pfloat32, _), _
| Pdivfloat (Pfloat32, _), _
| Pfloatcomp (Pfloat32, _), _
| Punbox_float Pfloat32, _
| Pbox_float (Pfloat32, _), _
| Pcompare_floats Pfloat32, _
| Punboxed_float_comp (Pfloat32, _), _ ->
(* CR mslater: (float32) runtime *)
assert false
| Pnegfloat (Pfloat32, mode), [[arg]] ->
[ box_float32 mode
(Unary (Float_arith (Float32, Neg), unbox_float32 arg))
~current_region ]
| Pabsfloat (Pfloat32, mode), [[arg]] ->
[ box_float32 mode
(Unary (Float_arith (Float32, Abs), unbox_float32 arg))
~current_region ]
| Paddfloat (Pfloat32, mode), [[arg1]; [arg2]] ->
[ box_float32 mode
(Binary
(Float_arith (Float32, Add), unbox_float32 arg1, unbox_float32 arg2))
~current_region ]
| Psubfloat (Pfloat32, mode), [[arg1]; [arg2]] ->
[ box_float32 mode
(Binary
(Float_arith (Float32, Sub), unbox_float32 arg1, unbox_float32 arg2))
~current_region ]
| Pmulfloat (Pfloat32, mode), [[arg1]; [arg2]] ->
[ box_float32 mode
(Binary
(Float_arith (Float32, Mul), unbox_float32 arg1, unbox_float32 arg2))
~current_region ]
| Pdivfloat (Pfloat32, mode), [[arg1]; [arg2]] ->
[ box_float32 mode
(Binary
(Float_arith (Float32, Div), unbox_float32 arg1, unbox_float32 arg2))
~current_region ]
| Pfloatcomp (Pfloat32, comp), [[arg1]; [arg2]] ->
[ tag_int
(Binary
( Float_comp (Float32, Yielding_bool (convert_float_comparison comp)),
unbox_float32 arg1,
unbox_float32 arg2 )) ]
| Punboxed_float_comp (Pfloat32, comp), [[arg1]; [arg2]] ->
[ tag_int
(Binary
( Float_comp (Float32, Yielding_bool (convert_float_comparison comp)),
arg1,
arg2 )) ]
| Punbox_float Pfloat32, [[arg]] -> [Unary (Unbox_number Naked_float32, arg)]
| Pbox_float (Pfloat32, mode), [[arg]] ->
[ Unary
( Box_number
( Naked_float32,
Alloc_mode.For_allocations.from_lambda mode ~current_region ),
arg ) ]
| Punbox_int bi, [[arg]] ->
let kind = boxable_number_of_boxed_integer bi in
[Unary (Unbox_number kind, arg)]
Expand Down Expand Up @@ -1818,9 +1856,15 @@ let convert_lprim ~big_endian (prim : L.primitive) (args : Simple.t list list)
| Pcompare_floats Pfloat64, [[f1]; [f2]] ->
[ tag_int
(Binary
( Float_comp (Yielding_int_like_compare_functions ()),
( Float_comp (Float64, Yielding_int_like_compare_functions ()),
Prim (Unary (Unbox_number Naked_float, f1)),
Prim (Unary (Unbox_number Naked_float, f2)) )) ]
| Pcompare_floats Pfloat32, [[f1]; [f2]] ->
[ tag_int
(Binary
( Float_comp (Float32, Yielding_int_like_compare_functions ()),
Prim (Unary (Unbox_number Naked_float32, f1)),
Prim (Unary (Unbox_number Naked_float32, f2)) )) ]
| Pcompare_bints int_kind, [[i1]; [i2]] ->
let unboxing_kind = boxable_number_of_boxed_integer int_kind in
[ tag_int
Expand Down Expand Up @@ -1860,18 +1904,17 @@ let convert_lprim ~big_endian (prim : L.primitive) (args : Simple.t list list)
%a (%a)"
Printlambda.primitive prim H.print_list_of_simple_or_prim
(List.flatten args)
| ( ( Pfield _ | Pnegint | Pnot | Poffsetint _
| Pintoffloat (Pfloat64 | Pfloat32)
| Pfloatofint ((Pfloat64 | Pfloat32), _)
| ( ( Pfield _ | Pnegint | Pnot | Poffsetint _ | Pintoffloat _
| Pfloatofint (_, _)
| Pfloatoffloat32 _ | Pfloat32offloat _
| Pnegfloat (Pfloat64, _)
| Pabsfloat (Pfloat64, _)
| Pnegfloat (_, _)
| Pabsfloat (_, _)
| Pstringlength | Pbyteslength | Pbintofint _ | Pintofbint _ | Pnegbint _
| Popaque _ | Pduprecord _ | Parraylength _ | Pduparray _ | Pfloatfield _
| Pcvtbint _ | Poffsetref _ | Pbswap16 | Pbbswap _ | Pisint _
| Pint_as_pointer _ | Pbigarraydim _ | Pobj_dup | Pobj_magic _
| Punbox_float Pfloat64
| Pbox_float (Pfloat64, _)
| Punbox_float _
| Pbox_float (_, _)
| Punbox_int _ | Pbox_int _ | Punboxed_product_field _ | Pget_header _
| Pufloatfield _ | Patomic_load _ | Pmixedfield _ ),
([] | _ :: _ :: _ | [([] | _ :: _ :: _)]) ) ->
Expand All @@ -1881,12 +1924,12 @@ let convert_lprim ~big_endian (prim : L.primitive) (args : Simple.t list list)
Printlambda.primitive prim H.print_list_of_lists_of_simple_or_prim args
| ( ( Paddint | Psubint | Pmulint | Pandint | Porint | Pxorint | Plslint
| Plsrint | Pasrint | Pdivint _ | Pmodint _ | Psetfield _ | Pintcomp _
| Paddfloat (Pfloat64, _)
| Psubfloat (Pfloat64, _)
| Pmulfloat (Pfloat64, _)
| Pdivfloat (Pfloat64, _)
| Pfloatcomp (Pfloat64, _)
| Punboxed_float_comp (Pfloat64, _)
| Paddfloat (_, _)
| Psubfloat (_, _)
| Pmulfloat (_, _)
| Pdivfloat (_, _)
| Pfloatcomp (_, _)
| Punboxed_float_comp (_, _)
| Pstringrefu | Pbytesrefu | Pstringrefs | Pbytesrefs | Pstring_load_16 _
| Pstring_load_32 _ | Pstring_load_64 _ | Pstring_load_128 _
| Pbytes_load_16 _ | Pbytes_load_32 _ | Pbytes_load_64 _
Expand Down
6 changes: 4 additions & 2 deletions middle_end/flambda2/parser/fexpr.ml
Original file line number Diff line number Diff line change
Expand Up @@ -350,12 +350,14 @@ type bytes_like_value = Flambda_primitive.bytes_like_value =
| Bytes
| Bigstring

type float_bitwidth = Flambda_primitive.float_bitwidth

type infix_binop =
| Int_arith of binary_int_arith_op (* on tagged immediates *)
| Int_shift of int_shift_op (* on tagged immediates *)
| Int_comp of signed_or_unsigned comparison_behaviour (* on tagged imms *)
| Float_arith of binary_float_arith_op
| Float_comp of unit comparison_behaviour
| Float_arith of float_bitwidth * binary_float_arith_op
| Float_comp of float_bitwidth * unit comparison_behaviour

type binop =
| Array_load of array_kind * array_accessor_width * mutability
Expand Down
4 changes: 2 additions & 2 deletions middle_end/flambda2/parser/fexpr_to_flambda.ml
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,8 @@ let infix_binop (binop : Fexpr.infix_binop) : Flambda_primitive.binary_primitive
| Int_arith o -> Int_arith (Tagged_immediate, o)
| Int_comp c -> Int_comp (Tagged_immediate, c)
| Int_shift s -> Int_shift (Tagged_immediate, s)
| Float_arith o -> Float_arith o
| Float_comp c -> Float_comp c
| Float_arith (w, o) -> Float_arith (w, o)
| Float_comp (w, c) -> Float_comp (w, c)

let block_access_kind (ak : Fexpr.block_access_kind) :
Flambda_primitive.Block_access_kind.t =
Expand Down
4 changes: 2 additions & 2 deletions middle_end/flambda2/parser/flambda_parser.ml
Original file line number Diff line number Diff line change
Expand Up @@ -4784,7 +4784,7 @@ module Tables = struct
# 4785 "flambda_parser_in.ml"
) =
# 431 "flambda_parser.mly"
( Float_arith o )
( Float_arith (Float64, o) )
# 4789 "flambda_parser_in.ml"
in
{
Expand Down Expand Up @@ -4817,7 +4817,7 @@ module Tables = struct
# 4818 "flambda_parser_in.ml"
) =
# 432 "flambda_parser.mly"
( Float_comp c )
( Float_comp (Float64, c) )
# 4822 "flambda_parser_in.ml"
in
{
Expand Down
4 changes: 2 additions & 2 deletions middle_end/flambda2/parser/flambda_parser.mly
Original file line number Diff line number Diff line change
Expand Up @@ -428,8 +428,8 @@ infix_binop:
| o = binary_int_arith_op { Int_arith o }
| c = int_comp { Int_comp (c Signed) }
| s = int_shift { Int_shift s }
| o = binary_float_arith_op { Float_arith o }
| c = float_comp { Float_comp c }
| o = binary_float_arith_op { Float_arith (Float64, o) }
| c = float_comp { Float_comp (Float64, c) }
;

prefix_binop:
Expand Down
4 changes: 2 additions & 2 deletions middle_end/flambda2/parser/flambda_to_fexpr.ml
Original file line number Diff line number Diff line change
Expand Up @@ -592,8 +592,8 @@ let binop (op : Flambda_primitive.binary_primitive) : Fexpr.binop =
| Int_comp (i, c) -> Int_comp (i, c)
| Int_shift (Tagged_immediate, s) -> Infix (Int_shift s)
| Int_shift (i, s) -> Int_shift (i, s)
| Float_arith o -> Infix (Float_arith o)
| Float_comp c -> Infix (Float_comp c)
| Float_arith (w, o) -> Infix (Float_arith (w, o))
| Float_comp (w, c) -> Infix (Float_comp (w, c))
| String_or_bigstring_load (slv, saw) -> String_or_bigstring_load (slv, saw)
| Bigarray_get_alignment align -> Bigarray_get_alignment align
| Bigarray_load _ | Atomic_exchange | Atomic_fetch_and_add ->
Expand Down
Loading
Loading