Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Azure OpenAI to OpenAI adapter #201

Merged
merged 1 commit into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ BEE_FRAMEWORK_LOG_SINGLE_LINE="false"
# For OpenAI LLM Adapter
# OPENAI_API_KEY=

# For Azure OpenAI LLM Adapter
OPENAI_API_VERSION=
AZURE_DEPLOYMENT_NAME=
AZURE_OPENAI_API_KEY=
AZURE_OPENAI_ENDPOINT=

# For Groq LLM Adapter
# GROQ_API_KEY=

Expand Down
21 changes: 21 additions & 0 deletions examples/llms/providers/azure_openai.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import "dotenv/config";
import { BaseMessage } from "bee-agent-framework/llms/primitives/message";
import { OpenAIChatLLM } from "bee-agent-framework/adapters/openai/chat";

const llm = new OpenAIChatLLM({
modelId: "gpt-4o-mini",
azure: true,
parameters: {
max_tokens: 10,
stop: ["post"],
},
});

console.info("Meta", await llm.meta());
const response = await llm.generate([
BaseMessage.of({
role: "user",
text: "Hello world!",
}),
]);
console.info(response.getTextContent());
21 changes: 20 additions & 1 deletion src/adapters/openai/chat.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,26 @@

import { verifyDeserialization } from "@tests/e2e/utils.js";
import { OpenAIChatLLM } from "@/adapters/openai/chat.js";
import { OpenAI } from "openai";
import { OpenAI, AzureOpenAI } from "openai";

describe("AzureOpenAI ChatLLM", () => {
const getInstance = () => {
return new OpenAIChatLLM({
modelId: "gpt-4o",
client: new AzureOpenAI(),
Tomas2D marked this conversation as resolved.
Show resolved Hide resolved
});
};

it("Serializes", async () => {
process.env["OPENAI_BASE_URL"] = "http://dummy/";
process.env["AZURE_OPENAI_API_KEY"] = "123";
process.env["OPENAI_API_VERSION"] = "version 1";
const instance = getInstance();
const serialized = instance.serialize();
const deserialized = OpenAIChatLLM.fromSerialized(serialized);
verifyDeserialization(instance, deserialized);
});
});

