diff --git a/modules/reindex/src/main/java/org/elasticsearch/index/reindex/ReindexTask.java b/modules/reindex/src/main/java/org/elasticsearch/index/reindex/ReindexTask.java index 70679be32121c..5881acf4114c0 100644 --- a/modules/reindex/src/main/java/org/elasticsearch/index/reindex/ReindexTask.java +++ b/modules/reindex/src/main/java/org/elasticsearch/index/reindex/ReindexTask.java @@ -118,7 +118,8 @@ private void execute(ReindexJob reindexJob) { String taskId = getPersistentTaskId(); long allocationId = getAllocationId(); Consumer<BulkByScrollTask.Status> committedCallback = childTask::setCommittedStatus; - ReindexTaskStateUpdater taskUpdater = new ReindexTaskStateUpdater(reindexIndexClient, taskId, allocationId, committedCallback); + ReindexTaskStateUpdater taskUpdater = new ReindexTaskStateUpdater(reindexIndexClient, client.threadPool(), taskId, allocationId, + committedCallback); taskUpdater.assign(new ActionListener<>() { @Override public void onResponse(ReindexTaskStateDoc stateDoc) { diff --git a/modules/reindex/src/main/java/org/elasticsearch/index/reindex/ReindexTaskStateUpdater.java b/modules/reindex/src/main/java/org/elasticsearch/index/reindex/ReindexTaskStateUpdater.java index 88b41e85df110..e0c8e004a2695 100644 --- a/modules/reindex/src/main/java/org/elasticsearch/index/reindex/ReindexTaskStateUpdater.java +++ b/modules/reindex/src/main/java/org/elasticsearch/index/reindex/ReindexTaskStateUpdater.java @@ -24,9 +24,11 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.index.engine.VersionConflictEngineException; +import org.elasticsearch.threadpool.ThreadPool; -import java.util.concurrent.Semaphore; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Consumer; public class ReindexTaskStateUpdater implements Reindexer.CheckpointListener { @@ -36,18 +38,20 @@ public class ReindexTaskStateUpdater implements Reindexer.CheckpointListener { private static final Logger logger = LogManager.getLogger(ReindexTask.class); private final ReindexIndexClient reindexIndexClient; + private final ThreadPool threadPool; private final String persistentTaskId; private final long allocationId; private final Consumer<BulkByScrollTask.Status> committedCallback; - private final Semaphore semaphore = new Semaphore(1); + private ThrottlingConsumer<Tuple<ScrollableHitSource.Checkpoint, BulkByScrollTask.Status>> checkpointThrottler; private int assignmentAttempts = 0; private ReindexTaskState lastState; - private boolean isDone = false; + private AtomicBoolean isDone = new AtomicBoolean(); - public ReindexTaskStateUpdater(ReindexIndexClient reindexIndexClient, String persistentTaskId, long allocationId, + public ReindexTaskStateUpdater(ReindexIndexClient reindexIndexClient, ThreadPool threadPool, String persistentTaskId, long allocationId, Consumer<BulkByScrollTask.Status> committedCallback) { this.reindexIndexClient = reindexIndexClient; + this.threadPool = threadPool; this.persistentTaskId = persistentTaskId; this.allocationId = allocationId; // TODO: At some point I think we would like to replace a single universal callback to a listener that @@ -70,7 +74,12 @@ public void onResponse(ReindexTaskState taskState) { reindexIndexClient.updateReindexTaskDoc(persistentTaskId, newDoc, term, seqNo, new ActionListener<>() { @Override public void onResponse(ReindexTaskState newTaskState) { + assert checkpointThrottler == null; lastState = newTaskState; + checkpointThrottler = new ThrottlingConsumer<>( + (t, whenDone) -> updateCheckpoint(t.v1(), t.v2(), whenDone), + newTaskState.getStateDoc().getReindexRequest().getCheckpointInterval(), System::nanoTime, threadPool + ); listener.onResponse(newTaskState.getStateDoc()); } @@ -114,61 +123,58 @@ public void onFailure(Exception ex) { @Override public void onCheckpoint(ScrollableHitSource.Checkpoint checkpoint, BulkByScrollTask.Status status) { - // TODO: Need some kind of throttling here, no need to do this all the time. - // only do one checkpoint at a time, in case checkpointing is too slow. - if (semaphore.tryAcquire()) { - if (isDone) { - semaphore.release(); - } else { - ReindexTaskStateDoc nextState = lastState.getStateDoc().withCheckpoint(checkpoint, status); - // TODO: This can fail due to conditional update. Need to hook into ability to cancel reindex process - long term = lastState.getPrimaryTerm(); - long seqNo = lastState.getSeqNo(); - reindexIndexClient.updateReindexTaskDoc(persistentTaskId, nextState, term, seqNo, new ActionListener<>() { - @Override - public void onResponse(ReindexTaskState taskState) { - lastState = taskState; - committedCallback.accept(status); - semaphore.release(); - } - - @Override - public void onFailure(Exception e) { - semaphore.release(); - } - }); + assert checkpointThrottler != null; + checkpointThrottler.accept(Tuple.tuple(checkpoint, status)); + } + + private void updateCheckpoint(ScrollableHitSource.Checkpoint checkpoint, BulkByScrollTask.Status status, Runnable whenDone) { + ReindexTaskStateDoc nextState = lastState.getStateDoc().withCheckpoint(checkpoint, status); + // TODO: This can fail due to conditional update. Need to hook into ability to cancel reindex process + long term = lastState.getPrimaryTerm(); + long seqNo = lastState.getSeqNo(); + reindexIndexClient.updateReindexTaskDoc(persistentTaskId, nextState, term, seqNo, new ActionListener<>() { + @Override + public void onResponse(ReindexTaskState taskState) { + lastState = taskState; + committedCallback.accept(status); + whenDone.run(); } - } + + @Override + public void onFailure(Exception e) { + whenDone.run(); + } + }); } public void finish(@Nullable BulkByScrollResponse reindexResponse, @Nullable ElasticsearchException exception, ActionListener<ReindexTaskStateDoc> listener) { - // TODO: Move to try acquire and a scheduled retry if there is currently contention - semaphore.acquireUninterruptibly(); - if (isDone) { - semaphore.release(); + assert checkpointThrottler != null; + if (isDone.compareAndSet(false, true) == false) { listener.onFailure(new ElasticsearchException("Reindex task already finished locally")); } else { - ReindexTaskStateDoc state = lastState.getStateDoc().withFinishedState(reindexResponse, exception); - isDone = true; - long term = lastState.getPrimaryTerm(); - long seqNo = lastState.getSeqNo(); - reindexIndexClient.updateReindexTaskDoc(persistentTaskId, state, term, seqNo, new ActionListener<>() { - @Override - public void onResponse(ReindexTaskState taskState) { - lastState = null; - semaphore.release(); - listener.onResponse(taskState.getStateDoc()); + checkpointThrottler.close(() -> writeFinishedState(reindexResponse, exception, listener)); + } + } - } + private void writeFinishedState(@Nullable BulkByScrollResponse reindexResponse, @Nullable ElasticsearchException exception, + ActionListener<ReindexTaskStateDoc> listener) { + ReindexTaskStateDoc state = lastState.getStateDoc().withFinishedState(reindexResponse, exception); + long term = lastState.getPrimaryTerm(); + long seqNo = lastState.getSeqNo(); + reindexIndexClient.updateReindexTaskDoc(persistentTaskId, state, term, seqNo, new ActionListener<>() { + @Override + public void onResponse(ReindexTaskState taskState) { + lastState = null; + listener.onResponse(taskState.getStateDoc()); - @Override - public void onFailure(Exception e) { - lastState = null; - semaphore.release(); - listener.onFailure(e); - } - }); - } + } + + @Override + public void onFailure(Exception e) { + lastState = null; + listener.onFailure(e); + } + }); } } diff --git a/modules/reindex/src/main/java/org/elasticsearch/index/reindex/ThrottlingConsumer.java b/modules/reindex/src/main/java/org/elasticsearch/index/reindex/ThrottlingConsumer.java new file mode 100644 index 0000000000000..6ee3b0c5af9df --- /dev/null +++ b/modules/reindex/src/main/java/org/elasticsearch/index/reindex/ThrottlingConsumer.java @@ -0,0 +1,142 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.index.reindex; + +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.threadpool.Scheduler; +import org.elasticsearch.threadpool.ThreadPool; + +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.function.LongSupplier; +import java.util.function.Supplier; + +/** + * A throttling consumer that forwards input to an outbound consumer, but discarding any that arrive too quickly (but eventually sending the + * last state after the minimumInterval). + * The outbound consumer is only called by a single thread at a time. The outbound consumer should be non-blocking. + * TODO: could be moved to more generic place for reuse? + * @param <T> type of object passed through. + */ +public class ThrottlingConsumer<T> implements Consumer<T> { + private final ThreadPool threadPool; + private final BiConsumer<T, Runnable> outbound; + private final TimeValue minimumInterval; + private final Object lock = new Object(); + private final LongSupplier nanoTimeSource; + + // state protected by lock. + private long lastWriteTimeNanos; + private T value; + private Scheduler.ScheduledCancellable scheduledWrite; + private boolean outboundActive; + private boolean closed; + private Runnable onClosed; + + public ThrottlingConsumer(BiConsumer<T, Runnable> outbound, TimeValue minimumInterval, + LongSupplier nanoTimeSource, ThreadPool threadPool) { + Supplier<ThreadContext.StoredContext> restorableContext = threadPool.getThreadContext().newRestorableContext(false); + this.outbound = (value, whenDone) -> { + try (ThreadContext.StoredContext ignored = restorableContext.get()) { + outbound.accept(value, whenDone); + } + }; + this.minimumInterval = minimumInterval; + this.threadPool = threadPool; + this.nanoTimeSource = nanoTimeSource; + this.lastWriteTimeNanos = nanoTimeSource.getAsLong(); + } + + @Override + public void accept(T newValue) { + long now = nanoTimeSource.getAsLong(); + synchronized (lock) { + if (closed) { + return; + } + this.value = newValue; + if (scheduledWrite == null) { + // schedule is non-blocking + scheduledWrite = threadPool.schedule(this::onScheduleTimeout, getDelay(now), ThreadPool.Names.SAME); + } + } + } + + private TimeValue getDelay(long now) { + long nanos = lastWriteTimeNanos + minimumInterval.nanos() - now; + return nanos < 0 ? TimeValue.ZERO : TimeValue.timeValueNanos(nanos); + } + + private void onScheduleTimeout() { + T value; + long now = nanoTimeSource.getAsLong(); + synchronized (lock) { + if (closed) { + return; + } + value = this.value; + lastWriteTimeNanos = now; + outboundActive = true; + } + + + outbound.accept(value, () -> { + synchronized (this) { + outboundActive = false; + if (closed == false) { + if (value != this.value) { + scheduledWrite = threadPool.schedule(this::onScheduleTimeout, minimumInterval, + ThreadPool.Names.SAME); + } else { + scheduledWrite = null; + } + } + } + + // safe since onScheduleTimeout is only called single threaded + if (onClosed != null) { + onClosed.run(); + } + }); + } + + /** + * Async close this. Any state submitted since last outbound call will be discarded (as well as any new inbound accept calls). + * @param onClosed called when closed, which guarantees no more calls on outbound consumer. Must be non-blocking. + */ + public void close(Runnable onClosed) { + synchronized (lock) { + assert closed == false : "multiple closes not supported"; + closed = true; + if (scheduledWrite != null) { + if (outboundActive) { + this.onClosed = onClosed; + } else { + scheduledWrite.cancel(); + } + scheduledWrite = null; + } + } + if (this.onClosed == null) { + onClosed.run(); + } + } +} diff --git a/modules/reindex/src/test/java/org/elasticsearch/index/reindex/ReindexFailoverIT.java b/modules/reindex/src/test/java/org/elasticsearch/index/reindex/ReindexFailoverIT.java index 7cdac5e69b5a4..2670697bc15e7 100644 --- a/modules/reindex/src/test/java/org/elasticsearch/index/reindex/ReindexFailoverIT.java +++ b/modules/reindex/src/test/java/org/elasticsearch/index/reindex/ReindexFailoverIT.java @@ -104,6 +104,7 @@ public void testReindexFailover() throws Throwable { ReindexRequestBuilder copy = reindex().source("source").destination("dest").refresh(true); ReindexRequest reindexRequest = copy.request(); reindexRequest.setScroll(TimeValue.timeValueSeconds(scrollTimeout)); + reindexRequest.setCheckpointInterval(TimeValue.timeValueMillis(100)); StartReindexJobAction.Request request = new StartReindexJobAction.Request(reindexRequest, false); copy.source().setSize(10); @@ -173,7 +174,8 @@ public void testReindexFailover() throws Throwable { assertThat(seqNos.length(), greaterThan(docCount)); } // The first 9 should not be replayed, we restart from at least seqNo 9. - assertThat(seqNos.length(), lessThan(Math.toIntExact(docCount + hitsAfterRestart - 9))); + assertThat("docCount: " + docCount + " hitsAfterRestart " + hitsAfterRestart, seqNos.length()-1, + lessThan(Math.toIntExact(docCount + hitsAfterRestart - 9))); }); diff --git a/modules/reindex/src/test/java/org/elasticsearch/index/reindex/ReindexTaskStateUpdaterTests.java b/modules/reindex/src/test/java/org/elasticsearch/index/reindex/ReindexTaskStateUpdaterTests.java index 41ac67d242cb3..1d8ef8885d582 100644 --- a/modules/reindex/src/test/java/org/elasticsearch/index/reindex/ReindexTaskStateUpdaterTests.java +++ b/modules/reindex/src/test/java/org/elasticsearch/index/reindex/ReindexTaskStateUpdaterTests.java @@ -44,7 +44,8 @@ public void testEnsureLowerAssignmentFails() throws Exception { ReindexIndexClient reindexClient = getReindexClient(); createDoc(reindexClient, taskId); - ReindexTaskStateUpdater updater = new ReindexTaskStateUpdater(reindexClient, taskId, 1, (s) -> {}); + ReindexTaskStateUpdater updater = new ReindexTaskStateUpdater(reindexClient, client().threadPool(), + taskId, 1, (s) -> {}); CountDownLatch successLatch = new CountDownLatch(1); updater.assign(new ActionListener<>() { @@ -61,7 +62,8 @@ public void onFailure(Exception exception) { }); successLatch.await(); - ReindexTaskStateUpdater oldAllocationUpdater = new ReindexTaskStateUpdater(reindexClient, taskId, 0, (s) -> {}); + ReindexTaskStateUpdater oldAllocationUpdater = new ReindexTaskStateUpdater(reindexClient, client().threadPool(), + taskId, 0, (s) -> {}); CountDownLatch failureLatch = new CountDownLatch(1); AtomicReference<Exception> exceptionRef = new AtomicReference<>(); @@ -80,7 +82,6 @@ public void onFailure(Exception exception) { }); failureLatch.await(); assertThat(exceptionRef.get().getMessage(), equalTo("A newer task has already been allocated")); - } public void testEnsureHighestAllocationIsWinningAssignment() throws Exception { @@ -93,7 +94,7 @@ public void testEnsureHighestAllocationIsWinningAssignment() throws Exception { Collections.shuffle(assignments, random()); for (Integer i : assignments) { - ReindexTaskStateUpdater updater = new ReindexTaskStateUpdater(reindexClient, taskId, i, (s) -> {}); + ReindexTaskStateUpdater updater = new ReindexTaskStateUpdater(reindexClient, client().threadPool(), taskId, i, (s) -> {}); new Thread(() -> { updater.assign(new ActionListener<>() { @Override @@ -124,7 +125,8 @@ public void testNewAllocationWillStopCheckpoints() throws Exception { AtomicInteger committed = new AtomicInteger(0); - ReindexTaskStateUpdater updater = new ReindexTaskStateUpdater(reindexClient, taskId, 0, (s) -> committed.incrementAndGet()); + ReindexTaskStateUpdater updater = new ReindexTaskStateUpdater(reindexClient, client().threadPool(), + taskId, 0, (s) -> committed.incrementAndGet()); CountDownLatch firstAssignmentLatch = new CountDownLatch(1); updater.assign(new ActionListener<>() { @@ -145,7 +147,8 @@ public void onFailure(Exception exception) { updater.onCheckpoint(new ScrollableHitSource.Checkpoint(10), status); assertBusy(() -> assertEquals(1, committed.get())); - ReindexTaskStateUpdater newAllocationUpdater = new ReindexTaskStateUpdater(reindexClient, taskId, 1, (s) -> {}); + ReindexTaskStateUpdater newAllocationUpdater = new ReindexTaskStateUpdater(reindexClient, client().threadPool(), + taskId, 1, (s) -> {}); CountDownLatch secondAssignmentLatch = new CountDownLatch(1); newAllocationUpdater.assign(new ActionListener<>() { @@ -179,7 +182,8 @@ public void testFinishWillStopCheckpoints() throws Exception { AtomicInteger committed = new AtomicInteger(0); - ReindexTaskStateUpdater updater = new ReindexTaskStateUpdater(reindexClient, taskId, 0, (s) -> committed.incrementAndGet()); + ReindexTaskStateUpdater updater = new ReindexTaskStateUpdater(reindexClient, client().threadPool(), + taskId, 0, (s) -> committed.incrementAndGet()); CountDownLatch firstAssignmentLatch = new CountDownLatch(1); updater.assign(new ActionListener<>() { @@ -218,7 +222,8 @@ public void testFinishStoresResult() throws Exception { AtomicInteger committed = new AtomicInteger(0); - ReindexTaskStateUpdater updater = new ReindexTaskStateUpdater(reindexClient, taskId, 0, (s) -> committed.incrementAndGet()); + ReindexTaskStateUpdater updater = new ReindexTaskStateUpdater(reindexClient, client().threadPool(), + taskId, 0, (s) -> committed.incrementAndGet()); CountDownLatch firstAssignmentLatch = new CountDownLatch(1); updater.assign(new ActionListener<>() { @@ -258,7 +263,8 @@ public void onFailure(Exception exception) { } private void createDoc(ReindexIndexClient client, String taskId) { - ReindexRequest request = reindex().source("source").destination("dest").refresh(true).request(); + ReindexRequest request = + reindex().source("source").destination("dest").refresh(true).request().setCheckpointInterval(TimeValue.ZERO); PlainActionFuture<ReindexTaskState> future = PlainActionFuture.newFuture(); client.createReindexTaskDoc(taskId, new ReindexTaskStateDoc(request), future); diff --git a/modules/reindex/src/test/java/org/elasticsearch/index/reindex/ThrottlingConsumerTests.java b/modules/reindex/src/test/java/org/elasticsearch/index/reindex/ThrottlingConsumerTests.java new file mode 100644 index 0000000000000..a9ff6e808c60a --- /dev/null +++ b/modules/reindex/src/test/java/org/elasticsearch/index/reindex/ThrottlingConsumerTests.java @@ -0,0 +1,282 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.index.reindex; + +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ExecutorBuilder; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.junit.After; +import org.junit.Before; + +import java.util.concurrent.BrokenBarrierException; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.Delayed; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +import static org.hamcrest.Matchers.equalTo; + +public class ThrottlingConsumerTests extends ESTestCase { + + private DeterministicSchedulerThreadPool threadPool; + + @Before + public void createThreadPool() { + threadPool = new DeterministicSchedulerThreadPool(getTestName()); + } + + @After + public void shutdownThreadPool() { + ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS); + } + + public void testThrottling() { + AtomicReference<Long> throttledValue = new AtomicReference<>(); + // we start random, maybe we even overflow, which is fine. + AtomicLong time = new AtomicLong(randomLong()); + ThrottlingConsumer<Long> throttler + = new ThrottlingConsumer<>(wrap(throttledValue::set), TimeValue.timeValueNanos(10), time::get, threadPool); + + assertNull(throttledValue.get()); + threadPool.validate(false, null, false); + + for (int i = 0; i < randomIntBetween(0, 10); ++i) { + throttler.accept(randomLong()); + assertNull(throttledValue.get()); + threadPool.validate(true, TimeValue.timeValueNanos(10), false); + } + + time.addAndGet(randomLongBetween(0, 20)); + long expectedValue = randomLong(); + throttler.accept(expectedValue); + assertNull(throttledValue.get()); + + threadPool.runScheduledTask(); + + threadPool.validate(false, null, false); + assertThat(throttledValue.get(), equalTo(expectedValue)); + + long timePassed = randomLongBetween(5, 15); + time.addAndGet(timePassed); + + expectedValue = randomLong(); + throttler.accept(expectedValue); + + threadPool.validate(true, TimeValue.timeValueNanos(Math.max(10-timePassed, 0)), false); + threadPool.runScheduledTask(); + + threadPool.validate(false, null, false); + assertThat(throttledValue.get(), equalTo(expectedValue)); + + // maybe provoke overflow. + time.addAndGet(randomLongBetween(10, Long.MAX_VALUE)); + + expectedValue = randomLong(); + throttler.accept(expectedValue); + + threadPool.validate(true, TimeValue.timeValueNanos(0), false); + threadPool.runScheduledTask(); + + threadPool.validate(false, null, false); + assertThat(throttledValue.get(), equalTo(expectedValue)); + } + + public void testRetainThreadContext() { + String headerValue = Long.toString(randomLong()); + threadPool.getThreadContext().putHeader("test-header", headerValue); + AtomicLong time = new AtomicLong(randomLong()); + AtomicInteger invocationCount = new AtomicInteger(); + Consumer<Long> validateThreadContextConsumer = v -> { + assertEquals(headerValue, threadPool.getThreadContext().getHeader("test-header")); + invocationCount.incrementAndGet(); + }; + ThrottlingConsumer<Long> throttler + = new ThrottlingConsumer<>(wrap(validateThreadContextConsumer), TimeValue.timeValueNanos(10), time::get, threadPool); + + throttler.accept(randomLong()); + threadPool.validate(true, TimeValue.timeValueNanos(10), false); + threadPool.getThreadContext().stashContext(); + threadPool.runScheduledTask(); + + time.addAndGet(randomLongBetween(10,1000)); + throttler.accept(randomLong()); + threadPool.validate(true, TimeValue.timeValueNanos(0), false); + threadPool.runScheduledTask(); + + assertEquals(2, invocationCount.get()); + } + + public void testCloseNormal() { + AtomicLong time = new AtomicLong(randomLong()); + ThrottlingConsumer<Long> throttler + = new ThrottlingConsumer<>(wrap(l -> fail()), TimeValue.timeValueNanos(10), time::get, threadPool); + + AtomicBoolean closed = new AtomicBoolean(); + throttler.close(() -> assertTrue(closed.compareAndSet(false, true))); + assertTrue(closed.get()); + + throttler.accept(randomLong()); + threadPool.validate(false, null, false); + } + + public void testCloseCancel() { + AtomicLong time = new AtomicLong(randomLong()); + ThrottlingConsumer<Long> throttler + = new ThrottlingConsumer<>(wrap(l -> fail()), TimeValue.timeValueNanos(10), time::get, threadPool); + + throttler.accept(randomLong()); + threadPool.validate(true, TimeValue.timeValueNanos(10), false); + + AtomicBoolean closed = new AtomicBoolean(); + throttler.close(() -> assertTrue(closed.compareAndSet(false, true))); + assertTrue(closed.get()); + + threadPool.validate(true, TimeValue.timeValueNanos(10), true); + } + + public void testCloseNotifyWhenOutboundActive() throws Exception { + AtomicLong time = new AtomicLong(randomLong()); + CyclicBarrier rendezvous = new CyclicBarrier(2); + Consumer<Long> waitingConsumer = x -> { + try { + rendezvous.await(); + rendezvous.await(); + } catch (InterruptedException | BrokenBarrierException e) { + throw new AssertionError(e); + } + }; + ThrottlingConsumer<Long> throttler + = new ThrottlingConsumer<>(wrap(waitingConsumer), TimeValue.timeValueNanos(10), time::get, threadPool); + + throttler.accept(randomLong()); + threadPool.validate(true, TimeValue.timeValueNanos(10), false); + + Future<?> future = threadPool.generic().submit(() -> threadPool.runScheduledTask()); + + rendezvous.await(); + + AtomicBoolean closed = new AtomicBoolean(); + throttler.close(() -> assertTrue(closed.compareAndSet(false, true))); + assertFalse(closed.get()); + + rendezvous.await(); + future.get(10, TimeUnit.SECONDS); + + assertTrue(closed.get()); + + threadPool.validate(false, null, false); + } + + public void testCloseNotifyWhenOutboundActiveAsyncConsumer() { + AtomicLong time = new AtomicLong(randomLong()); + AtomicReference<Runnable> whenDoneReference = new AtomicReference<>(); + BiConsumer<Long, Runnable> deferredResponseConsumer = (x, runnable) -> { + assertTrue(whenDoneReference.compareAndSet(null, runnable)); + }; + ThrottlingConsumer<Long> throttler + = new ThrottlingConsumer<>(deferredResponseConsumer, TimeValue.timeValueNanos(10), time::get, threadPool); + + throttler.accept(randomLong()); + threadPool.validate(true, TimeValue.timeValueNanos(10), false); + threadPool.runScheduledTask(); + assertNotNull(whenDoneReference.get()); + + AtomicBoolean closed = new AtomicBoolean(); + throttler.close(() -> assertTrue(closed.compareAndSet(false, true))); + assertFalse(closed.get()); + + whenDoneReference.get().run(); + + assertTrue(closed.get()); + threadPool.validate(false, null, false); + } + + private <T> BiConsumer<T, Runnable> wrap(Consumer<T> consumer) { + return (t, whenDone) -> { + try { + consumer.accept(t); + } finally { + whenDone.run(); + } + }; + } + + private static class DeterministicSchedulerThreadPool extends TestThreadPool { + private Runnable command; + private TimeValue delay; + private boolean cancelled; + + private DeterministicSchedulerThreadPool(String name, ExecutorBuilder<?>... customBuilders) { + super(name, customBuilders); + } + + @Override + public ScheduledCancellable schedule(Runnable command, TimeValue delay, String executor) { + assertNull(this.command); + assertNull(this.delay); + this.command = command; + this.delay = delay; + return new ScheduledCancellable() { + @Override + public long getDelay(TimeUnit unit) { + throw new UnsupportedOperationException(); + } + + @Override + public int compareTo(Delayed o) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean cancel() { + assertFalse(cancelled); + cancelled = true; + return true; + } + + @Override + public boolean isCancelled() { + throw new UnsupportedOperationException(); + } + }; + } + + public void runScheduledTask() { + assertNotNull(command); + command.run(); + command = null; + delay = null; + } + + public void validate(boolean hasCommand, TimeValue delay, boolean cancelled) { + assertEquals(hasCommand, this.command != null); + assertEquals(delay, this.delay); + assertEquals(cancelled, this.cancelled); + } + } +} diff --git a/modules/reindex/src/test/java/org/elasticsearch/index/reindex/ThrottlingConsumerThreadTests.java b/modules/reindex/src/test/java/org/elasticsearch/index/reindex/ThrottlingConsumerThreadTests.java new file mode 100644 index 0000000000000..8ab1306232d92 --- /dev/null +++ b/modules/reindex/src/test/java/org/elasticsearch/index/reindex/ThrottlingConsumerThreadTests.java @@ -0,0 +1,137 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.index.reindex; + +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.junit.After; +import org.junit.Before; + +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.BiConsumer; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.lessThanOrEqualTo; + +public class ThrottlingConsumerThreadTests extends ESTestCase { + + private ThreadPool threadPool; + @Before + public void createThreadPool() { + threadPool = new TestThreadPool(getTestName()); + } + + @After + public void shutdownThreadPool() { + ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS); + } + + public void testMultiThread() throws Exception { + int throttleInterval = randomIntBetween(1, 20); + CountDownLatch rounds = new CountDownLatch(randomIntBetween(10, 1000/throttleInterval)); // a second max + // we only pass increasing numbers and keep track of the before and after value (latest and committed) in order to validate + // the output. We use that the throttler must read the value again after invoking the consumer to enforce that the next + // consumer invocation must see at least the committed values from last round. + ConcurrentMap<Thread, Long> latest = new ConcurrentHashMap<>(); + ConcurrentMap<Thread, Long> committed = new ConcurrentHashMap<>(); + ConcurrentMap<Thread, Long> lastRoundCommitted = new ConcurrentHashMap<>(); + BiConsumer<Tuple<Thread, Long>, Runnable> validatingConsumer = new BiConsumer<>() { + private AtomicBoolean active = new AtomicBoolean(); + @Override + public void accept(Tuple<Thread, Long> value, Runnable whenDone) { + assertTrue(active.compareAndSet(false, true)); + assertThat(value.v2(), greaterThanOrEqualTo(lastRoundCommitted.get(value.v1()))); + assertThat(value.v2(), lessThanOrEqualTo(latest.get(value.v1()))); + lastRoundCommitted.putAll(committed); + Runnable work = () -> { + try { + Thread.sleep(1); // simulate hard work + } catch (InterruptedException e) { + throw new AssertionError(e); + } + assertTrue(active.compareAndSet(true, false)); + whenDone.run(); + rounds.countDown(); + }; + if (randomBoolean()) { + work.run(); + } else { + threadPool.generic().submit(work); + } + } + }; + ThrottlingConsumer<Tuple<Thread, Long>> throttler + = new ThrottlingConsumer<>(validatingConsumer, TimeValue.timeValueMillis(throttleInterval), System::nanoTime, threadPool); + AtomicBoolean stopped = new AtomicBoolean(); + List<Thread> threads = IntStream.range(0, randomIntBetween(2, 5)).mapToObj(i -> new Thread(getTestName() + "-" + i) { + private long value = randomLongBetween(0, Integer.MAX_VALUE); + { + setDaemon(true); + } + + public void run() { + while (!stopped.get()) { + latest.put(this, value); + throttler.accept(Tuple.tuple(this, value)); + committed.put(this, value); + value = value + randomIntBetween(0, Integer.MAX_VALUE); + } + } + }).collect(Collectors.toList()); + + threads.forEach(t -> { + committed.put(t, Long.MIN_VALUE); + lastRoundCommitted.put(t, Long.MIN_VALUE); + latest.put(t, Long.MIN_VALUE); + }); + + CountDownLatch closed = new CountDownLatch(1); + AutoCloseable shutdown = () -> { + throttler.close(() -> closed.countDown()); + stopped.set(true); + + threads.forEach(t -> { + try { + t.join(10000); + } catch (InterruptedException e) { + throw new AssertionError(e); + } + }); + + threads.forEach(t -> assertFalse(t.isAlive())); + }; + + threads.forEach(Thread::start); + + try (shutdown) { + assertTrue(rounds.await(10, TimeUnit.SECONDS)); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/reindex/ReindexRequest.java b/server/src/main/java/org/elasticsearch/index/reindex/ReindexRequest.java index 0f0306e7eed00..a1cc8caf2ac3b 100644 --- a/server/src/main/java/org/elasticsearch/index/reindex/ReindexRequest.java +++ b/server/src/main/java/org/elasticsearch/index/reindex/ReindexRequest.java @@ -20,6 +20,7 @@ package org.elasticsearch.index.reindex; import org.apache.logging.log4j.LogManager; +import org.elasticsearch.Version; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.CompositeIndicesRequest; import org.elasticsearch.action.index.IndexRequest; @@ -67,6 +68,8 @@ */ public class ReindexRequest extends AbstractBulkIndexByScrollRequest<ReindexRequest> implements CompositeIndicesRequest, ToXContentObject { + private static final TimeValue DEFAULT_CHECKPOINT_INTERVAL = TimeValue.timeValueSeconds(10); + /** * Prototype for index requests. */ @@ -74,6 +77,8 @@ public class ReindexRequest extends AbstractBulkIndexByScrollRequest<ReindexRequ private RemoteInfo remoteInfo; + private TimeValue checkpointInterval = DEFAULT_CHECKPOINT_INTERVAL; + public ReindexRequest() { this(new SearchRequest(), new IndexRequest(), true); } @@ -91,6 +96,10 @@ public ReindexRequest(StreamInput in) throws IOException { super(in); destination = new IndexRequest(in); remoteInfo = in.readOptionalWriteable(RemoteInfo::new); + // todo: version in backport. + if (in.getVersion().onOrAfter(Version.V_8_0_0)) { + checkpointInterval = in.readTimeValue(); + } } @Override @@ -260,6 +269,22 @@ public RemoteInfo getRemoteInfo() { return remoteInfo; } + /** + * Get the checkpoint interval, the interval between persisting progress state. + */ + public TimeValue getCheckpointInterval() { + return checkpointInterval; + } + + /** + * Set the checkpoint interval, the interval between persisting progress state. Defaults to 10 seconds. + */ + public ReindexRequest setCheckpointInterval(TimeValue interval) { + assert interval != null; + this.checkpointInterval = interval; + return this; + } + @Override public ReindexRequest forSlice(TaskId slicingTask, SearchRequest slice, int totalSlices) { ReindexRequest sliced = doForSlice(new ReindexRequest(slice, destination, false), slicingTask, totalSlices); @@ -272,6 +297,10 @@ public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); destination.writeTo(out); out.writeOptionalWriteable(remoteInfo); + // todo: version in backport + if (out.getVersion().onOrAfter(Version.V_8_0_0)) { + out.writeTimeValue(checkpointInterval); + } } @Override @@ -357,7 +386,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params, boolea if (ActiveShardCount.DEFAULT.equals(getWaitForActiveShards()) == false) { builder.field("wait_for_active_shards", getWaitForActiveShards().toString()); } - + builder.field("checkpoint_interval_millis", getCheckpointInterval().getMillis()); builder.endObject(); } } @@ -435,6 +464,9 @@ private static void setParams(ReindexRequest reindexRequest, ReindexParams param if (params.waitForActiveShardCount != null) { reindexRequest.setWaitForActiveShards(params.waitForActiveShardCount); } + if (params.checkpointInterval != null) { + reindexRequest.setCheckpointInterval(params.checkpointInterval); + } } public static ReindexRequest fromXContent(XContentParser parser) throws IOException { @@ -563,7 +595,7 @@ private static class ReindexParams { @SuppressWarnings("unchecked") public static final ConstructingObjectParser<ReindexParams, Void> PARAMS_PARSER = new ConstructingObjectParser<>("reindex_params", a -> new ReindexParams( - (Boolean) a[0], (Long) a[1], (Long) a[2], (Integer) a[3], (Float) a[4], (String) a[5])); + (Boolean) a[0], (Long) a[1], (Long) a[2], (Integer) a[3], (Float) a[4], (String) a[5], (Long) a[6])); static { PARAMS_PARSER.declareBoolean(ConstructingObjectParser.constructorArg(), new ParseField("refresh")); @@ -572,6 +604,7 @@ private static class ReindexParams { PARAMS_PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField("slices")); PARAMS_PARSER.declareFloat(ConstructingObjectParser.optionalConstructorArg(), new ParseField("requests_per_second")); PARAMS_PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("wait_for_active_shards")); + PARAMS_PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), new ParseField("checkpoint_interval_millis")); } private final boolean refresh; @@ -580,10 +613,10 @@ private static class ReindexParams { private final Integer slices; private final Float requestsPerSecond; private final ActiveShardCount waitForActiveShardCount; - + private final TimeValue checkpointInterval; private ReindexParams(boolean refresh, Long scroll, Long timeout, Integer slices, Float requestsPerSecond, - String waitForActiveShardCount) { + String waitForActiveShardCount, Long checkpointInterval) { this.refresh = refresh; if (scroll != null) { this.scroll = TimeValue.timeValueMillis(scroll); @@ -604,6 +637,12 @@ private ReindexParams(boolean refresh, Long scroll, Long timeout, Integer slices } else { this.waitForActiveShardCount = null; } + + if (checkpointInterval != null) { + this.checkpointInterval = TimeValue.timeValueMillis(checkpointInterval); + } else { + this.checkpointInterval = null; + } } public static ReindexParams fromXContent(XContentParser parser) throws IOException { diff --git a/server/src/test/java/org/elasticsearch/index/reindex/ReindexRequestTests.java b/server/src/test/java/org/elasticsearch/index/reindex/ReindexRequestTests.java index 2b3dcc00ef09a..3162046f06b1e 100644 --- a/server/src/test/java/org/elasticsearch/index/reindex/ReindexRequestTests.java +++ b/server/src/test/java/org/elasticsearch/index/reindex/ReindexRequestTests.java @@ -155,6 +155,10 @@ protected ReindexRequest createTestInstance() { } } + if (randomBoolean()) { + reindexRequest.setCheckpointInterval(TimeValue.timeValueMillis(randomNonNegativeLong())); + } + return reindexRequest; } @@ -192,6 +196,7 @@ protected void assertEqualInstances(ReindexRequest expectedInstance, ReindexRequ assertEquals(expectedInstance.getSlices(), newInstance.getSlices()); assertEquals(expectedInstance.getRequestsPerSecond(), newInstance.getRequestsPerSecond(), 0.1); assertEquals(expectedInstance.getWaitForActiveShards(), newInstance.getWaitForActiveShards()); + assertEquals(expectedInstance.getCheckpointInterval(), newInstance.getCheckpointInterval()); } }