From 7b76a7c12e4e93fc1240dd619755d40606004672 Mon Sep 17 00:00:00 2001
From: Michael Bromley <michael@michaelbromley.co.uk>
Date: Mon, 30 Aug 2021 16:08:23 +0200
Subject: [PATCH] feat(core): Always pass current Order to TaxZoneStrategy
 calls

Relates to #1048. This commit also introduces the use of RequestContextCacheService to optimize the
number of calls made to the `determineTaxZone()` method, as well as allowing async return values.
---
 .../core/src/config/tax/tax-zone-strategy.ts  | 17 ++++++++++++++++-
 .../order-calculator/order-calculator.spec.ts |  2 ++
 .../order-calculator/order-calculator.ts      |  7 ++++++-
 .../helpers/order-modifier/order-modifier.ts  |  6 +++++-
 .../service/services/order-testing.service.ts |  2 +-
 .../src/service/services/order.service.ts     |  8 +++-----
 .../services/product-variant.service.ts       | 19 ++++++++++++++++---
 7 files changed, 49 insertions(+), 12 deletions(-)

diff --git a/packages/core/src/config/tax/tax-zone-strategy.ts b/packages/core/src/config/tax/tax-zone-strategy.ts
index c915afec69..52b22dada9 100644
--- a/packages/core/src/config/tax/tax-zone-strategy.ts
+++ b/packages/core/src/config/tax/tax-zone-strategy.ts
@@ -6,8 +6,23 @@ import { Channel, Order, Zone } from '../../entity';
  * @description
  * Defines how the active {@link Zone} is determined for the purposes of calculating taxes.
  *
+ * This strategy is used in 2 scenarios:
+ *
+ * 1. To determine the applicable Zone when calculating the taxRate to apply when displaying ProductVariants. In this case the
+ * `order` argument will be undefined, as the request is not related to a specific Order.
+ * 2. To determine the applicable Zone when calculating the taxRate on the contents of a specific Order. In this case the
+ * `order` argument _will_ be defined, and can be used in the logic. For example, the shipping address can be taken into account.
+ *
+ * Note that this method is called very often in a typical user session, so any work it performs should be designed with as little
+ * performance impact as possible.
+ *
  * @docsCategory tax
  */
 export interface TaxZoneStrategy extends InjectableStrategy {
-    determineTaxZone(ctx: RequestContext, zones: Zone[], channel: Channel, order?: Order): Zone | undefined;
+    determineTaxZone(
+        ctx: RequestContext,
+        zones: Zone[],
+        channel: Channel,
+        order?: Order,
+    ): Zone | Promise<Zone> | undefined;
 }
diff --git a/packages/core/src/service/helpers/order-calculator/order-calculator.spec.ts b/packages/core/src/service/helpers/order-calculator/order-calculator.spec.ts
index d130a10340..6934c164e2 100644
--- a/packages/core/src/service/helpers/order-calculator/order-calculator.spec.ts
+++ b/packages/core/src/service/helpers/order-calculator/order-calculator.spec.ts
@@ -3,6 +3,7 @@ import { AdjustmentType, LanguageCode, TaxLine } from '@vendure/common/lib/gener
 import { summate } from '@vendure/common/lib/shared-utils';
 
 import { RequestContext } from '../../../api/common/request-context';
+import { RequestContextCacheService } from '../../../cache/request-context-cache.service';
 import { PromotionItemAction, PromotionOrderAction, PromotionShippingAction } from '../../../config';
 import { ConfigService } from '../../../config/config.service';
 import { MockConfigService } from '../../../config/config.service.mock';
