diff --git a/packages/core/nest-application-context.ts b/packages/core/nest-application-context.ts index eacd82ec2dc..ce9c12f1ae9 100644 --- a/packages/core/nest-application-context.ts +++ b/packages/core/nest-application-context.ts @@ -54,6 +54,7 @@ export class NestApplicationContext< private shutdownCleanupRef?: (...args: unknown[]) => unknown; private _instanceLinksHost: InstanceLinksHost; private _moduleRefsForHooksByDistance?: Array; + private initializationPromise?: Promise; protected get instanceLinksHost() { if (!this._instanceLinksHost) { @@ -234,8 +235,8 @@ export class NestApplicationContext< if (this.isInitialized) { return this; } - await this.callInitHook(); - await this.callBootstrapHook(); + this.initializationPromise = this.internalInit(); + await this.initializationPromise; this.isInitialized = true; return this; @@ -246,6 +247,7 @@ export class NestApplicationContext< * @returns {Promise} */ public async close(signal?: string): Promise { + await this.initializationPromise; await this.callDestroyHook(); await this.callBeforeShutdownHook(signal); await this.dispose(); @@ -333,6 +335,7 @@ export class NestApplicationContext< return; } receivedSignal = true; + await this.initializationPromise; await this.callDestroyHook(); await this.callBeforeShutdownHook(signal); await this.dispose(); @@ -431,6 +434,11 @@ export class NestApplicationContext< } } + private async internalInit() { + await this.callInitHook(); + await this.callBootstrapHook(); + } + private getModulesToTriggerHooksOn(): Module[] { if (this._moduleRefsForHooksByDistance) { return this._moduleRefsForHooksByDistance; diff --git a/packages/core/test/nest-application-context.spec.ts b/packages/core/test/nest-application-context.spec.ts index d83dba163e9..198a84c9768 100644 --- a/packages/core/test/nest-application-context.spec.ts +++ b/packages/core/test/nest-application-context.spec.ts @@ -1,4 +1,4 @@ -import { InjectionToken, Scope } from '@nestjs/common'; +import { InjectionToken, Provider, Scope } from '@nestjs/common'; import { expect } from 'chai'; import * as sinon from 'sinon'; import { ContextIdFactory } from '../helpers/context-id-factory'; @@ -7,6 +7,7 @@ import { Injector } from '../injector/injector'; import { InstanceLoader } from '../injector/instance-loader'; import { GraphInspector } from '../inspector/graph-inspector'; import { NestApplicationContext } from '../nest-application-context'; +import { setTimeout } from 'timers/promises'; describe('NestApplicationContext', () => { class A {} @@ -14,6 +15,7 @@ describe('NestApplicationContext', () => { async function testHelper( injectionKey: InjectionToken, scope: Scope, + additionalProviders: Array = [], ): Promise { const nestContainer = new NestContainer(); const injector = new Injector(); @@ -33,6 +35,10 @@ describe('NestApplicationContext', () => { moduleRef.token, ); + for (const provider of additionalProviders) { + nestContainer.addProvider(provider, moduleRef.token); + } + nestContainer.addInjectable( { provide: injectionKey, @@ -96,6 +102,58 @@ describe('NestApplicationContext', () => { expect(processUp).to.be.false; expect(promisesResolved).to.be.true; }); + + it('should defer shutdown until all init hooks are resolved', async () => { + const clock = sinon.useFakeTimers({ + toFake: ['setTimeout'], + }); + const signal = 'SIGTERM'; + + const onModuleInitStub = sinon.stub(); + const onApplicationShutdownStub = sinon.stub(); + + class B { + async onModuleInit() { + await setTimeout(5000); + onModuleInitStub(); + } + + async onApplicationShutdown() { + await setTimeout(1000); + onApplicationShutdownStub(); + } + } + + const applicationContext = await testHelper(A, Scope.DEFAULT, [ + { provide: B, useClass: B, scope: Scope.DEFAULT }, + ]); + applicationContext.enableShutdownHooks([signal]); + + const ignoreProcessSignal = () => { + // noop to prevent process from exiting + }; + process.on(signal, ignoreProcessSignal); + + const deferredShutdown = async () => { + setTimeout(1); + process.kill(process.pid, signal); + }; + Promise.all([applicationContext.init(), deferredShutdown()]); + + await clock.nextAsync(); + expect(onModuleInitStub.called).to.be.false; + expect(onApplicationShutdownStub.called).to.be.false; + + await clock.nextAsync(); + expect(onModuleInitStub.called).to.be.true; + expect(onApplicationShutdownStub.called).to.be.false; + + await clock.nextAsync(); + expect(onModuleInitStub.called).to.be.true; + expect(onApplicationShutdownStub.called).to.be.true; + + clock.restore(); + }); }); describe('get', () => {