diff --git a/jsr.json b/jsr.json index bfeae8b22..e6af4ed62 100644 --- a/jsr.json +++ b/jsr.json @@ -26,6 +26,7 @@ "./cookie": "./src/helper/cookie/index.ts", "./accepts": "./src/helper/accepts/index.ts", "./compress": "./src/middleware/compress/index.ts", + "./context-storage": "./src/middleware/context-storage/index.ts", "./cors": "./src/middleware/cors/index.ts", "./csrf": "./src/middleware/csrf/index.ts", "./etag": "./src/middleware/etag/index.ts", diff --git a/package.json b/package.json index f7647825b..363aaaffa 100644 --- a/package.json +++ b/package.json @@ -104,6 +104,11 @@ "import": "./dist/middleware/compress/index.js", "require": "./dist/cjs/middleware/compress/index.js" }, + "./context-storage": { + "types": "./dist/types/middleware/context-storage/index.d.ts", + "import": "./dist/middleware/context-storage/index.js", + "require": "./dist/cjs/middleware/context-storage/index.js" + }, "./cors": { "types": "./dist/types/middleware/cors/index.d.ts", "import": "./dist/middleware/cors/index.js", @@ -421,6 +426,9 @@ "compress": [ "./dist/types/middleware/compress" ], + "context-storage": [ + "./dist/types/middleware/context-storage" + ], "cors": [ "./dist/types/middleware/cors" ], diff --git a/runtime_tests/node/index.test.ts b/runtime_tests/node/index.test.ts index 92892ef8f..b34cf23a3 100644 --- a/runtime_tests/node/index.test.ts +++ b/runtime_tests/node/index.test.ts @@ -1,12 +1,14 @@ -import { createAdaptorServer } from '@hono/node-server' +import type { Server } from 'node:http' +import { createAdaptorServer, serve } from '@hono/node-server' import request from 'supertest' import { Hono } from '../../src' import { Context } from '../../src/context' import { env, getRuntimeKey } from '../../src/helper/adapter' import { basicAuth } from '../../src/middleware/basic-auth' import { jwt } from '../../src/middleware/jwt' -import { HonoRequest } from '../../src/request' +import { compress } from '../../src/middleware/compress' import { stream, streamSSE } from '../../src/helper/streaming' +import type { AddressInfo } from 'node:net' // Test only minimal patterns. // See for more tests and information. @@ -38,7 +40,7 @@ describe('Basic', () => { describe('Environment Variables', () => { it('Should return the environment variable', async () => { - const c = new Context(new HonoRequest(new Request('http://localhost/'))) + const c = new Context(new Request('http://localhost/')) const { NAME } = env<{ NAME: string }>(c) expect(NAME).toBe('Node') }) @@ -203,3 +205,44 @@ describe('streamSSE', () => { expect(aborted).toBe(false) }) }) + +describe('compress', async () => { + const cssContent = Array.from({ length: 60 }, () => 'body { color: red; }').join('\n') + const [externalServer, serverInfo] = await new Promise<[Server, AddressInfo]>((resolve) => { + const externalApp = new Hono() + externalApp.get('/style.css', (c) => + c.text(cssContent, { + headers: { + 'Content-Type': 'text/css', + }, + }) + ) + const server = serve( + { + fetch: externalApp.fetch, + port: 0, + }, + (serverInfo) => { + resolve([server as Server, serverInfo]) + } + ) + }) + + const app = new Hono() + app.use(compress()) + app.get('/fetch/:file', (c) => { + return fetch(`http://${serverInfo.address}:${serverInfo.port}/${c.req.param('file')}`) + }) + const server = createAdaptorServer(app) + + afterAll(() => { + externalServer.close() + }) + + it('Should be compressed a fetch response', async () => { + const res = await request(server).get('/fetch/style.css') + expect(res.status).toBe(200) + expect(res.headers['content-encoding']).toBe('gzip') + expect(res.text).toBe(cssContent) + }) +}) diff --git a/src/adapter/bun/websocket.ts b/src/adapter/bun/websocket.ts index 3f236f2b9..daae76479 100644 --- a/src/adapter/bun/websocket.ts +++ b/src/adapter/bun/websocket.ts @@ -19,11 +19,9 @@ interface BunWebSocketHandler { close(ws: BunServerWebSocket, code?: number, reason?: string): void message(ws: BunServerWebSocket, message: string | Uint8Array): void } -interface CreateWebSocket { - (): { - upgradeWebSocket: UpgradeWebSocket - websocket: BunWebSocketHandler - } +interface CreateWebSocket { + upgradeWebSocket: UpgradeWebSocket + websocket: BunWebSocketHandler } export interface BunWebSocketData { connId: number @@ -49,10 +47,11 @@ const createWSContext = (ws: BunServerWebSocket): WSContext => } } -export const createBunWebSocket: CreateWebSocket = () => { +export const createBunWebSocket = (): CreateWebSocket => { const websocketConns: WSEvents[] = [] - const upgradeWebSocket: UpgradeWebSocket = (createEvents) => { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const upgradeWebSocket: UpgradeWebSocket = (createEvents) => { return async (c, next) => { const server = getBunServer(c) if (!server) { diff --git a/src/adapter/cloudflare-pages/handler.test.ts b/src/adapter/cloudflare-pages/handler.test.ts index d08e31dca..747d11e07 100644 --- a/src/adapter/cloudflare-pages/handler.test.ts +++ b/src/adapter/cloudflare-pages/handler.test.ts @@ -230,6 +230,18 @@ describe('Middleware adapter for Cloudflare Pages', () => { await expect(handler(eventContext)).rejects.toThrowError('Something went wrong') expect(next).not.toHaveBeenCalled() }) + + it('Should set the data in eventContext.data', async () => { + const next = vi.fn() + const eventContext = createEventContext({ next }) + const handler = handleMiddleware(async (c, next) => { + c.env.eventContext.data.user = 'Joe' + await next() + }) + expect(eventContext.data.user).toBeUndefined() + await handler(eventContext) + expect(eventContext.data.user).toBe('Joe') + }) }) describe('serveStatic()', () => { diff --git a/src/adapter/cloudflare-pages/handler.ts b/src/adapter/cloudflare-pages/handler.ts index a3b988574..6eb7df406 100644 --- a/src/adapter/cloudflare-pages/handler.ts +++ b/src/adapter/cloudflare-pages/handler.ts @@ -9,7 +9,7 @@ import type { BlankSchema, Env, Input, MiddlewareHandler, Schema } from '../../t type Params

= Record // eslint-disable-next-line @typescript-eslint/no-explicit-any -export type EventContext = { +export type EventContext> = { request: Request functionPath: string waitUntil: (promise: Promise) => void @@ -43,12 +43,20 @@ export const handle = } // eslint-disable-next-line @typescript-eslint/no-explicit-any -export function handleMiddleware( - middleware: MiddlewareHandler +export function handleMiddleware( + middleware: MiddlewareHandler< + E & { + Bindings: { + eventContext: EventContext + } + }, + P, + I + > ): PagesFunction { return async (executionCtx) => { const context = new Context(executionCtx.request, { - env: executionCtx.env, + env: { ...executionCtx.env, eventContext: executionCtx }, executionCtx, }) diff --git a/src/adapter/cloudflare-workers/websocket.ts b/src/adapter/cloudflare-workers/websocket.ts index c7b982659..312e34786 100644 --- a/src/adapter/cloudflare-workers/websocket.ts +++ b/src/adapter/cloudflare-workers/websocket.ts @@ -1,7 +1,7 @@ import type { UpgradeWebSocket, WSContext, WSReadyState } from '../../helper/websocket' // Based on https://github.com/honojs/hono/issues/1153#issuecomment-1767321332 -export const upgradeWebSocket: UpgradeWebSocket = (createEvents) => async (c, next) => { +export const upgradeWebSocket: UpgradeWebSocket = (createEvents) => async (c, next) => { const events = await createEvents(c) const upgradeHeader = c.req.header('Upgrade') @@ -14,7 +14,7 @@ export const upgradeWebSocket: UpgradeWebSocket = (createEvents) => async (c, ne const client: WebSocket = webSocketPair[0] const server: WebSocket = webSocketPair[1] - const wsContext: WSContext = { + const wsContext: WSContext = { binaryType: 'arraybuffer', close: (code, reason) => server.close(code, reason), get protocol() { diff --git a/src/adapter/deno/websocket.ts b/src/adapter/deno/websocket.ts index 6dc3db111..1eda3e151 100644 --- a/src/adapter/deno/websocket.ts +++ b/src/adapter/deno/websocket.ts @@ -20,7 +20,7 @@ export interface UpgradeWebSocketOptions { idleTimeout?: number } -export const upgradeWebSocket: UpgradeWebSocket = +export const upgradeWebSocket: UpgradeWebSocket = (createEvents, options) => async (c, next) => { if (c.req.header('upgrade') !== 'websocket') { return await next() @@ -29,7 +29,7 @@ export const upgradeWebSocket: UpgradeWebSocket = const events = await createEvents(c) const { response, socket } = Deno.upgradeWebSocket(c.req.raw, options || {}) - const wsContext: WSContext = { + const wsContext: WSContext = { binaryType: 'arraybuffer', close: (code, reason) => socket.close(code, reason), get protocol() { diff --git a/src/context.test.ts b/src/context.test.ts index 8503220fb..ee02296bd 100644 --- a/src/context.test.ts +++ b/src/context.test.ts @@ -1,6 +1,29 @@ import { Context } from './context' import { setCookie } from './helper/cookie' +const makeResponseHeaderImmutable = (res: Response) => { + Object.defineProperty(res, 'headers', { + value: new Proxy(res.headers, { + set(target, prop, value) { + if (prop === 'set') { + throw new TypeError('Cannot modify headers: Headers are immutable') + } + return Reflect.set(target, prop, value) + }, + get(target, prop) { + if (prop === 'set') { + return function () { + throw new TypeError('Cannot modify headers: Headers are immutable') + } + } + return Reflect.get(target, prop) + }, + }), + writable: false, + }) + return res +} + describe('Context', () => { const req = new Request('http://localhost/') @@ -360,6 +383,28 @@ describe('Context header', () => { const res = c.text('Hi') expect(res.headers.get('set-cookie')).toBe('a, b, c') }) + + it('Should be able to overwrite a fetch response with a new response.', async () => { + c.res = makeResponseHeaderImmutable(new Response('bar')) + c.res = new Response('foo', { + headers: { + 'X-Custom': 'Message', + }, + }) + expect(c.res.text()).resolves.toBe('foo') + expect(c.res.headers.get('X-Custom')).toBe('Message') + }) + + it('Should be able to overwrite a response with a fetch response.', async () => { + c.res = new Response('foo', { + headers: { + 'X-Custom': 'Message', + }, + }) + c.res = makeResponseHeaderImmutable(new Response('bar')) + expect(c.res.text()).resolves.toBe('bar') + expect(c.res.headers.get('X-Custom')).toBe('Message') + }) }) describe('Pass a ResponseInit to respond methods', () => { diff --git a/src/context.ts b/src/context.ts index e3dbc3c5c..beb133210 100644 --- a/src/context.ts +++ b/src/context.ts @@ -465,16 +465,31 @@ export class Context< set res(_res: Response | undefined) { this.#isFresh = false if (this.#res && _res) { - this.#res.headers.delete('content-type') - for (const [k, v] of this.#res.headers.entries()) { - if (k === 'set-cookie') { - const cookies = this.#res.headers.getSetCookie() - _res.headers.delete('set-cookie') - for (const cookie of cookies) { - _res.headers.append('set-cookie', cookie) + try { + for (const [k, v] of this.#res.headers.entries()) { + if (k === 'content-type') { + continue } + if (k === 'set-cookie') { + const cookies = this.#res.headers.getSetCookie() + _res.headers.delete('set-cookie') + for (const cookie of cookies) { + _res.headers.append('set-cookie', cookie) + } + } else { + _res.headers.set(k, v) + } + } + } catch (e) { + if (e instanceof TypeError && e.message.includes('immutable')) { + // `_res` is immutable (probably a response from a fetch API), so retry with a new response. + this.res = new Response(_res.body, { + headers: _res.headers, + status: _res.status, + }) + return } else { - _res.headers.set(k, v) + throw e } } } @@ -844,18 +859,11 @@ export class Context< this.#preparedHeaders['content-type'] = 'text/html; charset=UTF-8' if (typeof html === 'object') { - if (!(html instanceof Promise)) { - html = (html as string).toString() // HtmlEscapedString object to string - } - if ((html as string | Promise) instanceof Promise) { - return (html as unknown as Promise) - .then((html) => resolveCallback(html, HtmlEscapedCallbackPhase.Stringify, false, {})) - .then((html) => { - return typeof arg === 'number' - ? this.newResponse(html, arg, headers) - : this.newResponse(html, arg) - }) - } + return resolveCallback(html, HtmlEscapedCallbackPhase.Stringify, false, {}).then((html) => { + return typeof arg === 'number' + ? this.newResponse(html, arg, headers) + : this.newResponse(html, arg) + }) } return typeof arg === 'number' diff --git a/src/helper/streaming/sse.test.ts b/src/helper/streaming/sse.test.tsx similarity index 62% rename from src/helper/streaming/sse.test.ts rename to src/helper/streaming/sse.test.tsx index eb7bbb897..df77bb3cb 100644 --- a/src/helper/streaming/sse.test.ts +++ b/src/helper/streaming/sse.test.tsx @@ -1,3 +1,5 @@ +/** @jsxImportSource ../../jsx */ +import { ErrorBoundary } from '../../jsx' import { Context } from '../../context' import { streamSSE } from '.' @@ -145,4 +147,90 @@ describe('SSE Streaming helper', () => { expect(onError).toBeCalledTimes(1) expect(onError).toBeCalledWith(new Error('Test error'), expect.anything()) // 2nd argument is StreamingApi instance }) + + it('Check streamSSE Response via Promise', async () => { + const res = streamSSE(c, async (stream) => { + await stream.writeSSE({ data: Promise.resolve('Async Message') }) + }) + + expect(res).not.toBeNull() + expect(res.status).toBe(200) + + if (!res.body) { + throw new Error('Body is null') + } + const reader = res.body.getReader() + const decoder = new TextDecoder() + const { value } = await reader.read() + const decodedValue = decoder.decode(value) + expect(decodedValue).toBe('data: Async Message\n\n') + }) + + it('Check streamSSE Response via JSX.Element', async () => { + const res = streamSSE(c, async (stream) => { + await stream.writeSSE({ data:

Hello
}) + }) + + expect(res).not.toBeNull() + expect(res.status).toBe(200) + + if (!res.body) { + throw new Error('Body is null') + } + const reader = res.body.getReader() + const decoder = new TextDecoder() + const { value } = await reader.read() + const decodedValue = decoder.decode(value) + expect(decodedValue).toBe('data:
Hello
\n\n') + }) + + it('Check streamSSE Response via ErrorBoundary in success case', async () => { + const AsyncComponent = async () => Promise.resolve(
Async Hello
) + const res = streamSSE(c, async (stream) => { + await stream.writeSSE({ + data: ( + Error}> + + + ), + }) + }) + + expect(res).not.toBeNull() + expect(res.status).toBe(200) + + if (!res.body) { + throw new Error('Body is null') + } + const reader = res.body.getReader() + const decoder = new TextDecoder() + const { value } = await reader.read() + const decodedValue = decoder.decode(value) + expect(decodedValue).toBe('data:
Async Hello
\n\n') + }) + + it('Check streamSSE Response via ErrorBoundary in error case', async () => { + const AsyncComponent = async () => Promise.reject() + const res = streamSSE(c, async (stream) => { + await stream.writeSSE({ + data: ( + Error}> + + + ), + }) + }) + + expect(res).not.toBeNull() + expect(res.status).toBe(200) + + if (!res.body) { + throw new Error('Body is null') + } + const reader = res.body.getReader() + const decoder = new TextDecoder() + const { value } = await reader.read() + const decodedValue = decoder.decode(value) + expect(decodedValue).toBe('data:
Error
\n\n') + }) }) diff --git a/src/helper/streaming/sse.ts b/src/helper/streaming/sse.ts index 1ed96e13d..fb38f3d45 100644 --- a/src/helper/streaming/sse.ts +++ b/src/helper/streaming/sse.ts @@ -1,8 +1,9 @@ import type { Context } from '../../context' import { StreamingApi } from '../../utils/stream' +import { HtmlEscapedCallbackPhase, resolveCallback } from '../../utils/html' export interface SSEMessage { - data: string + data: string | Promise event?: string id?: string retry?: number @@ -14,7 +15,8 @@ export class SSEStreamingApi extends StreamingApi { } async writeSSE(message: SSEMessage) { - const data = message.data + const data = await resolveCallback(message.data, HtmlEscapedCallbackPhase.Stringify, false, {}) + const dataLines = (data as string) .split('\n') .map((line) => { return `data: ${line}` @@ -24,7 +26,7 @@ export class SSEStreamingApi extends StreamingApi { const sseData = [ message.event && `event: ${message.event}`, - data, + dataLines, message.id && `id: ${message.id}`, message.retry && `retry: ${message.retry}`, ] diff --git a/src/helper/websocket/index.ts b/src/helper/websocket/index.ts index 512650d83..2c27b63de 100644 --- a/src/helper/websocket/index.ts +++ b/src/helper/websocket/index.ts @@ -10,19 +10,19 @@ import type { MiddlewareHandler } from '../../types' /** * WebSocket Event Listeners type */ -export interface WSEvents { - onOpen?: (evt: Event, ws: WSContext) => void - onMessage?: (evt: MessageEvent, ws: WSContext) => void - onClose?: (evt: CloseEvent, ws: WSContext) => void - onError?: (evt: Event, ws: WSContext) => void +export interface WSEvents { + onOpen?: (evt: Event, ws: WSContext) => void + onMessage?: (evt: MessageEvent, ws: WSContext) => void + onClose?: (evt: CloseEvent, ws: WSContext) => void + onError?: (evt: Event, ws: WSContext) => void } /** * Upgrade WebSocket Type */ -export type UpgradeWebSocket = ( - createEvents: (c: Context) => WSEvents | Promise, - options?: T +export type UpgradeWebSocket = ( + createEvents: (c: Context) => WSEvents | Promise>, + options?: U ) => MiddlewareHandler< any, string, @@ -33,14 +33,14 @@ export type UpgradeWebSocket = ( export type WSReadyState = 0 | 1 | 2 | 3 -export type WSContext = { +export type WSContext = { send( source: string | ArrayBuffer | Uint8Array, options?: { compress: boolean } ): void - raw?: unknown + raw?: T binaryType: BinaryType readyState: WSReadyState url: URL | null diff --git a/src/middleware/basic-auth/index.test.ts b/src/middleware/basic-auth/index.test.ts index 147b3c95a..1fc7015a5 100644 --- a/src/middleware/basic-auth/index.test.ts +++ b/src/middleware/basic-auth/index.test.ts @@ -83,6 +83,42 @@ describe('Basic Auth by Middleware', () => { return auth(c, next) }) + app.use( + '/auth-custom-invalid-user-message-string/*', + basicAuth({ + username, + password, + invalidUserMessage: 'Custom unauthorized message as string', + }) + ) + + app.use( + '/auth-custom-invalid-user-message-object/*', + basicAuth({ + username, + password, + invalidUserMessage: { message: 'Custom unauthorized message as object' }, + }) + ) + + app.use( + '/auth-custom-invalid-user-message-function-string/*', + basicAuth({ + username, + password, + invalidUserMessage: () => 'Custom unauthorized message as function string', + }) + ) + + app.use( + '/auth-custom-invalid-user-message-function-object/*', + basicAuth({ + username, + password, + invalidUserMessage: () => ({ message: 'Custom unauthorized message as function object' }), + }) + ) + app.get('/auth/*', (c) => { handlerExecuted = true return c.text('auth') @@ -110,6 +146,24 @@ describe('Basic Auth by Middleware', () => { return c.text('verify-user') }) + app.get('/auth-custom-invalid-user-message-string/*', (c) => { + handlerExecuted = true + return c.text('auth') + }) + app.get('/auth-custom-invalid-user-message-object/*', (c) => { + handlerExecuted = true + return c.text('auth') + }) + app.get('/auth-custom-invalid-user-message-function-string/*', (c) => { + handlerExecuted = true + return c.text('auth') + }) + + app.get('/auth-custom-invalid-user-message-function-object/*', (c) => { + handlerExecuted = true + return c.text('auth') + }) + it('Should not authorize', async () => { const req = new Request('http://localhost/auth/a') const res = await app.request(req) @@ -226,4 +280,42 @@ describe('Basic Auth by Middleware', () => { expect(res.status).toBe(401) expect(await res.text()).toBe('Unauthorized') }) + + it('Should not authorize - custom invalid user message as string', async () => { + const req = new Request('http://localhost/auth-custom-invalid-user-message-string') + const res = await app.request(req) + expect(res).not.toBeNull() + expect(res.status).toBe(401) + expect(handlerExecuted).toBeFalsy() + expect(await res.text()).toBe('Custom unauthorized message as string') + }) + + it('Should not authorize - custom invalid user message as object', async () => { + const req = new Request('http://localhost/auth-custom-invalid-user-message-object') + const res = await app.request(req) + expect(res).not.toBeNull() + expect(res.status).toBe(401) + expect(res.headers.get('Content-Type')).toMatch('application/json; charset=UTF-8') + expect(handlerExecuted).toBeFalsy() + expect(await res.text()).toBe('{"message":"Custom unauthorized message as object"}') + }) + + it('Should not authorize - custom invalid user message as function string', async () => { + const req = new Request('http://localhost/auth-custom-invalid-user-message-function-string') + const res = await app.request(req) + expect(res).not.toBeNull() + expect(res.status).toBe(401) + expect(handlerExecuted).toBeFalsy() + expect(await res.text()).toBe('Custom unauthorized message as function string') + }) + + it('Should not authorize - custom invalid user message as function object', async () => { + const req = new Request('http://localhost/auth-custom-invalid-user-message-function-object') + const res = await app.request(req) + expect(res).not.toBeNull() + expect(res.status).toBe(401) + expect(res.headers.get('Content-Type')).toMatch('application/json; charset=UTF-8') + expect(handlerExecuted).toBeFalsy() + expect(await res.text()).toBe('{"message":"Custom unauthorized message as function object"}') + }) }) diff --git a/src/middleware/basic-auth/index.ts b/src/middleware/basic-auth/index.ts index f1a2c3797..aa4ca72d0 100644 --- a/src/middleware/basic-auth/index.ts +++ b/src/middleware/basic-auth/index.ts @@ -9,17 +9,21 @@ import type { MiddlewareHandler } from '../../types' import { auth } from '../../utils/basic-auth' import { timingSafeEqual } from '../../utils/buffer' +type MessageFunction = (c: Context) => string | object | Promise + type BasicAuthOptions = | { username: string password: string realm?: string hashFunction?: Function + invalidUserMessage?: string | object | MessageFunction } | { verifyUser: (username: string, password: string, c: Context) => boolean | Promise realm?: string hashFunction?: Function + invalidUserMessage?: string | object | MessageFunction } /** @@ -33,6 +37,7 @@ type BasicAuthOptions = * @param {string} [options.realm="Secure Area"] - The realm attribute for the WWW-Authenticate header. * @param {Function} [options.hashFunction] - The hash function used for secure comparison. * @param {Function} [options.verifyUser] - The function to verify user credentials. + * @param {string | object | MessageFunction} [options.invalidUserMessage="Unauthorized"] - The invalid user message. * @returns {MiddlewareHandler} The middleware handler function. * @throws {HTTPException} If neither "username and password" nor "verifyUser" options are provided. * @@ -70,6 +75,10 @@ export const basicAuth = ( options.realm = 'Secure Area' } + if (!options.invalidUserMessage) { + options.invalidUserMessage = 'Unauthorized' + } + if (usernamePasswordInOptions) { users.unshift({ username: options.username, password: options.password }) } @@ -95,12 +104,25 @@ export const basicAuth = ( } } } - const res = new Response('Unauthorized', { - status: 401, - headers: { - 'WWW-Authenticate': 'Basic realm="' + options.realm?.replace(/"/g, '\\"') + '"', - }, - }) - throw new HTTPException(401, { res }) + // Invalid user. + const status = 401 + const headers = { + 'WWW-Authenticate': 'Basic realm="' + options.realm?.replace(/"/g, '\\"') + '"', + } + const responseMessage = + typeof options.invalidUserMessage === 'function' + ? await options.invalidUserMessage(ctx) + : options.invalidUserMessage + const res = + typeof responseMessage === 'string' + ? new Response(responseMessage, { status, headers }) + : new Response(JSON.stringify(responseMessage), { + status, + headers: { + ...headers, + 'content-type': 'application/json; charset=UTF-8', + }, + }) + throw new HTTPException(status, { res }) } } diff --git a/src/middleware/bearer-auth/index.test.ts b/src/middleware/bearer-auth/index.test.ts index 632d9c9af..4dca87f22 100644 --- a/src/middleware/bearer-auth/index.test.ts +++ b/src/middleware/bearer-auth/index.test.ts @@ -68,6 +68,163 @@ describe('Bearer Auth by Middleware', () => { handlerExecuted = true return c.text('auth-custom-header') }) + + app.use( + '/auth-custom-no-authentication-header-message-string/*', + bearerAuth({ + token, + noAuthenticationHeaderMessage: 'Custom no authentication header message as string', + }) + ) + app.get('/auth-custom-no-authentication-header-message-string/*', (c) => { + handlerExecuted = true + return c.text('auth') + }) + + app.use( + '/auth-custom-no-authentication-header-message-object/*', + bearerAuth({ + token, + noAuthenticationHeaderMessage: { + message: 'Custom no authentication header message as object', + }, + }) + ) + app.get('/auth-custom-no-authentication-header-message-object/*', (c) => { + handlerExecuted = true + return c.text('auth') + }) + + app.use( + '/auth-custom-no-authentication-header-message-function-string/*', + bearerAuth({ + token, + noAuthenticationHeaderMessage: () => + 'Custom no authentication header message as function string', + }) + ) + app.get('/auth-custom-no-authentication-header-message-function-string/*', (c) => { + handlerExecuted = true + return c.text('auth') + }) + + app.use( + '/auth-custom-no-authentication-header-message-function-object/*', + bearerAuth({ + token, + noAuthenticationHeaderMessage: () => ({ + message: 'Custom no authentication header message as function object', + }), + }) + ) + app.get('/auth-custom-no-authentication-header-message-function-object/*', (c) => { + handlerExecuted = true + return c.text('auth') + }) + + app.use( + '/auth-custom-invalid-authentication-header-message-string/*', + bearerAuth({ + token, + invalidAuthenticationHeaderMessage: + 'Custom invalid authentication header message as string', + }) + ) + app.get('/auth-custom-invalid-authentication-header-message-string/*', (c) => { + handlerExecuted = true + return c.text('auth') + }) + + app.use( + '/auth-custom-invalid-authentication-header-message-object/*', + bearerAuth({ + token, + invalidAuthenticationHeaderMessage: { + message: 'Custom invalid authentication header message as object', + }, + }) + ) + app.get('/auth-custom-invalid-authentication-header-message-object/*', (c) => { + handlerExecuted = true + return c.text('auth') + }) + + app.use( + '/auth-custom-invalid-authentication-header-message-function-string/*', + bearerAuth({ + token, + invalidAuthenticationHeaderMessage: () => + 'Custom invalid authentication header message as function string', + }) + ) + app.get('/auth-custom-invalid-authentication-header-message-function-string/*', (c) => { + handlerExecuted = true + return c.text('auth') + }) + + app.use( + '/auth-custom-invalid-authentication-header-message-function-object/*', + bearerAuth({ + token, + invalidAuthenticationHeaderMessage: () => ({ + message: 'Custom invalid authentication header message as function object', + }), + }) + ) + app.get('/auth-custom-invalid-authentication-header-message-function-object/*', (c) => { + handlerExecuted = true + return c.text('auth') + }) + + app.use( + '/auth-custom-invalid-token-message-string/*', + bearerAuth({ + token, + invalidTokenMessage: 'Custom invalid token message as string', + }) + ) + app.get('/auth-custom-invalid-token-message-string/*', (c) => { + handlerExecuted = true + return c.text('auth') + }) + + app.use( + '/auth-custom-invalid-token-message-object/*', + bearerAuth({ + token, + invalidTokenMessage: { message: 'Custom invalid token message as object' }, + }) + ) + app.get('/auth-custom-invalid-token-message-object/*', (c) => { + handlerExecuted = true + return c.text('auth') + }) + + app.use( + '/auth-custom-invalid-token-message-function-string/*', + bearerAuth({ + token, + invalidTokenMessage: () => 'Custom invalid token message as function string', + }) + ) + app.get('/auth-custom-invalid-token-message-function-string/*', (c) => { + handlerExecuted = true + return c.text('auth') + }) + + app.use( + '/auth-custom-invalid-token-message-function-object/*', + bearerAuth({ + token, + invalidTokenMessage: () => ({ + message: 'Custom invalid token message as function object', + }), + }) + ) + app.get('/auth-custom-invalid-token-message-function-object/*', (c) => { + handlerExecuted = true + return c.text('auth') + }) }) it('Should authorize', async () => { @@ -228,4 +385,144 @@ describe('Bearer Auth by Middleware', () => { expect(res.status).toBe(401) expect(await res.text()).toBe('Unauthorized') }) + + it('Should not authorize - custom no authorization header message as string', async () => { + const req = new Request('http://localhost/auth-custom-no-authentication-header-message-string') + const res = await app.request(req) + expect(res).not.toBeNull() + expect(res.status).toBe(401) + expect(handlerExecuted).toBeFalsy() + expect(await res.text()).toBe('Custom no authentication header message as string') + }) + + it('Should not authorize - custom no authorization header message as object', async () => { + const req = new Request('http://localhost/auth-custom-no-authentication-header-message-object') + const res = await app.request(req) + expect(res).not.toBeNull() + expect(res.status).toBe(401) + expect(res.headers.get('Content-Type')).toMatch('application/json; charset=UTF-8') + expect(handlerExecuted).toBeFalsy() + expect(await res.text()).toBe('{"message":"Custom no authentication header message as object"}') + }) + + it('Should not authorize - custom no authorization header message as function string', async () => { + const req = new Request( + 'http://localhost/auth-custom-no-authentication-header-message-function-string' + ) + const res = await app.request(req) + expect(res).not.toBeNull() + expect(res.status).toBe(401) + expect(handlerExecuted).toBeFalsy() + expect(await res.text()).toBe('Custom no authentication header message as function string') + }) + + it('Should not authorize - custom no authorization header message as function object', async () => { + const req = new Request( + 'http://localhost/auth-custom-no-authentication-header-message-function-object' + ) + const res = await app.request(req) + expect(res).not.toBeNull() + expect(res.status).toBe(401) + expect(res.headers.get('Content-Type')).toMatch('application/json; charset=UTF-8') + expect(handlerExecuted).toBeFalsy() + expect(await res.text()).toBe( + '{"message":"Custom no authentication header message as function object"}' + ) + }) + + it('Should not authorize - custom invalid authentication header message as string', async () => { + const req = new Request( + 'http://localhost/auth-custom-invalid-authentication-header-message-string' + ) + req.headers.set('Authorization', 'Beare abcdefg12345-._~+/=') + const res = await app.request(req) + expect(res).not.toBeNull() + expect(res.status).toBe(400) + expect(handlerExecuted).toBeFalsy() + expect(await res.text()).toBe('Custom invalid authentication header message as string') + }) + + it('Should not authorize - custom invalid authentication header message as object', async () => { + const req = new Request( + 'http://localhost/auth-custom-invalid-authentication-header-message-object' + ) + req.headers.set('Authorization', 'Beare abcdefg12345-._~+/=') + const res = await app.request(req) + expect(res).not.toBeNull() + expect(res.status).toBe(400) + expect(res.headers.get('Content-Type')).toMatch('application/json; charset=UTF-8') + expect(handlerExecuted).toBeFalsy() + expect(await res.text()).toBe( + '{"message":"Custom invalid authentication header message as object"}' + ) + }) + + it('Should not authorize - custom invalid authentication header message as function string', async () => { + const req = new Request( + 'http://localhost/auth-custom-invalid-authentication-header-message-function-string' + ) + req.headers.set('Authorization', 'Beare abcdefg12345-._~+/=') + const res = await app.request(req) + expect(res).not.toBeNull() + expect(res.status).toBe(400) + expect(handlerExecuted).toBeFalsy() + expect(await res.text()).toBe('Custom invalid authentication header message as function string') + }) + + it('Should not authorize - custom invalid authentication header message as function object', async () => { + const req = new Request( + 'http://localhost/auth-custom-invalid-authentication-header-message-function-object' + ) + req.headers.set('Authorization', 'Beare abcdefg12345-._~+/=') + const res = await app.request(req) + expect(res).not.toBeNull() + expect(res.status).toBe(400) + expect(res.headers.get('Content-Type')).toMatch('application/json; charset=UTF-8') + expect(handlerExecuted).toBeFalsy() + expect(await res.text()).toBe( + '{"message":"Custom invalid authentication header message as function object"}' + ) + }) + + it('Should not authorize - custom invalid token message as string', async () => { + const req = new Request('http://localhost/auth-custom-invalid-token-message-string') + req.headers.set('Authorization', 'Bearer invalid-token') + const res = await app.request(req) + expect(res).not.toBeNull() + expect(res.status).toBe(401) + expect(handlerExecuted).toBeFalsy() + expect(await res.text()).toBe('Custom invalid token message as string') + }) + + it('Should not authorize - custom invalid token message as object', async () => { + const req = new Request('http://localhost/auth-custom-invalid-token-message-object') + req.headers.set('Authorization', 'Bearer invalid-token') + const res = await app.request(req) + expect(res).not.toBeNull() + expect(res.status).toBe(401) + expect(res.headers.get('Content-Type')).toMatch('application/json; charset=UTF-8') + expect(handlerExecuted).toBeFalsy() + expect(await res.text()).toBe('{"message":"Custom invalid token message as object"}') + }) + + it('Should not authorize - custom invalid token message as function string', async () => { + const req = new Request('http://localhost/auth-custom-invalid-token-message-function-string') + req.headers.set('Authorization', 'Bearer invalid-token') + const res = await app.request(req) + expect(res).not.toBeNull() + expect(res.status).toBe(401) + expect(handlerExecuted).toBeFalsy() + expect(await res.text()).toBe('Custom invalid token message as function string') + }) + + it('Should not authorize - custom invalid token message as function object', async () => { + const req = new Request('http://localhost/auth-custom-invalid-token-message-function-object') + req.headers.set('Authorization', 'Bearer invalid-token') + const res = await app.request(req) + expect(res).not.toBeNull() + expect(res.status).toBe(401) + expect(res.headers.get('Content-Type')).toMatch('application/json; charset=UTF-8') + expect(handlerExecuted).toBeFalsy() + expect(await res.text()).toBe('{"message":"Custom invalid token message as function object"}') + }) }) diff --git a/src/middleware/bearer-auth/index.ts b/src/middleware/bearer-auth/index.ts index 6f6208ba4..0e86d4d47 100644 --- a/src/middleware/bearer-auth/index.ts +++ b/src/middleware/bearer-auth/index.ts @@ -7,11 +7,14 @@ import type { Context } from '../../context' import { HTTPException } from '../../http-exception' import type { MiddlewareHandler } from '../../types' import { timingSafeEqual } from '../../utils/buffer' +import type { StatusCode } from '../../utils/http-status' const TOKEN_STRINGS = '[A-Za-z0-9._~+/-]+=*' const PREFIX = 'Bearer' const HEADER = 'Authorization' +type MessageFunction = (c: Context) => string | object | Promise + type BearerAuthOptions = | { token: string | string[] @@ -19,6 +22,9 @@ type BearerAuthOptions = prefix?: string headerName?: string hashFunction?: Function + noAuthenticationHeaderMessage?: string | object | MessageFunction + invalidAuthenticationHeaderMessage?: string | object | MessageFunction + invalidTokenMessage?: string | object | MessageFunction } | { realm?: string @@ -26,6 +32,9 @@ type BearerAuthOptions = headerName?: string verifyToken: (token: string, c: Context) => boolean | Promise hashFunction?: Function + noAuthenticationHeaderMessage?: string | object | MessageFunction + invalidAuthenticationHeaderMessage?: string | object | MessageFunction + invalidTokenMessage?: string | object | MessageFunction } /** @@ -40,6 +49,9 @@ type BearerAuthOptions = * @param {string} [options.prefix="Bearer"] - The prefix (or known as `schema`) for the Authorization header value. If set to the empty string, no prefix is expected. * @param {string} [options.headerName=Authorization] - The header name. * @param {Function} [options.hashFunction] - A function to handle hashing for safe comparison of authentication tokens. + * @param {string | object | MessageFunction} [options.noAuthenticationHeaderMessage="Unauthorized"] - The no authentication header message. + * @param {string | object | MessageFunction} [options.invalidAuthenticationHeaderMeasage="Bad Request"] - The invalid authentication header message. + * @param {string | object | MessageFunction} [options.invalidTokenMessage="Unauthorized"] - The invalid token message. * @returns {MiddlewareHandler} The middleware handler function. * @throws {Error} If neither "token" nor "verifyToken" options are provided. * @throws {HTTPException} If authentication fails, with 401 status code for missing or invalid token, or 400 status code for invalid request. @@ -73,28 +85,50 @@ export const bearerAuth = (options: BearerAuthOptions): MiddlewareHandler => { const regexp = new RegExp(`^${prefixRegexStr}(${TOKEN_STRINGS}) *$`) const wwwAuthenticatePrefix = options.prefix === '' ? '' : `${options.prefix} ` + const throwHTTPException = async ( + c: Context, + status: StatusCode, + wwwAuthenticateHeader: string, + messageOption: string | object | MessageFunction + ): Promise => { + const headers = { + 'WWW-Authenticate': wwwAuthenticateHeader, + } + const responseMessage = + typeof messageOption === 'function' ? await messageOption(c) : messageOption + const res = + typeof responseMessage === 'string' + ? new Response(responseMessage, { status, headers }) + : new Response(JSON.stringify(responseMessage), { + status, + headers: { + ...headers, + 'content-type': 'application/json; charset=UTF-8', + }, + }) + throw new HTTPException(status, { res }) + } + return async function bearerAuth(c, next) { const headerToken = c.req.header(options.headerName || HEADER) if (!headerToken) { // No Authorization header - const res = new Response('Unauthorized', { - status: 401, - headers: { - 'WWW-Authenticate': `${wwwAuthenticatePrefix}realm="` + realm + '"', - }, - }) - throw new HTTPException(401, { res }) + await throwHTTPException( + c, + 401, + `${wwwAuthenticatePrefix}realm="${realm}"`, + options.noAuthenticationHeaderMessage || 'Unauthorized' + ) } else { const match = regexp.exec(headerToken) if (!match) { // Invalid Request - const res = new Response('Bad Request', { - status: 400, - headers: { - 'WWW-Authenticate': `${wwwAuthenticatePrefix}error="invalid_request"`, - }, - }) - throw new HTTPException(400, { res }) + await throwHTTPException( + c, + 400, + `${wwwAuthenticatePrefix}error="invalid_request"`, + options.invalidAuthenticationHeaderMessage || 'Bad Request' + ) } else { let equal = false if ('verifyToken' in options) { @@ -111,13 +145,12 @@ export const bearerAuth = (options: BearerAuthOptions): MiddlewareHandler => { } if (!equal) { // Invalid Token - const res = new Response('Unauthorized', { - status: 401, - headers: { - 'WWW-Authenticate': `${wwwAuthenticatePrefix}error="invalid_token"`, - }, - }) - throw new HTTPException(401, { res }) + await throwHTTPException( + c, + 401, + `${wwwAuthenticatePrefix}error="invalid_token"`, + options.invalidTokenMessage || 'Unauthorized' + ) } } } diff --git a/src/middleware/context-storage/index.test.ts b/src/middleware/context-storage/index.test.ts new file mode 100644 index 000000000..d9713b975 --- /dev/null +++ b/src/middleware/context-storage/index.test.ts @@ -0,0 +1,30 @@ +import { Hono } from '../../hono' +import { contextStorage, getContext } from '.' + +describe('Context Storage Middleware', () => { + type Env = { + Variables: { + message: string + } + } + + const app = new Hono() + + app.use(contextStorage()) + app.use(async (c, next) => { + c.set('message', 'Hono is cool!!') + await next() + }) + app.get('/', (c) => { + return c.text(getMessage()) + }) + + const getMessage = () => { + return getContext().var.message + } + + it('Should get context', async () => { + const res = await app.request('/') + expect(await res.text()).toBe('Hono is cool!!') + }) +}) diff --git a/src/middleware/context-storage/index.ts b/src/middleware/context-storage/index.ts new file mode 100644 index 000000000..2545d3cab --- /dev/null +++ b/src/middleware/context-storage/index.ts @@ -0,0 +1,55 @@ +/** + * @module + * Context Storage Middleware for Hono. + */ + +import type { Context } from '../../context' +import type { Env, MiddlewareHandler } from '../../types' +import { AsyncLocalStorage } from 'node:async_hooks' + +const asyncLocalStorage = new AsyncLocalStorage() + +/** + * Context Storage Middleware for Hono. + * + * @see {@link https://hono.dev/docs/middleware/builtin/context-storage} + * + * @returns {MiddlewareHandler} The middleware handler function. + * + * @example + * ```ts + * type Env = { + * Variables: { + * message: string + * } + * } + * + * const app = new Hono() + * + * app.use(contextStorage()) + * + * app.use(async (c, next) => { + * c.set('message', 'Hono is cool!!) + * await next() + * }) + * + * app.get('/', async (c) => { c.text(getMessage()) }) + * + * const getMessage = () => { + * return getContext().var.message + * } + * ``` + */ +export const contextStorage = (): MiddlewareHandler => { + return async function contextStorage(c, next) { + await asyncLocalStorage.run(c, next) + } +} + +export const getContext = (): Context => { + const context = asyncLocalStorage.getStore() + if (!context) { + throw new Error('Context is not available') + } + return context +} diff --git a/src/middleware/jsx-renderer/index.test.tsx b/src/middleware/jsx-renderer/index.test.tsx index 17f979999..042a1b633 100644 --- a/src/middleware/jsx-renderer/index.test.tsx +++ b/src/middleware/jsx-renderer/index.test.tsx @@ -238,6 +238,7 @@ describe('JSX renderer', () => { expect(res.status).toBe(200) expect(res.headers.get('Transfer-Encoding')).toEqual('chunked') expect(res.headers.get('Content-Type')).toEqual('text/html; charset=UTF-8') + expect(res.headers.get('Content-Encoding')).toEqual('Identity') if (!res.body) { throw new Error('Body is null') diff --git a/src/middleware/jsx-renderer/index.ts b/src/middleware/jsx-renderer/index.ts index 4eee20f65..195a02293 100644 --- a/src/middleware/jsx-renderer/index.ts +++ b/src/middleware/jsx-renderer/index.ts @@ -60,6 +60,7 @@ const createRenderer = if (options.stream === true) { c.header('Transfer-Encoding', 'chunked') c.header('Content-Type', 'text/html; charset=UTF-8') + c.header('Content-Encoding', 'Identity') } else { for (const [key, value] of Object.entries(options.stream)) { c.header(key, value) diff --git a/src/middleware/secure-headers/index.test.ts b/src/middleware/secure-headers/index.test.ts index 915d6e7de..5d055830b 100644 --- a/src/middleware/secure-headers/index.test.ts +++ b/src/middleware/secure-headers/index.test.ts @@ -36,6 +36,7 @@ describe('Secure Headers Middleware', () => { expect(res.headers.get('Cross-Origin-Resource-Policy')).toEqual('same-origin') expect(res.headers.get('Cross-Origin-Opener-Policy')).toEqual('same-origin') expect(res.headers.get('Origin-Agent-Cluster')).toEqual('?1') + expect(res.headers.get('Permissions-Policy')).toBeNull() expect(res.headers.get('Content-Security-Policy')).toBeFalsy() }) @@ -48,6 +49,9 @@ describe('Secure Headers Middleware', () => { defaultSrc: ["'self'"], }, crossOriginEmbedderPolicy: true, + permissionsPolicy: { + camera: [], + }, }) ) app.get('/test', async (ctx) => { @@ -72,6 +76,7 @@ describe('Secure Headers Middleware', () => { expect(res.headers.get('Cross-Origin-Opener-Policy')).toEqual('same-origin') expect(res.headers.get('Origin-Agent-Cluster')).toEqual('?1') expect(res.headers.get('Cross-Origin-Embedder-Policy')).toEqual('require-corp') + expect(res.headers.get('Permissions-Policy')).toEqual('camera=()') expect(res.headers.get('Content-Security-Policy')).toEqual("default-src 'self'") }) @@ -98,6 +103,7 @@ describe('Secure Headers Middleware', () => { expect(res.headers.get('X-Permitted-Cross-Domain-Policies')).toEqual('none') expect(res.headers.get('Cross-Origin-Resource-Policy')).toEqual('same-origin') expect(res.headers.get('Cross-Origin-Opener-Policy')).toEqual('same-origin') + expect(res.headers.get('Permissions-Policy')).toBeNull() expect(res.headers.get('Origin-Agent-Cluster')).toEqual('?1') }) @@ -154,6 +160,35 @@ describe('Secure Headers Middleware', () => { expect(res.headers.get('X-XSS-Protection')).toEqual('1') }) + it('should set Permission-Policy header correctly', async () => { + const app = new Hono() + app.use( + '/test', + secureHeaders({ + permissionsPolicy: { + fullscreen: ['self'], + bluetooth: ['none'], + payment: ['self', 'example.com'], + syncXhr: [], + camera: false, + microphone: true, + geolocation: ['*'], + usb: ['self', 'https://a.example.com', 'https://b.example.com'], + accelerometer: ['https://*.example.com'], + gyroscope: ['src'], + magnetometer: ['https://a.example.com', 'https://b.example.com'], + }, + }) + ) + + const res = await app.request('/test') + expect(res.headers.get('Permissions-Policy')).toEqual( + 'fullscreen=(self), bluetooth=none, payment=(self "example.com"), sync-xhr=(), camera=none, microphone=*, ' + + 'geolocation=*, usb=(self "https://a.example.com" "https://b.example.com"), ' + + 'accelerometer=("https://*.example.com"), gyroscope=(src), ' + + 'magnetometer=("https://a.example.com" "https://b.example.com")' + ) + }) it('CSP Setting', async () => { const app = new Hono() app.use( diff --git a/src/middleware/secure-headers/permissions-policy.ts b/src/middleware/secure-headers/permissions-policy.ts new file mode 100644 index 000000000..f63593ec5 --- /dev/null +++ b/src/middleware/secure-headers/permissions-policy.ts @@ -0,0 +1,86 @@ +// https://github.com/w3c/webappsec-permissions-policy/blob/main/features.md + +export type PermissionsPolicyDirective = + | StandardizedFeatures + | ProposedFeatures + | ExperimentalFeatures + +/** + * These features have been declared in a published version of the respective specification. + */ +type StandardizedFeatures = + | 'accelerometer' + | 'ambientLightSensor' + | 'attributionReporting' + | 'autoplay' + | 'battery' + | 'bluetooth' + | 'camera' + | 'chUa' + | 'chUaArch' + | 'chUaBitness' + | 'chUaFullVersion' + | 'chUaFullVersionList' + | 'chUaMobile' + | 'chUaModel' + | 'chUaPlatform' + | 'chUaPlatformVersion' + | 'chUaWow64' + | 'computePressure' + | 'crossOriginIsolated' + | 'directSockets' + | 'displayCapture' + | 'encryptedMedia' + | 'executionWhileNotRendered' + | 'executionWhileOutOfViewport' + | 'fullscreen' + | 'geolocation' + | 'gyroscope' + | 'hid' + | 'identityCredentialsGet' + | 'idleDetection' + | 'keyboardMap' + | 'magnetometer' + | 'microphone' + | 'midi' + | 'navigationOverride' + | 'payment' + | 'pictureInPicture' + | 'publickeyCredentialsGet' + | 'screenWakeLock' + | 'serial' + | 'storageAccess' + | 'syncXhr' + | 'usb' + | 'webShare' + | 'windowManagement' + | 'xrSpatialTracking' + +/** + * These features have been proposed, but the definitions have not yet been integrated into their respective specs. + */ +type ProposedFeatures = + | 'clipboardRead' + | 'clipboardWrite' + | 'gemepad' + | 'sharedAutofill' + | 'speakerSelection' + +/** + * These features generally have an explainer only, but may be available for experimentation by web developers. + */ +type ExperimentalFeatures = + | 'allScreensCapture' + | 'browsingTopics' + | 'capturedSurfaceControl' + | 'conversionMeasurement' + | 'digitalCredentialsGet' + | 'focusWithoutUserActivation' + | 'joinAdInterestGroup' + | 'localFonts' + | 'runAdAuction' + | 'smartCard' + | 'syncScript' + | 'trustTokenRedemption' + | 'unload' + | 'verticalScroll' diff --git a/src/middleware/secure-headers/secure-headers.ts b/src/middleware/secure-headers/secure-headers.ts index 3ef52aa00..c29515d56 100644 --- a/src/middleware/secure-headers/secure-headers.ts +++ b/src/middleware/secure-headers/secure-headers.ts @@ -6,6 +6,7 @@ import type { Context } from '../../context' import type { MiddlewareHandler } from '../../types' import { encodeBase64 } from '../../utils/encode' +import type { PermissionsPolicyDirective } from './permissions-policy' export type SecureHeadersVariables = { secureHeadersNonce?: string @@ -54,6 +55,12 @@ interface ReportingEndpointOptions { url: string } +type PermissionsPolicyValue = '*' | 'self' | 'src' | 'none' | string + +type PermissionsPolicyOptions = Partial< + Record +> + type overridableHeader = boolean | string interface SecureHeadersOptions { @@ -73,6 +80,7 @@ interface SecureHeadersOptions { xPermittedCrossDomainPolicies?: overridableHeader xXssProtection?: overridableHeader removePoweredBy?: boolean + permissionsPolicy?: PermissionsPolicyOptions } type HeadersMap = { @@ -108,6 +116,7 @@ const DEFAULT_OPTIONS: SecureHeadersOptions = { xPermittedCrossDomainPolicies: true, xXssProtection: true, removePoweredBy: true, + permissionsPolicy: {}, } type SecureHeadersCallback = ( @@ -154,6 +163,7 @@ export const NONCE: ContentSecurityPolicyOptionHandler = (ctx) => { * @param {overridableHeader} [customOptions.xPermittedCrossDomainPolicies=true] - Settings for the X-Permitted-Cross-Domain-Policies header. * @param {overridableHeader} [customOptions.xXssProtection=true] - Settings for the X-XSS-Protection header. * @param {boolean} [customOptions.removePoweredBy=true] - Settings for remove X-Powered-By header. + * @param {PermissionsPolicyOptions} [customOptions.permissionsPolicy] - Settings for the Permissions-Policy header. * @returns {MiddlewareHandler} The middleware handler function. * * @example @@ -175,6 +185,13 @@ export const secureHeaders = (customOptions?: SecureHeadersOptions): MiddlewareH headersToSet.push(['Content-Security-Policy', value as string]) } + if (options.permissionsPolicy && Object.keys(options.permissionsPolicy).length > 0) { + headersToSet.push([ + 'Permissions-Policy', + getPermissionsPolicyDirectives(options.permissionsPolicy), + ]) + } + if (options.reportingEndpoints) { headersToSet.push(['Reporting-Endpoints', getReportingEndpoints(options.reportingEndpoints)]) } @@ -255,6 +272,36 @@ function getCSPDirectives( ] } +function getPermissionsPolicyDirectives(policy: PermissionsPolicyOptions): string { + return Object.entries(policy) + .map(([directive, value]) => { + const kebabDirective = camelToKebab(directive) + + if (typeof value === 'boolean') { + return `${kebabDirective}=${value ? '*' : 'none'}` + } + + if (Array.isArray(value)) { + if (value.length === 0) { + return `${kebabDirective}=()` + } + if (value.length === 1 && (value[0] === '*' || value[0] === 'none')) { + return `${kebabDirective}=${value[0]}` + } + const allowlist = value.map((item) => (['self', 'src'].includes(item) ? item : `"${item}"`)) + return `${kebabDirective}=(${allowlist.join(' ')})` + } + + return '' + }) + .filter(Boolean) + .join(', ') +} + +function camelToKebab(str: string): string { + return str.replace(/([a-z\d])([A-Z])/g, '$1-$2').toLowerCase() +} + function getReportingEndpoints( reportingEndpoints: SecureHeadersOptions['reportingEndpoints'] = [] ): string { diff --git a/src/middleware/serve-static/index.test.ts b/src/middleware/serve-static/index.test.ts index efae1aeb5..b0bd1e3ef 100644 --- a/src/middleware/serve-static/index.test.ts +++ b/src/middleware/serve-static/index.test.ts @@ -18,6 +18,11 @@ describe('Serve Static Middleware', () => { isDir: (path) => { return path === 'static/hello.world' }, + onFound: (path, c) => { + if (path.endsWith('hello.html')) { + c.header('X-Custom', `Found the file at ${path}`) + } + }, }) app.get('/static/*', serveStatic) @@ -29,8 +34,10 @@ describe('Serve Static Middleware', () => { it('Should return 200 response - /static/hello.html', async () => { const res = await app.request('/static/hello.html') expect(res.status).toBe(200) + expect(res.headers.get('Content-Encoding')).toBeNull() expect(res.headers.get('Content-Type')).toMatch(/^text\/html/) expect(await res.text()).toBe('Hello in ./static/hello.html') + expect(res.headers.get('X-Custom')).toBe('Found the file at ./static/hello.html') }) it('Should return 200 response - /static/sub', async () => { @@ -57,12 +64,15 @@ describe('Serve Static Middleware', () => { it('Should decode URI strings - /static/%E7%82%8E.txt', async () => { const res = await app.request('/static/%E7%82%8E.txt') expect(res.status).toBe(200) + expect(res.headers.get('Content-Type')).toMatch(/^text\/plain/) expect(await res.text()).toBe('Hello in ./static/炎.txt') }) - it('Should return 404 response - /static/not-found', async () => { + it('Should return 404 response - /static/not-found.txt', async () => { const res = await app.request('/static/not-found.txt') expect(res.status).toBe(404) + expect(res.headers.get('Content-Encoding')).toBeNull() + expect(res.headers.get('Content-Type')).toMatch(/^text\/plain/) expect(await res.text()).toBe('404 Not Found') expect(getContent).toBeCalledTimes(1) }) @@ -73,9 +83,90 @@ describe('Serve Static Middleware', () => { url: 'http://localhost/static/%2e%2e/static/hello.html', } as Request) expect(res.status).toBe(404) + expect(res.headers.get('Content-Type')).toMatch(/^text\/plain/) expect(await res.text()).toBe('404 Not Found') }) + it('Should return a pre-compressed zstd response - /static/hello.html', async () => { + const app = new Hono().use( + '*', + baseServeStatic({ + getContent, + precompressed: true, + }) + ) + + const res = await app.request('/static/hello.html', { + headers: { 'Accept-Encoding': 'zstd' }, + }) + + expect(res.status).toBe(200) + expect(res.headers.get('Content-Encoding')).toBe('zstd') + expect(res.headers.get('Vary')).toBe('Accept-Encoding') + expect(res.headers.get('Content-Type')).toMatch(/^text\/html/) + expect(await res.text()).toBe('Hello in static/hello.html.zst') + }) + + it('Should return a pre-compressed brotli response - /static/hello.html', async () => { + const app = new Hono().use( + '*', + baseServeStatic({ + getContent, + precompressed: true, + }) + ) + + const res = await app.request('/static/hello.html', { + headers: { 'Accept-Encoding': 'wompwomp, gzip, br, deflate, zstd' }, + }) + + expect(res.status).toBe(200) + expect(res.headers.get('Content-Encoding')).toBe('br') + expect(res.headers.get('Vary')).toBe('Accept-Encoding') + expect(res.headers.get('Content-Type')).toMatch(/^text\/html/) + expect(await res.text()).toBe('Hello in static/hello.html.br') + }) + + it('Should not return a pre-compressed response - /static/not-found.txt', async () => { + const app = new Hono().use( + '*', + baseServeStatic({ + getContent, + precompressed: true, + }) + ) + + const res = await app.request('/static/not-found.txt', { + headers: { 'Accept-Encoding': 'gzip, zstd, br' }, + }) + + expect(res.status).toBe(404) + expect(res.headers.get('Content-Encoding')).toBeNull() + expect(res.headers.get('Vary')).toBeNull() + expect(res.headers.get('Content-Type')).toMatch(/^text\/plain/) + expect(await res.text()).toBe('404 Not Found') + }) + + it('Should not return a pre-compressed response - /static/hello.html', async () => { + const app = new Hono().use( + '*', + baseServeStatic({ + getContent, + precompressed: true, + }) + ) + + const res = await app.request('/static/hello.html', { + headers: { 'Accept-Encoding': 'wompwomp, unknown' }, + }) + + expect(res.status).toBe(200) + expect(res.headers.get('Content-Encoding')).toBeNull() + expect(res.headers.get('Vary')).toBeNull() + expect(res.headers.get('Content-Type')).toMatch(/^text\/html/) + expect(await res.text()).toBe('Hello in static/hello.html') + }) + it('Should return response object content as-is', async () => { const body = new ReadableStream() const response = new Response(body) diff --git a/src/middleware/serve-static/index.ts b/src/middleware/serve-static/index.ts index 1e3e78569..56c601da5 100644 --- a/src/middleware/serve-static/index.ts +++ b/src/middleware/serve-static/index.ts @@ -11,11 +11,19 @@ import { getMimeType } from '../../utils/mime' export type ServeStaticOptions = { root?: string path?: string + precompressed?: boolean mimes?: Record rewriteRequestPath?: (path: string) => string + onFound?: (path: string, c: Context) => void | Promise onNotFound?: (path: string, c: Context) => void | Promise } +const ENCODINGS = { + br: '.br', + zstd: '.zst', + gzip: '.gz', +} as const + const DEFAULT_DOCUMENT = 'index.html' const defaultPathResolve = (path: string) => path @@ -47,7 +55,7 @@ export const serveStatic = ( root, }) if (path && (await options.isDir(path))) { - filename = filename + '/' + filename += '/' } } @@ -63,24 +71,23 @@ export const serveStatic = ( const getContent = options.getContent const pathResolve = options.pathResolve ?? defaultPathResolve - path = pathResolve(path) let content = await getContent(path, c) if (!content) { - let pathWithOutDefaultDocument = getFilePathWithoutDefaultDocument({ + let pathWithoutDefaultDocument = getFilePathWithoutDefaultDocument({ filename, root, }) - if (!pathWithOutDefaultDocument) { + if (!pathWithoutDefaultDocument) { return await next() } - pathWithOutDefaultDocument = pathResolve(pathWithOutDefaultDocument) + pathWithoutDefaultDocument = pathResolve(pathWithoutDefaultDocument) - if (pathWithOutDefaultDocument !== path) { - content = await getContent(pathWithOutDefaultDocument, c) + if (pathWithoutDefaultDocument !== path) { + content = await getContent(pathWithoutDefaultDocument, c) if (content) { - path = pathWithOutDefaultDocument + path = pathWithoutDefaultDocument } } } @@ -89,16 +96,40 @@ export const serveStatic = ( return c.newResponse(content.body, content) } + const mimeType = options.mimes + ? getMimeType(path, options.mimes) ?? getMimeType(path) + : getMimeType(path) + + if (mimeType) { + c.header('Content-Type', mimeType) + } + if (content) { - let mimeType: string | undefined - if (options.mimes) { - mimeType = getMimeType(path, options.mimes) ?? getMimeType(path) - } else { - mimeType = getMimeType(path) - } - if (mimeType) { - c.header('Content-Type', mimeType) + if (options.precompressed) { + const acceptEncodings = + c.req + .header('Accept-Encoding') + ?.split(',') + .map((encoding) => encoding.trim()) + .filter((encoding): encoding is keyof typeof ENCODINGS => + Object.hasOwn(ENCODINGS, encoding) + ) + .sort( + (a, b) => Object.keys(ENCODINGS).indexOf(a) - Object.keys(ENCODINGS).indexOf(b) + ) ?? [] + + for (const encoding of acceptEncodings) { + const compressedContent = (await getContent(path + ENCODINGS[encoding], c)) as Data | null + + if (compressedContent) { + content = compressedContent + c.header('Content-Encoding', encoding) + c.header('Vary', 'Accept-Encoding', { append: true }) + break + } + } } + await options.onFound?.(path, c) return c.body(content) } diff --git a/src/utils/html.ts b/src/utils/html.ts index d35572634..7731e565c 100644 --- a/src/utils/html.ts +++ b/src/utils/html.ts @@ -140,12 +140,21 @@ export const resolveCallbackSync = (str: string | HtmlEscapedString): string => } export const resolveCallback = async ( - str: string | HtmlEscapedString, + str: string | HtmlEscapedString | Promise, phase: (typeof HtmlEscapedCallbackPhase)[keyof typeof HtmlEscapedCallbackPhase], preserveCallbacks: boolean, context: object, buffer?: [string] ): Promise => { + if (typeof str === 'object' && !(str instanceof String)) { + if (!((str as unknown) instanceof Promise)) { + str = (str as unknown as string).toString() // HtmlEscapedString object to string + } + if ((str as string | Promise) instanceof Promise) { + str = await (str as unknown as Promise) + } + } + const callbacks = (str as HtmlEscapedString).callbacks as HtmlEscapedCallback[] if (!callbacks?.length) { return Promise.resolve(str) @@ -153,7 +162,7 @@ export const resolveCallback = async ( if (buffer) { buffer[0] += str } else { - buffer = [str] + buffer = [str as string] } const resStr = Promise.all(callbacks.map((c) => c({ phase, buffer, context }))).then((res) =>