Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Inference API Rate limiter #106330

Merged
merged 10 commits into from
Mar 19, 2024
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.common;

import org.elasticsearch.common.Strings;

import java.time.Clock;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Objects;
import java.util.concurrent.TimeUnit;

/**
* Implements a throttler using the <a href="https://en.wikipedia.org/wiki/Token_bucket">token bucket algorithm</a>.
*
* The general approach is to define the rate limiter with size (accumulated tokens limit) which dictates how many
* unused tokens can be saved up, and a rate at which the tokens are created. Then when a thread should be rate limited
* it can attempt to acquire a certain number of tokens (typically one for each item of work it's going to do). If unused tokens
* are available in the bucket already, those will be used. If the number of available tokens covers the desired amount
* the thread will not sleep. If the bucket does not contain enough tokens, it will calculate how long the thread needs to sleep
* to accumulate the requested amount of tokens.
*
* By setting the accumulated tokens limit to a value greater than zero, it effectively allows bursts of traffic. If the accumulated
* tokens limit is set to zero, it will force the acquiring thread to wait on each call.
*/
public class RateLimiter {

private double tokensPerMicros;
private double accumulatedTokensLimit;
private double accumulatedTokens;
private Instant nextTokenAvailability;
private final Sleeper sleeper;
private final Clock clock;

/**
* @param accumulatedTokensLimit the limit for tokens stashed in the bucket
* @param tokensPerTimeUnit the number of tokens to produce per the time unit passed in
* @param unit the time unit frequency for generating tokens
*/
public RateLimiter(double accumulatedTokensLimit, double tokensPerTimeUnit, TimeUnit unit) {
this(accumulatedTokensLimit, tokensPerTimeUnit, unit, new TimeUnitSleeper(), Clock.systemUTC());
}

// default for testing
RateLimiter(double accumulatedTokensLimit, double tokensPerTimeUnit, TimeUnit unit, Sleeper sleeper, Clock clock) {
this.sleeper = Objects.requireNonNull(sleeper);
this.clock = Objects.requireNonNull(clock);
nextTokenAvailability = Instant.MIN;
setRate(accumulatedTokensLimit, tokensPerTimeUnit, unit);
}

public final synchronized void setRate(double newAccumulatedTokensLimit, double newTokensPerTimeUnit, TimeUnit newUnit) {
Objects.requireNonNull(newUnit);

if (newAccumulatedTokensLimit < 0) {
throw new IllegalArgumentException("Accumulated tokens limit must be greater than or equal to 0");
}

if (Double.isInfinite(newAccumulatedTokensLimit)) {
throw new IllegalArgumentException(
Strings.format("Accumulated tokens limit must be less than or equal to %s", Double.MAX_VALUE)
);
}

if (newTokensPerTimeUnit <= 0) {
throw new IllegalArgumentException("Tokens per time unit must be greater than 0");
}

if (newTokensPerTimeUnit == Double.POSITIVE_INFINITY) {
throw new IllegalArgumentException(Strings.format("Tokens per time unit must be less than or equal to %s", Double.MAX_VALUE));
}

accumulatedTokens = Math.min(accumulatedTokens, newAccumulatedTokensLimit);

accumulatedTokensLimit = newAccumulatedTokensLimit;

var unitsInMicros = newUnit.toMicros(1);
tokensPerMicros = newTokensPerTimeUnit / unitsInMicros;
assert Double.isInfinite(tokensPerMicros) == false : "Tokens per microsecond should not be infinity";

accumulateTokens();
}

/**
* Causes the thread to wait until the tokens are available
* @param tokens the number of items of work that should be throttled, typically you'd pass a value of 1 here
* @throws InterruptedException _
*/
public void acquire(int tokens) throws InterruptedException {
if (tokens <= 0) {
throw new IllegalArgumentException("Requested tokens must be positive");
}

double microsToWait;
synchronized (this) {
accumulateTokens();
var accumulatedTokensToUse = Math.min(tokens, accumulatedTokens);
var additionalTokensRequired = tokens - accumulatedTokensToUse;
microsToWait = additionalTokensRequired / tokensPerMicros;
accumulatedTokens -= accumulatedTokensToUse;
nextTokenAvailability = nextTokenAvailability.plus((long) microsToWait, ChronoUnit.MICROS);
}

sleeper.sleep((long) microsToWait);
}

private void accumulateTokens() {
var now = Instant.now(clock);
if (now.isAfter(nextTokenAvailability)) {
var elapsedTimeMicros = microsBetweenExact(nextTokenAvailability, now);
var newTokens = tokensPerMicros * elapsedTimeMicros;
accumulatedTokens = Math.min(accumulatedTokensLimit, accumulatedTokens + newTokens);
nextTokenAvailability = now;
}
}

private static long microsBetweenExact(Instant start, Instant end) {
try {
return ChronoUnit.MICROS.between(start, end);
} catch (ArithmeticException e) {
if (end.isAfter(start)) {
return Long.MAX_VALUE;
}

return 0;
}
}

// default for testing
Instant getNextTokenAvailability() {
return nextTokenAvailability;
}

public interface Sleeper {
void sleep(long microsecondsToSleep) throws InterruptedException;
}

static final class TimeUnitSleeper implements Sleeper {
public void sleep(long microsecondsToSleep) throws InterruptedException {
TimeUnit.MICROSECONDS.sleep(microsecondsToSleep);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.common;

import org.elasticsearch.common.Strings;
import org.elasticsearch.test.ESTestCase;

import java.time.Clock;
import java.time.Duration;
import java.time.temporal.ChronoUnit;
import java.util.concurrent.TimeUnit;

import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class RateLimiterTests extends ESTestCase {
public void testThrows_WhenAccumulatedTokensLimit_IsNegative() {
var exception = expectThrows(
IllegalArgumentException.class,
() -> new RateLimiter(-1, 1, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC())
);
assertThat(exception.getMessage(), is("Accumulated tokens limit must be greater than or equal to 0"));
}

public void testThrows_WhenAccumulatedTokensLimit_IsInfinity() {
var exception = expectThrows(
IllegalArgumentException.class,
() -> new RateLimiter(Double.POSITIVE_INFINITY, 1, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC())
);
assertThat(
exception.getMessage(),
is(Strings.format("Accumulated tokens limit must be less than or equal to %s", Double.MAX_VALUE))
);
}

public void testThrows_WhenTokensPerTimeUnit_IsZero() {
var exception = expectThrows(
IllegalArgumentException.class,
() -> new RateLimiter(0, 0, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC())
);
assertThat(exception.getMessage(), is("Tokens per time unit must be greater than 0"));
}

public void testThrows_WhenTokensPerTimeUnit_IsInfinity() {
var exception = expectThrows(
IllegalArgumentException.class,
() -> new RateLimiter(0, Double.POSITIVE_INFINITY, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC())
);
assertThat(exception.getMessage(), is(Strings.format("Tokens per time unit must be less than or equal to %s", Double.MAX_VALUE)));
}

public void testThrows_WhenTokensPerTimeUnit_IsNegative() {
var exception = expectThrows(
IllegalArgumentException.class,
() -> new RateLimiter(0, -1, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC())
);
assertThat(exception.getMessage(), is("Tokens per time unit must be greater than 0"));
}

public void testAcquire_Throws_WhenTokens_IsZero() {
var limiter = new RateLimiter(0, 1, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC());
var exception = expectThrows(IllegalArgumentException.class, () -> limiter.acquire(0));
assertThat(exception.getMessage(), is("Requested tokens must be positive"));
}

public void testAcquire_Throws_WhenTokens_IsNegative() {
var limiter = new RateLimiter(0, 1, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC());
var exception = expectThrows(IllegalArgumentException.class, () -> limiter.acquire(-1));
assertThat(exception.getMessage(), is("Requested tokens must be positive"));
}

public void testAcquire_First_CallDoesNotSleep() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(1, 1, TimeUnit.MINUTES, sleeper, clock);
limiter.acquire(1);
verify(sleeper, times(1)).sleep(0);
}

public void testAcquire_DoesNotSleep_WhenTokenRateIsHigh() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(0, Double.MAX_VALUE, TimeUnit.MICROSECONDS, sleeper, clock);
limiter.acquire(1);
verify(sleeper, times(1)).sleep(0);
}

public void testAcquire_AcceptsMaxIntValue_WhenTokenRateIsHigh() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(0, Double.MAX_VALUE, TimeUnit.MICROSECONDS, sleeper, clock);
limiter.acquire(Integer.MAX_VALUE);
verify(sleeper, times(1)).sleep(0);
}

public void testAcquire_AcceptsMaxIntValue_WhenTokenRateIsLow() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

double tokensPerDay = 1;
var limiter = new RateLimiter(0, tokensPerDay, TimeUnit.DAYS, sleeper, clock);
limiter.acquire(Integer.MAX_VALUE);

double tokensPerMicro = tokensPerDay / TimeUnit.DAYS.toMicros(1);
verify(sleeper, times(1)).sleep((long) ((double) Integer.MAX_VALUE / tokensPerMicro));
}

