diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractProcessorContext.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractProcessorContext.java index 410212e1ff046..eeaed893482a3 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractProcessorContext.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractProcessorContext.java @@ -198,4 +198,10 @@ public ThreadCache getCache() { public void initialized() { initialized = true; } + + @Override + public void uninitialize() { + initialized = false; + } + } diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractTask.java index baea4af62f1c0..fcf2f6b13b78a 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractTask.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractTask.java @@ -107,7 +107,7 @@ public final String applicationId() { } @Override - public final Set partitions() { + public Set partitions() { return partitions; } @@ -226,6 +226,9 @@ void registerStateStores() { } } + void reinitializeStateStoresForPartitions(final TopicPartition partitions) { + stateMgr.reinitializeStateStoresForPartitions(partitions, processorContext); + } /** * @throws ProcessorStateException if there is an error while closing the state manager diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/AssignedTasks.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AssignedTasks.java index 0d9d04de5cc3d..cfce57588e086 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/AssignedTasks.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AssignedTasks.java @@ -50,7 +50,7 @@ class AssignedTasks implements RestoringTasks { // IQ may access this map. private Map running = new ConcurrentHashMap<>(); private Map runningByPartition = new HashMap<>(); - private Map restoringByPartition = new HashMap<>(); + private Map restoringByPartition = new HashMap<>(); private int committed = 0; @@ -122,7 +122,8 @@ Set initializeNewTasks() { try { if (!entry.getValue().initializeStateStores()) { log.debug("Transitioning {} {} to restoring", taskTypeName, entry.getKey()); - addToRestoring(entry.getValue()); + // cast is safe, because StandbyTasks always returns `true` in `initializeStateStores()` above + addToRestoring((StreamTask) entry.getValue()); } else { transitionToRunning(entry.getValue(), readyPartitions); } @@ -278,7 +279,7 @@ boolean maybeResumeSuspendedTask(final TaskId taskId, final Set return false; } - private void addToRestoring(final Task task) { + private void addToRestoring(final StreamTask task) { restoring.put(task.id(), task); for (TopicPartition topicPartition : task.partitions()) { restoringByPartition.put(topicPartition, task); @@ -307,7 +308,7 @@ private void transitionToRunning(final Task task, final Set read } @Override - public Task restoringTaskFor(final TopicPartition partition) { + public StreamTask restoringTaskFor(final TopicPartition partition) { return restoringByPartition.get(partition); } diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalProcessorContext.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalProcessorContext.java index 57bb3ac81a6d2..b5719b111f0cf 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalProcessorContext.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalProcessorContext.java @@ -53,4 +53,9 @@ public interface InternalProcessorContext extends ProcessorContext { * Mark this contex as being initialized */ void initialized(); + + /** + * Mark this context as being uninitialized + */ + void uninitialize(); } diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java index 5536ac12be7d6..9ccb458cf6611 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java @@ -19,8 +19,10 @@ import org.apache.kafka.clients.consumer.ConsumerRecord; import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Utils; import org.apache.kafka.streams.KeyValue; import org.apache.kafka.streams.errors.ProcessorStateException; +import org.apache.kafka.streams.errors.StreamsException; import org.apache.kafka.streams.processor.BatchingStateRestoreCallback; import org.apache.kafka.streams.processor.StateRestoreCallback; import org.apache.kafka.streams.processor.StateStore; @@ -33,9 +35,11 @@ import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; +import java.util.HashSet; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Set; public class ProcessorStateManager implements StateManager { @@ -49,6 +53,7 @@ public class ProcessorStateManager implements StateManager { private final String logPrefix; private final boolean isStandby; private final ChangelogReader changelogReader; + private final boolean eosEnabled; private final Map stores; private final Map globalStores; private final Map offsetLimits; @@ -98,6 +103,7 @@ public ProcessorStateManager(final TaskId taskId, checkpoint = new OffsetCheckpoint(new File(baseDir, CHECKPOINT_FILE_NAME)); checkpointedOffsets = new HashMap<>(checkpoint.read()); + this.eosEnabled = eosEnabled; if (eosEnabled) { // delete the checkpoint file after finish loading its stored offsets checkpoint.delete(); @@ -176,6 +182,61 @@ public Map checkpointed() { return partitionsAndOffsets; } + void reinitializeStateStoresForPartitions(final TopicPartition topicPartition, + final InternalProcessorContext processorContext) { + final Map changelogTopicToStore = inverseOneToOneMap(storeToChangelogTopic); + final Set storeToBeReinitialized = new HashSet<>(); + final Map storesCopy = new HashMap<>(stores); + + checkpointedOffsets.remove(topicPartition); + storeToBeReinitialized.add(changelogTopicToStore.get(topicPartition.topic())); + + if (!eosEnabled) { + try { + checkpoint.write(checkpointedOffsets); + } catch (final IOException fatalException) { + log.error("Failed to write offset checkpoint file to {} while re-initializing {}: {}", checkpoint, stores, fatalException); + throw new StreamsException("Failed to reinitialize stores.", fatalException); + } + } + + for (final Map.Entry entry : storesCopy.entrySet()) { + final StateStore stateStore = entry.getValue(); + final String storeName = stateStore.name(); + if (storeToBeReinitialized.contains(storeName)) { + try { + stateStore.close(); + } catch (final RuntimeException ignoreAndSwallow) { /* ignore */ } + processorContext.uninitialize(); + stores.remove(entry.getKey()); + + try { + Utils.delete(new File(baseDir + File.separator + "rocksdb" + File.separator + storeName)); + } catch (final IOException fatalException) { + log.error("Failed to reinitialize store {}.", storeName, fatalException); + throw new StreamsException(String.format("Failed to reinitialize store %s.", storeName), fatalException); + } + + try { + Utils.delete(new File(baseDir + File.separator + storeName)); + } catch (final IOException fatalException) { + log.error("Failed to reinitialize store {}.", storeName, fatalException); + throw new StreamsException(String.format("Failed to reinitialize store %s.", storeName), fatalException); + } + + stateStore.init(processorContext, stateStore); + } + } + } + + private Map inverseOneToOneMap(final Map origin) { + final Map reversedMap = new HashMap<>(); + for (final Map.Entry entry : origin.entrySet()) { + reversedMap.put(entry.getValue(), entry.getKey()); + } + return reversedMap; + } + List> updateStandbyStates(final TopicPartition storePartition, final List> records) { final long limit = offsetLimit(storePartition); diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RestoringTasks.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RestoringTasks.java index 6ed28fdf63c06..3671b493f1992 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RestoringTasks.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RestoringTasks.java @@ -19,5 +19,5 @@ import org.apache.kafka.common.TopicPartition; public interface RestoringTasks { - Task restoringTaskFor(final TopicPartition partition); + StreamTask restoringTaskFor(final TopicPartition partition); } diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateRestorer.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateRestorer.java index 33dce9e755814..e0bac939b8b96 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateRestorer.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateRestorer.java @@ -26,13 +26,13 @@ public class StateRestorer { static final int NO_CHECKPOINT = -1; - private final Long checkpoint; private final long offsetLimit; private final boolean persistent; private final String storeName; private final TopicPartition partition; private final CompositeRestoreListener compositeRestoreListener; + private long checkpointOffset; private long restoredOffset; private long startingOffset; private long endingOffset; @@ -45,7 +45,7 @@ public class StateRestorer { final String storeName) { this.partition = partition; this.compositeRestoreListener = compositeRestoreListener; - this.checkpoint = checkpoint; + this.checkpointOffset = checkpoint == null ? NO_CHECKPOINT : checkpoint; this.offsetLimit = offsetLimit; this.persistent = persistent; this.storeName = storeName; @@ -56,7 +56,15 @@ public TopicPartition partition() { } long checkpoint() { - return checkpoint == null ? NO_CHECKPOINT : checkpoint; + return checkpointOffset; + } + + void setCheckpointOffset(final long checkpointOffset) { + this.checkpointOffset = checkpointOffset; + } + + public String storeName() { + return storeName; } void restoreStarted() { @@ -67,7 +75,8 @@ void restoreDone() { compositeRestoreListener.onRestoreEnd(partition, storeName, restoredNumRecords()); } - void restoreBatchCompleted(long currentRestoredOffset, int numRestored) { + void restoreBatchCompleted(final long currentRestoredOffset, + final int numRestored) { compositeRestoreListener.onBatchRestored(partition, storeName, currentRestoredOffset, numRestored); } @@ -79,7 +88,7 @@ boolean isPersistent() { return persistent; } - void setUserRestoreListener(StateRestoreListener userRestoreListener) { + void setUserRestoreListener(final StateRestoreListener userRestoreListener) { this.compositeRestoreListener.setUserRestoreListener(userRestoreListener); } diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java index 8d85b1d8fafb3..34350c17eb0e5 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java @@ -69,7 +69,7 @@ public void register(final StateRestorer restorer) { */ public Collection restore(final RestoringTasks active) { if (!needsInitializing.isEmpty()) { - initialize(); + initialize(active); } if (needsRestoring.isEmpty()) { @@ -90,7 +90,7 @@ public Collection restore(final RestoringTasks active) { return completed(); } - private void initialize() { + private void initialize(final RestoringTasks active) { if (!consumer.subscription().isEmpty()) { throw new IllegalStateException("Restore consumer should not be subscribed to any topics (" + consumer.subscription() + ")"); } @@ -99,8 +99,8 @@ private void initialize() { // the needsInitializing map is not empty, meaning we do not know the metadata for some of them yet refreshChangelogInfo(); - Map initializable = new HashMap<>(); - for (Map.Entry entry : needsInitializing.entrySet()) { + final Map initializable = new HashMap<>(); + for (final Map.Entry entry : needsInitializing.entrySet()) { final TopicPartition topicPartition = entry.getKey(); if (hasPartition(topicPartition)) { initializable.put(entry.getKey(), entry.getValue()); @@ -144,11 +144,12 @@ private void initialize() { // set up restorer for those initializable if (!initializable.isEmpty()) { - startRestoration(initializable); + startRestoration(initializable, active); } } - private void startRestoration(final Map initialized) { + private void startRestoration(final Map initialized, + final RestoringTasks active) { log.debug("Start restoring state stores from changelog topics {}", initialized.keySet()); final Set assignment = new HashSet<>(consumer.assignment()); @@ -157,26 +158,47 @@ private void startRestoration(final Map initializ final List needsPositionUpdate = new ArrayList<>(); for (final StateRestorer restorer : initialized.values()) { + final TopicPartition restoringPartition = restorer.partition(); if (restorer.checkpoint() != StateRestorer.NO_CHECKPOINT) { - consumer.seek(restorer.partition(), restorer.checkpoint()); - logRestoreOffsets(restorer.partition(), - restorer.checkpoint(), - endOffsets.get(restorer.partition())); - restorer.setStartingOffset(consumer.position(restorer.partition())); + consumer.seek(restoringPartition, restorer.checkpoint()); + logRestoreOffsets(restoringPartition, + restorer.checkpoint(), + endOffsets.get(restoringPartition)); + restorer.setStartingOffset(consumer.position(restoringPartition)); restorer.restoreStarted(); } else { - consumer.seekToBeginning(Collections.singletonList(restorer.partition())); + consumer.seekToBeginning(Collections.singletonList(restoringPartition)); needsPositionUpdate.add(restorer); } } for (final StateRestorer restorer : needsPositionUpdate) { - final long position = consumer.position(restorer.partition()); - logRestoreOffsets(restorer.partition(), - position, - endOffsets.get(restorer.partition())); - restorer.setStartingOffset(position); - restorer.restoreStarted(); + final TopicPartition restoringPartition = restorer.partition(); + final StreamTask task = active.restoringTaskFor(restoringPartition); + + // If checkpoint does not exist it means the task was not shutdown gracefully before; + // and in this case if EOS is turned on we should wipe out the state and re-initialize the task + if (task.eosEnabled) { + log.info("No checkpoint found for task {} state store {} changelog {} with EOS turned on. " + + "Reinitializing the task and restore its state from the beginning.", task.id, restorer.storeName(), restoringPartition); + + // we move the partitions here, because they will be added back within + // `task.reinitializeStateStoresForPartitions()` that calls `register()` internally again + needsInitializing.remove(restoringPartition); + restorer.setCheckpointOffset(consumer.position(restoringPartition)); + + task.reinitializeStateStoresForPartitions(restoringPartition); + stateRestorers.get(restoringPartition).restoreStarted(); + } else { + log.info("Restoring task {}'s state store {} from beginning of the changelog {} ", task.id, restorer.storeName(), restoringPartition); + + final long position = consumer.position(restoringPartition); + logRestoreOffsets(restoringPartition, + position, + endOffsets.get(restoringPartition)); + restorer.setStartingOffset(position); + restorer.restoreStarted(); + } } needsRestoring.putAll(initialized); diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/EosIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/EosIntegrationTest.java index 6c7b2b43c9d11..d07781f5b2453 100644 --- a/streams/src/test/java/org/apache/kafka/streams/integration/EosIntegrationTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/integration/EosIntegrationTest.java @@ -392,7 +392,7 @@ public void shouldNotViolateEosIfOneTaskFailsWithState() throws Exception { // the app commits after each 10 records per partition, and thus will have 2*5 uncommitted writes // and store updates (ie, another 5 uncommitted writes to a changelog topic per partition) // - // the failure gets inject after 20 committed and 30 uncommitted records got received + // the failure gets inject after 20 committed and 10 uncommitted records got received // -> the failure only kills one thread // after fail over, we should read 40 committed records and the state stores should contain the correct sums // per key (even if some records got processed twice) @@ -402,7 +402,7 @@ public void shouldNotViolateEosIfOneTaskFailsWithState() throws Exception { streams.start(); final List> committedDataBeforeFailure = prepareData(0L, 10L, 0L, 1L); - final List> uncommittedDataBeforeFailure = prepareData(10L, 15L, 0L, 1L); + final List> uncommittedDataBeforeFailure = prepareData(10L, 15L, 0L, 1L, 2L, 3L); final List> dataBeforeFailure = new ArrayList<>(); dataBeforeFailure.addAll(committedDataBeforeFailure); diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/AssignedTasksTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/AssignedTasksTest.java index 5c8b7c4403392..01439add706ef 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/AssignedTasksTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/AssignedTasksTest.java @@ -40,8 +40,8 @@ public class AssignedTasksTest { - private final Task t1 = EasyMock.createMock(Task.class); - private final Task t2 = EasyMock.createMock(Task.class); + private final StreamTask t1 = EasyMock.createMock(StreamTask.class); + private final StreamTask t2 = EasyMock.createMock(StreamTask.class); private final TopicPartition tp1 = new TopicPartition("t1", 0); private final TopicPartition tp2 = new TopicPartition("t2", 0); private final TopicPartition changeLog1 = new TopicPartition("cl1", 0); diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StoreChangelogReaderTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StoreChangelogReaderTest.java index bb0c51e15465b..a0e2140e92d3b 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StoreChangelogReaderTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StoreChangelogReaderTest.java @@ -59,7 +59,7 @@ public class StoreChangelogReaderTest { @Mock(type = MockType.NICE) private RestoringTasks active; @Mock(type = MockType.NICE) - private Task task; + private StreamTask task; private final MockStateRestoreListener callback = new MockStateRestoreListener(); private final CompositeRestoreListener restoreListener = new CompositeRestoreListener(callback); @@ -107,6 +107,10 @@ public void shouldRestoreAllMessagesFromBeginningWhenCheckpointNull() { final int messages = 10; setupConsumer(messages, topicPartition); changelogReader.register(new StateRestorer(topicPartition, restoreListener, null, Long.MAX_VALUE, true, "storeName")); + + expect(active.restoringTaskFor(topicPartition)).andStubReturn(task); + replay(active, task); + changelogReader.restore(active); assertThat(callback.restored.size(), equalTo(messages)); } @@ -115,8 +119,7 @@ public void shouldRestoreAllMessagesFromBeginningWhenCheckpointNull() { public void shouldRestoreMessagesFromCheckpoint() { final int messages = 10; setupConsumer(messages, topicPartition); - changelogReader.register(new StateRestorer(topicPartition, restoreListener, 5L, Long.MAX_VALUE, true, - "storeName")); + changelogReader.register(new StateRestorer(topicPartition, restoreListener, 5L, Long.MAX_VALUE, true, "storeName")); changelogReader.restore(active); assertThat(callback.restored.size(), equalTo(5)); @@ -126,9 +129,9 @@ public void shouldRestoreMessagesFromCheckpoint() { public void shouldClearAssignmentAtEndOfRestore() { final int messages = 1; setupConsumer(messages, topicPartition); - changelogReader.register(new StateRestorer(topicPartition, restoreListener, null, Long.MAX_VALUE, true, - "storeName")); - + changelogReader.register(new StateRestorer(topicPartition, restoreListener, null, Long.MAX_VALUE, true, "storeName")); + expect(active.restoringTaskFor(topicPartition)).andStubReturn(task); + replay(active, task); changelogReader.restore(active); assertThat(consumer.assignment(), equalTo(Collections.emptySet())); } @@ -136,9 +139,10 @@ public void shouldClearAssignmentAtEndOfRestore() { @Test public void shouldRestoreToLimitWhenSupplied() { setupConsumer(10, topicPartition); - final StateRestorer restorer = new StateRestorer(topicPartition, restoreListener, null, 3, true, - "storeName"); + final StateRestorer restorer = new StateRestorer(topicPartition, restoreListener, null, 3, true, "storeName"); changelogReader.register(restorer); + expect(active.restoringTaskFor(topicPartition)).andStubReturn(task); + replay(active, task); changelogReader.restore(active); assertThat(callback.restored.size(), equalTo(3)); assertThat(restorer.restoredOffset(), equalTo(3L)); @@ -156,14 +160,14 @@ public void shouldRestoreMultipleStores() { setupConsumer(5, one); setupConsumer(3, two); - changelogReader - .register(new StateRestorer(topicPartition, restoreListener, null, Long.MAX_VALUE, true, "storeName1")); + changelogReader.register(new StateRestorer(topicPartition, restoreListener, null, Long.MAX_VALUE, true, "storeName1")); changelogReader.register(new StateRestorer(one, restoreListener1, null, Long.MAX_VALUE, true, "storeName2")); changelogReader.register(new StateRestorer(two, restoreListener2, null, Long.MAX_VALUE, true, "storeName3")); - expect(active.restoringTaskFor(one)).andReturn(null); - expect(active.restoringTaskFor(two)).andReturn(null); - replay(active); + expect(active.restoringTaskFor(one)).andStubReturn(task); + expect(active.restoringTaskFor(two)).andStubReturn(task); + expect(active.restoringTaskFor(topicPartition)).andStubReturn(task); + replay(active, task); changelogReader.restore(active); assertThat(callback.restored.size(), equalTo(10)); @@ -188,9 +192,10 @@ public void shouldRestoreAndNotifyMultipleStores() throws Exception { changelogReader.register(new StateRestorer(one, restoreListener1, null, Long.MAX_VALUE, true, "storeName2")); changelogReader.register(new StateRestorer(two, restoreListener2, null, Long.MAX_VALUE, true, "storeName3")); - expect(active.restoringTaskFor(one)).andReturn(null); - expect(active.restoringTaskFor(two)).andReturn(null); - replay(active); + expect(active.restoringTaskFor(one)).andStubReturn(task); + expect(active.restoringTaskFor(two)).andStubReturn(task); + expect(active.restoringTaskFor(topicPartition)).andStubReturn(task); + replay(active, task); changelogReader.restore(active); assertThat(callback.restored.size(), equalTo(10)); @@ -210,8 +215,10 @@ public void shouldRestoreAndNotifyMultipleStores() throws Exception { @Test public void shouldOnlyReportTheLastRestoredOffset() { setupConsumer(10, topicPartition); - changelogReader - .register(new StateRestorer(topicPartition, restoreListener, null, 5, true, "storeName1")); + changelogReader.register(new StateRestorer(topicPartition, restoreListener, null, 5, true, "storeName1")); + expect(active.restoringTaskFor(topicPartition)).andStubReturn(task); + replay(active, task); + changelogReader.restore(active); assertThat(callback.restored.size(), equalTo(5)); @@ -270,6 +277,8 @@ public void shouldNotRestoreAnythingWhenCheckpointAtEndOffset() { public void shouldReturnRestoredOffsetsForPersistentStores() { setupConsumer(10, topicPartition); changelogReader.register(new StateRestorer(topicPartition, restoreListener, null, Long.MAX_VALUE, true, "storeName")); + expect(active.restoringTaskFor(topicPartition)).andStubReturn(task); + replay(active, task); changelogReader.restore(active); final Map restoredOffsets = changelogReader.restoredOffsets(); assertThat(restoredOffsets, equalTo(Collections.singletonMap(topicPartition, 10L))); @@ -279,6 +288,8 @@ public void shouldReturnRestoredOffsetsForPersistentStores() { public void shouldNotReturnRestoredOffsetsForNonPersistentStore() { setupConsumer(10, topicPartition); changelogReader.register(new StateRestorer(topicPartition, restoreListener, null, Long.MAX_VALUE, false, "storeName")); + expect(active.restoringTaskFor(topicPartition)).andStubReturn(task); + replay(active, task); changelogReader.restore(active); final Map restoredOffsets = changelogReader.restoredOffsets(); assertThat(restoredOffsets, equalTo(Collections.emptyMap())); @@ -292,8 +303,9 @@ public void shouldIgnoreNullKeysWhenRestoring() { consumer.addRecord(new ConsumerRecord<>(topicPartition.topic(), topicPartition.partition(), 1, (byte[]) null, bytes)); consumer.addRecord(new ConsumerRecord<>(topicPartition.topic(), topicPartition.partition(), 2, bytes, bytes)); consumer.assign(Collections.singletonList(topicPartition)); - changelogReader.register(new StateRestorer(topicPartition, restoreListener, null, Long.MAX_VALUE, false, - "storeName")); + changelogReader.register(new StateRestorer(topicPartition, restoreListener, null, Long.MAX_VALUE, false, "storeName")); + expect(active.restoringTaskFor(topicPartition)).andStubReturn(task); + replay(active, task); changelogReader.restore(active); assertThat(callback.restored, CoreMatchers.equalTo(Utils.mkList(KeyValue.pair(bytes, bytes), KeyValue.pair(bytes, bytes)))); @@ -318,10 +330,11 @@ public void shouldRestorePartitionsRegisteredPostInitialization() { changelogReader.register(new StateRestorer(topicPartition, restoreListener, null, Long.MAX_VALUE, false, "storeName")); final TopicPartition postInitialization = new TopicPartition("other", 0); - expect(active.restoringTaskFor(topicPartition)).andReturn(null); - expect(active.restoringTaskFor(topicPartition)).andReturn(null); - expect(active.restoringTaskFor(postInitialization)).andReturn(null); - replay(active); + expect(active.restoringTaskFor(topicPartition)).andStubReturn(task); + expect(active.restoringTaskFor(topicPartition)).andStubReturn(task); + expect(active.restoringTaskFor(postInitialization)).andStubReturn(task); + expect(active.restoringTaskFor(topicPartition)).andStubReturn(task); + replay(active, task); assertTrue(changelogReader.restore(active).isEmpty()); @@ -348,7 +361,7 @@ public void shouldThrowTaskMigratedExceptionIfEndOffsetGetsExceededDuringRestore consumer.updateEndOffsets(Collections.singletonMap(topicPartition, 5L)); changelogReader.register(new StateRestorer(topicPartition, restoreListener, null, Long.MAX_VALUE, true, "storeName")); - expect(active.restoringTaskFor(topicPartition)).andReturn(task); + expect(active.restoringTaskFor(topicPartition)).andStubReturn(task); replay(active); try { @@ -371,8 +384,8 @@ public void shouldThrowTaskMigratedExceptionIfEndOffsetGetsExceededDuringRestore consumer.updateEndOffsets(Collections.singletonMap(topicPartition, 6L)); changelogReader.register(new StateRestorer(topicPartition, restoreListener, null, Long.MAX_VALUE, true, "storeName")); - expect(active.restoringTaskFor(topicPartition)).andReturn(task); - replay(active); + expect(active.restoringTaskFor(topicPartition)).andStubReturn(task); + replay(active, task); try { changelogReader.restore(active); fail("Should have thrown task migrated exception");