-
Notifications
You must be signed in to change notification settings - Fork 25k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ML] Inference API Rate limiter (#106330)
* Working tests * Adding more tests * Adding comment * Switching to micros and addressing feedback * Removing nanos and adding test for bug fix --------- Co-authored-by: Elastic Machine <[email protected]>
- Loading branch information
1 parent
ea1672b
commit edbff94
Showing
2 changed files
with
373 additions
and
0 deletions.
There are no files selected for viewing
148 changes: 148 additions & 0 deletions
148
.../plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/RateLimiter.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.inference.common; | ||
|
||
import org.elasticsearch.common.Strings; | ||
|
||
import java.time.Clock; | ||
import java.time.Instant; | ||
import java.time.temporal.ChronoUnit; | ||
import java.util.Objects; | ||
import java.util.concurrent.TimeUnit; | ||
|
||
/** | ||
* Implements a throttler using the <a href="https://en.wikipedia.org/wiki/Token_bucket">token bucket algorithm</a>. | ||
* | ||
* The general approach is to define the rate limiter with size (accumulated tokens limit) which dictates how many | ||
* unused tokens can be saved up, and a rate at which the tokens are created. Then when a thread should be rate limited | ||
* it can attempt to acquire a certain number of tokens (typically one for each item of work it's going to do). If unused tokens | ||
* are available in the bucket already, those will be used. If the number of available tokens covers the desired amount | ||
* the thread will not sleep. If the bucket does not contain enough tokens, it will calculate how long the thread needs to sleep | ||
* to accumulate the requested amount of tokens. | ||
* | ||
* By setting the accumulated tokens limit to a value greater than zero, it effectively allows bursts of traffic. If the accumulated | ||
* tokens limit is set to zero, it will force the acquiring thread to wait on each call. | ||
*/ | ||
public class RateLimiter { | ||
|
||
private double tokensPerMicros; | ||
private double accumulatedTokensLimit; | ||
private double accumulatedTokens; | ||
private Instant nextTokenAvailability; | ||
private final Sleeper sleeper; | ||
private final Clock clock; | ||
|
||
/** | ||
* @param accumulatedTokensLimit the limit for tokens stashed in the bucket | ||
* @param tokensPerTimeUnit the number of tokens to produce per the time unit passed in | ||
* @param unit the time unit frequency for generating tokens | ||
*/ | ||
public RateLimiter(double accumulatedTokensLimit, double tokensPerTimeUnit, TimeUnit unit) { | ||
this(accumulatedTokensLimit, tokensPerTimeUnit, unit, new TimeUnitSleeper(), Clock.systemUTC()); | ||
} | ||
|
||
// default for testing | ||
RateLimiter(double accumulatedTokensLimit, double tokensPerTimeUnit, TimeUnit unit, Sleeper sleeper, Clock clock) { | ||
this.sleeper = Objects.requireNonNull(sleeper); | ||
this.clock = Objects.requireNonNull(clock); | ||
nextTokenAvailability = Instant.MIN; | ||
setRate(accumulatedTokensLimit, tokensPerTimeUnit, unit); | ||
} | ||
|
||
public final synchronized void setRate(double newAccumulatedTokensLimit, double newTokensPerTimeUnit, TimeUnit newUnit) { | ||
Objects.requireNonNull(newUnit); | ||
|
||
if (newAccumulatedTokensLimit < 0) { | ||
throw new IllegalArgumentException("Accumulated tokens limit must be greater than or equal to 0"); | ||
} | ||
|
||
if (Double.isInfinite(newAccumulatedTokensLimit)) { | ||
throw new IllegalArgumentException( | ||
Strings.format("Accumulated tokens limit must be less than or equal to %s", Double.MAX_VALUE) | ||
); | ||
} | ||
|
||
if (newTokensPerTimeUnit <= 0) { | ||
throw new IllegalArgumentException("Tokens per time unit must be greater than 0"); | ||
} | ||
|
||
if (newTokensPerTimeUnit == Double.POSITIVE_INFINITY) { | ||
throw new IllegalArgumentException(Strings.format("Tokens per time unit must be less than or equal to %s", Double.MAX_VALUE)); | ||
} | ||
|
||
accumulatedTokens = Math.min(accumulatedTokens, newAccumulatedTokensLimit); | ||
|
||
accumulatedTokensLimit = newAccumulatedTokensLimit; | ||
|
||
var unitsInMicros = newUnit.toMicros(1); | ||
tokensPerMicros = newTokensPerTimeUnit / unitsInMicros; | ||
assert Double.isInfinite(tokensPerMicros) == false : "Tokens per microsecond should not be infinity"; | ||
|
||
accumulateTokens(); | ||
} | ||
|
||
/** | ||
* Causes the thread to wait until the tokens are available | ||
* @param tokens the number of items of work that should be throttled, typically you'd pass a value of 1 here | ||
* @throws InterruptedException _ | ||
*/ | ||
public void acquire(int tokens) throws InterruptedException { | ||
if (tokens <= 0) { | ||
throw new IllegalArgumentException("Requested tokens must be positive"); | ||
} | ||
|
||
double microsToWait; | ||
synchronized (this) { | ||
accumulateTokens(); | ||
var accumulatedTokensToUse = Math.min(tokens, accumulatedTokens); | ||
var additionalTokensRequired = tokens - accumulatedTokensToUse; | ||
microsToWait = additionalTokensRequired / tokensPerMicros; | ||
accumulatedTokens -= accumulatedTokensToUse; | ||
nextTokenAvailability = nextTokenAvailability.plus((long) microsToWait, ChronoUnit.MICROS); | ||
} | ||
|
||
sleeper.sleep((long) microsToWait); | ||
} | ||
|
||
private void accumulateTokens() { | ||
var now = Instant.now(clock); | ||
if (now.isAfter(nextTokenAvailability)) { | ||
var elapsedTimeMicros = microsBetweenExact(nextTokenAvailability, now); | ||
var newTokens = tokensPerMicros * elapsedTimeMicros; | ||
accumulatedTokens = Math.min(accumulatedTokensLimit, accumulatedTokens + newTokens); | ||
nextTokenAvailability = now; | ||
} | ||
} | ||
|
||
private static long microsBetweenExact(Instant start, Instant end) { | ||
try { | ||
return ChronoUnit.MICROS.between(start, end); | ||
} catch (ArithmeticException e) { | ||
if (end.isAfter(start)) { | ||
return Long.MAX_VALUE; | ||
} | ||
|
||
return 0; | ||
} | ||
} | ||
|
||
// default for testing | ||
Instant getNextTokenAvailability() { | ||
return nextTokenAvailability; | ||
} | ||
|
||
public interface Sleeper { | ||
void sleep(long microsecondsToSleep) throws InterruptedException; | ||
} | ||
|
||
static final class TimeUnitSleeper implements Sleeper { | ||
public void sleep(long microsecondsToSleep) throws InterruptedException { | ||
TimeUnit.MICROSECONDS.sleep(microsecondsToSleep); | ||
} | ||
} | ||
} |
225 changes: 225 additions & 0 deletions
225
...in/inference/src/test/java/org/elasticsearch/xpack/inference/common/RateLimiterTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,225 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.inference.common; | ||
|
||
import org.elasticsearch.common.Strings; | ||
import org.elasticsearch.test.ESTestCase; | ||
|
||
import java.time.Clock; | ||
import java.time.Duration; | ||
import java.time.temporal.ChronoUnit; | ||
import java.util.concurrent.TimeUnit; | ||
|
||
import static org.hamcrest.Matchers.is; | ||
import static org.mockito.Mockito.mock; | ||
import static org.mockito.Mockito.times; | ||
import static org.mockito.Mockito.verify; | ||
import static org.mockito.Mockito.when; | ||
|
||
public class RateLimiterTests extends ESTestCase { | ||
public void testThrows_WhenAccumulatedTokensLimit_IsNegative() { | ||
var exception = expectThrows( | ||
IllegalArgumentException.class, | ||
() -> new RateLimiter(-1, 1, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC()) | ||
); | ||
assertThat(exception.getMessage(), is("Accumulated tokens limit must be greater than or equal to 0")); | ||
} | ||
|
||
public void testThrows_WhenAccumulatedTokensLimit_IsInfinity() { | ||
var exception = expectThrows( | ||
IllegalArgumentException.class, | ||
() -> new RateLimiter(Double.POSITIVE_INFINITY, 1, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC()) | ||
); | ||
assertThat( | ||
exception.getMessage(), | ||
is(Strings.format("Accumulated tokens limit must be less than or equal to %s", Double.MAX_VALUE)) | ||
); | ||
} | ||
|
||
public void testThrows_WhenTokensPerTimeUnit_IsZero() { | ||
var exception = expectThrows( | ||
IllegalArgumentException.class, | ||
() -> new RateLimiter(0, 0, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC()) | ||
); | ||
assertThat(exception.getMessage(), is("Tokens per time unit must be greater than 0")); | ||
} | ||
|
||
public void testThrows_WhenTokensPerTimeUnit_IsInfinity() { | ||
var exception = expectThrows( | ||
IllegalArgumentException.class, | ||
() -> new RateLimiter(0, Double.POSITIVE_INFINITY, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC()) | ||
); | ||
assertThat(exception.getMessage(), is(Strings.format("Tokens per time unit must be less than or equal to %s", Double.MAX_VALUE))); | ||
} | ||
|
||
public void testThrows_WhenTokensPerTimeUnit_IsNegative() { | ||
var exception = expectThrows( | ||
IllegalArgumentException.class, | ||
() -> new RateLimiter(0, -1, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC()) | ||
); | ||
assertThat(exception.getMessage(), is("Tokens per time unit must be greater than 0")); | ||
} | ||
|
||
public void testAcquire_Throws_WhenTokens_IsZero() { | ||
var limiter = new RateLimiter(0, 1, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC()); | ||
var exception = expectThrows(IllegalArgumentException.class, () -> limiter.acquire(0)); | ||
assertThat(exception.getMessage(), is("Requested tokens must be positive")); | ||
} | ||
|
||
public void testAcquire_Throws_WhenTokens_IsNegative() { | ||
var limiter = new RateLimiter(0, 1, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC()); | ||
var exception = expectThrows(IllegalArgumentException.class, () -> limiter.acquire(-1)); | ||
assertThat(exception.getMessage(), is("Requested tokens must be positive")); | ||
} | ||
|
||
public void testAcquire_First_CallDoesNotSleep() throws InterruptedException { | ||
var now = Clock.systemUTC().instant(); | ||
var clock = mock(Clock.class); | ||
when(clock.instant()).thenReturn(now); | ||
|
||
var sleeper = mock(RateLimiter.Sleeper.class); | ||
|
||
var limiter = new RateLimiter(1, 1, TimeUnit.MINUTES, sleeper, clock); | ||
limiter.acquire(1); | ||
verify(sleeper, times(1)).sleep(0); | ||
} | ||
|
||
public void testAcquire_DoesNotSleep_WhenTokenRateIsHigh() throws InterruptedException { | ||
var now = Clock.systemUTC().instant(); | ||
var clock = mock(Clock.class); | ||
when(clock.instant()).thenReturn(now); | ||
|
||
var sleeper = mock(RateLimiter.Sleeper.class); | ||
|
||
var limiter = new RateLimiter(0, Double.MAX_VALUE, TimeUnit.MICROSECONDS, sleeper, clock); | ||
limiter.acquire(1); | ||
verify(sleeper, times(1)).sleep(0); | ||
} | ||
|
||
public void testAcquire_AcceptsMaxIntValue_WhenTokenRateIsHigh() throws InterruptedException { | ||
var now = Clock.systemUTC().instant(); | ||
var clock = mock(Clock.class); | ||
when(clock.instant()).thenReturn(now); | ||
|
||
var sleeper = mock(RateLimiter.Sleeper.class); | ||
|
||
var limiter = new RateLimiter(0, Double.MAX_VALUE, TimeUnit.MICROSECONDS, sleeper, clock); | ||
limiter.acquire(Integer.MAX_VALUE); | ||
verify(sleeper, times(1)).sleep(0); | ||
} | ||
|
||
public void testAcquire_AcceptsMaxIntValue_WhenTokenRateIsLow() throws InterruptedException { | ||
var now = Clock.systemUTC().instant(); | ||
var clock = mock(Clock.class); | ||
when(clock.instant()).thenReturn(now); | ||
|
||
var sleeper = mock(RateLimiter.Sleeper.class); | ||
|
||
double tokensPerDay = 1; | ||
var limiter = new RateLimiter(0, tokensPerDay, TimeUnit.DAYS, sleeper, clock); | ||
limiter.acquire(Integer.MAX_VALUE); | ||
|
||
double tokensPerMicro = tokensPerDay / TimeUnit.DAYS.toMicros(1); | ||
verify(sleeper, times(1)).sleep((long) ((double) Integer.MAX_VALUE / tokensPerMicro)); | ||
} | ||
|
||
public void testAcquire_SleepsForOneMinute_WhenRequestingOneUnavailableToken() throws InterruptedException { | ||
var now = Clock.systemUTC().instant(); | ||
var clock = mock(Clock.class); | ||
when(clock.instant()).thenReturn(now); | ||
|
||
var sleeper = mock(RateLimiter.Sleeper.class); | ||
|
||
var limiter = new RateLimiter(1, 1, TimeUnit.MINUTES, sleeper, clock); | ||
limiter.acquire(2); | ||
verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toMicros(1)); | ||
} | ||
|
||
public void testAcquire_SleepsForOneMinute_WhenRequestingOneUnavailableToken_NoAccumulated() throws InterruptedException { | ||
var now = Clock.systemUTC().instant(); | ||
var clock = mock(Clock.class); | ||
when(clock.instant()).thenReturn(now); | ||
|
||
var sleeper = mock(RateLimiter.Sleeper.class); | ||
|
||
var limiter = new RateLimiter(0, 1, TimeUnit.MINUTES, sleeper, clock); | ||
limiter.acquire(1); | ||
verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toMicros(1)); | ||
} | ||
|
||
public void testAcquire_SleepsFor10Minute_WhenRequesting10UnavailableToken_NoAccumulated() throws InterruptedException { | ||
var now = Clock.systemUTC().instant(); | ||
var clock = mock(Clock.class); | ||
when(clock.instant()).thenReturn(now); | ||
|
||
var sleeper = mock(RateLimiter.Sleeper.class); | ||
|
||
var limiter = new RateLimiter(0, 1, TimeUnit.MINUTES, sleeper, clock); | ||
limiter.acquire(10); | ||
verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toMicros(10)); | ||
} | ||
|
||
public void testAcquire_IncrementsNextTokenAvailabilityInstant_ByOneMinute() throws InterruptedException { | ||
var now = Clock.systemUTC().instant(); | ||
var clock = mock(Clock.class); | ||
when(clock.instant()).thenReturn(now); | ||
|
||
var sleeper = mock(RateLimiter.Sleeper.class); | ||
|
||
var limiter = new RateLimiter(0, 1, TimeUnit.MINUTES, sleeper, clock); | ||
limiter.acquire(1); | ||
verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toMicros(1)); | ||
assertThat(limiter.getNextTokenAvailability(), is(now.plus(1, ChronoUnit.MINUTES))); | ||
} | ||
|
||
public void testAcquire_SecondCallToAcquire_ShouldWait_WhenAccumulatedTokensAreDepleted() throws InterruptedException { | ||
var now = Clock.systemUTC().instant(); | ||
var clock = mock(Clock.class); | ||
when(clock.instant()).thenReturn(now); | ||
|
||
var sleeper = mock(RateLimiter.Sleeper.class); | ||
|
||
var limiter = new RateLimiter(1, 1, TimeUnit.MINUTES, sleeper, clock); | ||
limiter.acquire(1); | ||
verify(sleeper, times(1)).sleep(0); | ||
limiter.acquire(1); | ||
verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toMicros(1)); | ||
} | ||
|
||
public void testAcquire_SecondCallToAcquire_ShouldWaitForHalfDuration_WhenElapsedTimeIsHalfRequiredDuration() | ||
throws InterruptedException { | ||
var now = Clock.systemUTC().instant(); | ||
var clock = mock(Clock.class); | ||
when(clock.instant()).thenReturn(now); | ||
|
||
var sleeper = mock(RateLimiter.Sleeper.class); | ||
|
||
var limiter = new RateLimiter(1, 1, TimeUnit.MINUTES, sleeper, clock); | ||
limiter.acquire(1); | ||
verify(sleeper, times(1)).sleep(0); | ||
when(clock.instant()).thenReturn(now.plus(Duration.ofSeconds(30))); | ||
limiter.acquire(1); | ||
verify(sleeper, times(1)).sleep(TimeUnit.SECONDS.toMicros(30)); | ||
} | ||
|
||
public void testAcquire_ShouldAccumulateTokens() throws InterruptedException { | ||
var now = Clock.systemUTC().instant(); | ||
var clock = mock(Clock.class); | ||
when(clock.instant()).thenReturn(now); | ||
|
||
var sleeper = mock(RateLimiter.Sleeper.class); | ||
|
||
var limiter = new RateLimiter(10, 10, TimeUnit.MINUTES, sleeper, clock); | ||
limiter.acquire(5); | ||
verify(sleeper, times(1)).sleep(0); | ||
// it should accumulate 5 tokens | ||
when(clock.instant()).thenReturn(now.plus(Duration.ofSeconds(30))); | ||
limiter.acquire(10); | ||
verify(sleeper, times(2)).sleep(0); | ||
} | ||
} |