Skip to content

Commit

Permalink
Add rate limiter primitives (#235)
Browse files Browse the repository at this point in the history
Co-authored-by: pablf <[email protected]>
Co-authored-by: adamw <[email protected]>
  • Loading branch information
3 people authored Nov 15, 2024
1 parent 2d0b4c0 commit 6bdd3d0
Show file tree
Hide file tree
Showing 7 changed files with 880 additions and 1 deletion.
7 changes: 6 additions & 1 deletion core/src/main/scala/ox/fork.scala
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,16 @@ def forkCancellable[T](f: => T)(using OxUnsupervised): CancellableFork[T] =
end new
end forkCancellable

/** Same as [[fork]], but discards the resulting [[Fork]], to avoid compiler warnings. That is, the fork is run only for its side-effects,
/** Same as [[fork]], but discards the resulting [[Fork]], to avoid compiler warnings. That is, the fork is run only for its side effects,
* it's not possible to join it.
*/
inline def forkDiscard[T](inline f: T)(using Ox): Unit = fork(f).discard

/** Same as [[forkUser]], but discards the resulting [[Fork]], to avoid compiler warnings. That is, the fork is run only for its side
* effects, it's not possible to join it.
*/
inline def forkUserDiscard[T](inline f: T)(using Ox): Unit = forkUser(f).discard

private trait ForkUsingResult[T](result: CompletableFuture[T]) extends Fork[T]:
override def join(): T = unwrapExecutionException(result.get())
override private[ox] def wasInterruptedWith(ie: InterruptedException): Boolean =
Expand Down
78 changes: 78 additions & 0 deletions core/src/main/scala/ox/resilience/RateLimiter.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package ox.resilience

import scala.concurrent.duration.FiniteDuration
import ox.*

import scala.annotation.tailrec

/** Rate limiter with a customizable algorithm. Operations can be blocked or dropped, when the rate limit is reached. */
class RateLimiter private (algorithm: RateLimiterAlgorithm):
/** Runs the operation, blocking if the rate limit is reached, until the rate limiter is replenished. */
def runBlocking[T](operation: => T): T =
algorithm.acquire()
operation

/** Runs or drops the operation, if the rate limit is reached.
*
* @return
* `Some` if the operation has been allowed to run, `None` if the operation has been dropped.
*/
def runOrDrop[T](operation: => T): Option[T] =
if algorithm.tryAcquire() then Some(operation)
else None

end RateLimiter

object RateLimiter:
def apply(algorithm: RateLimiterAlgorithm)(using Ox): RateLimiter =
@tailrec
def update(): Unit =
val waitTime = algorithm.getNextUpdate
val millis = waitTime / 1000000
val nanos = waitTime % 1000000
Thread.sleep(millis, nanos.toInt)
algorithm.update()
update()
end update

forkDiscard(update())
new RateLimiter(algorithm)
end apply

/** Creates a rate limiter using a fixed window algorithm.
*
* Must be run within an [[Ox]] concurrency scope, as a background fork is created, to replenish the rate limiter.
*
* @param maxOperations
* Maximum number of operations that are allowed to **start** within a time [[window]].
* @param window
* Interval of time between replenishing the rate limiter. THe rate limiter is replenished to allow up to [[maxOperations]] in the next
* time window.
*/
def fixedWindow(maxOperations: Int, window: FiniteDuration)(using Ox): RateLimiter =
apply(RateLimiterAlgorithm.FixedWindow(maxOperations, window))

/** Creates a rate limiter using a sliding window algorithm.
*
* Must be run within an [[Ox]] concurrency scope, as a background fork is created, to replenish the rate limiter.
*
* @param maxOperations
* Maximum number of operations that are allowed to **start** within any [[window]] of time.
* @param window
* Length of the window.
*/
def slidingWindow(maxOperations: Int, window: FiniteDuration)(using Ox): RateLimiter =
apply(RateLimiterAlgorithm.SlidingWindow(maxOperations, window))

/** Rate limiter with token/leaky bucket algorithm.
*
* Must be run within an [[Ox]] concurrency scope, as a background fork is created, to replenish the rate limiter.
*
* @param maxTokens
* Max capacity of tokens in the algorithm, limiting the operations that are allowed to **start** concurrently.
* @param refillInterval
* Interval of time between adding a single token to the bucket.
*/
def leakyBucket(maxTokens: Int, refillInterval: FiniteDuration)(using Ox): RateLimiter =
apply(RateLimiterAlgorithm.LeakyBucket(maxTokens, refillInterval))
end RateLimiter
142 changes: 142 additions & 0 deletions core/src/main/scala/ox/resilience/RateLimiterAlgorithm.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
package ox.resilience

import scala.concurrent.duration.FiniteDuration
import scala.collection.immutable.Queue
import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.atomic.AtomicReference
import java.util.concurrent.Semaphore
import scala.annotation.tailrec

/** Determines the algorithm to use for the rate limiter */
trait RateLimiterAlgorithm:

/** Acquires a permit to execute the operation. This method should block until a permit is available. */
final def acquire(): Unit =
acquire(1)

/** Acquires permits to execute the operation. This method should block until a permit is available. */
def acquire(permits: Int): Unit

/** Tries to acquire a permit to execute the operation. This method should not block. */
final def tryAcquire(): Boolean =
tryAcquire(1)

/** Tries to acquire permits to execute the operation. This method should not block. */
def tryAcquire(permits: Int): Boolean

/** Updates the internal state of the rate limiter to check whether new operations can be accepted. */
def update(): Unit

/** Returns the time in nanoseconds that needs to elapse until the next update. It should not modify internal state. */
def getNextUpdate: Long

end RateLimiterAlgorithm

object RateLimiterAlgorithm:
/** Fixed window algorithm: allows starting at most `rate` operations in consecutively segments of duration `per`. */
case class FixedWindow(rate: Int, per: FiniteDuration) extends RateLimiterAlgorithm:
private val lastUpdate = new AtomicLong(System.nanoTime())
private val semaphore = new Semaphore(rate)

def acquire(permits: Int): Unit =
semaphore.acquire(permits)

def tryAcquire(permits: Int): Boolean =
semaphore.tryAcquire(permits)

def getNextUpdate: Long =
val waitTime = lastUpdate.get() + per.toNanos - System.nanoTime()
if waitTime > 0 then waitTime else 0L

def update(): Unit =
val now = System.nanoTime()
lastUpdate.set(now)
semaphore.release(rate - semaphore.availablePermits())
end update

end FixedWindow

/** Sliding window algorithm: allows to start at most `rate` operations in the lapse of `per` before current time. */
case class SlidingWindow(rate: Int, per: FiniteDuration) extends RateLimiterAlgorithm:
// stores the timestamp and the number of permits acquired after calling acquire or tryAcquire successfully
private val log = new AtomicReference[Queue[(Long, Int)]](Queue[(Long, Int)]())
private val semaphore = new Semaphore(rate)

def acquire(permits: Int): Unit =
semaphore.acquire(permits)
addTimestampToLog(permits)

def tryAcquire(permits: Int): Boolean =
if semaphore.tryAcquire(permits) then
addTimestampToLog(permits)
true
else false

private def addTimestampToLog(permits: Int): Unit =
val now = System.nanoTime()
log.updateAndGet { q =>
q.enqueue((now, permits))
}
()

def getNextUpdate: Long =
log.get().headOption match
case None =>
// no logs so no need to update until `per` has passed
per.toNanos
case Some(record) =>
// oldest log provides the new updating point
val waitTime = record._1 + per.toNanos - System.nanoTime()
if waitTime > 0 then waitTime else 0L
end getNextUpdate

def update(): Unit =
val now = System.nanoTime()
// retrieving current queue to append it later if some elements were added concurrently
val q = log.getAndUpdate(_ => Queue[(Long, Int)]())
// remove records older than window size
val qUpdated = removeRecords(q, now)
// merge old records with the ones concurrently added
val _ = log.updateAndGet(qNew =>
qNew.foldLeft(qUpdated) { case (queue, record) =>
queue.enqueue(record)
}
)
end update

@tailrec
private def removeRecords(q: Queue[(Long, Int)], now: Long): Queue[(Long, Int)] =
q.dequeueOption match
case None => q
case Some((head, tail)) =>
if head._1 + per.toNanos < now then
val (_, permits) = head
semaphore.release(permits)
removeRecords(tail, now)
else q

end SlidingWindow

/** Token/leaky bucket algorithm It adds a token to start an new operation each `per` with a maximum number of tokens of `rate`. */
case class LeakyBucket(rate: Int, per: FiniteDuration) extends RateLimiterAlgorithm:
private val refillInterval = per.toNanos
private val lastRefillTime = new AtomicLong(System.nanoTime())
private val semaphore = new Semaphore(1)

def acquire(permits: Int): Unit =
semaphore.acquire(permits)

def tryAcquire(permits: Int): Boolean =
semaphore.tryAcquire(permits)

def getNextUpdate: Long =
val waitTime = lastRefillTime.get() + refillInterval - System.nanoTime()
if waitTime > 0 then waitTime else 0L

def update(): Unit =
val now = System.nanoTime()
lastRefillTime.set(now)
if semaphore.availablePermits() < rate then semaphore.release()

end LeakyBucket
end RateLimiterAlgorithm
126 changes: 126 additions & 0 deletions core/src/test/scala/ox/resilience/RateLimiterInterfaceTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package ox.resilience

import ox.*
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import org.scalatest.{EitherValues, TryValues}
import scala.concurrent.duration.*

class RateLimiterInterfaceTest extends AnyFlatSpec with Matchers with EitherValues with TryValues:
behavior of "RateLimiter interface"

it should "drop or block operation depending on method used for fixed rate algorithm" in {
supervised:
val rateLimiter = RateLimiter.fixedWindow(2, FiniteDuration(1, "second"))

var executions = 0
def operation =
executions += 1
0

val result1 = rateLimiter.runOrDrop(operation)
val result2 = rateLimiter.runOrDrop(operation)
val result3 = rateLimiter.runOrDrop(operation)
val result4 = rateLimiter.runBlocking(operation)
val result5 = rateLimiter.runBlocking(operation)
val result6 = rateLimiter.runOrDrop(operation)

result1 shouldBe Some(0)
result2 shouldBe Some(0)
result3 shouldBe None
result4 shouldBe 0
result5 shouldBe 0
result6 shouldBe None
executions shouldBe 4
}

it should "drop or block operation depending on method used for sliding window algorithm" in {
supervised:
val rateLimiter = RateLimiter.slidingWindow(2, FiniteDuration(1, "second"))

var executions = 0
def operation =
executions += 1
0

val result1 = rateLimiter.runOrDrop(operation)
val result2 = rateLimiter.runOrDrop(operation)
val result3 = rateLimiter.runOrDrop(operation)
val result4 = rateLimiter.runBlocking(operation)
val result5 = rateLimiter.runBlocking(operation)
val result6 = rateLimiter.runOrDrop(operation)

result1 shouldBe Some(0)
result2 shouldBe Some(0)
result3 shouldBe None
result4 shouldBe 0
result5 shouldBe 0
result6 shouldBe None
executions shouldBe 4
}

it should "drop or block operation depending on method used for bucket algorithm" in {
supervised:
val rateLimiter = RateLimiter.leakyBucket(2, FiniteDuration(1, "second"))

var executions = 0
def operation =
executions += 1
0

val result1 = rateLimiter.runOrDrop(operation)
val result2 = rateLimiter.runOrDrop(operation)
val result3 = rateLimiter.runOrDrop(operation)
val result4 = rateLimiter.runBlocking(operation)
val result5 = rateLimiter.runBlocking(operation)
val result6 = rateLimiter.runOrDrop(operation)

result1 shouldBe Some(0)
result2 shouldBe None
result3 shouldBe None
result4 shouldBe 0
result5 shouldBe 0
result6 shouldBe None
executions shouldBe 3
}

it should "drop or block operation concurrently" in {
supervised:
val rateLimiter = RateLimiter.fixedWindow(2, FiniteDuration(1, "second"))

def operation = 0

var result1: Option[Int] = Some(-1)
var result2: Option[Int] = Some(-1)
var result3: Option[Int] = Some(-1)
var result4: Int = -1
var result5: Int = -1
var result6: Int = -1

// run two operations to block the rate limiter
rateLimiter.runOrDrop(operation).discard
rateLimiter.runOrDrop(operation).discard

// operations with runOrDrop should be dropped while operations with runBlocking should wait
supervised:
forkUserDiscard:
result1 = rateLimiter.runOrDrop(operation)
forkUserDiscard:
result2 = rateLimiter.runOrDrop(operation)
forkUserDiscard:
result3 = rateLimiter.runOrDrop(operation)
forkUserDiscard:
result4 = rateLimiter.runBlocking(operation)
forkUserDiscard:
result5 = rateLimiter.runBlocking(operation)
forkUserDiscard:
result6 = rateLimiter.runBlocking(operation)

result1 shouldBe None
result2 shouldBe None
result3 shouldBe None
result4 shouldBe 0
result5 shouldBe 0
result6 shouldBe 0
}
end RateLimiterInterfaceTest
Loading

0 comments on commit 6bdd3d0

Please sign in to comment.