Skip to content

Commit

Permalink
flambda-backend: Index arrays with unboxed ints (#2337)
Browse files Browse the repository at this point in the history
* wip

* .

* new primitives for unboxed int indexing

* fix tests

* bytegen

* fix runtime

* reduce diff

* code cleanup

* fix upstream

* add comment in array_access_validity_condition
  • Loading branch information
alanechang authored Mar 15, 2024
1 parent c723951 commit 0f4d23e
Show file tree
Hide file tree
Showing 15 changed files with 574 additions and 115 deletions.
57 changes: 41 additions & 16 deletions bytecomp/bytegen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,16 @@ let comp_bint_primitive bi suff args =
| Pint64 -> "caml_int64_" in
Kccall(pref ^ suff, List.length args)

let array_primitive (index_kind : Lambda.array_index_kind) prefix =
let suffix =
match index_kind with
| Ptagged_int_index -> ""
| Punboxed_int_index Pint64 -> "_indexed_by_int64"
| Punboxed_int_index Pint32 -> "_indexed_by_int32"
| Punboxed_int_index Pnativeint -> "_indexed_by_nativeint"
in
prefix ^ suffix

let comp_primitive stack_info p sz args =
check_stack stack_info sz;
match p with
Expand Down Expand Up @@ -504,30 +514,45 @@ let comp_primitive stack_info p sz args =
(* In bytecode, nothing is ever actually stack-allocated, so we ignore the
array modes (allocation for [Parrayref{s,u}], modification for
[Parrayset{s,u}]). *)
| Parrayrefs (Pgenarray_ref _) -> Kccall("caml_array_get", 2)
| Parrayrefs (Pfloatarray_ref _) -> Kccall("caml_floatarray_get", 2)
| Parrayrefs (Paddrarray_ref | Pintarray_ref) ->
| Parrayrefs (Pgenarray_ref _, index_kind)
| Parrayrefs ((Paddrarray_ref | Pintarray_ref | Pfloatarray_ref _),
(Punboxed_int_index _ as index_kind)) ->
Kccall(array_primitive index_kind "caml_array_get", 2)
| Parrayrefs (Pfloatarray_ref _, Ptagged_int_index) ->
Kccall("caml_floatarray_get", 2)
| Parrayrefs ((Paddrarray_ref | Pintarray_ref), Ptagged_int_index) ->
Kccall("caml_array_get_addr", 2)
| Parrayrefs (Punboxedfloatarray_ref _ | Punboxedintarray_ref _) ->
| Parrayrefs ((Punboxedfloatarray_ref _ | Punboxedintarray_ref _), _) ->
Misc.fatal_errorf "Cannot use primitive %a for unboxed arrays in bytecode"
Printlambda.primitive p
| Parraysets (Pgenarray_set _) -> Kccall("caml_array_set", 3)
| Parraysets Pfloatarray_set -> Kccall("caml_floatarray_set", 3)
| Parraysets (Paddrarray_set _ | Pintarray_set) ->
| Parraysets (Pgenarray_set _, index_kind)
| Parraysets ((Paddrarray_set _ | Pintarray_set | Pfloatarray_set),
(Punboxed_int_index _ as index_kind)) ->
Kccall(array_primitive index_kind "caml_array_set", 3)
| Parraysets (Pfloatarray_set, Ptagged_int_index) -> Kccall("caml_floatarray_set", 3)
| Parraysets ((Paddrarray_set _ | Pintarray_set), Ptagged_int_index) ->
Kccall("caml_array_set_addr", 3)
| Parraysets (Punboxedfloatarray_set _ | Punboxedintarray_set _) ->
| Parraysets ((Punboxedfloatarray_set _ | Punboxedintarray_set _), _index_kind) ->
Misc.fatal_errorf "Cannot use primitive %a for unboxed arrays in bytecode"
Printlambda.primitive p
| Parrayrefu (Pgenarray_ref _) -> Kccall("caml_array_unsafe_get", 2)
| Parrayrefu (Pfloatarray_ref _) -> Kccall("caml_floatarray_unsafe_get", 2)
| Parrayrefu (Paddrarray_ref | Pintarray_ref) -> Kgetvectitem
| Parrayrefu (Punboxedfloatarray_ref _ | Punboxedintarray_ref _) ->
| Parrayrefu (Pgenarray_ref _, index_kind)
| Parrayrefu ((Paddrarray_ref | Pintarray_ref | Pfloatarray_ref _),
(Punboxed_int_index _ as index_kind)) ->
Kccall(array_primitive index_kind "caml_array_unsafe_get", 2)
| Parrayrefu (Pfloatarray_ref _, Ptagged_int_index) ->
Kccall("caml_floatarray_unsafe_get", 2)
| Parrayrefu ((Paddrarray_ref | Pintarray_ref), Ptagged_int_index) -> Kgetvectitem
| Parrayrefu ((Punboxedfloatarray_ref _ | Punboxedintarray_ref _), _index_kind) ->
Misc.fatal_errorf "Cannot use primitive %a for unboxed arrays in bytecode"
Printlambda.primitive p
| Parraysetu (Pgenarray_set _) -> Kccall("caml_array_unsafe_set", 3)
| Parraysetu Pfloatarray_set -> Kccall("caml_floatarray_unsafe_set", 3)
| Parraysetu (Paddrarray_set _ | Pintarray_set) -> Ksetvectitem
| Parraysetu (Punboxedfloatarray_set _ | Punboxedintarray_set _) ->
| Parraysetu (Pgenarray_set _, index_kind)
| Parraysetu ((Paddrarray_set _ | Pintarray_set | Pfloatarray_set),
(Punboxed_int_index _ as index_kind)) ->
Kccall(array_primitive index_kind "caml_array_unsafe_set", 3)
| Parraysetu (Pfloatarray_set, Ptagged_int_index) ->
Kccall("caml_floatarray_unsafe_set", 3)
| Parraysetu ((Paddrarray_set _ | Pintarray_set), Ptagged_int_index) -> Ksetvectitem
| Parraysetu ((Punboxedfloatarray_set _ | Punboxedintarray_set _), _index_kind) ->
Misc.fatal_errorf "Cannot use primitive %a for unboxed arrays in bytecode"
Printlambda.primitive p
| Pctconst c ->
Expand Down
26 changes: 15 additions & 11 deletions lambda/lambda.ml
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,10 @@ type primitive =
| Pmakearray of array_kind * mutable_flag * alloc_mode
| Pduparray of array_kind * mutable_flag
| Parraylength of array_kind
| Parrayrefu of array_ref_kind
| Parraysetu of array_set_kind
| Parrayrefs of array_ref_kind
| Parraysets of array_set_kind
| Parrayrefu of array_ref_kind * array_index_kind
| Parraysetu of array_set_kind * array_index_kind
| Parrayrefs of array_ref_kind * array_index_kind
| Parraysets of array_set_kind * array_index_kind
(* Test if the argument is a block or an immediate integer *)
| Pisint of { variant_only : bool }
(* Test if the (integer) argument is outside an interval *)
Expand Down Expand Up @@ -358,6 +358,10 @@ and array_set_kind =
| Punboxedfloatarray_set of unboxed_float
| Punboxedintarray_set of unboxed_integer

and array_index_kind =
| Ptagged_int_index
| Punboxed_int_index of unboxed_integer

and boxed_float = Primitive.boxed_float =
| Pfloat64
| Pfloat32
Expand Down Expand Up @@ -1586,12 +1590,12 @@ let primitive_may_allocate : primitive -> alloc_mode option = function
| Pduparray _ -> Some alloc_heap
| Parraylength _ -> None
| Parraysetu _ | Parraysets _
| Parrayrefu (Paddrarray_ref | Pintarray_ref
| Punboxedfloatarray_ref _ | Punboxedintarray_ref _)
| Parrayrefs (Paddrarray_ref | Pintarray_ref
| Punboxedfloatarray_ref _ | Punboxedintarray_ref _) -> None
| Parrayrefu (Pgenarray_ref m | Pfloatarray_ref m)
| Parrayrefs (Pgenarray_ref m | Pfloatarray_ref m) -> Some m
| Parrayrefu ((Paddrarray_ref | Pintarray_ref
| Punboxedfloatarray_ref _ | Punboxedintarray_ref _), _)
| Parrayrefs ((Paddrarray_ref | Pintarray_ref
| Punboxedfloatarray_ref _ | Punboxedintarray_ref _), _) -> None
| Parrayrefu ((Pgenarray_ref m | Pfloatarray_ref m), _)
| Parrayrefs ((Pgenarray_ref m | Pfloatarray_ref m), _) -> Some m
| Pisint _ | Pisout -> None
| Pintofbint _ -> None
| Pbintofint (_,m)
Expand Down Expand Up @@ -1743,7 +1747,7 @@ let primitive_result_layout (p : primitive) =
| Pstring_load_16 _ | Pbytes_load_16 _ | Pbigstring_load_16 _
| Pprobe_is_enabled _ | Pbswap16
-> layout_int
| Parrayrefu array_ref_kind | Parrayrefs array_ref_kind ->
| Parrayrefu (array_ref_kind, _) | Parrayrefs (array_ref_kind, _) ->
array_ref_kind_result_layout array_ref_kind
| Pbintofint (bi, _) | Pcvtbint (_,bi,_)
| Pnegbint (bi, _) | Paddbint (bi, _) | Psubbint (bi, _)
Expand Down
12 changes: 8 additions & 4 deletions lambda/lambda.mli
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,10 @@ type primitive =
The arguments of [Pduparray] give the kind and mutability of the
array being *produced* by the duplication. *)
| Parraylength of array_kind
| Parrayrefu of array_ref_kind
| Parraysetu of array_set_kind
| Parrayrefs of array_ref_kind
| Parraysets of array_set_kind
| Parrayrefu of array_ref_kind * array_index_kind
| Parraysetu of array_set_kind * array_index_kind
| Parrayrefs of array_ref_kind * array_index_kind
| Parraysets of array_set_kind * array_index_kind
(* Test if the argument is a block or an immediate integer *)
| Pisint of { variant_only : bool }
(* Test if the (integer) argument is outside an interval *)
Expand Down Expand Up @@ -308,6 +308,10 @@ and array_set_kind =
| Punboxedfloatarray_set of unboxed_float
| Punboxedintarray_set of unboxed_integer

