Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: introduce Body Limit Middleware using stream #2103

Merged
merged 7 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions deno_dist/middleware/body-limit/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import type { Context } from '../../context.ts'
import { HTTPException } from '../../http-exception.ts'
import type { MiddlewareHandler } from '../../types.ts'

const ERROR_MESSAGE = 'Payload Too Large'

type OnError = (c: Context) => Response | Promise<Response>
type BodyLimitOptions = {
maxSize: number
onError?: OnError
}

class BodyLimitError extends Error {
constructor(message: string) {
super(message)
this.name = 'BodyLimitError'
}
}

/**
* Body Limit Middleware
*
* @example
* ```ts
* app.post(
* '/hello',
* bodyLimit({
* maxSize: 100 * 1024, // 100kb
* onError: (c) => {
* return c.text('overflow :(', 413)
* }
* }),
* (c) => {
* return c.text('pass :)')
* }
* )
* ```
*/
export const bodyLimit = (options: BodyLimitOptions): MiddlewareHandler => {
const onError: OnError =
options.onError ||
(() => {
const res = new Response(ERROR_MESSAGE, {
status: 413,
})
throw new HTTPException(413, { res })
})
const maxSize = options.maxSize

return async function bodyLimit(c, next) {
if (!c.req.raw.body) {
// maybe GET or HEAD request
return next()
}

if (c.req.raw.headers.has('content-length')) {
// we can trust content-length header because it's already validated by server
const contentLength = parseInt(c.req.raw.headers.get('content-length') || '0', 10)
return contentLength > maxSize ? onError(c) : next()
}

// maybe chunked transfer encoding

let size = 0
const rawReader = c.req.raw.body.getReader()
const reader = new ReadableStream({
async start(controller) {
try {
for (;;) {
const { done, value } = await rawReader.read()
if (done) {
break
}
size += value.length
if (size > maxSize) {
controller.error(new BodyLimitError(ERROR_MESSAGE))
break
}

controller.enqueue(value)
}
} finally {
controller.close()
}
},
})

c.req.raw = new Request(c.req.raw, { body: reader })

await next()

if (c.error instanceof BodyLimitError) {
c.res = await onError(c)
}
}
}
8 changes: 8 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@
"import": "./dist/middleware/bearer-auth/index.js",
"require": "./dist/cjs/middleware/bearer-auth/index.js"
},
"./body-limit": {
"types": "./dist/types/middleware/body-limit/index.d.ts",
"import": "./dist/middleware/body-limit/index.js",
"require": "./dist/cjs/middleware/body-limit/index.js"
},
"./cache": {
"types": "./dist/types/middleware/cache/index.d.ts",
"import": "./dist/middleware/cache/index.js",
Expand Down Expand Up @@ -338,6 +343,9 @@
"bearer-auth": [
"./dist/types/middleware/bearer-auth"
],
"body-limit": [
"./dist/types/middleware/body-limit"
],
"cache": [
"./dist/types/middleware/cache"
],
Expand Down
143 changes: 143 additions & 0 deletions src/middleware/body-limit/index.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import { Hono } from '../../hono'
import { bodyLimit } from '.'

const GlobalRequest = globalThis.Request
globalThis.Request = class Request extends GlobalRequest {
constructor(input: Request | string, init: RequestInit) {
if (init) {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
;(init as any).duplex ??= 'half'
}
super(input, init)
}
} as typeof GlobalRequest

const buildRequestInit = (init: RequestInit = {}): RequestInit => {
const headers: Record<string, string> = {
'Content-Type': 'text/plain',
}
if (typeof init.body === 'string') {
headers['Content-Length'] = init.body.length.toString()
}
return {
method: 'POST',
headers,
body: null,
...init,
}
}

describe('Body Limit Middleware', () => {
let app: Hono

const exampleText = 'hono is so cool' // 15byte
const exampleText2 = 'hono is so cool and cute' // 24byte

beforeEach(() => {
app = new Hono()
app.use('*', bodyLimit({ maxSize: 15 }))
app.get('/', (c) => c.text('index'))
app.post('/body-limit-15byte', async (c) => {
return c.text(await c.req.raw.text())
})
})

describe('GET request', () => {
it('should return 200 response', async () => {
const res = await app.request('/')
expect(res).not.toBeNull()
expect(res.status).toBe(200)
expect(await res.text()).toBe('index')
})
})

describe('POST request', () => {
describe('string body', () => {
it('should return 200 response', async () => {
const res = await app.request('/body-limit-15byte', buildRequestInit({ body: exampleText }))

expect(res).not.toBeNull()
expect(res.status).toBe(200)
expect(await res.text()).toBe(exampleText)
})

it('should return 413 response', async () => {
const res = await app.request(
'/body-limit-15byte',
buildRequestInit({ body: exampleText2 })
)

expect(res).not.toBeNull()
expect(res.status).toBe(413)
expect(await res.text()).toBe('Payload Too Large')
})
})

describe('ReadableStream body', () => {
it('should return 200 response', async () => {
const contents = ['a', 'b', 'c']
const stream = new ReadableStream({
start(controller) {
while (contents.length) {
controller.enqueue(new TextEncoder().encode(contents.shift() as string))
}
controller.close()
},
})
const res = await app.request('/body-limit-15byte', buildRequestInit({ body: stream }))

expect(res).not.toBeNull()
expect(res.status).toBe(200)
expect(await res.text()).toBe('abc')
})

it('should return 413 response', async () => {
const readSpy = vi.fn().mockImplementation(() => {
return {
done: false,
value: new TextEncoder().encode(exampleText),
}
})
const stream = new ReadableStream()
vi.spyOn(stream, 'getReader').mockReturnValue({
read: readSpy,
} as unknown as ReadableStreamDefaultReader)
const res = await app.request('/body-limit-15byte', buildRequestInit({ body: stream }))

expect(res).not.toBeNull()
expect(res.status).toBe(413)
expect(readSpy).toHaveBeenCalledTimes(2)
expect(await res.text()).toBe('Payload Too Large')
})
})
})

describe('custom error handler', () => {
beforeEach(() => {
app = new Hono()
app.post(
'/text-limit-15byte-custom',
bodyLimit({
maxSize: 15,
onError: (c) => {
return c.text('no', 413)
},
}),
(c) => {
return c.text('yes')
}
)
})

it('should return the custom error handler', async () => {
const res = await app.request(
'/text-limit-15byte-custom',
buildRequestInit({ body: exampleText2 })
)

expect(res).not.toBeNull()
expect(res.status).toBe(413)
expect(await res.text()).toBe('no')
})
})
})
96 changes: 96 additions & 0 deletions src/middleware/body-limit/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import type { Context } from '../../context'
import { HTTPException } from '../../http-exception'
import type { MiddlewareHandler } from '../../types'

const ERROR_MESSAGE = 'Payload Too Large'

type OnError = (c: Context) => Response | Promise<Response>
type BodyLimitOptions = {
maxSize: number
onError?: OnError
}

class BodyLimitError extends Error {
constructor(message: string) {
super(message)
this.name = 'BodyLimitError'
}
}

/**
* Body Limit Middleware
*
* @example
* ```ts
* app.post(
* '/hello',
* bodyLimit({
* maxSize: 100 * 1024, // 100kb
* onError: (c) => {
* return c.text('overflow :(', 413)
* }
* }),
* (c) => {
* return c.text('pass :)')
* }
* )
* ```
*/
export const bodyLimit = (options: BodyLimitOptions): MiddlewareHandler => {
const onError: OnError =
options.onError ||
(() => {
const res = new Response(ERROR_MESSAGE, {
status: 413,
})
throw new HTTPException(413, { res })
})
const maxSize = options.maxSize

return async function bodyLimit(c, next) {
if (!c.req.raw.body) {
// maybe GET or HEAD request
return next()
}

if (c.req.raw.headers.has('content-length')) {
// we can trust content-length header because it's already validated by server
const contentLength = parseInt(c.req.raw.headers.get('content-length') || '0', 10)
return contentLength > maxSize ? onError(c) : next()
}

// maybe chunked transfer encoding

let size = 0
const rawReader = c.req.raw.body.getReader()
const reader = new ReadableStream({
async start(controller) {
try {
for (;;) {
const { done, value } = await rawReader.read()
if (done) {
break
}
size += value.length
if (size > maxSize) {
controller.error(new BodyLimitError(ERROR_MESSAGE))
break
}

controller.enqueue(value)
}
} finally {
controller.close()
}
},
})

c.req.raw = new Request(c.req.raw, { body: reader })

await next()

if (c.error instanceof BodyLimitError) {
c.res = await onError(c)
}
}
}