@@ -1537,6 +1538,7 @@ function createTestModule() {
     return Test.createTestingModule({
         providers: [
             OrderCalculator,
+            RequestContextCacheService,
             { provide: TaxRateService, useClass: MockTaxRateService },
             { provide: ShippingCalculator, useValue: { getEligibleShippingMethods: () => [] } },
             {
diff --git a/packages/core/src/service/helpers/order-calculator/order-calculator.ts b/packages/core/src/service/helpers/order-calculator/order-calculator.ts
index bc2d06e323..a9cff5fdc4 100644
--- a/packages/core/src/service/helpers/order-calculator/order-calculator.ts
+++ b/packages/core/src/service/helpers/order-calculator/order-calculator.ts
@@ -3,6 +3,7 @@ import { filterAsync } from '@vendure/common/lib/filter-async';
 import { AdjustmentType } from '@vendure/common/lib/generated-types';
 
 import { RequestContext } from '../../../api/common/request-context';
+import { RequestContextCacheService } from '../../../cache/request-context-cache.service';
 import { InternalServerError } from '../../../common/error/errors';
 import { netPriceOf } from '../../../common/tax-utils';
 import { idsAreEqual } from '../../../common/utils';
@@ -27,6 +28,7 @@ export class OrderCalculator {
         private taxRateService: TaxRateService,
         private shippingMethodService: ShippingMethodService,
         private shippingCalculator: ShippingCalculator,
+        private requestContextCache: RequestContextCacheService,
     ) {}
 
     /**
@@ -43,7 +45,10 @@ export class OrderCalculator {
     ): Promise<OrderItem[]> {
         const { taxZoneStrategy } = this.configService.taxOptions;
         const zones = this.zoneService.findAll(ctx);
-        const activeTaxZone = taxZoneStrategy.determineTaxZone(ctx, zones, ctx.channel, order);
+        const activeTaxZone = await this.requestContextCache.get(ctx, 'activeTaxZone', () =>
+            taxZoneStrategy.determineTaxZone(ctx, zones, ctx.channel, order),
+        );
+
         let taxZoneChanged = false;
         if (!activeTaxZone) {
             throw new InternalServerError(`error.no-active-tax-zone`);
diff --git a/packages/core/src/service/helpers/order-modifier/order-modifier.ts b/packages/core/src/service/helpers/order-modifier/order-modifier.ts
index c8a5fc2647..3017500792 100644
--- a/packages/core/src/service/helpers/order-modifier/order-modifier.ts
+++ b/packages/core/src/service/helpers/order-modifier/order-modifier.ts
@@ -130,7 +130,11 @@ export class OrderModifier {
             ],
         });
         lineWithRelations.productVariant = translateDeep(
-            await this.productVariantService.applyChannelPriceAndTax(lineWithRelations.productVariant, ctx),
+            await this.productVariantService.applyChannelPriceAndTax(
+                lineWithRelations.productVariant,
+                ctx,
+                order,
+            ),
             ctx.languageCode,
         );
         order.lines.push(lineWithRelations);
diff --git a/packages/core/src/service/services/order-testing.service.ts b/packages/core/src/service/services/order-testing.service.ts
index 0d04cd4eaf..af891c336e 100644
--- a/packages/core/src/service/services/order-testing.service.ts
+++ b/packages/core/src/service/services/order-testing.service.ts
@@ -119,7 +119,7 @@ export class OrderTestingService {
                 line.productVariantId,
                 { relations: ['taxCategory'] },
             );
-            await this.productVariantService.applyChannelPriceAndTax(productVariant, ctx);
+            await this.productVariantService.applyChannelPriceAndTax(productVariant, ctx, mockOrder);
             const orderLine = new OrderLine({
                 productVariant,
                 items: [],
diff --git a/packages/core/src/service/services/order.service.ts b/packages/core/src/service/services/order.service.ts
index df4423dab4..ec981e5523 100644
--- a/packages/core/src/service/services/order.service.ts
+++ b/packages/core/src/service/services/order.service.ts
@@ -213,7 +213,7 @@ export class OrderService {
         if (order) {
             for (const line of order.lines) {
                 line.productVariant = translateDeep(
-                    await this.productVariantService.applyChannelPriceAndTax(line.productVariant, ctx),
+                    await this.productVariantService.applyChannelPriceAndTax(line.productVariant, ctx, order),
                     ctx.languageCode,
                 );
             }
@@ -1348,10 +1348,8 @@ export class OrderService {
         updatedOrderLine?: OrderLine,
     ): Promise<Order> {
         if (updatedOrderLine) {
-            const {
-                orderItemPriceCalculationStrategy,
-                changedPriceHandlingStrategy,
-            } = this.configService.orderOptions;
+            const { orderItemPriceCalculationStrategy, changedPriceHandlingStrategy } =
+                this.configService.orderOptions;
             let priceResult = await orderItemPriceCalculationStrategy.calculateUnitPrice(
                 ctx,
                 updatedOrderLine.productVariant,
diff --git a/packages/core/src/service/services/product-variant.service.ts b/packages/core/src/service/services/product-variant.service.ts
index 4925b19c29..066c437d29 100644
--- a/packages/core/src/service/services/product-variant.service.ts
+++ b/packages/core/src/service/services/product-variant.service.ts
@@ -19,7 +19,14 @@ import { ListQueryOptions } from '../../common/types/common-types';
 import { Translated } from '../../common/types/locale-types';
 import { idsAreEqual } from '../../common/utils';
 import { ConfigService } from '../../config/config.service';
-import { Channel, OrderLine, ProductOptionGroup, ProductVariantPrice, TaxCategory } from '../../entity';
+import {
+    Channel,
+    Order,
+    OrderLine,
+    ProductOptionGroup,
+    ProductVariantPrice,
+    TaxCategory,
+} from '../../entity';
 import { FacetValue } from '../../entity/facet-value/facet-value.entity';
 import { ProductOption } from '../../entity/product-option/product-option.entity';
 import { ProductVariantTranslation } from '../../entity/product-variant/product-variant-translation.entity';
@@ -551,7 +558,11 @@ export class ProductVariantService {
     /**
      * Populates the `price` field with the price for the specified channel.
      */
-    async applyChannelPriceAndTax(variant: ProductVariant, ctx: RequestContext): Promise<ProductVariant> {
+    async applyChannelPriceAndTax(
+        variant: ProductVariant,
+        ctx: RequestContext,
+        order?: Order,
+    ): Promise<ProductVariant> {
         const channelPrice = variant.productVariantPrices.find(p => idsAreEqual(p.channelId, ctx.channelId));
         if (!channelPrice) {
             throw new InternalServerError(`error.no-price-found-for-channel`, {
@@ -561,7 +572,9 @@ export class ProductVariantService {
         }
         const { taxZoneStrategy } = this.configService.taxOptions;
         const zones = this.zoneService.findAll(ctx);
-        const activeTaxZone = taxZoneStrategy.determineTaxZone(ctx, zones, ctx.channel);
+        const activeTaxZone = await this.requestCache.get(ctx, 'activeTaxZone', () =>
+            taxZoneStrategy.determineTaxZone(ctx, zones, ctx.channel, order),
+        );
         if (!activeTaxZone) {
             throw new InternalServerError(`error.no-active-tax-zone`);
         }