and array_index_kind =
| Ptagged_int_index
| Punboxed_int_index of unboxed_integer

and value_kind =
| Pgenval
| Pintval
Expand Down
2 changes: 1 addition & 1 deletion lambda/matching.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2219,7 +2219,7 @@ let get_expr_args_array ~scopes kind head (arg, _mut, _sort, _layout) rem =
let ref_kind = Lambda.(array_ref_kind alloc_heap kind) in
let result_layout = array_ref_kind_result_layout ref_kind in
( Lprim
(Parrayrefu ref_kind,
(Parrayrefu (ref_kind, Ptagged_int_index),
[ arg; Lconst (Const_base (Const_int pos)) ],
loc),
(match am with
Expand Down
23 changes: 19 additions & 4 deletions lambda/printlambda.ml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ let array_ref_kind ppf k =
| Punboxedintarray_ref Pint64 -> fprintf ppf "unboxed_int64"
| Punboxedintarray_ref Pnativeint -> fprintf ppf "unboxed_nativeint"

let array_index_kind ppf k =
match k with
| Ptagged_int_index -> fprintf ppf "int"
| Punboxed_int_index Pint32 -> fprintf ppf "unboxed_int32"
| Punboxed_int_index Pint64 -> fprintf ppf "unboxed_int64"
| Punboxed_int_index Pnativeint -> fprintf ppf "unboxed_nativeint"

let array_set_kind ppf k =
let pp_mode ppf = function
| Modify_heap -> ()
Expand Down Expand Up @@ -482,10 +489,18 @@ let primitive ppf = function
| Pduparray (k, Immutable) -> fprintf ppf "duparray_imm[%s]" (array_kind k)
| Pduparray (k, Immutable_unique) ->
fprintf ppf "duparray_unique[%s]" (array_kind k)
| Parrayrefu rk -> fprintf ppf "array.unsafe_get[%a]" array_ref_kind rk
| Parraysetu sk -> fprintf ppf "array.unsafe_set[%a]" array_set_kind sk
| Parrayrefs rk -> fprintf ppf "array.get[%a]" array_ref_kind rk
| Parraysets sk -> fprintf ppf "array.set[%a]" array_set_kind sk
| Parrayrefu (rk, idx) -> fprintf ppf "array.unsafe_get[%a indexed by %a]"
array_ref_kind rk
array_index_kind idx
| Parraysetu (sk, idx) -> fprintf ppf "array.unsafe_set[%a indexed by %a]"
array_set_kind sk
array_index_kind idx
| Parrayrefs (rk, idx) -> fprintf ppf "array.get[%a indexed by %a]"
array_ref_kind rk
array_index_kind idx
| Parraysets (sk, idx) -> fprintf ppf "array.set[%a indexed by %a]"
array_set_kind sk
array_index_kind idx
| Pctconst c ->
let const_name = match c with
| Big_endian -> "big_endian"
Expand Down
6 changes: 4 additions & 2 deletions lambda/transl_array_comprehension.ml
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,8 @@ let iterator ~transl_exp ~scopes ~loc
~return_layout:(Pvalue Pintval)
pattern.pat_loc
(Lprim(Parrayrefu
Lambda.(array_ref_kind alloc_heap iter_arr_kind),
(Lambda.(array_ref_kind alloc_heap iter_arr_kind),
Ptagged_int_index),
[iter_arr.var; Lvar iter_ix],
loc))
pattern
Expand Down Expand Up @@ -776,7 +777,8 @@ let body
let open Let_binding in
let set_element_raw elt =
(* array.(index) <- elt *)
Lprim(Parraysetu Lambda.(array_set_kind modify_heap array_kind),
Lprim(Parraysetu (Lambda.(array_set_kind modify_heap array_kind),
Ptagged_int_index),
[array.var; index.var; elt],
loc)
in
Expand Down
97 changes: 78 additions & 19 deletions lambda/translprim.ml
Original file line number Diff line number Diff line change
Expand Up @@ -293,23 +293,82 @@ let lookup_primitive loc ~poly_mode ~poly_sort pos p =
| "%bytes_unsafe_get" -> Primitive (Pbytesrefu, 2)
| "%bytes_unsafe_set" -> Primitive (Pbytessetu, 3)
| "%array_length" -> Primitive ((Parraylength gen_array_kind), 1)
| "%array_safe_get" -> Primitive ((Parrayrefs (gen_array_ref_kind mode)), 2)
| "%array_safe_get" ->
Primitive
((Parrayrefs (gen_array_ref_kind mode, Ptagged_int_index)), 2)
| "%array_safe_set" ->
Primitive (Parraysets (gen_array_set_kind (get_first_arg_mode ())), 3)
| "%array_unsafe_get" -> Primitive (Parrayrefu (gen_array_ref_kind mode), 2)
Primitive
(Parraysets (gen_array_set_kind (get_first_arg_mode ()), Ptagged_int_index),
3)
| "%array_unsafe_get" ->
Primitive
(Parrayrefu (gen_array_ref_kind mode, Ptagged_int_index), 2)
| "%array_unsafe_set" ->
Primitive ((Parraysetu (gen_array_set_kind (get_first_arg_mode ()))), 3)
Primitive
((Parraysetu (gen_array_set_kind (get_first_arg_mode ()), Ptagged_int_index)),
3)
| "%array_safe_get_indexed_by_int64#" ->
Primitive
((Parrayrefs (gen_array_ref_kind mode, Punboxed_int_index Pint64)), 2)
| "%array_safe_set_indexed_by_int64#" ->
Primitive
(Parraysets
(gen_array_set_kind (get_first_arg_mode ()), Punboxed_int_index Pint64),
3)
| "%array_unsafe_get_indexed_by_int64#" ->
Primitive
(Parrayrefu (gen_array_ref_kind mode, Punboxed_int_index Pint64), 2)
| "%array_unsafe_set_indexed_by_int64#" ->
Primitive
((Parraysetu
(gen_array_set_kind (get_first_arg_mode ()), Punboxed_int_index Pint64)),
3)
| "%array_safe_get_indexed_by_int32#" ->
Primitive
((Parrayrefs (gen_array_ref_kind mode, Punboxed_int_index Pint32)), 2)
| "%array_safe_set_indexed_by_int32#" ->
Primitive
(Parraysets
(gen_array_set_kind (get_first_arg_mode ()), Punboxed_int_index Pint32),
3)
| "%array_unsafe_get_indexed_by_int32#" ->
Primitive
(Parrayrefu (gen_array_ref_kind mode, Punboxed_int_index Pint32), 2)
| "%array_unsafe_set_indexed_by_int32#" ->
Primitive
((Parraysetu
(gen_array_set_kind (get_first_arg_mode ()), Punboxed_int_index Pint32)),
3)
| "%array_safe_get_indexed_by_nativeint#" ->
Primitive
((Parrayrefs (gen_array_ref_kind mode, Punboxed_int_index Pnativeint)), 2)
| "%array_safe_set_indexed_by_nativeint#" ->
Primitive
(Parraysets
(gen_array_set_kind (get_first_arg_mode ()), Punboxed_int_index Pnativeint),
3)
| "%array_unsafe_get_indexed_by_nativeint#" ->
Primitive
(Parrayrefu (gen_array_ref_kind mode, Punboxed_int_index Pnativeint), 2)
| "%array_unsafe_set_indexed_by_nativeint#" ->
Primitive
((Parraysetu
(gen_array_set_kind (get_first_arg_mode ()), Punboxed_int_index Pnativeint)),
3)
| "%obj_size" -> Primitive ((Parraylength Pgenarray), 1)
| "%obj_field" -> Primitive ((Parrayrefu (Pgenarray_ref mode)), 2)
| "%obj_field" -> Primitive ((Parrayrefu (Pgenarray_ref mode, Ptagged_int_index)), 2)
| "%obj_set_field" ->
Primitive ((Parraysetu (Pgenarray_set (get_first_arg_mode ()))), 3)
Primitive
((Parraysetu (Pgenarray_set (get_first_arg_mode ()), Ptagged_int_index)), 3)
| "%floatarray_length" -> Primitive ((Parraylength Pfloatarray), 1)
| "%floatarray_safe_get" ->
Primitive ((Parrayrefs (Pfloatarray_ref mode)), 2)
| "%floatarray_safe_set" -> Primitive (Parraysets Pfloatarray_set, 3)
Primitive ((Parrayrefs (Pfloatarray_ref mode, Ptagged_int_index)), 2)
| "%floatarray_safe_set" ->
Primitive (Parraysets (Pfloatarray_set, Ptagged_int_index), 3)
| "%floatarray_unsafe_get" ->
Primitive ((Parrayrefu (Pfloatarray_ref mode)), 2)
| "%floatarray_unsafe_set" -> Primitive ((Parraysetu Pfloatarray_set), 3)
Primitive ((Parrayrefu (Pfloatarray_ref mode, Ptagged_int_index)), 2)
| "%floatarray_unsafe_set" ->
Primitive ((Parraysetu (Pfloatarray_set, Ptagged_int_index)), 3)
| "%obj_is_int" -> Primitive (Pisint { variant_only = false }, 1)
| "%lazy_force" -> Lazy_force pos
| "%nativeint_of_int" -> Primitive ((Pbintofint (Pnativeint, mode)), 1)
Expand Down Expand Up @@ -807,26 +866,26 @@ let specialize_primitive env loc ty ~has_constant_constructor prim =
if t = array_type then None
else Some (Primitive (Parraylength array_type, arity))
end
| Primitive (Parrayrefu rt, arity), p1 :: _ -> begin
| Primitive (Parrayrefu (rt, index_kind), arity), p1 :: _ -> begin
let array_ref_type = glb_array_ref_type (to_location loc) rt (array_type_kind env p1)
in
if rt = array_ref_type then None
else Some (Primitive (Parrayrefu array_ref_type, arity))
else Some (Primitive (Parrayrefu (array_ref_type, index_kind), arity))
end
| Primitive (Parraysetu st, arity), p1 :: _ -> begin
| Primitive (Parraysetu (st, index_kind), arity), p1 :: _ -> begin
let array_set_type = glb_array_set_type (to_location loc) st (array_type_kind env p1) in
if st = array_set_type then None
else Some (Primitive (Parraysetu array_set_type, arity))
else Some (Primitive (Parraysetu (array_set_type, index_kind), arity))
end
| Primitive (Parrayrefs rt, arity), p1 :: _ -> begin
| Primitive (Parrayrefs (rt, index_kind), arity), p1 :: _ -> begin
let array_ref_type = glb_array_ref_type (to_location loc) rt (array_type_kind env p1) in
if rt = array_ref_type then None
else Some (Primitive (Parrayrefs array_ref_type, arity))
else Some (Primitive (Parrayrefs (array_ref_type, index_kind), arity))
end
| Primitive (Parraysets st, arity), p1 :: _ -> begin
| Primitive (Parraysets (st, index_kind), arity), p1 :: _ -> begin
let array_set_type = glb_array_set_type (to_location loc) st (array_type_kind env p1) in
if st = array_set_type then None
else Some (Primitive (Parraysets array_set_type, arity))
else Some (Primitive (Parraysets (array_set_type, index_kind), arity))
end
| Primitive (Pbigarrayref(unsafe, n, Pbigarray_unknown,
Pbigarray_unknown_layout), arity), p1 :: _ -> begin
Expand Down Expand Up @@ -1293,7 +1352,7 @@ let lambda_primitive_needs_event_after = function
| Pmulfloat (_, _) | Pdivfloat (_, _)
| Pstringrefs | Pbytesrefs
| Pbytessets | Pmakearray (Pgenarray, _, _) | Pduparray _
| Parrayrefu (Pgenarray_ref _ | Pfloatarray_ref _)
| Parrayrefu ((Pgenarray_ref _ | Pfloatarray_ref _), _)
| Parrayrefs _ | Parraysets _ | Pbintofint _ | Pcvtbint _ | Pnegbint _
| Paddbint _ | Psubbint _ | Pmulbint _ | Pdivbint _ | Pmodbint _ | Pandbint _
| Porbint _ | Pxorbint _ | Plslbint _ | Plsrbint _ | Pasrbint _
Expand Down
12 changes: 8 additions & 4 deletions middle_end/convert_primitives.ml
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,10 @@ let convert (prim : Lambda.primitive) : Clambda_primitives.primitive =
| Pmakearray (kind, mutability, mode) -> Pmakearray (kind, mutability, mode)
| Pduparray (kind, mutability) -> Pduparray (kind, mutability)
| Parraylength kind -> Parraylength kind
| Parrayrefu rkind -> Parrayrefu rkind
| Parraysetu skind -> Parraysetu skind
| Parrayrefs rkind -> Parrayrefs rkind
| Parraysets skind -> Parraysets skind
| Parrayrefu (rkind, Ptagged_int_index) -> Parrayrefu rkind
| Parraysetu (skind, Ptagged_int_index) -> Parraysetu skind
| Parrayrefs (rkind, Ptagged_int_index) -> Parrayrefs rkind
| Parraysets (skind, Ptagged_int_index) -> Parraysets skind
| Pisint _ -> Pisint
| Pisout -> Pisout
| Pcvtbint (src, dest, m) -> Pcvtbint (src, dest, m)
Expand Down Expand Up @@ -219,6 +219,10 @@ let convert (prim : Lambda.primitive) : Clambda_primitives.primitive =
| Punboxed_int32_array_set_128 _
| Punboxed_int64_array_set_128 _
| Punboxed_nativeint_array_set_128 _
| Parrayrefu (_, Punboxed_int_index _)
| Parraysetu (_, Punboxed_int_index _)
| Parrayrefs (_, Punboxed_int_index _)
| Parraysets (_, Punboxed_int_index _)
->
Misc.fatal_errorf "lambda primitive %a can't be converted to \
clambda primitive"
Expand Down
Loading

0 comments on commit 0f4d23e

Please sign in to comment.