diff --git a/src/FSharp.Control.TaskSeq.Test/TaskSeq.TakeWhile.Tests.fs b/src/FSharp.Control.TaskSeq.Test/TaskSeq.TakeWhile.Tests.fs index dab9b748..36043e39 100644 --- a/src/FSharp.Control.TaskSeq.Test/TaskSeq.TakeWhile.Tests.fs +++ b/src/FSharp.Control.TaskSeq.Test/TaskSeq.TakeWhile.Tests.fs @@ -27,11 +27,42 @@ module EmptySeq = |> TaskSeq.toListAsync |> Task.map (List.isEmpty >> should be True) +module Terminates = + + [] + let ``TaskSeq-takeWhile stops after predicate fails`` () = + seq { 1; 2; 3; failwith "Too far" } + |> TaskSeq.ofSeq + |> TaskSeq.takeWhile (fun x -> x <= 2) + |> TaskSeq.map char + |> TaskSeq.map ((+) '@') + |> TaskSeq.toArrayAsync + |> Task.map (String >> should equal "AB") + + [] + let ``TaskSeq-takeWhileAsync stops after predicate fails`` () = + taskSeq { 1; 2; 3; failwith "Too far" } + |> TaskSeq.takeWhileAsync (fun x -> task { return x <= 2 }) + |> TaskSeq.map char + |> TaskSeq.map ((+) '@') + |> TaskSeq.toArrayAsync + |> Task.map (String >> should equal "AB") + +// This is the base condition as one would expect in actual code +let inline cond x = x <> 6 + +// For each of the tests below, we add a guard that will trigger if the predicate is passed items known to be beyond the +// first failing item in the known sequence (which is 1..10) +let inline condWithGuard x = + let res = cond x + if x > 6 then failwith "Test sequence should not be enumerated beyond the first item failing the predicate" + res + module Immutable = [)>] let ``TaskSeq-takeWhile filters correctly`` variant = Gen.getSeqImmutable variant - |> TaskSeq.takeWhile (fun x -> x <> 6) + |> TaskSeq.takeWhile condWithGuard |> TaskSeq.map char |> TaskSeq.map ((+) '@') |> TaskSeq.toArrayAsync @@ -40,7 +71,7 @@ module Immutable = [)>] let ``TaskSeq-takeWhileAsync filters correctly`` variant = Gen.getSeqImmutable variant - |> TaskSeq.takeWhileAsync (fun x -> task { return x <> 6 }) + |> TaskSeq.takeWhileAsync (fun x -> task { return condWithGuard x }) |> TaskSeq.map char |> TaskSeq.map ((+) '@') |> TaskSeq.toArrayAsync @@ -50,7 +81,7 @@ module SideEffects = [)>] let ``TaskSeq-takeWhile filters correctly`` variant = Gen.getSeqWithSideEffect variant - |> TaskSeq.takeWhile (fun x -> x <> 6) + |> TaskSeq.takeWhile condWithGuard |> TaskSeq.map char |> TaskSeq.map ((+) '@') |> TaskSeq.toArrayAsync @@ -59,7 +90,7 @@ module SideEffects = [)>] let ``TaskSeq-takeWhileAsync filters correctly`` variant = Gen.getSeqWithSideEffect variant - |> TaskSeq.takeWhileAsync (fun x -> task { return x <> 6 }) + |> TaskSeq.takeWhileAsync (fun x -> task { return condWithGuard x }) |> TaskSeq.map char |> TaskSeq.map ((+) '@') |> TaskSeq.toArrayAsync