Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
stephmilovic committed Oct 9, 2024
1 parent 83a701e commit 9db4909
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@ const mockStreamExecute = jest.fn().mockImplementation(() => {
};
});

const systemInstruction = 'Answer the following questions truthfully and as best you can.';

const callMessages = [
new SystemMessage('Answer the following questions truthfully and as best you can.'),
new SystemMessage(systemInstruction),
new HumanMessage('Question: Do you know my name?\n\n'),
] as unknown as BaseMessage[];

Expand Down Expand Up @@ -196,4 +198,29 @@ describe('ActionsClientChatVertexAI', () => {
expect(handleLLMNewToken).toHaveBeenCalledWith('token3');
});
});

describe('message formatting', () => {
it('Properly sorts out the system role', async () => {
const actionsClientChatVertexAI = new ActionsClientChatVertexAI(defaultArgs);

await actionsClientChatVertexAI._generate(callMessages, callOptions, callRunManager);
const params = actionsClient.execute.mock.calls[0][0].params.subActionParams;
console.log(actionsClient.execute.mock.calls[0][0].params.subActionParams);
expect(params.messages.length).toEqual(1);
expect(params.messages[0].parts.length).toEqual(1);
expect(params.systemInstruction).toEqual(systemInstruction);
});
it('Handles 2 messages in a row from the same role', async () => {
const actionsClientChatVertexAI = new ActionsClientChatVertexAI(defaultArgs);

await actionsClientChatVertexAI._generate(
[...callMessages, new HumanMessage('Oh boy, another')],
callOptions,
callRunManager
);
const messages = actionsClient.execute.mock.calls[0][0].params.subActionParams.messages;
expect(messages.length).toEqual(1);
expect(messages[0].parts.length).toEqual(2);
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,22 @@ export class ActionsClientChatConnection<Auth> extends ChatConnection<Auth> {
this.caller = caller;
this.#model = fields.model;
this.temperature = fields.temperature ?? 0;
const nativeFormatData = this.formatData.bind(this);
this.formatData = async (data, options) => {
const result = await nativeFormatData(data, options);
if (result?.contents.length) {
// ensure there are not 2 messages in a row from the same role,
// if there are combine them
result.contents = result.contents.reduce((acc, currentEntry) => {
if (currentEntry.role === acc[acc.length - 1]?.role) {
acc[acc.length - 1].parts = acc[acc.length - 1].parts.concat(currentEntry.parts);
return acc;
}
return [...acc, currentEntry];
}, []);
}
return result;
};
}

async _request(
Expand Down

0 comments on commit 9db4909

Please sign in to comment.