Skip to content

Commit

Permalink
feat(helper/streaming): Support Promise<string> or (async) JSX.Elemen…
Browse files Browse the repository at this point in the history
…t in streamSSE (#3344)

* feat(helper/streaming): Support Promise<string> or (async) JSX.Element in streamSSE

* refactor(context): enable to pass Promise<string> (includes async JSX.Element) to resolveCallback
  • Loading branch information
usualoma authored Sep 8, 2024
1 parent c8d5e34 commit 8ca155e
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 17 deletions.
17 changes: 5 additions & 12 deletions src/context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -844,18 +844,11 @@ export class Context<
this.#preparedHeaders['content-type'] = 'text/html; charset=UTF-8'

if (typeof html === 'object') {
if (!(html instanceof Promise)) {
html = (html as string).toString() // HtmlEscapedString object to string
}
if ((html as string | Promise<string>) instanceof Promise) {
return (html as unknown as Promise<string>)
.then((html) => resolveCallback(html, HtmlEscapedCallbackPhase.Stringify, false, {}))
.then((html) => {
return typeof arg === 'number'
? this.newResponse(html, arg, headers)
: this.newResponse(html, arg)
})
}
return resolveCallback(html, HtmlEscapedCallbackPhase.Stringify, false, {}).then((html) => {
return typeof arg === 'number'
? this.newResponse(html, arg, headers)
: this.newResponse(html, arg)
})
}

return typeof arg === 'number'
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
/** @jsxImportSource ../../jsx */
import { ErrorBoundary } from '../../jsx'
import { Context } from '../../context'
import { streamSSE } from '.'

Expand Down Expand Up @@ -145,4 +147,90 @@ describe('SSE Streaming helper', () => {
expect(onError).toBeCalledTimes(1)
expect(onError).toBeCalledWith(new Error('Test error'), expect.anything()) // 2nd argument is StreamingApi instance
})

it('Check streamSSE Response via Promise<string>', async () => {
const res = streamSSE(c, async (stream) => {
await stream.writeSSE({ data: Promise.resolve('Async Message') })
})

expect(res).not.toBeNull()
expect(res.status).toBe(200)

if (!res.body) {
throw new Error('Body is null')
}
const reader = res.body.getReader()
const decoder = new TextDecoder()
const { value } = await reader.read()
const decodedValue = decoder.decode(value)
expect(decodedValue).toBe('data: Async Message\n\n')
})

it('Check streamSSE Response via JSX.Element', async () => {
const res = streamSSE(c, async (stream) => {
await stream.writeSSE({ data: <div>Hello</div> })
})

expect(res).not.toBeNull()
expect(res.status).toBe(200)

if (!res.body) {
throw new Error('Body is null')
}
const reader = res.body.getReader()
const decoder = new TextDecoder()
const { value } = await reader.read()
const decodedValue = decoder.decode(value)
expect(decodedValue).toBe('data: <div>Hello</div>\n\n')
})

it('Check streamSSE Response via ErrorBoundary in success case', async () => {
const AsyncComponent = async () => Promise.resolve(<div>Async Hello</div>)
const res = streamSSE(c, async (stream) => {
await stream.writeSSE({
data: (
<ErrorBoundary fallback={<div>Error</div>}>
<AsyncComponent />
</ErrorBoundary>
),
})
})

expect(res).not.toBeNull()
expect(res.status).toBe(200)

if (!res.body) {
throw new Error('Body is null')
}
const reader = res.body.getReader()
const decoder = new TextDecoder()
const { value } = await reader.read()
const decodedValue = decoder.decode(value)
expect(decodedValue).toBe('data: <div>Async Hello</div>\n\n')
})

it('Check streamSSE Response via ErrorBoundary in error case', async () => {
const AsyncComponent = async () => Promise.reject()
const res = streamSSE(c, async (stream) => {
await stream.writeSSE({
data: (
<ErrorBoundary fallback={<div>Error</div>}>
<AsyncComponent />
</ErrorBoundary>
),
})
})

expect(res).not.toBeNull()
expect(res.status).toBe(200)

if (!res.body) {
throw new Error('Body is null')
}
const reader = res.body.getReader()
const decoder = new TextDecoder()
const { value } = await reader.read()
const decodedValue = decoder.decode(value)
expect(decodedValue).toBe('data: <div>Error</div>\n\n')
})
})
8 changes: 5 additions & 3 deletions src/helper/streaming/sse.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import type { Context } from '../../context'
import { StreamingApi } from '../../utils/stream'
import { HtmlEscapedCallbackPhase, resolveCallback } from '../../utils/html'

export interface SSEMessage {
data: string
data: string | Promise<string>
event?: string
id?: string
retry?: number
Expand All @@ -14,7 +15,8 @@ export class SSEStreamingApi extends StreamingApi {
}

async writeSSE(message: SSEMessage) {
const data = message.data
const data = await resolveCallback(message.data, HtmlEscapedCallbackPhase.Stringify, false, {})
const dataLines = (data as string)
.split('\n')
.map((line) => {
return `data: ${line}`
Expand All @@ -24,7 +26,7 @@ export class SSEStreamingApi extends StreamingApi {
const sseData =
[
message.event && `event: ${message.event}`,
data,
dataLines,
message.id && `id: ${message.id}`,
message.retry && `retry: ${message.retry}`,
]
Expand Down
13 changes: 11 additions & 2 deletions src/utils/html.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,20 +140,29 @@ export const resolveCallbackSync = (str: string | HtmlEscapedString): string =>
}

export const resolveCallback = async (
str: string | HtmlEscapedString,
str: string | HtmlEscapedString | Promise<string>,
phase: (typeof HtmlEscapedCallbackPhase)[keyof typeof HtmlEscapedCallbackPhase],
preserveCallbacks: boolean,
context: object,
buffer?: [string]
): Promise<string> => {
if (typeof str === 'object' && !(str instanceof String)) {
if (!((str as unknown) instanceof Promise)) {
str = (str as unknown as string).toString() // HtmlEscapedString object to string
}
if ((str as string | Promise<string>) instanceof Promise) {
str = await (str as unknown as Promise<string>)
}
}

const callbacks = (str as HtmlEscapedString).callbacks as HtmlEscapedCallback[]
if (!callbacks?.length) {
return Promise.resolve(str)
}
if (buffer) {
buffer[0] += str
} else {
buffer = [str]
buffer = [str as string]
}

const resStr = Promise.all(callbacks.map((c) => c({ phase, buffer, context }))).then((res) =>
Expand Down

0 comments on commit 8ca155e

Please sign in to comment.