diff --git a/src/consumer.ts b/src/consumer.ts index 36119fcb..02c589bd 100644 --- a/src/consumer.ts +++ b/src/consumer.ts @@ -38,6 +38,8 @@ export class Consumer extends TypedEventEmitter { private queueUrl: string; private handleMessage: (message: Message) => Promise; private handleMessageBatch: (message: Message[]) => Promise; + private preReceiveMessageCallback?: () => Promise; + private postReceiveMessageCallback?: () => Promise; private sqs: SQSClient; private handleMessageTimeout: number; private attributeNames: string[]; @@ -57,6 +59,8 @@ export class Consumer extends TypedEventEmitter { this.queueUrl = options.queueUrl; this.handleMessage = options.handleMessage; this.handleMessageBatch = options.handleMessageBatch; + this.preReceiveMessageCallback = options.preReceiveMessageCallback; + this.postReceiveMessageCallback = options.postReceiveMessageCallback; this.handleMessageTimeout = options.handleMessageTimeout; this.attributeNames = options.attributeNames || []; this.messageAttributeNames = options.messageAttributeNames || []; @@ -223,10 +227,18 @@ export class Consumer extends TypedEventEmitter { params: ReceiveMessageCommandInput ): Promise { try { - return await this.sqs.send( + if (this.preReceiveMessageCallback) { + await this.preReceiveMessageCallback(); + } + const result = await this.sqs.send( new ReceiveMessageCommand(params), this.sqsSendOptions ); + if (this.postReceiveMessageCallback) { + await this.postReceiveMessageCallback(); + } + + return result; } catch (err) { throw toSQSError(err, `SQS receive message failed: ${err.message}`); } diff --git a/src/types.ts b/src/types.ts index c933dd0a..d26492d1 100644 --- a/src/types.ts +++ b/src/types.ts @@ -103,6 +103,22 @@ export interface ConsumerOptions { * the successful messages only. */ handleMessageBatch?(messages: Message[]): Promise; + /** + * An `async` function (or function that returns a `Promise`) to be called right + * before the SQS Client sends a receive message command. + * + * This function is usefull if SQS Client module exports have been modified, for + * example to add middlewares. + */ + preReceiveMessageCallback?(): Promise; + /** + * An `async` function (or function that returns a `Promise`) to be called right + * after the SQS Client sends a receive message command. + * + * This function is usefull if SQS Client module exports have been modified, for + * example to add middlewares. + */ + postReceiveMessageCallback?(): Promise; } export type UpdatableOptions = diff --git a/test/tests/consumer.test.ts b/test/tests/consumer.test.ts index ae732318..54314ce5 100644 --- a/test/tests/consumer.test.ts +++ b/test/tests/consumer.test.ts @@ -479,6 +479,30 @@ describe('Consumer', () => { sandbox.assert.calledWith(handleMessage, response.Messages[0]); }); + it('calls the preReceiveMessageCallback and postReceiveMessageCallback function before receiving a message', async () => { + let callbackCalls = 0; + + consumer = new Consumer({ + queueUrl: QUEUE_URL, + region: REGION, + handleMessage, + sqs, + authenticationErrorTimeout: AUTHENTICATION_ERROR_TIMEOUT, + preReceiveMessageCallback: async () => { + callbackCalls++; + }, + postReceiveMessageCallback: async () => { + callbackCalls++; + } + }); + + consumer.start(); + await pEvent(consumer, 'message_processed'); + consumer.stop(); + + assert.equal(callbackCalls, 2); + }); + it('deletes the message when the handleMessage function is called', async () => { handleMessage.resolves();