From b0d892b9bb7b8b3202259b3de15ee551a2011a3f Mon Sep 17 00:00:00 2001 From: Marc Handalian Date: Wed, 31 Aug 2022 13:12:24 -0700 Subject: [PATCH] Segment Replication - Implement segment replication event cancellation. (#4225) * Segment Replication. Fix Cancellation of replication events. This PR updates segment replication paths to correctly cancel replication events on the primary and replica. In the source service, any ongoing event for a primary that is sending to a replica that shuts down or is promoted as a new primary are cancelled. In the target service, any ongoing event for a replica that is promoted as a new primary or is fetching from a primary that shuts down. It wires up SegmentReplicationSourceService as an IndexEventListener so that it can respond to events and cancel any ongoing transfer state. This change also includes some test cleanup for segment replication to rely on actual components over mocks. Signed-off-by: Marc Handalian Fix to not start/stop SegmentReplicationSourceService as a lifecycle component with feature flag off. Signed-off-by: Marc Handalian Update logic to properly mark SegmentReplicationTarget as cancelled when cancel initiated by primary. Signed-off-by: Marc Handalian Minor updates from self review. Signed-off-by: Marc Handalian * Add missing changelog entry. Signed-off-by: Marc Handalian Signed-off-by: Marc Handalian (cherry picked from commit 19d1a2b027fef8b981560969bf428476d700bd07) --- CHANGELOG.md | 1 + .../cluster/IndicesClusterStateService.java | 5 + .../OngoingSegmentReplications.java | 22 +- .../PrimaryShardReplicationSource.java | 6 + .../replication/SegmentReplicationSource.java | 6 + .../SegmentReplicationSourceHandler.java | 11 + .../SegmentReplicationSourceService.java | 44 ++- .../replication/SegmentReplicationState.java | 15 +- .../replication/SegmentReplicationTarget.java | 31 +- .../SegmentReplicationTargetService.java | 53 +++- .../main/java/org/opensearch/node/Node.java | 7 + .../SegmentReplicationIndexShardTests.java | 275 ++++++++++++++++++ ...ClusterStateServiceRandomUpdatesTests.java | 2 + .../OngoingSegmentReplicationsTests.java | 49 ++++ .../PrimaryShardReplicationSourceTests.java | 37 +++ .../SegmentReplicationSourceHandlerTests.java | 46 +++ .../SegmentReplicationTargetServiceTests.java | 205 +++++-------- .../snapshots/SnapshotResiliencyTests.java | 2 + .../index/shard/IndexShardTestCase.java | 118 +++++--- 19 files changed, 739 insertions(+), 196 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 800786d504cf0..3ce717efdb73a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ### Added - Github workflow for changelog verification ([#4085](https://github.com/opensearch-project/OpenSearch/pull/4085)) - Add timing data and more granular stages to SegmentReplicationState ([#4367](https://github.com/opensearch-project/OpenSearch/pull/4367)) +- Fixed cancellation of segment replication events ([#4225](https://github.com/opensearch-project/OpenSearch/pull/4225)) ### Changed diff --git a/server/src/main/java/org/opensearch/indices/cluster/IndicesClusterStateService.java b/server/src/main/java/org/opensearch/indices/cluster/IndicesClusterStateService.java index ed66fb448ba95..c994e582971ef 100644 --- a/server/src/main/java/org/opensearch/indices/cluster/IndicesClusterStateService.java +++ b/server/src/main/java/org/opensearch/indices/cluster/IndicesClusterStateService.java @@ -81,6 +81,7 @@ import org.opensearch.indices.recovery.PeerRecoveryTargetService; import org.opensearch.indices.recovery.RecoveryListener; import org.opensearch.indices.recovery.RecoveryState; +import org.opensearch.indices.replication.SegmentReplicationSourceService; import org.opensearch.indices.replication.SegmentReplicationTargetService; import org.opensearch.indices.replication.checkpoint.SegmentReplicationCheckpointPublisher; import org.opensearch.indices.replication.common.ReplicationState; @@ -152,6 +153,7 @@ public IndicesClusterStateService( final ThreadPool threadPool, final PeerRecoveryTargetService recoveryTargetService, final SegmentReplicationTargetService segmentReplicationTargetService, + final SegmentReplicationSourceService segmentReplicationSourceService, final ShardStateAction shardStateAction, final NodeMappingRefreshAction nodeMappingRefreshAction, final RepositoriesService repositoriesService, @@ -170,6 +172,7 @@ public IndicesClusterStateService( threadPool, checkpointPublisher, segmentReplicationTargetService, + segmentReplicationSourceService, recoveryTargetService, shardStateAction, nodeMappingRefreshAction, @@ -191,6 +194,7 @@ public IndicesClusterStateService( final ThreadPool threadPool, final SegmentReplicationCheckpointPublisher checkpointPublisher, final SegmentReplicationTargetService segmentReplicationTargetService, + final SegmentReplicationSourceService segmentReplicationSourceService, final PeerRecoveryTargetService recoveryTargetService, final ShardStateAction shardStateAction, final NodeMappingRefreshAction nodeMappingRefreshAction, @@ -211,6 +215,7 @@ public IndicesClusterStateService( // if segrep feature flag is not enabled, don't wire the target serivce as an IndexEventListener. if (FeatureFlags.isEnabled(FeatureFlags.REPLICATION_TYPE)) { indexEventListeners.add(segmentReplicationTargetService); + indexEventListeners.add(segmentReplicationSourceService); } this.builtInIndexListener = Collections.unmodifiableList(indexEventListeners); this.indicesService = indicesService; diff --git a/server/src/main/java/org/opensearch/indices/replication/OngoingSegmentReplications.java b/server/src/main/java/org/opensearch/indices/replication/OngoingSegmentReplications.java index dfebe5f7cabf2..828aa29192fe3 100644 --- a/server/src/main/java/org/opensearch/indices/replication/OngoingSegmentReplications.java +++ b/server/src/main/java/org/opensearch/indices/replication/OngoingSegmentReplications.java @@ -37,7 +37,6 @@ * @opensearch.internal */ class OngoingSegmentReplications { - private final RecoverySettings recoverySettings; private final IndicesService indicesService; private final Map copyStateMap; @@ -161,6 +160,20 @@ synchronized void cancel(IndexShard shard, String reason) { cancelHandlers(handler -> handler.getCopyState().getShard().shardId().equals(shard.shardId()), reason); } + /** + * Cancel all Replication events for the given allocation ID, intended to be called when a primary is shutting down. + * + * @param allocationId {@link String} - Allocation ID. + * @param reason {@link String} - Reason for the cancel + */ + synchronized void cancel(String allocationId, String reason) { + final SegmentReplicationSourceHandler handler = allocationIdToHandlers.remove(allocationId); + if (handler != null) { + handler.cancel(reason); + removeCopyState(handler.getCopyState()); + } + } + /** * Cancel any ongoing replications for a given {@link DiscoveryNode} * @@ -168,7 +181,6 @@ synchronized void cancel(IndexShard shard, String reason) { */ void cancelReplication(DiscoveryNode node) { cancelHandlers(handler -> handler.getTargetNode().equals(node), "Node left"); - } /** @@ -243,11 +255,7 @@ private void cancelHandlers(Predicate p .map(SegmentReplicationSourceHandler::getAllocationId) .collect(Collectors.toList()); for (String allocationId : allocationIds) { - final SegmentReplicationSourceHandler handler = allocationIdToHandlers.remove(allocationId); - if (handler != null) { - handler.cancel(reason); - removeCopyState(handler.getCopyState()); - } + cancel(allocationId, reason); } } } diff --git a/server/src/main/java/org/opensearch/indices/replication/PrimaryShardReplicationSource.java b/server/src/main/java/org/opensearch/indices/replication/PrimaryShardReplicationSource.java index 08dc0b97b31d5..aa0b5416dd0ff 100644 --- a/server/src/main/java/org/opensearch/indices/replication/PrimaryShardReplicationSource.java +++ b/server/src/main/java/org/opensearch/indices/replication/PrimaryShardReplicationSource.java @@ -87,4 +87,10 @@ public void getSegmentFiles( ); transportClient.executeRetryableAction(GET_SEGMENT_FILES, request, responseListener, reader); } + + @Override + public void cancel() { + transportClient.cancel(); + } + } diff --git a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSource.java b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSource.java index 8628a266ea7d0..b2e7487fff4b2 100644 --- a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSource.java +++ b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSource.java @@ -9,6 +9,7 @@ package org.opensearch.indices.replication; import org.opensearch.action.ActionListener; +import org.opensearch.common.util.CancellableThreads.ExecutionCancelledException; import org.opensearch.index.store.Store; import org.opensearch.index.store.StoreFileMetadata; import org.opensearch.indices.replication.checkpoint.ReplicationCheckpoint; @@ -47,4 +48,9 @@ void getSegmentFiles( Store store, ActionListener listener ); + + /** + * Cancel any ongoing requests, should resolve any ongoing listeners with onFailure with a {@link ExecutionCancelledException}. + */ + default void cancel() {} } diff --git a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSourceHandler.java b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSourceHandler.java index 2d21653c1924c..022d90b41d8ee 100644 --- a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSourceHandler.java +++ b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSourceHandler.java @@ -113,6 +113,16 @@ public synchronized void sendFiles(GetSegmentFilesRequest request, ActionListene final Closeable releaseResources = () -> IOUtils.close(resources); try { timer.start(); + cancellableThreads.setOnCancel((reason, beforeCancelEx) -> { + final RuntimeException e = new CancellableThreads.ExecutionCancelledException( + "replication was canceled reason [" + reason + "]" + ); + if (beforeCancelEx != null) { + e.addSuppressed(beforeCancelEx); + } + IOUtils.closeWhileHandlingException(releaseResources, () -> future.onFailure(e)); + throw e; + }); final Consumer onFailure = e -> { assert Transports.assertNotTransportThread(SegmentReplicationSourceHandler.this + "[onFailure]"); IOUtils.closeWhileHandlingException(releaseResources, () -> future.onFailure(e)); @@ -153,6 +163,7 @@ public synchronized void sendFiles(GetSegmentFilesRequest request, ActionListene final MultiChunkTransfer transfer = segmentFileTransferHandler .createTransfer(shard.store(), storeFileMetadata, () -> 0, sendFileStep); resources.add(transfer); + cancellableThreads.checkForCancel(); transfer.start(); sendFileStep.whenComplete(r -> { diff --git a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSourceService.java b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSourceService.java index 0cee731fde2cb..db3f87201b774 100644 --- a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSourceService.java +++ b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSourceService.java @@ -15,6 +15,7 @@ import org.opensearch.cluster.ClusterChangedEvent; import org.opensearch.cluster.ClusterStateListener; import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.routing.ShardRouting; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.Nullable; import org.opensearch.common.component.AbstractLifecycleComponent; @@ -42,7 +43,25 @@ * * @opensearch.internal */ -public final class SegmentReplicationSourceService extends AbstractLifecycleComponent implements ClusterStateListener, IndexEventListener { +public class SegmentReplicationSourceService extends AbstractLifecycleComponent implements ClusterStateListener, IndexEventListener { + + // Empty Implementation, only required while Segment Replication is under feature flag. + public static final SegmentReplicationSourceService NO_OP = new SegmentReplicationSourceService() { + @Override + public void clusterChanged(ClusterChangedEvent event) { + // NoOp; + } + + @Override + public void beforeIndexShardClosed(ShardId shardId, IndexShard indexShard, Settings indexSettings) { + // NoOp; + } + + @Override + public void shardRoutingChanged(IndexShard indexShard, @Nullable ShardRouting oldRouting, ShardRouting newRouting) { + // NoOp; + } + }; private static final Logger logger = LogManager.getLogger(SegmentReplicationSourceService.class); private final RecoverySettings recoverySettings; @@ -62,6 +81,14 @@ public static class Actions { private final OngoingSegmentReplications ongoingSegmentReplications; + // Used only for empty implementation. + private SegmentReplicationSourceService() { + recoverySettings = null; + ongoingSegmentReplications = null; + transportService = null; + indicesService = null; + } + public SegmentReplicationSourceService( IndicesService indicesService, TransportService transportService, @@ -163,10 +190,25 @@ protected void doClose() throws IOException { } + /** + * + * Cancels any replications on this node to a replica shard that is about to be closed. + */ @Override public void beforeIndexShardClosed(ShardId shardId, @Nullable IndexShard indexShard, Settings indexSettings) { if (indexShard != null) { ongoingSegmentReplications.cancel(indexShard, "shard is closed"); } } + + /** + * Cancels any replications on this node to a replica that has been promoted as primary. + */ + @Override + public void shardRoutingChanged(IndexShard indexShard, @Nullable ShardRouting oldRouting, ShardRouting newRouting) { + if (indexShard != null && oldRouting.primary() == false && newRouting.primary()) { + ongoingSegmentReplications.cancel(indexShard.routingEntry().allocationId().getId(), "Relocating primary shard."); + } + } + } diff --git a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationState.java b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationState.java index f865ba1332186..2e2e6df007c5c 100644 --- a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationState.java +++ b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationState.java @@ -35,7 +35,8 @@ public enum Stage { GET_CHECKPOINT_INFO((byte) 3), FILE_DIFF((byte) 4), GET_FILES((byte) 5), - FINALIZE_REPLICATION((byte) 6); + FINALIZE_REPLICATION((byte) 6), + CANCELLED((byte) 7); private static final Stage[] STAGES = new Stage[Stage.values().length]; @@ -118,6 +119,10 @@ protected void validateAndSetStage(Stage expected, Stage next) { "can't move replication to stage [" + next + "]. current stage: [" + stage + "] (expected [" + expected + "])" ); } + stopTimersAndSetStage(next); + } + + private void stopTimersAndSetStage(Stage next) { // save the timing data for the current step stageTimer.stop(); timingData.add(new Tuple<>(stage.name(), stageTimer.time())); @@ -155,6 +160,14 @@ public void setStage(Stage stage) { overallTimer.stop(); timingData.add(new Tuple<>("OVERALL", overallTimer.time())); break; + case CANCELLED: + if (this.stage == Stage.DONE) { + throw new IllegalStateException("can't move replication to Cancelled state from Done."); + } + stopTimersAndSetStage(Stage.CANCELLED); + overallTimer.stop(); + timingData.add(new Tuple<>("OVERALL", overallTimer.time())); + break; default: throw new IllegalArgumentException("unknown SegmentReplicationState.Stage [" + stage + "]"); } diff --git a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTarget.java b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTarget.java index a658ffc09d590..d1d6104a416ca 100644 --- a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTarget.java +++ b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTarget.java @@ -17,6 +17,7 @@ import org.apache.lucene.store.ByteBuffersDataInput; import org.apache.lucene.store.ByteBuffersIndexInput; import org.apache.lucene.store.ChecksumIndexInput; +import org.opensearch.ExceptionsHelper; import org.opensearch.OpenSearchException; import org.opensearch.action.ActionListener; import org.opensearch.action.StepListener; @@ -103,7 +104,15 @@ public String description() { @Override public void notifyListener(OpenSearchException e, boolean sendShardFailure) { - listener.onFailure(state(), e, sendShardFailure); + // Cancellations still are passed to our SegmentReplicationListner as failures, if we have failed because of cancellation + // update the stage. + final Throwable cancelledException = ExceptionsHelper.unwrap(e, CancellableThreads.ExecutionCancelledException.class); + if (cancelledException != null) { + state.setStage(SegmentReplicationState.Stage.CANCELLED); + listener.onFailure(state(), (CancellableThreads.ExecutionCancelledException) cancelledException, sendShardFailure); + } else { + listener.onFailure(state(), e, sendShardFailure); + } } @Override @@ -134,6 +143,14 @@ public void writeFileChunk( * @param listener {@link ActionListener} listener. */ public void startReplication(ActionListener listener) { + cancellableThreads.setOnCancel((reason, beforeCancelEx) -> { + // This method only executes when cancellation is triggered by this node and caught by a call to checkForCancel, + // SegmentReplicationSource does not share CancellableThreads. + final CancellableThreads.ExecutionCancelledException executionCancelledException = + new CancellableThreads.ExecutionCancelledException("replication was canceled reason [" + reason + "]"); + notifyListener(executionCancelledException, false); + throw executionCancelledException; + }); state.setStage(SegmentReplicationState.Stage.REPLICATING); final StepListener checkpointInfoListener = new StepListener<>(); final StepListener getFilesListener = new StepListener<>(); @@ -141,6 +158,7 @@ public void startReplication(ActionListener listener) { logger.trace("[shardId {}] Replica starting replication [id {}]", shardId().getId(), getId()); // Get list of files to copy from this checkpoint. + cancellableThreads.checkForCancel(); state.setStage(SegmentReplicationState.Stage.GET_CHECKPOINT_INFO); source.getCheckpointMetadata(getId(), checkpoint, checkpointInfoListener); @@ -154,6 +172,7 @@ public void startReplication(ActionListener listener) { private void getFiles(CheckpointInfoResponse checkpointInfo, StepListener getFilesListener) throws IOException { + cancellableThreads.checkForCancel(); state.setStage(SegmentReplicationState.Stage.FILE_DIFF); final Store.MetadataSnapshot snapshot = checkpointInfo.getSnapshot(); Store.MetadataSnapshot localMetadata = getMetadataSnapshot(); @@ -188,12 +207,14 @@ private void getFiles(CheckpointInfoResponse checkpointInfo, StepListener listener) { - state.setStage(SegmentReplicationState.Stage.FINALIZE_REPLICATION); ActionListener.completeWith(listener, () -> { + cancellableThreads.checkForCancel(); + state.setStage(SegmentReplicationState.Stage.FINALIZE_REPLICATION); multiFileWriter.renameAllTempFiles(); final Store store = store(); store.incRef(); @@ -261,4 +282,10 @@ Store.MetadataSnapshot getMetadataSnapshot() throws IOException { } return store.getMetadata(indexShard.getSegmentInfosSnapshot().get()); } + + @Override + protected void onCancel(String reason) { + cancellableThreads.cancel(reason); + source.cancel(); + } } diff --git a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTargetService.java b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTargetService.java index a79ce195ad83b..9e6b66dc4d7d6 100644 --- a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTargetService.java +++ b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTargetService.java @@ -11,10 +11,13 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.ExceptionsHelper; import org.opensearch.OpenSearchException; import org.opensearch.action.ActionListener; +import org.opensearch.cluster.routing.ShardRouting; import org.opensearch.common.Nullable; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.CancellableThreads; import org.opensearch.index.shard.IndexEventListener; import org.opensearch.index.shard.IndexShard; import org.opensearch.index.shard.ShardId; @@ -64,6 +67,11 @@ public void beforeIndexShardClosed(ShardId shardId, IndexShard indexShard, Setti public synchronized void onNewCheckpoint(ReplicationCheckpoint receivedCheckpoint, IndexShard replicaShard) { // noOp; } + + @Override + public void shardRoutingChanged(IndexShard indexShard, @Nullable ShardRouting oldRouting, ShardRouting newRouting) { + // noOp; + } }; // Used only for empty implementation. @@ -74,6 +82,10 @@ private SegmentReplicationTargetService() { sourceFactory = null; } + public ReplicationRef get(long replicationId) { + return onGoingReplications.get(replicationId); + } + /** * The internal actions * @@ -102,6 +114,9 @@ public SegmentReplicationTargetService( ); } + /** + * Cancel any replications on this node for a replica that is about to be closed. + */ @Override public void beforeIndexShardClosed(ShardId shardId, @Nullable IndexShard indexShard, Settings indexSettings) { if (indexShard != null) { @@ -109,11 +124,22 @@ public void beforeIndexShardClosed(ShardId shardId, @Nullable IndexShard indexSh } } + /** + * Cancel any replications on this node for a replica that has just been promoted as the new primary. + */ + @Override + public void shardRoutingChanged(IndexShard indexShard, @Nullable ShardRouting oldRouting, ShardRouting newRouting) { + if (oldRouting != null && oldRouting.primary() == false && newRouting.primary()) { + onGoingReplications.cancelForShard(indexShard.shardId(), "shard has been promoted to primary"); + } + } + /** * Invoked when a new checkpoint is received from a primary shard. * It checks if a new checkpoint should be processed or not and starts replication if needed. - * @param receivedCheckpoint received checkpoint that is checked for processing - * @param replicaShard replica shard on which checkpoint is received + * + * @param receivedCheckpoint received checkpoint that is checked for processing + * @param replicaShard replica shard on which checkpoint is received */ public synchronized void onNewCheckpoint(final ReplicationCheckpoint receivedCheckpoint, final IndexShard replicaShard) { logger.trace(() -> new ParameterizedMessage("Replica received new replication checkpoint from primary [{}]", receivedCheckpoint)); @@ -180,12 +206,19 @@ public void onReplicationFailure(SegmentReplicationState state, OpenSearchExcept } } - public void startReplication( + public SegmentReplicationTarget startReplication( final ReplicationCheckpoint checkpoint, final IndexShard indexShard, final SegmentReplicationListener listener ) { - startReplication(new SegmentReplicationTarget(checkpoint, indexShard, sourceFactory.get(indexShard), listener)); + final SegmentReplicationTarget target = new SegmentReplicationTarget( + checkpoint, + indexShard, + sourceFactory.get(indexShard), + listener + ); + startReplication(target); + return target; } // pkg-private for integration tests @@ -248,7 +281,17 @@ public void onResponse(Void o) { @Override public void onFailure(Exception e) { - onGoingReplications.fail(replicationId, new OpenSearchException("Segment Replication failed", e), true); + Throwable cause = ExceptionsHelper.unwrapCause(e); + if (cause instanceof CancellableThreads.ExecutionCancelledException) { + if (onGoingReplications.getTarget(replicationId) != null) { + // if the target still exists in our collection, the primary initiated the cancellation, fail the replication + // but do not fail the shard. Cancellations initiated by this node from Index events will be removed with + // onGoingReplications.cancel and not appear in the collection when this listener resolves. + onGoingReplications.fail(replicationId, (CancellableThreads.ExecutionCancelledException) cause, false); + } + } else { + onGoingReplications.fail(replicationId, new OpenSearchException("Segment Replication failed", e), true); + } } }); } diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java index 0ac8471be7087..f4a93b80cecd6 100644 --- a/server/src/main/java/org/opensearch/node/Node.java +++ b/server/src/main/java/org/opensearch/node/Node.java @@ -963,6 +963,7 @@ protected Node( .toInstance(new SegmentReplicationSourceService(indicesService, transportService, recoverySettings)); } else { b.bind(SegmentReplicationTargetService.class).toInstance(SegmentReplicationTargetService.NO_OP); + b.bind(SegmentReplicationSourceService.class).toInstance(SegmentReplicationSourceService.NO_OP); } } b.bind(HttpServerTransport.class).toInstance(httpServerTransport); @@ -1106,6 +1107,9 @@ public Node start() throws NodeValidationException { assert transportService.getLocalNode().equals(localNodeFactory.getNode()) : "transportService has a different local node than the factory provided"; injector.getInstance(PeerRecoverySourceService.class).start(); + if (FeatureFlags.isEnabled(REPLICATION_TYPE)) { + injector.getInstance(SegmentReplicationSourceService.class).start(); + } // Load (and maybe upgrade) the metadata stored on disk final GatewayMetaState gatewayMetaState = injector.getInstance(GatewayMetaState.class); @@ -1281,6 +1285,9 @@ public synchronized void close() throws IOException { // close filter/fielddata caches after indices toClose.add(injector.getInstance(IndicesStore.class)); toClose.add(injector.getInstance(PeerRecoverySourceService.class)); + if (FeatureFlags.isEnabled(REPLICATION_TYPE)) { + toClose.add(injector.getInstance(SegmentReplicationSourceService.class)); + } toClose.add(() -> stopWatch.stop().start("cluster")); toClose.add(injector.getInstance(ClusterService.class)); toClose.add(() -> stopWatch.stop().start("node_connections_service")); diff --git a/server/src/test/java/org/opensearch/index/shard/SegmentReplicationIndexShardTests.java b/server/src/test/java/org/opensearch/index/shard/SegmentReplicationIndexShardTests.java index 23371a39871c7..88a3bdad53d0c 100644 --- a/server/src/test/java/org/opensearch/index/shard/SegmentReplicationIndexShardTests.java +++ b/server/src/test/java/org/opensearch/index/shard/SegmentReplicationIndexShardTests.java @@ -8,11 +8,18 @@ package org.opensearch.index.shard; +import org.junit.Assert; +import org.opensearch.OpenSearchException; +import org.opensearch.action.ActionListener; import org.opensearch.action.delete.DeleteRequest; import org.opensearch.action.index.IndexRequest; import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.routing.ShardRouting; +import org.opensearch.common.lease.Releasable; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.CancellableThreads; import org.opensearch.common.xcontent.XContentType; import org.opensearch.index.IndexSettings; import org.opensearch.index.engine.DocIdSeqNoAndSource; @@ -21,12 +28,28 @@ import org.opensearch.index.engine.NRTReplicationEngineFactory; import org.opensearch.index.mapper.MapperService; import org.opensearch.index.replication.OpenSearchIndexLevelReplicationTestCase; +import org.opensearch.index.store.Store; +import org.opensearch.index.store.StoreFileMetadata; +import org.opensearch.indices.recovery.RecoverySettings; +import org.opensearch.indices.replication.CheckpointInfoResponse; +import org.opensearch.indices.replication.GetSegmentFilesResponse; +import org.opensearch.indices.replication.SegmentReplicationSource; +import org.opensearch.indices.replication.SegmentReplicationSourceFactory; +import org.opensearch.indices.replication.SegmentReplicationState; +import org.opensearch.indices.replication.SegmentReplicationTarget; +import org.opensearch.indices.replication.SegmentReplicationTargetService; import org.opensearch.indices.replication.checkpoint.SegmentReplicationCheckpointPublisher; import org.opensearch.indices.replication.checkpoint.ReplicationCheckpoint; +import org.opensearch.indices.replication.common.CopyState; import org.opensearch.indices.replication.common.ReplicationType; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; import java.io.IOException; +import java.util.Collections; import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import static java.util.Arrays.asList; import static org.hamcrest.Matchers.equalTo; @@ -34,6 +57,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; public class SegmentReplicationIndexShardTests extends OpenSearchIndexLevelReplicationTestCase { @@ -241,6 +265,213 @@ public void testNRTReplicaPromotedAsPrimary() throws Exception { } } + public void testReplicaPromotedWhileReplicating() throws Exception { + try (ReplicationGroup shards = createGroup(1, settings, new NRTReplicationEngineFactory())) { + shards.startAll(); + final IndexShard oldPrimary = shards.getPrimary(); + final IndexShard nextPrimary = shards.getReplicas().get(0); + + final int numDocs = shards.indexDocs(randomInt(10)); + oldPrimary.refresh("Test"); + shards.syncGlobalCheckpoint(); + + final SegmentReplicationSourceFactory sourceFactory = mock(SegmentReplicationSourceFactory.class); + final SegmentReplicationTargetService targetService = newTargetService(sourceFactory); + SegmentReplicationSource source = new SegmentReplicationSource() { + @Override + public void getCheckpointMetadata( + long replicationId, + ReplicationCheckpoint checkpoint, + ActionListener listener + ) { + resolveCheckpointInfoResponseListener(listener, oldPrimary); + ShardRouting oldRouting = nextPrimary.shardRouting; + try { + shards.promoteReplicaToPrimary(nextPrimary); + } catch (IOException e) { + Assert.fail("Promotion should not fail"); + } + targetService.shardRoutingChanged(nextPrimary, oldRouting, nextPrimary.shardRouting); + } + + @Override + public void getSegmentFiles( + long replicationId, + ReplicationCheckpoint checkpoint, + List filesToFetch, + Store store, + ActionListener listener + ) { + listener.onResponse(new GetSegmentFilesResponse(Collections.emptyList())); + } + }; + when(sourceFactory.get(any())).thenReturn(source); + startReplicationAndAssertCancellation(nextPrimary, targetService); + // wait for replica to finish being promoted, and assert doc counts. + final CountDownLatch latch = new CountDownLatch(1); + nextPrimary.acquirePrimaryOperationPermit(new ActionListener<>() { + @Override + public void onResponse(Releasable releasable) { + latch.countDown(); + } + + @Override + public void onFailure(Exception e) { + throw new AssertionError(e); + } + }, ThreadPool.Names.GENERIC, ""); + latch.await(); + assertEquals(nextPrimary.getEngine().getClass(), InternalEngine.class); + nextPrimary.refresh("test"); + + oldPrimary.close("demoted", false); + oldPrimary.store().close(); + IndexShard newReplica = shards.addReplicaWithExistingPath(oldPrimary.shardPath(), oldPrimary.routingEntry().currentNodeId()); + shards.recoverReplica(newReplica); + + assertDocCount(nextPrimary, numDocs); + assertDocCount(newReplica, numDocs); + + nextPrimary.refresh("test"); + replicateSegments(nextPrimary, shards.getReplicas()); + final List docsAfterRecovery = getDocIdAndSeqNos(shards.getPrimary()); + for (IndexShard shard : shards.getReplicas()) { + assertThat(shard.routingEntry().toString(), getDocIdAndSeqNos(shard), equalTo(docsAfterRecovery)); + } + } + } + + public void testReplicaClosesWhileReplicating_AfterGetCheckpoint() throws Exception { + try (ReplicationGroup shards = createGroup(1, settings, new NRTReplicationEngineFactory())) { + shards.startAll(); + IndexShard primary = shards.getPrimary(); + final IndexShard replica = shards.getReplicas().get(0); + + final int numDocs = shards.indexDocs(randomInt(10)); + primary.refresh("Test"); + + final SegmentReplicationSourceFactory sourceFactory = mock(SegmentReplicationSourceFactory.class); + final SegmentReplicationTargetService targetService = newTargetService(sourceFactory); + SegmentReplicationSource source = new SegmentReplicationSource() { + @Override + public void getCheckpointMetadata( + long replicationId, + ReplicationCheckpoint checkpoint, + ActionListener listener + ) { + // trigger a cancellation by closing the replica. + targetService.beforeIndexShardClosed(replica.shardId, replica, Settings.EMPTY); + resolveCheckpointInfoResponseListener(listener, primary); + } + + @Override + public void getSegmentFiles( + long replicationId, + ReplicationCheckpoint checkpoint, + List filesToFetch, + Store store, + ActionListener listener + ) { + Assert.fail("Should not be reached"); + } + }; + when(sourceFactory.get(any())).thenReturn(source); + startReplicationAndAssertCancellation(replica, targetService); + + shards.removeReplica(replica); + closeShards(replica); + } + } + + public void testReplicaClosesWhileReplicating_AfterGetSegmentFiles() throws Exception { + try (ReplicationGroup shards = createGroup(1, settings, new NRTReplicationEngineFactory())) { + shards.startAll(); + IndexShard primary = shards.getPrimary(); + final IndexShard replica = shards.getReplicas().get(0); + + final int numDocs = shards.indexDocs(randomInt(10)); + primary.refresh("Test"); + + final SegmentReplicationSourceFactory sourceFactory = mock(SegmentReplicationSourceFactory.class); + final SegmentReplicationTargetService targetService = newTargetService(sourceFactory); + SegmentReplicationSource source = new SegmentReplicationSource() { + @Override + public void getCheckpointMetadata( + long replicationId, + ReplicationCheckpoint checkpoint, + ActionListener listener + ) { + resolveCheckpointInfoResponseListener(listener, primary); + } + + @Override + public void getSegmentFiles( + long replicationId, + ReplicationCheckpoint checkpoint, + List filesToFetch, + Store store, + ActionListener listener + ) { + // randomly resolve the listener, indicating the source has resolved. + listener.onResponse(new GetSegmentFilesResponse(Collections.emptyList())); + targetService.beforeIndexShardClosed(replica.shardId, replica, Settings.EMPTY); + } + }; + when(sourceFactory.get(any())).thenReturn(source); + startReplicationAndAssertCancellation(replica, targetService); + + shards.removeReplica(replica); + closeShards(replica); + } + } + + public void testPrimaryCancelsExecution() throws Exception { + try (ReplicationGroup shards = createGroup(1, settings, new NRTReplicationEngineFactory())) { + shards.startAll(); + IndexShard primary = shards.getPrimary(); + final IndexShard replica = shards.getReplicas().get(0); + + final int numDocs = shards.indexDocs(randomInt(10)); + primary.refresh("Test"); + + final SegmentReplicationSourceFactory sourceFactory = mock(SegmentReplicationSourceFactory.class); + final SegmentReplicationTargetService targetService = newTargetService(sourceFactory); + SegmentReplicationSource source = new SegmentReplicationSource() { + @Override + public void getCheckpointMetadata( + long replicationId, + ReplicationCheckpoint checkpoint, + ActionListener listener + ) { + listener.onFailure(new CancellableThreads.ExecutionCancelledException("Cancelled")); + } + + @Override + public void getSegmentFiles( + long replicationId, + ReplicationCheckpoint checkpoint, + List filesToFetch, + Store store, + ActionListener listener + ) {} + }; + when(sourceFactory.get(any())).thenReturn(source); + startReplicationAndAssertCancellation(replica, targetService); + + shards.removeReplica(replica); + closeShards(replica); + } + } + + private SegmentReplicationTargetService newTargetService(SegmentReplicationSourceFactory sourceFactory) { + return new SegmentReplicationTargetService( + threadPool, + new RecoverySettings(Settings.EMPTY, new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS)), + mock(TransportService.class), + sourceFactory + ); + } + /** * Assert persisted and searchable doc counts. This method should not be used while docs are concurrently indexed because * it asserts point in time seqNos are relative to the doc counts. @@ -253,4 +484,48 @@ private void assertDocCounts(IndexShard indexShard, int expectedPersistedDocCoun // processed cp should be 1 less than our searchable doc count. assertEquals(expectedSearchableDocCount - 1, indexShard.getProcessedLocalCheckpoint()); } + + private void resolveCheckpointInfoResponseListener(ActionListener listener, IndexShard primary) { + try { + final CopyState copyState = new CopyState(ReplicationCheckpoint.empty(primary.shardId), primary); + listener.onResponse( + new CheckpointInfoResponse( + copyState.getCheckpoint(), + copyState.getMetadataSnapshot(), + copyState.getInfosBytes(), + copyState.getPendingDeleteFiles() + ) + ); + } catch (IOException e) { + logger.error("Unexpected error computing CopyState", e); + Assert.fail("Failed to compute copyState"); + } + } + + private void startReplicationAndAssertCancellation(IndexShard replica, SegmentReplicationTargetService targetService) + throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + final SegmentReplicationTarget target = targetService.startReplication( + ReplicationCheckpoint.empty(replica.shardId), + replica, + new SegmentReplicationTargetService.SegmentReplicationListener() { + @Override + public void onReplicationDone(SegmentReplicationState state) { + Assert.fail("Replication should not complete"); + } + + @Override + public void onReplicationFailure(SegmentReplicationState state, OpenSearchException e, boolean sendShardFailure) { + assertTrue(e instanceof CancellableThreads.ExecutionCancelledException); + assertFalse(sendShardFailure); + assertEquals(SegmentReplicationState.Stage.CANCELLED, state.getStage()); + latch.countDown(); + } + } + ); + + latch.await(2, TimeUnit.SECONDS); + assertEquals("Should have resolved listener with failure", 0, latch.getCount()); + assertNull(targetService.get(target.getId())); + } } diff --git a/server/src/test/java/org/opensearch/indices/cluster/IndicesClusterStateServiceRandomUpdatesTests.java b/server/src/test/java/org/opensearch/indices/cluster/IndicesClusterStateServiceRandomUpdatesTests.java index 1f2360abde2ad..22481b5a7b99f 100644 --- a/server/src/test/java/org/opensearch/indices/cluster/IndicesClusterStateServiceRandomUpdatesTests.java +++ b/server/src/test/java/org/opensearch/indices/cluster/IndicesClusterStateServiceRandomUpdatesTests.java @@ -66,6 +66,7 @@ import org.opensearch.index.shard.PrimaryReplicaSyncer; import org.opensearch.index.shard.ShardId; import org.opensearch.indices.recovery.PeerRecoveryTargetService; +import org.opensearch.indices.replication.SegmentReplicationSourceService; import org.opensearch.indices.replication.SegmentReplicationTargetService; import org.opensearch.indices.replication.checkpoint.SegmentReplicationCheckpointPublisher; import org.opensearch.repositories.RepositoriesService; @@ -572,6 +573,7 @@ private IndicesClusterStateService createIndicesClusterStateService( threadPool, SegmentReplicationCheckpointPublisher.EMPTY, SegmentReplicationTargetService.NO_OP, + SegmentReplicationSourceService.NO_OP, recoveryTargetService, shardStateAction, null, diff --git a/server/src/test/java/org/opensearch/indices/replication/OngoingSegmentReplicationsTests.java b/server/src/test/java/org/opensearch/indices/replication/OngoingSegmentReplicationsTests.java index 38c55620e1223..f49ee0471b5e8 100644 --- a/server/src/test/java/org/opensearch/indices/replication/OngoingSegmentReplicationsTests.java +++ b/server/src/test/java/org/opensearch/indices/replication/OngoingSegmentReplicationsTests.java @@ -14,6 +14,8 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.CancellableThreads; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.index.IndexService; import org.opensearch.index.shard.IndexShard; import org.opensearch.index.shard.IndexShardTestCase; @@ -31,6 +33,8 @@ import java.util.Collection; import java.util.Collections; import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import static org.mockito.Mockito.any; import static org.mockito.Mockito.mock; @@ -154,6 +158,51 @@ public void testCancelReplication() throws IOException { assertEquals(0, replications.cachedCopyStateSize()); } + public void testCancelReplication_AfterSendFilesStarts() throws IOException, InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + OngoingSegmentReplications replications = new OngoingSegmentReplications(mockIndicesService, recoverySettings); + // add a doc and refresh so primary has more than one segment. + indexDoc(primary, "1", "{\"foo\" : \"baz\"}", XContentType.JSON, "foobar"); + primary.refresh("Test"); + final CheckpointInfoRequest request = new CheckpointInfoRequest( + 1L, + replica.routingEntry().allocationId().getId(), + primaryDiscoveryNode, + testCheckpoint + ); + final FileChunkWriter segmentSegmentFileChunkWriter = (fileMetadata, position, content, lastChunk, totalTranslogOps, listener) -> { + // cancel the replication as soon as the writer starts sending files. + replications.cancel(replica.routingEntry().allocationId().getId(), "Test"); + }; + final CopyState copyState = replications.prepareForReplication(request, segmentSegmentFileChunkWriter); + assertEquals(1, replications.size()); + assertEquals(1, replications.cachedCopyStateSize()); + getSegmentFilesRequest = new GetSegmentFilesRequest( + 1L, + replica.routingEntry().allocationId().getId(), + replicaDiscoveryNode, + new ArrayList<>(copyState.getMetadataSnapshot().asMap().values()), + testCheckpoint + ); + replications.startSegmentCopy(getSegmentFilesRequest, new ActionListener<>() { + @Override + public void onResponse(GetSegmentFilesResponse getSegmentFilesResponse) { + Assert.fail("Expected onFailure to be invoked."); + } + + @Override + public void onFailure(Exception e) { + assertEquals(CancellableThreads.ExecutionCancelledException.class, e.getClass()); + assertEquals(0, copyState.refCount()); + assertEquals(0, replications.size()); + assertEquals(0, replications.cachedCopyStateSize()); + latch.countDown(); + } + }); + latch.await(2, TimeUnit.SECONDS); + assertEquals("listener should have resolved with failure", 0, latch.getCount()); + } + public void testMultipleReplicasUseSameCheckpoint() throws IOException { IndexShard secondReplica = newShard(primary.shardId(), false); recoverReplica(secondReplica, primary, true); diff --git a/server/src/test/java/org/opensearch/indices/replication/PrimaryShardReplicationSourceTests.java b/server/src/test/java/org/opensearch/indices/replication/PrimaryShardReplicationSourceTests.java index 6bce74be569c3..323445bee1274 100644 --- a/server/src/test/java/org/opensearch/indices/replication/PrimaryShardReplicationSourceTests.java +++ b/server/src/test/java/org/opensearch/indices/replication/PrimaryShardReplicationSourceTests.java @@ -9,12 +9,14 @@ package org.opensearch.indices.replication; import org.apache.lucene.util.Version; +import org.junit.Assert; import org.opensearch.action.ActionListener; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.node.DiscoveryNodeRole; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.CancellableThreads; import org.opensearch.core.internal.io.IOUtils; import org.opensearch.index.shard.IndexShard; import org.opensearch.index.shard.IndexShardTestCase; @@ -28,6 +30,8 @@ import java.util.Arrays; import java.util.Collections; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import static org.mockito.Mockito.mock; @@ -126,6 +130,39 @@ public void testGetSegmentFiles() { assertTrue(capturedRequest.request instanceof GetSegmentFilesRequest); } + public void testGetSegmentFiles_CancelWhileRequestOpen() throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + final ReplicationCheckpoint checkpoint = new ReplicationCheckpoint( + indexShard.shardId(), + PRIMARY_TERM, + SEGMENTS_GEN, + SEQ_NO, + VERSION + ); + StoreFileMetadata testMetadata = new StoreFileMetadata("testFile", 1L, "checksum", Version.LATEST); + replicationSource.getSegmentFiles( + REPLICATION_ID, + checkpoint, + Arrays.asList(testMetadata), + mock(Store.class), + new ActionListener<>() { + @Override + public void onResponse(GetSegmentFilesResponse getSegmentFilesResponse) { + Assert.fail("onFailure response expected."); + } + + @Override + public void onFailure(Exception e) { + assertEquals(e.getClass(), CancellableThreads.ExecutionCancelledException.class); + latch.countDown(); + } + } + ); + replicationSource.cancel(); + latch.await(2, TimeUnit.SECONDS); + assertEquals("listener should have resolved in a failure", 0, latch.getCount()); + } + private DiscoveryNode newDiscoveryNode(String nodeName) { return new DiscoveryNode( nodeName, diff --git a/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationSourceHandlerTests.java b/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationSourceHandlerTests.java index 2c52772649acc..a6e169dbc3d61 100644 --- a/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationSourceHandlerTests.java +++ b/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationSourceHandlerTests.java @@ -18,6 +18,7 @@ import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.CancellableThreads; import org.opensearch.index.shard.IndexShard; import org.opensearch.index.shard.IndexShardTestCase; import org.opensearch.index.store.StoreFileMetadata; @@ -28,6 +29,8 @@ import java.io.IOException; import java.util.Collections; import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import static org.mockito.Mockito.mock; @@ -197,4 +200,47 @@ public void testReplicationAlreadyRunning() throws IOException { handler.sendFiles(getSegmentFilesRequest, mock(ActionListener.class)); Assert.assertThrows(OpenSearchException.class, () -> { handler.sendFiles(getSegmentFilesRequest, mock(ActionListener.class)); }); } + + public void testCancelReplication() throws IOException, InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + chunkWriter = mock(FileChunkWriter.class); + + final ReplicationCheckpoint latestReplicationCheckpoint = primary.getLatestReplicationCheckpoint(); + final CopyState copyState = new CopyState(latestReplicationCheckpoint, primary); + SegmentReplicationSourceHandler handler = new SegmentReplicationSourceHandler( + localNode, + chunkWriter, + threadPool, + copyState, + primary.routingEntry().allocationId().getId(), + 5000, + 1 + ); + + final GetSegmentFilesRequest getSegmentFilesRequest = new GetSegmentFilesRequest( + 1L, + replica.routingEntry().allocationId().getId(), + replicaDiscoveryNode, + Collections.emptyList(), + latestReplicationCheckpoint + ); + + // cancel before xfer starts. Cancels during copy will be tested in SegmentFileTransferHandlerTests, that uses the same + // cancellableThreads. + handler.cancel("test"); + handler.sendFiles(getSegmentFilesRequest, new ActionListener<>() { + @Override + public void onResponse(GetSegmentFilesResponse getSegmentFilesResponse) { + Assert.fail("Expected failure."); + } + + @Override + public void onFailure(Exception e) { + assertEquals(CancellableThreads.ExecutionCancelledException.class, e.getClass()); + latch.countDown(); + } + }); + latch.await(2, TimeUnit.SECONDS); + assertEquals("listener should have resolved with failure", 0, latch.getCount()); + } } diff --git a/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java b/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java index e218f09aad575..7d9b0f09f21cd 100644 --- a/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java +++ b/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java @@ -9,38 +9,39 @@ package org.opensearch.indices.replication; import org.junit.Assert; -import org.mockito.ArgumentCaptor; import org.mockito.Mockito; import org.opensearch.OpenSearchException; import org.opensearch.action.ActionListener; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; +import org.opensearch.index.engine.NRTReplicationEngineFactory; import org.opensearch.index.shard.IndexShard; import org.opensearch.index.shard.IndexShardTestCase; -import org.opensearch.indices.recovery.RecoverySettings; +import org.opensearch.index.store.Store; +import org.opensearch.index.store.StoreFileMetadata; import org.opensearch.indices.replication.checkpoint.ReplicationCheckpoint; -import org.opensearch.indices.replication.common.ReplicationLuceneIndex; -import org.opensearch.transport.TransportService; +import org.opensearch.indices.replication.common.ReplicationType; import java.io.IOException; +import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import static org.mockito.Mockito.times; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.eq; public class SegmentReplicationTargetServiceTests extends IndexShardTestCase { - private IndexShard indexShard; + private IndexShard replicaShard; + private IndexShard primaryShard; private ReplicationCheckpoint checkpoint; private SegmentReplicationSource replicationSource; private SegmentReplicationTargetService sut; @@ -52,20 +53,20 @@ public class SegmentReplicationTargetServiceTests extends IndexShardTestCase { public void setUp() throws Exception { super.setUp(); final Settings settings = Settings.builder() - .put(IndexMetadata.SETTING_REPLICATION_TYPE, "SEGMENT") + .put(IndexMetadata.SETTING_REPLICATION_TYPE, ReplicationType.SEGMENT) .put("node.name", SegmentReplicationTargetServiceTests.class.getSimpleName()) .build(); final ClusterSettings clusterSettings = new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); - final RecoverySettings recoverySettings = new RecoverySettings(settings, clusterSettings); - final TransportService transportService = mock(TransportService.class); - indexShard = newStartedShard(false, settings); - checkpoint = new ReplicationCheckpoint(indexShard.shardId(), 0L, 0L, 0L, 0L); + primaryShard = newStartedShard(true); + replicaShard = newShard(false, settings, new NRTReplicationEngineFactory()); + recoverReplica(replicaShard, primaryShard, true); + checkpoint = new ReplicationCheckpoint(replicaShard.shardId(), 0L, 0L, 0L, 0L); SegmentReplicationSourceFactory replicationSourceFactory = mock(SegmentReplicationSourceFactory.class); replicationSource = mock(SegmentReplicationSource.class); - when(replicationSourceFactory.get(indexShard)).thenReturn(replicationSource); + when(replicationSourceFactory.get(replicaShard)).thenReturn(replicationSource); - sut = new SegmentReplicationTargetService(threadPool, recoverySettings, transportService, replicationSourceFactory); - initialCheckpoint = indexShard.getLatestReplicationCheckpoint(); + sut = prepareForReplication(primaryShard); + initialCheckpoint = replicaShard.getLatestReplicationCheckpoint(); aheadCheckpoint = new ReplicationCheckpoint( initialCheckpoint.getShardId(), initialCheckpoint.getPrimaryTerm(), @@ -77,44 +78,58 @@ public void setUp() throws Exception { @Override public void tearDown() throws Exception { - closeShards(indexShard); + closeShards(primaryShard, replicaShard); super.tearDown(); } - public void testTargetReturnsSuccess_listenerCompletes() { - final SegmentReplicationTarget target = new SegmentReplicationTarget( - checkpoint, - indexShard, - replicationSource, - new SegmentReplicationTargetService.SegmentReplicationListener() { - @Override - public void onReplicationDone(SegmentReplicationState state) { - assertEquals(SegmentReplicationState.Stage.DONE, state.getStage()); - } + public void testsSuccessfulReplication_listenerCompletes() throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + sut.startReplication(checkpoint, replicaShard, new SegmentReplicationTargetService.SegmentReplicationListener() { + @Override + public void onReplicationDone(SegmentReplicationState state) { + assertEquals(SegmentReplicationState.Stage.DONE, state.getStage()); + latch.countDown(); + } - @Override - public void onReplicationFailure(SegmentReplicationState state, OpenSearchException e, boolean sendShardFailure) { - Assert.fail(); - } + @Override + public void onReplicationFailure(SegmentReplicationState state, OpenSearchException e, boolean sendShardFailure) { + logger.error("Unexpected error", e); + Assert.fail("Test should succeed"); } - ); - final SegmentReplicationTarget spy = Mockito.spy(target); - doAnswer(invocation -> { - // set up stage correctly so the transition in markAsDone succeeds on listener completion - moveTargetToFinalStage(target); - final ActionListener listener = invocation.getArgument(0); - listener.onResponse(null); - return null; - }).when(spy).startReplication(any()); - sut.startReplication(spy); + }); + latch.await(2, TimeUnit.SECONDS); + assertEquals(0, latch.getCount()); } - public void testTargetThrowsException() { + public void testReplicationFails() throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); final OpenSearchException expectedError = new OpenSearchException("Fail"); + SegmentReplicationSource source = new SegmentReplicationSource() { + + @Override + public void getCheckpointMetadata( + long replicationId, + ReplicationCheckpoint checkpoint, + ActionListener listener + ) { + listener.onFailure(expectedError); + } + + @Override + public void getSegmentFiles( + long replicationId, + ReplicationCheckpoint checkpoint, + List filesToFetch, + Store store, + ActionListener listener + ) { + Assert.fail("Should not be called"); + } + }; final SegmentReplicationTarget target = new SegmentReplicationTarget( checkpoint, - indexShard, - replicationSource, + replicaShard, + source, new SegmentReplicationTargetService.SegmentReplicationListener() { @Override public void onReplicationDone(SegmentReplicationState state) { @@ -123,24 +138,21 @@ public void onReplicationDone(SegmentReplicationState state) { @Override public void onReplicationFailure(SegmentReplicationState state, OpenSearchException e, boolean sendShardFailure) { - assertEquals(SegmentReplicationState.Stage.INIT, state.getStage()); + // failures leave state object in last entered stage. + assertEquals(SegmentReplicationState.Stage.GET_CHECKPOINT_INFO, state.getStage()); assertEquals(expectedError, e.getCause()); - assertTrue(sendShardFailure); + latch.countDown(); } } ); - final SegmentReplicationTarget spy = Mockito.spy(target); - doAnswer(invocation -> { - final ActionListener listener = invocation.getArgument(0); - listener.onFailure(expectedError); - return null; - }).when(spy).startReplication(any()); - sut.startReplication(spy); + sut.startReplication(target); + latch.await(2, TimeUnit.SECONDS); + assertEquals(0, latch.getCount()); } public void testAlreadyOnNewCheckpoint() { SegmentReplicationTargetService spy = spy(sut); - spy.onNewCheckpoint(indexShard.getLatestReplicationCheckpoint(), indexShard); + spy.onNewCheckpoint(replicaShard.getLatestReplicationCheckpoint(), replicaShard); verify(spy, times(0)).startReplication(any(), any(), any()); } @@ -149,7 +161,7 @@ public void testShardAlreadyReplicating() throws InterruptedException { SegmentReplicationTargetService serviceSpy = spy(sut); final SegmentReplicationTarget target = new SegmentReplicationTarget( checkpoint, - indexShard, + replicaShard, replicationSource, mock(SegmentReplicationTargetService.SegmentReplicationListener.class) ); @@ -161,7 +173,7 @@ public void testShardAlreadyReplicating() throws InterruptedException { doAnswer(invocation -> { final ActionListener listener = invocation.getArgument(0); // a new checkpoint arrives before we've completed. - serviceSpy.onNewCheckpoint(aheadCheckpoint, indexShard); + serviceSpy.onNewCheckpoint(aheadCheckpoint, replicaShard); listener.onResponse(null); latch.countDown(); return null; @@ -173,12 +185,12 @@ public void testShardAlreadyReplicating() throws InterruptedException { // wait for the new checkpoint to arrive, before the listener completes. latch.await(30, TimeUnit.SECONDS); - verify(serviceSpy, times(0)).startReplication(eq(aheadCheckpoint), eq(indexShard), any()); + verify(serviceSpy, times(0)).startReplication(eq(aheadCheckpoint), eq(replicaShard), any()); } public void testNewCheckpointBehindCurrentCheckpoint() { SegmentReplicationTargetService spy = spy(sut); - spy.onNewCheckpoint(checkpoint, indexShard); + spy.onNewCheckpoint(checkpoint, replicaShard); verify(spy, times(0)).startReplication(any(), any(), any()); } @@ -190,22 +202,6 @@ public void testShardNotStarted() throws IOException { closeShards(shard); } - public void testNewCheckpoint_validationPassesAndReplicationFails() throws IOException { - allowShardFailures(); - SegmentReplicationTargetService spy = spy(sut); - IndexShard spyShard = spy(indexShard); - ArgumentCaptor captor = ArgumentCaptor.forClass( - SegmentReplicationTargetService.SegmentReplicationListener.class - ); - doNothing().when(spy).startReplication(any(), any(), any()); - spy.onNewCheckpoint(aheadCheckpoint, spyShard); - verify(spy, times(1)).startReplication(any(), any(), captor.capture()); - SegmentReplicationTargetService.SegmentReplicationListener listener = captor.getValue(); - listener.onFailure(new SegmentReplicationState(new ReplicationLuceneIndex()), new OpenSearchException("testing"), true); - verify(spyShard).failShard(any(), any()); - closeShard(indexShard, false); - } - /** * here we are starting a new shard in PrimaryMode and testing that we don't process a checkpoint on shard when it is in PrimaryMode. */ @@ -215,71 +211,10 @@ public void testRejectCheckpointOnShardPrimaryMode() throws IOException { // Starting a new shard in PrimaryMode. IndexShard primaryShard = newStartedShard(true); IndexShard spyShard = spy(primaryShard); - doNothing().when(spy).startReplication(any(), any(), any()); spy.onNewCheckpoint(aheadCheckpoint, spyShard); // Verify that checkpoint is not processed as shard is in PrimaryMode. verify(spy, times(0)).startReplication(any(), any(), any()); closeShards(primaryShard); } - - public void testReplicationOnDone() throws IOException { - SegmentReplicationTargetService spy = spy(sut); - IndexShard spyShard = spy(indexShard); - ReplicationCheckpoint cp = indexShard.getLatestReplicationCheckpoint(); - ReplicationCheckpoint newCheckpoint = new ReplicationCheckpoint( - cp.getShardId(), - cp.getPrimaryTerm(), - cp.getSegmentsGen(), - cp.getSeqNo(), - cp.getSegmentInfosVersion() + 1 - ); - ReplicationCheckpoint anotherNewCheckpoint = new ReplicationCheckpoint( - cp.getShardId(), - cp.getPrimaryTerm(), - cp.getSegmentsGen(), - cp.getSeqNo(), - cp.getSegmentInfosVersion() + 2 - ); - ArgumentCaptor captor = ArgumentCaptor.forClass( - SegmentReplicationTargetService.SegmentReplicationListener.class - ); - doNothing().when(spy).startReplication(any(), any(), any()); - spy.onNewCheckpoint(newCheckpoint, spyShard); - spy.onNewCheckpoint(anotherNewCheckpoint, spyShard); - verify(spy, times(1)).startReplication(eq(newCheckpoint), any(), captor.capture()); - verify(spy, times(1)).onNewCheckpoint(eq(anotherNewCheckpoint), any()); - SegmentReplicationTargetService.SegmentReplicationListener listener = captor.getValue(); - listener.onDone(new SegmentReplicationState(new ReplicationLuceneIndex())); - doNothing().when(spy).onNewCheckpoint(any(), any()); - verify(spy, timeout(0).times(2)).onNewCheckpoint(eq(anotherNewCheckpoint), any()); - closeShard(indexShard, false); - - } - - public void testBeforeIndexShardClosed_CancelsOngoingReplications() { - final SegmentReplicationTarget target = new SegmentReplicationTarget( - checkpoint, - indexShard, - replicationSource, - mock(SegmentReplicationTargetService.SegmentReplicationListener.class) - ); - final SegmentReplicationTarget spy = Mockito.spy(target); - sut.startReplication(spy); - sut.beforeIndexShardClosed(indexShard.shardId(), indexShard, Settings.EMPTY); - verify(spy, times(1)).cancel(any()); - } - - /** - * Move the {@link SegmentReplicationTarget} object through its {@link SegmentReplicationState.Stage} values in order - * until the final, non-terminal stage. - */ - private void moveTargetToFinalStage(SegmentReplicationTarget target) { - SegmentReplicationState.Stage[] stageValues = SegmentReplicationState.Stage.values(); - assertEquals(target.state().getStage(), SegmentReplicationState.Stage.INIT); - // Skip the first two stages (DONE and INIT) and iterate until the last value - for (int i = 2; i < stageValues.length; i++) { - target.state().setStage(stageValues[i]); - } - } } diff --git a/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java index e9ef5ba30c865..060c416f1a75c 100644 --- a/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java @@ -184,6 +184,7 @@ import org.opensearch.indices.recovery.PeerRecoveryTargetService; import org.opensearch.indices.recovery.RecoverySettings; import org.opensearch.indices.replication.SegmentReplicationSourceFactory; +import org.opensearch.indices.replication.SegmentReplicationSourceService; import org.opensearch.indices.replication.SegmentReplicationTargetService; import org.opensearch.indices.replication.checkpoint.SegmentReplicationCheckpointPublisher; import org.opensearch.ingest.IngestService; @@ -1855,6 +1856,7 @@ public void onFailure(final Exception e) { transportService, new SegmentReplicationSourceFactory(transportService, recoverySettings, clusterService) ), + SegmentReplicationSourceService.NO_OP, shardStateAction, new NodeMappingRefreshAction(transportService, metadataMappingService), repositoriesService, diff --git a/test/framework/src/main/java/org/opensearch/index/shard/IndexShardTestCase.java b/test/framework/src/main/java/org/opensearch/index/shard/IndexShardTestCase.java index cad6579ac941d..02442a6d1ad4d 100644 --- a/test/framework/src/main/java/org/opensearch/index/shard/IndexShardTestCase.java +++ b/test/framework/src/main/java/org/opensearch/index/shard/IndexShardTestCase.java @@ -64,7 +64,6 @@ import org.opensearch.common.lucene.uid.Versions; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; -import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.BigArrays; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.internal.io.IOUtils; @@ -105,7 +104,10 @@ import org.opensearch.indices.replication.CheckpointInfoResponse; import org.opensearch.indices.replication.GetSegmentFilesResponse; import org.opensearch.indices.replication.SegmentReplicationSource; +import org.opensearch.indices.replication.SegmentReplicationSourceFactory; +import org.opensearch.indices.replication.SegmentReplicationState; import org.opensearch.indices.replication.SegmentReplicationTarget; +import org.opensearch.indices.replication.SegmentReplicationTargetService; import org.opensearch.indices.replication.checkpoint.ReplicationCheckpoint; import org.opensearch.indices.replication.checkpoint.SegmentReplicationCheckpointPublisher; import org.opensearch.indices.replication.common.CopyState; @@ -120,8 +122,10 @@ import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; @@ -138,7 +142,9 @@ import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import static org.opensearch.cluster.routing.TestShardRouting.newShardRouting; /** @@ -1135,35 +1141,40 @@ public static Engine.Warmer createTestWarmer(IndexSettings indexSettings) { } /** - * Segment Replication specific test method - Replicate segments to a list of replicas from a given primary. - * This test will use a real {@link SegmentReplicationTarget} for each replica with a mock {@link SegmentReplicationSource} that - * writes all segments directly to the target. + * Segment Replication specific test method - Creates a {@link SegmentReplicationTargetService} to perform replications that has + * been configured to return the given primaryShard's current segments. + * + * @param primaryShard {@link IndexShard} - The primary shard to replicate from. */ - public final void replicateSegments(IndexShard primaryShard, List replicaShards) throws IOException, InterruptedException { - final CountDownLatch countDownLatch = new CountDownLatch(replicaShards.size()); - Store.MetadataSnapshot primaryMetadata; - try (final GatedCloseable segmentInfosSnapshot = primaryShard.getSegmentInfosSnapshot()) { - final SegmentInfos primarySegmentInfos = segmentInfosSnapshot.get(); - primaryMetadata = primaryShard.store().getMetadata(primarySegmentInfos); - } - final CopyState copyState = new CopyState(ReplicationCheckpoint.empty(primaryShard.shardId), primaryShard); - - final ReplicationCollection replicationCollection = new ReplicationCollection<>(logger, threadPool); - final SegmentReplicationSource source = new SegmentReplicationSource() { + public final SegmentReplicationTargetService prepareForReplication(IndexShard primaryShard) { + final SegmentReplicationSourceFactory sourceFactory = mock(SegmentReplicationSourceFactory.class); + final SegmentReplicationTargetService targetService = new SegmentReplicationTargetService( + threadPool, + new RecoverySettings(Settings.EMPTY, new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS)), + mock(TransportService.class), + sourceFactory + ); + final SegmentReplicationSource replicationSource = new SegmentReplicationSource() { @Override public void getCheckpointMetadata( long replicationId, ReplicationCheckpoint checkpoint, ActionListener listener ) { - listener.onResponse( - new CheckpointInfoResponse( - copyState.getCheckpoint(), - copyState.getMetadataSnapshot(), - copyState.getInfosBytes(), - copyState.getPendingDeleteFiles() - ) - ); + try { + final CopyState copyState = new CopyState(ReplicationCheckpoint.empty(primaryShard.shardId), primaryShard); + listener.onResponse( + new CheckpointInfoResponse( + copyState.getCheckpoint(), + copyState.getMetadataSnapshot(), + copyState.getInfosBytes(), + copyState.getPendingDeleteFiles() + ) + ); + } catch (IOException e) { + logger.error("Unexpected error computing CopyState", e); + Assert.fail("Failed to compute copyState"); + } } @Override @@ -1175,9 +1186,7 @@ public void getSegmentFiles( ActionListener listener ) { try ( - final ReplicationCollection.ReplicationRef replicationRef = replicationCollection.get( - replicationId - ) + final ReplicationCollection.ReplicationRef replicationRef = targetService.get(replicationId) ) { writeFileChunks(replicationRef.get(), primaryShard, filesToFetch.toArray(new StoreFileMetadata[] {})); } catch (IOException e) { @@ -1186,15 +1195,43 @@ public void getSegmentFiles( listener.onResponse(new GetSegmentFilesResponse(filesToFetch)); } }; + when(sourceFactory.get(any())).thenReturn(replicationSource); + return targetService; + } + /** + * Segment Replication specific test method - Replicate segments to a list of replicas from a given primary. + * This test will use a real {@link SegmentReplicationTarget} for each replica with a mock {@link SegmentReplicationSource} that + * writes all segments directly to the target. + * @param primaryShard - {@link IndexShard} The current primary shard. + * @param replicaShards - Replicas that will be updated. + * @return {@link List} List of target components orchestrating replication. + */ + public final List replicateSegments(IndexShard primaryShard, List replicaShards) + throws IOException, InterruptedException { + final SegmentReplicationTargetService targetService = prepareForReplication(primaryShard); + return replicateSegments(targetService, primaryShard, replicaShards); + } + + public final List replicateSegments( + SegmentReplicationTargetService targetService, + IndexShard primaryShard, + List replicaShards + ) throws IOException, InterruptedException { + final CountDownLatch countDownLatch = new CountDownLatch(replicaShards.size()); + Store.MetadataSnapshot primaryMetadata; + try (final GatedCloseable segmentInfosSnapshot = primaryShard.getSegmentInfosSnapshot()) { + final SegmentInfos primarySegmentInfos = segmentInfosSnapshot.get(); + primaryMetadata = primaryShard.store().getMetadata(primarySegmentInfos); + } + List ids = new ArrayList<>(); for (IndexShard replica : replicaShards) { - final SegmentReplicationTarget target = new SegmentReplicationTarget( + final SegmentReplicationTarget target = targetService.startReplication( ReplicationCheckpoint.empty(replica.shardId), replica, - source, - new ReplicationListener() { + new SegmentReplicationTargetService.SegmentReplicationListener() { @Override - public void onDone(ReplicationState state) { + public void onReplicationDone(SegmentReplicationState state) { try (final GatedCloseable snapshot = replica.getSegmentInfosSnapshot()) { final SegmentInfos replicaInfos = snapshot.get(); final Store.MetadataSnapshot replicaMetadata = replica.store().getMetadata(replicaInfos); @@ -1205,31 +1242,22 @@ public void onDone(ReplicationState state) { assertEquals(primaryMetadata.getCommitUserData(), replicaMetadata.getCommitUserData()); } catch (Exception e) { throw ExceptionsHelper.convertToRuntime(e); + } finally { + countDownLatch.countDown(); } - countDownLatch.countDown(); } @Override - public void onFailure(ReplicationState state, OpenSearchException e, boolean sendShardFailure) { + public void onReplicationFailure(SegmentReplicationState state, OpenSearchException e, boolean sendShardFailure) { logger.error("Unexpected replication failure in test", e); Assert.fail("test replication should not fail: " + e); } } ); - replicationCollection.start(target, TimeValue.timeValueMillis(5000)); - target.startReplication(new ActionListener<>() { - @Override - public void onResponse(Void o) { - replicationCollection.markAsDone(target.getId()); - } - - @Override - public void onFailure(Exception e) { - replicationCollection.fail(target.getId(), new OpenSearchException("Segment Replication failed", e), true); - } - }); + ids.add(target); + countDownLatch.await(1, TimeUnit.SECONDS); } - countDownLatch.await(3, TimeUnit.SECONDS); + return ids; } private void writeFileChunks(SegmentReplicationTarget target, IndexShard primary, StoreFileMetadata[] files) throws IOException {