Skip to content

Commit

Permalink
chore: upgrade to ai SDK v4 (#459)
Browse files Browse the repository at this point in the history
  • Loading branch information
stefl authored Jan 7, 2025
1 parent 79972b4 commit e61ba52
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 320 deletions.
2 changes: 1 addition & 1 deletion apps/nextjs/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
"@trpc/server": "10.45.2",
"@types/ramda": "^0.30.2",
"@uidotdev/usehooks": "^2.4.1",
"ai": "^3.3.26",
"ai": "^4.0.20",
"american-british-english-translator": "^0.2.1",
"archiver": "^7.0.1",
"axios": "^1.6.8",
Expand Down
50 changes: 32 additions & 18 deletions apps/nextjs/src/app/api/chat/chatHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@ import { withTelemetry } from "@oakai/core/src/tracing/serverTracing";
import type { PrismaClientWithAccelerate } from "@oakai/db";
import { prisma as globalPrisma } from "@oakai/db/client";
import { aiLogger } from "@oakai/logger";
// #TODO StreamingTextResponse is deprecated. If we choose to adopt the "ai" package
// more fully, we should refactor to support its approach to streaming
// but this could be a significant change given we have our record-separator approach
import { StreamingTextResponse } from "ai";
import type { NextRequest } from "next/server";
import invariant from "tiny-invariant";

Expand Down Expand Up @@ -124,23 +120,41 @@ function handleConnectionAborted(req: NextRequest) {
async function generateChatStream(
aila: Aila,
abortController: AbortController,
) {
): Promise<Response> {
return await withTelemetry(
"chat-aila-generate",
{ chat_id: aila.chatId, user_id: aila.userId },
async () => {
invariant(aila, "Aila instance is required");
const result = await aila.generate({ abortController });
const transformStream = new TransformStream({
transform(chunk, controller) {
const formattedChunk = new TextEncoder().encode(
`0:${JSON.stringify(chunk)}\n`,
);
controller.enqueue(formattedChunk);
},
});

return result.pipeThrough(transformStream);
try {
invariant(aila, "Aila instance is required");
const result = await aila.generate({ abortController });
const transformStream = new TransformStream({
transform(chunk, controller) {
const formattedChunk = new TextEncoder().encode(
`0:${JSON.stringify(chunk)}\n`,
);
controller.enqueue(formattedChunk);
},
});

const stream = result.pipeThrough(transformStream);
return new Response(stream, {
headers: {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
Connection: "keep-alive",
},
});
} catch (error) {
log.error("Error generating chat stream", { error });
return new Response(
JSON.stringify({ error: "Stream generation failed" }),
{
status: 500,
headers: { "Content-Type": "application/json" },
},
);
}
},
);
}
Expand Down Expand Up @@ -199,7 +213,7 @@ export async function handleChatPostRequest(

const abortController = handleConnectionAborted(req);
const stream = await generateChatStream(aila, abortController);
return new StreamingTextResponse(stream);
return stream;
} catch (e) {
return handleChatException(span, e, chatId, prisma);
} finally {
Expand Down
3 changes: 1 addition & 2 deletions apps/nextjs/src/app/api/chat/protocol.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@ import type {
ActionDocument,
ErrorDocument,
} from "@oakai/aila/src/protocol/jsonPatchProtocol";
import { StreamingTextResponse } from "ai";

export function streamingJSON(message: ErrorDocument | ActionDocument) {
const jsonContent = JSON.stringify(message);
const errorMessage = `0:"${jsonContent.replace(/"/g, '\\"')}"`;

const errorEncoder = new TextEncoder();

return new StreamingTextResponse(
return new Response(
new ReadableStream({
start(controller) {
controller.enqueue(errorEncoder.encode(errorMessage));
Expand Down
11 changes: 10 additions & 1 deletion apps/nextjs/src/components/ContextProviders/ChatProvider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,13 @@ function getModerationFromMessage(message?: { content: string }) {
return moderation;
}

function isValidMessageRole(role: unknown): role is Message["role"] {
return (
typeof role === "string" &&
["system", "assistant", "user", "data"].includes(role)
);
}

export function ChatProvider({ id, children }: Readonly<ChatProviderProps>) {
const {
data: chat,
Expand Down Expand Up @@ -179,7 +186,9 @@ export function ChatProvider({ id, children }: Readonly<ChatProviderProps>) {
setMessages,
} = useChat({
sendExtraMessageFields: true,
initialMessages: chat?.messages ?? [],
initialMessages: (chat?.messages ?? []).filter((m) =>
isValidMessageRole(m.role),
) as Message[],
generateId: () => generateMessageId({ role: "user" }),
id,
body: {
Expand Down
2 changes: 1 addition & 1 deletion packages/aila/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"@oakai/logger": "*",
"@sentry/nextjs": "^8.35.0",
"@vercel/kv": "^0.2.2",
"ai": "^3.3.26",
"ai": "^4.0.20",
"american-british-english-translator": "^0.2.1",
"cloudinary": "^1.41.1",
"dedent": "^1.5.3",
Expand Down
2 changes: 1 addition & 1 deletion packages/aila/src/core/chat/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { z } from "zod";

export const MessageSchema = z.object({
content: z.string(),
role: z.enum(["system", "assistant", "user"]),
role: z.enum(["system", "assistant", "user", "data"]),
id: z.string(),
});

Expand Down
4 changes: 2 additions & 2 deletions packages/aila/src/core/llm/OpenAIService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ export class OpenAIService implements LLMService {
messages: Message[];
temperature: number;
}): Promise<ReadableStreamDefaultReader<string>> {
const { textStream: stream } = await streamText({
const { textStream: stream } = streamText({
model: this._openAIProvider(params.model),
messages: params.messages.map((m) => ({
role: m.role,
Expand All @@ -53,7 +53,7 @@ export class OpenAIService implements LLMService {
return this.createChatCompletionStream({ model, messages, temperature });
}
const startTime = Date.now();
const { textStream: stream } = await streamObject({
const { textStream: stream } = streamObject({
model: this._openAIProvider(model, { structuredOutputs: true }),
output: "object",
schema,
Expand Down
28 changes: 9 additions & 19 deletions packages/aila/src/helpers/chat/getLastAssistantMessage.ts
Original file line number Diff line number Diff line change
@@ -1,22 +1,12 @@
import type { Message as AiMessage } from "ai";
import { findLast } from "remeda";

import type { Message as AilaMessage } from "../../core/chat/types";

interface AssistantMessage extends AilaMessage {
role: "assistant";
interface HasRole {
role: string;
id: string;
}

/**
* This function takes an array of messages, and returns the last message from the assistant.
*/
export function getLastAssistantMessage(
messages: AiMessage[],
): AssistantMessage | undefined {
const lastAssistantMessage = findLast(
messages,
(m): m is AssistantMessage => m.role === "assistant",
) as AssistantMessage | undefined;

return lastAssistantMessage;
export function getLastAssistantMessage<T extends HasRole>(
messages: T[],
): (T & { role: "assistant" }) | undefined {
return messages.filter((m) => m.role === "assistant").pop() as
| (T & { role: "assistant" })
| undefined;
}
Loading

0 comments on commit e61ba52

Please sign in to comment.