Skip to content

Commit

Permalink
Merge pull request #14139 from mksony/chore/ensure-application-init-b…
Browse files Browse the repository at this point in the history
…efore-accepting-sigterm

chore(core): defer application shutdown until init finishes
  • Loading branch information
kamilmysliwiec authored Nov 15, 2024
2 parents 49dc36d + 5c6986f commit dade6d5
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 3 deletions.
12 changes: 10 additions & 2 deletions packages/core/nest-application-context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ export class NestApplicationContext<
private shutdownCleanupRef?: (...args: unknown[]) => unknown;
private _instanceLinksHost: InstanceLinksHost;
private _moduleRefsForHooksByDistance?: Array<Module>;
private initializationPromise?: Promise<void>;

protected get instanceLinksHost() {
if (!this._instanceLinksHost) {
Expand Down Expand Up @@ -234,8 +235,8 @@ export class NestApplicationContext<
if (this.isInitialized) {
return this;
}
await this.callInitHook();
await this.callBootstrapHook();
this.initializationPromise = this.internalInit();
await this.initializationPromise;

this.isInitialized = true;
return this;
Expand All @@ -246,6 +247,7 @@ export class NestApplicationContext<
* @returns {Promise<void>}
*/
public async close(signal?: string): Promise<void> {
await this.initializationPromise;
await this.callDestroyHook();
await this.callBeforeShutdownHook(signal);
await this.dispose();
Expand Down Expand Up @@ -333,6 +335,7 @@ export class NestApplicationContext<
return;
}
receivedSignal = true;
await this.initializationPromise;
await this.callDestroyHook();
await this.callBeforeShutdownHook(signal);
await this.dispose();
Expand Down Expand Up @@ -431,6 +434,11 @@ export class NestApplicationContext<
}
}

private async internalInit() {
await this.callInitHook();
await this.callBootstrapHook();
}

private getModulesToTriggerHooksOn(): Module[] {
if (this._moduleRefsForHooksByDistance) {
return this._moduleRefsForHooksByDistance;
Expand Down
60 changes: 59 additions & 1 deletion packages/core/test/nest-application-context.spec.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { InjectionToken, Scope } from '@nestjs/common';
import { InjectionToken, Provider, Scope } from '@nestjs/common';
import { expect } from 'chai';
import * as sinon from 'sinon';
import { ContextIdFactory } from '../helpers/context-id-factory';
Expand All @@ -7,13 +7,15 @@ import { Injector } from '../injector/injector';
import { InstanceLoader } from '../injector/instance-loader';
import { GraphInspector } from '../inspector/graph-inspector';
import { NestApplicationContext } from '../nest-application-context';
import { setTimeout } from 'timers/promises';

describe('NestApplicationContext', () => {
class A {}

async function testHelper(
injectionKey: InjectionToken,
scope: Scope,
additionalProviders: Array<Provider> = [],
): Promise<NestApplicationContext> {
const nestContainer = new NestContainer();
const injector = new Injector();
Expand All @@ -33,6 +35,10 @@ describe('NestApplicationContext', () => {
moduleRef.token,
);

for (const provider of additionalProviders) {
nestContainer.addProvider(provider, moduleRef.token);
}

nestContainer.addInjectable(
{
provide: injectionKey,
Expand Down Expand Up @@ -96,6 +102,58 @@ describe('NestApplicationContext', () => {
expect(processUp).to.be.false;
expect(promisesResolved).to.be.true;
});

it('should defer shutdown until all init hooks are resolved', async () => {
const clock = sinon.useFakeTimers({
toFake: ['setTimeout'],
});
const signal = 'SIGTERM';

const onModuleInitStub = sinon.stub();
const onApplicationShutdownStub = sinon.stub();

class B {
async onModuleInit() {
await setTimeout(5000);
onModuleInitStub();
}

async onApplicationShutdown() {
await setTimeout(1000);
onApplicationShutdownStub();
}
}

const applicationContext = await testHelper(A, Scope.DEFAULT, [
{ provide: B, useClass: B, scope: Scope.DEFAULT },
]);
applicationContext.enableShutdownHooks([signal]);

const ignoreProcessSignal = () => {
// noop to prevent process from exiting
};
process.on(signal, ignoreProcessSignal);

const deferredShutdown = async () => {
setTimeout(1);
process.kill(process.pid, signal);
};
Promise.all([applicationContext.init(), deferredShutdown()]);

await clock.nextAsync();
expect(onModuleInitStub.called).to.be.false;
expect(onApplicationShutdownStub.called).to.be.false;

await clock.nextAsync();
expect(onModuleInitStub.called).to.be.true;
expect(onApplicationShutdownStub.called).to.be.false;

await clock.nextAsync();
expect(onModuleInitStub.called).to.be.true;
expect(onApplicationShutdownStub.called).to.be.true;

clock.restore();
});
});

describe('get', () => {
Expand Down

0 comments on commit dade6d5

Please sign in to comment.