From 668bc8e59457994808898213814ff5749f968ac9 Mon Sep 17 00:00:00 2001 From: Will Soto Date: Sat, 11 Jul 2020 09:35:17 -0400 Subject: [PATCH] feat(core): add onApplicationShutdown hook This will clean up the connection when the nestjs app is shutdown Signed-off-by: Will Soto --- lib/core.ts | 28 ++++++++++++++++++++++++---- tests/core.spec.ts | 16 +++++++++------- tests/integration.spec.ts | 9 ++++----- tests/module.spec.ts | 12 +++++++++--- 4 files changed, 46 insertions(+), 19 deletions(-) diff --git a/lib/core.ts b/lib/core.ts index 1febb489..d1620528 100644 --- a/lib/core.ts +++ b/lib/core.ts @@ -1,11 +1,13 @@ -/* eslint-disable new-cap */ import { DynamicModule, FactoryProvider, + Inject, Module, + OnApplicationShutdown, Provider, ValueProvider, } from "@nestjs/common"; +import { ModuleRef } from "@nestjs/core"; import Knex from "knex"; import { Model } from "objection"; import { @@ -21,7 +23,13 @@ import { } from "./interfaces"; @Module({}) -export class ObjectionCoreModule { +export class ObjectionCoreModule implements OnApplicationShutdown { + constructor( + @Inject(OBJECTION_MODULE_OPTIONS) + private options: ObjectionModuleOptions, + private moduleRef: ModuleRef, + ) {} + public static register(options: ObjectionModuleOptions): DynamicModule { const BaseModel = options.Model || Model; const connection = Knex(options.config); @@ -95,7 +103,7 @@ export class ObjectionCoreModule { }; } - public static createAsyncProviders( + private static createAsyncProviders( options: ObjectionModuleAsyncOptions, ): Provider[] { if (options.useExisting || options.useFactory) { @@ -113,7 +121,7 @@ export class ObjectionCoreModule { ]; } - public static createAsyncOptionsProvider( + private static createAsyncOptionsProvider( options: ObjectionModuleAsyncOptions, ): Provider { if (options.useFactory) { @@ -144,4 +152,16 @@ export class ObjectionCoreModule { inject: [inject], }; } + + async onApplicationShutdown(): Promise { + await this.disconnect(); + } + + private async disconnect(): Promise { + const connection = this.moduleRef.get( + this.options.name || KNEX_CONNECTION, + ); + + await connection.destroy(); + } } diff --git a/tests/core.spec.ts b/tests/core.spec.ts index 2fcde8df..23f1a9a1 100644 --- a/tests/core.spec.ts +++ b/tests/core.spec.ts @@ -24,7 +24,7 @@ describe("ObjectionCoreModule", () => { }; describe("#register", () => { - beforeEach(async () => { + beforeAll(async () => { testingModule = await Test.createTestingModule({ imports: [ ObjectionCoreModule.register({ @@ -34,6 +34,8 @@ describe("ObjectionCoreModule", () => { }).compile(); }); + afterAll(() => testingModule.close()); + test("provides a connection", () => { const connection = testingModule.get("KnexConnection"); @@ -88,12 +90,12 @@ describe("ObjectionCoreModule", () => { test("throws an error if options.useClass, useExisting, useFactory are not provided", () => { expect(() => { - ObjectionCoreModule.createAsyncProviders({}); + ObjectionCoreModule["createAsyncProviders"]({}); }).toThrowError("Invalid configuration"); }); test("leverages useClass if provided", () => { - const providers = ObjectionCoreModule.createAsyncProviders({ + const providers = ObjectionCoreModule["createAsyncProviders"]({ useClass: ModuleOptionsFactory, }); @@ -111,7 +113,7 @@ describe("ObjectionCoreModule", () => { }); test("returns an array of providers when useExisting is passed", () => { - const providers = ObjectionCoreModule.createAsyncProviders({ + const providers = ObjectionCoreModule["createAsyncProviders"]({ useExisting: ModuleOptionsFactory, }); @@ -125,7 +127,7 @@ describe("ObjectionCoreModule", () => { }); test("returns an array of providers when useFactory is passed", () => { - const providers = ObjectionCoreModule.createAsyncProviders({ + const providers = ObjectionCoreModule["createAsyncProviders"]({ useFactory: () => ({ config, }), @@ -153,7 +155,7 @@ describe("ObjectionCoreModule", () => { } test("returns the appropriate provider when useFactory is passed", () => { - const provider = ObjectionCoreModule.createAsyncOptionsProvider({ + const provider = ObjectionCoreModule["createAsyncOptionsProvider"]({ useFactory: () => ({ config, }), @@ -167,7 +169,7 @@ describe("ObjectionCoreModule", () => { }); test("returns the appropriate provider when useExisting is passed", () => { - const provider = ObjectionCoreModule.createAsyncOptionsProvider({ + const provider = ObjectionCoreModule["createAsyncOptionsProvider"]({ useExisting: ModuleOptionsFactory, }); diff --git a/tests/integration.spec.ts b/tests/integration.spec.ts index cf453dd2..51f85be7 100644 --- a/tests/integration.spec.ts +++ b/tests/integration.spec.ts @@ -7,9 +7,10 @@ import { ConnectionCheck, ConnectionModule } from "./fixtures"; describe("Integration", () => { let connectionCheck: ConnectionCheck; let connection: Knex; + let testingModule: TestingModule; - beforeEach(async () => { - const testingModule: TestingModule = await Test.createTestingModule({ + beforeAll(async () => { + testingModule = await Test.createTestingModule({ imports: [ConnectionModule], }).compile(); @@ -17,9 +18,7 @@ describe("Integration", () => { connection = testingModule.get(KNEX_CONNECTION); }); - afterEach(async () => { - await connection.destroy(); - }); + afterAll(() => testingModule.close()); test("database works", () => { return expect(connection.raw("select 1")).resolves.toEqual([{ "1": 1 }]); diff --git a/tests/module.spec.ts b/tests/module.spec.ts index 7ac2476d..39f2c27d 100644 --- a/tests/module.spec.ts +++ b/tests/module.spec.ts @@ -16,7 +16,7 @@ describe("ObjectionModule", () => { }; describe("#register", () => { - beforeEach(async () => { + beforeAll(async () => { testingModule = await Test.createTestingModule({ imports: [ ObjectionModule.register({ @@ -26,6 +26,8 @@ describe("ObjectionModule", () => { }).compile(); }); + afterAll(() => testingModule.close()); + test("provides a connection", () => { const connection = testingModule.get(KNEX_CONNECTION); @@ -40,7 +42,7 @@ describe("ObjectionModule", () => { }); describe("#registerAsync", () => { - beforeEach(async () => { + beforeAll(async () => { testingModule = await Test.createTestingModule({ imports: [ ObjectionModule.registerAsync({ @@ -54,6 +56,8 @@ describe("ObjectionModule", () => { }).compile(); }); + afterAll(() => testingModule.close()); + test("provides a connection", () => { const connection = testingModule.get("KnexConnection"); @@ -68,7 +72,7 @@ describe("ObjectionModule", () => { }); describe("#forFeature", () => { - beforeEach(async () => { + beforeAll(async () => { testingModule = await Test.createTestingModule({ imports: [ ObjectionModule.register({ @@ -79,6 +83,8 @@ describe("ObjectionModule", () => { }).compile(); }); + afterAll(() => testingModule.close()); + test("provides a model by token", () => { const model = testingModule.get(User);