diff --git a/kotlinx-coroutines-core/common/src/flow/internal/ChannelFlow.kt b/kotlinx-coroutines-core/common/src/flow/internal/ChannelFlow.kt index 0efe5f86db..8610cf5a48 100644 --- a/kotlinx-coroutines-core/common/src/flow/internal/ChannelFlow.kt +++ b/kotlinx-coroutines-core/common/src/flow/internal/ChannelFlow.kt @@ -161,7 +161,7 @@ internal abstract class ChannelFlowOperator( // Fast-path: When channel creation is optional (flowOn/flowWith operators without buffer) if (capacity == Channel.OPTIONAL_CHANNEL) { val collectContext = coroutineContext - val newContext = collectContext + context // compute resulting collect context + val newContext = collectContext.newCoroutineContext(context) // compute resulting collect context // #1: If the resulting context happens to be the same as it was -- fallback to plain collect if (newContext == collectContext) return flowCollect(collector) diff --git a/kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt b/kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt index 83f5ae17db..3bb79b4971 100644 --- a/kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt +++ b/kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt @@ -7,6 +7,7 @@ package kotlinx.coroutines import org.junit.Test import kotlin.coroutines.* import kotlin.test.* +import kotlinx.coroutines.flow.* class ThreadContextElementTest : TestBase() { @@ -37,7 +38,7 @@ class ThreadContextElementTest : TestBase() { } @Test - fun testUndispatched()= runTest { + fun testUndispatched() = runTest { val exceptionHandler = coroutineContext[CoroutineExceptionHandler]!! val data = MyData() val element = MyElement(data) @@ -191,6 +192,21 @@ class ThreadContextElementTest : TestBase() { assertEquals(manuallyCaptured, captor.capturees) } + + @Test + fun testThreadLocalFlowOn() = runTest { + val myData = MyData() + myThreadLocal.set(myData) + expect(1) + flow { + assertEquals(myData, myThreadLocal.get()) + emit(1) + } + .flowOn(myThreadLocal.asContextElement() + Dispatchers.Default) + .single() + myThreadLocal.set(null) + finish(2) + } } class MyData @@ -259,6 +275,7 @@ class CopyForChildCoroutineElement(val data: MyData?) : CopyableThreadContextEle } } + /** * Calls [block], setting the value of [this] [ThreadLocal] for the duration of [block]. * diff --git a/kotlinx-coroutines-core/jvm/test/ThreadContextMutableCopiesTest.kt b/kotlinx-coroutines-core/jvm/test/ThreadContextMutableCopiesTest.kt index 34e5955fd7..73d4ee6e9e 100644 --- a/kotlinx-coroutines-core/jvm/test/ThreadContextMutableCopiesTest.kt +++ b/kotlinx-coroutines-core/jvm/test/ThreadContextMutableCopiesTest.kt @@ -4,6 +4,7 @@ package kotlinx.coroutines +import kotlinx.coroutines.flow.* import kotlin.coroutines.* import kotlin.test.* @@ -131,4 +132,32 @@ class ThreadContextMutableCopiesTest : TestBase() { finish(2) } } + + @Test + fun testDataIsCopiedThroughFlowOnUndispatched() = runTest { + expect(1) + val root = MyMutableElement(ArrayList()) + val originalData = root.mutableData + flow { + assertNotSame(originalData, threadLocalData.get()) + emit(1) + } + .flowOn(root) + .single() + finish(2) + } + + @Test + fun testDataIsCopiedThroughFlowOnDispatched() = runTest { + expect(1) + val root = MyMutableElement(ArrayList()) + val originalData = root.mutableData + flow { + assertNotSame(originalData, threadLocalData.get()) + emit(1) + } + .flowOn(root + Dispatchers.Default) + .single() + finish(2) + } }