From 9f138840e18128758f835afc59cabfc039f55cfc Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 12 Mar 2024 17:44:00 -0400 Subject: [PATCH 1/5] Working tests --- .../xpack/inference/common/RateLimiter.java | 107 ++++++++++++++++++ .../inference/common/RateLimiterTests.java | 107 ++++++++++++++++++ 2 files changed, 214 insertions(+) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/RateLimiter.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/RateLimiterTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/RateLimiter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/RateLimiter.java new file mode 100644 index 0000000000000..285b1a94feb2f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/RateLimiter.java @@ -0,0 +1,107 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.common; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.Objects; +import java.util.concurrent.TimeUnit; + +/** + * Implements a throttler using the token bucket algorithm. + */ +public class RateLimiter { + + private double tokensPerTimeUnit; + private double accumulatedTokensLimit; + private double accumulatedTokens; + private Instant nextTokenAvailability; + private TimeUnit unit; + private final Sleeper sleeper; + private final Clock clock; + + /** + * @param accumulatedTokensLimit the limit for tokens stashed in the bucket + * @param tokensPerTimeUnit the number of tokens to produce per the time unit passed in + * @param unit the time unit frequency for generating tokens + */ + public RateLimiter(double accumulatedTokensLimit, double tokensPerTimeUnit, TimeUnit unit) { + this(accumulatedTokensLimit, tokensPerTimeUnit, unit, new TimeUnitSleeper(), Clock.systemUTC()); + } + + // default for testing + RateLimiter(double accumulatedTokensLimit, double tokensPerTimeUnit, TimeUnit unit, Sleeper sleeper, Clock clock) { + this.sleeper = Objects.requireNonNull(sleeper); + this.clock = Objects.requireNonNull(clock); + nextTokenAvailability = Instant.MIN; + setRate(accumulatedTokensLimit, tokensPerTimeUnit, unit); + } + + public final synchronized void setRate(double newAccumulatedTokensLimit, double newTokensPerTimeUnit, TimeUnit newUnit) { + if (newAccumulatedTokensLimit < 0) { + throw new IllegalArgumentException("Accumulated tokens limit must be greater than or equal to 0"); + } + + if (newTokensPerTimeUnit <= 0) { + throw new IllegalArgumentException("Tokens per time unit must be greater than 0"); + } + + accumulatedTokens = Math.min(accumulatedTokens, newAccumulatedTokensLimit); + + accumulatedTokensLimit = newAccumulatedTokensLimit; + tokensPerTimeUnit = newTokensPerTimeUnit; + unit = Objects.requireNonNull(newUnit); + accumulateTokens(); + } + + /** + * Causes the thread to wait until the tokens are available + * @param tokens the number of items of work that should be throttled, typically you'd pass a value of 1 here + * @throws InterruptedException + */ + public void acquire(int tokens) throws InterruptedException { + if (tokens <= 0) { + throw new IllegalArgumentException("Requested tokens must be positive"); + } + + double nanosToWait; + synchronized (this) { + accumulateTokens(); + var accumulatedTokensToUse = Math.min(tokens, accumulatedTokens); + var additionalTokensRequired = tokens - accumulatedTokensToUse; + var timeUnitsToWait = additionalTokensRequired / tokensPerTimeUnit; + var unitsInNanos = unit.toNanos(1); + nanosToWait = timeUnitsToWait * unitsInNanos; + accumulatedTokens -= accumulatedTokensToUse; + nextTokenAvailability = nextTokenAvailability.plus(Duration.ofNanos((long) nanosToWait)); + } + + sleeper.sleep((long) nanosToWait); + } + + private void accumulateTokens() { + var now = Instant.now(clock); + if (now.isAfter(nextTokenAvailability)) { + var elapsedTime = unit.toChronoUnit().between(nextTokenAvailability, now); + var newTokens = tokensPerTimeUnit * elapsedTime; + accumulatedTokens = Math.min(accumulatedTokensLimit, newTokens); + nextTokenAvailability = now; + } + } + + public interface Sleeper { + void sleep(long nanosecondsToSleep) throws InterruptedException; + } + + static final class TimeUnitSleeper implements Sleeper { + public void sleep(long nanosecondsToSleep) throws InterruptedException { + TimeUnit.NANOSECONDS.sleep(nanosecondsToSleep); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/RateLimiterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/RateLimiterTests.java new file mode 100644 index 0000000000000..dbdbfa380f003 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/RateLimiterTests.java @@ -0,0 +1,107 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.common; + +import org.elasticsearch.test.ESTestCase; + +import java.time.Clock; +import java.util.concurrent.TimeUnit; + +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class RateLimiterTests extends ESTestCase { + public void testThrows_WhenAccumulatedTokensLimit_IsNegative() { + var exception = expectThrows( + IllegalArgumentException.class, + () -> new RateLimiter(-1, 1, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC()) + ); + assertThat(exception.getMessage(), is("Accumulated tokens limit must be greater than or equal to 0")); + } + + public void testThrows_WhenTokensPerTimeUnit_IsZero() { + var exception = expectThrows( + IllegalArgumentException.class, + () -> new RateLimiter(0, 0, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC()) + ); + assertThat(exception.getMessage(), is("Tokens per time unit must be greater than 0")); + } + + public void testThrows_WhenTokensPerTimeUnit_IsNegative() { + var exception = expectThrows( + IllegalArgumentException.class, + () -> new RateLimiter(0, -1, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC()) + ); + assertThat(exception.getMessage(), is("Tokens per time unit must be greater than 0")); + } + + public void testAcquire_Throws_WhenTokens_IsZero() { + var limiter = new RateLimiter(0, 1, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC()); + var exception = expectThrows(IllegalArgumentException.class, () -> limiter.acquire(0)); + assertThat(exception.getMessage(), is("Requested tokens must be positive")); + } + + public void testAcquire_Throws_WhenTokens_IsNegative() { + var limiter = new RateLimiter(0, 1, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC()); + var exception = expectThrows(IllegalArgumentException.class, () -> limiter.acquire(-1)); + assertThat(exception.getMessage(), is("Requested tokens must be positive")); + } + + public void testAcquire_First_CallDoesNotSleep() throws InterruptedException { + var now = Clock.systemUTC().instant(); + var clock = mock(Clock.class); + when(clock.instant()).thenReturn(now); + + var sleeper = mock(RateLimiter.Sleeper.class); + + var limiter = new RateLimiter(1, 1, TimeUnit.MINUTES, sleeper, clock); + limiter.acquire(1); + verify(sleeper, times(1)).sleep(0); + } + + public void testAcquire_SleepsForOneMinute_WhenRequestingOneUnavailableToken() throws InterruptedException { + var now = Clock.systemUTC().instant(); + var clock = mock(Clock.class); + when(clock.instant()).thenReturn(now); + + var sleeper = mock(RateLimiter.Sleeper.class); + + var limiter = new RateLimiter(1, 1, TimeUnit.MINUTES, sleeper, clock); + limiter.acquire(2); + verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toNanos(1)); + } + + public void testAcquire_SleepsForOneMinute_WhenRequestingOneUnavailableToken_NoAccumulated() throws InterruptedException { + var now = Clock.systemUTC().instant(); + var clock = mock(Clock.class); + when(clock.instant()).thenReturn(now); + + var sleeper = mock(RateLimiter.Sleeper.class); + + var limiter = new RateLimiter(0, 1, TimeUnit.MINUTES, sleeper, clock); + limiter.acquire(1); + verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toNanos(1)); + } + + public void testAcquire_SecondCallToAcquire_ShouldWait_WhenAccumulatedTokensAreDepleted() throws InterruptedException { + var now = Clock.systemUTC().instant(); + var clock = mock(Clock.class); + when(clock.instant()).thenReturn(now); + + var sleeper = mock(RateLimiter.Sleeper.class); + + var limiter = new RateLimiter(1, 1, TimeUnit.MINUTES, sleeper, clock); + limiter.acquire(1); + verify(sleeper, times(1)).sleep(0); + limiter.acquire(1); + verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toNanos(1)); + } +} From eb0b08647c53e7ad4f1ffa04bdd2f8019c7562ed Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 13 Mar 2024 13:21:44 -0400 Subject: [PATCH 2/5] Adding more tests --- .../xpack/inference/common/RateLimiter.java | 45 ++++++++-- .../inference/common/RateLimiterTests.java | 88 +++++++++++++++++++ 2 files changed, 124 insertions(+), 9 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/RateLimiter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/RateLimiter.java index 285b1a94feb2f..99ce21985d8b6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/RateLimiter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/RateLimiter.java @@ -7,9 +7,12 @@ package org.elasticsearch.xpack.inference.common; +import org.elasticsearch.common.Strings; + import java.time.Clock; import java.time.Duration; import java.time.Instant; +import java.time.temporal.ChronoUnit; import java.util.Objects; import java.util.concurrent.TimeUnit; @@ -18,11 +21,10 @@ */ public class RateLimiter { - private double tokensPerTimeUnit; + private double tokensPerNanos; private double accumulatedTokensLimit; private double accumulatedTokens; private Instant nextTokenAvailability; - private TimeUnit unit; private final Sleeper sleeper; private final Clock clock; @@ -44,19 +46,34 @@ public RateLimiter(double accumulatedTokensLimit, double tokensPerTimeUnit, Time } public final synchronized void setRate(double newAccumulatedTokensLimit, double newTokensPerTimeUnit, TimeUnit newUnit) { + Objects.requireNonNull(newUnit); + if (newAccumulatedTokensLimit < 0) { throw new IllegalArgumentException("Accumulated tokens limit must be greater than or equal to 0"); } + if (newAccumulatedTokensLimit == Double.POSITIVE_INFINITY) { + throw new IllegalArgumentException( + Strings.format("Accumulated tokens limit must be less than or equal to %s", Double.MAX_VALUE) + ); + } + if (newTokensPerTimeUnit <= 0) { throw new IllegalArgumentException("Tokens per time unit must be greater than 0"); } + if (newTokensPerTimeUnit == Double.POSITIVE_INFINITY) { + throw new IllegalArgumentException(Strings.format("Tokens per time unit must be less than or equal to %s", Double.MAX_VALUE)); + } + accumulatedTokens = Math.min(accumulatedTokens, newAccumulatedTokensLimit); accumulatedTokensLimit = newAccumulatedTokensLimit; - tokensPerTimeUnit = newTokensPerTimeUnit; - unit = Objects.requireNonNull(newUnit); + + var unitsInNanos = newUnit.toNanos(1); + tokensPerNanos = newTokensPerTimeUnit / unitsInNanos; + assert tokensPerNanos != Double.POSITIVE_INFINITY : "Tokens per nanosecond should not be infinity"; + accumulateTokens(); } @@ -75,9 +92,7 @@ public void acquire(int tokens) throws InterruptedException { accumulateTokens(); var accumulatedTokensToUse = Math.min(tokens, accumulatedTokens); var additionalTokensRequired = tokens - accumulatedTokensToUse; - var timeUnitsToWait = additionalTokensRequired / tokensPerTimeUnit; - var unitsInNanos = unit.toNanos(1); - nanosToWait = timeUnitsToWait * unitsInNanos; + nanosToWait = additionalTokensRequired / tokensPerNanos; accumulatedTokens -= accumulatedTokensToUse; nextTokenAvailability = nextTokenAvailability.plus(Duration.ofNanos((long) nanosToWait)); } @@ -88,13 +103,25 @@ public void acquire(int tokens) throws InterruptedException { private void accumulateTokens() { var now = Instant.now(clock); if (now.isAfter(nextTokenAvailability)) { - var elapsedTime = unit.toChronoUnit().between(nextTokenAvailability, now); - var newTokens = tokensPerTimeUnit * elapsedTime; + var elapsedTimeNanos = nanosBetweenExact(nextTokenAvailability, now); + var newTokens = tokensPerNanos * elapsedTimeNanos; accumulatedTokens = Math.min(accumulatedTokensLimit, newTokens); nextTokenAvailability = now; } } + private static double nanosBetweenExact(Instant start, Instant end) { + try { + return ChronoUnit.NANOS.between(start, end); + } catch (ArithmeticException e) { + if (end.isAfter(start)) { + return Double.POSITIVE_INFINITY; + } + + return Double.NEGATIVE_INFINITY; + } + } + public interface Sleeper { void sleep(long nanosecondsToSleep) throws InterruptedException; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/RateLimiterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/RateLimiterTests.java index dbdbfa380f003..85c3cc9ceb7b6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/RateLimiterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/RateLimiterTests.java @@ -7,9 +7,11 @@ package org.elasticsearch.xpack.inference.common; +import org.elasticsearch.common.Strings; import org.elasticsearch.test.ESTestCase; import java.time.Clock; +import java.time.Duration; import java.util.concurrent.TimeUnit; import static org.hamcrest.Matchers.is; @@ -27,6 +29,17 @@ public void testThrows_WhenAccumulatedTokensLimit_IsNegative() { assertThat(exception.getMessage(), is("Accumulated tokens limit must be greater than or equal to 0")); } + public void testThrows_WhenAccumulatedTokensLimit_IsInfinity() { + var exception = expectThrows( + IllegalArgumentException.class, + () -> new RateLimiter(Double.POSITIVE_INFINITY, 1, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC()) + ); + assertThat( + exception.getMessage(), + is(Strings.format("Accumulated tokens limit must be less than or equal to %s", Double.MAX_VALUE)) + ); + } + public void testThrows_WhenTokensPerTimeUnit_IsZero() { var exception = expectThrows( IllegalArgumentException.class, @@ -35,6 +48,14 @@ public void testThrows_WhenTokensPerTimeUnit_IsZero() { assertThat(exception.getMessage(), is("Tokens per time unit must be greater than 0")); } + public void testThrows_WhenTokensPerTimeUnit_IsInfinity() { + var exception = expectThrows( + IllegalArgumentException.class, + () -> new RateLimiter(0, Double.POSITIVE_INFINITY, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC()) + ); + assertThat(exception.getMessage(), is(Strings.format("Tokens per time unit must be less than or equal to %s", Double.MAX_VALUE))); + } + public void testThrows_WhenTokensPerTimeUnit_IsNegative() { var exception = expectThrows( IllegalArgumentException.class, @@ -67,6 +88,45 @@ public void testAcquire_First_CallDoesNotSleep() throws InterruptedException { verify(sleeper, times(1)).sleep(0); } + public void testAcquire_DoesNotSleep_WhenTokenRateIsHigh() throws InterruptedException { + var now = Clock.systemUTC().instant(); + var clock = mock(Clock.class); + when(clock.instant()).thenReturn(now); + + var sleeper = mock(RateLimiter.Sleeper.class); + + var limiter = new RateLimiter(0, Double.MAX_VALUE, TimeUnit.NANOSECONDS, sleeper, clock); + limiter.acquire(1); + verify(sleeper, times(1)).sleep(0); + } + + public void testAcquire_AcceptsMaxIntValue_WhenTokenRateIsHigh() throws InterruptedException { + var now = Clock.systemUTC().instant(); + var clock = mock(Clock.class); + when(clock.instant()).thenReturn(now); + + var sleeper = mock(RateLimiter.Sleeper.class); + + var limiter = new RateLimiter(0, Double.MAX_VALUE, TimeUnit.NANOSECONDS, sleeper, clock); + limiter.acquire(Integer.MAX_VALUE); + verify(sleeper, times(1)).sleep(0); + } + + public void testAcquire_AcceptsMaxIntValue_WhenTokenRateIsLow() throws InterruptedException { + var now = Clock.systemUTC().instant(); + var clock = mock(Clock.class); + when(clock.instant()).thenReturn(now); + + var sleeper = mock(RateLimiter.Sleeper.class); + + double tokensPerDay = 1; + var limiter = new RateLimiter(0, tokensPerDay, TimeUnit.DAYS, sleeper, clock); + limiter.acquire(Integer.MAX_VALUE); + + double tokensPerNano = tokensPerDay / TimeUnit.DAYS.toNanos(1); + verify(sleeper, times(1)).sleep((long) ((double) Integer.MAX_VALUE / tokensPerNano)); + } + public void testAcquire_SleepsForOneMinute_WhenRequestingOneUnavailableToken() throws InterruptedException { var now = Clock.systemUTC().instant(); var clock = mock(Clock.class); @@ -91,6 +151,18 @@ public void testAcquire_SleepsForOneMinute_WhenRequestingOneUnavailableToken_NoA verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toNanos(1)); } + public void testAcquire_SleepsFor10Minute_WhenRequesting10UnavailableToken_NoAccumulated() throws InterruptedException { + var now = Clock.systemUTC().instant(); + var clock = mock(Clock.class); + when(clock.instant()).thenReturn(now); + + var sleeper = mock(RateLimiter.Sleeper.class); + + var limiter = new RateLimiter(0, 1, TimeUnit.MINUTES, sleeper, clock); + limiter.acquire(10); + verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toNanos(10)); + } + public void testAcquire_SecondCallToAcquire_ShouldWait_WhenAccumulatedTokensAreDepleted() throws InterruptedException { var now = Clock.systemUTC().instant(); var clock = mock(Clock.class); @@ -104,4 +176,20 @@ public void testAcquire_SecondCallToAcquire_ShouldWait_WhenAccumulatedTokensAreD limiter.acquire(1); verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toNanos(1)); } + + public void testAcquire_SecondCallToAcquire_ShouldWaitForHalfDuration_WhenElapsedTimeIsHalfRequiredDuration() + throws InterruptedException { + var now = Clock.systemUTC().instant(); + var clock = mock(Clock.class); + when(clock.instant()).thenReturn(now); + + var sleeper = mock(RateLimiter.Sleeper.class); + + var limiter = new RateLimiter(1, 1, TimeUnit.MINUTES, sleeper, clock); + limiter.acquire(1); + verify(sleeper, times(1)).sleep(0); + when(clock.instant()).thenReturn(now.plus(Duration.ofSeconds(30))); + limiter.acquire(1); + verify(sleeper, times(1)).sleep(TimeUnit.SECONDS.toNanos(30)); + } } From f15bd38dd28c12151a5cfa6f5ea9be3ba4ce4a78 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 13 Mar 2024 15:24:45 -0400 Subject: [PATCH 3/5] Adding comment --- .../xpack/inference/common/RateLimiter.java | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/RateLimiter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/RateLimiter.java index 99ce21985d8b6..f38c011551c51 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/RateLimiter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/RateLimiter.java @@ -18,6 +18,16 @@ /** * Implements a throttler using the token bucket algorithm. + * + * The general approach is to define the rate limiter with size (accumulated tokens limit) which dictates how many + * unused tokens can be saved up, and a rate at which the tokens are created. Then when a thread should be rate limited + * it can attempt to acquire a certain number of tokens (typically one for each item of work it's going to do). If unused tokens + * are available in the bucket already, those will be used. If the number of available tokens covers the desired amount + * the thread will not sleep. If the bucket does not contain enough tokens, it will calculate how long the thread needs to sleep + * to accumulate the requested amount of tokens. + * + * By setting the accumulated tokens limit to a value greater than zero, it effectively allows bursts of traffic. If the accumulated + * tokens limit is set to zero, it will force the acquiring thread to wait on each call. */ public class RateLimiter { From 8cc99c96df42e745f3d2bc7c10add5bd89c517d7 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 18 Mar 2024 10:48:24 -0400 Subject: [PATCH 4/5] Switching to micros and addressing feedback --- .../xpack/inference/common/RateLimiter.java | 44 ++++++++++--------- .../inference/common/RateLimiterTests.java | 32 ++++++++++---- 2 files changed, 47 insertions(+), 29 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/RateLimiter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/RateLimiter.java index f38c011551c51..5d5abcb471ee4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/RateLimiter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/RateLimiter.java @@ -10,7 +10,6 @@ import org.elasticsearch.common.Strings; import java.time.Clock; -import java.time.Duration; import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.Objects; @@ -31,7 +30,7 @@ */ public class RateLimiter { - private double tokensPerNanos; + private double tokensPerMicros; private double accumulatedTokensLimit; private double accumulatedTokens; private Instant nextTokenAvailability; @@ -62,7 +61,7 @@ public final synchronized void setRate(double newAccumulatedTokensLimit, double throw new IllegalArgumentException("Accumulated tokens limit must be greater than or equal to 0"); } - if (newAccumulatedTokensLimit == Double.POSITIVE_INFINITY) { + if (Double.isInfinite(newAccumulatedTokensLimit)) { throw new IllegalArgumentException( Strings.format("Accumulated tokens limit must be less than or equal to %s", Double.MAX_VALUE) ); @@ -80,9 +79,9 @@ public final synchronized void setRate(double newAccumulatedTokensLimit, double accumulatedTokensLimit = newAccumulatedTokensLimit; - var unitsInNanos = newUnit.toNanos(1); - tokensPerNanos = newTokensPerTimeUnit / unitsInNanos; - assert tokensPerNanos != Double.POSITIVE_INFINITY : "Tokens per nanosecond should not be infinity"; + var unitsInNanos = newUnit.toMicros(1); + tokensPerMicros = newTokensPerTimeUnit / unitsInNanos; + assert Double.isInfinite(tokensPerMicros) == false : "Tokens per microsecond should not be infinity"; accumulateTokens(); } @@ -90,55 +89,60 @@ public final synchronized void setRate(double newAccumulatedTokensLimit, double /** * Causes the thread to wait until the tokens are available * @param tokens the number of items of work that should be throttled, typically you'd pass a value of 1 here - * @throws InterruptedException + * @throws InterruptedException _ */ public void acquire(int tokens) throws InterruptedException { if (tokens <= 0) { throw new IllegalArgumentException("Requested tokens must be positive"); } - double nanosToWait; + double microsToWait; synchronized (this) { accumulateTokens(); var accumulatedTokensToUse = Math.min(tokens, accumulatedTokens); var additionalTokensRequired = tokens - accumulatedTokensToUse; - nanosToWait = additionalTokensRequired / tokensPerNanos; + microsToWait = additionalTokensRequired / tokensPerMicros; accumulatedTokens -= accumulatedTokensToUse; - nextTokenAvailability = nextTokenAvailability.plus(Duration.ofNanos((long) nanosToWait)); + nextTokenAvailability = nextTokenAvailability.plus((long) microsToWait, ChronoUnit.MICROS); } - sleeper.sleep((long) nanosToWait); + sleeper.sleep((long) microsToWait); } private void accumulateTokens() { var now = Instant.now(clock); if (now.isAfter(nextTokenAvailability)) { - var elapsedTimeNanos = nanosBetweenExact(nextTokenAvailability, now); - var newTokens = tokensPerNanos * elapsedTimeNanos; + var elapsedTimeNanos = microsBetweenExact(nextTokenAvailability, now); + var newTokens = tokensPerMicros * elapsedTimeNanos; accumulatedTokens = Math.min(accumulatedTokensLimit, newTokens); nextTokenAvailability = now; } } - private static double nanosBetweenExact(Instant start, Instant end) { + private static long microsBetweenExact(Instant start, Instant end) { try { - return ChronoUnit.NANOS.between(start, end); + return ChronoUnit.MICROS.between(start, end); } catch (ArithmeticException e) { if (end.isAfter(start)) { - return Double.POSITIVE_INFINITY; + return Long.MAX_VALUE; } - return Double.NEGATIVE_INFINITY; + return 0; } } + // default for testing + Instant getNextTokenAvailability() { + return nextTokenAvailability; + } + public interface Sleeper { - void sleep(long nanosecondsToSleep) throws InterruptedException; + void sleep(long microsecondsToSleep) throws InterruptedException; } static final class TimeUnitSleeper implements Sleeper { - public void sleep(long nanosecondsToSleep) throws InterruptedException { - TimeUnit.NANOSECONDS.sleep(nanosecondsToSleep); + public void sleep(long microsecondsToSleep) throws InterruptedException { + TimeUnit.MICROSECONDS.sleep(microsecondsToSleep); } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/RateLimiterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/RateLimiterTests.java index 85c3cc9ceb7b6..f4ca2d3f078ca 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/RateLimiterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/RateLimiterTests.java @@ -12,6 +12,7 @@ import java.time.Clock; import java.time.Duration; +import java.time.temporal.ChronoUnit; import java.util.concurrent.TimeUnit; import static org.hamcrest.Matchers.is; @@ -95,7 +96,7 @@ public void testAcquire_DoesNotSleep_WhenTokenRateIsHigh() throws InterruptedExc var sleeper = mock(RateLimiter.Sleeper.class); - var limiter = new RateLimiter(0, Double.MAX_VALUE, TimeUnit.NANOSECONDS, sleeper, clock); + var limiter = new RateLimiter(0, Double.MAX_VALUE, TimeUnit.MICROSECONDS, sleeper, clock); limiter.acquire(1); verify(sleeper, times(1)).sleep(0); } @@ -107,7 +108,7 @@ public void testAcquire_AcceptsMaxIntValue_WhenTokenRateIsHigh() throws Interrup var sleeper = mock(RateLimiter.Sleeper.class); - var limiter = new RateLimiter(0, Double.MAX_VALUE, TimeUnit.NANOSECONDS, sleeper, clock); + var limiter = new RateLimiter(0, Double.MAX_VALUE, TimeUnit.MICROSECONDS, sleeper, clock); limiter.acquire(Integer.MAX_VALUE); verify(sleeper, times(1)).sleep(0); } @@ -123,8 +124,8 @@ public void testAcquire_AcceptsMaxIntValue_WhenTokenRateIsLow() throws Interrupt var limiter = new RateLimiter(0, tokensPerDay, TimeUnit.DAYS, sleeper, clock); limiter.acquire(Integer.MAX_VALUE); - double tokensPerNano = tokensPerDay / TimeUnit.DAYS.toNanos(1); - verify(sleeper, times(1)).sleep((long) ((double) Integer.MAX_VALUE / tokensPerNano)); + double tokensPerMicro = tokensPerDay / TimeUnit.DAYS.toMicros(1); + verify(sleeper, times(1)).sleep((long) ((double) Integer.MAX_VALUE / tokensPerMicro)); } public void testAcquire_SleepsForOneMinute_WhenRequestingOneUnavailableToken() throws InterruptedException { @@ -136,7 +137,7 @@ public void testAcquire_SleepsForOneMinute_WhenRequestingOneUnavailableToken() t var limiter = new RateLimiter(1, 1, TimeUnit.MINUTES, sleeper, clock); limiter.acquire(2); - verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toNanos(1)); + verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toMicros(1)); } public void testAcquire_SleepsForOneMinute_WhenRequestingOneUnavailableToken_NoAccumulated() throws InterruptedException { @@ -148,7 +149,7 @@ public void testAcquire_SleepsForOneMinute_WhenRequestingOneUnavailableToken_NoA var limiter = new RateLimiter(0, 1, TimeUnit.MINUTES, sleeper, clock); limiter.acquire(1); - verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toNanos(1)); + verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toMicros(1)); } public void testAcquire_SleepsFor10Minute_WhenRequesting10UnavailableToken_NoAccumulated() throws InterruptedException { @@ -160,7 +161,20 @@ public void testAcquire_SleepsFor10Minute_WhenRequesting10UnavailableToken_NoAcc var limiter = new RateLimiter(0, 1, TimeUnit.MINUTES, sleeper, clock); limiter.acquire(10); - verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toNanos(10)); + verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toMicros(10)); + } + + public void testAcquire_IncrementsNextTokenAvailabilityInstant_ByOneMinute() throws InterruptedException { + var now = Clock.systemUTC().instant(); + var clock = mock(Clock.class); + when(clock.instant()).thenReturn(now); + + var sleeper = mock(RateLimiter.Sleeper.class); + + var limiter = new RateLimiter(0, 1, TimeUnit.MINUTES, sleeper, clock); + limiter.acquire(1); + verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toMicros(1)); + assertThat(limiter.getNextTokenAvailability(), is(now.plus(1, ChronoUnit.MINUTES))); } public void testAcquire_SecondCallToAcquire_ShouldWait_WhenAccumulatedTokensAreDepleted() throws InterruptedException { @@ -174,7 +188,7 @@ public void testAcquire_SecondCallToAcquire_ShouldWait_WhenAccumulatedTokensAreD limiter.acquire(1); verify(sleeper, times(1)).sleep(0); limiter.acquire(1); - verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toNanos(1)); + verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toMicros(1)); } public void testAcquire_SecondCallToAcquire_ShouldWaitForHalfDuration_WhenElapsedTimeIsHalfRequiredDuration() @@ -190,6 +204,6 @@ public void testAcquire_SecondCallToAcquire_ShouldWaitForHalfDuration_WhenElapse verify(sleeper, times(1)).sleep(0); when(clock.instant()).thenReturn(now.plus(Duration.ofSeconds(30))); limiter.acquire(1); - verify(sleeper, times(1)).sleep(TimeUnit.SECONDS.toNanos(30)); + verify(sleeper, times(1)).sleep(TimeUnit.SECONDS.toMicros(30)); } } From 0496a8930ef0275a4e39dc8d942f8eb555f848cb Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 19 Mar 2024 08:56:41 -0400 Subject: [PATCH 5/5] Removing nanos and adding test for bug fix --- .../xpack/inference/common/RateLimiter.java | 10 +++++----- .../xpack/inference/common/RateLimiterTests.java | 16 ++++++++++++++++ 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/RateLimiter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/RateLimiter.java index 5d5abcb471ee4..ac28aa87f554b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/RateLimiter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/RateLimiter.java @@ -79,8 +79,8 @@ public final synchronized void setRate(double newAccumulatedTokensLimit, double accumulatedTokensLimit = newAccumulatedTokensLimit; - var unitsInNanos = newUnit.toMicros(1); - tokensPerMicros = newTokensPerTimeUnit / unitsInNanos; + var unitsInMicros = newUnit.toMicros(1); + tokensPerMicros = newTokensPerTimeUnit / unitsInMicros; assert Double.isInfinite(tokensPerMicros) == false : "Tokens per microsecond should not be infinity"; accumulateTokens(); @@ -112,9 +112,9 @@ public void acquire(int tokens) throws InterruptedException { private void accumulateTokens() { var now = Instant.now(clock); if (now.isAfter(nextTokenAvailability)) { - var elapsedTimeNanos = microsBetweenExact(nextTokenAvailability, now); - var newTokens = tokensPerMicros * elapsedTimeNanos; - accumulatedTokens = Math.min(accumulatedTokensLimit, newTokens); + var elapsedTimeMicros = microsBetweenExact(nextTokenAvailability, now); + var newTokens = tokensPerMicros * elapsedTimeMicros; + accumulatedTokens = Math.min(accumulatedTokensLimit, accumulatedTokens + newTokens); nextTokenAvailability = now; } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/RateLimiterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/RateLimiterTests.java index f4ca2d3f078ca..46931f12aaf4f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/RateLimiterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/RateLimiterTests.java @@ -206,4 +206,20 @@ public void testAcquire_SecondCallToAcquire_ShouldWaitForHalfDuration_WhenElapse limiter.acquire(1); verify(sleeper, times(1)).sleep(TimeUnit.SECONDS.toMicros(30)); } + + public void testAcquire_ShouldAccumulateTokens() throws InterruptedException { + var now = Clock.systemUTC().instant(); + var clock = mock(Clock.class); + when(clock.instant()).thenReturn(now); + + var sleeper = mock(RateLimiter.Sleeper.class); + + var limiter = new RateLimiter(10, 10, TimeUnit.MINUTES, sleeper, clock); + limiter.acquire(5); + verify(sleeper, times(1)).sleep(0); + // it should accumulate 5 tokens + when(clock.instant()).thenReturn(now.plus(Duration.ofSeconds(30))); + limiter.acquire(10); + verify(sleeper, times(2)).sleep(0); + } }