Skip to content

Commit

Permalink
feat(adapters): add Azure OpenAI LLM adapter
Browse files Browse the repository at this point in the history
Signed-off-by: Akihiko Kuroda <[email protected]>
  • Loading branch information
akihikokuroda committed Dec 2, 2024
1 parent 74713a6 commit 5a4f489
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 24 deletions.
6 changes: 6 additions & 0 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ BEE_FRAMEWORK_LOG_SINGLE_LINE="false"
# For OpenAI LLM Adapter
# OPENAI_API_KEY=

# For Azure OpenAI LLM Adapter
OPENAI_API_VERSION=
AZURE_DEPLOYMENT_NAME=
AZURE_OPENAI_API_KEY=
AZURE_OPENAI_ENDPOINT=

# For Groq LLM Adapter
# GROQ_API_KEY=

Expand Down
21 changes: 21 additions & 0 deletions examples/llms/providers/azure_openai.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import "dotenv/config";
import { BaseMessage } from "bee-agent-framework/llms/primitives/message";
import { OpenAIChatLLM } from "bee-agent-framework/adapters/openai/chat";

const llm = new OpenAIChatLLM({
modelId: "gpt-4o-mini",
azure: true,
parameters: {
max_tokens: 10,
stop: ["post"],
},
});

console.info("Meta", await llm.meta());
const response = await llm.generate([
BaseMessage.of({
role: "user",
text: "Hello world!",
}),
]);
console.info(response.getTextContent());
21 changes: 20 additions & 1 deletion src/adapters/openai/chat.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,26 @@

import { verifyDeserialization } from "@tests/e2e/utils.js";
import { OpenAIChatLLM } from "@/adapters/openai/chat.js";
import { OpenAI } from "openai";
import { OpenAI, AzureOpenAI } from "openai";

describe("AzureOpenAI ChatLLM", () => {
const getInstance = () => {
return new OpenAIChatLLM({
modelId: "gpt-4o",
client: new AzureOpenAI(),
});
};

it("Serializes", async () => {
process.env["OPENAI_BASE_URL"] = "http://dummy/";
process.env["AZURE_OPENAI_API_KEY"] = "123";
process.env["OPENAI_API_VERSION"] = "version 1";
const instance = getInstance();
const serialized = instance.serialize();
const deserialized = OpenAIChatLLM.fromSerialized(serialized);
verifyDeserialization(instance, deserialized);
});
});

