Skip to content

Commit

Permalink
Merge branch 'main' into update/auxlib-0.5.3
Browse files Browse the repository at this point in the history
  • Loading branch information
kovstas authored Jun 8, 2024
2 parents b3f4037 + f3169cb commit 724d0ee
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 58 deletions.
104 changes: 46 additions & 58 deletions src/main/scala/dev/kovstas/fs2throttler/Throttler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@
package dev.kovstas.fs2throttler

import cats.Applicative
import cats.effect.Temporal
import cats.effect.kernel.Clock
import cats.effect.{Ref, Temporal}
import cats.implicits._
import fs2.{Pipe, Pull, Stream}

import scala.concurrent.duration._

object Throttler {
Expand Down Expand Up @@ -88,75 +87,64 @@ object Throttler {
mode: ThrottleMode,
burst: Long,
fnCost: O => F[Long]
): Pipe[F, O, O] = {
): Pipe[F, O, O] = { in =>
val capacity = if (elements + burst <= 0) Long.MaxValue else elements + burst
val interval = duration.toNanos / capacity

def go(
s: Stream[F, O],
bucket: Ref[F, (Long, FiniteDuration)],
capacity: Long,
interval: Long
tokens: => Long,
time: => Long
): Pull[F, O, Unit] = {
s.pull.uncons1.flatMap {
case Some((head, tail)) =>
Pull
.eval(for {
cost <- fnCost(head)
now <- Clock[F].monotonic
delay <- bucket.modify { case (tokens, lastUpdate) =>
if (interval == 0) {
((0, now), Duration.Zero)
} else {
val elapsed = (now - lastUpdate).toNanos
val tokensArrived =
if (elapsed >= interval) {
elapsed / interval
} else 0

val nextTime = lastUpdate + (tokensArrived * interval).nanos
val available = math.min(tokens + tokensArrived, capacity)

if (cost <= available) {
((available - cost, nextTime), Duration.Zero)
} else {
val timePassed = now.toNanos - nextTime.toNanos
val waitingTime = (cost - available) * interval
val delay = (waitingTime - timePassed).nanos

((0, now + delay), delay)
}
}
Pull.eval(fnCost(head) product Clock[F].monotonic.map(_.toNanos)).flatMap { case (cost, now) =>
val (remainingTokens, nextTime, delay) = {
val elapsed = now - time

val tokensArrived =
if (elapsed >= interval) {
elapsed / interval
} else 0
val nextTime = time + tokensArrived * interval
val available = math.min(tokens + tokensArrived, capacity)

if (cost <= available) {
(available - cost, nextTime, 0L)
} else {
val timePassed = now - nextTime
val waitingTime = (cost - available) * interval
val delay = waitingTime - timePassed

(0L, now + delay, delay)
}
}

if (delay == 0) {
Pull.output1(head) >> go(tail, remainingTokens, nextTime)
} else
mode match {
case Enforcing =>
go(tail, remainingTokens, nextTime)
case Shaping =>
Pull.sleep(delay.nanos) >> Pull.output1(head) >> go(tail, remainingTokens, nextTime)
}
continueF = Pull.output1(head) >> go(tail, bucket, capacity, interval)
result <-
if (delay == Duration.Zero) {
Applicative[F].pure(continueF)
} else {
mode match {
case Enforcing =>
Applicative[F].pure(go(tail, bucket, capacity, interval))
case Shaping =>
Clock[F].delayBy(Applicative[F].pure(continueF), delay)
}
}
} yield result)
.flatMap(identity)
}

case None =>
Pull.done
}
}

in =>
val capacity = if (elements + burst <= 0) Long.MaxValue else elements + burst

for {
bucket <- Stream.eval(
Ref.ofEffect(
Clock[F].monotonic.map((capacity, _))
)
)
stream <- go(in, bucket, capacity, duration.toNanos / capacity).stream
} yield stream
if (interval == 0) {
in
} else {
Stream
.eval(Clock[F].monotonic)
.flatMap { time =>
go(in, elements, time.toNanos).stream
}
}

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class ThrottlerSpec extends munit.FunSuite {
.unsafeToFuture()(runtime)

ctx.tick()
ctx.advanceAndTick(500.millis)
assertEquals(elements.toList, List(0, 1))
ctx.advanceAndTick(2.seconds)
assertEquals(elements.toList, List(0, 1, 2, 3, 4))
Expand Down

0 comments on commit 724d0ee

Please sign in to comment.