diff --git a/CHANGELOG.md b/CHANGELOG.md index 8e7fa8b5547f0..4d07052d55ff0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) - Do not fail replica shard due to primary closure ([#4133](https://github.com/opensearch-project/OpenSearch/pull/4133)) - Add timeout on Mockito.verify to reduce flakyness in testReplicationOnDone test([#4314](https://github.com/opensearch-project/OpenSearch/pull/4314)) - Commit workflow for dependabot changelog helper ([#4331](https://github.com/opensearch-project/OpenSearch/pull/4331)) +- Fixed cancellation of segment replication events ([#4225](https://github.com/opensearch-project/OpenSearch/pull/4225)) ### Security - CVE-2022-25857 org.yaml:snakeyaml DOS vulnerability ([#4341](https://github.com/opensearch-project/OpenSearch/pull/4341)) diff --git a/server/src/main/java/org/opensearch/indices/cluster/IndicesClusterStateService.java b/server/src/main/java/org/opensearch/indices/cluster/IndicesClusterStateService.java index 8884ef2cddd0a..15a9bf9e4c492 100644 --- a/server/src/main/java/org/opensearch/indices/cluster/IndicesClusterStateService.java +++ b/server/src/main/java/org/opensearch/indices/cluster/IndicesClusterStateService.java @@ -81,6 +81,7 @@ import org.opensearch.indices.recovery.PeerRecoveryTargetService; import org.opensearch.indices.recovery.RecoveryListener; import org.opensearch.indices.recovery.RecoveryState; +import org.opensearch.indices.replication.SegmentReplicationSourceService; import org.opensearch.indices.replication.SegmentReplicationTargetService; import org.opensearch.indices.replication.checkpoint.SegmentReplicationCheckpointPublisher; import org.opensearch.indices.replication.common.ReplicationState; @@ -152,6 +153,7 @@ public IndicesClusterStateService( final ThreadPool threadPool, final PeerRecoveryTargetService recoveryTargetService, final SegmentReplicationTargetService segmentReplicationTargetService, + final SegmentReplicationSourceService segmentReplicationSourceService, final ShardStateAction shardStateAction, final NodeMappingRefreshAction nodeMappingRefreshAction, final RepositoriesService repositoriesService, @@ -170,6 +172,7 @@ public IndicesClusterStateService( threadPool, checkpointPublisher, segmentReplicationTargetService, + segmentReplicationSourceService, recoveryTargetService, shardStateAction, nodeMappingRefreshAction, @@ -191,6 +194,7 @@ public IndicesClusterStateService( final ThreadPool threadPool, final SegmentReplicationCheckpointPublisher checkpointPublisher, final SegmentReplicationTargetService segmentReplicationTargetService, + final SegmentReplicationSourceService segmentReplicationSourceService, final PeerRecoveryTargetService recoveryTargetService, final ShardStateAction shardStateAction, final NodeMappingRefreshAction nodeMappingRefreshAction, @@ -211,6 +215,7 @@ public IndicesClusterStateService( // if segrep feature flag is not enabled, don't wire the target serivce as an IndexEventListener. if (FeatureFlags.isEnabled(FeatureFlags.REPLICATION_TYPE)) { indexEventListeners.add(segmentReplicationTargetService); + indexEventListeners.add(segmentReplicationSourceService); } this.builtInIndexListener = Collections.unmodifiableList(indexEventListeners); this.indicesService = indicesService; diff --git a/server/src/main/java/org/opensearch/indices/replication/OngoingSegmentReplications.java b/server/src/main/java/org/opensearch/indices/replication/OngoingSegmentReplications.java index dfebe5f7cabf2..828aa29192fe3 100644 --- a/server/src/main/java/org/opensearch/indices/replication/OngoingSegmentReplications.java +++ b/server/src/main/java/org/opensearch/indices/replication/OngoingSegmentReplications.java @@ -37,7 +37,6 @@ * @opensearch.internal */ class OngoingSegmentReplications { - private final RecoverySettings recoverySettings; private final IndicesService indicesService; private final Map copyStateMap; @@ -161,6 +160,20 @@ synchronized void cancel(IndexShard shard, String reason) { cancelHandlers(handler -> handler.getCopyState().getShard().shardId().equals(shard.shardId()), reason); } + /** + * Cancel all Replication events for the given allocation ID, intended to be called when a primary is shutting down. + * + * @param allocationId {@link String} - Allocation ID. + * @param reason {@link String} - Reason for the cancel + */ + synchronized void cancel(String allocationId, String reason) { + final SegmentReplicationSourceHandler handler = allocationIdToHandlers.remove(allocationId); + if (handler != null) { + handler.cancel(reason); + removeCopyState(handler.getCopyState()); + } + } + /** * Cancel any ongoing replications for a given {@link DiscoveryNode} * @@ -168,7 +181,6 @@ synchronized void cancel(IndexShard shard, String reason) { */ void cancelReplication(DiscoveryNode node) { cancelHandlers(handler -> handler.getTargetNode().equals(node), "Node left"); - } /** @@ -243,11 +255,7 @@ private void cancelHandlers(Predicate p .map(SegmentReplicationSourceHandler::getAllocationId) .collect(Collectors.toList()); for (String allocationId : allocationIds) { - final SegmentReplicationSourceHandler handler = allocationIdToHandlers.remove(allocationId); - if (handler != null) { - handler.cancel(reason); - removeCopyState(handler.getCopyState()); - } + cancel(allocationId, reason); } } } diff --git a/server/src/main/java/org/opensearch/indices/replication/PrimaryShardReplicationSource.java b/server/src/main/java/org/opensearch/indices/replication/PrimaryShardReplicationSource.java index 08dc0b97b31d5..aa0b5416dd0ff 100644 --- a/server/src/main/java/org/opensearch/indices/replication/PrimaryShardReplicationSource.java +++ b/server/src/main/java/org/opensearch/indices/replication/PrimaryShardReplicationSource.java @@ -87,4 +87,10 @@ public void getSegmentFiles( ); transportClient.executeRetryableAction(GET_SEGMENT_FILES, request, responseListener, reader); } + + @Override + public void cancel() { + transportClient.cancel(); + } + } diff --git a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSource.java b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSource.java index 8628a266ea7d0..b2e7487fff4b2 100644 --- a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSource.java +++ b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSource.java @@ -9,6 +9,7 @@ package org.opensearch.indices.replication; import org.opensearch.action.ActionListener; +import org.opensearch.common.util.CancellableThreads.ExecutionCancelledException; import org.opensearch.index.store.Store; import org.opensearch.index.store.StoreFileMetadata; import org.opensearch.indices.replication.checkpoint.ReplicationCheckpoint; @@ -47,4 +48,9 @@ void getSegmentFiles( Store store, ActionListener listener ); + + /** + * Cancel any ongoing requests, should resolve any ongoing listeners with onFailure with a {@link ExecutionCancelledException}. + */ + default void cancel() {} } diff --git a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSourceHandler.java b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSourceHandler.java index 2d21653c1924c..022d90b41d8ee 100644 --- a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSourceHandler.java +++ b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSourceHandler.java @@ -113,6 +113,16 @@ public synchronized void sendFiles(GetSegmentFilesRequest request, ActionListene final Closeable releaseResources = () -> IOUtils.close(resources); try { timer.start(); + cancellableThreads.setOnCancel((reason, beforeCancelEx) -> { + final RuntimeException e = new CancellableThreads.ExecutionCancelledException( + "replication was canceled reason [" + reason + "]" + ); + if (beforeCancelEx != null) { + e.addSuppressed(beforeCancelEx); + } + IOUtils.closeWhileHandlingException(releaseResources, () -> future.onFailure(e)); + throw e; + }); final Consumer onFailure = e -> { assert Transports.assertNotTransportThread(SegmentReplicationSourceHandler.this + "[onFailure]"); IOUtils.closeWhileHandlingException(releaseResources, () -> future.onFailure(e)); @@ -153,6 +163,7 @@ public synchronized void sendFiles(GetSegmentFilesRequest request, ActionListene final MultiChunkTransfer transfer = segmentFileTransferHandler .createTransfer(shard.store(), storeFileMetadata, () -> 0, sendFileStep); resources.add(transfer); + cancellableThreads.checkForCancel(); transfer.start(); sendFileStep.whenComplete(r -> { diff --git a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSourceService.java b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSourceService.java index 0cee731fde2cb..db3f87201b774 100644 --- a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSourceService.java +++ b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSourceService.java @@ -15,6 +15,7 @@ import org.opensearch.cluster.ClusterChangedEvent; import org.opensearch.cluster.ClusterStateListener; import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.routing.ShardRouting; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.Nullable; import org.opensearch.common.component.AbstractLifecycleComponent; @@ -42,7 +43,25 @@ * * @opensearch.internal */ -public final class SegmentReplicationSourceService extends AbstractLifecycleComponent implements ClusterStateListener, IndexEventListener { +public class SegmentReplicationSourceService extends AbstractLifecycleComponent implements ClusterStateListener, IndexEventListener { + + // Empty Implementation, only required while Segment Replication is under feature flag. + public static final SegmentReplicationSourceService NO_OP = new SegmentReplicationSourceService() { + @Override + public void clusterChanged(ClusterChangedEvent event) { + // NoOp; + } + + @Override + public void beforeIndexShardClosed(ShardId shardId, IndexShard indexShard, Settings indexSettings) { + // NoOp; + } + + @Override + public void shardRoutingChanged(IndexShard indexShard, @Nullable ShardRouting oldRouting, ShardRouting newRouting) { + // NoOp; + } + }; private static final Logger logger = LogManager.getLogger(SegmentReplicationSourceService.class); private final RecoverySettings recoverySettings; @@ -62,6 +81,14 @@ public static class Actions { private final OngoingSegmentReplications ongoingSegmentReplications; + // Used only for empty implementation. + private SegmentReplicationSourceService() { + recoverySettings = null; + ongoingSegmentReplications = null; + transportService = null; + indicesService = null; + } + public SegmentReplicationSourceService( IndicesService indicesService, TransportService transportService, @@ -163,10 +190,25 @@ protected void doClose() throws IOException { } + /** + * + * Cancels any replications on this node to a replica shard that is about to be closed. + */ @Override public void beforeIndexShardClosed(ShardId shardId, @Nullable IndexShard indexShard, Settings indexSettings) { if (indexShard != null) { ongoingSegmentReplications.cancel(indexShard, "shard is closed"); } } + + /** + * Cancels any replications on this node to a replica that has been promoted as primary. + */ + @Override + public void shardRoutingChanged(IndexShard indexShard, @Nullable ShardRouting oldRouting, ShardRouting newRouting) { + if (indexShard != null && oldRouting.primary() == false && newRouting.primary()) { + ongoingSegmentReplications.cancel(indexShard.routingEntry().allocationId().getId(), "Relocating primary shard."); + } + } + } diff --git a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationState.java b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationState.java index f865ba1332186..2e2e6df007c5c 100644 --- a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationState.java +++ b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationState.java @@ -35,7 +35,8 @@ public enum Stage { GET_CHECKPOINT_INFO((byte) 3), FILE_DIFF((byte) 4), GET_FILES((byte) 5), - FINALIZE_REPLICATION((byte) 6); + FINALIZE_REPLICATION((byte) 6), + CANCELLED((byte) 7); private static final Stage[] STAGES = new Stage[Stage.values().length]; @@ -118,6 +119,10 @@ protected void validateAndSetStage(Stage expected, Stage next) { "can't move replication to stage [" + next + "]. current stage: [" + stage + "] (expected [" + expected + "])" ); } + stopTimersAndSetStage(next); + } + + private void stopTimersAndSetStage(Stage next) { // save the timing data for the current step stageTimer.stop(); timingData.add(new Tuple<>(stage.name(), stageTimer.time())); @@ -155,6 +160,14 @@ public void setStage(Stage stage) { overallTimer.stop(); timingData.add(new Tuple<>("OVERALL", overallTimer.time())); break; + case CANCELLED: + if (this.stage == Stage.DONE) { + throw new IllegalStateException("can't move replication to Cancelled state from Done."); + } + stopTimersAndSetStage(Stage.CANCELLED); + overallTimer.stop(); + timingData.add(new Tuple<>("OVERALL", overallTimer.time())); + break; default: throw new IllegalArgumentException("unknown SegmentReplicationState.Stage [" + stage + "]"); } diff --git a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTarget.java b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTarget.java index a658ffc09d590..d1d6104a416ca 100644 --- a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTarget.java +++ b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTarget.java @@ -17,6 +17,7 @@ import org.apache.lucene.store.ByteBuffersDataInput; import org.apache.lucene.store.ByteBuffersIndexInput; import org.apache.lucene.store.ChecksumIndexInput; +import org.opensearch.ExceptionsHelper; import org.opensearch.OpenSearchException; import org.opensearch.action.ActionListener; import org.opensearch.action.StepListener; @@ -103,7 +104,15 @@ public String description() { @Override public void notifyListener(OpenSearchException e, boolean sendShardFailure) { - listener.onFailure(state(), e, sendShardFailure); + // Cancellations still are passed to our SegmentReplicationListner as failures, if we have failed because of cancellation + // update the stage. + final Throwable cancelledException = ExceptionsHelper.unwrap(e, CancellableThreads.ExecutionCancelledException.class); + if (cancelledException != null) { + state.setStage(SegmentReplicationState.Stage.CANCELLED); + listener.onFailure(state(), (CancellableThreads.ExecutionCancelledException) cancelledException, sendShardFailure); + } else { + listener.onFailure(state(), e, sendShardFailure); + } } @Override @@ -134,6 +143,14 @@ public void writeFileChunk( * @param listener {@link ActionListener} listener. */ public void startReplication(ActionListener listener) { + cancellableThreads.setOnCancel((reason, beforeCancelEx) -> { + // This method only executes when cancellation is triggered by this node and caught by a call to checkForCancel, + // SegmentReplicationSource does not share CancellableThreads. + final CancellableThreads.ExecutionCancelledException executionCancelledException = + new CancellableThreads.ExecutionCancelledException("replication was canceled reason [" + reason + "]"); + notifyListener(executionCancelledException, false); + throw executionCancelledException; + }); state.setStage(SegmentReplicationState.Stage.REPLICATING); final StepListener checkpointInfoListener = new StepListener<>(); final StepListener getFilesListener = new StepListener<>(); @@ -141,6 +158,7 @@ public void startReplication(ActionListener listener) { logger.trace("[shardId {}] Replica starting replication [id {}]", shardId().getId(), getId()); // Get list of files to copy from this checkpoint. + cancellableThreads.checkForCancel(); state.setStage(SegmentReplicationState.Stage.GET_CHECKPOINT_INFO); source.getCheckpointMetadata(getId(), checkpoint, checkpointInfoListener); @@ -154,6 +172,7 @@ public void startReplication(ActionListener listener) { private void getFiles(CheckpointInfoResponse checkpointInfo, StepListener getFilesListener) throws IOException { + cancellableThreads.checkForCancel(); state.setStage(SegmentReplicationState.Stage.FILE_DIFF); final Store.MetadataSnapshot snapshot = checkpointInfo.getSnapshot(); Store.MetadataSnapshot localMetadata = getMetadataSnapshot(); @@ -188,12 +207,14 @@ private void getFiles(CheckpointInfoResponse checkpointInfo, StepListener listener) { - state.setStage(SegmentReplicationState.Stage.FINALIZE_REPLICATION); ActionListener.completeWith(listener, () -> { + cancellableThreads.checkForCancel(); + state.setStage(SegmentReplicationState.Stage.FINALIZE_REPLICATION); multiFileWriter.renameAllTempFiles(); final Store store = store(); store.incRef(); @@ -261,4 +282,10 @@ Store.MetadataSnapshot getMetadataSnapshot() throws IOException { } return store.getMetadata(indexShard.getSegmentInfosSnapshot().get()); } + + @Override + protected void onCancel(String reason) { + cancellableThreads.cancel(reason); + source.cancel(); + } } diff --git a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTargetService.java b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTargetService.java index a79ce195ad83b..9e6b66dc4d7d6 100644 --- a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTargetService.java +++ b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTargetService.java @@ -11,10 +11,13 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.ExceptionsHelper; import org.opensearch.OpenSearchException; import org.opensearch.action.ActionListener; +import org.opensearch.cluster.routing.ShardRouting; import org.opensearch.common.Nullable; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.CancellableThreads; import org.opensearch.index.shard.IndexEventListener; import org.opensearch.index.shard.IndexShard; import org.opensearch.index.shard.ShardId; @@ -64,6 +67,11 @@ public void beforeIndexShardClosed(ShardId shardId, IndexShard indexShard, Setti public synchronized void onNewCheckpoint(ReplicationCheckpoint receivedCheckpoint, IndexShard replicaShard) { // noOp; } + + @Override + public void shardRoutingChanged(IndexShard indexShard, @Nullable ShardRouting oldRouting, ShardRouting newRouting) { + // noOp; + } }; // Used only for empty implementation. @@ -74,6 +82,10 @@ private SegmentReplicationTargetService() { sourceFactory = null; } + public ReplicationRef get(long replicationId) { + return onGoingReplications.get(replicationId); + } + /** * The internal actions * @@ -102,6 +114,9 @@ public SegmentReplicationTargetService( ); } + /** + * Cancel any replications on this node for a replica that is about to be closed. + */ @Override public void beforeIndexShardClosed(ShardId shardId, @Nullable IndexShard indexShard, Settings indexSettings) { if (indexShard != null) { @@ -109,11 +124,22 @@ public void beforeIndexShardClosed(ShardId shardId, @Nullable IndexShard indexSh } } + /** + * Cancel any replications on this node for a replica that has just been promoted as the new primary. + */ + @Override + public void shardRoutingChanged(IndexShard indexShard, @Nullable ShardRouting oldRouting, ShardRouting newRouting) { + if (oldRouting != null && oldRouting.primary() == false && newRouting.primary()) { + onGoingReplications.cancelForShard(indexShard.shardId(), "shard has been promoted to primary"); + } + } + /** * Invoked when a new checkpoint is received from a primary shard. * It checks if a new checkpoint should be processed or not and starts replication if needed. - * @param receivedCheckpoint received checkpoint that is checked for processing - * @param replicaShard replica shard on which checkpoint is received + * + * @param receivedCheckpoint received checkpoint that is checked for processing + * @param replicaShard replica shard on which checkpoint is received */ public synchronized void onNewCheckpoint(final ReplicationCheckpoint receivedCheckpoint, final IndexShard replicaShard) { logger.trace(() -> new ParameterizedMessage("Replica received new replication checkpoint from primary [{}]", receivedCheckpoint)); @@ -180,12 +206,19 @@ public void onReplicationFailure(SegmentReplicationState state, OpenSearchExcept } } - public void startReplication( + public SegmentReplicationTarget startReplication( final ReplicationCheckpoint checkpoint, final IndexShard indexShard, final SegmentReplicationListener listener ) { - startReplication(new SegmentReplicationTarget(checkpoint, indexShard, sourceFactory.get(indexShard), listener)); + final SegmentReplicationTarget target = new SegmentReplicationTarget( + checkpoint, + indexShard, + sourceFactory.get(indexShard), + listener + ); + startReplication(target); + return target; } // pkg-private for integration tests @@ -248,7 +281,17 @@ public void onResponse(Void o) { @Override public void onFailure(Exception e) { - onGoingReplications.fail(replicationId, new OpenSearchException("Segment Replication failed", e), true); + Throwable cause = ExceptionsHelper.unwrapCause(e); + if (cause instanceof CancellableThreads.ExecutionCancelledException) { + if (onGoingReplications.getTarget(replicationId) != null) { + // if the target still exists in our collection, the primary initiated the cancellation, fail the replication + // but do not fail the shard. Cancellations initiated by this node from Index events will be removed with + // onGoingReplications.cancel and not appear in the collection when this listener resolves. + onGoingReplications.fail(replicationId, (CancellableThreads.ExecutionCancelledException) cause, false); + } + } else { + onGoingReplications.fail(replicationId, new OpenSearchException("Segment Replication failed", e), true); + } } }); } diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java index 3f4eadc52fd2a..92e9815313fa0 100644 --- a/server/src/main/java/org/opensearch/node/Node.java +++ b/server/src/main/java/org/opensearch/node/Node.java @@ -969,6 +969,7 @@ protected Node( .toInstance(new SegmentReplicationSourceService(indicesService, transportService, recoverySettings)); } else { b.bind(SegmentReplicationTargetService.class).toInstance(SegmentReplicationTargetService.NO_OP); + b.bind(SegmentReplicationSourceService.class).toInstance(SegmentReplicationSourceService.NO_OP); } } b.bind(HttpServerTransport.class).toInstance(httpServerTransport); @@ -1112,6 +1113,9 @@ public Node start() throws NodeValidationException { assert transportService.getLocalNode().equals(localNodeFactory.getNode()) : "transportService has a different local node than the factory provided"; injector.getInstance(PeerRecoverySourceService.class).start(); + if (FeatureFlags.isEnabled(REPLICATION_TYPE)) { + injector.getInstance(SegmentReplicationSourceService.class).start(); + } // Load (and maybe upgrade) the metadata stored on disk final GatewayMetaState gatewayMetaState = injector.getInstance(GatewayMetaState.class); @@ -1287,6 +1291,9 @@ public synchronized void close() throws IOException { // close filter/fielddata caches after indices toClose.add(injector.getInstance(IndicesStore.class)); toClose.add(injector.getInstance(PeerRecoverySourceService.class)); + if (FeatureFlags.isEnabled(REPLICATION_TYPE)) { + toClose.add(injector.getInstance(SegmentReplicationSourceService.class)); + } toClose.add(() -> stopWatch.stop().start("cluster")); toClose.add(injector.getInstance(ClusterService.class)); toClose.add(() -> stopWatch.stop().start("node_connections_service")); diff --git a/server/src/test/java/org/opensearch/index/shard/SegmentReplicationIndexShardTests.java b/server/src/test/java/org/opensearch/index/shard/SegmentReplicationIndexShardTests.java index 23371a39871c7..88a3bdad53d0c 100644 --- a/server/src/test/java/org/opensearch/index/shard/SegmentReplicationIndexShardTests.java +++ b/server/src/test/java/org/opensearch/index/shard/SegmentReplicationIndexShardTests.java @@ -8,11 +8,18 @@ package org.opensearch.index.shard; +import org.junit.Assert; +import org.opensearch.OpenSearchException; +import org.opensearch.action.ActionListener; import org.opensearch.action.delete.DeleteRequest; import org.opensearch.action.index.IndexRequest; import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.routing.ShardRouting; +import org.opensearch.common.lease.Releasable; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.CancellableThreads; import org.opensearch.common.xcontent.XContentType; import org.opensearch.index.IndexSettings; import org.opensearch.index.engine.DocIdSeqNoAndSource; @@ -21,12 +28,28 @@ import org.opensearch.index.engine.NRTReplicationEngineFactory; import org.opensearch.index.mapper.MapperService; import org.opensearch.index.replication.OpenSearchIndexLevelReplicationTestCase; +import org.opensearch.index.store.Store; +import org.opensearch.index.store.StoreFileMetadata; +import org.opensearch.indices.recovery.RecoverySettings; +import org.opensearch.indices.replication.CheckpointInfoResponse; +import org.opensearch.indices.replication.GetSegmentFilesResponse; +import org.opensearch.indices.replication.SegmentReplicationSource; +import org.opensearch.indices.replication.SegmentReplicationSourceFactory; +import org.opensearch.indices.replication.SegmentReplicationState; +import org.opensearch.indices.replication.SegmentReplicationTarget; +import org.opensearch.indices.replication.SegmentReplicationTargetService; import org.opensearch.indices.replication.checkpoint.SegmentReplicationCheckpointPublisher; import org.opensearch.indices.replication.checkpoint.ReplicationCheckpoint; +import org.opensearch.indices.replication.common.CopyState; import org.opensearch.indices.replication.common.ReplicationType; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; import java.io.IOException; +import java.util.Collections; import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import static java.util.Arrays.asList; import static org.hamcrest.Matchers.equalTo; @@ -34,6 +57,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; public class SegmentReplicationIndexShardTests extends OpenSearchIndexLevelReplicationTestCase { @@ -241,6 +265,213 @@ public void testNRTReplicaPromotedAsPrimary() throws Exception { } } + public void testReplicaPromotedWhileReplicating() throws Exception { + try (ReplicationGroup shards = createGroup(1, settings, new NRTReplicationEngineFactory())) { + shards.startAll(); + final IndexShard oldPrimary = shards.getPrimary(); + final IndexShard nextPrimary = shards.getReplicas().get(0); + + final int numDocs = shards.indexDocs(randomInt(10)); + oldPrimary.refresh("Test"); + shards.syncGlobalCheckpoint(); + + final SegmentReplicationSourceFactory sourceFactory = mock(SegmentReplicationSourceFactory.class); + final SegmentReplicationTargetService targetService = newTargetService(sourceFactory); + SegmentReplicationSource source = new SegmentReplicationSource() { + @Override + public void getCheckpointMetadata( + long replicationId, + ReplicationCheckpoint checkpoint, + ActionListener listener + ) { + resolveCheckpointInfoResponseListener(listener, oldPrimary); + ShardRouting oldRouting = nextPrimary.shardRouting; + try { + shards.promoteReplicaToPrimary(nextPrimary); + } catch (IOException e) { + Assert.fail("Promotion should not fail"); + } + targetService.shardRoutingChanged(nextPrimary, oldRouting, nextPrimary.shardRouting); + } + + @Override + public void getSegmentFiles( + long replicationId, + ReplicationCheckpoint checkpoint, + List filesToFetch, + Store store, + ActionListener listener + ) { + listener.onResponse(new GetSegmentFilesResponse(Collections.emptyList())); + } + }; + when(sourceFactory.get(any())).thenReturn(source); + startReplicationAndAssertCancellation(nextPrimary, targetService); + // wait for replica to finish being promoted, and assert doc counts. + final CountDownLatch latch = new CountDownLatch(1); + nextPrimary.acquirePrimaryOperationPermit(new ActionListener<>() { + @Override + public void onResponse(Releasable releasable) { + latch.countDown(); + } + + @Override + public void onFailure(Exception e) { + throw new AssertionError(e); + } + }, ThreadPool.Names.GENERIC, ""); + latch.await(); + assertEquals(nextPrimary.getEngine().getClass(), InternalEngine.class); + nextPrimary.refresh("test"); + + oldPrimary.close("demoted", false); + oldPrimary.store().close(); + IndexShard newReplica = shards.addReplicaWithExistingPath(oldPrimary.shardPath(), oldPrimary.routingEntry().currentNodeId()); + shards.recoverReplica(newReplica); + + assertDocCount(nextPrimary, numDocs); + assertDocCount(newReplica, numDocs); + + nextPrimary.refresh("test"); + replicateSegments(nextPrimary, shards.getReplicas()); + final List docsAfterRecovery = getDocIdAndSeqNos(shards.getPrimary()); + for (IndexShard shard : shards.getReplicas()) { + assertThat(shard.routingEntry().toString(), getDocIdAndSeqNos(shard), equalTo(docsAfterRecovery)); + } + } + } + + public void testReplicaClosesWhileReplicating_AfterGetCheckpoint() throws Exception { + try (ReplicationGroup shards = createGroup(1, settings, new NRTReplicationEngineFactory())) { + shards.startAll(); + IndexShard primary = shards.getPrimary(); + final IndexShard replica = shards.getReplicas().get(0); + + final int numDocs = shards.indexDocs(randomInt(10)); + primary.refresh("Test"); + + final SegmentReplicationSourceFactory sourceFactory = mock(SegmentReplicationSourceFactory.class); + final SegmentReplicationTargetService targetService = newTargetService(sourceFactory); + SegmentReplicationSource source = new SegmentReplicationSource() { + @Override + public void getCheckpointMetadata( + long replicationId, + ReplicationCheckpoint checkpoint, + ActionListener listener + ) { + // trigger a cancellation by closing the replica. + targetService.beforeIndexShardClosed(replica.shardId, replica, Settings.EMPTY); + resolveCheckpointInfoResponseListener(listener, primary); + } + + @Override + public void getSegmentFiles( + long replicationId, + ReplicationCheckpoint checkpoint, + List filesToFetch, + Store store, + ActionListener listener + ) { + Assert.fail("Should not be reached"); + } + }; + when(sourceFactory.get(any())).thenReturn(source); + startReplicationAndAssertCancellation(replica, targetService); + + shards.removeReplica(replica); + closeShards(replica); + } + } + + public void testReplicaClosesWhileReplicating_AfterGetSegmentFiles() throws Exception { + try (ReplicationGroup shards = createGroup(1, settings, new NRTReplicationEngineFactory())) { + shards.startAll(); + IndexShard primary = shards.getPrimary(); + final IndexShard replica = shards.getReplicas().get(0); + + final int numDocs = shards.indexDocs(randomInt(10)); + primary.refresh("Test"); + + final SegmentReplicationSourceFactory sourceFactory = mock(SegmentReplicationSourceFactory.class); + final SegmentReplicationTargetService targetService = newTargetService(sourceFactory); + SegmentReplicationSource source = new SegmentReplicationSource() { + @Override + public void getCheckpointMetadata( + long replicationId, + ReplicationCheckpoint checkpoint, + ActionListener listener + ) { + resolveCheckpointInfoResponseListener(listener, primary); + } + + @Override + public void getSegmentFiles( + long replicationId, + ReplicationCheckpoint checkpoint, + List filesToFetch, + Store store, + ActionListener listener + ) { + // randomly resolve the listener, indicating the source has resolved. + listener.onResponse(new GetSegmentFilesResponse(Collections.emptyList())); + targetService.beforeIndexShardClosed(replica.shardId, replica, Settings.EMPTY); + } + }; + when(sourceFactory.get(any())).thenReturn(source); + startReplicationAndAssertCancellation(replica, targetService); + + shards.removeReplica(replica); + closeShards(replica); + } + } + + public void testPrimaryCancelsExecution() throws Exception { + try (ReplicationGroup shards = createGroup(1, settings, new NRTReplicationEngineFactory())) { + shards.startAll(); + IndexShard primary = shards.getPrimary(); + final IndexShard replica = shards.getReplicas().get(0); + + final int numDocs = shards.indexDocs(randomInt(10)); + primary.refresh("Test"); + + final SegmentReplicationSourceFactory sourceFactory = mock(SegmentReplicationSourceFactory.class); + final SegmentReplicationTargetService targetService = newTargetService(sourceFactory); + SegmentReplicationSource source = new SegmentReplicationSource() { + @Override + public void getCheckpointMetadata( + long replicationId, + ReplicationCheckpoint checkpoint, + ActionListener listener + ) { + listener.onFailure(new CancellableThreads.ExecutionCancelledException("Cancelled")); + } + + @Override + public void getSegmentFiles( + long replicationId, + ReplicationCheckpoint checkpoint, + List filesToFetch, + Store store, + ActionListener listener + ) {} + }; + when(sourceFactory.get(any())).thenReturn(source); + startReplicationAndAssertCancellation(replica, targetService); + + shards.removeReplica(replica); + closeShards(replica); + } + } + + private SegmentReplicationTargetService newTargetService(SegmentReplicationSourceFactory sourceFactory) { + return new SegmentReplicationTargetService( + threadPool, + new RecoverySettings(Settings.EMPTY, new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS)), + mock(TransportService.class), + sourceFactory + ); + } + /** * Assert persisted and searchable doc counts. This method should not be used while docs are concurrently indexed because * it asserts point in time seqNos are relative to the doc counts. @@ -253,4 +484,48 @@ private void assertDocCounts(IndexShard indexShard, int expectedPersistedDocCoun // processed cp should be 1 less than our searchable doc count. assertEquals(expectedSearchableDocCount - 1, indexShard.getProcessedLocalCheckpoint()); } + + private void resolveCheckpointInfoResponseListener(ActionListener listener, IndexShard primary) { + try { + final CopyState copyState = new CopyState(ReplicationCheckpoint.empty(primary.shardId), primary); + listener.onResponse( + new CheckpointInfoResponse( + copyState.getCheckpoint(), + copyState.getMetadataSnapshot(), + copyState.getInfosBytes(), + copyState.getPendingDeleteFiles() + ) + ); + } catch (IOException e) { + logger.error("Unexpected error computing CopyState", e); + Assert.fail("Failed to compute copyState"); + } + } + + private void startReplicationAndAssertCancellation(IndexShard replica, SegmentReplicationTargetService targetService) + throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + final SegmentReplicationTarget target = targetService.startReplication( + ReplicationCheckpoint.empty(replica.shardId), + replica, + new SegmentReplicationTargetService.SegmentReplicationListener() { + @Override + public void onReplicationDone(SegmentReplicationState state) { + Assert.fail("Replication should not complete"); + } + + @Override + public void onReplicationFailure(SegmentReplicationState state, OpenSearchException e, boolean sendShardFailure) { + assertTrue(e instanceof CancellableThreads.ExecutionCancelledException); + assertFalse(sendShardFailure); + assertEquals(SegmentReplicationState.Stage.CANCELLED, state.getStage()); + latch.countDown(); + } + } + ); + + latch.await(2, TimeUnit.SECONDS); + assertEquals("Should have resolved listener with failure", 0, latch.getCount()); + assertNull(targetService.get(target.getId())); + } } diff --git a/server/src/test/java/org/opensearch/indices/cluster/IndicesClusterStateServiceRandomUpdatesTests.java b/server/src/test/java/org/opensearch/indices/cluster/IndicesClusterStateServiceRandomUpdatesTests.java index 1f2360abde2ad..22481b5a7b99f 100644 --- a/server/src/test/java/org/opensearch/indices/cluster/IndicesClusterStateServiceRandomUpdatesTests.java +++ b/server/src/test/java/org/opensearch/indices/cluster/IndicesClusterStateServiceRandomUpdatesTests.java @@ -66,6 +66,7 @@ import org.opensearch.index.shard.PrimaryReplicaSyncer; import org.opensearch.index.shard.ShardId; import org.opensearch.indices.recovery.PeerRecoveryTargetService; +import org.opensearch.indices.replication.SegmentReplicationSourceService; import org.opensearch.indices.replication.SegmentReplicationTargetService; import org.opensearch.indices.replication.checkpoint.SegmentReplicationCheckpointPublisher; import org.opensearch.repositories.RepositoriesService; @@ -572,6 +573,7 @@ private IndicesClusterStateService createIndicesClusterStateService( threadPool, SegmentReplicationCheckpointPublisher.EMPTY, SegmentReplicationTargetService.NO_OP, + SegmentReplicationSourceService.NO_OP, recoveryTargetService, shardStateAction, null, diff --git a/server/src/test/java/org/opensearch/indices/replication/OngoingSegmentReplicationsTests.java b/server/src/test/java/org/opensearch/indices/replication/OngoingSegmentReplicationsTests.java index 38c55620e1223..f49ee0471b5e8 100644 --- a/server/src/test/java/org/opensearch/indices/replication/OngoingSegmentReplicationsTests.java +++ b/server/src/test/java/org/opensearch/indices/replication/OngoingSegmentReplicationsTests.java @@ -14,6 +14,8 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.CancellableThreads; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.index.IndexService; import org.opensearch.index.shard.IndexShard; import org.opensearch.index.shard.IndexShardTestCase; @@ -31,6 +33,8 @@ import java.util.Collection; import java.util.Collections; import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import static org.mockito.Mockito.any; import static org.mockito.Mockito.mock; @@ -154,6 +158,51 @@ public void testCancelReplication() throws IOException { assertEquals(0, replications.cachedCopyStateSize()); } + public void testCancelReplication_AfterSendFilesStarts() throws IOException, InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + OngoingSegmentReplications replications = new OngoingSegmentReplications(mockIndicesService, recoverySettings); + // add a doc and refresh so primary has more than one segment. + indexDoc(primary, "1", "{\"foo\" : \"baz\"}", XContentType.JSON, "foobar"); + primary.refresh("Test"); + final CheckpointInfoRequest request = new CheckpointInfoRequest( + 1L, + replica.routingEntry().allocationId().getId(), + primaryDiscoveryNode, + testCheckpoint + ); + final FileChunkWriter segmentSegmentFileChunkWriter = (fileMetadata, position, content, lastChunk, totalTranslogOps, listener) -> { + // cancel the replication as soon as the writer starts sending files. + replications.cancel(replica.routingEntry().allocationId().getId(), "Test"); + }; + final CopyState copyState = replications.prepareForReplication(request, segmentSegmentFileChunkWriter); + assertEquals(1, replications.size()); + assertEquals(1, replications.cachedCopyStateSize()); + getSegmentFilesRequest = new GetSegmentFilesRequest( + 1L, + replica.routingEntry().allocationId().getId(), + replicaDiscoveryNode, + new ArrayList<>(copyState.getMetadataSnapshot().asMap().values()), + testCheckpoint + ); + replications.startSegmentCopy(getSegmentFilesRequest, new ActionListener<>() { + @Override + public void onResponse(GetSegmentFilesResponse getSegmentFilesResponse) { + Assert.fail("Expected onFailure to be invoked."); + } + + @Override + public void onFailure(Exception e) { + assertEquals(CancellableThreads.ExecutionCancelledException.class, e.getClass()); + assertEquals(0, copyState.refCount()); + assertEquals(0, replications.size()); + assertEquals(0, replications.cachedCopyStateSize()); + latch.countDown(); + } + }); + latch.await(2, TimeUnit.SECONDS); + assertEquals("listener should have resolved with failure", 0, latch.getCount()); + } + public void testMultipleReplicasUseSameCheckpoint() throws IOException { IndexShard secondReplica = newShard(primary.shardId(), false); recoverReplica(secondReplica, primary, true); diff --git a/server/src/test/java/org/opensearch/indices/replication/PrimaryShardReplicationSourceTests.java b/server/src/test/java/org/opensearch/indices/replication/PrimaryShardReplicationSourceTests.java index 6bce74be569c3..323445bee1274 100644 --- a/server/src/test/java/org/opensearch/indices/replication/PrimaryShardReplicationSourceTests.java +++ b/server/src/test/java/org/opensearch/indices/replication/PrimaryShardReplicationSourceTests.java @@ -9,12 +9,14 @@ package org.opensearch.indices.replication; import org.apache.lucene.util.Version; +import org.junit.Assert; import org.opensearch.action.ActionListener; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.node.DiscoveryNodeRole; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.CancellableThreads; import org.opensearch.core.internal.io.IOUtils; import org.opensearch.index.shard.IndexShard; import org.opensearch.index.shard.IndexShardTestCase; @@ -28,6 +30,8 @@ import java.util.Arrays; import java.util.Collections; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import static org.mockito.Mockito.mock; @@ -126,6 +130,39 @@ public void testGetSegmentFiles() { assertTrue(capturedRequest.request instanceof GetSegmentFilesRequest); } + public void testGetSegmentFiles_CancelWhileRequestOpen() throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + final ReplicationCheckpoint checkpoint = new ReplicationCheckpoint( + indexShard.shardId(), + PRIMARY_TERM, + SEGMENTS_GEN, + SEQ_NO, + VERSION + ); + StoreFileMetadata testMetadata = new StoreFileMetadata("testFile", 1L, "checksum", Version.LATEST); + replicationSource.getSegmentFiles( + REPLICATION_ID, + checkpoint, + Arrays.asList(testMetadata), + mock(Store.class), + new ActionListener<>() { + @Override + public void onResponse(GetSegmentFilesResponse getSegmentFilesResponse) { + Assert.fail("onFailure response expected."); + } + + @Override + public void onFailure(Exception e) { + assertEquals(e.getClass(), CancellableThreads.ExecutionCancelledException.class); + latch.countDown(); + } + } + ); + replicationSource.cancel(); + latch.await(2, TimeUnit.SECONDS); + assertEquals("listener should have resolved in a failure", 0, latch.getCount()); + } + private DiscoveryNode newDiscoveryNode(String nodeName) { return new DiscoveryNode( nodeName, diff --git a/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationSourceHandlerTests.java b/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationSourceHandlerTests.java index 2c52772649acc..a6e169dbc3d61 100644 --- a/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationSourceHandlerTests.java +++ b/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationSourceHandlerTests.java @@ -18,6 +18,7 @@ import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.CancellableThreads; import org.opensearch.index.shard.IndexShard; import org.opensearch.index.shard.IndexShardTestCase; import org.opensearch.index.store.StoreFileMetadata; @@ -28,6 +29,8 @@ import java.io.IOException; import java.util.Collections; import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import static org.mockito.Mockito.mock; @@ -197,4 +200,47 @@ public void testReplicationAlreadyRunning() throws IOException { handler.sendFiles(getSegmentFilesRequest, mock(ActionListener.class)); Assert.assertThrows(OpenSearchException.class, () -> { handler.sendFiles(getSegmentFilesRequest, mock(ActionListener.class)); }); } + + public void testCancelReplication() throws IOException, InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + chunkWriter = mock(FileChunkWriter.class); + + final ReplicationCheckpoint latestReplicationCheckpoint = primary.getLatestReplicationCheckpoint(); + final CopyState copyState = new CopyState(latestReplicationCheckpoint, primary); + SegmentReplicationSourceHandler handler = new SegmentReplicationSourceHandler( + localNode, + chunkWriter, + threadPool, + copyState, + primary.routingEntry().allocationId().getId(), + 5000, + 1 + ); + + final GetSegmentFilesRequest getSegmentFilesRequest = new GetSegmentFilesRequest( + 1L, + replica.routingEntry().allocationId().getId(), + replicaDiscoveryNode, + Collections.emptyList(), + latestReplicationCheckpoint + ); + + // cancel before xfer starts. Cancels during copy will be tested in SegmentFileTransferHandlerTests, that uses the same + // cancellableThreads. + handler.cancel("test"); + handler.sendFiles(getSegmentFilesRequest, new ActionListener<>() { + @Override + public void onResponse(GetSegmentFilesResponse getSegmentFilesResponse) { + Assert.fail("Expected failure."); + } + + @Override + public void onFailure(Exception e) { + assertEquals(CancellableThreads.ExecutionCancelledException.class, e.getClass()); + latch.countDown(); + } + }); + latch.await(2, TimeUnit.SECONDS); + assertEquals("listener should have resolved with failure", 0, latch.getCount()); + } } diff --git a/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java b/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java index de739f4ca834a..7d9b0f09f21cd 100644 --- a/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java +++ b/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java @@ -9,21 +9,22 @@ package org.opensearch.indices.replication; import org.junit.Assert; -import org.mockito.ArgumentCaptor; import org.mockito.Mockito; import org.opensearch.OpenSearchException; import org.opensearch.action.ActionListener; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; +import org.opensearch.index.engine.NRTReplicationEngineFactory; import org.opensearch.index.shard.IndexShard; import org.opensearch.index.shard.IndexShardTestCase; -import org.opensearch.indices.recovery.RecoverySettings; +import org.opensearch.index.store.Store; +import org.opensearch.index.store.StoreFileMetadata; import org.opensearch.indices.replication.checkpoint.ReplicationCheckpoint; -import org.opensearch.indices.replication.common.ReplicationLuceneIndex; -import org.opensearch.transport.TransportService; +import org.opensearch.indices.replication.common.ReplicationType; import java.io.IOException; +import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -35,12 +36,12 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.times; import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.eq; public class SegmentReplicationTargetServiceTests extends IndexShardTestCase { - private IndexShard indexShard; + private IndexShard replicaShard; + private IndexShard primaryShard; private ReplicationCheckpoint checkpoint; private SegmentReplicationSource replicationSource; private SegmentReplicationTargetService sut; @@ -52,20 +53,20 @@ public class SegmentReplicationTargetServiceTests extends IndexShardTestCase { public void setUp() throws Exception { super.setUp(); final Settings settings = Settings.builder() - .put(IndexMetadata.SETTING_REPLICATION_TYPE, "SEGMENT") + .put(IndexMetadata.SETTING_REPLICATION_TYPE, ReplicationType.SEGMENT) .put("node.name", SegmentReplicationTargetServiceTests.class.getSimpleName()) .build(); final ClusterSettings clusterSettings = new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); - final RecoverySettings recoverySettings = new RecoverySettings(settings, clusterSettings); - final TransportService transportService = mock(TransportService.class); - indexShard = newStartedShard(false, settings); - checkpoint = new ReplicationCheckpoint(indexShard.shardId(), 0L, 0L, 0L, 0L); + primaryShard = newStartedShard(true); + replicaShard = newShard(false, settings, new NRTReplicationEngineFactory()); + recoverReplica(replicaShard, primaryShard, true); + checkpoint = new ReplicationCheckpoint(replicaShard.shardId(), 0L, 0L, 0L, 0L); SegmentReplicationSourceFactory replicationSourceFactory = mock(SegmentReplicationSourceFactory.class); replicationSource = mock(SegmentReplicationSource.class); - when(replicationSourceFactory.get(indexShard)).thenReturn(replicationSource); + when(replicationSourceFactory.get(replicaShard)).thenReturn(replicationSource); - sut = new SegmentReplicationTargetService(threadPool, recoverySettings, transportService, replicationSourceFactory); - initialCheckpoint = indexShard.getLatestReplicationCheckpoint(); + sut = prepareForReplication(primaryShard); + initialCheckpoint = replicaShard.getLatestReplicationCheckpoint(); aheadCheckpoint = new ReplicationCheckpoint( initialCheckpoint.getShardId(), initialCheckpoint.getPrimaryTerm(), @@ -77,44 +78,58 @@ public void setUp() throws Exception { @Override public void tearDown() throws Exception { - closeShards(indexShard); + closeShards(primaryShard, replicaShard); super.tearDown(); } - public void testTargetReturnsSuccess_listenerCompletes() { - final SegmentReplicationTarget target = new SegmentReplicationTarget( - checkpoint, - indexShard, - replicationSource, - new SegmentReplicationTargetService.SegmentReplicationListener() { - @Override - public void onReplicationDone(SegmentReplicationState state) { - assertEquals(SegmentReplicationState.Stage.DONE, state.getStage()); - } + public void testsSuccessfulReplication_listenerCompletes() throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + sut.startReplication(checkpoint, replicaShard, new SegmentReplicationTargetService.SegmentReplicationListener() { + @Override + public void onReplicationDone(SegmentReplicationState state) { + assertEquals(SegmentReplicationState.Stage.DONE, state.getStage()); + latch.countDown(); + } - @Override - public void onReplicationFailure(SegmentReplicationState state, OpenSearchException e, boolean sendShardFailure) { - Assert.fail(); - } + @Override + public void onReplicationFailure(SegmentReplicationState state, OpenSearchException e, boolean sendShardFailure) { + logger.error("Unexpected error", e); + Assert.fail("Test should succeed"); } - ); - final SegmentReplicationTarget spy = Mockito.spy(target); - doAnswer(invocation -> { - // set up stage correctly so the transition in markAsDone succeeds on listener completion - moveTargetToFinalStage(target); - final ActionListener listener = invocation.getArgument(0); - listener.onResponse(null); - return null; - }).when(spy).startReplication(any()); - sut.startReplication(spy); + }); + latch.await(2, TimeUnit.SECONDS); + assertEquals(0, latch.getCount()); } - public void testTargetThrowsException() { + public void testReplicationFails() throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); final OpenSearchException expectedError = new OpenSearchException("Fail"); + SegmentReplicationSource source = new SegmentReplicationSource() { + + @Override + public void getCheckpointMetadata( + long replicationId, + ReplicationCheckpoint checkpoint, + ActionListener listener + ) { + listener.onFailure(expectedError); + } + + @Override + public void getSegmentFiles( + long replicationId, + ReplicationCheckpoint checkpoint, + List filesToFetch, + Store store, + ActionListener listener + ) { + Assert.fail("Should not be called"); + } + }; final SegmentReplicationTarget target = new SegmentReplicationTarget( checkpoint, - indexShard, - replicationSource, + replicaShard, + source, new SegmentReplicationTargetService.SegmentReplicationListener() { @Override public void onReplicationDone(SegmentReplicationState state) { @@ -123,24 +138,21 @@ public void onReplicationDone(SegmentReplicationState state) { @Override public void onReplicationFailure(SegmentReplicationState state, OpenSearchException e, boolean sendShardFailure) { - assertEquals(SegmentReplicationState.Stage.INIT, state.getStage()); + // failures leave state object in last entered stage. + assertEquals(SegmentReplicationState.Stage.GET_CHECKPOINT_INFO, state.getStage()); assertEquals(expectedError, e.getCause()); - assertTrue(sendShardFailure); + latch.countDown(); } } ); - final SegmentReplicationTarget spy = Mockito.spy(target); - doAnswer(invocation -> { - final ActionListener listener = invocation.getArgument(0); - listener.onFailure(expectedError); - return null; - }).when(spy).startReplication(any()); - sut.startReplication(spy); + sut.startReplication(target); + latch.await(2, TimeUnit.SECONDS); + assertEquals(0, latch.getCount()); } public void testAlreadyOnNewCheckpoint() { SegmentReplicationTargetService spy = spy(sut); - spy.onNewCheckpoint(indexShard.getLatestReplicationCheckpoint(), indexShard); + spy.onNewCheckpoint(replicaShard.getLatestReplicationCheckpoint(), replicaShard); verify(spy, times(0)).startReplication(any(), any(), any()); } @@ -149,7 +161,7 @@ public void testShardAlreadyReplicating() throws InterruptedException { SegmentReplicationTargetService serviceSpy = spy(sut); final SegmentReplicationTarget target = new SegmentReplicationTarget( checkpoint, - indexShard, + replicaShard, replicationSource, mock(SegmentReplicationTargetService.SegmentReplicationListener.class) ); @@ -161,7 +173,7 @@ public void testShardAlreadyReplicating() throws InterruptedException { doAnswer(invocation -> { final ActionListener listener = invocation.getArgument(0); // a new checkpoint arrives before we've completed. - serviceSpy.onNewCheckpoint(aheadCheckpoint, indexShard); + serviceSpy.onNewCheckpoint(aheadCheckpoint, replicaShard); listener.onResponse(null); latch.countDown(); return null; @@ -173,12 +185,12 @@ public void testShardAlreadyReplicating() throws InterruptedException { // wait for the new checkpoint to arrive, before the listener completes. latch.await(30, TimeUnit.SECONDS); - verify(serviceSpy, times(0)).startReplication(eq(aheadCheckpoint), eq(indexShard), any()); + verify(serviceSpy, times(0)).startReplication(eq(aheadCheckpoint), eq(replicaShard), any()); } public void testNewCheckpointBehindCurrentCheckpoint() { SegmentReplicationTargetService spy = spy(sut); - spy.onNewCheckpoint(checkpoint, indexShard); + spy.onNewCheckpoint(checkpoint, replicaShard); verify(spy, times(0)).startReplication(any(), any(), any()); } @@ -190,22 +202,6 @@ public void testShardNotStarted() throws IOException { closeShards(shard); } - public void testNewCheckpoint_validationPassesAndReplicationFails() throws IOException { - allowShardFailures(); - SegmentReplicationTargetService spy = spy(sut); - IndexShard spyShard = spy(indexShard); - ArgumentCaptor captor = ArgumentCaptor.forClass( - SegmentReplicationTargetService.SegmentReplicationListener.class - ); - doNothing().when(spy).startReplication(any(), any(), any()); - spy.onNewCheckpoint(aheadCheckpoint, spyShard); - verify(spy, times(1)).startReplication(any(), any(), captor.capture()); - SegmentReplicationTargetService.SegmentReplicationListener listener = captor.getValue(); - listener.onFailure(new SegmentReplicationState(new ReplicationLuceneIndex()), new OpenSearchException("testing"), true); - verify(spyShard).failShard(any(), any()); - closeShard(indexShard, false); - } - /** * here we are starting a new shard in PrimaryMode and testing that we don't process a checkpoint on shard when it is in PrimaryMode. */ @@ -215,70 +211,10 @@ public void testRejectCheckpointOnShardPrimaryMode() throws IOException { // Starting a new shard in PrimaryMode. IndexShard primaryShard = newStartedShard(true); IndexShard spyShard = spy(primaryShard); - doNothing().when(spy).startReplication(any(), any(), any()); spy.onNewCheckpoint(aheadCheckpoint, spyShard); // Verify that checkpoint is not processed as shard is in PrimaryMode. verify(spy, times(0)).startReplication(any(), any(), any()); closeShards(primaryShard); } - - public void testReplicationOnDone() throws IOException { - SegmentReplicationTargetService spy = spy(sut); - IndexShard spyShard = spy(indexShard); - ReplicationCheckpoint cp = indexShard.getLatestReplicationCheckpoint(); - ReplicationCheckpoint newCheckpoint = new ReplicationCheckpoint( - cp.getShardId(), - cp.getPrimaryTerm(), - cp.getSegmentsGen(), - cp.getSeqNo(), - cp.getSegmentInfosVersion() + 1 - ); - ReplicationCheckpoint anotherNewCheckpoint = new ReplicationCheckpoint( - cp.getShardId(), - cp.getPrimaryTerm(), - cp.getSegmentsGen(), - cp.getSeqNo(), - cp.getSegmentInfosVersion() + 2 - ); - ArgumentCaptor captor = ArgumentCaptor.forClass( - SegmentReplicationTargetService.SegmentReplicationListener.class - ); - doNothing().when(spy).startReplication(any(), any(), any()); - spy.onNewCheckpoint(newCheckpoint, spyShard); - spy.onNewCheckpoint(anotherNewCheckpoint, spyShard); - verify(spy, times(1)).startReplication(eq(newCheckpoint), any(), captor.capture()); - verify(spy, times(1)).onNewCheckpoint(eq(anotherNewCheckpoint), any()); - SegmentReplicationTargetService.SegmentReplicationListener listener = captor.getValue(); - listener.onDone(new SegmentReplicationState(new ReplicationLuceneIndex())); - doNothing().when(spy).onNewCheckpoint(any(), any()); - verify(spy, timeout(100).times(2)).onNewCheckpoint(eq(anotherNewCheckpoint), any()); - closeShard(indexShard, false); - } - - public void testBeforeIndexShardClosed_CancelsOngoingReplications() { - final SegmentReplicationTarget target = new SegmentReplicationTarget( - checkpoint, - indexShard, - replicationSource, - mock(SegmentReplicationTargetService.SegmentReplicationListener.class) - ); - final SegmentReplicationTarget spy = Mockito.spy(target); - sut.startReplication(spy); - sut.beforeIndexShardClosed(indexShard.shardId(), indexShard, Settings.EMPTY); - verify(spy, times(1)).cancel(any()); - } - - /** - * Move the {@link SegmentReplicationTarget} object through its {@link SegmentReplicationState.Stage} values in order - * until the final, non-terminal stage. - */ - private void moveTargetToFinalStage(SegmentReplicationTarget target) { - SegmentReplicationState.Stage[] stageValues = SegmentReplicationState.Stage.values(); - assertEquals(target.state().getStage(), SegmentReplicationState.Stage.INIT); - // Skip the first two stages (DONE and INIT) and iterate until the last value - for (int i = 2; i < stageValues.length; i++) { - target.state().setStage(stageValues[i]); - } - } } diff --git a/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java index 4d3b841e203de..ff4005d9bcedf 100644 --- a/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java @@ -185,6 +185,7 @@ import org.opensearch.indices.recovery.PeerRecoveryTargetService; import org.opensearch.indices.recovery.RecoverySettings; import org.opensearch.indices.replication.SegmentReplicationSourceFactory; +import org.opensearch.indices.replication.SegmentReplicationSourceService; import org.opensearch.indices.replication.SegmentReplicationTargetService; import org.opensearch.indices.replication.checkpoint.SegmentReplicationCheckpointPublisher; import org.opensearch.ingest.IngestService; @@ -1857,6 +1858,7 @@ public void onFailure(final Exception e) { transportService, new SegmentReplicationSourceFactory(transportService, recoverySettings, clusterService) ), + SegmentReplicationSourceService.NO_OP, shardStateAction, new NodeMappingRefreshAction(transportService, metadataMappingService), repositoriesService, diff --git a/test/framework/src/main/java/org/opensearch/index/shard/IndexShardTestCase.java b/test/framework/src/main/java/org/opensearch/index/shard/IndexShardTestCase.java index 08004b7e42fea..1b40cb4f2dfa3 100644 --- a/test/framework/src/main/java/org/opensearch/index/shard/IndexShardTestCase.java +++ b/test/framework/src/main/java/org/opensearch/index/shard/IndexShardTestCase.java @@ -68,7 +68,6 @@ import org.opensearch.common.lucene.uid.Versions; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; -import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.BigArrays; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.internal.io.IOUtils; @@ -112,7 +111,10 @@ import org.opensearch.indices.replication.CheckpointInfoResponse; import org.opensearch.indices.replication.GetSegmentFilesResponse; import org.opensearch.indices.replication.SegmentReplicationSource; +import org.opensearch.indices.replication.SegmentReplicationSourceFactory; +import org.opensearch.indices.replication.SegmentReplicationState; import org.opensearch.indices.replication.SegmentReplicationTarget; +import org.opensearch.indices.replication.SegmentReplicationTargetService; import org.opensearch.indices.replication.checkpoint.ReplicationCheckpoint; import org.opensearch.indices.replication.checkpoint.SegmentReplicationCheckpointPublisher; import org.opensearch.indices.replication.common.CopyState; @@ -127,8 +129,10 @@ import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; import java.io.IOException; +import java.util.ArrayList; import java.nio.file.Path; import java.util.Arrays; import java.util.Collections; @@ -146,7 +150,9 @@ import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import static org.opensearch.cluster.routing.TestShardRouting.newShardRouting; /** @@ -1171,35 +1177,40 @@ public static Engine.Warmer createTestWarmer(IndexSettings indexSettings) { } /** - * Segment Replication specific test method - Replicate segments to a list of replicas from a given primary. - * This test will use a real {@link SegmentReplicationTarget} for each replica with a mock {@link SegmentReplicationSource} that - * writes all segments directly to the target. + * Segment Replication specific test method - Creates a {@link SegmentReplicationTargetService} to perform replications that has + * been configured to return the given primaryShard's current segments. + * + * @param primaryShard {@link IndexShard} - The primary shard to replicate from. */ - public final void replicateSegments(IndexShard primaryShard, List replicaShards) throws IOException, InterruptedException { - final CountDownLatch countDownLatch = new CountDownLatch(replicaShards.size()); - Store.MetadataSnapshot primaryMetadata; - try (final GatedCloseable segmentInfosSnapshot = primaryShard.getSegmentInfosSnapshot()) { - final SegmentInfos primarySegmentInfos = segmentInfosSnapshot.get(); - primaryMetadata = primaryShard.store().getMetadata(primarySegmentInfos); - } - final CopyState copyState = new CopyState(ReplicationCheckpoint.empty(primaryShard.shardId), primaryShard); - - final ReplicationCollection replicationCollection = new ReplicationCollection<>(logger, threadPool); - final SegmentReplicationSource source = new SegmentReplicationSource() { + public final SegmentReplicationTargetService prepareForReplication(IndexShard primaryShard) { + final SegmentReplicationSourceFactory sourceFactory = mock(SegmentReplicationSourceFactory.class); + final SegmentReplicationTargetService targetService = new SegmentReplicationTargetService( + threadPool, + new RecoverySettings(Settings.EMPTY, new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS)), + mock(TransportService.class), + sourceFactory + ); + final SegmentReplicationSource replicationSource = new SegmentReplicationSource() { @Override public void getCheckpointMetadata( long replicationId, ReplicationCheckpoint checkpoint, ActionListener listener ) { - listener.onResponse( - new CheckpointInfoResponse( - copyState.getCheckpoint(), - copyState.getMetadataSnapshot(), - copyState.getInfosBytes(), - copyState.getPendingDeleteFiles() - ) - ); + try { + final CopyState copyState = new CopyState(ReplicationCheckpoint.empty(primaryShard.shardId), primaryShard); + listener.onResponse( + new CheckpointInfoResponse( + copyState.getCheckpoint(), + copyState.getMetadataSnapshot(), + copyState.getInfosBytes(), + copyState.getPendingDeleteFiles() + ) + ); + } catch (IOException e) { + logger.error("Unexpected error computing CopyState", e); + Assert.fail("Failed to compute copyState"); + } } @Override @@ -1211,9 +1222,7 @@ public void getSegmentFiles( ActionListener listener ) { try ( - final ReplicationCollection.ReplicationRef replicationRef = replicationCollection.get( - replicationId - ) + final ReplicationCollection.ReplicationRef replicationRef = targetService.get(replicationId) ) { writeFileChunks(replicationRef.get(), primaryShard, filesToFetch.toArray(new StoreFileMetadata[] {})); } catch (IOException e) { @@ -1222,15 +1231,43 @@ public void getSegmentFiles( listener.onResponse(new GetSegmentFilesResponse(filesToFetch)); } }; + when(sourceFactory.get(any())).thenReturn(replicationSource); + return targetService; + } + /** + * Segment Replication specific test method - Replicate segments to a list of replicas from a given primary. + * This test will use a real {@link SegmentReplicationTarget} for each replica with a mock {@link SegmentReplicationSource} that + * writes all segments directly to the target. + * @param primaryShard - {@link IndexShard} The current primary shard. + * @param replicaShards - Replicas that will be updated. + * @return {@link List} List of target components orchestrating replication. + */ + public final List replicateSegments(IndexShard primaryShard, List replicaShards) + throws IOException, InterruptedException { + final SegmentReplicationTargetService targetService = prepareForReplication(primaryShard); + return replicateSegments(targetService, primaryShard, replicaShards); + } + + public final List replicateSegments( + SegmentReplicationTargetService targetService, + IndexShard primaryShard, + List replicaShards + ) throws IOException, InterruptedException { + final CountDownLatch countDownLatch = new CountDownLatch(replicaShards.size()); + Store.MetadataSnapshot primaryMetadata; + try (final GatedCloseable segmentInfosSnapshot = primaryShard.getSegmentInfosSnapshot()) { + final SegmentInfos primarySegmentInfos = segmentInfosSnapshot.get(); + primaryMetadata = primaryShard.store().getMetadata(primarySegmentInfos); + } + List ids = new ArrayList<>(); for (IndexShard replica : replicaShards) { - final SegmentReplicationTarget target = new SegmentReplicationTarget( + final SegmentReplicationTarget target = targetService.startReplication( ReplicationCheckpoint.empty(replica.shardId), replica, - source, - new ReplicationListener() { + new SegmentReplicationTargetService.SegmentReplicationListener() { @Override - public void onDone(ReplicationState state) { + public void onReplicationDone(SegmentReplicationState state) { try (final GatedCloseable snapshot = replica.getSegmentInfosSnapshot()) { final SegmentInfos replicaInfos = snapshot.get(); final Store.MetadataSnapshot replicaMetadata = replica.store().getMetadata(replicaInfos); @@ -1241,31 +1278,22 @@ public void onDone(ReplicationState state) { assertEquals(primaryMetadata.getCommitUserData(), replicaMetadata.getCommitUserData()); } catch (Exception e) { throw ExceptionsHelper.convertToRuntime(e); + } finally { + countDownLatch.countDown(); } - countDownLatch.countDown(); } @Override - public void onFailure(ReplicationState state, OpenSearchException e, boolean sendShardFailure) { + public void onReplicationFailure(SegmentReplicationState state, OpenSearchException e, boolean sendShardFailure) { logger.error("Unexpected replication failure in test", e); Assert.fail("test replication should not fail: " + e); } } ); - replicationCollection.start(target, TimeValue.timeValueMillis(5000)); - target.startReplication(new ActionListener<>() { - @Override - public void onResponse(Void o) { - replicationCollection.markAsDone(target.getId()); - } - - @Override - public void onFailure(Exception e) { - replicationCollection.fail(target.getId(), new OpenSearchException("Segment Replication failed", e), true); - } - }); + ids.add(target); + countDownLatch.await(1, TimeUnit.SECONDS); } - countDownLatch.await(3, TimeUnit.SECONDS); + return ids; } private void writeFileChunks(SegmentReplicationTarget target, IndexShard primary, StoreFileMetadata[] files) throws IOException {