describe("OpenAI ChatLLM", () => {
const getInstance = () => {
Expand Down
65 changes: 42 additions & 23 deletions src/adapters/openai/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,24 @@ import { shallowCopy } from "@/serializer/utils.js";
import { ChatLLM, ChatLLMGenerateEvents, ChatLLMOutput } from "@/llms/chat.js";
import { BaseMessage, RoleType } from "@/llms/primitives/message.js";
import { Emitter } from "@/emitter/emitter.js";
import { ClientOptions, OpenAI as Client } from "openai";
import { ClientOptions, OpenAI, AzureOpenAI } from "openai";
import { GetRunContext } from "@/context.js";
import { promptTokensEstimate } from "openai-chat-tokens";
import { Serializer } from "@/serializer/serializer.js";
import { getProp, getPropStrict } from "@/internals/helpers/object.js";
import { isString } from "remeda";
import type {
ChatCompletionChunk,
ChatCompletionCreateParams,
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionUserMessageParam,
ChatCompletionAssistantMessageParam,
ChatModel,
} from "openai/resources/index";

type Parameters = Omit<Client.Chat.ChatCompletionCreateParams, "stream" | "messages" | "model">;
type Response = Omit<Client.Chat.ChatCompletionChunk, "object">;
type Parameters = Omit<ChatCompletionCreateParams, "stream" | "messages" | "model">;
type Response = Omit<ChatCompletionChunk, "object">;

export class OpenAIChatLLMOutput extends ChatLLMOutput {
public readonly responses: Response[];
Expand Down Expand Up @@ -84,11 +93,12 @@ export class OpenAIChatLLMOutput extends ChatLLMOutput {
}

interface Input {
modelId?: Client.ChatModel;
client?: Client;
modelId?: ChatModel;
client?: OpenAI | AzureOpenAI;
parameters?: Partial<Parameters>;
executionOptions?: ExecutionOptions;
cache?: LLMCache<OpenAIChatLLMOutput>;
azure?: boolean;
}

export type OpenAIChatLLMEvents = ChatLLMGenerateEvents<OpenAIChatLLMOutput>;
Expand All @@ -99,28 +109,36 @@ export class OpenAIChatLLM extends ChatLLM<OpenAIChatLLMOutput> {
creator: this,
});

public readonly client: Client;
public readonly client: OpenAI | AzureOpenAI;
public readonly parameters: Partial<Parameters>;

constructor({
client,
modelId = "gpt-4o",
parameters,
executionOptions = {},
cache,
}: Input = {}) {
super(modelId, executionOptions, cache);
this.client = client ?? new Client();
constructor({ client, modelId, parameters, executionOptions = {}, cache, azure }: Input = {}) {
super(modelId || "gpt-4o-mini", executionOptions, cache);
if (client) {
this.client = client;
} else if (azure) {
this.client = new AzureOpenAI();
} else {
this.client = new OpenAI();
}
this.parameters = parameters ?? { temperature: 0 };
}

static {
this.register();
Serializer.register(Client, {
Serializer.register(AzureOpenAI, {
toPlain: (value) => ({
azureADTokenProvider: getPropStrict(value, "_azureADTokenProvider"),
apiVersion: getPropStrict(value, "apiVersion"),
deployment: getPropStrict(value, "_deployment"),
}),
fromPlain: (value) => new AzureOpenAI(value.azureADTokenProvider),
});
Serializer.register(OpenAI, {
toPlain: (value) => ({
options: getPropStrict(value, "_options") as ClientOptions,
}),
fromPlain: (value) => new Client(value.options),
fromPlain: (value) => new OpenAI(value.options),
});
}

Expand Down Expand Up @@ -152,7 +170,7 @@ export class OpenAIChatLLM extends ChatLLM<OpenAIChatLLMOutput> {
({
role: msg.role,
content: msg.text,
}) as Client.Chat.ChatCompletionMessageParam,
}) as ChatCompletionMessageParam,
),
});

Expand All @@ -164,11 +182,11 @@ export class OpenAIChatLLM extends ChatLLM<OpenAIChatLLMOutput> {
protected _prepareRequest(
input: BaseMessage[],
options?: GenerateOptions,
): Client.Chat.ChatCompletionCreateParams {
): ChatCompletionCreateParams {
type OpenAIMessage =
| Client.Chat.ChatCompletionSystemMessageParam
| Client.Chat.ChatCompletionUserMessageParam
| Client.Chat.ChatCompletionAssistantMessageParam;
| ChatCompletionSystemMessageParam
| ChatCompletionUserMessageParam
| ChatCompletionAssistantMessageParam;

return {
...this.parameters,
Expand All @@ -180,6 +198,7 @@ export class OpenAIChatLLM extends ChatLLM<OpenAIChatLLMOutput> {
content: message.text,
}),
),

response_format: (() => {
if (options?.guided?.json) {
const schema = isString(options.guided.json)
Expand Down Expand Up @@ -232,7 +251,7 @@ export class OpenAIChatLLM extends ChatLLM<OpenAIChatLLMOutput> {
index: 1,
logprobs: choice.logprobs,
finish_reason: choice.finish_reason,
}) as Client.Chat.ChatCompletionChunk.Choice,
}) as ChatCompletionChunk.Choice,
),
});
}
Expand Down
1 change: 1 addition & 0 deletions tests/examples/examples.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ const exclude: string[] = [
"examples/agents/bee_reusable.ts",
"examples/llms/providers/openai.ts",
],
!hasEnv("AZURE_OPENAI_API_KEY") && ["examples/llms/providers/azure_openai.ts"],
!hasEnv("IBM_VLLM_URL") && [
"examples/llms/providers/ibm-vllm.ts",
"examples/agents/granite/chat.ts",
Expand Down

0 comments on commit 5a4f489

Please sign in to comment.