describe("OpenAI ChatLLM", () => {
const getInstance = () => {
Expand Down
65 changes: 42 additions & 23 deletions src/adapters/openai/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,24 @@ import { shallowCopy } from "@/serializer/utils.js";
import { ChatLLM, ChatLLMGenerateEvents, ChatLLMOutput } from "@/llms/chat.js";
import { BaseMessage, RoleType } from "@/llms/primitives/message.js";
import { Emitter } from "@/emitter/emitter.js";
import { ClientOptions, OpenAI as Client } from "openai";
import { ClientOptions, OpenAI, AzureOpenAI } from "openai";
import { GetRunContext } from "@/context.js";
import { promptTokensEstimate } from "openai-chat-tokens";
import { Serializer } from "@/serializer/serializer.js";
import { getProp, getPropStrict } from "@/internals/helpers/object.js";
import { isString } from "remeda";
import type {
ChatCompletionChunk,
ChatCompletionCreateParams,
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionUserMessageParam,
ChatCompletionAssistantMessageParam,
ChatModel,
} from "openai/resources/index";

type Parameters = Omit<Client.Chat.ChatCompletionCreateParams, "stream" | "messages" | "model">;
type Response = Omit<Client.Chat.ChatCompletionChunk, "object">;
type Parameters = Omit<ChatCompletionCreateParams, "stream" | "messages" | "model">;
type Response = Omit<ChatCompletionChunk, "object">;

export class OpenAIChatLLMOutput extends ChatLLMOutput {
public readonly responses: Response[];
Expand Down Expand Up @@ -84,11 +93,12 @@ export class OpenAIChatLLMOutput extends ChatLLMOutput {
}

interface Input {
modelId?: Client.ChatModel;
client?: Client;
modelId?: ChatModel;
client?: OpenAI | AzureOpenAI;
parameters?: Partial<Parameters>;
executionOptions?: ExecutionOptions;
cache?: LLMCache<OpenAIChatLLMOutput>;
azure?: boolean;
}

export type OpenAIChatLLMEvents = ChatLLMGenerateEvents<OpenAIChatLLMOutput>;
Expand All @@ -99,28 +109,36 @@ export class OpenAIChatLLM extends ChatLLM<OpenAIChatLLMOutput> {
creator: this,
});

public readonly client: Client;
public readonly client: OpenAI | AzureOpenAI;
public readonly parameters: Partial<Parameters>;

constructor({
client,
modelId = "gpt-4o",
parameters,
executionOptions = {},
cache,
}: Input = {}) {
super(modelId, executionOptions, cache);
this.client = client ?? new Client();
constructor({ client, modelId, parameters, executionOptions = {}, cache, azure }: Input = {}) {
super(modelId || "gpt-4o-mini", executionOptions, cache);
if (client) {
this.client = client;
} else if (azure) {
this.client = new AzureOpenAI();
} else {
this.client = new OpenAI();
}
this.parameters = parameters ?? { temperature: 0 };
}

static {
this.register();
Serializer.register(Client, {
Serializer.register(AzureOpenAI, {
toPlain: (value) => ({
azureADTokenProvider: getPropStrict(value, "_azureADTokenProvider"),
apiVersion: getPropStrict(value, "apiVersion"),
deployment: getPropStrict(value, "_deployment"),
}),
fromPlain: (value) => new AzureOpenAI(value.azureADTokenProvider),
});
Serializer.register(OpenAI, {
toPlain: (value) => ({
options: getPropStrict(value, "_options") as ClientOptions,
}),
fromPlain: (value) => new Client(value.options),
fromPlain: (value) => new OpenAI(value.options),
});
}

Expand Down Expand Up @@ -152,7 +170,7 @@ export class OpenAIChatLLM extends ChatLLM<OpenAIChatLLMOutput> {
({
role: msg.role,
content: msg.text,
}) as Client.Chat.ChatCompletionMessageParam,
}) as ChatCompletionMessageParam,
),
});

Expand All @@ -164,11 +182,11 @@ export class OpenAIChatLLM extends ChatLLM<OpenAIChatLLMOutput> {
protected _prepareRequest(
input: BaseMessage[],
options?: GenerateOptions,
): Client.Chat.ChatCompletionCreateParams {
): ChatCompletionCreateParams {
type OpenAIMessage =
| Client.Chat.ChatCompletionSystemMessageParam
| Client.Chat.ChatCompletionUserMessageParam
| Client.Chat.ChatCompletionAssistantMessageParam;
| ChatCompletionSystemMessageParam
| ChatCompletionUserMessageParam
| ChatCompletionAssistantMessageParam;

return {
...this.parameters,
Expand All @@ -180,6 +198,7 @@ export class OpenAIChatLLM extends ChatLLM<OpenAIChatLLMOutput> {
content: message.text,
}),
),

response_format: (() => {
if (options?.guided?.json) {
const schema = isString(options.guided.json)
Expand Down Expand Up @@ -232,7 +251,7 @@ export class OpenAIChatLLM extends ChatLLM<OpenAIChatLLMOutput> {
index: 1,
logprobs: choice.logprobs,
finish_reason: choice.finish_reason,
}) as Client.Chat.ChatCompletionChunk.Choice,
}) as ChatCompletionChunk.Choice,
),
});
}
Expand Down
1 change: 1 addition & 0 deletions tests/examples/examples.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ const exclude: string[] = [
"examples/agents/bee_reusable.ts",
"examples/llms/providers/openai.ts",
],
!hasEnv("AZURE_OPENAI_API_KEY") && ["examples/llms/providers/azure_openai.ts"],
!hasEnv("IBM_VLLM_URL") && [
"examples/llms/providers/ibm-vllm.ts",
"examples/agents/granite/chat.ts",
Expand Down