Skip to content

Commit

Permalink
Confine context-specific state to the thread in UndispatchedCoroutine… (
Browse files Browse the repository at this point in the history
Kotlin#3155)

* Confine context-specific state to the thread in UndispatchedCoroutine in order to avoid state interference when the coroutine is updated concurrently.

Concurrency is inevitable in this scenario: when the coroutine that has UndispatchedCoroutine as its completion suspends, we have to clear the thread context, but while we are doing so, concurrent resume of the coroutine could've happened that also ends up in save/clear/update context

Fixes Kotlin#2930
  • Loading branch information
qwwdfsad authored and pablobaxter committed Sep 14, 2022
1 parent 5ae8611 commit ac276a1
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 12 deletions.
32 changes: 20 additions & 12 deletions kotlinx-coroutines-core/jvm/src/CoroutineContext.kt
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ internal tailrec fun CoroutineStackFrame.undispatchedCompletion(): UndispatchedC

/**
* Marker indicating that [UndispatchedCoroutine] exists somewhere up in the stack.
* Used as a performance optimization to avoid stack walking where it is not nesessary.
* Used as a performance optimization to avoid stack walking where it is not necessary.
*/
private object UndispatchedMarker: CoroutineContext.Element, CoroutineContext.Key<UndispatchedMarker> {
override val key: CoroutineContext.Key<*>
Expand All @@ -120,26 +120,34 @@ internal actual class UndispatchedCoroutine<in T>actual constructor (
uCont: Continuation<T>
) : ScopeCoroutine<T>(if (context[UndispatchedMarker] == null) context + UndispatchedMarker else context, uCont) {

private var savedContext: CoroutineContext? = null
private var savedOldValue: Any? = null
/*
* The state is thread-local because this coroutine can be used concurrently.
* Scenario of usage (withContinuationContext):
* val state = saveThreadContext(ctx)
* try {
* invokeSmthWithThisCoroutineAsCompletion() // Completion implies that 'afterResume' will be called
* // COROUTINE_SUSPENDED is returned
* } finally {
* thisCoroutine().clearThreadContext() // Concurrently the "smth" could've been already resumed on a different thread
* // and it also calls saveThreadContext and clearThreadContext
* }
*/
private var threadStateToRecover = ThreadLocal<Pair<CoroutineContext, Any?>>()

fun saveThreadContext(context: CoroutineContext, oldValue: Any?) {
savedContext = context
savedOldValue = oldValue
threadStateToRecover.set(context to oldValue)
}

fun clearThreadContext(): Boolean {
if (savedContext == null) return false
savedContext = null
savedOldValue = null
if (threadStateToRecover.get() == null) return false
threadStateToRecover.set(null)
return true
}

override fun afterResume(state: Any?) {
savedContext?.let { context ->
restoreThreadContext(context, savedOldValue)
savedContext = null
savedOldValue = null
threadStateToRecover.get()?.let { (ctx, value) ->
restoreThreadContext(ctx, value)
threadStateToRecover.set(null)
}
// resume undispatched -- update context but stay on the same dispatcher
val result = recoverResult(state, uCont)
Expand Down
72 changes: 72 additions & 0 deletions kotlinx-coroutines-core/jvm/test/ThreadLocalStressTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright 2016-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines

import kotlin.test.*


class ThreadLocalStressTest : TestBase() {

private val threadLocal = ThreadLocal<String>()

// See the comment in doStress for the machinery
@Test
fun testStress() = runTest {
repeat (100 * stressTestMultiplierSqrt) {
withContext(Dispatchers.Default) {
repeat(100) {
launch {
doStress(null)
}
}
}
}
}

@Test
fun testStressWithOuterValue() = runTest {
repeat (100 * stressTestMultiplierSqrt) {
withContext(Dispatchers.Default + threadLocal.asContextElement("bar")) {
repeat(100) {
launch {
doStress("bar")
}
}
}
}
}

private suspend fun doStress(expectedValue: String?) {
assertEquals(expectedValue, threadLocal.get())
try {
/*
* Here we are using very specific code-path to trigger the execution we want to.
* The bug, in general, has a larger impact, but this particular code pinpoints it:
*
* 1) We use _undispatched_ withContext with thread element
* 2) We cancel the coroutine
* 3) We use 'suspendCancellableCoroutineReusable' that does _postponed_ cancellation check
* which makes the reproduction of this race pretty reliable.
*
* Now the following code path is likely to be triggered:
*
* T1 from within 'withContinuationContext' method:
* Finds 'oldValue', finds undispatched completion, invokes its 'block' argument.
* 'block' is this coroutine, it goes to 'trySuspend', checks for postponed cancellation and *dispatches* it.
* The execution stops _right_ before 'undispatchedCompletion.clearThreadContext()'.
*
* T2 now executes the dispatched cancellation and concurrently mutates the state of the undispatched completion.
* All bets are off, now both threads can leave the thread locals state inconsistent.
*/
withContext(threadLocal.asContextElement("foo")) {
yield()
cancel()
suspendCancellableCoroutineReusable<Unit> { }
}
} finally {
assertEquals(expectedValue, threadLocal.get())
}
}
}

0 comments on commit ac276a1

Please sign in to comment.