Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Remove threading from tests #113212

Merged
merged 4 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions muted-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,6 @@ tests:
- class: org.elasticsearch.xpack.esql.action.ManyShardsIT
method: testConcurrentQueries
issue: https://github.com/elastic/elasticsearch/issues/112424
- class: org.elasticsearch.xpack.inference.external.http.RequestBasedTaskRunnerTests
method: testLoopOneAtATime
issue: https://github.com/elastic/elasticsearch/issues/112471
- class: org.elasticsearch.ingest.geoip.IngestGeoIpClientYamlTestSuiteIT
issue: https://github.com/elastic/elasticsearch/issues/111497
- class: org.elasticsearch.smoketest.SmokeTestIngestWithAllDepsClientYamlTestSuiteIT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class RequestBasedTaskRunner {
* Else, offload to a new thread so we do not block another threadpool's thread.
*/
public void requestNextRun() {
if (loopCount.getAndIncrement() == 0) {
if (isRunning.get() && loopCount.getAndIncrement() == 0) {
var currentThreadPool = EsExecutors.executorName(Thread.currentThread().getName());
if (executorServiceName.equalsIgnoreCase(currentThreadPool)) {
run();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,150 +7,98 @@

package org.elasticsearch.xpack.inference.external.http;

import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.junit.After;
import org.junit.Before;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantLock;
import java.util.concurrent.atomic.AtomicReference;

import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.spy;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;

public class RequestBasedTaskRunnerTests extends ESTestCase {
private ThreadPool threadPool;

@Before
public void setUp() throws Exception {
super.setUp();
threadPool = spy(createThreadPool(inferenceUtilityPool()));
threadPool = mock();
when(threadPool.executor(UTILITY_THREAD_POOL_NAME)).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE);
}

@After
public void tearDown() throws Exception {
terminate(threadPool);
super.tearDown();
}
public void testRequestWhileLoopingWillRerunCommand() {
var expectedTimesRerun = randomInt(5);
AtomicInteger counter = new AtomicInteger(0);

public void testLoopOneAtATime() throws Exception {
// count the number of times the runnable is called
var counter = new AtomicInteger(0);

// block the runnable and wait for the test thread to take an action
var lock = new ReentrantLock();
var condition = lock.newCondition();
Runnable block = () -> {
try {
try {
lock.lock();
condition.await();
} finally {
lock.unlock();
}
} catch (InterruptedException e) {
fail(e, "did not unblock the thread in time, likely during threadpool terminate");
}
};
Runnable unblock = () -> {
try {
lock.lock();
condition.signalAll();
} finally {
lock.unlock();
var requestNextRun = new AtomicReference<Runnable>();
Runnable command = () -> {
if (counter.getAndIncrement() < expectedTimesRerun) {
requestNextRun.get().run();
}
};

var runner = new RequestBasedTaskRunner(() -> {
counter.incrementAndGet();
block.run();
}, threadPool, UTILITY_THREAD_POOL_NAME);

// given we have not called requestNextRun, then no thread should have started
assertThat(counter.get(), equalTo(0));
verify(threadPool, times(0)).executor(UTILITY_THREAD_POOL_NAME);

var runner = new RequestBasedTaskRunner(command, threadPool, UTILITY_THREAD_POOL_NAME);
requestNextRun.set(runner::requestNextRun);
runner.requestNextRun();

// given that we have called requestNextRun, then 1 thread should run once
assertBusy(() -> {
verify(threadPool, times(1)).executor(UTILITY_THREAD_POOL_NAME);
assertThat(counter.get(), equalTo(1));
});

// given that we have called requestNextRun while a thread was running, and the thread was blocked
runner.requestNextRun();
// then 1 thread should run once
verify(threadPool, times(1)).executor(UTILITY_THREAD_POOL_NAME);
assertThat(counter.get(), equalTo(1));
verify(threadPool, times(1)).executor(eq(UTILITY_THREAD_POOL_NAME));
verifyNoMoreInteractions(threadPool);
assertThat(counter.get(), equalTo(expectedTimesRerun + 1));
}

// given the thread is unblocked
unblock.run();
// then 1 thread should run twice
verify(threadPool, times(1)).executor(UTILITY_THREAD_POOL_NAME);
assertBusy(() -> assertThat(counter.get(), equalTo(2)));
public void testRequestWhileNotLoopingWillQueueCommand() {
AtomicInteger counter = new AtomicInteger(0);

// given the thread is unblocked again, but there were only two calls to requestNextRun
unblock.run();
// then 1 thread should run twice
verify(threadPool, times(1)).executor(UTILITY_THREAD_POOL_NAME);
assertBusy(() -> assertThat(counter.get(), equalTo(2)));
var runner = new RequestBasedTaskRunner(counter::incrementAndGet, threadPool, UTILITY_THREAD_POOL_NAME);

// given no thread is running, when we call requestNextRun
runner.requestNextRun();
// then a second thread should start for the third run
assertBusy(() -> {
verify(threadPool, times(2)).executor(UTILITY_THREAD_POOL_NAME);
assertThat(counter.get(), equalTo(3));
});

// given the thread is unblocked, then it should exit and rejoin the threadpool
unblock.run();
assertTrue("Test thread should unblock after all runs complete", terminate(threadPool));

// final check - we ran three times on two threads
verify(threadPool, times(2)).executor(UTILITY_THREAD_POOL_NAME);
assertThat(counter.get(), equalTo(3));
for (int i = 1; i < randomInt(10); i++) {
runner.requestNextRun();
verify(threadPool, times(i)).executor(eq(UTILITY_THREAD_POOL_NAME));
assertThat(counter.get(), equalTo(i));
}
;
}

public void testCancel() throws Exception {
// count the number of times the runnable is called
var counter = new AtomicInteger(0);
var latch = new CountDownLatch(1);
var runner = new RequestBasedTaskRunner(() -> {
counter.incrementAndGet();
try {
latch.await();
} catch (InterruptedException e) {
fail(e, "did not unblock the thread in time, likely during threadpool terminate");
}
}, threadPool, UTILITY_THREAD_POOL_NAME);
public void testCancelBeforeRunning() {
AtomicInteger counter = new AtomicInteger(0);

// given that we have called requestNextRun, then 1 thread should run once
var runner = new RequestBasedTaskRunner(counter::incrementAndGet, threadPool, UTILITY_THREAD_POOL_NAME);
runner.cancel();
runner.requestNextRun();
assertBusy(() -> {
verify(threadPool, times(1)).executor(UTILITY_THREAD_POOL_NAME);
assertThat(counter.get(), equalTo(1));
});

// given that a thread is running, three more calls will be queued
runner.requestNextRun();
runner.requestNextRun();
verifyNoInteractions(threadPool);
assertThat(counter.get(), equalTo(0));
}

public void testCancelWhileRunning() {
var expectedTimesRerun = randomInt(5);
AtomicInteger counter = new AtomicInteger(0);

var runnerRef = new AtomicReference<RequestBasedTaskRunner>();
Runnable command = () -> {
if (counter.getAndIncrement() < expectedTimesRerun) {
runnerRef.get().requestNextRun();
}
runnerRef.get().cancel();
};
var runner = new RequestBasedTaskRunner(command, threadPool, UTILITY_THREAD_POOL_NAME);
runnerRef.set(runner);
runner.requestNextRun();

// when we cancel the thread, then the thread should immediately exit and rejoin
runner.cancel();
latch.countDown();
assertTrue("Test thread should unblock after all runs complete", terminate(threadPool));
verify(threadPool, times(1)).executor(eq(UTILITY_THREAD_POOL_NAME));
verifyNoMoreInteractions(threadPool);
assertThat(counter.get(), equalTo(1));

// given that we called cancel, when we call requestNextRun then no thread should start
runner.requestNextRun();
verify(threadPool, times(1)).executor(UTILITY_THREAD_POOL_NAME);
verifyNoMoreInteractions(threadPool);
assertThat(counter.get(), equalTo(1));
}

Expand Down