diff --git a/integration-testing/src/jvmCoreTest/kotlin/ListAllCoroutineThrowableSubclassesTest.kt b/integration-testing/src/jvmCoreTest/kotlin/ListAllCoroutineThrowableSubclassesTest.kt index 5c564c8a8e..0cf592c636 100644 --- a/integration-testing/src/jvmCoreTest/kotlin/ListAllCoroutineThrowableSubclassesTest.kt +++ b/integration-testing/src/jvmCoreTest/kotlin/ListAllCoroutineThrowableSubclassesTest.kt @@ -26,6 +26,7 @@ class ListAllCoroutineThrowableSubclassesTest { "kotlinx.coroutines.internal.DiagnosticCoroutineContextException", "kotlinx.coroutines.internal.ExceptionSuccessfullyProcessed", "kotlinx.coroutines.CoroutinesInternalError", + "kotlinx.coroutines.DispatchException", "kotlinx.coroutines.channels.ClosedSendChannelException", "kotlinx.coroutines.channels.ClosedReceiveChannelException", "kotlinx.coroutines.flow.internal.ChildCancelledException", diff --git a/kotlinx-coroutines-core/common/src/CoroutineDispatcher.kt b/kotlinx-coroutines-core/common/src/CoroutineDispatcher.kt index 37b68760a0..b208c84f83 100644 --- a/kotlinx-coroutines-core/common/src/CoroutineDispatcher.kt +++ b/kotlinx-coroutines-core/common/src/CoroutineDispatcher.kt @@ -225,7 +225,7 @@ public abstract class CoroutineDispatcher : * @suppress **This an internal API and should not be used from general code.** */ @InternalCoroutinesApi - public open fun dispatchYield(context: CoroutineContext, block: Runnable): Unit = dispatch(context, block) + public open fun dispatchYield(context: CoroutineContext, block: Runnable): Unit = safeDispatch(context, block) /** * Returns a continuation that wraps the provided [continuation], thus intercepting all resumptions. diff --git a/kotlinx-coroutines-core/common/src/CoroutineExceptionHandler.kt b/kotlinx-coroutines-core/common/src/CoroutineExceptionHandler.kt index e6f1d9e63c..0899eb6fb6 100644 --- a/kotlinx-coroutines-core/common/src/CoroutineExceptionHandler.kt +++ b/kotlinx-coroutines-core/common/src/CoroutineExceptionHandler.kt @@ -16,18 +16,19 @@ import kotlin.coroutines.* */ @InternalCoroutinesApi public fun handleCoroutineException(context: CoroutineContext, exception: Throwable) { + val reportException = if (exception is DispatchException) exception.cause else exception // Invoke an exception handler from the context if present try { context[CoroutineExceptionHandler]?.let { - it.handleException(context, exception) + it.handleException(context, reportException) return } } catch (t: Throwable) { - handleUncaughtCoroutineException(context, handlerException(exception, t)) + handleUncaughtCoroutineException(context, handlerException(reportException, t)) return } // If a handler is not present in the context or an exception was thrown, fallback to the global handler - handleUncaughtCoroutineException(context, exception) + handleUncaughtCoroutineException(context, reportException) } internal fun handlerException(originalException: Throwable, thrownException: Throwable): Throwable { diff --git a/kotlinx-coroutines-core/common/src/Yield.kt b/kotlinx-coroutines-core/common/src/Yield.kt index 0598228640..afe79494b7 100644 --- a/kotlinx-coroutines-core/common/src/Yield.kt +++ b/kotlinx-coroutines-core/common/src/Yield.kt @@ -26,7 +26,7 @@ public suspend fun yield(): Unit = suspendCoroutineUninterceptedOrReturn sc@ { u val context = uCont.context context.ensureActive() val cont = uCont.intercepted() as? DispatchedContinuation ?: return@sc Unit - if (cont.dispatcher.isDispatchNeeded(context)) { + if (cont.dispatcher.safeIsDispatchNeeded(context)) { // this is a regular dispatcher -- do simple dispatchYield cont.dispatchYield(context, Unit) } else { diff --git a/kotlinx-coroutines-core/common/src/internal/DispatchedContinuation.kt b/kotlinx-coroutines-core/common/src/internal/DispatchedContinuation.kt index 26e7c5abd4..4c8f54e877 100644 --- a/kotlinx-coroutines-core/common/src/internal/DispatchedContinuation.kt +++ b/kotlinx-coroutines-core/common/src/internal/DispatchedContinuation.kt @@ -187,10 +187,10 @@ internal class DispatchedContinuation( override fun resumeWith(result: Result) { val state = result.toState() - if (dispatcher.isDispatchNeeded(context)) { + if (dispatcher.safeIsDispatchNeeded(context)) { _state = state resumeMode = MODE_ATOMIC - dispatcher.dispatch(context, this) + dispatcher.safeDispatch(context, this) } else { executeUnconfined(state, MODE_ATOMIC) { withCoroutineContext(context, countOrElement) { @@ -205,10 +205,10 @@ internal class DispatchedContinuation( @Suppress("NOTHING_TO_INLINE") internal inline fun resumeCancellableWith(result: Result) { val state = result.toState() - if (dispatcher.isDispatchNeeded(context)) { + if (dispatcher.safeIsDispatchNeeded(context)) { _state = state resumeMode = MODE_CANCELLABLE - dispatcher.dispatch(context, this) + dispatcher.safeDispatch(context, this) } else { executeUnconfined(state, MODE_CANCELLABLE) { if (!resumeCancelled(state)) { @@ -249,6 +249,22 @@ internal class DispatchedContinuation( "DispatchedContinuation[$dispatcher, ${continuation.toDebugString()}]" } +internal fun CoroutineDispatcher.safeDispatch(context: CoroutineContext, runnable: Runnable) { + try { + dispatch(context, runnable) + } catch (e: Throwable) { + throw DispatchException(e, this, context) + } +} + +internal fun CoroutineDispatcher.safeIsDispatchNeeded(context: CoroutineContext): Boolean { + try { + return isDispatchNeeded(context) + } catch (e: Throwable) { + throw DispatchException(e, this, context) + } +} + /** * It is not inline to save bytecode (it is pretty big and used in many places) * and we leave it public so that its name is not mangled in use stack traces if it shows there. diff --git a/kotlinx-coroutines-core/common/src/internal/DispatchedTask.kt b/kotlinx-coroutines-core/common/src/internal/DispatchedTask.kt index 309685bb7c..ad5fed1205 100644 --- a/kotlinx-coroutines-core/common/src/internal/DispatchedTask.kt +++ b/kotlinx-coroutines-core/common/src/internal/DispatchedTask.kt @@ -76,7 +76,6 @@ internal abstract class DispatchedTask internal constructor( final override fun run() { assert { resumeMode != MODE_UNINITIALIZED } // should have been set before dispatching - var fatalException: Throwable? = null try { val delegate = delegate as DispatchedContinuation val continuation = delegate.continuation @@ -102,11 +101,10 @@ internal abstract class DispatchedTask internal constructor( } } } + } catch (e: DispatchException) { + handleCoroutineException(delegate.context, e.cause) } catch (e: Throwable) { - // This instead of runCatching to have nicer stacktrace and debug experience - fatalException = e - } finally { - fatalException?.let { handleFatalException(it) } + handleFatalException(e) } } @@ -143,8 +141,8 @@ internal fun DispatchedTask.dispatch(mode: Int) { // dispatch directly using this instance's Runnable implementation val dispatcher = delegate.dispatcher val context = delegate.context - if (dispatcher.isDispatchNeeded(context)) { - dispatcher.dispatch(context, this) + if (dispatcher.safeIsDispatchNeeded(context)) { + dispatcher.safeDispatch(context, this) } else { resumeUnconfined() } @@ -205,3 +203,17 @@ internal inline fun DispatchedTask<*>.runUnconfinedEventLoop( internal inline fun Continuation<*>.resumeWithStackTrace(exception: Throwable) { resumeWith(Result.failure(recoverStackTrace(exception, this))) } + +/** + * This exception holds an exception raised in [CoroutineDispatcher.dispatch] method. + * When dispatcher methods fail unexpectedly, it is likely a user-induced programmatic bug, + * such as calling `executor.close()` prematurely. To avoid reporting such exceptions as fatal errors, + * we handle them with a separate code path. See also #4091. + * + * @see safeDispatch + */ +internal class DispatchException( + override val cause: Throwable, + dispatcher: CoroutineDispatcher, + context: CoroutineContext, +) : Exception("Coroutine dispatcher $dispatcher threw an exception, context = $context", cause) diff --git a/kotlinx-coroutines-core/common/src/internal/LimitedDispatcher.kt b/kotlinx-coroutines-core/common/src/internal/LimitedDispatcher.kt index eb5196144f..488331fc37 100644 --- a/kotlinx-coroutines-core/common/src/internal/LimitedDispatcher.kt +++ b/kotlinx-coroutines-core/common/src/internal/LimitedDispatcher.kt @@ -42,7 +42,7 @@ internal class LimitedDispatcher( override fun dispatch(context: CoroutineContext, block: Runnable) { dispatchInternal(block) { worker -> - dispatcher.dispatch(this, worker) + dispatcher.safeDispatch(this, worker) } } @@ -116,10 +116,10 @@ internal class LimitedDispatcher( } currentTask = obtainTaskOrDeallocateWorker() ?: return // 16 is our out-of-thin-air constant to emulate fairness. Used in JS dispatchers as well - if (++fairnessCounter >= 16 && dispatcher.isDispatchNeeded(this@LimitedDispatcher)) { + if (++fairnessCounter >= 16 && dispatcher.safeIsDispatchNeeded(this@LimitedDispatcher)) { // Do "yield" to let other views execute their runnable as well // Note that we do not decrement 'runningWorkers' as we are still committed to our part of work - dispatcher.dispatch(this@LimitedDispatcher, this) + dispatcher.safeDispatch(this@LimitedDispatcher, this) return } } diff --git a/kotlinx-coroutines-core/common/src/intrinsics/Cancellable.kt b/kotlinx-coroutines-core/common/src/intrinsics/Cancellable.kt index 2f9a434a1f..1e87d767af 100644 --- a/kotlinx-coroutines-core/common/src/intrinsics/Cancellable.kt +++ b/kotlinx-coroutines-core/common/src/intrinsics/Cancellable.kt @@ -58,6 +58,7 @@ private fun dispatcherFailure(completion: Continuation<*>, e: Throwable) { * 2) Rethrow the exception immediately, so it will crash the caller (e.g. when the coroutine had * no parent or it was async/produce over MainScope). */ - completion.resumeWith(Result.failure(e)) - throw e + val reportException = if (e is DispatchException) e.cause else e + completion.resumeWith(Result.failure(reportException)) + throw reportException } diff --git a/kotlinx-coroutines-core/common/src/intrinsics/Undispatched.kt b/kotlinx-coroutines-core/common/src/intrinsics/Undispatched.kt index 0d2e0404fc..254182b387 100644 --- a/kotlinx-coroutines-core/common/src/intrinsics/Undispatched.kt +++ b/kotlinx-coroutines-core/common/src/intrinsics/Undispatched.kt @@ -20,7 +20,8 @@ internal fun (suspend (R) -> T).startCoroutineUndispatched(receiver: R, c startCoroutineUninterceptedOrReturn(receiver, actualCompletion) } } catch (e: Throwable) { - actualCompletion.resumeWithException(e) + val reportException = if (e is DispatchException) e.cause else e + actualCompletion.resumeWithException(reportException) return } if (value !== COROUTINE_SUSPENDED) { diff --git a/kotlinx-coroutines-core/jvm/src/Executors.kt b/kotlinx-coroutines-core/jvm/src/Executors.kt index 8ba3f18a24..bdfbe6dbbc 100644 --- a/kotlinx-coroutines-core/jvm/src/Executors.kt +++ b/kotlinx-coroutines-core/jvm/src/Executors.kt @@ -105,8 +105,8 @@ public fun CoroutineDispatcher.asExecutor(): Executor = private class DispatcherExecutor(@JvmField val dispatcher: CoroutineDispatcher) : Executor { override fun execute(block: Runnable) { - if (dispatcher.isDispatchNeeded(EmptyCoroutineContext)) { - dispatcher.dispatch(EmptyCoroutineContext, block) + if (dispatcher.safeIsDispatchNeeded(EmptyCoroutineContext)) { + dispatcher.safeDispatch(EmptyCoroutineContext, block) } else { block.run() } diff --git a/kotlinx-coroutines-core/jvm/test/ExecutorsTest.kt b/kotlinx-coroutines-core/jvm/test/ExecutorsTest.kt index 1ad2f8a2a4..965b8fc0be 100644 --- a/kotlinx-coroutines-core/jvm/test/ExecutorsTest.kt +++ b/kotlinx-coroutines-core/jvm/test/ExecutorsTest.kt @@ -119,4 +119,106 @@ class ExecutorsTest : TestBase() { dispatcher.close() check(executorService.isShutdown) } + + @Test + fun testEarlyExecutorShutdown() { + runTestExceptionInDispatch(6, { it is RejectedExecutionException }) { + expect(1) + val dispatcher = newSingleThreadContext("Ctx") + launch(dispatcher) { + withContext(Dispatchers.Default) { + expect(2) + delay(100) + expect(4) + } + } + + delay(50) + expect(3) + + dispatcher.close() + } + } + + @Test + fun testExceptionInDispatch() { + runTestExceptionInDispatch(5, { it is TestException }) { + val dispatcher = object : CoroutineDispatcher() { + private var closed = false + override fun dispatch(context: CoroutineContext, block: Runnable) { + if (closed) throw TestException() + Dispatchers.Default.dispatch(context, block) + } + + fun close() { + closed = true + } + } + launch(dispatcher) { + withContext(Dispatchers.Default) { + expect(1) + delay(100) + expect(3) + } + } + + delay(50) + expect(2) + dispatcher.close() + } + } + + @Test + fun testExceptionInIsDispatchNeeded() { + val dispatcher = object : CoroutineDispatcher() { + override fun isDispatchNeeded(context: CoroutineContext): Boolean { + expect(2) + throw TestException() + } + override fun dispatch(context: CoroutineContext, block: Runnable) = expectUnreached() + } + try { + runBlocking { + expect(1) + try { + launch(dispatcher) { + expectUnreached() + } + expectUnreached() + } catch (_: TestException) { + expect(3) + } + + } + } catch (_: TestException) { + finish(4) + } + } + + private fun runTestExceptionInDispatch( + totalSteps: Int, + isExpectedException: (Throwable) -> Boolean, + block: suspend CoroutineScope.() -> Unit, + ) { + var mainThread: Thread? = null + val exceptionHandler = CoroutineExceptionHandler { _, e -> + if (isExpectedException(e)) { + expect(totalSteps - 1) + mainThread!!.run { + interrupt() + unpark(this) + } + } else { + expectUnreached() + } + } + try { + runBlocking(exceptionHandler) { + block() + mainThread = Thread.currentThread() + } + } catch (_: InterruptedException) { + finish(totalSteps) + } + } }