Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Inference API Rate limiter #106330

Merged
merged 10 commits into from
Mar 19, 2024
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.common;

import org.elasticsearch.common.Strings;

import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Objects;
import java.util.concurrent.TimeUnit;

/**
* Implements a throttler using the <a href="https://en.wikipedia.org/wiki/Token_bucket">token bucket algorithm</a>.
*
* The general approach is to define the rate limiter with size (accumulated tokens limit) which dictates how many
* unused tokens can be saved up, and a rate at which the tokens are created. Then when a thread should be rate limited
* it can attempt to acquire a certain number of tokens (typically one for each item of work it's going to do). If unused tokens
* are available in the bucket already, those will be used. If the number of available tokens covers the desired amount
* the thread will not sleep. If the bucket does not contain enough tokens, it will calculate how long the thread needs to sleep
* to accumulate the requested amount of tokens.
*
* By setting the accumulated tokens limit to a value greater than zero, it effectively allows bursts of traffic. If the accumulated
* tokens limit is set to zero, it will force the acquiring thread to wait on each call.
*/
public class RateLimiter {

private double tokensPerNanos;
private double accumulatedTokensLimit;
private double accumulatedTokens;
private Instant nextTokenAvailability;
private final Sleeper sleeper;
private final Clock clock;

/**
* @param accumulatedTokensLimit the limit for tokens stashed in the bucket
* @param tokensPerTimeUnit the number of tokens to produce per the time unit passed in
* @param unit the time unit frequency for generating tokens
*/
public RateLimiter(double accumulatedTokensLimit, double tokensPerTimeUnit, TimeUnit unit) {
this(accumulatedTokensLimit, tokensPerTimeUnit, unit, new TimeUnitSleeper(), Clock.systemUTC());
}

// default for testing
RateLimiter(double accumulatedTokensLimit, double tokensPerTimeUnit, TimeUnit unit, Sleeper sleeper, Clock clock) {
this.sleeper = Objects.requireNonNull(sleeper);
this.clock = Objects.requireNonNull(clock);
nextTokenAvailability = Instant.MIN;
setRate(accumulatedTokensLimit, tokensPerTimeUnit, unit);
}

public final synchronized void setRate(double newAccumulatedTokensLimit, double newTokensPerTimeUnit, TimeUnit newUnit) {
Objects.requireNonNull(newUnit);

if (newAccumulatedTokensLimit < 0) {
throw new IllegalArgumentException("Accumulated tokens limit must be greater than or equal to 0");
}

if (newAccumulatedTokensLimit == Double.POSITIVE_INFINITY) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (newAccumulatedTokensLimit == Double.POSITIVE_INFINITY) {
if (Double.isInfinite(newAccumulatedTokensLimit)) {

throw new IllegalArgumentException(
Strings.format("Accumulated tokens limit must be less than or equal to %s", Double.MAX_VALUE)
);
}

if (newTokensPerTimeUnit <= 0) {
throw new IllegalArgumentException("Tokens per time unit must be greater than 0");
}

if (newTokensPerTimeUnit == Double.POSITIVE_INFINITY) {
throw new IllegalArgumentException(Strings.format("Tokens per time unit must be less than or equal to %s", Double.MAX_VALUE));
}

accumulatedTokens = Math.min(accumulatedTokens, newAccumulatedTokensLimit);

accumulatedTokensLimit = newAccumulatedTokensLimit;

var unitsInNanos = newUnit.toNanos(1);
tokensPerNanos = newTokensPerTimeUnit / unitsInNanos;
assert tokensPerNanos != Double.POSITIVE_INFINITY : "Tokens per nanosecond should not be infinity";

accumulateTokens();
}

/**
* Causes the thread to wait until the tokens are available
* @param tokens the number of items of work that should be throttled, typically you'd pass a value of 1 here
* @throws InterruptedException
*/
public void acquire(int tokens) throws InterruptedException {
if (tokens <= 0) {
throw new IllegalArgumentException("Requested tokens must be positive");
}

double nanosToWait;
synchronized (this) {
accumulateTokens();
var accumulatedTokensToUse = Math.min(tokens, accumulatedTokens);
var additionalTokensRequired = tokens - accumulatedTokensToUse;
nanosToWait = additionalTokensRequired / tokensPerNanos;
accumulatedTokens -= accumulatedTokensToUse;
nextTokenAvailability = nextTokenAvailability.plus(Duration.ofNanos((long) nanosToWait));
}

sleeper.sleep((long) nanosToWait);
}

private void accumulateTokens() {
var now = Instant.now(clock);
if (now.isAfter(nextTokenAvailability)) {
var elapsedTimeNanos = nanosBetweenExact(nextTokenAvailability, now);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is called in the ctor via setRate() at which point nextTokenAvailability == Instant.MIN. Because the calculated elapsedTimeNanos is high the class will be initialised with accumulatedTokens == accumulatedTokensLimit.

That seems reasonable to me, or at least as good as initialising accumulatedTokens to 0. Just want to check that is the intention

Copy link
Contributor Author

@jonathan-buttner jonathan-buttner Mar 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that was intentional. My thinking was that the first request can move forward without having to wait for tokens to accumulate if the limit was set to a positive number. If we always want it to start as 0 that's fine with me too though.

var newTokens = tokensPerNanos * elapsedTimeNanos;
accumulatedTokens = Math.min(accumulatedTokensLimit, newTokens);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this include the previously accumulated tokens?

Suggested change
accumulatedTokens = Math.min(accumulatedTokensLimit, newTokens);
accumulatedTokens = Math.min(accumulatedTokensLimit, accumulatedTokens + newTokens);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep thanks for that.

nextTokenAvailability = now;
}
}

private static double nanosBetweenExact(Instant start, Instant end) {
try {
return ChronoUnit.NANOS.between(start, end);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TemporalUnit.between() returns a long not double

} catch (ArithmeticException e) {
if (end.isAfter(start)) {
return Double.POSITIVE_INFINITY;
}

return Double.NEGATIVE_INFINITY;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the accumulateTokens code:

            var newTokens = tokensPerNanos * elapsedTimeNanos;
            accumulatedTokens = Math.min(accumulatedTokensLimit, newTokens);

Math.min(a_positive_number, Double.NEGATIVE_INFINITY) returns Double.NEGATIVE_INFINITY. If accumulatedTokens becomes a -ve number I think that could cause errors.

One option is to return 0 ( not +ve or -ve infinity). Using ChronoUnit.MICRO or ChronoUnit.MILLIS reduces the chance of an arithmetic overflow

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I'll switch it to 0 and use micros instead.

}
}

public interface Sleeper {
void sleep(long nanosecondsToSleep) throws InterruptedException;
}

static final class TimeUnitSleeper implements Sleeper {
public void sleep(long nanosecondsToSleep) throws InterruptedException {
TimeUnit.NANOSECONDS.sleep(nanosecondsToSleep);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.common;

import org.elasticsearch.common.Strings;
import org.elasticsearch.test.ESTestCase;

import java.time.Clock;
import java.time.Duration;
import java.util.concurrent.TimeUnit;

import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class RateLimiterTests extends ESTestCase {
public void testThrows_WhenAccumulatedTokensLimit_IsNegative() {
var exception = expectThrows(
IllegalArgumentException.class,
() -> new RateLimiter(-1, 1, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC())
);
assertThat(exception.getMessage(), is("Accumulated tokens limit must be greater than or equal to 0"));
}

public void testThrows_WhenAccumulatedTokensLimit_IsInfinity() {
var exception = expectThrows(
IllegalArgumentException.class,
() -> new RateLimiter(Double.POSITIVE_INFINITY, 1, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC())
);
assertThat(
exception.getMessage(),
is(Strings.format("Accumulated tokens limit must be less than or equal to %s", Double.MAX_VALUE))
);
}

public void testThrows_WhenTokensPerTimeUnit_IsZero() {
var exception = expectThrows(
IllegalArgumentException.class,
() -> new RateLimiter(0, 0, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC())
);
assertThat(exception.getMessage(), is("Tokens per time unit must be greater than 0"));
}

public void testThrows_WhenTokensPerTimeUnit_IsInfinity() {
var exception = expectThrows(
IllegalArgumentException.class,
() -> new RateLimiter(0, Double.POSITIVE_INFINITY, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC())
);
assertThat(exception.getMessage(), is(Strings.format("Tokens per time unit must be less than or equal to %s", Double.MAX_VALUE)));
}

public void testThrows_WhenTokensPerTimeUnit_IsNegative() {
var exception = expectThrows(
IllegalArgumentException.class,
() -> new RateLimiter(0, -1, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC())
);
assertThat(exception.getMessage(), is("Tokens per time unit must be greater than 0"));
}

public void testAcquire_Throws_WhenTokens_IsZero() {
var limiter = new RateLimiter(0, 1, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC());
var exception = expectThrows(IllegalArgumentException.class, () -> limiter.acquire(0));
assertThat(exception.getMessage(), is("Requested tokens must be positive"));
}

public void testAcquire_Throws_WhenTokens_IsNegative() {
var limiter = new RateLimiter(0, 1, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC());
var exception = expectThrows(IllegalArgumentException.class, () -> limiter.acquire(-1));
assertThat(exception.getMessage(), is("Requested tokens must be positive"));
}

public void testAcquire_First_CallDoesNotSleep() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(1, 1, TimeUnit.MINUTES, sleeper, clock);
limiter.acquire(1);
verify(sleeper, times(1)).sleep(0);
}

public void testAcquire_DoesNotSleep_WhenTokenRateIsHigh() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(0, Double.MAX_VALUE, TimeUnit.NANOSECONDS, sleeper, clock);
limiter.acquire(1);
verify(sleeper, times(1)).sleep(0);
}

public void testAcquire_AcceptsMaxIntValue_WhenTokenRateIsHigh() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(0, Double.MAX_VALUE, TimeUnit.NANOSECONDS, sleeper, clock);
limiter.acquire(Integer.MAX_VALUE);
verify(sleeper, times(1)).sleep(0);
}

public void testAcquire_AcceptsMaxIntValue_WhenTokenRateIsLow() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

double tokensPerDay = 1;
var limiter = new RateLimiter(0, tokensPerDay, TimeUnit.DAYS, sleeper, clock);
limiter.acquire(Integer.MAX_VALUE);

double tokensPerNano = tokensPerDay / TimeUnit.DAYS.toNanos(1);
verify(sleeper, times(1)).sleep((long) ((double) Integer.MAX_VALUE / tokensPerNano));
}

public void testAcquire_SleepsForOneMinute_WhenRequestingOneUnavailableToken() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(1, 1, TimeUnit.MINUTES, sleeper, clock);
limiter.acquire(2);
verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toNanos(1));
}

