diff --git a/src/consumer.ts b/src/consumer.ts index dc2df77..11e081c 100644 --- a/src/consumer.ts +++ b/src/consumer.ts @@ -55,6 +55,7 @@ export class Consumer extends TypedEventEmitter { private authenticationErrorTimeout: number; private pollingWaitTimeMs: number; private heartbeatInterval: number; + private areMessagesInFlight: boolean; public abortController: AbortController; constructor(options: ConsumerOptions) { @@ -142,7 +143,36 @@ export class Consumer extends TypedEventEmitter { this.emit('aborted'); } - this.emit('stopped'); + if (options?.waitForInFlightMessages) { + this.waitForInFlightMessagesToComplete(options?.waitTimeoutMs || 0).then( + () => { + this.emit('stopped'); + } + ); + } else { + this.emit('stopped'); + } + } + + /** + * Wait for in flight messages to complete. + * @param {number} waitTimeoutMs + * @private + */ + private async waitForInFlightMessagesToComplete( + waitTimeoutMs: number + ): Promise { + const startedAt = Date.now(); + while (this.areMessagesInFlight) { + if (waitTimeoutMs && Date.now() - startedAt > waitTimeoutMs) { + logger.debug( + 'waiting_for_in_flight_messages_to_complete_wait_timeout_exceeded' + ); + return; + } + logger.debug('waiting_for_in_flight_messages_to_complete'); + await new Promise((resolve) => setTimeout(resolve, 1000)); + } } /** @@ -270,6 +300,8 @@ export class Consumer extends TypedEventEmitter { }); }, 1000); + this.areMessagesInFlight = true; + if (this.handleMessageBatch) { await this.processMessageBatch(response.Messages); } else { @@ -279,6 +311,7 @@ export class Consumer extends TypedEventEmitter { clearInterval(handlerProcessingDebugger); this.emit('response_processed'); + this.areMessagesInFlight = false; } else if (response) { this.emit('empty'); } diff --git a/src/types.ts b/src/types.ts index fd93c88..a73df34 100644 --- a/src/types.ts +++ b/src/types.ts @@ -141,6 +141,20 @@ export interface StopOptions { * @defaultvalue `false` */ abort?: boolean; + + /** + * Default to `false`, if you want the stop action to wait for in-flight messages + * to be processed before emitting 'stopped' set this to `true`. + * @defaultvalue `false` + */ + waitForInFlightMessages?: boolean; + + /** + * if `waitForInFlightMessages` is set to `true`, this option will be used to + * determine how long to wait for in-flight messages to be processed before + * emitting 'stopped'. + */ + waitTimeoutMs?: number; } export interface Events { diff --git a/test/tests/consumer.test.ts b/test/tests/consumer.test.ts index e69bb67..0865c23 100644 --- a/test/tests/consumer.test.ts +++ b/test/tests/consumer.test.ts @@ -1481,6 +1481,77 @@ describe('Consumer', () => { sandbox.assert.calledOnce(handleAbort); sandbox.assert.calledOnce(handleStop); }); + + it('waits for in-flight messages before emitting stopped (no timeout)', async () => { + const handleStop = sandbox.stub().returns(null); + const handleResponseProcessed = sandbox.stub().returns(null); + + // A slow message handler + handleMessage = sandbox.stub().callsFake(async () => { + await new Promise((resolve) => setTimeout(resolve, 5000)); + }); + + consumer = new Consumer({ + queueUrl: QUEUE_URL, + region: REGION, + handleMessage, + sqs, + authenticationErrorTimeout: AUTHENTICATION_ERROR_TIMEOUT + }); + + consumer.on('stopped', handleStop); + consumer.on('response_processed', handleResponseProcessed); + + consumer.start(); + await clock.nextAsync(); + consumer.stop({ waitForInFlightMessages: true }); + + await clock.runAllAsync(); + + sandbox.assert.calledOnce(handleStop); + sandbox.assert.calledOnce(handleResponseProcessed); + sandbox.assert.calledOnce(handleMessage); + assert.ok(handleMessage.calledBefore(handleStop)); + + // handleResponseProcessed is called after handleMessage, indicating + // messages were allowed to complete before 'stopped' was emitted + assert.ok(handleResponseProcessed.calledBefore(handleStop)); + }); + + it('waits for in-flight messages before emitting stopped (timeout reached)', async () => { + const handleStop = sandbox.stub().returns(null); + const handleResponseProcessed = sandbox.stub().returns(null); + + // A slow message handler + handleMessage = sandbox.stub().callsFake(async () => { + await new Promise((resolve) => setTimeout(resolve, 5000)); + }); + + consumer = new Consumer({ + queueUrl: QUEUE_URL, + region: REGION, + handleMessage, + sqs, + authenticationErrorTimeout: AUTHENTICATION_ERROR_TIMEOUT + }); + + consumer.on('stopped', handleStop); + consumer.on('response_processed', handleResponseProcessed); + + consumer.start(); + await clock.nextAsync(); + consumer.stop({ waitForInFlightMessages: true, waitTimeoutMs: 500 }); + + await clock.runAllAsync(); + + sandbox.assert.calledOnce(handleStop); + sandbox.assert.calledOnce(handleResponseProcessed); + sandbox.assert.calledOnce(handleMessage); + assert(handleMessage.calledBefore(handleStop)); + + // Stop was called before the message could be processed, because we reached timeout. + assert(handleStop.calledBefore(handleResponseProcessed)); + }); }); describe('isRunning', async () => {