From 8668502a1e12e67a0b06c2d2beebac043796e33d Mon Sep 17 00:00:00 2001 From: Jay McDoniel Date: Thu, 6 Jul 2023 10:47:34 -0700 Subject: [PATCH] feat: allowfor multiple throttler contexts This is a bit of something that I've wanted to do for a while and inspired by [this pr][pr]. With the new appraoch, we're now able to let users define scales at which they would like the throttling to work over, and let it work for any number of configuratins, from a single 10 requests in 5 seconds to scales of months, or milliseconds BREAKING CHANGES: It's worth noting there are quite a few breaking changes in this which will be reflected in the changelog as well, but better to have multiple mentions in my opinion * ttl is now in milliseconds, not seconds, but there are time helper exposed to ease the migration to that * the module options is now either an array or an object with a `throttlers` array property * `@Throttle()` now takes in an object instead of two parameters, to allow for setting multiple throttle contexts at once in a more readable manner * `@ThrottleSkip()` now takes in an object with string boolean to say which throttler should be skipped pr: https://github.com/nestjs/throttler/pull/1522 ref: #1369 ref: #1522 --- package.json | 3 +- pnpm-lock.yaml | 110 ++++++++++++++++++++- src/index.ts | 1 + src/throttler-module-options.interface.ts | 42 ++++++-- src/throttler.decorator.ts | 41 ++++---- src/throttler.guard.ts | 112 ++++++++++++++-------- src/throttler.module.ts | 2 +- src/throttler.providers.ts | 4 +- src/throttler.service.ts | 2 +- src/utilities.ts | 5 + test/app/controllers/app.controller.ts | 4 +- test/app/controllers/controller.module.ts | 14 +-- test/app/controllers/limit.controller.ts | 6 +- test/controller.e2e-spec.ts | 10 +- test/multi/app.module.ts | 33 +++++++ test/multi/multi-throttler.controller.ts | 22 +++++ test/multi/multi-throttler.e2e-spec.ts | 92 ++++++++++++++++++ 17 files changed, 422 insertions(+), 81 deletions(-) create mode 100644 src/utilities.ts create mode 100644 test/multi/app.module.ts create mode 100644 test/multi/multi-throttler.controller.ts create mode 100644 test/multi/multi-throttler.e2e-spec.ts diff --git a/package.json b/package.json index 7254cd64c..d4d07116e 100644 --- a/package.json +++ b/package.json @@ -48,6 +48,7 @@ "md5": "^2.2.1" }, "devDependencies": { + "@apollo/server": "4.7.5", "@changesets/cli": "2.26.2", "@commitlint/cli": "17.6.6", "@commitlint/config-angular": "17.6.6", @@ -71,7 +72,6 @@ "@types/supertest": "2.0.12", "@typescript-eslint/eslint-plugin": "5.61.0", "@typescript-eslint/parser": "5.61.0", - "@apollo/server": "4.7.5", "apollo-server-fastify": "3.12.0", "conventional-changelog-cli": "3.0.0", "cz-conventional-changelog": "3.3.0", @@ -84,6 +84,7 @@ "jest": "29.6.1", "lint-staged": "13.2.3", "nodemon": "2.0.22", + "pactum": "^3.4.1", "pinst": "3.0.0", "prettier": "3.0.0", "reflect-metadata": "0.1.13", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index dd9c5a33a..8638f12f4 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -1,4 +1,4 @@ -lockfileVersion: '6.0' +lockfileVersion: '6.1' settings: autoInstallPeers: true @@ -118,6 +118,9 @@ devDependencies: nodemon: specifier: 2.0.22 version: 2.0.22 + pactum: + specifier: ^3.4.1 + version: 3.4.1 pinst: specifier: 3.0.0 version: 3.0.0 @@ -571,6 +574,11 @@ packages: xss: 1.0.14 dev: true + /@arr/every@1.0.1: + resolution: {integrity: sha512-UQFQ6SgyJ6LX42W8rHCs8KVc0JS0tzVL9ct4XYedJukskYVWTo49tNiMEK9C2HTyarbNiT/RVIRSY82vH+6sTg==} + engines: {node: '>=4'} + dev: true + /@babel/code-frame@7.22.5: resolution: {integrity: sha512-Xmwn266vad+6DAqEB2A6V/CcZVp62BbwVmcOJc2RPuwih1kw02TjQvWVWlcKGbBPd+8/0V5DEkOcizRGYsspYQ==} engines: {node: '>=6.9.0'} @@ -1319,6 +1327,10 @@ packages: engines: {node: ^12.22.0 || ^14.17.0 || >=16.0.0} dev: true + /@exodus/schemasafe@1.0.1: + resolution: {integrity: sha512-PQdbF8dGd4LnbwBlcc4ML8RKYdplm+e9sUeWBTr4zgF13/Shiuov9XznvM4T8cb1CfyKK21yTUkuAIIh/DAH/g==} + dev: true + /@fastify/accepts@3.0.0: resolution: {integrity: sha512-+ldBB3O59p/z9Uc1LSZqAA4/oZaNbRtCVMwjgJOahl+PKmx4ciRRoWVht77kFOb36lRE5MPEba4Vt78H7PKfQw==} engines: {node: '>=10'} @@ -2351,6 +2363,10 @@ packages: config-chain: 1.1.13 dev: true + /@polka/url@0.5.0: + resolution: {integrity: sha512-oZLYFEAzUKyi3SKnXvj32ZCEGH6RDnao7COuCVhDydMS9NrCSVXhM79VaKyP5+Zc33m0QXEd2DN3UkU7OsHcfw==} + dev: true + /@protobufjs/aspromise@1.1.2: resolution: {integrity: sha512-j+gKExEuLmKwvz3OgROXtrJ2UG2x8Ch2YZUxahh+s1F2HZ+wAceUNLkvy6zKCPVRkU++ZWQrdxsUeQXmcg4uoQ==} dev: true @@ -3808,6 +3824,10 @@ packages: redeyed: 2.1.1 dev: true + /centra@2.6.0: + resolution: {integrity: sha512-dgh+YleemrT8u85QL11Z6tYhegAs3MMxsaWAq/oXeAmYJ7VxL3SI9TZtnfaEvNDMAPolj25FXIb3S+HCI4wQaQ==} + dev: true + /chalk@2.4.2: resolution: {integrity: sha512-Mti+f9lpJNcwF4tWV8/OrTTtF1gZi+f8FqlyAdouralcFWFQWF2+NgCHShjkCb+IFBLq9buZwE1xckQU4peSuQ==} engines: {node: '>=4'} @@ -4472,6 +4492,10 @@ packages: resolution: {integrity: sha512-oIPzksmTg4/MriiaYGO+okXDT7ztn/w3Eptv/+gSIdMdKsJo0u4CfYNFJPy+4SKMuCqGw2wxnA+URMg3t8a/bQ==} dev: true + /deep-override@1.0.2: + resolution: {integrity: sha512-+bAuLuYqaVVUWPaq8rmU8NLTX85p4I5k5/cVdhBioEfH7k+5NlGdv4NoJVQcJRByqzzTWWzTpih+pU1wBTmMow==} + dev: true + /deepmerge@4.3.1: resolution: {integrity: sha512-3sUqbMEc77XqpdNO7FRyRog+eW3ph+GYCbj+rK+uYyRMuwsVy0rMiVtPn+QJlKFvWP/1PYpapqYn0Me2knFn+A==} engines: {node: '>=0.10.0'} @@ -5434,6 +5458,14 @@ packages: webpack: 5.88.1 dev: true + /form-data-lite@1.0.3: + resolution: {integrity: sha512-P7xPqAiOPKzC9Q9aywAZJCQq4QOE5WokPb3HrcWRh7C57RKytueJzoORZAVgHBNvK/lL7E+FxjQjd4X/zbecEQ==} + dependencies: + asynckit: 0.4.0 + combined-stream: 1.0.8 + mime-lite: 1.0.3 + dev: true + /form-data@3.0.1: resolution: {integrity: sha512-RHkBKtLWUVwd7SqRIvCZMEvAMoGUp0XU+seQiZejj0COz3RI3hWP4sCv3gZWWLjJTd7rGwcsF5eKZGii0r/hbg==} engines: {node: '>= 6'} @@ -6977,6 +7009,10 @@ packages: engines: {node: ^14.17.0 || ^16.13.0 || >=18.0.0} dev: true + /json-query@2.2.2: + resolution: {integrity: sha512-y+IcVZSdqNmS4fO8t1uZF6RMMs0xh3SrTjJr9bp1X3+v0Q13+7Cyv12dSmKwDswp/H427BVtpkLWhGxYu3ZWRA==} + dev: true + /json-schema-traverse@0.4.1: resolution: {integrity: sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg==} dev: true @@ -7044,6 +7080,11 @@ packages: engines: {node: '>=6'} dev: true + /klona@2.0.6: + resolution: {integrity: sha512-dhG34DXATL5hSxJbIexCft8FChFXtmskoZYnoPWjXQuebWYCNkVeV3KkGegCK9CP1oswI/vQibS2GY7Em/sJJA==} + engines: {node: '>= 8'} + dev: true + /leven@3.1.0: resolution: {integrity: sha512-qsda+H8jTaUaN/x5vzW2rzc+8Rw4TAQ/4KjB46IwK5VH+IlVeeeje/EoZRpiXvIqjFgK84QffqPztGI3VBLG1A==} engines: {node: '>=6'} @@ -7074,6 +7115,10 @@ packages: set-cookie-parser: 2.6.0 dev: true + /lightcookie@1.0.25: + resolution: {integrity: sha512-SrY/+eBPaKAMnsn7mCsoOMZzoQyCyHHHZlFCu2fjo28XxSyCLjlooKiTxyrXTg8NPaHp1YzWi0lcGG1gDi6KHw==} + dev: true + /lilconfig@2.1.0: resolution: {integrity: sha512-utWOt/GHzuUxnLKxB6dk81RoOeoNeHgbrXiuGk4yyF5qlRz+iIVWu56E2fqGHFrXz0QNUhLB/8nKqvRH66JKGQ==} engines: {node: '>=10'} @@ -7399,6 +7444,13 @@ packages: hasBin: true dev: true + /matchit@1.1.0: + resolution: {integrity: sha512-+nGYoOlfHmxe5BW5tE0EMJppXEwdSf8uBA1GTZC7Q77kbT35+VKLYJMzVNWCHSsga1ps1tPYFtFyvxvKzWVmMA==} + engines: {node: '>=6'} + dependencies: + '@arr/every': 1.0.1 + dev: true + /md5@2.2.1: resolution: {integrity: sha512-PlGG4z5mBANDGCKsYQe0CaUYHdZYZt8ZPZLmEt+Urf0W4GlpTX4HescwHU+dc9+Z/G/vZKYZYFrwgm9VxK6QOQ==} dependencies: @@ -7488,6 +7540,10 @@ packages: engines: {node: '>= 0.6'} dev: true + /mime-lite@1.0.3: + resolution: {integrity: sha512-V85l97zJSTG8FEvmdTlmNYb0UMrVBwvRjw7bWTf/aT6KjFwtz3iTz8D2tuFIp7lwiaO2C5ecnrEmSkkMRCrqVw==} + dev: true + /mime-types@2.1.35: resolution: {integrity: sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==} engines: {node: '>= 0.6'} @@ -7900,6 +7956,12 @@ packages: mimic-fn: 4.0.0 dev: true + /openapi-fuzzer-core@1.0.6: + resolution: {integrity: sha512-FJNJIfgUFuv4NmVGq9MYdoKra2GrkDy2uhIjE2YGlw30UA1glf4SXLMhI4UwdcJ8jisKdIxi7lXrfej8GvNW5w==} + dependencies: + klona: 2.0.6 + dev: true + /optimism@0.16.2: resolution: {integrity: sha512-zWNbgWj+3vLEjZNIh/okkY2EUfX+vB9TJopzIZwT1xxaMqC5hRLLraePod4c5n4He08xuXNH+zhKFFCu390wiQ==} dependencies: @@ -8071,6 +8133,27 @@ packages: engines: {node: '>=6'} dev: true + /pactum-matchers@1.1.5: + resolution: {integrity: sha512-Bc9VVmCYeTFLRw+Du/i8MJl8jA0iBP1pbhYg45+shrKkZJbhSTsrL8ZyAKLuGeFnO0e1HG1c7u6O9Cua30a9cg==} + dev: true + + /pactum@3.4.1: + resolution: {integrity: sha512-rRo2qtrUCdzjjKC+o5GeET+OsIMxxejimPh0tz7w9MB6ZDQgI0WJaXjVbZ0+/CtWPrKD7cyWIWdIJMNVJxsQxg==} + engines: {node: '>=10'} + dependencies: + '@exodus/schemasafe': 1.0.1 + deep-override: 1.0.2 + form-data-lite: 1.0.3 + json-query: 2.2.2 + klona: 2.0.6 + lightcookie: 1.0.25 + openapi-fuzzer-core: 1.0.6 + pactum-matchers: 1.1.5 + parse-graphql: 1.0.0 + phin: 3.7.0 + polka: 0.5.2 + dev: true + /parent-module@1.0.1: resolution: {integrity: sha512-GQ2EWRpQV8/o+Aw8YqtfZZPfNRWZYkbidE9k5rpl/hC3vtHHBfGm2Ifi6qWV+coDGkrUKZAxE3Lot5kcsRlh+g==} engines: {node: '>=6'} @@ -8078,6 +8161,10 @@ packages: callsites: 3.1.0 dev: true + /parse-graphql@1.0.0: + resolution: {integrity: sha512-NjvQHHaiPCxPZrhm/kKnorxOv7r/eA+tE0VW5E8iJMH9wTqFA1V0YK/7nbpxVu3JdXUxyWTKMez9lsHUtAwa0w==} + dev: true + /parse-json@4.0.0: resolution: {integrity: sha512-aOIos8bujGN93/8Ox/jPLh7RwVnPEysynVFE+fQZyg6jKELEHwzgKdLRFHUgXJL6kylijVSBC4BvN9OmsB48Rw==} engines: {node: '>=4'} @@ -8183,6 +8270,13 @@ packages: engines: {node: '>=8'} dev: true + /phin@3.7.0: + resolution: {integrity: sha512-DqnVNrpYhKGBZppNKprD+UJylMeEKOZxHgPB+ZP6mGzf3uA2uox4Ep9tUm+rUc8WLIdHT3HcAE4X8fhwQA9JKg==} + engines: {node: '>= 8'} + dependencies: + centra: 2.6.0 + dev: true + /picocolors@1.0.0: resolution: {integrity: sha512-1fygroTLlHu66zi26VoTDv8yRgm0Fccecssto+MhsZ0D/DGW2sm8E8AjW7NU5VVTRt5GxbeZ5qBuJr+HyLYkjQ==} dev: true @@ -8289,6 +8383,13 @@ packages: engines: {node: '>=4'} dev: true + /polka@0.5.2: + resolution: {integrity: sha512-FVg3vDmCqP80tOrs+OeNlgXYmFppTXdjD5E7I4ET1NjvtNmQrb1/mJibybKkb/d4NA7YWAr1ojxuhpL3FHqdlw==} + dependencies: + '@polka/url': 0.5.0 + trouter: 2.0.1 + dev: true + /preferred-pm@3.0.3: resolution: {integrity: sha512-+wZgbxNES/KlJs9q40F/1sfOd/j7f1O9JaHcW5Dsn3aUUOZg3L2bjpVUcKV2jvtElYfoTuQiNeMfQJ4kwUAhCQ==} engines: {node: '>=10'} @@ -9652,6 +9753,13 @@ packages: engines: {node: '>=8'} dev: true + /trouter@2.0.1: + resolution: {integrity: sha512-kr8SKKw94OI+xTGOkfsvwZQ8mWoikZDd2n8XZHjJVZUARZT+4/VV6cacRS6CLsH9bNm+HFIPU1Zx4CnNnb4qlQ==} + engines: {node: '>=6'} + dependencies: + matchit: 1.1.0 + dev: true + /ts-invariant@0.10.3: resolution: {integrity: sha512-uivwYcQaxAucv1CzRp2n/QdYPo4ILf9VXgH19zEIjFx2EJufV16P0JtJVpYHy89DItG6Kwj2oIUjrcK5au+4tQ==} engines: {node: '>=8'} diff --git a/src/index.ts b/src/index.ts index 678b61f87..bbbc1d510 100644 --- a/src/index.ts +++ b/src/index.ts @@ -6,3 +6,4 @@ export * from './throttler.guard'; export * from './throttler.module'; export { getOptionsToken, getStorageToken } from './throttler.providers'; export * from './throttler.service'; +export * from './utilities'; diff --git a/src/throttler-module-options.interface.ts b/src/throttler-module-options.interface.ts index d0b406787..4aa33f82d 100644 --- a/src/throttler-module-options.interface.ts +++ b/src/throttler-module-options.interface.ts @@ -1,4 +1,5 @@ import { ExecutionContext, ModuleMetadata, Type } from '@nestjs/common/interfaces'; +import { ThrottlerStorage } from './throttler-storage.interface'; export type Resolvable = | T @@ -7,14 +8,21 @@ export type Resolvable = /** * @publicApi */ -export interface ThrottlerModuleOptions { +export interface ThrottlerOptions { + /** + * The name for the rate limit to be used. + * This can be left blank and it will be tracked as "default" internally. + * If this is set, it will be added to the return headers. + * e.g. x-ratelimit-remaining-long: 5 + */ + name?: string; /** * The amount of requests that are allowed within the ttl's time window. */ limit?: Resolvable; /** - * The amount of seconds of how many requests are allowed within this time. + * The number of milliseconds the limit of requests are allowed */ ttl?: Resolvable; @@ -23,11 +31,6 @@ export interface ThrottlerModuleOptions { */ ignoreUserAgents?: RegExp[]; - /** - * The storage class to use where all the record will be stored in. - */ - storage?: any; - /** * A factory method to determine if throttling should be skipped. * This can be based on the incoming context, or something like an env value. @@ -35,6 +38,31 @@ export interface ThrottlerModuleOptions { skipIf?: (context: ExecutionContext) => boolean; } +/** + * @publicApi + */ +export type ThrottlerModuleOptions = + | Array + | { + /** + * A factory method to determine if throttling should be skipped. + * This can be based on the incoming context, or something like an env value. + */ + skipIf?: (context: ExecutionContext) => boolean; + /** + * The user agents that should be ignored (checked against the User-Agent header). + */ + ignoreUserAgents?: RegExp[]; + /** + * The storage class to use where all the record will be stored in. + */ + storage?: Type; + /** + * The named throttlers to use + */ + throttlers: Array; + }; + /** * @publicApi */ diff --git a/src/throttler.decorator.ts b/src/throttler.decorator.ts index 09e405060..b7d35eb7c 100644 --- a/src/throttler.decorator.ts +++ b/src/throttler.decorator.ts @@ -3,25 +3,30 @@ import { THROTTLER_LIMIT, THROTTLER_SKIP, THROTTLER_TTL } from './throttler.cons import { getOptionsToken, getStorageToken } from './throttler.providers'; import { Resolvable } from './throttler-module-options.interface'; +interface ThrottlerMethodOrControllerOptions { + limit?: Resolvable; + ttl?: Resolvable; +} + function setThrottlerMetadata( target: any, - limit: Resolvable, - ttl: Resolvable, + options: Record, ): void { - Reflect.defineMetadata(THROTTLER_TTL, ttl, target); - Reflect.defineMetadata(THROTTLER_LIMIT, limit, target); + for (const name in options) { + Reflect.defineMetadata(THROTTLER_TTL + name, options[name].ttl, target); + Reflect.defineMetadata(THROTTLER_LIMIT + name, options[name].limit, target); + } } /** * Adds metadata to the target which will be handled by the ThrottlerGuard to * handle incoming requests based on the given metadata. - * @example @Throttle(2, 10) - * @example @Throttle(() => 2, () => 10) + * @example @Throttle({ default: { limit: 2, ttl: 10 }}) + * @example @Throttle({default: { limit: () => 20, ttl: () => 60 }}) * @publicApi */ export const Throttle = ( - limit: Resolvable = 20, - ttl: Resolvable = 60, + options: Record, ): MethodDecorator & ClassDecorator => { return ( target: any, @@ -29,10 +34,10 @@ export const Throttle = ( descriptor?: TypedPropertyDescriptor, ) => { if (descriptor) { - setThrottlerMetadata(descriptor.value, limit, ttl); + setThrottlerMetadata(descriptor.value, options); return descriptor; } - setThrottlerMetadata(target, limit, ttl); + setThrottlerMetadata(target, options); return target; }; }; @@ -44,18 +49,22 @@ export const Throttle = ( * @example @SkipThrottle(false) * @publicApi */ -export const SkipThrottle = (skip = true): MethodDecorator & ClassDecorator => { +export const SkipThrottle = ( + skip: Record = { default: true }, +): MethodDecorator & ClassDecorator => { return ( target: any, propertyKey?: string | symbol, descriptor?: TypedPropertyDescriptor, ) => { - if (descriptor) { - Reflect.defineMetadata(THROTTLER_SKIP, skip, descriptor.value); - return descriptor; + for (const key in skip) { + if (descriptor) { + Reflect.defineMetadata(THROTTLER_SKIP + key, skip[key], descriptor.value); + return descriptor; + } + Reflect.defineMetadata(THROTTLER_SKIP + key, skip[key], target); + return target; } - Reflect.defineMetadata(THROTTLER_SKIP, skip, target); - return target; }; }; diff --git a/src/throttler.guard.ts b/src/throttler.guard.ts index dfa193911..ba5b6ae26 100644 --- a/src/throttler.guard.ts +++ b/src/throttler.guard.ts @@ -1,7 +1,7 @@ import { CanActivate, ExecutionContext, Injectable } from '@nestjs/common'; import { Reflector } from '@nestjs/core'; import * as md5 from 'md5'; -import { Resolvable, ThrottlerModuleOptions } from './throttler-module-options.interface'; +import { ThrottlerModuleOptions, ThrottlerOptions } from './throttler-module-options.interface'; import { ThrottlerStorage } from './throttler-storage.interface'; import { THROTTLER_LIMIT, THROTTLER_SKIP, THROTTLER_TTL } from './throttler.constants'; import { InjectThrottlerOptions, InjectThrottlerStorage } from './throttler.decorator'; @@ -15,12 +15,36 @@ import { ThrottlerLimitDetail } from './throttler.guard.interface'; export class ThrottlerGuard implements CanActivate { protected headerPrefix = 'X-RateLimit'; protected errorMessage = throttlerMessage; + protected throttlers: Array; + protected commonOptions: Pick; constructor( @InjectThrottlerOptions() protected readonly options: ThrottlerModuleOptions, @InjectThrottlerStorage() protected readonly storageService: ThrottlerStorage, protected readonly reflector: Reflector, ) {} + async onModuleInit() { + this.throttlers = (Array.isArray(this.options) ? this.options : this.options.throttlers) + .sort((first, second) => { + if (typeof first.ttl === 'function') { + return 1; + } + if (typeof second.ttl === 'function') { + return 0; + } + return first.ttl - second.ttl; + }) + .map((opt) => ({ ...opt, name: opt.name ?? 'default' })); + if (Array.isArray(this.options)) { + this.commonOptions = {}; + } else { + this.commonOptions = { + skipIf: this.options.skipIf, + ignoreUserAgents: this.options.ignoreUserAgents, + }; + } + } + /** * Throttle requests against their TTL limit and whether to allow or deny it. * Based on the context type different handlers will be called. @@ -30,28 +54,41 @@ export class ThrottlerGuard implements CanActivate { const handler = context.getHandler(); const classRef = context.getClass(); - // Return early if the current route should be skipped. - if ( - this.reflector.getAllAndOverride(THROTTLER_SKIP, [handler, classRef]) || - this.options.skipIf?.(context) - ) { - return true; - } + const continues: boolean[] = []; + for (const namedThrottler of this.throttlers) { + // Return early if the current route should be skipped. + const skip = this.reflector.getAllAndOverride(THROTTLER_SKIP + namedThrottler.name, [ + handler, + classRef, + ]); + const skipIf = namedThrottler.skipIf || this.commonOptions.skipIf; + if (skip || skipIf?.(context)) { + continues.push(true); + continue; + } - // Return early when we have no limit or ttl data. - const routeOrClassLimit = this.reflector.getAllAndOverride>( - THROTTLER_LIMIT, - [handler, classRef], - ); - const routeOrClassTtl = this.reflector.getAllAndOverride>(THROTTLER_TTL, [ - handler, - classRef, - ]); + // Return early when we have no limit or ttl data. + const routeOrClassLimit = this.reflector.getAllAndOverride( + THROTTLER_LIMIT + namedThrottler.name, + [handler, classRef], + ); + const routeOrClassTtl = this.reflector.getAllAndOverride( + THROTTLER_TTL + namedThrottler.name, + [handler, classRef], + ); - // Check if specific limits are set at class or route level, otherwise use global options. - const limit = await this.resolveValue(context, routeOrClassLimit || this.options.limit); - const ttl = await this.resolveValue(context, routeOrClassTtl || this.options.ttl); - return this.handleRequest(context, limit, ttl); + // Check if specific limits are set at class or route level, otherwise use global options. + let limit = routeOrClassLimit || namedThrottler.limit; + let ttl = routeOrClassTtl || namedThrottler.ttl; + if (typeof limit === 'function') { + limit = await limit(context); + } + if (typeof ttl === 'function') { + ttl = await ttl(context); + } + continues.push(await this.handleRequest(context, limit, ttl, namedThrottler)); + } + return continues.every((cont) => cont); } /** @@ -64,25 +101,28 @@ export class ThrottlerGuard implements CanActivate { context: ExecutionContext, limit: number, ttl: number, + throttler: ThrottlerOptions, ): Promise { // Here we start to check the amount of requests being done against the ttl. const { req, res } = this.getRequestResponse(context); - + const ignoreUserAgents = throttler.ignoreUserAgents ?? this.commonOptions.ignoreUserAgents; // Return early if the current user agent should be ignored. - if (Array.isArray(this.options.ignoreUserAgents)) { - for (const pattern of this.options.ignoreUserAgents) { + if (Array.isArray(ignoreUserAgents)) { + for (const pattern of ignoreUserAgents) { if (pattern.test(req.headers['user-agent'])) { return true; } } } const tracker = this.getTracker(req); - const key = this.generateKey(context, tracker); + const key = this.generateKey(context, tracker, throttler.name); const { totalHits, timeToExpire } = await this.storageService.increment(key, ttl); + const getThrottlerSuffix = (name: string) => (name === 'default' ? '' : `-${name}`); + // Throw an error when the user reached their limit. if (totalHits > limit) { - res.header('Retry-After', timeToExpire); + res.header(`Retry-After${getThrottlerSuffix(throttler.name)}`, timeToExpire); this.throwThrottlingException(context, { limit, ttl, @@ -93,11 +133,14 @@ export class ThrottlerGuard implements CanActivate { }); } - res.header(`${this.headerPrefix}-Limit`, limit); + res.header(`${this.headerPrefix}-Limit${getThrottlerSuffix(throttler.name)}`, limit); // We're about to add a record so we need to take that into account here. // Otherwise the header says we have a request left when there are none. - res.header(`${this.headerPrefix}-Remaining`, Math.max(0, limit - totalHits)); - res.header(`${this.headerPrefix}-Reset`, timeToExpire); + res.header( + `${this.headerPrefix}-Remaining${getThrottlerSuffix(throttler.name)}`, + Math.max(0, limit - totalHits), + ); + res.header(`${this.headerPrefix}-Reset${getThrottlerSuffix(throttler.name)}`, timeToExpire); return true; } @@ -118,8 +161,8 @@ export class ThrottlerGuard implements CanActivate { * Generate a hashed key that will be used as a storage key. * The key will always be a combination of the current context and IP. */ - protected generateKey(context: ExecutionContext, suffix: string): string { - const prefix = `${context.getClass().name}-${context.getHandler().name}`; + protected generateKey(context: ExecutionContext, suffix: string, name: string): string { + const prefix = `${context.getClass().name}-${context.getHandler().name}-${name}`; return md5(`${prefix}-${suffix}`); } @@ -137,11 +180,4 @@ export class ThrottlerGuard implements CanActivate { ): void { throw new ThrottlerException(this.errorMessage); } - - private async resolveValue( - context: ExecutionContext, - resolvableValue: Resolvable, - ): Promise { - return typeof resolvableValue === 'function' ? resolvableValue(context) : resolvableValue; - } } diff --git a/src/throttler.module.ts b/src/throttler.module.ts index 935e886e5..7116ce811 100644 --- a/src/throttler.module.ts +++ b/src/throttler.module.ts @@ -16,7 +16,7 @@ export class ThrottlerModule { /** * Register the module synchronously. */ - static forRoot(options: ThrottlerModuleOptions = {}): DynamicModule { + static forRoot(options: ThrottlerModuleOptions = [{}]): DynamicModule { const providers = [...createThrottlerProviders(options), ThrottlerStorageProvider]; return { module: ThrottlerModule, diff --git a/src/throttler.providers.ts b/src/throttler.providers.ts index aecb7558f..6b42f7d45 100644 --- a/src/throttler.providers.ts +++ b/src/throttler.providers.ts @@ -16,7 +16,9 @@ export function createThrottlerProviders(options: ThrottlerModuleOptions): Provi export const ThrottlerStorageProvider = { provide: ThrottlerStorage, useFactory: (options: ThrottlerModuleOptions) => { - return options.storage ? options.storage : new ThrottlerStorageService(); + return !Array.isArray(options) && options.storage + ? options.storage + : new ThrottlerStorageService(); }, inject: [THROTTLER_OPTIONS], }; diff --git a/src/throttler.service.ts b/src/throttler.service.ts index 7bf18477a..7f79da441 100644 --- a/src/throttler.service.ts +++ b/src/throttler.service.ts @@ -35,7 +35,7 @@ export class ThrottlerStorageService implements ThrottlerStorage, OnApplicationS } async increment(key: string, ttl: number): Promise { - const ttlMilliseconds = ttl * 1000; + const ttlMilliseconds = ttl; if (!this.storage[key]) { this.storage[key] = { totalHits: 0, expiresAt: Date.now() + ttlMilliseconds }; } diff --git a/src/utilities.ts b/src/utilities.ts new file mode 100644 index 000000000..ff980d15f --- /dev/null +++ b/src/utilities.ts @@ -0,0 +1,5 @@ +export const seconds = (howMany: number) => howMany * 1000; +export const minutes = (howMany: number) => seconds(howMany) * 60; +export const hours = (howMany: number) => minutes(howMany) * 60; +export const days = (howMany: number) => hours(howMany) * 24; +export const weeks = (howMany: number) => days(howMany) * 7; diff --git a/test/app/controllers/app.controller.ts b/test/app/controllers/app.controller.ts index 4a322c5c3..5fd17f29c 100644 --- a/test/app/controllers/app.controller.ts +++ b/test/app/controllers/app.controller.ts @@ -1,9 +1,9 @@ import { Controller, Get } from '@nestjs/common'; -import { SkipThrottle, Throttle } from '../../../src'; +import { SkipThrottle, Throttle, seconds } from '../../../src'; import { AppService } from '../app.service'; @Controller() -@Throttle(2, 10) +@Throttle({ default: { limit: 2, ttl: seconds(10) } }) export class AppController { constructor(private readonly appService: AppService) {} diff --git a/test/app/controllers/controller.module.ts b/test/app/controllers/controller.module.ts index 36cc4bab4..cd6e8034d 100644 --- a/test/app/controllers/controller.module.ts +++ b/test/app/controllers/controller.module.ts @@ -1,5 +1,5 @@ import { Module } from '@nestjs/common'; -import { ThrottlerModule } from '../../../src'; +import { ThrottlerModule, seconds } from '../../../src'; import { AppService } from '../app.service'; import { AppController } from './app.controller'; import { DefaultController } from './default.controller'; @@ -7,11 +7,13 @@ import { LimitController } from './limit.controller'; @Module({ imports: [ - ThrottlerModule.forRoot({ - limit: 5, - ttl: 60, - ignoreUserAgents: [/throttler-test/g], - }), + ThrottlerModule.forRoot([ + { + limit: 5, + ttl: seconds(60), + ignoreUserAgents: [/throttler-test/g], + }, + ]), ], controllers: [AppController, DefaultController, LimitController], providers: [AppService], diff --git a/test/app/controllers/limit.controller.ts b/test/app/controllers/limit.controller.ts index c45d4cad3..857e2bb6e 100644 --- a/test/app/controllers/limit.controller.ts +++ b/test/app/controllers/limit.controller.ts @@ -1,8 +1,8 @@ import { Controller, Get } from '@nestjs/common'; -import { Throttle } from '../../../src'; +import { Throttle, seconds } from '../../../src'; import { AppService } from '../app.service'; -@Throttle(2, 10) +@Throttle({ default: { limit: 2, ttl: seconds(10) } }) @Controller('limit') export class LimitController { constructor(private readonly appService: AppService) {} @@ -11,7 +11,7 @@ export class LimitController { return this.appService.success(); } - @Throttle(5, 10) + @Throttle({ default: { limit: 5, ttl: seconds(10) } }) @Get('higher') getHigher() { return this.appService.success(); diff --git a/test/controller.e2e-spec.ts b/test/controller.e2e-spec.ts index b17a24a1a..d957e5b47 100644 --- a/test/controller.e2e-spec.ts +++ b/test/controller.e2e-spec.ts @@ -131,10 +131,12 @@ describe('SkipIf suite', () => { ], }) .overrideProvider(THROTTLER_OPTIONS) - .useValue({ - skipIf: () => true, - limit: 5, - }) + .useValue([ + { + skipIf: () => true, + limit: 5, + }, + ]) .compile(); const app = moduleFixture.createNestApplication(); diff --git a/test/multi/app.module.ts b/test/multi/app.module.ts new file mode 100644 index 000000000..88a0a1c8f --- /dev/null +++ b/test/multi/app.module.ts @@ -0,0 +1,33 @@ +import { Module } from '@nestjs/common'; +import { APP_GUARD } from '@nestjs/core'; +import { ThrottlerGuard, ThrottlerModule, seconds, minutes } from '../../src'; +import { MultiThrottlerController } from './multi-throttler.controller'; + +@Module({ + imports: [ + ThrottlerModule.forRoot([ + { + ttl: seconds(5), + limit: 2, + }, + { + name: 'long', + ttl: minutes(1), + limit: 5, + }, + { + name: 'short', + limit: 1, + ttl: seconds(1), + }, + ]), + ], + controllers: [MultiThrottlerController], + providers: [ + { + provide: APP_GUARD, + useClass: ThrottlerGuard, + }, + ], +}) +export class MultiThrottlerAppModule {} diff --git a/test/multi/multi-throttler.controller.ts b/test/multi/multi-throttler.controller.ts new file mode 100644 index 000000000..f81a15244 --- /dev/null +++ b/test/multi/multi-throttler.controller.ts @@ -0,0 +1,22 @@ +import { Controller, Get } from '@nestjs/common'; +import { SkipThrottle } from '../../src'; + +@Controller() +export class MultiThrottlerController { + @Get() + simpleRoute() { + return { success: true }; + } + + @SkipThrottle({ short: true }) + @Get('skip-short') + skipShort() { + return { success: true }; + } + + @SkipThrottle({ default: true, long: true }) + @Get('skip-default-and-long') + skipDefAndLong() { + return { success: true }; + } +} diff --git a/test/multi/multi-throttler.e2e-spec.ts b/test/multi/multi-throttler.e2e-spec.ts new file mode 100644 index 000000000..440d11fd7 --- /dev/null +++ b/test/multi/multi-throttler.e2e-spec.ts @@ -0,0 +1,92 @@ +import { INestApplication, Type } from '@nestjs/common'; +import { AbstractHttpAdapter } from '@nestjs/core'; +import { ExpressAdapter } from '@nestjs/platform-express'; +import { FastifyAdapter } from '@nestjs/platform-fastify'; +import { Test } from '@nestjs/testing'; +import { setTimeout } from 'node:timers/promises'; +import { request, spec } from 'pactum'; +import { MultiThrottlerAppModule } from './app.module'; + +jest.setTimeout(10000); + +const commonHeader = (prefix: string, name?: string) => `${prefix}${name ? '-' + name : ''}`; + +const remainingHeader = (name?: string) => commonHeader('x-ratelimit-remaining', name); +const limitHeader = (name?: string) => commonHeader('x-ratelimit-limit', name); +const retryHeader = (name?: string) => commonHeader('retry-after', name); + +const short = 'short'; +const long = 'long'; + +describe.each` + adapter | name + ${ExpressAdapter} | ${'express'} + ${FastifyAdapter} | ${'fastify'} +`('Mutli-Throttler Named Usage - $name', ({ adapter }: { adapter: Type }) => { + let app: INestApplication; + beforeAll(async () => { + const modRef = await Test.createTestingModule({ + imports: [MultiThrottlerAppModule], + }).compile(); + app = modRef.createNestApplication(new adapter()); + await app.listen(0); + request.setBaseUrl(await app.getUrl()); + }); + afterAll(async () => { + await app.close(); + }); + + describe('Default Route: 1/s, 2/5s, 5/min', () => { + it('should receive an exception when firing 2 request swithin a second', async () => { + await spec() + .get('/') + .expectStatus(200) + .expectHeader(remainingHeader(short), '0') + .expectHeader(limitHeader(short), '1'); + await spec().get('/').expectStatus(429).expectHeaderContains(retryHeader(short), /^\d+$/); + await setTimeout(1000); + }); + it('should get an error if we send two more requests within the first five seconds', async () => { + await spec() + .get('/') + .expectStatus(200) + .expectHeader(remainingHeader(), '0') + .expectHeader(limitHeader(), '2'); + await setTimeout(1000); + await spec().get('/').expectStatus(429).expectHeaderContains(retryHeader(), /^\d+$/); + await setTimeout(5000); + }); + it('should get an error if we smartly send 4 more requests within the minute', async () => { + await spec() + .get('/') + .expectStatus(200) + .expectHeader(limitHeader(long), '5') + .expectHeader(remainingHeader(long), '2') + .expectHeader(remainingHeader(short), '0'); + await setTimeout(1000); + await spec().get('/').expectStatus(200).expectHeader(remainingHeader(), '0'); + console.log('waiting 5 seconds'); + await setTimeout(5000); + await spec() + .get('/') + .expectStatus(200) + .expectHeader(remainingHeader(long), '0') + .expectHeader(remainingHeader(short), '0') + .expectHeader(remainingHeader(), '1'); + await setTimeout(1000); + await spec().get('/').expectStatus(429).expectHeaderContains(retryHeader(long), /^\d+$/); + }); + }); + describe('skips', () => { + it('should skip theshort throttler', async () => { + await spec().get('/skip-short').expectStatus(200).expectHeader(remainingHeader(), '1'); + await spec().get('/skip-short').expectStatus(200).expectHeader(remainingHeader(), '0'); + }); + it('should skip the default and long trackers', async () => { + await spec() + .get('/skip-default-and-long') + .expectStatus(200) + .expectHeader(remainingHeader(short), '0'); + }); + }); +});