Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function to manually check rate limit (#346) #392

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 101 additions & 51 deletions index.js
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,17 @@ async function fastifyRateLimit (fastify, settings) {

fastify.decorateRequest(pluginComponent.rateLimitRan, false)

if (!fastify.hasDecorator('createRateLimit')) {
fastify.decorate('createRateLimit', (options) => {
const args = createLimiterArgs(pluginComponent, globalParams, options)
return (req) => applyRateLimit(...args, req)
})
}

if (!fastify.hasDecorator('rateLimit')) {
fastify.decorate('rateLimit', (options) => {
if (typeof options === 'object') {
const newPluginComponent = Object.create(pluginComponent)
const mergedRateLimitParams = mergeParams(globalParams, options, { routeInfo: {} })
newPluginComponent.store = newPluginComponent.store.child(mergedRateLimitParams)
return rateLimitRequestHandler(newPluginComponent, mergedRateLimitParams)
}

return rateLimitRequestHandler(pluginComponent, globalParams)
const args = createLimiterArgs(pluginComponent, globalParams, options)
return rateLimitRequestHandler(...args)
})
}

Expand Down Expand Up @@ -189,6 +190,17 @@ function mergeParams (...params) {
return result
}

function createLimiterArgs (pluginComponent, globalParams, options) {
if (typeof options === 'object') {
const newPluginComponent = Object.create(pluginComponent)
const mergedRateLimitParams = mergeParams(globalParams, options, { routeInfo: {} })
newPluginComponent.store = newPluginComponent.store.child(mergedRateLimitParams)
return [newPluginComponent, mergedRateLimitParams]
}

return [pluginComponent, globalParams]
}

function addRouteRateHook (pluginComponent, params, routeOptions) {
const hook = params.hook
const hookHandler = rateLimitRequestHandler(pluginComponent, params)
Expand All @@ -201,8 +213,72 @@ function addRouteRateHook (pluginComponent, params, routeOptions) {
}
}

async function applyRateLimit (pluginComponent, params, req) {
const { store } = pluginComponent

// Retrieve the key from the generator (the global one or the one defined in the endpoint)
let key = await params.keyGenerator(req)
const groupId = req.routeOptions.config?.rateLimit?.groupId

if (groupId) {
key += groupId
}

// Don't apply any rate limiting if in the allow list
if (params.allowList) {
if (typeof params.allowList === 'function') {
if (await params.allowList(req, key)) {
return {
isAllowed: true,
key
}
}
} else if (params.allowList.indexOf(key) !== -1) {
return {
isAllowed: true,
key
}
}
}

const max = typeof params.max === 'number' ? params.max : await params.max(req, key)
const timeWindow = typeof params.timeWindow === 'number' ? params.timeWindow : await params.timeWindow(req, key)
let current = 0
let ttl = 0
let ttlInSeconds = 0

// We increment the rate limit for the current request
try {
const res = await new Promise((resolve, reject) => {
store.incr(key, (err, res) => {
err ? reject(err) : resolve(res)
}, timeWindow, max)
})

current = res.current
ttl = res.ttl
ttlInSeconds = Math.ceil(res.ttl / 1000)
} catch (err) {
if (!params.skipOnError) {
throw err
}
}

return {
isAllowed: false,
key,
max,
timeWindow,
remaining: Math.max(0, max - current),
ttl,
ttlInSeconds,
isExceeded: current > max,
isBanned: params.ban !== -1 && current - max > params.ban
}
}

function rateLimitRequestHandler (pluginComponent, params) {
const { rateLimitRan, store } = pluginComponent
const { rateLimitRan } = pluginComponent

let timeWindowString
if (typeof params.timeWindow === 'number') {
Expand All @@ -216,51 +292,25 @@ function rateLimitRequestHandler (pluginComponent, params) {

req[rateLimitRan] = true

// Retrieve the key from the generator (the global one or the one defined in the endpoint)
let key = await params.keyGenerator(req)
const groupId = req.routeOptions.config?.rateLimit?.groupId

if (groupId) {
key += groupId
}

// Don't apply any rate limiting if in the allow list
if (params.allowList) {
if (typeof params.allowList === 'function') {
if (await params.allowList(req, key)) {
return
}
} else if (params.allowList.indexOf(key) !== -1) {
return
}
const rateLimit = await applyRateLimit(pluginComponent, params, req)
if (rateLimit.isAllowed) {
return
}

const max = typeof params.max === 'number' ? params.max : await params.max(req, key)
const timeWindow = typeof params.timeWindow === 'number' ? params.timeWindow : await params.timeWindow(req, key)
let current = 0
let ttl = 0
let ttlInSeconds = 0

// We increment the rate limit for the current request
try {
const res = await new Promise((resolve, reject) => {
store.incr(key, (err, res) => {
err ? reject(err) : resolve(res)
}, timeWindow, max)
})

current = res.current
ttl = res.ttl
ttlInSeconds = Math.ceil(res.ttl / 1000)
} catch (err) {
if (!params.skipOnError) {
throw err
}
}
const {
key,
max,
timeWindow,
remaining,
ttl,
ttlInSeconds,
isExceeded,
isBanned
} = rateLimit

if (current <= max) {
if (!isExceeded) {
if (params.addHeadersOnExceeding[params.labels.rateLimit]) { res.header(params.labels.rateLimit, max) }
if (params.addHeadersOnExceeding[params.labels.rateRemaining]) { res.header(params.labels.rateRemaining, max - current) }
if (params.addHeadersOnExceeding[params.labels.rateRemaining]) { res.header(params.labels.rateRemaining, remaining) }
if (params.addHeadersOnExceeding[params.labels.rateReset]) { res.header(params.labels.rateReset, ttlInSeconds) }

params.onExceeding(req, key)
Expand All @@ -283,7 +333,7 @@ function rateLimitRequestHandler (pluginComponent, params) {
after: timeWindowString ?? ms.format(timeWindow, true)
}

if (params.ban !== -1 && current - max > params.ban) {
if (isBanned) {
respCtx.statusCode = 403
respCtx.ban = true
params.onBanReach(req, key)
Expand Down
51 changes: 36 additions & 15 deletions types/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,24 @@ import {

declare module 'fastify' {
interface FastifyInstance<RawServer, RawRequest, RawReply, Logger, TypeProvider> {
createRateLimit(options?: fastifyRateLimit.CreateRateLimitOptions): (req: FastifyRequest) => Promise<
| {
isAllowed: true
key: string
}
| {
isAllowed: false
key: string
max: number
timeWindow: number
remaining: number
ttl: number
ttlInSeconds: number
isExceeded: boolean
isBanned: boolean
}
>

rateLimit<
RouteGeneric extends RouteGenericInterface = RouteGenericInterface,
ContextConfig = ContextConfigDefault,
Expand Down Expand Up @@ -89,13 +107,9 @@ declare namespace fastifyRateLimit {
'ratelimit-reset'?: boolean;
}

export type RateLimitHook =
| 'onRequest'
| 'preParsing'
| 'preValidation'
| 'preHandler'

export interface RateLimitOptions {
export interface CreateRateLimitOptions {
store?: FastifyRateLimitStoreCtor;
skipOnError?: boolean;
max?:
| number
| ((req: FastifyRequest, key: string) => number)
Expand All @@ -105,19 +119,26 @@ declare namespace fastifyRateLimit {
| string
| ((req: FastifyRequest, key: string) => number)
| ((req: FastifyRequest, key: string) => Promise<number>);
hook?: RateLimitHook;
cache?: number;
store?: FastifyRateLimitStoreCtor;
/**
* @deprecated Use `allowList` property
*/
* @deprecated Use `allowList` property
*/
whitelist?: string[] | ((req: FastifyRequest, key: string) => boolean);
allowList?: string[] | ((req: FastifyRequest, key: string) => boolean | Promise<boolean>);
continueExceeding?: boolean;
skipOnError?: boolean;
keyGenerator?: (req: FastifyRequest) => string | number | Promise<string | number>;
ban?: number;
}

export type RateLimitHook =
| 'onRequest'
| 'preParsing'
| 'preValidation'
| 'preHandler'

export interface RateLimitOptions extends CreateRateLimitOptions {
hook?: RateLimitHook;
cache?: number;
continueExceeding?: boolean;
onBanReach?: (req: FastifyRequest, key: string) => void;
keyGenerator?: (req: FastifyRequest) => string | number | Promise<string | number>;
groupId?: string;
errorResponseBuilder?: (
req: FastifyRequest,
Expand Down
Loading