diff --git a/core/jvm/src/main/java/cats/effect/unsafe/IOLocalsConstants.java b/core/jvm/src/main/java/cats/effect/unsafe/IOLocalsConstants.java deleted file mode 100644 index 14fa36d9a0..0000000000 --- a/core/jvm/src/main/java/cats/effect/unsafe/IOLocalsConstants.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright 2020-2024 Typelevel - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package cats.effect.unsafe; - -// defined in Java since Scala doesn't let us define static fields -final class IOLocalsConstants { - static final boolean ioLocalPropagation = Boolean.getBoolean("cats.effect.ioLocalPropagation"); -} diff --git a/core/jvm/src/main/scala/cats/effect/IOPlatform.scala b/core/jvm/src/main/scala/cats/effect/IOPlatform.scala index afe36ca1fb..c53654eafc 100644 --- a/core/jvm/src/main/scala/cats/effect/IOPlatform.scala +++ b/core/jvm/src/main/scala/cats/effect/IOPlatform.scala @@ -84,7 +84,7 @@ abstract private[effect] class IOPlatform[+A] extends Serializable { self: IO[A] None } finally { if (IOFiberConstants.ioLocalPropagation) - unsafe.IOLocals.setState(fiber.getLocalState()) + IOLocal.setThreadLocalState(fiber.getLocalState()) } } diff --git a/core/shared/src/main/scala/cats/effect/IO.scala b/core/shared/src/main/scala/cats/effect/IO.scala index bd8074ce7d..dc9d55e46e 100644 --- a/core/shared/src/main/scala/cats/effect/IO.scala +++ b/core/shared/src/main/scala/cats/effect/IO.scala @@ -1076,7 +1076,8 @@ sealed abstract class IO[+A] private () extends IOPlatform[A] { implicit runtime: unsafe.IORuntime): IOFiber[A @uncheckedVariance] = { val fiber = new IOFiber[A]( - if (IOFiberConstants.ioLocalPropagation) unsafe.IOLocals.getState else Map.empty, + if (IOFiberConstants.ioLocalPropagation) IOLocal.getThreadLocalState() + else IOLocalState.empty, oc => oc.fold( { diff --git a/core/shared/src/main/scala/cats/effect/IOLocal.scala b/core/shared/src/main/scala/cats/effect/IOLocal.scala index 16e1c8cdac..423b788941 100644 --- a/core/shared/src/main/scala/cats/effect/IOLocal.scala +++ b/core/shared/src/main/scala/cats/effect/IOLocal.scala @@ -18,6 +18,8 @@ package cats.effect import cats.data.AndThen +import IOFiberConstants.ioLocalPropagation + /** * [[IOLocal]] provides a handy way of manipulating a context on different scopes. * @@ -136,7 +138,7 @@ import cats.data.AndThen * @tparam A * the type of the local value */ -sealed trait IOLocal[A] { +sealed trait IOLocal[A] { self => protected[effect] def getOrDefault(state: IOLocalState): A @@ -238,6 +240,28 @@ sealed trait IOLocal[A] { */ def lens[B](get: A => B)(set: A => B => A): IOLocal[B] + def unsafeToThreadLocal(): ThreadLocal[A] = new ThreadLocal[A] { + override def get(): A = if (ioLocalPropagation) { + val fiber = IOFiber.currentIOFiber() + val state = if (fiber ne null) fiber.getLocalState() else IOLocalState.empty + self.getOrDefault(state) + } else self.getOrDefault(IOLocalState.empty) + + override def set(value: A): Unit = if (ioLocalPropagation) { + val fiber = IOFiber.currentIOFiber() + if (fiber ne null) { + fiber.setLocalState(self.set(fiber.getLocalState(), value)) + } + } + + override def remove(): Unit = if (ioLocalPropagation) { + val fiber = IOFiber.currentIOFiber() + if (fiber ne null) { + fiber.setLocalState(self.reset(fiber.getLocalState())) + } + } + } + } object IOLocal { @@ -255,6 +279,21 @@ object IOLocal { */ def apply[A](default: A): IO[IOLocal[A]] = IO(new IOLocalImpl(default)) + /** + * `true` if IOLocal-Threadlocal propagation is enabled + */ + def isPropagating: Boolean = IOFiberConstants.ioLocalPropagation + + private[effect] def getThreadLocalState() = { + val fiber = IOFiber.currentIOFiber() + if (fiber ne null) fiber.getLocalState() else IOLocalState.empty + } + + private[effect] def setThreadLocalState(state: IOLocalState) = { + val fiber = IOFiber.currentIOFiber() + if (fiber ne null) fiber.setLocalState(state) + } + private final class IOLocalImpl[A](default: A) extends IOLocal[A] { def getOrDefault(state: IOLocalState): A = diff --git a/core/shared/src/main/scala/cats/effect/unsafe/IOLocals.scala b/core/shared/src/main/scala/cats/effect/unsafe/IOLocals.scala deleted file mode 100644 index 86724e42b3..0000000000 --- a/core/shared/src/main/scala/cats/effect/unsafe/IOLocals.scala +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Copyright 2020-2024 Typelevel - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package cats.effect -package unsafe - -import IOLocalsConstants.ioLocalPropagation - -object IOLocals { - - /** - * `true` if IOLocal propagation is enabled - */ - def arePropagating: Boolean = ioLocalPropagation - - def get[A](iol: IOLocal[A]): A = if (ioLocalPropagation) { - val fiber = IOFiber.currentIOFiber() - val state = if (fiber ne null) fiber.getLocalState() else IOLocalState.empty - iol.getOrDefault(state) - } else iol.getOrDefault(IOLocalState.empty) - - def set[A](iol: IOLocal[A], value: A): Unit = if (ioLocalPropagation) { - val fiber = IOFiber.currentIOFiber() - if (fiber ne null) { - fiber.setLocalState(iol.set(fiber.getLocalState(), value)) - } - } - - def reset[A](iol: IOLocal[A]): Unit = if (ioLocalPropagation) { - val fiber = IOFiber.currentIOFiber() - if (fiber ne null) { - fiber.setLocalState(iol.reset(fiber.getLocalState())) - } - } - - def update[A](iol: IOLocal[A])(f: A => A): Unit = if (ioLocalPropagation) { - val fiber = IOFiber.currentIOFiber() - if (fiber ne null) { - val state = fiber.getLocalState() - fiber.setLocalState(iol.set(state, f(iol.getOrDefault(state)))) - } - } - - def modify[A, B](iol: IOLocal[A])(f: A => (A, B)): B = if (ioLocalPropagation) { - val fiber = IOFiber.currentIOFiber() - if (fiber ne null) { - val state = fiber.getLocalState() - val (a2, b) = f(iol.getOrDefault(state)) - fiber.setLocalState(iol.set(state, a2)) - b - } else f(iol.getOrDefault(IOLocalState.empty))._2 - } else f(iol.getOrDefault(IOLocalState.empty))._2 - - def getAndSet[A](iol: IOLocal[A], a: A): A = if (ioLocalPropagation) { - val fiber = IOFiber.currentIOFiber() - if (fiber ne null) { - val state = fiber.getLocalState() - fiber.setLocalState(iol.set(state, a)) - iol.getOrDefault(state) - } else iol.getOrDefault(IOLocalState.empty) - } else iol.getOrDefault(IOLocalState.empty) - - def getAndReset[A](iol: IOLocal[A]): A = if (ioLocalPropagation) { - val fiber = IOFiber.currentIOFiber() - if (fiber ne null) { - val state = fiber.getLocalState() - fiber.setLocalState(iol.reset(state)) - iol.getOrDefault(state) - } else iol.getOrDefault(IOLocalState.empty) - } else iol.getOrDefault(IOLocalState.empty) - - private[effect] def getState = { - val fiber = IOFiber.currentIOFiber() - if (fiber ne null) fiber.getLocalState() else IOLocalState.empty - } - - private[effect] def setState(state: IOLocalState) = { - val fiber = IOFiber.currentIOFiber() - if (fiber ne null) fiber.setLocalState(state) - } - - // private[effect] def getAndClearState() = { - // val thread = Thread.currentThread() - // if (thread.isInstanceOf[WorkerThread[_]]) { - // val worker = thread.asInstanceOf[WorkerThread[_]] - // val state = worker.ioLocalState - // worker.ioLocalState = IOLocalState.empty - // state - // } else { - // val state = threadLocal.get() - // threadLocal.set(IOLocalState.empty) - // state - // } - // } -} diff --git a/tests/jvm/src/test/scala/cats/effect/unsafe/IOLocalsSpec.scala b/tests/jvm/src/test/scala/cats/effect/unsafe/IOLocalsSpec.scala index fe66a6fa9f..97c06e7047 100644 --- a/tests/jvm/src/test/scala/cats/effect/unsafe/IOLocalsSpec.scala +++ b/tests/jvm/src/test/scala/cats/effect/unsafe/IOLocalsSpec.scala @@ -21,61 +21,35 @@ class IOLocalsSpec extends BaseSpec { "IOLocals" should { "return a default value" in real { - IOLocal(42).flatMap(local => IO(IOLocals.get(local))).map(_ must beEqualTo(42)) + IOLocal(42) + .flatMap(local => IO(local.unsafeToThreadLocal().get())) + .map(_ must beEqualTo(42)) } "return a set value" in real { - IOLocal(42) - .flatMap(local => local.set(24) *> IO(IOLocals.get(local))) - .map(_ must beEqualTo(24)) + for { + local <- IOLocal(42) + threadLocal <- IO(local.unsafeToThreadLocal()) + _ <- local.set(24) + got <- IO(threadLocal.get()) + } yield got must beEqualTo(24) } "unsafely set" in real { IOLocal(42).flatMap(local => - IO(IOLocals.set(local, 24)) *> local.get.map(_ must beEqualTo(24))) + IO(local.unsafeToThreadLocal().set(24)) *> local.get.map(_ must beEqualTo(24))) } "unsafely reset" in real { - IOLocal(42) - .flatMap(local => local.set(24) *> IO(IOLocals.reset(local)) *> local.get) - .map(_ must beEqualTo(42)) + for { + local <- IOLocal(42) + threadLocal <- IO(local.unsafeToThreadLocal()) + _ <- local.set(24) + _ <- IO(threadLocal.remove()) + got <- local.get + } yield got must beEqualTo(42) } - "unsafely update" in real { - IOLocal(42) - .flatMap(local => IO(IOLocals.update(local)(_ * 2)) *> local.get) - .map(_ must beEqualTo(84)) - } - - "unsafely modify" in real { - IOLocal(42) - .flatMap { local => - IO { - IOLocals.modify(local)(x => (x * 2, x.toString)) must beEqualTo("42") - } *> local.get - } - .map(_ must beEqualTo(84)) - } - - "unsafely getAndSet" in real { - IOLocal(42) - .flatMap { local => - IO { - IOLocals.getAndSet(local, 24) must beEqualTo(42) - } *> local.get - } - .map(_ must beEqualTo(24)) - } - - "unsafely getAndReset" in real { - IOLocal(42) - .flatMap { local => - local.set(24) *> IO { - IOLocals.getAndReset(local) must beEqualTo(24) - } *> local.get - } - .map(_ must beEqualTo(42)) - } } }