From e22b6443d7bcd06b48bb90dacc21d99dcff24151 Mon Sep 17 00:00:00 2001 From: Patrick Siegle Date: Sun, 10 Nov 2024 20:06:21 +0100 Subject: [PATCH] feat: add function to manually check rate limit (#346) --- index.js | 152 +++++++++++++++++++++++++++++++---------------- types/index.d.ts | 51 +++++++++++----- 2 files changed, 137 insertions(+), 66 deletions(-) diff --git a/index.js b/index.js index fb3f625..ad210c7 100644 --- a/index.js +++ b/index.js @@ -124,16 +124,17 @@ async function fastifyRateLimit (fastify, settings) { fastify.decorateRequest(pluginComponent.rateLimitRan, false) + if (!fastify.hasDecorator('createRateLimit')) { + fastify.decorate('createRateLimit', (options) => { + const args = createLimiterArgs(pluginComponent, globalParams, options) + return (req) => applyRateLimit(...args, req) + }) + } + if (!fastify.hasDecorator('rateLimit')) { fastify.decorate('rateLimit', (options) => { - if (typeof options === 'object') { - const newPluginComponent = Object.create(pluginComponent) - const mergedRateLimitParams = mergeParams(globalParams, options, { routeInfo: {} }) - newPluginComponent.store = newPluginComponent.store.child(mergedRateLimitParams) - return rateLimitRequestHandler(newPluginComponent, mergedRateLimitParams) - } - - return rateLimitRequestHandler(pluginComponent, globalParams) + const args = createLimiterArgs(pluginComponent, globalParams, options) + return rateLimitRequestHandler(...args) }) } @@ -189,6 +190,17 @@ function mergeParams (...params) { return result } +function createLimiterArgs (pluginComponent, globalParams, options) { + if (typeof options === 'object') { + const newPluginComponent = Object.create(pluginComponent) + const mergedRateLimitParams = mergeParams(globalParams, options, { routeInfo: {} }) + newPluginComponent.store = newPluginComponent.store.child(mergedRateLimitParams) + return [newPluginComponent, mergedRateLimitParams] + } + + return [pluginComponent, globalParams] +} + function addRouteRateHook (pluginComponent, params, routeOptions) { const hook = params.hook const hookHandler = rateLimitRequestHandler(pluginComponent, params) @@ -201,8 +213,72 @@ function addRouteRateHook (pluginComponent, params, routeOptions) { } } +async function applyRateLimit (pluginComponent, params, req) { + const { store } = pluginComponent + + // Retrieve the key from the generator (the global one or the one defined in the endpoint) + let key = await params.keyGenerator(req) + const groupId = req.routeOptions.config?.rateLimit?.groupId + + if (groupId) { + key += groupId + } + + // Don't apply any rate limiting if in the allow list + if (params.allowList) { + if (typeof params.allowList === 'function') { + if (await params.allowList(req, key)) { + return { + isAllowed: true, + key + } + } + } else if (params.allowList.indexOf(key) !== -1) { + return { + isAllowed: true, + key + } + } + } + + const max = typeof params.max === 'number' ? params.max : await params.max(req, key) + const timeWindow = typeof params.timeWindow === 'number' ? params.timeWindow : await params.timeWindow(req, key) + let current = 0 + let ttl = 0 + let ttlInSeconds = 0 + + // We increment the rate limit for the current request + try { + const res = await new Promise((resolve, reject) => { + store.incr(key, (err, res) => { + err ? reject(err) : resolve(res) + }, timeWindow, max) + }) + + current = res.current + ttl = res.ttl + ttlInSeconds = Math.ceil(res.ttl / 1000) + } catch (err) { + if (!params.skipOnError) { + throw err + } + } + + return { + isAllowed: false, + key, + max, + timeWindow, + remaining: Math.max(0, max - current), + ttl, + ttlInSeconds, + isExceeded: current > max, + isBanned: params.ban !== -1 && current - max > params.ban + } +} + function rateLimitRequestHandler (pluginComponent, params) { - const { rateLimitRan, store } = pluginComponent + const { rateLimitRan } = pluginComponent let timeWindowString if (typeof params.timeWindow === 'number') { @@ -216,51 +292,25 @@ function rateLimitRequestHandler (pluginComponent, params) { req[rateLimitRan] = true - // Retrieve the key from the generator (the global one or the one defined in the endpoint) - let key = await params.keyGenerator(req) - const groupId = req.routeOptions.config?.rateLimit?.groupId - - if (groupId) { - key += groupId - } - - // Don't apply any rate limiting if in the allow list - if (params.allowList) { - if (typeof params.allowList === 'function') { - if (await params.allowList(req, key)) { - return - } - } else if (params.allowList.indexOf(key) !== -1) { - return - } + const rateLimit = await applyRateLimit(pluginComponent, params, req) + if (rateLimit.isAllowed) { + return } - const max = typeof params.max === 'number' ? params.max : await params.max(req, key) - const timeWindow = typeof params.timeWindow === 'number' ? params.timeWindow : await params.timeWindow(req, key) - let current = 0 - let ttl = 0 - let ttlInSeconds = 0 - - // We increment the rate limit for the current request - try { - const res = await new Promise((resolve, reject) => { - store.incr(key, (err, res) => { - err ? reject(err) : resolve(res) - }, timeWindow, max) - }) - - current = res.current - ttl = res.ttl - ttlInSeconds = Math.ceil(res.ttl / 1000) - } catch (err) { - if (!params.skipOnError) { - throw err - } - } + const { + key, + max, + timeWindow, + remaining, + ttl, + ttlInSeconds, + isExceeded, + isBanned + } = rateLimit - if (current <= max) { + if (!isExceeded) { if (params.addHeadersOnExceeding[params.labels.rateLimit]) { res.header(params.labels.rateLimit, max) } - if (params.addHeadersOnExceeding[params.labels.rateRemaining]) { res.header(params.labels.rateRemaining, max - current) } + if (params.addHeadersOnExceeding[params.labels.rateRemaining]) { res.header(params.labels.rateRemaining, remaining) } if (params.addHeadersOnExceeding[params.labels.rateReset]) { res.header(params.labels.rateReset, ttlInSeconds) } params.onExceeding(req, key) @@ -283,7 +333,7 @@ function rateLimitRequestHandler (pluginComponent, params) { after: timeWindowString ?? ms.format(timeWindow, true) } - if (params.ban !== -1 && current - max > params.ban) { + if (isBanned) { respCtx.statusCode = 403 respCtx.ban = true params.onBanReach(req, key) diff --git a/types/index.d.ts b/types/index.d.ts index 81ce22b..29d557e 100644 --- a/types/index.d.ts +++ b/types/index.d.ts @@ -12,6 +12,24 @@ import { declare module 'fastify' { interface FastifyInstance { + createRateLimit(options?: fastifyRateLimit.CreateRateLimitOptions): (req: FastifyRequest) => Promise< + | { + isAllowed: true + key: string + } + | { + isAllowed: false + key: string + max: number + timeWindow: number + remaining: number + ttl: number + ttlInSeconds: number + isExceeded: boolean + isBanned: boolean + } + > + rateLimit< RouteGeneric extends RouteGenericInterface = RouteGenericInterface, ContextConfig = ContextConfigDefault, @@ -89,13 +107,9 @@ declare namespace fastifyRateLimit { 'ratelimit-reset'?: boolean; } - export type RateLimitHook = - | 'onRequest' - | 'preParsing' - | 'preValidation' - | 'preHandler' - - export interface RateLimitOptions { + export interface CreateRateLimitOptions { + store?: FastifyRateLimitStoreCtor; + skipOnError?: boolean; max?: | number | ((req: FastifyRequest, key: string) => number) @@ -105,19 +119,26 @@ declare namespace fastifyRateLimit { | string | ((req: FastifyRequest, key: string) => number) | ((req: FastifyRequest, key: string) => Promise); - hook?: RateLimitHook; - cache?: number; - store?: FastifyRateLimitStoreCtor; /** - * @deprecated Use `allowList` property - */ + * @deprecated Use `allowList` property + */ whitelist?: string[] | ((req: FastifyRequest, key: string) => boolean); allowList?: string[] | ((req: FastifyRequest, key: string) => boolean | Promise); - continueExceeding?: boolean; - skipOnError?: boolean; + keyGenerator?: (req: FastifyRequest) => string | number | Promise; ban?: number; + } + + export type RateLimitHook = + | 'onRequest' + | 'preParsing' + | 'preValidation' + | 'preHandler' + + export interface RateLimitOptions extends CreateRateLimitOptions { + hook?: RateLimitHook; + cache?: number; + continueExceeding?: boolean; onBanReach?: (req: FastifyRequest, key: string) => void; - keyGenerator?: (req: FastifyRequest) => string | number | Promise; groupId?: string; errorResponseBuilder?: ( req: FastifyRequest,