Skip to content

Commit

Permalink
feat(core): add onApplicationShutdown hook
Browse files Browse the repository at this point in the history
This will clean up the connection when the nestjs app is shutdown

Signed-off-by: Will Soto <[email protected]>
  • Loading branch information
willsoto committed Jul 11, 2020
1 parent 7c3b0c3 commit 668bc8e
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 19 deletions.
28 changes: 24 additions & 4 deletions lib/core.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
/* eslint-disable new-cap */
import {
DynamicModule,
FactoryProvider,
Inject,
Module,
OnApplicationShutdown,
Provider,
ValueProvider,
} from "@nestjs/common";
import { ModuleRef } from "@nestjs/core";
import Knex from "knex";
import { Model } from "objection";
import {
Expand All @@ -21,7 +23,13 @@ import {
} from "./interfaces";

@Module({})
export class ObjectionCoreModule {
export class ObjectionCoreModule implements OnApplicationShutdown {
constructor(
@Inject(OBJECTION_MODULE_OPTIONS)
private options: ObjectionModuleOptions,
private moduleRef: ModuleRef,
) {}

public static register(options: ObjectionModuleOptions): DynamicModule {
const BaseModel = options.Model || Model;
const connection = Knex(options.config);
Expand Down Expand Up @@ -95,7 +103,7 @@ export class ObjectionCoreModule {
};
}

public static createAsyncProviders(
private static createAsyncProviders(
options: ObjectionModuleAsyncOptions,
): Provider[] {
if (options.useExisting || options.useFactory) {
Expand All @@ -113,7 +121,7 @@ export class ObjectionCoreModule {
];
}

public static createAsyncOptionsProvider(
private static createAsyncOptionsProvider(
options: ObjectionModuleAsyncOptions,
): Provider {
if (options.useFactory) {
Expand Down Expand Up @@ -144,4 +152,16 @@ export class ObjectionCoreModule {
inject: [inject],
};
}

async onApplicationShutdown(): Promise<void> {
await this.disconnect();
}

private async disconnect(): Promise<void> {
const connection = this.moduleRef.get<Connection>(
this.options.name || KNEX_CONNECTION,
);

await connection.destroy();
}
}
16 changes: 9 additions & 7 deletions tests/core.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ describe("ObjectionCoreModule", () => {
};

describe("#register", () => {
beforeEach(async () => {
beforeAll(async () => {
testingModule = await Test.createTestingModule({
imports: [
ObjectionCoreModule.register({
Expand All @@ -34,6 +34,8 @@ describe("ObjectionCoreModule", () => {
}).compile();
});

afterAll(() => testingModule.close());

test("provides a connection", () => {
const connection = testingModule.get("KnexConnection");

Expand Down Expand Up @@ -88,12 +90,12 @@ describe("ObjectionCoreModule", () => {

test("throws an error if options.useClass, useExisting, useFactory are not provided", () => {
expect(() => {
ObjectionCoreModule.createAsyncProviders({});
ObjectionCoreModule["createAsyncProviders"]({});
}).toThrowError("Invalid configuration");
});

test("leverages useClass if provided", () => {
const providers = ObjectionCoreModule.createAsyncProviders({
const providers = ObjectionCoreModule["createAsyncProviders"]({
useClass: ModuleOptionsFactory,
});

Expand All @@ -111,7 +113,7 @@ describe("ObjectionCoreModule", () => {
});

test("returns an array of providers when useExisting is passed", () => {
const providers = ObjectionCoreModule.createAsyncProviders({
const providers = ObjectionCoreModule["createAsyncProviders"]({
useExisting: ModuleOptionsFactory,
});

Expand All @@ -125,7 +127,7 @@ describe("ObjectionCoreModule", () => {
});

test("returns an array of providers when useFactory is passed", () => {
const providers = ObjectionCoreModule.createAsyncProviders({
const providers = ObjectionCoreModule["createAsyncProviders"]({
useFactory: () => ({
config,
}),
Expand Down Expand Up @@ -153,7 +155,7 @@ describe("ObjectionCoreModule", () => {
}

test("returns the appropriate provider when useFactory is passed", () => {
const provider = ObjectionCoreModule.createAsyncOptionsProvider({
const provider = ObjectionCoreModule["createAsyncOptionsProvider"]({
useFactory: () => ({
config,
}),
Expand All @@ -167,7 +169,7 @@ describe("ObjectionCoreModule", () => {
});

test("returns the appropriate provider when useExisting is passed", () => {
const provider = ObjectionCoreModule.createAsyncOptionsProvider({
const provider = ObjectionCoreModule["createAsyncOptionsProvider"]({
useExisting: ModuleOptionsFactory,
});

Expand Down
9 changes: 4 additions & 5 deletions tests/integration.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,18 @@ import { ConnectionCheck, ConnectionModule } from "./fixtures";
describe("Integration", () => {
let connectionCheck: ConnectionCheck;
let connection: Knex;
let testingModule: TestingModule;

beforeEach(async () => {
const testingModule: TestingModule = await Test.createTestingModule({
beforeAll(async () => {
testingModule = await Test.createTestingModule({
imports: [ConnectionModule],
}).compile();

connectionCheck = testingModule.get<ConnectionCheck>(ConnectionCheck);
connection = testingModule.get<Knex>(KNEX_CONNECTION);
});

afterEach(async () => {
await connection.destroy();
});
afterAll(() => testingModule.close());

test("database works", () => {
return expect(connection.raw("select 1")).resolves.toEqual([{ "1": 1 }]);
Expand Down
12 changes: 9 additions & 3 deletions tests/module.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ describe("ObjectionModule", () => {
};

describe("#register", () => {
beforeEach(async () => {
beforeAll(async () => {
testingModule = await Test.createTestingModule({
imports: [
ObjectionModule.register({
Expand All @@ -26,6 +26,8 @@ describe("ObjectionModule", () => {
}).compile();
});

afterAll(() => testingModule.close());

test("provides a connection", () => {
const connection = testingModule.get(KNEX_CONNECTION);

Expand All @@ -40,7 +42,7 @@ describe("ObjectionModule", () => {
});

describe("#registerAsync", () => {
beforeEach(async () => {
beforeAll(async () => {
testingModule = await Test.createTestingModule({
imports: [
ObjectionModule.registerAsync({
Expand All @@ -54,6 +56,8 @@ describe("ObjectionModule", () => {
}).compile();
});

afterAll(() => testingModule.close());

test("provides a connection", () => {
const connection = testingModule.get("KnexConnection");

Expand All @@ -68,7 +72,7 @@ describe("ObjectionModule", () => {
});

describe("#forFeature", () => {
beforeEach(async () => {
beforeAll(async () => {
testingModule = await Test.createTestingModule({
imports: [
ObjectionModule.register({
Expand All @@ -79,6 +83,8 @@ describe("ObjectionModule", () => {
}).compile();
});

afterAll(() => testingModule.close());

test("provides a model by token", () => {
const model = testingModule.get(User);

Expand Down

0 comments on commit 668bc8e

Please sign in to comment.