Skip to content

Commit

Permalink
Merge pull request #3636 from armanbilge/topic/thread-local-iolocal
Browse files Browse the repository at this point in the history
`IOLocal` propagation for unsafe access
  • Loading branch information
armanbilge authored Nov 22, 2024
2 parents ccae9c7 + 1adf368 commit 8091026
Show file tree
Hide file tree
Showing 15 changed files with 276 additions and 87 deletions.
13 changes: 11 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,15 @@ lazy val core = crossProject(JSPlatform, JVMPlatform, NativePlatform)
"cats.effect.unsafe.IORuntimeBuilder.this"),
// introduced by #3695, which enabled fiber dumps on native
ProblemFilters.exclude[MissingClassProblem](
"cats.effect.unsafe.FiberMonitorCompanionPlatform")
"cats.effect.unsafe.FiberMonitorCompanionPlatform"),
// introduced by #3636, IOLocal propagation
// IOLocal is a sealed trait
ProblemFilters.exclude[ReversedMissingMethodProblem]("cats.effect.IOLocal.getOrDefault"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("cats.effect.IOLocal.set"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("cats.effect.IOLocal.reset"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("cats.effect.IOLocal.lens"),
// this filter is particulary terrible, because it can also mask real issues :(
ProblemFilters.exclude[DirectMissingMethodProblem]("cats.effect.IOLocal.lens")
) ++ {
if (tlIsScala3.value) {
// Scala 3 specific exclusions
Expand Down Expand Up @@ -905,7 +913,8 @@ lazy val tests: CrossProject = crossProject(JSPlatform, JVMPlatform, NativePlatf
scalacOptions ~= { _.filterNot(_.startsWith("-P:scalajs:mapSourceURI")) }
)
.jvmSettings(
fork := true
fork := true,
Test / javaOptions += "-Dcats.effect.ioLocalPropagation=true"
)
.nativeSettings(
Compile / mainClass := Some("catseffect.examples.NativeRunner")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ private object IOFiberConstants {
final val AutoCedeR = 7
final val DoneR = 8

final val ioLocalPropagation = false

@nowarn212
@inline def isVirtualThread(t: Thread): Boolean = false
}
19 changes: 19 additions & 0 deletions core/js-native/src/main/scala/cats/effect/IOLocalPlatform.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* Copyright 2020-2024 Typelevel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package cats.effect

private[effect] trait IOLocalPlatform[A]
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ private[effect] sealed abstract class WorkStealingThreadPool[P] private ()
Map[Runnable, Trace])
}

private[unsafe] sealed abstract class WorkerThread[P] private () extends Thread {
private[effect] sealed abstract class WorkerThread[P] private () extends Thread {
private[unsafe] def isOwnedBy(threadPool: WorkStealingThreadPool[_]): Boolean
private[unsafe] def monitor(fiber: Runnable): WeakBag.Handle
private[unsafe] def index: Int
private[effect] var currentIOFiber: IOFiber[_]
}
2 changes: 2 additions & 0 deletions core/jvm/src/main/java/cats/effect/IOFiberConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ final class IOFiberConstants {
static final byte AutoCedeR = 7;
static final byte DoneR = 8;

static final boolean ioLocalPropagation = Boolean.getBoolean("cats.effect.ioLocalPropagation");

static boolean isVirtualThread(final Thread thread) {
try {
return (boolean) THREAD_IS_VIRTUAL_HANDLE.invokeExact(thread);
Expand Down
57 changes: 57 additions & 0 deletions core/jvm/src/main/scala/cats/effect/IOLocalPlatform.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright 2020-2024 Typelevel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package cats.effect

import IOFiberConstants.ioLocalPropagation

private[effect] trait IOLocalPlatform[A] { self: IOLocal[A] =>

/**
* Returns a [[java.lang.ThreadLocal]] view of this [[IOLocal]] that allows to unsafely get,
* set, and remove (aka reset) the value in the currently running fiber. The system property
* `cats.effect.ioLocalPropagation` must be `true`, otherwise throws an
* [[java.lang.UnsupportedOperationException]].
*/
def unsafeThreadLocal(): ThreadLocal[A] = if (ioLocalPropagation)
new ThreadLocal[A] {
override def get(): A = {
val fiber = IOFiber.currentIOFiber()
val state = if (fiber ne null) fiber.getLocalState() else IOLocalState.empty
self.getOrDefault(state)
}

override def set(value: A): Unit = {
val fiber = IOFiber.currentIOFiber()
if (fiber ne null) {
fiber.setLocalState(self.set(fiber.getLocalState(), value))
}
}

override def remove(): Unit = {
val fiber = IOFiber.currentIOFiber()
if (fiber ne null) {
fiber.setLocalState(self.reset(fiber.getLocalState()))
}
}
}
else
throw new UnsupportedOperationException(
"IOLocal-ThreadLocal propagation is disabled.\n" +
"Enable by setting cats.effect.ioLocalPropagation=true."
)

}
5 changes: 4 additions & 1 deletion core/jvm/src/main/scala/cats/effect/IOPlatform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ abstract private[effect] class IOPlatform[+A] extends Serializable { self: IO[A]
implicit runtime: unsafe.IORuntime): Option[A] = {
val queue = new ArrayBlockingQueue[Either[Throwable, A]](1)

unsafeRunAsync { r =>
val fiber = unsafeRunAsyncImpl { r =>
queue.offer(r)
()
}
Expand All @@ -82,6 +82,9 @@ abstract private[effect] class IOPlatform[+A] extends Serializable { self: IO[A]
} catch {
case _: InterruptedException =>
None
} finally {
if (IOFiberConstants.ioLocalPropagation)
IOLocal.setThreadLocalState(fiber.getLocalState())
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import java.util.concurrent.atomic.AtomicBoolean
* system when compared to a fixed size thread pool whose worker threads all draw tasks from a
* single global work queue.
*/
private final class WorkerThread[P <: AnyRef](
private[effect] final class WorkerThread[P <: AnyRef](
idx: Int,
// Local queue instance with exclusive write access.
private[this] var queue: LocalQueue,
Expand Down Expand Up @@ -107,6 +107,8 @@ private final class WorkerThread[P <: AnyRef](
private val indexTransfer: LinkedTransferQueue[Integer] = new LinkedTransferQueue()
private[this] val runtimeBlockingExpiration: Duration = pool.runtimeBlockingExpiration

private[effect] var currentIOFiber: IOFiber[_] = _

private[this] val RightUnit = Right(())
private[this] val noop = new Function0[Unit] with Runnable {
def apply() = ()
Expand Down
11 changes: 8 additions & 3 deletions core/shared/src/main/scala/cats/effect/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,12 @@ sealed abstract class IO[+A] private () extends IOPlatform[A] {
*/
def unsafeRunAsync(cb: Either[Throwable, A] => Unit)(
implicit runtime: unsafe.IORuntime): Unit = {
unsafeRunAsyncImpl(cb)
()
}

private[effect] def unsafeRunAsyncImpl(cb: Either[Throwable, A] => Unit)(
implicit runtime: unsafe.IORuntime): IOFiber[A @uncheckedVariance] =
unsafeRunFiber(
cb(Left(new CancellationException("The fiber was canceled"))),
t => {
Expand All @@ -1007,8 +1013,6 @@ sealed abstract class IO[+A] private () extends IOPlatform[A] {
},
a => cb(Right(a))
)
()
}

def unsafeRunAsyncOutcome(cb: Outcome[Id, Throwable, A @uncheckedVariance] => Unit)(
implicit runtime: unsafe.IORuntime): Unit = {
Expand Down Expand Up @@ -1111,7 +1115,8 @@ sealed abstract class IO[+A] private () extends IOPlatform[A] {
implicit runtime: unsafe.IORuntime): IOFiber[A @uncheckedVariance] = {

val fiber = new IOFiber[A](
Map.empty,
if (IOFiberConstants.ioLocalPropagation) IOLocal.getThreadLocalState()
else IOLocalState.empty,
{ oc =>
if (registerCallback) {
runtime.fiberErrorCbs.remove(failure)
Expand Down
30 changes: 30 additions & 0 deletions core/shared/src/main/scala/cats/effect/IOFiber.scala
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,18 @@ private final class IOFiber[A](
@volatile
private[this] var outcome: OutcomeIO[A] = _

def getLocalState(): IOLocalState = localState

def setLocalState(s: IOLocalState): Unit = localState = s

override def run(): Unit = {
// insert a read barrier after every async boundary
readBarrier()

if (ioLocalPropagation) {
IOFiber.setCurrentIOFiber(this)
}

(resumeTag: @switch) match {
case 0 => execR()
case 1 => asyncContinueSuccessfulR()
Expand All @@ -121,6 +130,10 @@ private final class IOFiber[A](
case 7 => autoCedeR()
case 8 => () // DoneR
}

if (ioLocalPropagation) {
IOFiber.setCurrentIOFiber(null)
}
}

/* backing fields for `cancel` and `join` */
Expand Down Expand Up @@ -1559,6 +1572,23 @@ private object IOFiber {
@static private[IOFiber] val OutcomeCanceled = Outcome.Canceled()
@static private[effect] val RightUnit = Right(())

@static private[this] val threadLocal = new ThreadLocal[IOFiber[_]]
@static def currentIOFiber(): IOFiber[_] = {
val thread = Thread.currentThread()
if (thread.isInstanceOf[WorkerThread[_]])
thread.asInstanceOf[WorkerThread[_]].currentIOFiber
else
threadLocal.get()
}

@static private def setCurrentIOFiber(f: IOFiber[_]): Unit = {
val thread = Thread.currentThread()
if (thread.isInstanceOf[WorkerThread[_]])
thread.asInstanceOf[WorkerThread[_]].currentIOFiber = f
else
threadLocal.set(f)
}

@static def onFatalFailure(t: Throwable): Nothing = {
val interrupted = Thread.interrupted()

Expand Down
Loading

0 comments on commit 8091026

Please sign in to comment.