diff --git a/.env.template b/.env.template index a3e3933f..ecd238eb 100644 --- a/.env.template +++ b/.env.template @@ -17,6 +17,12 @@ BEE_FRAMEWORK_LOG_SINGLE_LINE="false" # For OpenAI LLM Adapter # OPENAI_API_KEY= +# For Azure OpenAI LLM Adapter +OPENAI_API_VERSION= +AZURE_DEPLOYMENT_NAME= +AZURE_OPENAI_API_KEY= +AZURE_OPENAI_ENDPOINT= + # For Groq LLM Adapter # GROQ_API_KEY= diff --git a/examples/llms/providers/azureopenai.ts b/examples/llms/providers/azureopenai.ts new file mode 100644 index 00000000..3f9fc5e4 --- /dev/null +++ b/examples/llms/providers/azureopenai.ts @@ -0,0 +1,21 @@ +import "dotenv/config"; +import { BaseMessage } from "bee-agent-framework/llms/primitives/message"; +import { OpenAIChatLLM } from "bee-agent-framework/adapters/openai/chat"; + +const llm = new OpenAIChatLLM({ + modelId: "gpt-4o-mini", + azure: true, + parameters: { + max_tokens: 10, + stop: ["post"], + }, +}); + +console.info("Meta", await llm.meta()); +const response = await llm.generate([ + BaseMessage.of({ + role: "user", + text: "Hello world!", + }), +]); +console.info(response.getTextContent()); diff --git a/src/adapters/openai/chat.test.ts b/src/adapters/openai/chat.test.ts index 0337889c..61127be8 100644 --- a/src/adapters/openai/chat.test.ts +++ b/src/adapters/openai/chat.test.ts @@ -16,7 +16,26 @@ import { verifyDeserialization } from "@tests/e2e/utils.js"; import { OpenAIChatLLM } from "@/adapters/openai/chat.js"; -import { OpenAI } from "openai"; +import { OpenAI, AzureOpenAI } from "openai"; + +describe("AzureOpenAI ChatLLM", () => { + const getInstance = () => { + return new OpenAIChatLLM({ + modelId: "gpt-4o", + client: new AzureOpenAI(), + }); + }; + + it("Serializes", async () => { + process.env["OPENAI_BASE_URL"] = "http://dummy/"; + process.env["AZURE_OPENAI_API_KEY"] = "123"; + process.env["OPENAI_API_VERSION"] = "version 1"; + const instance = getInstance(); + const serialized = instance.serialize(); + const deserialized = OpenAIChatLLM.fromSerialized(serialized); + verifyDeserialization(instance, deserialized); + }); +}); describe("OpenAI ChatLLM", () => { const getInstance = () => { diff --git a/src/adapters/openai/chat.ts b/src/adapters/openai/chat.ts index 714fed48..67e33666 100644 --- a/src/adapters/openai/chat.ts +++ b/src/adapters/openai/chat.ts @@ -28,15 +28,24 @@ import { shallowCopy } from "@/serializer/utils.js"; import { ChatLLM, ChatLLMOutput } from "@/llms/chat.js"; import { BaseMessage, RoleType } from "@/llms/primitives/message.js"; import { Emitter } from "@/emitter/emitter.js"; -import { ClientOptions, OpenAI as Client } from "openai"; +import { ClientOptions, OpenAI, AzureOpenAI } from "openai"; import { GetRunContext } from "@/context.js"; import { promptTokensEstimate } from "openai-chat-tokens"; import { Serializer } from "@/serializer/serializer.js"; import { getProp, getPropStrict } from "@/internals/helpers/object.js"; import { isString } from "remeda"; +import type { + ChatCompletionChunk, + ChatCompletionCreateParams, + ChatCompletionMessageParam, + ChatCompletionSystemMessageParam, + ChatCompletionUserMessageParam, + ChatCompletionAssistantMessageParam, + ChatModel, +} from "openai/resources/index"; -type Parameters = Omit; -type Response = Omit; +type Parameters = Omit; +type Response = Omit; export class OpenAIChatLLMOutput extends ChatLLMOutput { public readonly responses: Response[]; @@ -85,11 +94,12 @@ export class OpenAIChatLLMOutput extends ChatLLMOutput { } interface Input { - modelId?: Client.ChatModel; - client?: Client; + modelId?: ChatModel; + client?: OpenAI | AzureOpenAI; parameters?: Partial; executionOptions?: ExecutionOptions; cache?: LLMCache; + azure?: boolean; } export class OpenAIChatLLM extends ChatLLM { @@ -98,28 +108,38 @@ export class OpenAIChatLLM extends ChatLLM { creator: this, }); - public readonly client: Client; + public readonly client: OpenAI | AzureOpenAI; public readonly parameters: Partial; - constructor({ - client, - modelId = "gpt-4o", - parameters, - executionOptions = {}, - cache, - }: Input = {}) { - super(modelId, executionOptions, cache); - this.client = client ?? new Client(); + constructor({ client, modelId, parameters, executionOptions = {}, cache, azure }: Input = {}) { + super(modelId || "gpt-4o-mini", executionOptions, cache); + if (client) { + this.client = client; + } else if (azure) { + this.client = new AzureOpenAI(); + } else { + this.client = new OpenAI(); + } this.parameters = parameters ?? { temperature: 0 }; } static { this.register(); - Serializer.register(Client, { + Serializer.register(AzureOpenAI, { + toPlain: (value) => ({ + azureADTokenProvider: getPropStrict(value, "_azureADTokenProvider"), + apiVersion: getPropStrict(value, "apiVersion"), + deployment: getPropStrict(value, "_deployment"), + }), + fromPlain: (value) => { + return new AzureOpenAI(value.azureADTokenProvider); + }, + }); + Serializer.register(OpenAI, { toPlain: (value) => ({ options: getPropStrict(value, "_options") as ClientOptions, }), - fromPlain: (value) => new Client(value.options), + fromPlain: (value) => new OpenAI(value.options), }); } @@ -151,7 +171,7 @@ export class OpenAIChatLLM extends ChatLLM { ({ role: msg.role, content: msg.text, - }) as Client.Chat.ChatCompletionMessageParam, + }) as ChatCompletionMessageParam, ), }); @@ -163,11 +183,11 @@ export class OpenAIChatLLM extends ChatLLM { protected _prepareRequest( input: BaseMessage[], options?: GenerateOptions, - ): Client.Chat.ChatCompletionCreateParams { + ): ChatCompletionCreateParams { type OpenAIMessage = - | Client.Chat.ChatCompletionSystemMessageParam - | Client.Chat.ChatCompletionUserMessageParam - | Client.Chat.ChatCompletionAssistantMessageParam; + | ChatCompletionSystemMessageParam + | ChatCompletionUserMessageParam + | ChatCompletionAssistantMessageParam; return { ...this.parameters, @@ -179,6 +199,7 @@ export class OpenAIChatLLM extends ChatLLM { content: message.text, }), ), + response_format: (() => { if (options?.guided?.json) { const schema = isString(options.guided.json) @@ -231,7 +252,7 @@ export class OpenAIChatLLM extends ChatLLM { index: 1, logprobs: choice.logprobs, finish_reason: choice.finish_reason, - }) as Client.Chat.ChatCompletionChunk.Choice, + }) as ChatCompletionChunk.Choice, ), }); } diff --git a/tests/examples/examples.test.ts b/tests/examples/examples.test.ts index 5f5efa0f..8214b8bd 100644 --- a/tests/examples/examples.test.ts +++ b/tests/examples/examples.test.ts @@ -49,6 +49,7 @@ const exclude: string[] = [ "examples/agents/bee_reusable.ts", "examples/llms/providers/openai.ts", ], + !hasEnv("AZURE_OPENAI_API_KEY") && ["examples/llms/providers/azureopenai.ts"], !hasEnv("IBM_VLLM_URL") && [ "examples/llms/providers/ibm-vllm.ts", "examples/agents/granite/chat.ts",