Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nat_split, pos_split #114

Merged
merged 2 commits into from
May 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions src/core/QCheck.ml
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,34 @@ module Gen = struct
let samples = List.rev_map sample l in
List.sort (fun (w1, _) (w2, _) -> poly_compare w1 w2) samples |> List.rev_map snd

let range_subset ~size low high st =
if not (low <= high && size <= high - low + 1) then invalid_arg "Gen.range_subset";
(* The algorithm below is attributed to Floyd, see for example
https://eyalsch.wordpress.com/2010/04/01/random-sample/
https://math.stackexchange.com/questions/178690

Note: the code be made faster by checking membership in [arr]
directly instead of using an additional Set. None of our
dependencies implements dichotomic search, so using Set is
easier.
*)
let module ISet = Set.Make(Int) in
let s = ref ISet.empty in
let arr = Array.make size 0 in
for i = high - size to high do
let pos = int_range high i st in
let choice =
if ISet.mem pos !s then i else pos
in
arr.(i - low) <- choice;
s := ISet.add choice !s;
done;
arr

let array_subset size arr st =
range_subset ~size 0 (Array.length arr - 1) st
|> Array.map (fun i -> arr.(i))

let pair g1 g2 st = (g1 st, g2 st)

let triple g1 g2 g3 st = (g1 st, g2 st, g3 st)
Expand Down Expand Up @@ -300,6 +328,36 @@ module Gen = struct
let rec f' n st = f f' n st in
f'

(* nat splitting *)

let nat_split2 n st =
if (n < 2) then invalid_arg "nat_split2";
let n1 = int_range 1 (n - 1) st in
(n1, n - n1)

let pos_split2 n st =
let n1 = int_range 0 n st in
(n1, n - n1)

let pos_split ~size:k n st =
if (k > n) then invalid_arg "nat_split";
(* To split n into n{0}+n{1}+..+n{k-1}, we draw distinct "boundaries"
b{-1}..b{k-1}, with b{-1}=0 and b{k-1} = n
and the k-1 intermediate boundaries b{0}..b{k-2}
chosen randomly distinct in [1;n-1].

Then each n{i} is defined as b{i}-b{i-1}. *)
let b = range_subset ~size:(k-1) 1 (n - 1) st in
Array.init k (fun i ->
if i = 0 then b.(0)
else if i = k-1 then n - b.(i-1)
else b.(i) - b.(i-1)
)

let nat_split ~size:k n st =
pos_split ~size:k (n+k) st
|> Array.map (fun v -> v - 1)

let generate ?(rand=Random.State.make_self_init()) ~n g =
list_repeat n g rand

Expand Down
69 changes: 69 additions & 0 deletions src/core/QCheck.mli
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,30 @@ module Gen : sig
@since 0.11
*)

val range_subset : size:int -> int -> int -> int array t
(** [range_subset ~size:k low high] generates an array of length [k]
of sorted distinct integers in the range [low..high] (included).

Complexity O(k log k), drawing [k] random integers.

@raise Invalid_argument outside the valid region [0 <= k <= high-low+1].

@since 0.18
*)

val array_subset : int -> 'a array -> 'a array t
(** [array_subset k arr] generates a sub-array of [k] elements
at distinct positions in the input array [arr],
in the same order.

Complexity O(k log k), drawing [k] random integers.

@raise Invalid_argument outside the valid region
[0 <= size <= Array.length arr].

@since 0.18
*)

val unit : unit t (** The unit generator. *)

val bool : bool t (** The boolean generator. *)
Expand Down Expand Up @@ -422,6 +446,51 @@ module Gen : sig

*)

val nat_split2 : int -> (int * int) t
(** [nat_split2 n] generates pairs [(n1, n2)] of natural numbers
with [n1 + n2 = n].

This is useful to split sizes to combine sized generators.

@raise Invalid_argument unless [n >= 2].

@since 0.18
*)

val pos_split2 : int -> (int * int) t
(** [nat_split2 n] generates pairs [(n1, n2)] of strictly positive
(nonzero) natural numbers with [n1 + n2 = n].

This is useful to split sizes to combine sized generators.

@since 0.18
*)

val nat_split : size:int -> int -> int array t
(** [nat_split2 ~size:k n] generates [k]-sized arrays [n1,n2,..nk]
of natural numbers in [[0;n]] with [n1 + n2 + ... + nk = n].

This is useful to split sizes to combine sized generators.

Complexity O(k log k).

@since 0.18
*)

val pos_split : size:int -> int -> int array t
(** [nat_split2 ~size:k n] generates [k]-sized arrays [n1,n2,..nk]
of strictly positive (non-zero) natural numbers with
[n1 + n2 + ... + nk = n].

This is useful to split sizes to combine sized generators.

Complexity O(k log k).

@raise Invalid_argument unless [k <= n].

@since 0.18
*)

val delay : (unit -> 'a t) -> 'a t
(** Delay execution of some code until the generator is actually called.
This can be used to manually implement recursion or control flow
Expand Down