Skip to content

Commit

Permalink
Fix AsyncEx.AwaitTask cancellation (#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
TheAngryByrd authored Apr 22, 2024
1 parent 8705e22 commit e720a0e
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 92 deletions.
110 changes: 19 additions & 91 deletions src/IcedTasks/AsyncEx.fs
Original file line number Diff line number Diff line change
Expand Up @@ -25,35 +25,27 @@ type AsyncEx =
/// This is based on <see href="https://stackoverflow.com/a/66815960">How to use awaitable inside async?</see> and <see href="https://github.com/fsharp/fslang-suggestions/issues/840">Async.Await overload (esp. AwaitTask without throwing AggregateException)</see>
/// </remarks>
static member inline AwaitAwaiter(awaiter: 'Awaiter) =
let inline handleFinished (onNext: 'a -> unit, onError: exn -> unit, awaiter) =
try
onNext (Awaiter.GetResult awaiter)
with
| :? AggregateException as ae when ae.InnerExceptions.Count = 1 ->
onError ae.InnerExceptions.[0]
| e ->
// Why not handle TaskCanceledException/OperationCanceledException?
// From https://github.com/dotnet/fsharp/blob/89e641108e8773e8d5731437a2b944510de52567/src/FSharp.Core/async.fs#L1228-L1231:
// A cancelled task calls the exception continuation with TaskCanceledException, since it may not represent cancellation of
// the overall async (they may be governed by different cancellation tokens, or
// the task may not have a cancellation token at all).
onError e

Async.FromContinuations(fun (onNext, onError, onCancel) ->
if Awaiter.IsCompleted awaiter then
try
onNext (Awaiter.GetResult awaiter)
with
| :? TaskCanceledException as ce -> onCancel ce
| :? OperationCanceledException as ce -> onCancel ce
| :? AggregateException as ae ->
if ae.InnerExceptions.Count = 1 then
onError ae.InnerExceptions.[0]
else
onError ae
| e -> onError e
handleFinished (onNext, onError, awaiter)
else
Awaiter.OnCompleted(
Awaiter.UnsafeOnCompleted(
awaiter,
(fun () ->
try
onNext (Awaiter.GetResult awaiter)
with
| :? TaskCanceledException as ce -> onCancel ce
| :? OperationCanceledException as ce -> onCancel ce
| :? AggregateException as ae ->
if ae.InnerExceptions.Count = 1 then
onError ae.InnerExceptions.[0]
else
onError ae
| e -> onError e
)
(fun () -> handleFinished (onNext, onError, awaiter))
)
)

Expand All @@ -78,39 +70,7 @@ type AsyncEx =
/// <remarks>
/// This is based on <see href="https://github.com/fsharp/fslang-suggestions/issues/840">Async.Await overload (esp. AwaitTask without throwing AggregateException)</see>
/// </remarks>
static member AwaitTask(task: Task) : Async<unit> =
Async.FromContinuations(fun (onNext, onError, onCancel) ->
if task.IsCompleted then
if task.IsFaulted then
let e = task.Exception

if e.InnerExceptions.Count = 1 then
onError e.InnerExceptions.[0]
else
onError e
elif task.IsCanceled then
onCancel (TaskCanceledException(task))
else
onNext ()
else
task.ContinueWith(
(fun (task: Task) ->
if task.IsFaulted then
let e = task.Exception

if e.InnerExceptions.Count = 1 then
onError e.InnerExceptions.[0]
else
onError e
elif task.IsCanceled then
onCancel (TaskCanceledException(task))
else
onNext ()
),
TaskContinuationOptions.ExecuteSynchronously
)
|> ignore
)
static member AwaitTask(task: Task) : Async<unit> = AsyncEx.AwaitAwaitable(task)

/// <summary>
/// Return an asynchronous computation that will wait for the given Task to complete and return
Expand All @@ -122,39 +82,7 @@ type AsyncEx =
/// This is based on <see href="https://github.com/fsharp/fslang-suggestions/issues/840">Async.Await overload (esp. AwaitTask without throwing AggregateException)</see>
/// </remarks>
static member AwaitTask(task: Task<'T>) : Async<'T> =
Async.FromContinuations(fun (onNext, onError, onCancel) ->

if task.IsCompleted then
if task.IsFaulted then
let e = task.Exception

if e.InnerExceptions.Count = 1 then
onError e.InnerExceptions.[0]
else
onError e
elif task.IsCanceled then
onCancel (TaskCanceledException(task))
else
onNext task.Result
else
task.ContinueWith(
(fun (task: Task<'T>) ->
if task.IsFaulted then
let e = task.Exception

if e.InnerExceptions.Count = 1 then
onError e.InnerExceptions.[0]
else
onError e
elif task.IsCanceled then
onCancel (TaskCanceledException(task))
else
onNext task.Result
),
TaskContinuationOptions.ExecuteSynchronously
)
|> ignore
)
AsyncEx.AwaitAwaiter(Awaitable.GetTaskAwaiter task)


/// <summary>
Expand Down
63 changes: 62 additions & 1 deletion tests/IcedTasks.Tests/AsyncExTests.fs
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,68 @@ module PolyfillTest =
let! result = outer
Expect.equal result () "Should return the data"
}

let withCancellation (ct: CancellationToken) (a: Async<'a>) : Async<'a> =
async {
let! ct2 = Async.CancellationToken
use cts = CancellationTokenSource.CreateLinkedTokenSource(ct, ct2)
let tcs = new TaskCompletionSource<'a>()

use _reg =
cts.Token.Register(fun () ->
tcs.TrySetCanceled(cts.Token)
|> ignore
)

let a =
async {
try
let! a = a

tcs.TrySetResult a
|> ignore
with ex ->
tcs.TrySetException ex
|> ignore
}

Async.Start(a, cts.Token)

return!
tcs.Task
|> AsyncEx.AwaitTask
}

testCase "Don't cancel everything if one task cancels"
<| fun () ->
use cts = new CancellationTokenSource()
cts.CancelAfter(100)

let doWork i =
asyncEx {
try
let! _ =
Async.Sleep(100)
|> withCancellation cts.Token

()
with :? OperationCanceledException as e ->
()
}

Seq.init
(Environment.ProcessorCount
* 2)
doWork
|> Async.Parallel
|> Async.RunSynchronously
|> ignore
]


[<Tests>]
let asyncExTests = testList "IcedTasks.Polyfill.Async" [ builderTests ]
let asyncExTests =
testList "IcedTasks.Polyfill.Async" [
builderTests

]

0 comments on commit e720a0e

Please sign in to comment.