From 4040089a611df048953b73c82eff69ae5ad5f552 Mon Sep 17 00:00:00 2001 From: Michael Bromley Date: Wed, 2 Sep 2020 15:28:08 +0200 Subject: [PATCH] feat(core): Create Transaction decorator --- .../e2e/database-transactions.e2e-spec.ts | 27 +++++---- packages/core/src/api/common/parse-context.ts | 12 +++- .../src/api/common/request-context.service.ts | 4 +- .../decorators/request-context.decorator.ts | 2 +- .../api/decorators/transaction.decorator.ts | 15 +++++ packages/core/src/api/index.ts | 1 + .../core/src/api/middleware/auth-guard.ts | 3 +- .../core/src/api/middleware/id-interceptor.ts | 5 +- .../api/middleware/transaction-interceptor.ts | 59 +++++++++++++++++++ .../validate-custom-fields-interceptor.ts | 21 ++----- packages/core/src/common/constants.ts | 2 + .../core/src/entity/order/order.entity.ts | 15 ++--- packages/core/src/service/index.ts | 1 - .../core/src/service/initializer.service.ts | 42 +++++++++++++ packages/core/src/service/service.module.ts | 35 ++--------- .../service/services/collection.service.ts | 8 +-- .../src/service/services/order.service.ts | 2 + .../src/service/services/tax-rate.service.ts | 4 ++ .../transaction/transactional-connection.ts | 46 +++++++++------ .../src/service/transaction/unit-of-work.ts | 42 ------------- 20 files changed, 202 insertions(+), 144 deletions(-) create mode 100644 packages/core/src/api/decorators/transaction.decorator.ts create mode 100644 packages/core/src/api/middleware/transaction-interceptor.ts create mode 100644 packages/core/src/service/initializer.service.ts delete mode 100644 packages/core/src/service/transaction/unit-of-work.ts diff --git a/packages/core/e2e/database-transactions.e2e-spec.ts b/packages/core/e2e/database-transactions.e2e-spec.ts index bc00ae212e..8cba7246ed 100644 --- a/packages/core/e2e/database-transactions.e2e-spec.ts +++ b/packages/core/e2e/database-transactions.e2e-spec.ts @@ -2,12 +2,14 @@ import { Injectable } from '@nestjs/common'; import { Args, Mutation, Query, Resolver } from '@nestjs/graphql'; import { Administrator, + Ctx, InternalServerError, mergeConfig, NativeAuthenticationMethod, PluginCommonModule, + RequestContext, + Transaction, TransactionalConnection, - UnitOfWork, User, VendurePlugin, } from '@vendure/core'; @@ -22,8 +24,8 @@ import { TEST_SETUP_TIMEOUT_MS, testConfig } from '../../../e2e-common/test-conf class TestUserService { constructor(private connection: TransactionalConnection) {} - async createUser(identifier: string) { - const authMethod = await this.connection.getRepository(NativeAuthenticationMethod).save( + async createUser(ctx: RequestContext, identifier: string) { + const authMethod = await this.connection.getRepository(ctx, NativeAuthenticationMethod).save( new NativeAuthenticationMethod({ identifier, passwordHash: 'abc', @@ -45,12 +47,12 @@ class TestUserService { class TestAdminService { constructor(private connection: TransactionalConnection, private userService: TestUserService) {} - async createAdministrator(emailAddress: string, fail: boolean) { - const user = await this.userService.createUser(emailAddress); + async createAdministrator(ctx: RequestContext, emailAddress: string, fail: boolean) { + const user = await this.userService.createUser(ctx, emailAddress); if (fail) { throw new InternalServerError('Failed!'); } - const admin = await this.connection.getRepository(Administrator).save( + const admin = await this.connection.getRepository(ctx, Administrator).save( new Administrator({ emailAddress, user, @@ -64,19 +66,18 @@ class TestAdminService { @Resolver() class TestResolver { - constructor(private uow: UnitOfWork, private testAdminService: TestAdminService) {} + constructor(private testAdminService: TestAdminService, private connection: TransactionalConnection) {} @Mutation() - createTestAdministrator(@Args() args: any) { - return this.uow.withTransaction(() => { - return this.testAdminService.createAdministrator(args.emailAddress, args.fail); - }); + @Transaction + createTestAdministrator(@Ctx() ctx: RequestContext, @Args() args: any) { + return this.testAdminService.createAdministrator(ctx, args.emailAddress, args.fail); } @Query() async verify() { - const admins = await this.uow.getConnection().getRepository(Administrator).find(); - const users = await this.uow.getConnection().getRepository(User).find(); + const admins = await this.connection.getRepository(Administrator).find(); + const users = await this.connection.getRepository(User).find(); return { admins, users, diff --git a/packages/core/src/api/common/parse-context.ts b/packages/core/src/api/common/parse-context.ts index a63df7128b..3fe0f732c9 100644 --- a/packages/core/src/api/common/parse-context.ts +++ b/packages/core/src/api/common/parse-context.ts @@ -3,13 +3,19 @@ import { GqlExecutionContext } from '@nestjs/graphql'; import { Request, Response } from 'express'; import { GraphQLResolveInfo } from 'graphql'; +export type RestContext = { req: Request; res: Response; isGraphQL: false; info: undefined }; +export type GraphQLContext = { + req: Request; + res: Response; + isGraphQL: true; + info: GraphQLResolveInfo; +}; + /** * Parses in the Nest ExecutionContext of the incoming request, accounting for both * GraphQL & REST requests. */ -export function parseContext( - context: ExecutionContext | ArgumentsHost, -): { req: Request; res: Response; isGraphQL: boolean; info?: GraphQLResolveInfo } { +export function parseContext(context: ExecutionContext | ArgumentsHost): RestContext | GraphQLContext { const graphQlContext = GqlExecutionContext.create(context as ExecutionContext); const restContext = GqlExecutionContext.create(context as ExecutionContext); const info = graphQlContext.getInfo(); diff --git a/packages/core/src/api/common/request-context.service.ts b/packages/core/src/api/common/request-context.service.ts index caff440d7e..12e2bc582b 100644 --- a/packages/core/src/api/common/request-context.service.ts +++ b/packages/core/src/api/common/request-context.service.ts @@ -12,8 +12,6 @@ import { ChannelService } from '../../service/services/channel.service'; import { getApiType } from './get-api-type'; import { RequestContext } from './request-context'; -export const REQUEST_CONTEXT_KEY = 'vendureRequestContext'; - /** * Creates new RequestContext instances. */ @@ -79,7 +77,7 @@ export class RequestContextService { if (!user || !channel) { return false; } - const permissionsOnChannel = user.channelPermissions.find((c) => idsAreEqual(c.id, channel.id)); + const permissionsOnChannel = user.channelPermissions.find(c => idsAreEqual(c.id, channel.id)); if (permissionsOnChannel) { return this.arraysIntersect(permissionsOnChannel.permissions, permissions); } diff --git a/packages/core/src/api/decorators/request-context.decorator.ts b/packages/core/src/api/decorators/request-context.decorator.ts index c9915fe08d..9e4c0e1c85 100644 --- a/packages/core/src/api/decorators/request-context.decorator.ts +++ b/packages/core/src/api/decorators/request-context.decorator.ts @@ -1,6 +1,6 @@ import { ContextType, createParamDecorator, ExecutionContext } from '@nestjs/common'; -import { REQUEST_CONTEXT_KEY } from '../common/request-context.service'; +import { REQUEST_CONTEXT_KEY } from '../../common/constants'; /** * @description diff --git a/packages/core/src/api/decorators/transaction.decorator.ts b/packages/core/src/api/decorators/transaction.decorator.ts new file mode 100644 index 0000000000..398b59144b --- /dev/null +++ b/packages/core/src/api/decorators/transaction.decorator.ts @@ -0,0 +1,15 @@ +import { applyDecorators, UseInterceptors } from '@nestjs/common'; + +import { TransactionInterceptor } from '../middleware/transaction-interceptor'; + +/** + * @description + * Runs the decorated method in a TypeORM transaction. It works by creating a transctional + * QueryRunner which gets attached to the RequestContext object. When the RequestContext + * is the passed to the {@link TransactionalConnection} `getRepository()` method, this + * QueryRunner is used to execute the queries within this transaction. + * + * @docsCategory request + * @docsPage Decorators + */ +export const Transaction = applyDecorators(UseInterceptors(TransactionInterceptor)); diff --git a/packages/core/src/api/index.ts b/packages/core/src/api/index.ts index 9dfeaf8c24..650d4a12da 100644 --- a/packages/core/src/api/index.ts +++ b/packages/core/src/api/index.ts @@ -1,6 +1,7 @@ export { ApiType } from './common/get-api-type'; export * from './common/request-context'; export * from './decorators/allow.decorator'; +export * from './decorators/transaction.decorator'; export * from './decorators/api.decorator'; export * from './decorators/request-context.decorator'; export * from './resolvers/admin/search.resolver'; diff --git a/packages/core/src/api/middleware/auth-guard.ts b/packages/core/src/api/middleware/auth-guard.ts index 77c7241d01..2b5c555e3c 100644 --- a/packages/core/src/api/middleware/auth-guard.ts +++ b/packages/core/src/api/middleware/auth-guard.ts @@ -3,13 +3,14 @@ import { Reflector } from '@nestjs/core'; import { Permission } from '@vendure/common/lib/generated-types'; import { Request, Response } from 'express'; +import { REQUEST_CONTEXT_KEY } from '../../common/constants'; import { ForbiddenError } from '../../common/error/errors'; import { ConfigService } from '../../config/config.service'; import { CachedSession } from '../../config/session-cache/session-cache-strategy'; import { SessionService } from '../../service/services/session.service'; import { extractSessionToken } from '../common/extract-session-token'; import { parseContext } from '../common/parse-context'; -import { REQUEST_CONTEXT_KEY, RequestContextService } from '../common/request-context.service'; +import { RequestContextService } from '../common/request-context.service'; import { setSessionToken } from '../common/set-session-token'; import { PERMISSIONS_METADATA_KEY } from '../decorators/allow.decorator'; diff --git a/packages/core/src/api/middleware/id-interceptor.ts b/packages/core/src/api/middleware/id-interceptor.ts index 3c22df536d..15144f7ad3 100644 --- a/packages/core/src/api/middleware/id-interceptor.ts +++ b/packages/core/src/api/middleware/id-interceptor.ts @@ -29,10 +29,9 @@ export class IdInterceptor implements NestInterceptor { constructor(private idCodecService: IdCodecService) {} intercept(context: ExecutionContext, next: CallHandler): Observable { - const { isGraphQL, req } = parseContext(context); - if (isGraphQL) { + const { isGraphQL, req, info } = parseContext(context); + if (isGraphQL && info) { const args = GqlExecutionContext.create(context).getArgs(); - const info = GqlExecutionContext.create(context).getInfo(); const transformer = this.getTransformerForSchema(info.schema); this.decodeIdArguments(transformer, info.operation, args); } diff --git a/packages/core/src/api/middleware/transaction-interceptor.ts b/packages/core/src/api/middleware/transaction-interceptor.ts new file mode 100644 index 0000000000..878a1abf33 --- /dev/null +++ b/packages/core/src/api/middleware/transaction-interceptor.ts @@ -0,0 +1,59 @@ +import { CallHandler, ExecutionContext, Injectable, NestInterceptor } from '@nestjs/common'; +import { Observable, of } from 'rxjs'; +import { tap } from 'rxjs/operators'; + +import { REQUEST_CONTEXT_KEY, TRANSACTION_MANAGER_KEY } from '../../common/constants'; +import { TransactionalConnection } from '../../service/transaction/transactional-connection'; +import { parseContext } from '../common/parse-context'; +import { RequestContext } from '../common/request-context'; + +/** + * @description + * Used by the {@link Transaction} decorator to create a transactional query runner + * and attach it to the RequestContext. + */ +@Injectable() +export class TransactionInterceptor implements NestInterceptor { + constructor(private connection: TransactionalConnection) {} + intercept(context: ExecutionContext, next: CallHandler): Observable { + const { isGraphQL, req } = parseContext(context); + const ctx = (req as any)[REQUEST_CONTEXT_KEY]; + if (ctx) { + return of(this.withTransaction(ctx, () => next.handle().toPromise())); + } else { + return next.handle(); + } + } + + /** + * @description + * Executes the `work` function within the context of a transaction. + */ + private async withTransaction(ctx: RequestContext, work: () => T): Promise { + const queryRunnerExists = !!(ctx as any)[TRANSACTION_MANAGER_KEY]; + if (queryRunnerExists) { + // If a QueryRunner already exists on the RequestContext, there must be an existing + // outer transaction in progress. In that case, we just execute the work function + // as usual without needing to further wrap in a transaction. + return work(); + } + const queryRunner = this.connection.rawConnection.createQueryRunner(); + await queryRunner.startTransaction(); + (ctx as any)[TRANSACTION_MANAGER_KEY] = queryRunner.manager; + + try { + const result = await work(); + if (queryRunner.isTransactionActive) { + await queryRunner.commitTransaction(); + } + return result; + } catch (error) { + if (queryRunner.isTransactionActive) { + await queryRunner.rollbackTransaction(); + } + throw error; + } finally { + await queryRunner.release(); + } + } +} diff --git a/packages/core/src/api/middleware/validate-custom-fields-interceptor.ts b/packages/core/src/api/middleware/validate-custom-fields-interceptor.ts index 79c264f982..9c0519105e 100644 --- a/packages/core/src/api/middleware/validate-custom-fields-interceptor.ts +++ b/packages/core/src/api/middleware/validate-custom-fields-interceptor.ts @@ -1,29 +1,20 @@ import { CallHandler, ExecutionContext, Injectable, NestInterceptor } from '@nestjs/common'; import { GqlExecutionContext } from '@nestjs/graphql'; import { LanguageCode } from '@vendure/common/lib/generated-types'; -import { assertNever } from '@vendure/common/lib/shared-utils'; import { - DefinitionNode, GraphQLInputType, GraphQLList, GraphQLNonNull, - GraphQLResolveInfo, GraphQLSchema, OperationDefinitionNode, TypeNode, } from 'graphql'; -import { UserInputError } from '../../common/error/errors'; +import { REQUEST_CONTEXT_KEY } from '../../common/constants'; import { ConfigService } from '../../config/config.service'; -import { - CustomFieldConfig, - CustomFields, - LocaleStringCustomFieldConfig, - StringCustomFieldConfig, -} from '../../config/custom-field/custom-field-types'; +import { CustomFieldConfig, CustomFields } from '../../config/custom-field/custom-field-types'; import { parseContext } from '../common/parse-context'; import { RequestContext } from '../common/request-context'; -import { REQUEST_CONTEXT_KEY } from '../common/request-context.service'; import { validateCustomFieldValue } from '../common/validate-custom-field-value'; /** @@ -44,12 +35,12 @@ export class ValidateCustomFieldsInterceptor implements NestInterceptor { } intercept(context: ExecutionContext, next: CallHandler) { - const { isGraphQL } = parseContext(context); - if (isGraphQL) { + const parsedContext = parseContext(context); + if (parsedContext.isGraphQL) { const gqlExecutionContext = GqlExecutionContext.create(context); - const { operation, schema } = gqlExecutionContext.getInfo(); + const { operation, schema } = parsedContext.info; const variables = gqlExecutionContext.getArgs(); - const ctx: RequestContext = gqlExecutionContext.getContext().req[REQUEST_CONTEXT_KEY]; + const ctx: RequestContext = (parsedContext.req as any)[REQUEST_CONTEXT_KEY]; if (operation.operation === 'mutation') { const inputTypeNames = this.getArgumentMap(operation, schema); diff --git a/packages/core/src/common/constants.ts b/packages/core/src/common/constants.ts index 27e807ad68..5b2cd33d26 100644 --- a/packages/core/src/common/constants.ts +++ b/packages/core/src/common/constants.ts @@ -5,3 +5,5 @@ import { LanguageCode } from '@vendure/common/lib/generated-types'; * VendureConfig to ensure at least a valid LanguageCode is available. */ export const DEFAULT_LANGUAGE_CODE = LanguageCode.en; +export const TRANSACTION_MANAGER_KEY = Symbol('TRANSACTION_MANAGER'); +export const REQUEST_CONTEXT_KEY = 'vendureRequestContext'; diff --git a/packages/core/src/entity/order/order.entity.ts b/packages/core/src/entity/order/order.entity.ts index c924b81e13..b0789f443d 100644 --- a/packages/core/src/entity/order/order.entity.ts +++ b/packages/core/src/entity/order/order.entity.ts @@ -3,9 +3,11 @@ import { DeepPartial, ID } from '@vendure/common/lib/shared-types'; import { Column, Entity, JoinTable, ManyToMany, ManyToOne, OneToMany } from 'typeorm'; import { Calculated } from '../../common/calculated-decorator'; +import { ChannelAware } from '../../common/types/common-types'; import { HasCustomFields } from '../../config/custom-field/custom-field-types'; import { OrderState } from '../../service/helpers/order-state-machine/order-state'; import { VendureEntity } from '../base/base.entity'; +import { Channel } from '../channel/channel.entity'; import { CustomOrderFields } from '../custom-entity-fields'; import { Customer } from '../customer/customer.entity'; import { EntityId } from '../entity-id.decorator'; @@ -14,8 +16,6 @@ import { OrderLine } from '../order-line/order-line.entity'; import { Payment } from '../payment/payment.entity'; import { Promotion } from '../promotion/promotion.entity'; import { ShippingMethod } from '../shipping-method/shipping-method.entity'; -import { ChannelAware } from '../../common/types/common-types'; -import { Channel } from '../channel/channel.entity'; /** * @description @@ -96,7 +96,7 @@ export class Order extends VendureEntity implements ChannelAware, HasCustomField @EntityId({ nullable: true }) taxZoneId?: ID; - @ManyToMany((type) => Channel) + @ManyToMany(type => Channel) @JoinTable() channels: Channel[]; @@ -134,11 +134,8 @@ export class Order extends VendureEntity implements ChannelAware, HasCustomField } getOrderItems(): OrderItem[] { - return this.lines.reduce( - (items, line) => { - return [...items, ...line.items]; - }, - [] as OrderItem[], - ); + return this.lines.reduce((items, line) => { + return [...items, ...line.items]; + }, [] as OrderItem[]); } } diff --git a/packages/core/src/service/index.ts b/packages/core/src/service/index.ts index 212ff23cc6..66ce7d17cf 100644 --- a/packages/core/src/service/index.ts +++ b/packages/core/src/service/index.ts @@ -33,5 +33,4 @@ export * from './services/tax-category.service'; export * from './services/tax-rate.service'; export * from './services/user.service'; export * from './services/user.service'; -export * from './transaction/unit-of-work'; export * from './transaction/transactional-connection'; diff --git a/packages/core/src/service/initializer.service.ts b/packages/core/src/service/initializer.service.ts new file mode 100644 index 0000000000..8c553000f5 --- /dev/null +++ b/packages/core/src/service/initializer.service.ts @@ -0,0 +1,42 @@ +import { Injectable } from '@nestjs/common'; + +import { AdministratorService } from './services/administrator.service'; +import { ChannelService } from './services/channel.service'; +import { GlobalSettingsService } from './services/global-settings.service'; +import { PaymentMethodService } from './services/payment-method.service'; +import { RoleService } from './services/role.service'; +import { ShippingMethodService } from './services/shipping-method.service'; +import { TaxRateService } from './services/tax-rate.service'; + +/** + * Only used internally to run the various service init methods in the correct + * sequence on bootstrap. + */ +@Injectable() +export class InitializerService { + constructor( + private channelService: ChannelService, + private roleService: RoleService, + private administratorService: AdministratorService, + private taxRateService: TaxRateService, + private shippingMethodService: ShippingMethodService, + private paymentMethodService: PaymentMethodService, + private globalSettingsService: GlobalSettingsService, + ) {} + + async onModuleInit() { + // IMPORTANT - why manually invoke these init methods rather than just relying on + // Nest's "onModuleInit" lifecycle hook within each individual service class? + // The reason is that the order of invokation matters. By explicitly invoking the + // methods below, we can e.g. guarantee that the default channel exists + // (channelService.initChannels()) before we try to create any roles (which assume that + // there is a default Channel to work with. + await this.globalSettingsService.initGlobalSettings(); + await this.channelService.initChannels(); + await this.roleService.initRoles(); + await this.administratorService.initAdministrators(); + await this.taxRateService.initTaxRates(); + await this.shippingMethodService.initShippingMethods(); + await this.paymentMethodService.initPaymentMethods(); + } +} diff --git a/packages/core/src/service/service.module.ts b/packages/core/src/service/service.module.ts index a731b2b1dc..a3eeb47eca 100644 --- a/packages/core/src/service/service.module.ts +++ b/packages/core/src/service/service.module.ts @@ -1,4 +1,4 @@ -import { DynamicModule, Module, OnModuleInit } from '@nestjs/common'; +import { DynamicModule, Module } from '@nestjs/common'; import { TypeOrmModule } from '@nestjs/typeorm'; import { ConnectionOptions } from 'typeorm'; @@ -25,6 +25,7 @@ import { SlugValidator } from './helpers/slug-validator/slug-validator'; import { TaxCalculator } from './helpers/tax-calculator/tax-calculator'; import { TranslatableSaver } from './helpers/translatable-saver/translatable-saver'; import { VerificationTokenGenerator } from './helpers/verification-token-generator/verification-token-generator'; +import { InitializerService } from './initializer.service'; import { AdministratorService } from './services/administrator.service'; import { AssetService } from './services/asset.service'; import { AuthService } from './services/auth.service'; @@ -55,7 +56,6 @@ import { TaxRateService } from './services/tax-rate.service'; import { UserService } from './services/user.service'; import { ZoneService } from './services/zone.service'; import { TransactionalConnection } from './transaction/transactional-connection'; -import { UnitOfWork } from './transaction/unit-of-work'; const services = [ AdministratorService, @@ -104,7 +104,6 @@ const helpers = [ ShippingConfiguration, SlugValidator, ExternalAuthenticationService, - UnitOfWork, TransactionalConnection, ]; @@ -120,36 +119,10 @@ let workerTypeOrmModule: DynamicModule; */ @Module({ imports: [ConfigModule, EventBusModule, WorkerServiceModule, JobQueueModule], - providers: [...services, ...helpers], + providers: [...services, ...helpers, InitializerService], exports: [...services, ...helpers], }) -export class ServiceCoreModule implements OnModuleInit { - constructor( - private channelService: ChannelService, - private roleService: RoleService, - private administratorService: AdministratorService, - private taxRateService: TaxRateService, - private shippingMethodService: ShippingMethodService, - private paymentMethodService: PaymentMethodService, - private globalSettingsService: GlobalSettingsService, - ) {} - - async onModuleInit() { - // IMPORTANT - why manually invoke these init methods rather than just relying on - // Nest's "onModuleInit" lifecycle hook within each individual service class? - // The reason is that the order of invokation matters. By explicitly invoking the - // methods below, we can e.g. guarantee that the default channel exists - // (channelService.initChannels()) before we try to create any roles (which assume that - // there is a default Channel to work with. - await this.globalSettingsService.initGlobalSettings(); - await this.channelService.initChannels(); - await this.roleService.initRoles(); - await this.administratorService.initAdministrators(); - await this.taxRateService.initTaxRates(); - await this.shippingMethodService.initShippingMethods(); - await this.paymentMethodService.initPaymentMethods(); - } -} +export class ServiceCoreModule {} /** * The ServiceModule is responsible for the service layer, i.e. accessing the database diff --git a/packages/core/src/service/services/collection.service.ts b/packages/core/src/service/services/collection.service.ts index af35243358..388618a0b7 100644 --- a/packages/core/src/service/services/collection.service.ts +++ b/packages/core/src/service/services/collection.service.ts @@ -1,4 +1,4 @@ -import { OnModuleInit, Optional } from '@nestjs/common'; +import { Injectable, OnModuleInit, Optional } from '@nestjs/common'; import { InjectConnection } from '@nestjs/typeorm'; import { ConfigurableOperation, @@ -49,15 +49,13 @@ import { AssetService } from './asset.service'; import { ChannelService } from './channel.service'; import { FacetValueService } from './facet-value.service'; +@Injectable() export class CollectionService implements OnModuleInit { private rootCollection: Collection | undefined; private applyFiltersQueue: JobQueue; constructor( - // Optional() allows the onModuleInit() hook to run with injected - // providers despite the request-scoped TransactionalConnection - // not yet having been created - @Optional() private connection: TransactionalConnection, + private connection: TransactionalConnection, private channelService: ChannelService, private assetService: AssetService, private facetValueService: FacetValueService, diff --git a/packages/core/src/service/services/order.service.ts b/packages/core/src/service/services/order.service.ts index a5961f4fd9..ca3051f9bb 100644 --- a/packages/core/src/service/services/order.service.ts +++ b/packages/core/src/service/services/order.service.ts @@ -1,3 +1,4 @@ +import { Injectable } from '@nestjs/common'; import { PaymentInput } from '@vendure/common/lib/generated-shop-types'; import { AddNoteToOrderInput, @@ -72,6 +73,7 @@ import { ProductVariantService } from './product-variant.service'; import { PromotionService } from './promotion.service'; import { StockMovementService } from './stock-movement.service'; +@Injectable() export class OrderService { constructor( private connection: TransactionalConnection, diff --git a/packages/core/src/service/services/tax-rate.service.ts b/packages/core/src/service/services/tax-rate.service.ts index faa8c88ab9..c0fe2c976a 100644 --- a/packages/core/src/service/services/tax-rate.service.ts +++ b/packages/core/src/service/services/tax-rate.service.ts @@ -1,3 +1,5 @@ +import { Injectable } from '@nestjs/common'; +import { InjectConnection } from '@nestjs/typeorm'; import { CreateTaxRateInput, DeletionResponse, @@ -5,6 +7,7 @@ import { UpdateTaxRateInput, } from '@vendure/common/lib/generated-types'; import { ID, PaginatedList } from '@vendure/common/lib/shared-types'; +import { Connection } from 'typeorm'; import { RequestContext } from '../../api/common/request-context'; import { EntityNotFoundError } from '../../common/error/errors'; @@ -23,6 +26,7 @@ import { patchEntity } from '../helpers/utils/patch-entity'; import { TransactionalConnection } from '../transaction/transactional-connection'; import { TaxRateUpdatedMessage } from '../types/tax-rate-messages'; +@Injectable() export class TaxRateService { /** * We cache all active TaxRates to avoid hitting the DB many times diff --git a/packages/core/src/service/transaction/transactional-connection.ts b/packages/core/src/service/transaction/transactional-connection.ts index c0dff70e1e..0a0343135f 100644 --- a/packages/core/src/service/transaction/transactional-connection.ts +++ b/packages/core/src/service/transaction/transactional-connection.ts @@ -1,27 +1,26 @@ -import { Injectable, Scope } from '@nestjs/common'; -import { Connection, ConnectionOptions, EntitySchema, getRepository, ObjectType, Repository } from 'typeorm'; +import { Injectable } from '@nestjs/common'; +import { InjectConnection } from '@nestjs/typeorm'; +import { Connection, EntitySchema, getRepository, ObjectType, Repository } from 'typeorm'; import { RepositoryFactory } from 'typeorm/repository/RepositoryFactory'; -import { UnitOfWork } from './unit-of-work'; +import { RequestContext } from '../../api/common/request-context'; +import { TRANSACTION_MANAGER_KEY } from '../../common/constants'; /** * @description * The TransactionalConnection is a wrapper around the TypeORM `Connection` object which works in conjunction - * with the {@link UnitOfWork} class to implement per-request transactions. All services which access the + * with the {@link Transaction} decorator to implement per-request transactions. All services which access the * database should use this class rather than the raw TypeORM connection, to ensure that db changes can be * easily wrapped in transactions when required. * * The service layer does not need to know about the scope of a transaction, as this is covered at the - * API level depending on the nature of the request. - * - * Based on the pattern outlined in - * [this article](https://aaronboman.com/programming/2020/05/15/per-request-database-transactions-with-nestjs-and-typeorm/) + * API by the use of the `Transaction` decorator. * * @docsCategory data-access */ @Injectable() export class TransactionalConnection { - constructor(private uow: UnitOfWork) {} + constructor(@InjectConnection() private connection: Connection) {} /** * @description @@ -30,7 +29,7 @@ export class TransactionalConnection { * transactions. */ get rawConnection(): Connection { - return this.uow.getConnection(); + return this.connection; } /** @@ -38,13 +37,26 @@ export class TransactionalConnection { * Gets a repository bound to the current transaction manager * or defaults to the current connection's call to getRepository(). */ - getRepository(target: ObjectType | EntitySchema | string): Repository { - const transactionManager = this.uow.getTransactionManager(); - if (transactionManager) { - const connection = this.uow.getConnection(); - const metadata = connection.getMetadata(target); - return new RepositoryFactory().create(transactionManager, metadata); + getRepository(target: ObjectType | EntitySchema | string): Repository; + getRepository( + ctx: RequestContext, + target: ObjectType | EntitySchema | string, + ): Repository; + getRepository( + ctxOrTarget: RequestContext | ObjectType | EntitySchema | string, + maybeTarget?: ObjectType | EntitySchema | string, + ): Repository { + if (ctxOrTarget instanceof RequestContext) { + const transactionManager = (ctxOrTarget as any)[TRANSACTION_MANAGER_KEY]; + if (transactionManager && maybeTarget) { + const metadata = this.connection.getMetadata(maybeTarget); + return new RepositoryFactory().create(transactionManager, metadata); + } else { + // tslint:disable-next-line:no-non-null-assertion + return getRepository(maybeTarget!); + } + } else { + return getRepository(ctxOrTarget); } - return getRepository(target); } } diff --git a/packages/core/src/service/transaction/unit-of-work.ts b/packages/core/src/service/transaction/unit-of-work.ts deleted file mode 100644 index c608f55dcd..0000000000 --- a/packages/core/src/service/transaction/unit-of-work.ts +++ /dev/null @@ -1,42 +0,0 @@ -import { Injectable, Scope } from '@nestjs/common'; -import { InjectConnection } from '@nestjs/typeorm'; -import { Connection, EntityManager } from 'typeorm'; - -/** - * @description - * This class is used to wrap an entire request in a database transaction. It should - * generally be injected at the API layer and wrap the service-layer call(s) so that - * all DB access within the `withTransaction()` method takes place within a transaction. - * - * @docsCategory data-access - */ -@Injectable() -export class UnitOfWork { - private transactionManager: EntityManager | null; - constructor(@InjectConnection() private connection: Connection) {} - - getTransactionManager(): EntityManager | null { - return this.transactionManager; - } - - getConnection(): Connection { - return this.connection; - } - - async withTransaction(work: () => T): Promise { - const queryRunner = this.connection.createQueryRunner(); - await queryRunner.startTransaction(); - this.transactionManager = queryRunner.manager; - try { - const result = await work(); - await queryRunner.commitTransaction(); - return result; - } catch (error) { - await queryRunner.rollbackTransaction(); - throw error; - } finally { - await queryRunner.release(); - this.transactionManager = null; - } - } -}