public void testAcquire_SleepsForOneMinute_WhenRequestingOneUnavailableToken_NoAccumulated() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(0, 1, TimeUnit.MINUTES, sleeper, clock);
limiter.acquire(1);
verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toNanos(1));
}

public void testAcquire_SleepsFor10Minute_WhenRequesting10UnavailableToken_NoAccumulated() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(0, 1, TimeUnit.MINUTES, sleeper, clock);
limiter.acquire(10);
verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toNanos(10));
}

public void testAcquire_SecondCallToAcquire_ShouldWait_WhenAccumulatedTokensAreDepleted() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(1, 1, TimeUnit.MINUTES, sleeper, clock);
limiter.acquire(1);
verify(sleeper, times(1)).sleep(0);
limiter.acquire(1);
verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toNanos(1));
}

public void testAcquire_SecondCallToAcquire_ShouldWaitForHalfDuration_WhenElapsedTimeIsHalfRequiredDuration()
throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(1, 1, TimeUnit.MINUTES, sleeper, clock);
limiter.acquire(1);
verify(sleeper, times(1)).sleep(0);
when(clock.instant()).thenReturn(now.plus(Duration.ofSeconds(30)));
limiter.acquire(1);
verify(sleeper, times(1)).sleep(TimeUnit.SECONDS.toNanos(30));
}
}