Skip to content

Commit

Permalink
[ML] Inference API Rate limiter (#106330)
Browse files Browse the repository at this point in the history
* Working tests

* Adding more tests

* Adding comment

* Switching to micros and addressing feedback

* Removing nanos and adding test for bug fix

---------

Co-authored-by: Elastic Machine <[email protected]>
  • Loading branch information
jonathan-buttner and elasticmachine authored Mar 19, 2024
1 parent ea1672b commit edbff94
Show file tree
Hide file tree
Showing 2 changed files with 373 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.common;

import org.elasticsearch.common.Strings;

import java.time.Clock;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Objects;
import java.util.concurrent.TimeUnit;

/**
* Implements a throttler using the <a href="https://en.wikipedia.org/wiki/Token_bucket">token bucket algorithm</a>.
*
* The general approach is to define the rate limiter with size (accumulated tokens limit) which dictates how many
* unused tokens can be saved up, and a rate at which the tokens are created. Then when a thread should be rate limited
* it can attempt to acquire a certain number of tokens (typically one for each item of work it's going to do). If unused tokens
* are available in the bucket already, those will be used. If the number of available tokens covers the desired amount
* the thread will not sleep. If the bucket does not contain enough tokens, it will calculate how long the thread needs to sleep
* to accumulate the requested amount of tokens.
*
* By setting the accumulated tokens limit to a value greater than zero, it effectively allows bursts of traffic. If the accumulated
* tokens limit is set to zero, it will force the acquiring thread to wait on each call.
*/
public class RateLimiter {

private double tokensPerMicros;
private double accumulatedTokensLimit;
private double accumulatedTokens;
private Instant nextTokenAvailability;
private final Sleeper sleeper;
private final Clock clock;

/**
* @param accumulatedTokensLimit the limit for tokens stashed in the bucket
* @param tokensPerTimeUnit the number of tokens to produce per the time unit passed in
* @param unit the time unit frequency for generating tokens
*/
public RateLimiter(double accumulatedTokensLimit, double tokensPerTimeUnit, TimeUnit unit) {
this(accumulatedTokensLimit, tokensPerTimeUnit, unit, new TimeUnitSleeper(), Clock.systemUTC());
}

// default for testing
RateLimiter(double accumulatedTokensLimit, double tokensPerTimeUnit, TimeUnit unit, Sleeper sleeper, Clock clock) {
this.sleeper = Objects.requireNonNull(sleeper);
this.clock = Objects.requireNonNull(clock);
nextTokenAvailability = Instant.MIN;
setRate(accumulatedTokensLimit, tokensPerTimeUnit, unit);
}

public final synchronized void setRate(double newAccumulatedTokensLimit, double newTokensPerTimeUnit, TimeUnit newUnit) {
Objects.requireNonNull(newUnit);

if (newAccumulatedTokensLimit < 0) {
throw new IllegalArgumentException("Accumulated tokens limit must be greater than or equal to 0");
}

if (Double.isInfinite(newAccumulatedTokensLimit)) {
throw new IllegalArgumentException(
Strings.format("Accumulated tokens limit must be less than or equal to %s", Double.MAX_VALUE)
);
}

if (newTokensPerTimeUnit <= 0) {
throw new IllegalArgumentException("Tokens per time unit must be greater than 0");
}

if (newTokensPerTimeUnit == Double.POSITIVE_INFINITY) {
throw new IllegalArgumentException(Strings.format("Tokens per time unit must be less than or equal to %s", Double.MAX_VALUE));
}

accumulatedTokens = Math.min(accumulatedTokens, newAccumulatedTokensLimit);

accumulatedTokensLimit = newAccumulatedTokensLimit;

var unitsInMicros = newUnit.toMicros(1);
tokensPerMicros = newTokensPerTimeUnit / unitsInMicros;
assert Double.isInfinite(tokensPerMicros) == false : "Tokens per microsecond should not be infinity";

accumulateTokens();
}

/**
* Causes the thread to wait until the tokens are available
* @param tokens the number of items of work that should be throttled, typically you'd pass a value of 1 here
* @throws InterruptedException _
*/
public void acquire(int tokens) throws InterruptedException {
if (tokens <= 0) {
throw new IllegalArgumentException("Requested tokens must be positive");
}

double microsToWait;
synchronized (this) {
accumulateTokens();
var accumulatedTokensToUse = Math.min(tokens, accumulatedTokens);
var additionalTokensRequired = tokens - accumulatedTokensToUse;
microsToWait = additionalTokensRequired / tokensPerMicros;
accumulatedTokens -= accumulatedTokensToUse;
nextTokenAvailability = nextTokenAvailability.plus((long) microsToWait, ChronoUnit.MICROS);
}

sleeper.sleep((long) microsToWait);
}

private void accumulateTokens() {
var now = Instant.now(clock);
if (now.isAfter(nextTokenAvailability)) {
var elapsedTimeMicros = microsBetweenExact(nextTokenAvailability, now);
var newTokens = tokensPerMicros * elapsedTimeMicros;
accumulatedTokens = Math.min(accumulatedTokensLimit, accumulatedTokens + newTokens);
nextTokenAvailability = now;
}
}

private static long microsBetweenExact(Instant start, Instant end) {
try {
return ChronoUnit.MICROS.between(start, end);
} catch (ArithmeticException e) {
if (end.isAfter(start)) {
return Long.MAX_VALUE;
}

return 0;
}
}

// default for testing
Instant getNextTokenAvailability() {
return nextTokenAvailability;
}

public interface Sleeper {
void sleep(long microsecondsToSleep) throws InterruptedException;
}

static final class TimeUnitSleeper implements Sleeper {
public void sleep(long microsecondsToSleep) throws InterruptedException {
TimeUnit.MICROSECONDS.sleep(microsecondsToSleep);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.common;

import org.elasticsearch.common.Strings;
import org.elasticsearch.test.ESTestCase;

import java.time.Clock;
import java.time.Duration;
import java.time.temporal.ChronoUnit;
import java.util.concurrent.TimeUnit;

import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class RateLimiterTests extends ESTestCase {
public void testThrows_WhenAccumulatedTokensLimit_IsNegative() {
var exception = expectThrows(
IllegalArgumentException.class,
() -> new RateLimiter(-1, 1, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC())
);
assertThat(exception.getMessage(), is("Accumulated tokens limit must be greater than or equal to 0"));
}

public void testThrows_WhenAccumulatedTokensLimit_IsInfinity() {
var exception = expectThrows(
IllegalArgumentException.class,
() -> new RateLimiter(Double.POSITIVE_INFINITY, 1, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC())
);
assertThat(
exception.getMessage(),
is(Strings.format("Accumulated tokens limit must be less than or equal to %s", Double.MAX_VALUE))
);
}

public void testThrows_WhenTokensPerTimeUnit_IsZero() {
var exception = expectThrows(
IllegalArgumentException.class,
() -> new RateLimiter(0, 0, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC())
);
assertThat(exception.getMessage(), is("Tokens per time unit must be greater than 0"));
}

public void testThrows_WhenTokensPerTimeUnit_IsInfinity() {
var exception = expectThrows(
IllegalArgumentException.class,
() -> new RateLimiter(0, Double.POSITIVE_INFINITY, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC())
);
assertThat(exception.getMessage(), is(Strings.format("Tokens per time unit must be less than or equal to %s", Double.MAX_VALUE)));
}

public void testThrows_WhenTokensPerTimeUnit_IsNegative() {
var exception = expectThrows(
IllegalArgumentException.class,
() -> new RateLimiter(0, -1, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC())
);
assertThat(exception.getMessage(), is("Tokens per time unit must be greater than 0"));
}

public void testAcquire_Throws_WhenTokens_IsZero() {
var limiter = new RateLimiter(0, 1, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC());
var exception = expectThrows(IllegalArgumentException.class, () -> limiter.acquire(0));
assertThat(exception.getMessage(), is("Requested tokens must be positive"));
}

public void testAcquire_Throws_WhenTokens_IsNegative() {
var limiter = new RateLimiter(0, 1, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC());
var exception = expectThrows(IllegalArgumentException.class, () -> limiter.acquire(-1));
assertThat(exception.getMessage(), is("Requested tokens must be positive"));
}

public void testAcquire_First_CallDoesNotSleep() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(1, 1, TimeUnit.MINUTES, sleeper, clock);
limiter.acquire(1);
verify(sleeper, times(1)).sleep(0);
}

public void testAcquire_DoesNotSleep_WhenTokenRateIsHigh() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(0, Double.MAX_VALUE, TimeUnit.MICROSECONDS, sleeper, clock);
limiter.acquire(1);
verify(sleeper, times(1)).sleep(0);
}

public void testAcquire_AcceptsMaxIntValue_WhenTokenRateIsHigh() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(0, Double.MAX_VALUE, TimeUnit.MICROSECONDS, sleeper, clock);
limiter.acquire(Integer.MAX_VALUE);
verify(sleeper, times(1)).sleep(0);
}

public void testAcquire_AcceptsMaxIntValue_WhenTokenRateIsLow() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

double tokensPerDay = 1;
var limiter = new RateLimiter(0, tokensPerDay, TimeUnit.DAYS, sleeper, clock);
limiter.acquire(Integer.MAX_VALUE);

double tokensPerMicro = tokensPerDay / TimeUnit.DAYS.toMicros(1);
verify(sleeper, times(1)).sleep((long) ((double) Integer.MAX_VALUE / tokensPerMicro));
}

public void testAcquire_SleepsForOneMinute_WhenRequestingOneUnavailableToken() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(1, 1, TimeUnit.MINUTES, sleeper, clock);
limiter.acquire(2);
verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toMicros(1));
}

