Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IOLocal propagation for unsafe access #3636

Merged
merged 39 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
0b88c01
POC thread-local iolocals
armanbilge May 16, 2023
db743e2
Simplify and optimize
armanbilge May 16, 2023
716ef32
Special-case for `WorkerThread`
armanbilge May 17, 2023
0a69caf
Load locals in `unsafeRunFiber`
armanbilge May 17, 2023
2775064
Dump locals in more places
armanbilge May 18, 2023
270764f
Refactor `IOLocal`
armanbilge May 21, 2023
d55489d
Use new `IOLocal` APIs in `IOLocals`
armanbilge May 21, 2023
2cf72a5
Mark `IOLocal` methods as `final`
armanbilge May 21, 2023
cb3859d
Add `IOLocalsSpec`
armanbilge Jun 10, 2023
7dce01c
Rename property to `ioLocalPropagation` and fixes
armanbilge Jun 28, 2023
5e171ac
Bump base version
armanbilge Jun 28, 2023
c2f312d
Add files I forgot tocommit :)
armanbilge Jun 28, 2023
638930d
Merge branch 'series/3.x' into topic/thread-local-iolocal
armanbilge Jun 28, 2023
9174c6a
Add MiMa filters
armanbilge Jun 28, 2023
1987e3a
Fix scaladoc links
armanbilge Jun 28, 2023
02a43a6
Alias the disambiguations
armanbilge Jun 28, 2023
a7bf748
Copy locals back out after blocking unsafe run
armanbilge Sep 5, 2023
145fc0e
Merge remote-tracking branch 'upstream/series/3.x' into topic/thread-…
armanbilge Sep 5, 2023
fa99a5c
Expose status of `IOLocal` propagation
armanbilge Sep 25, 2023
6cad03c
`propagating` -> `arePropagating`
armanbilge Sep 29, 2023
bb5d4b1
Merge branch 'series/3.x' into topic/thread-local-iolocal
armanbilge Sep 30, 2023
7517755
Use `real` instead of `ticked`
armanbilge Sep 30, 2023
8d8e004
Formatting
armanbilge Sep 30, 2023
3589db4
Try keeping the current fiber as a thread-local instead
armanbilge Sep 30, 2023
522677e
Revert spurious whitespace changes
armanbilge Sep 30, 2023
6cc4d38
Merge branch 'series/3.x' into topic/thread-local-iolocal
armanbilge May 27, 2024
ac88480
Update headers
armanbilge May 27, 2024
49e5c30
Update platform headers
armanbilge May 27, 2024
925f504
Remove unused class
armanbilge May 27, 2024
d63a6ff
Expose `IOLocal` propagation as a `ThreadLocal`
armanbilge Jun 4, 2024
d4549fb
`unsafeToThreadLocal()` throws if propagation disabled
armanbilge Jun 4, 2024
2502045
Add scaladoc
armanbilge Jun 5, 2024
535fc8a
Factor out to JVM-only
armanbilge Jun 5, 2024
d854799
Bikeshed API and docs
armanbilge Jun 5, 2024
f070552
Formatting
armanbilge Jun 5, 2024
2cf1d8a
Delete dead code
armanbilge Jun 5, 2024
0eec9dd
Document `ThreadLocal` propagation
armanbilge Aug 5, 2024
af84973
Merge branch 'series/3.x' into topic/thread-local-iolocal
armanbilge Nov 14, 2024
1adf368
Merge branch 'series/3.x' into topic/thread-local-iolocal
armanbilge Nov 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,6 @@ private object IOFiberConstants {
final val CedeR = 6
final val AutoCedeR = 7
final val DoneR = 8

final val dumpLocals = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,5 @@ private[unsafe] sealed abstract class WorkerThread private () extends Thread {
private[unsafe] def isOwnedBy(threadPool: WorkStealingThreadPool): Boolean
private[unsafe] def monitor(fiber: Runnable): WeakBag.Handle
private[unsafe] def index: Int
private[unsafe] var ioLocalState: IOLocalState
}
2 changes: 2 additions & 0 deletions core/jvm/src/main/java/cats/effect/IOFiberConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,6 @@ final class IOFiberConstants {
static final byte CedeR = 6;
static final byte AutoCedeR = 7;
static final byte DoneR = 8;

static final boolean dumpLocals = Boolean.getBoolean("cats.effect.tracing.dumpLocals");
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bikesheddable configuration for opting-in. So the rest of us don't have to pay the penalty 😇

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this specifically "tracing", even if that's the most obvious use case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh woops, this was a very lazy copy-pasta. I copied it from the system properties we use to configure fiber tracing. We should rename it anyway, dumpLocals is not quite right I think 😅

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about cats.effect.localContextPropagation similar to Monix's monix.environment.localContextPropagation?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I liked that! I went with cats.effect.ioLocalPropagation.

}
2 changes: 2 additions & 0 deletions core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ private final class WorkerThread(
private val indexTransfer: LinkedTransferQueue[Integer] = new LinkedTransferQueue()
private[this] val runtimeBlockingExpiration: Duration = pool.runtimeBlockingExpiration

private[unsafe] var ioLocalState: IOLocalState = IOLocalState.empty

val nameIndex: Int = pool.blockedWorkerThreadNamingIndex.getAndIncrement()

// Constructor code.
Expand Down
2 changes: 1 addition & 1 deletion core/shared/src/main/scala/cats/effect/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,7 @@ sealed abstract class IO[+A] private () extends IOPlatform[A] {
implicit runtime: unsafe.IORuntime): IOFiber[A @uncheckedVariance] = {

val fiber = new IOFiber[A](
Map.empty,
if (IOFiberConstants.dumpLocals) unsafe.IOLocals.getState else Map.empty,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can even go in the opposite direction for IO#unsafeRun* 😁

It's less clear if/how to do this for fibers started in a Dispatcher, since they should be inheriting locals from the fiber backing the Dispatcher.

oc =>
oc.fold(
{
Expand Down
8 changes: 8 additions & 0 deletions core/shared/src/main/scala/cats/effect/IOFiber.scala
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,10 @@ private final class IOFiber[A](
pushTracingEvent(cur.event)
}

if (dumpLocals) {
IOLocals.setState(localState)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dumb question: can't we simply do this when we get scheduled on a thread? We know when we're on a thread and we know when we get off of it, so can't we simply set and clear the state respectively at those points?

Copy link
Member Author

@armanbilge armanbilge Sep 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No we can't, unless we unify how the state is represented. Currently it's a var to an immutable map in the fiber and also in the thread. While the fiber is running its copy of the var may be updated effectually in the runloop so the thread-local copy would need to be kept in sync with that. Or we could drive all updates through the thread-local copy of the var, but then there would be a penalty for accessing it esp. if we are not running on a worker thread.

Putting aside technical issues, nobody should be unsafely messing about with IOLocals outside of a properly suspended side-effect block and this strategy enforces that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What we can do is set the current fiber in a thread local every time we get scheduled on a thread. Then the unsafe IOLocals manipulations can operate on the state via the current fiber and we don't need to pay the penalty for every delay block. Based on the benchmarks this strategy is seeming more attractive 😅

Note this would leave the fiber's IOLocal state exposed to unsafe manipulations outside of delay blocks.

}
armanbilge marked this conversation as resolved.
Show resolved Hide resolved

var error: Throwable = null
val r =
try cur.thunk()
Expand All @@ -260,6 +264,10 @@ private final class IOFiber[A](
onFatalFailure(t)
}

if (dumpLocals) {
localState = IOLocals.getAndClearState()
}

val next =
if (error == null) succeeded(r, 0)
else failed(error, 0)
Expand Down
3 changes: 3 additions & 0 deletions core/shared/src/main/scala/cats/effect/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ package object effect {
val Ref = cekernel.Ref

private[effect] type IOLocalState = scala.collection.immutable.Map[IOLocal[_], Any]
private[effect] object IOLocalState {
val empty: IOLocalState = scala.collection.immutable.Map.empty
}

private[effect] type ByteStack = ByteStack.T
}
68 changes: 68 additions & 0 deletions core/shared/src/main/scala/cats/effect/unsafe/IOLocals.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package cats.effect
package unsafe

// TODO handle defaults and lenses. all do-able, just needs refactoring ...
object IOLocals {
djspiewak marked this conversation as resolved.
Show resolved Hide resolved
djspiewak marked this conversation as resolved.
Show resolved Hide resolved

def get[A](iol: IOLocal[A]): A = {
val thread = Thread.currentThread()
val state =
if (thread.isInstanceOf[WorkerThread])
thread.asInstanceOf[WorkerThread].ioLocalState
else
threadLocal.get
state(iol).asInstanceOf[A]
}
armanbilge marked this conversation as resolved.
Show resolved Hide resolved

def set[A](iol: IOLocal[A], value: A): Unit = {
val thread = Thread.currentThread()
if (thread.isInstanceOf[WorkerThread])
thread.asInstanceOf[WorkerThread].ioLocalState += (iol -> value)
else
threadLocal.set(threadLocal.get() + (iol -> value))
armanbilge marked this conversation as resolved.
Show resolved Hide resolved
}

def reset[A](iol: IOLocal[A]): Unit = {
val thread = Thread.currentThread()
if (thread.isInstanceOf[WorkerThread])
thread.asInstanceOf[WorkerThread].ioLocalState -= iol
else
threadLocal.set(threadLocal.get() - iol)
}

// TODO other ops from IOLocal

private[this] val threadLocal = new ThreadLocal[IOLocalState] {
override def initialValue() = IOLocalState.empty
}

private[effect] def getState = {
val thread = Thread.currentThread()
if (thread.isInstanceOf[WorkerThread])
thread.asInstanceOf[WorkerThread].ioLocalState
else
threadLocal.get()
}

private[effect] def setState(state: IOLocalState) = {
val thread = Thread.currentThread()
if (thread.isInstanceOf[WorkerThread])
thread.asInstanceOf[WorkerThread].ioLocalState = state
else
threadLocal.set(state)
}

private[effect] def getAndClearState() = {
val thread = Thread.currentThread()
if (thread.isInstanceOf[WorkerThread]) {
val worker = thread.asInstanceOf[WorkerThread]
val state = worker.ioLocalState
worker.ioLocalState = IOLocalState.empty
state
} else {
val state = threadLocal.get()
threadLocal.set(IOLocalState.empty)
state
}
}
}