public void testAcquire_SleepsForOneMinute_WhenRequestingOneUnavailableToken() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(1, 1, TimeUnit.MINUTES, sleeper, clock);
limiter.acquire(2);
verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toMicros(1));
}

public void testAcquire_SleepsForOneMinute_WhenRequestingOneUnavailableToken_NoAccumulated() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(0, 1, TimeUnit.MINUTES, sleeper, clock);
limiter.acquire(1);
verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toMicros(1));
}

public void testAcquire_SleepsFor10Minute_WhenRequesting10UnavailableToken_NoAccumulated() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(0, 1, TimeUnit.MINUTES, sleeper, clock);
limiter.acquire(10);
verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toMicros(10));
}

public void testAcquire_IncrementsNextTokenAvailabilityInstant_ByOneMinute() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(0, 1, TimeUnit.MINUTES, sleeper, clock);
limiter.acquire(1);
verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toMicros(1));
assertThat(limiter.getNextTokenAvailability(), is(now.plus(1, ChronoUnit.MINUTES)));
}

public void testAcquire_SecondCallToAcquire_ShouldWait_WhenAccumulatedTokensAreDepleted() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(1, 1, TimeUnit.MINUTES, sleeper, clock);
limiter.acquire(1);
verify(sleeper, times(1)).sleep(0);
limiter.acquire(1);
verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toMicros(1));
}

public void testAcquire_SecondCallToAcquire_ShouldWaitForHalfDuration_WhenElapsedTimeIsHalfRequiredDuration()
throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(1, 1, TimeUnit.MINUTES, sleeper, clock);
limiter.acquire(1);
verify(sleeper, times(1)).sleep(0);
when(clock.instant()).thenReturn(now.plus(Duration.ofSeconds(30)));
limiter.acquire(1);
verify(sleeper, times(1)).sleep(TimeUnit.SECONDS.toMicros(30));
}

public void testAcquire_ShouldAccumulateTokens() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(10, 10, TimeUnit.MINUTES, sleeper, clock);
limiter.acquire(5);
verify(sleeper, times(1)).sleep(0);
// it should accumulate 5 tokens
when(clock.instant()).thenReturn(now.plus(Duration.ofSeconds(30)));
limiter.acquire(10);
verify(sleeper, times(2)).sleep(0);
}
}