public void testAcquire_SleepsForOneMinute_WhenRequestingOneUnavailableToken_NoAccumulated() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(0, 1, TimeUnit.MINUTES, sleeper, clock);
limiter.acquire(1);
verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toMicros(1));
}

public void testAcquire_SleepsFor10Minute_WhenRequesting10UnavailableToken_NoAccumulated() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(0, 1, TimeUnit.MINUTES, sleeper, clock);
limiter.acquire(10);
verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toMicros(10));
}

public void testAcquire_IncrementsNextTokenAvailabilityInstant_ByOneMinute() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(0, 1, TimeUnit.MINUTES, sleeper, clock);
limiter.acquire(1);
verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toMicros(1));
assertThat(limiter.getNextTokenAvailability(), is(now.plus(1, ChronoUnit.MINUTES)));
}

public void testAcquire_SecondCallToAcquire_ShouldWait_WhenAccumulatedTokensAreDepleted() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(1, 1, TimeUnit.MINUTES, sleeper, clock);
limiter.acquire(1);
verify(sleeper, times(1)).sleep(0);
limiter.acquire(1);
verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toMicros(1));
}

public void testAcquire_SecondCallToAcquire_ShouldWaitForHalfDuration_WhenElapsedTimeIsHalfRequiredDuration()
throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(1, 1, TimeUnit.MINUTES, sleeper, clock);
limiter.acquire(1);
verify(sleeper, times(1)).sleep(0);
when(clock.instant()).thenReturn(now.plus(Duration.ofSeconds(30)));
limiter.acquire(1);
verify(sleeper, times(1)).sleep(TimeUnit.SECONDS.toMicros(30));
}

public void testAcquire_ShouldAccumulateTokens() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(10, 10, TimeUnit.MINUTES, sleeper, clock);
limiter.acquire(5);
verify(sleeper, times(1)).sleep(0);
// it should accumulate 5 tokens
when(clock.instant()).thenReturn(now.plus(Duration.ofSeconds(30)));
limiter.acquire(10);
verify(sleeper, times(2)).sleep(0);
}
}

0 comments on commit edbff94

Please sign in to comment.