-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: pablf <[email protected]> Co-authored-by: adamw <[email protected]>
- Loading branch information
1 parent
2d0b4c0
commit 6bdd3d0
Showing
7 changed files
with
880 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
package ox.resilience | ||
|
||
import scala.concurrent.duration.FiniteDuration | ||
import ox.* | ||
|
||
import scala.annotation.tailrec | ||
|
||
/** Rate limiter with a customizable algorithm. Operations can be blocked or dropped, when the rate limit is reached. */ | ||
class RateLimiter private (algorithm: RateLimiterAlgorithm): | ||
/** Runs the operation, blocking if the rate limit is reached, until the rate limiter is replenished. */ | ||
def runBlocking[T](operation: => T): T = | ||
algorithm.acquire() | ||
operation | ||
|
||
/** Runs or drops the operation, if the rate limit is reached. | ||
* | ||
* @return | ||
* `Some` if the operation has been allowed to run, `None` if the operation has been dropped. | ||
*/ | ||
def runOrDrop[T](operation: => T): Option[T] = | ||
if algorithm.tryAcquire() then Some(operation) | ||
else None | ||
|
||
end RateLimiter | ||
|
||
object RateLimiter: | ||
def apply(algorithm: RateLimiterAlgorithm)(using Ox): RateLimiter = | ||
@tailrec | ||
def update(): Unit = | ||
val waitTime = algorithm.getNextUpdate | ||
val millis = waitTime / 1000000 | ||
val nanos = waitTime % 1000000 | ||
Thread.sleep(millis, nanos.toInt) | ||
algorithm.update() | ||
update() | ||
end update | ||
|
||
forkDiscard(update()) | ||
new RateLimiter(algorithm) | ||
end apply | ||
|
||
/** Creates a rate limiter using a fixed window algorithm. | ||
* | ||
* Must be run within an [[Ox]] concurrency scope, as a background fork is created, to replenish the rate limiter. | ||
* | ||
* @param maxOperations | ||
* Maximum number of operations that are allowed to **start** within a time [[window]]. | ||
* @param window | ||
* Interval of time between replenishing the rate limiter. THe rate limiter is replenished to allow up to [[maxOperations]] in the next | ||
* time window. | ||
*/ | ||
def fixedWindow(maxOperations: Int, window: FiniteDuration)(using Ox): RateLimiter = | ||
apply(RateLimiterAlgorithm.FixedWindow(maxOperations, window)) | ||
|
||
/** Creates a rate limiter using a sliding window algorithm. | ||
* | ||
* Must be run within an [[Ox]] concurrency scope, as a background fork is created, to replenish the rate limiter. | ||
* | ||
* @param maxOperations | ||
* Maximum number of operations that are allowed to **start** within any [[window]] of time. | ||
* @param window | ||
* Length of the window. | ||
*/ | ||
def slidingWindow(maxOperations: Int, window: FiniteDuration)(using Ox): RateLimiter = | ||
apply(RateLimiterAlgorithm.SlidingWindow(maxOperations, window)) | ||
|
||
/** Rate limiter with token/leaky bucket algorithm. | ||
* | ||
* Must be run within an [[Ox]] concurrency scope, as a background fork is created, to replenish the rate limiter. | ||
* | ||
* @param maxTokens | ||
* Max capacity of tokens in the algorithm, limiting the operations that are allowed to **start** concurrently. | ||
* @param refillInterval | ||
* Interval of time between adding a single token to the bucket. | ||
*/ | ||
def leakyBucket(maxTokens: Int, refillInterval: FiniteDuration)(using Ox): RateLimiter = | ||
apply(RateLimiterAlgorithm.LeakyBucket(maxTokens, refillInterval)) | ||
end RateLimiter |
142 changes: 142 additions & 0 deletions
142
core/src/main/scala/ox/resilience/RateLimiterAlgorithm.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
package ox.resilience | ||
|
||
import scala.concurrent.duration.FiniteDuration | ||
import scala.collection.immutable.Queue | ||
import java.util.concurrent.atomic.AtomicLong | ||
import java.util.concurrent.atomic.AtomicReference | ||
import java.util.concurrent.Semaphore | ||
import scala.annotation.tailrec | ||
|
||
/** Determines the algorithm to use for the rate limiter */ | ||
trait RateLimiterAlgorithm: | ||
|
||
/** Acquires a permit to execute the operation. This method should block until a permit is available. */ | ||
final def acquire(): Unit = | ||
acquire(1) | ||
|
||
/** Acquires permits to execute the operation. This method should block until a permit is available. */ | ||
def acquire(permits: Int): Unit | ||
|
||
/** Tries to acquire a permit to execute the operation. This method should not block. */ | ||
final def tryAcquire(): Boolean = | ||
tryAcquire(1) | ||
|
||
/** Tries to acquire permits to execute the operation. This method should not block. */ | ||
def tryAcquire(permits: Int): Boolean | ||
|
||
/** Updates the internal state of the rate limiter to check whether new operations can be accepted. */ | ||
def update(): Unit | ||
|
||
/** Returns the time in nanoseconds that needs to elapse until the next update. It should not modify internal state. */ | ||
def getNextUpdate: Long | ||
|
||
end RateLimiterAlgorithm | ||
|
||
object RateLimiterAlgorithm: | ||
/** Fixed window algorithm: allows starting at most `rate` operations in consecutively segments of duration `per`. */ | ||
case class FixedWindow(rate: Int, per: FiniteDuration) extends RateLimiterAlgorithm: | ||
private val lastUpdate = new AtomicLong(System.nanoTime()) | ||
private val semaphore = new Semaphore(rate) | ||
|
||
def acquire(permits: Int): Unit = | ||
semaphore.acquire(permits) | ||
|
||
def tryAcquire(permits: Int): Boolean = | ||
semaphore.tryAcquire(permits) | ||
|
||
def getNextUpdate: Long = | ||
val waitTime = lastUpdate.get() + per.toNanos - System.nanoTime() | ||
if waitTime > 0 then waitTime else 0L | ||
|
||
def update(): Unit = | ||
val now = System.nanoTime() | ||
lastUpdate.set(now) | ||
semaphore.release(rate - semaphore.availablePermits()) | ||
end update | ||
|
||
end FixedWindow | ||
|
||
/** Sliding window algorithm: allows to start at most `rate` operations in the lapse of `per` before current time. */ | ||
case class SlidingWindow(rate: Int, per: FiniteDuration) extends RateLimiterAlgorithm: | ||
// stores the timestamp and the number of permits acquired after calling acquire or tryAcquire successfully | ||
private val log = new AtomicReference[Queue[(Long, Int)]](Queue[(Long, Int)]()) | ||
private val semaphore = new Semaphore(rate) | ||
|
||
def acquire(permits: Int): Unit = | ||
semaphore.acquire(permits) | ||
addTimestampToLog(permits) | ||
|
||
def tryAcquire(permits: Int): Boolean = | ||
if semaphore.tryAcquire(permits) then | ||
addTimestampToLog(permits) | ||
true | ||
else false | ||
|
||
private def addTimestampToLog(permits: Int): Unit = | ||
val now = System.nanoTime() | ||
log.updateAndGet { q => | ||
q.enqueue((now, permits)) | ||
} | ||
() | ||
|
||
def getNextUpdate: Long = | ||
log.get().headOption match | ||
case None => | ||
// no logs so no need to update until `per` has passed | ||
per.toNanos | ||
case Some(record) => | ||
// oldest log provides the new updating point | ||
val waitTime = record._1 + per.toNanos - System.nanoTime() | ||
if waitTime > 0 then waitTime else 0L | ||
end getNextUpdate | ||
|
||
def update(): Unit = | ||
val now = System.nanoTime() | ||
// retrieving current queue to append it later if some elements were added concurrently | ||
val q = log.getAndUpdate(_ => Queue[(Long, Int)]()) | ||
// remove records older than window size | ||
val qUpdated = removeRecords(q, now) | ||
// merge old records with the ones concurrently added | ||
val _ = log.updateAndGet(qNew => | ||
qNew.foldLeft(qUpdated) { case (queue, record) => | ||
queue.enqueue(record) | ||
} | ||
) | ||
end update | ||
|
||
@tailrec | ||
private def removeRecords(q: Queue[(Long, Int)], now: Long): Queue[(Long, Int)] = | ||
q.dequeueOption match | ||
case None => q | ||
case Some((head, tail)) => | ||
if head._1 + per.toNanos < now then | ||
val (_, permits) = head | ||
semaphore.release(permits) | ||
removeRecords(tail, now) | ||
else q | ||
|
||
end SlidingWindow | ||
|
||
/** Token/leaky bucket algorithm It adds a token to start an new operation each `per` with a maximum number of tokens of `rate`. */ | ||
case class LeakyBucket(rate: Int, per: FiniteDuration) extends RateLimiterAlgorithm: | ||
private val refillInterval = per.toNanos | ||
private val lastRefillTime = new AtomicLong(System.nanoTime()) | ||
private val semaphore = new Semaphore(1) | ||
|
||
def acquire(permits: Int): Unit = | ||
semaphore.acquire(permits) | ||
|
||
def tryAcquire(permits: Int): Boolean = | ||
semaphore.tryAcquire(permits) | ||
|
||
def getNextUpdate: Long = | ||
val waitTime = lastRefillTime.get() + refillInterval - System.nanoTime() | ||
if waitTime > 0 then waitTime else 0L | ||
|
||
def update(): Unit = | ||
val now = System.nanoTime() | ||
lastRefillTime.set(now) | ||
if semaphore.availablePermits() < rate then semaphore.release() | ||
|
||
end LeakyBucket | ||
end RateLimiterAlgorithm |
126 changes: 126 additions & 0 deletions
126
core/src/test/scala/ox/resilience/RateLimiterInterfaceTest.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
package ox.resilience | ||
|
||
import ox.* | ||
import org.scalatest.flatspec.AnyFlatSpec | ||
import org.scalatest.matchers.should.Matchers | ||
import org.scalatest.{EitherValues, TryValues} | ||
import scala.concurrent.duration.* | ||
|
||
class RateLimiterInterfaceTest extends AnyFlatSpec with Matchers with EitherValues with TryValues: | ||
behavior of "RateLimiter interface" | ||
|
||
it should "drop or block operation depending on method used for fixed rate algorithm" in { | ||
supervised: | ||
val rateLimiter = RateLimiter.fixedWindow(2, FiniteDuration(1, "second")) | ||
|
||
var executions = 0 | ||
def operation = | ||
executions += 1 | ||
0 | ||
|
||
val result1 = rateLimiter.runOrDrop(operation) | ||
val result2 = rateLimiter.runOrDrop(operation) | ||
val result3 = rateLimiter.runOrDrop(operation) | ||
val result4 = rateLimiter.runBlocking(operation) | ||
val result5 = rateLimiter.runBlocking(operation) | ||
val result6 = rateLimiter.runOrDrop(operation) | ||
|
||
result1 shouldBe Some(0) | ||
result2 shouldBe Some(0) | ||
result3 shouldBe None | ||
result4 shouldBe 0 | ||
result5 shouldBe 0 | ||
result6 shouldBe None | ||
executions shouldBe 4 | ||
} | ||
|
||
it should "drop or block operation depending on method used for sliding window algorithm" in { | ||
supervised: | ||
val rateLimiter = RateLimiter.slidingWindow(2, FiniteDuration(1, "second")) | ||
|
||
var executions = 0 | ||
def operation = | ||
executions += 1 | ||
0 | ||
|
||
val result1 = rateLimiter.runOrDrop(operation) | ||
val result2 = rateLimiter.runOrDrop(operation) | ||
val result3 = rateLimiter.runOrDrop(operation) | ||
val result4 = rateLimiter.runBlocking(operation) | ||
val result5 = rateLimiter.runBlocking(operation) | ||
val result6 = rateLimiter.runOrDrop(operation) | ||
|
||
result1 shouldBe Some(0) | ||
result2 shouldBe Some(0) | ||
result3 shouldBe None | ||
result4 shouldBe 0 | ||
result5 shouldBe 0 | ||
result6 shouldBe None | ||
executions shouldBe 4 | ||
} | ||
|
||
it should "drop or block operation depending on method used for bucket algorithm" in { | ||
supervised: | ||
val rateLimiter = RateLimiter.leakyBucket(2, FiniteDuration(1, "second")) | ||
|
||
var executions = 0 | ||
def operation = | ||
executions += 1 | ||
0 | ||
|
||
val result1 = rateLimiter.runOrDrop(operation) | ||
val result2 = rateLimiter.runOrDrop(operation) | ||
val result3 = rateLimiter.runOrDrop(operation) | ||
val result4 = rateLimiter.runBlocking(operation) | ||
val result5 = rateLimiter.runBlocking(operation) | ||
val result6 = rateLimiter.runOrDrop(operation) | ||
|
||
result1 shouldBe Some(0) | ||
result2 shouldBe None | ||
result3 shouldBe None | ||
result4 shouldBe 0 | ||
result5 shouldBe 0 | ||
result6 shouldBe None | ||
executions shouldBe 3 | ||
} | ||
|
||
it should "drop or block operation concurrently" in { | ||
supervised: | ||
val rateLimiter = RateLimiter.fixedWindow(2, FiniteDuration(1, "second")) | ||
|
||
def operation = 0 | ||
|
||
var result1: Option[Int] = Some(-1) | ||
var result2: Option[Int] = Some(-1) | ||
var result3: Option[Int] = Some(-1) | ||
var result4: Int = -1 | ||
var result5: Int = -1 | ||
var result6: Int = -1 | ||
|
||
// run two operations to block the rate limiter | ||
rateLimiter.runOrDrop(operation).discard | ||
rateLimiter.runOrDrop(operation).discard | ||
|
||
// operations with runOrDrop should be dropped while operations with runBlocking should wait | ||
supervised: | ||
forkUserDiscard: | ||
result1 = rateLimiter.runOrDrop(operation) | ||
forkUserDiscard: | ||
result2 = rateLimiter.runOrDrop(operation) | ||
forkUserDiscard: | ||
result3 = rateLimiter.runOrDrop(operation) | ||
forkUserDiscard: | ||
result4 = rateLimiter.runBlocking(operation) | ||
forkUserDiscard: | ||
result5 = rateLimiter.runBlocking(operation) | ||
forkUserDiscard: | ||
result6 = rateLimiter.runBlocking(operation) | ||
|
||
result1 shouldBe None | ||
result2 shouldBe None | ||
result3 shouldBe None | ||
result4 shouldBe 0 | ||
result5 shouldBe 0 | ||
result6 shouldBe 0 | ||
} | ||
end RateLimiterInterfaceTest |
Oops, something went wrong.