Skip to content

Commit

Permalink
Expose IOLocal propagation as a ThreadLocal
Browse files Browse the repository at this point in the history
  • Loading branch information
armanbilge committed Jun 4, 2024
1 parent 925f504 commit d63a6ff
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 176 deletions.
22 changes: 0 additions & 22 deletions core/jvm/src/main/java/cats/effect/unsafe/IOLocalsConstants.java

This file was deleted.

2 changes: 1 addition & 1 deletion core/jvm/src/main/scala/cats/effect/IOPlatform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ abstract private[effect] class IOPlatform[+A] extends Serializable { self: IO[A]
None
} finally {
if (IOFiberConstants.ioLocalPropagation)
unsafe.IOLocals.setState(fiber.getLocalState())
IOLocal.setThreadLocalState(fiber.getLocalState())
}
}

Expand Down
3 changes: 2 additions & 1 deletion core/shared/src/main/scala/cats/effect/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1076,7 +1076,8 @@ sealed abstract class IO[+A] private () extends IOPlatform[A] {
implicit runtime: unsafe.IORuntime): IOFiber[A @uncheckedVariance] = {

val fiber = new IOFiber[A](
if (IOFiberConstants.ioLocalPropagation) unsafe.IOLocals.getState else Map.empty,
if (IOFiberConstants.ioLocalPropagation) IOLocal.getThreadLocalState()
else IOLocalState.empty,
oc =>
oc.fold(
{
Expand Down
41 changes: 40 additions & 1 deletion core/shared/src/main/scala/cats/effect/IOLocal.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package cats.effect

import cats.data.AndThen

import IOFiberConstants.ioLocalPropagation

/**
* [[IOLocal]] provides a handy way of manipulating a context on different scopes.
*
Expand Down Expand Up @@ -136,7 +138,7 @@ import cats.data.AndThen
* @tparam A
* the type of the local value
*/
sealed trait IOLocal[A] {
sealed trait IOLocal[A] { self =>

protected[effect] def getOrDefault(state: IOLocalState): A

Expand Down Expand Up @@ -238,6 +240,28 @@ sealed trait IOLocal[A] {
*/
def lens[B](get: A => B)(set: A => B => A): IOLocal[B]

def unsafeToThreadLocal(): ThreadLocal[A] = new ThreadLocal[A] {
override def get(): A = if (ioLocalPropagation) {
val fiber = IOFiber.currentIOFiber()
val state = if (fiber ne null) fiber.getLocalState() else IOLocalState.empty
self.getOrDefault(state)
} else self.getOrDefault(IOLocalState.empty)

override def set(value: A): Unit = if (ioLocalPropagation) {
val fiber = IOFiber.currentIOFiber()
if (fiber ne null) {
fiber.setLocalState(self.set(fiber.getLocalState(), value))
}
}

override def remove(): Unit = if (ioLocalPropagation) {
val fiber = IOFiber.currentIOFiber()
if (fiber ne null) {
fiber.setLocalState(self.reset(fiber.getLocalState()))
}
}
}

}

object IOLocal {
Expand All @@ -255,6 +279,21 @@ object IOLocal {
*/
def apply[A](default: A): IO[IOLocal[A]] = IO(new IOLocalImpl(default))

/**
* `true` if IOLocal-Threadlocal propagation is enabled
*/
def isPropagating: Boolean = IOFiberConstants.ioLocalPropagation

private[effect] def getThreadLocalState() = {
val fiber = IOFiber.currentIOFiber()
if (fiber ne null) fiber.getLocalState() else IOLocalState.empty
}

private[effect] def setThreadLocalState(state: IOLocalState) = {
val fiber = IOFiber.currentIOFiber()
if (fiber ne null) fiber.setLocalState(state)
}

private final class IOLocalImpl[A](default: A) extends IOLocal[A] {

def getOrDefault(state: IOLocalState): A =
Expand Down
108 changes: 0 additions & 108 deletions core/shared/src/main/scala/cats/effect/unsafe/IOLocals.scala

This file was deleted.

60 changes: 17 additions & 43 deletions tests/jvm/src/test/scala/cats/effect/unsafe/IOLocalsSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,61 +21,35 @@ class IOLocalsSpec extends BaseSpec {

"IOLocals" should {
"return a default value" in real {
IOLocal(42).flatMap(local => IO(IOLocals.get(local))).map(_ must beEqualTo(42))
IOLocal(42)
.flatMap(local => IO(local.unsafeToThreadLocal().get()))
.map(_ must beEqualTo(42))
}

"return a set value" in real {
IOLocal(42)
.flatMap(local => local.set(24) *> IO(IOLocals.get(local)))
.map(_ must beEqualTo(24))
for {
local <- IOLocal(42)
threadLocal <- IO(local.unsafeToThreadLocal())
_ <- local.set(24)
got <- IO(threadLocal.get())
} yield got must beEqualTo(24)
}

"unsafely set" in real {
IOLocal(42).flatMap(local =>
IO(IOLocals.set(local, 24)) *> local.get.map(_ must beEqualTo(24)))
IO(local.unsafeToThreadLocal().set(24)) *> local.get.map(_ must beEqualTo(24)))
}

"unsafely reset" in real {
IOLocal(42)
.flatMap(local => local.set(24) *> IO(IOLocals.reset(local)) *> local.get)
.map(_ must beEqualTo(42))
for {
local <- IOLocal(42)
threadLocal <- IO(local.unsafeToThreadLocal())
_ <- local.set(24)
_ <- IO(threadLocal.remove())
got <- local.get
} yield got must beEqualTo(42)
}

"unsafely update" in real {
IOLocal(42)
.flatMap(local => IO(IOLocals.update(local)(_ * 2)) *> local.get)
.map(_ must beEqualTo(84))
}

"unsafely modify" in real {
IOLocal(42)
.flatMap { local =>
IO {
IOLocals.modify(local)(x => (x * 2, x.toString)) must beEqualTo("42")
} *> local.get
}
.map(_ must beEqualTo(84))
}

"unsafely getAndSet" in real {
IOLocal(42)
.flatMap { local =>
IO {
IOLocals.getAndSet(local, 24) must beEqualTo(42)
} *> local.get
}
.map(_ must beEqualTo(24))
}

"unsafely getAndReset" in real {
IOLocal(42)
.flatMap { local =>
local.set(24) *> IO {
IOLocals.getAndReset(local) must beEqualTo(24)
} *> local.get
}
.map(_ must beEqualTo(42))
}
}

}

0 comments on commit d63a6ff

Please sign in to comment.