Skip to content

Commit

Permalink
Introduce PollerProvider
Browse files Browse the repository at this point in the history
  • Loading branch information
armanbilge committed Aug 22, 2024
1 parent 7168625 commit 7c610bb
Show file tree
Hide file tree
Showing 12 changed files with 63 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,8 @@ abstract class PollingSystem {

/**
* Creates a new instance of the user-facing interface.
*
* @param access
* callback to obtain a thread-local `Poller`.
* @return
* an instance of the user-facing interface `Api`.
*/
def makeApi(access: (Poller => Unit) => Unit): Api
def makeApi(provider: PollerProvider[Poller]): Api

/**
* Creates a new instance of the thread-local data structure used for polling.
Expand Down Expand Up @@ -109,7 +104,20 @@ abstract class PollingSystem {

}

private object PollingSystem {
trait PollerProvider[P] {

/**
* Register a callback to obtain a thread-local `Poller`
*/
def accessPoller(cb: P => Unit): Unit

/**
* Returns `true` if it is safe to interact with this `Poller`
*/
def ownPoller(poller: P): Boolean
}

object PollingSystem {

/**
* Type alias for a `PollingSystem` that has a specified `Poller` type.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ private[unsafe] abstract class IORuntimeCompanionPlatform { this: IORuntime.type

(
threadPool,
pollingSystem.makeApi(threadPool.accessPoller),
pollingSystem.makeApi(threadPool),
{ () =>
unregisterMBeans()
threadPool.shutdown()
Expand Down
14 changes: 7 additions & 7 deletions core/jvm/src/main/scala/cats/effect/unsafe/SelectorSystem.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@ import java.nio.channels.spi.{AbstractSelector, SelectorProvider}

import SelectorSystem._

final class SelectorSystem private (provider: SelectorProvider) extends PollingSystem {
final class SelectorSystem private (selectorProvider: SelectorProvider) extends PollingSystem {

type Api = Selector

def close(): Unit = ()

def makeApi(access: (Poller => Unit) => Unit): Selector =
new SelectorImpl(access, provider)
def makeApi(provider: PollerProvider[Poller]): Selector =
new SelectorImpl(provider, selectorProvider)

def makePoller(): Poller = new Poller(provider.openSelector())
def makePoller(): Poller = new Poller(selectorProvider.openSelector())

def closePoller(poller: Poller): Unit =
poller.selector.close()
Expand Down Expand Up @@ -107,15 +107,15 @@ final class SelectorSystem private (provider: SelectorProvider) extends PollingS
}

final class SelectorImpl private[SelectorSystem] (
access: (Poller => Unit) => Unit,
poller: PollerProvider[Poller],
val provider: SelectorProvider
) extends Selector {

def select(ch: SelectableChannel, ops: Int): IO[Int] = IO.async { selectCb =>
IO.async_[CallbackNode] { cb =>
access { data =>
poller.accessPoller { poller =>
try {
val selector = data.selector
val selector = poller.selector
val key = ch.keyFor(selector)

val node = if (key eq null) { // not yet registered on this selector
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ object SleepSystem extends PollingSystem {

def close(): Unit = ()

def makeApi(access: (Poller => Unit) => Unit): Api = this
def makeApi(provider: PollerProvider[Poller]): Api = this

def makePoller(): Poller = this

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ import WorkStealingThreadPool._
* contention. Work stealing is tried using a linear search starting from a random worker thread
* index.
*/
private[effect] final class WorkStealingThreadPool[P](
private[effect] final class WorkStealingThreadPool[P <: AnyRef](
threadCount: Int, // number of worker threads
private[unsafe] val threadPrefix: String, // prefix for the name of worker threads
private[unsafe] val blockerThreadPrefix: String, // prefix for the name of worker threads currently in a blocking region
Expand All @@ -71,7 +71,8 @@ private[effect] final class WorkStealingThreadPool[P](
system: PollingSystem.WithPoller[P],
reportFailure0: Throwable => Unit
) extends ExecutionContextExecutor
with Scheduler {
with Scheduler
with PollerProvider[P] {

import TracingConstants._
import WorkStealingThreadPoolConstants._
Expand All @@ -87,7 +88,7 @@ private[effect] final class WorkStealingThreadPool[P](
private[unsafe] val pollers: Array[P] =
new Array[AnyRef](threadCount).asInstanceOf[Array[P]]

private[unsafe] def accessPoller(cb: P => Unit): Unit = {
def accessPoller(cb: P => Unit): Unit = {

// figure out where we are
val thread = Thread.currentThread()
Expand All @@ -101,6 +102,14 @@ private[effect] final class WorkStealingThreadPool[P](
} else scheduleExternal(() => accessPoller(cb))
}

def ownPoller(poller: P): Boolean = {
val thread = Thread.currentThread()
if (thread.isInstanceOf[WorkerThread[_]]) {
val worker = thread.asInstanceOf[WorkerThread[P]]
worker.ownsPoller(poller)
} else false
}

/**
* Atomic variable for used for publishing changes to the references in the `workerThreads`
* array. Worker threads can be changed whenever blocking code is encountered on the pool.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import java.util.concurrent.atomic.AtomicBoolean
* system when compared to a fixed size thread pool whose worker threads all draw tasks from a
* single global work queue.
*/
private final class WorkerThread[P](
private final class WorkerThread[P <: AnyRef](
idx: Int,
// Local queue instance with exclusive write access.
private[this] var queue: LocalQueue,
Expand Down Expand Up @@ -291,6 +291,9 @@ private final class WorkerThread[P](
foreign.toMap
}

private[unsafe] def ownsPoller(poller: P): Boolean =
poller eq _poller

private[unsafe] def ownsTimers(timers: TimerHeap): Boolean =
sleepers eq timers

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ object EpollSystem extends PollingSystem {

def close(): Unit = ()

def makeApi(access: (Poller => Unit) => Unit): Api =
new FileDescriptorPollerImpl(access)
def makeApi(provider: PollerProvider[Poller]): Api =
new FileDescriptorPollerImpl(provider)

def makePoller(): Poller = {
val fd = epoll_create1(0)
Expand All @@ -67,7 +67,7 @@ object EpollSystem extends PollingSystem {
def interrupt(targetThread: Thread, targetPoller: Poller): Unit = ()

private final class FileDescriptorPollerImpl private[EpollSystem] (
access: (Poller => Unit) => Unit)
provider: PollerProvider[Poller])
extends FileDescriptorPoller {

def registerFileDescriptor(
Expand All @@ -78,7 +78,7 @@ object EpollSystem extends PollingSystem {
Resource {
(Mutex[IO], Mutex[IO]).flatMapN { (readMutex, writeMutex) =>
IO.async_[(PollHandle, IO[Unit])] { cb =>
access { epoll =>
provider.accessPoller { epoll =>
val handle = new PollHandle(readMutex, writeMutex)
epoll.register(fd, reads, writes, handle, cb)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@ private[unsafe] abstract class IORuntimeCompanionPlatform { this: IORuntime.type
): (ExecutionContext with Scheduler, system.Api, () => Unit) = {
val loop = new EventLoopExecutorScheduler[system.Poller](64, system)
val poller = loop.poller
(loop, system.makeApi(cb => cb(poller)), () => loop.shutdown())
val api = system.makeApi(
new PollerProvider[system.Poller] {
def accessPoller(cb: system.Poller => Unit) = cb(poller)
def ownPoller(poller: system.Poller) = true
}
)
(loop, api, () => loop.shutdown())
}

def createDefaultPollingSystem(): PollingSystem =
Expand Down
14 changes: 7 additions & 7 deletions core/native/src/main/scala/cats/effect/unsafe/KqueueSystem.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ object KqueueSystem extends PollingSystem {

def close(): Unit = ()

def makeApi(access: (Poller => Unit) => Unit): FileDescriptorPoller =
new FileDescriptorPollerImpl(access)
def makeApi(provider: PollerProvider[Poller]): FileDescriptorPoller =
new FileDescriptorPollerImpl(provider)

def makePoller(): Poller = {
val fd = kqueue()
Expand All @@ -67,7 +67,7 @@ object KqueueSystem extends PollingSystem {
def interrupt(targetThread: Thread, targetPoller: Poller): Unit = ()

private final class FileDescriptorPollerImpl private[KqueueSystem] (
access: (Poller => Unit) => Unit
provider: PollerProvider[Poller]
) extends FileDescriptorPoller {
def registerFileDescriptor(
fd: Int,
Expand All @@ -76,7 +76,7 @@ object KqueueSystem extends PollingSystem {
): Resource[IO, FileDescriptorPollHandle] =
Resource.eval {
(Mutex[IO], Mutex[IO]).mapN {
new PollHandle(access, fd, _, _)
new PollHandle(provider, fd, _, _)
}
}
}
Expand All @@ -86,7 +86,7 @@ object KqueueSystem extends PollingSystem {
(filter.toLong << 32) | ident.toLong

private final class PollHandle(
access: (Poller => Unit) => Unit,
provider: PollerProvider[Poller],
fd: Int,
readMutex: Mutex[IO],
writeMutex: Mutex[IO]
Expand All @@ -101,7 +101,7 @@ object KqueueSystem extends PollingSystem {
else
IO.async[Unit] { kqcb =>
IO.async_[Option[IO[Unit]]] { cb =>
access { kqueue =>
provider.accessPoller { kqueue =>
kqueue.evSet(fd, EVFILT_READ, EV_ADD.toUShort, kqcb)
cb(Right(Some(IO(kqueue.removeCallback(fd, EVFILT_READ)))))
}
Expand All @@ -121,7 +121,7 @@ object KqueueSystem extends PollingSystem {
else
IO.async[Unit] { kqcb =>
IO.async_[Option[IO[Unit]]] { cb =>
access { kqueue =>
provider.accessPoller { kqueue =>
kqueue.evSet(fd, EVFILT_WRITE, EV_ADD.toUShort, kqcb)
cb(Right(Some(IO(kqueue.removeCallback(fd, EVFILT_WRITE)))))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ abstract class PollingExecutorScheduler(pollEvery: Int)
type Poller = outer.type
private[this] var needsPoll = true
def close(): Unit = ()
def makeApi(access: (Poller => Unit) => Unit): Api = outer
def makeApi(provider: PollerProvider[Poller]): Api = outer
def makePoller(): Poller = outer
def closePoller(poller: Poller): Unit = ()
def poll(poller: Poller, nanos: Long, reportFailure: Throwable => Unit): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ object SleepSystem extends PollingSystem {

def close(): Unit = ()

def makeApi(access: (Poller => Unit) => Unit): Api = this
def makeApi(provider: PollerProvider[Poller]): Api = this

def makePoller(): Poller = this

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import cats.effect.std.Semaphore
import cats.effect.unsafe.{
IORuntime,
IORuntimeConfig,
PollerProvider,
PollingSystem,
SleepSystem,
WorkStealingThreadPool
Expand Down Expand Up @@ -513,10 +514,10 @@ trait IOPlatformSpecification extends DetectPlatform { self: BaseSpec with Scala
}
}

def makeApi(access: (Poller => Unit) => Unit): DummySystem.Api =
def makeApi(provider: PollerProvider[Poller]): DummySystem.Api =
new DummyPoller {
def poll = IO.async_[Unit] { cb =>
access { poller =>
provider.accessPoller { poller =>
poller.getAndUpdate(cb :: _)
()
}
Expand Down

0 comments on commit 7c610bb

Please sign in to comment.