Skip to content

Commit

Permalink
Update OpenAIAgent to support Runnable models (#3346)
Browse files Browse the repository at this point in the history
* Update OpenAIAgent to support Runnable interface

* Add test with executor

* Call invoke for all paths and add CallOptions

* Format and fix test

---------

Co-authored-by: jacoblee93 <[email protected]>
  • Loading branch information
gramliu and jacoblee93 authored Nov 22, 2023
1 parent cc09d15 commit e14539a
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 11 deletions.
42 changes: 32 additions & 10 deletions langchain/src/agents/openai/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { CallbackManager } from "../../callbacks/manager.js";
import { ChatOpenAI } from "../../chat_models/openai.js";
import { ChatOpenAI, ChatOpenAICallOptions } from "../../chat_models/openai.js";
import { BasePromptTemplate } from "../../prompts/base.js";
import {
AIMessage,
Expand All @@ -10,6 +10,7 @@ import {
FunctionMessage,
ChainValues,
SystemMessage,
BaseMessageChunk,
} from "../../schema/index.js";
import { StructuredTool } from "../../tools/base.js";
import { Agent, AgentArgs } from "../agent.js";
Expand All @@ -21,13 +22,20 @@ import {
MessagesPlaceholder,
SystemMessagePromptTemplate,
} from "../../prompts/chat.js";
import { BaseLanguageModel } from "../../base_language/index.js";
import {
BaseLanguageModel,
BaseLanguageModelInput,
} from "../../base_language/index.js";
import { LLMChain } from "../../chains/llm_chain.js";
import {
FunctionsAgentAction,
OpenAIFunctionsAgentOutputParser,
} from "./output_parser.js";
import { formatToOpenAIFunction } from "../../tools/convert_to_openai.js";
import { Runnable } from "../../schema/runnable/base.js";

// eslint-disable-next-line @typescript-eslint/no-explicit-any
type CallOptionsIfAvailable<T> = T extends { CallOptions: infer CO } ? CO : any;

/**
* Checks if the given action is a FunctionsAgentAction.
Expand Down Expand Up @@ -199,28 +207,42 @@ export class OpenAIAgent extends Agent {
}

// Split inputs between prompt and llm
const llm = this.llmChain.llm as ChatOpenAI;
const llm = this.llmChain.llm as
| ChatOpenAI
| Runnable<
BaseLanguageModelInput,
BaseMessageChunk,
ChatOpenAICallOptions
>;

const valuesForPrompt = { ...newInputs };
const valuesForLLM: (typeof llm)["CallOptions"] = {
const valuesForLLM: CallOptionsIfAvailable<typeof llm> = {
functions: this.tools.map(formatToOpenAIFunction),
};
const callKeys =
"callKeys" in this.llmChain.llm ? this.llmChain.llm.callKeys : [];
for (const key of callKeys) {
if (key in inputs) {
valuesForLLM[key as keyof (typeof llm)["CallOptions"]] = inputs[key];
valuesForLLM[key as keyof CallOptionsIfAvailable<typeof llm>] =
inputs[key];
delete valuesForPrompt[key];
}
}

const promptValue = await this.llmChain.prompt.formatPromptValue(
valuesForPrompt
);
const message = await llm.predictMessages(
promptValue.toChatMessages(),
valuesForLLM,
callbackManager
);

const message = await (
llm as Runnable<
BaseLanguageModelInput,
BaseMessageChunk,
ChatOpenAICallOptions
>
).invoke(promptValue.toChatMessages(), {
...valuesForLLM,
callbacks: callbackManager,
});
return this.outputParser.parseAIMessage(message);
}
}
40 changes: 39 additions & 1 deletion langchain/src/agents/tests/runnable.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import { SerpAPI } from "../../tools/serpapi.js";
import { formatToOpenAIFunction } from "../../tools/convert_to_openai.js";
import { Calculator } from "../../tools/calculator.js";
import { OpenAIFunctionsAgentOutputParser } from "../openai/output_parser.js";
import { LLMChain } from "../../chains/llm_chain.js";
import { OpenAIAgent } from "../openai/index.js";

test("Runnable variant", async () => {
const tools = [new Calculator(), new SerpAPI()];
Expand Down Expand Up @@ -59,8 +61,44 @@ test("Runnable variant", async () => {

const query = "What is the weather in New York?";
console.log(`Calling agent executor with query: ${query}`);
const result = await executor.call({
const result = await executor.invoke({
input: query,
});
console.log(result);
});

test("Runnable variant works with executor", async () => {
// Prepare tools
const tools = [new Calculator(), new SerpAPI()];
const runnableModel = new ChatOpenAI({
modelName: "gpt-4",
temperature: 0,
}).bind({});

const prompt = ChatPromptTemplate.fromMessages([
["ai", "You are a helpful assistant"],
["human", "{input}"],
new MessagesPlaceholder("agent_scratchpad"),
]);

// Prepare agent chain
const llmChain = new LLMChain({
prompt,
llm: runnableModel,
});
const agent = new OpenAIAgent({
llmChain,
tools,
});

// Prepare and run executor
const executor = new AgentExecutor({
agent,
tools,
});
const result = await executor.invoke({
input: "What is the weather in New York?",
});

console.log(result);
});

2 comments on commit e14539a

@vercel
Copy link

@vercel vercel bot commented on e14539a Nov 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vercel
Copy link

@vercel vercel bot commented on e14539a Nov 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Successfully deployed to the following URLs:

langchainjs-docs – ./docs/core_docs/

langchainjs-docs-git-main-langchain.vercel.app
langchainjs-docs-langchain.vercel.app
js.langchain.com
langchainjs-docs-ruddy.vercel.app

Please sign in to comment.