Skip to content

Commit

Permalink
fix: Message annotations are compatible for any message type (#959)
Browse files Browse the repository at this point in the history
Co-authored-by: Max Leiter <[email protected]>
  • Loading branch information
nick-inkeep and MaxLeiter authored Feb 9, 2024
1 parent 2076ae6 commit ed1e278
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 45 deletions.
5 changes: 5 additions & 0 deletions .changeset/stupid-news-greet.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'ai': patch
---

Message annotations handling for all Message types
2 changes: 2 additions & 0 deletions examples/next-openai/app/api/chat-with-functions/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ export async function POST(req: Request) {
text: 'Some custom data',
});

data.appendMessageAnnotation({ current_weather: weatherData });

const newMessages = createFunctionCallMessages(weatherData);
return openai.chat.completions.create({
messages: [...messages, ...newMessages],
Expand Down
6 changes: 6 additions & 0 deletions examples/next-openai/app/function-calling/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ export default function Chat() {
>
<strong>{`${m.role}: `}</strong>
{m.content || JSON.stringify(m.function_call)}
{m.annotations ? (
<div>
<br />
<em>Annotations:</em> {JSON.stringify(m.annotations)}
</div>
) : null}
<br />
<br />
</div>
Expand Down
93 changes: 89 additions & 4 deletions packages/core/shared/parse-complex-response.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,8 @@ describe('parseComplexResponse function', () => {
// Execute the parser function
const result = await parseComplexResponse({
reader: createTestReader([
'0:"Sample text message."\n',
'8:[{"key":"value"}, 2]\n',
'0:"Sample text message."\n',
]),
abortControllerRef: { current: new AbortController() },
update: mockUpdate,
Expand All @@ -243,9 +243,7 @@ describe('parseComplexResponse function', () => {
// check the mockUpdate call:
expect(mockUpdate).toHaveBeenCalledTimes(2);

expect(mockUpdate.mock.calls[0][0]).toEqual([
assistantTextMessage('Sample text message.'),
]);
expect(mockUpdate.mock.calls[0][0]).toEqual([]);

expect(mockUpdate.mock.calls[1][0]).toEqual([
{
Expand All @@ -265,4 +263,91 @@ describe('parseComplexResponse function', () => {
data: [],
});
});

it('should parse a combination of a function_call and message annotations', async () => {
const mockUpdate = vi.fn();

// Execute the parser function
const result = await parseComplexResponse({
reader: createTestReader([
'1:{"function_call":{"name":"get_current_weather","arguments":"{\\n\\"location\\": \\"Charlottesville, Virginia\\",\\n\\"format\\": \\"celsius\\"\\n}"}}\n',
'8:[{"key":"value"}, 2]\n',
'8:[null,false,"text"]\n',
]),
abortControllerRef: { current: new AbortController() },
update: mockUpdate,
generateId: () => 'test-id',
getCurrentDate: () => new Date(0),
});

// check the mockUpdate call:
expect(mockUpdate).toHaveBeenCalledTimes(3);

expect(mockUpdate.mock.calls[0][0]).toEqual([
{
content: '',
createdAt: new Date(0),
id: 'test-id',
role: 'assistant',
function_call: {
name: 'get_current_weather',
arguments:
'{\n"location": "Charlottesville, Virginia",\n"format": "celsius"\n}',
},
name: 'get_current_weather',
},
]);

expect(mockUpdate.mock.calls[1][0]).toEqual([
{
content: '',
createdAt: new Date(0),
id: 'test-id',
role: 'assistant',
function_call: {
name: 'get_current_weather',
arguments:
'{\n"location": "Charlottesville, Virginia",\n"format": "celsius"\n}',
},
name: 'get_current_weather',
annotations: [{ key: 'value' }, 2],
},
]);

expect(mockUpdate.mock.calls[2][0]).toEqual([
{
content: '',
createdAt: new Date(0),
id: 'test-id',
role: 'assistant',
function_call: {
name: 'get_current_weather',
arguments:
'{\n"location": "Charlottesville, Virginia",\n"format": "celsius"\n}',
},
name: 'get_current_weather',
annotations: [{ key: 'value' }, 2, null, false, 'text'],
},
]);

// check the result
expect(result).toEqual({
messages: [
{
content: '',
createdAt: new Date(0),
id: 'test-id',
role: 'assistant',
function_call: {
name: 'get_current_weather',
arguments:
'{\n"location": "Charlottesville, Virginia",\n"format": "celsius"\n}',
},
name: 'get_current_weather',
annotations: [{ key: 'value' }, 2, null, false, 'text'],
},
],
data: [],
});
});
});
92 changes: 53 additions & 39 deletions packages/core/shared/parse-complex-response.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,12 @@ type PrefixMap = {
data: JSONValue[];
};

function initializeMessage({
generateId,
...rest
}: {
generateId: () => string;
content: string;
createdAt: Date;
annotations?: JSONValue[];
}): Message {
return {
id: generateId(),
role: 'assistant',
...rest,
};
function assignAnnotationsToMessage<T extends Message | null | undefined>(
message: T,
annotations: JSONValue[] | undefined,
): T {
if (!message || !annotations || !annotations.length) return message;
return { ...message, annotations: [...annotations] } as T;
}

export async function parseComplexResponse({
Expand All @@ -53,6 +45,9 @@ export async function parseComplexResponse({
data: [],
};

// keep list of current message annotations for message
let message_annotations: JSONValue[] | undefined = undefined;

// we create a map of each prefix, and for each prefixed message we push to the map
for await (const { type, value } of readDataStream(reader, {
isAborted: () => abortControllerRef?.current === null,
Expand All @@ -73,24 +68,7 @@ export async function parseComplexResponse({
}
}

if (type == 'message_annotations') {
if (prefixMap['text']) {
prefixMap['text'] = {
...prefixMap['text'],
annotations: [...(prefixMap['text'].annotations || []), ...value],
};
} else {
prefixMap['text'] = {
id: generateId(),
role: 'assistant',
content: '',
annotations: [...value],
createdAt,
};
}
}

let functionCallMessage: Message | null = null;
let functionCallMessage: Message | null | undefined = null;

if (type === 'function_call') {
prefixMap['function_call'] = {
Expand All @@ -105,7 +83,7 @@ export async function parseComplexResponse({
functionCallMessage = prefixMap['function_call'];
}

let toolCallMessage: Message | null = null;
let toolCallMessage: Message | null | undefined = null;

if (type === 'tool_calls') {
prefixMap['tool_calls'] = {
Expand All @@ -123,14 +101,50 @@ export async function parseComplexResponse({
prefixMap['data'].push(...value);
}

const responseMessage = prefixMap['text'];
let responseMessage = prefixMap['text'];

if (type === 'message_annotations') {
if (!message_annotations) {
message_annotations = [...value];
} else {
message_annotations.push(...value);
}

// Update any existing message with the latest annotations
functionCallMessage = assignAnnotationsToMessage(
prefixMap['function_call'],
message_annotations,
);
toolCallMessage = assignAnnotationsToMessage(
prefixMap['tool_calls'],
message_annotations,
);
responseMessage = assignAnnotationsToMessage(
prefixMap['text'],
message_annotations,
);
}

// keeps the prefixMap up to date with the latest annotations, even if annotations preceded the message
if (message_annotations?.length) {
const messagePrefixKeys: (keyof PrefixMap)[] = [
'text',
'function_call',
'tool_calls',
];
messagePrefixKeys.forEach(key => {
if (prefixMap[key]) {
(prefixMap[key] as Message).annotations = [...message_annotations!];
}
});
}

// We add function & tool calls and response messages to the messages[], but data is its own thing
const merged = [
functionCallMessage,
toolCallMessage,
responseMessage,
].filter(Boolean) as Message[];
const merged = [functionCallMessage, toolCallMessage, responseMessage]
.filter(Boolean)
.map(message => ({
...assignAnnotationsToMessage(message, message_annotations),
})) as Message[];

update(merged, [...prefixMap['data']]); // make a copy of the data array
}
Expand Down
5 changes: 3 additions & 2 deletions packages/core/streams/stream-data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@ export class experimental_StreamData {
}

if (self.messageAnnotations.length) {
const encodedmessageAnnotations = self.encoder.encode(
const encodedMessageAnnotations = self.encoder.encode(
formatStreamPart('message_annotations', self.messageAnnotations),
);
controller.enqueue(encodedmessageAnnotations);
self.messageAnnotations = [];
controller.enqueue(encodedMessageAnnotations);
}

controller.enqueue(chunk);
Expand Down

0 comments on commit ed1e278

Please sign in to comment.