diff --git a/docs/source/en/main_classes/agent.mdx b/docs/source/en/main_classes/agent.mdx index ecf1747d12b..5f4b3df306c 100644 --- a/docs/source/en/main_classes/agent.mdx +++ b/docs/source/en/main_classes/agent.mdx @@ -38,6 +38,10 @@ We provide three types of agents: [`HfAgent`] uses inference endpoints for opens [[autodoc]] OpenAiAgent +### AzureOpenAiAgent + +[[autodoc]] AzureOpenAiAgent + ### Agent [[autodoc]] Agent diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index c7deecf7552..a1d995670f6 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -619,6 +619,7 @@ ], "tools": [ "Agent", + "AzureOpenAiAgent", "HfAgent", "LocalAgent", "OpenAiAgent", @@ -4410,6 +4411,7 @@ # Tools from .tools import ( Agent, + AzureOpenAiAgent, HfAgent, LocalAgent, OpenAiAgent, diff --git a/src/transformers/tools/__init__.py b/src/transformers/tools/__init__.py index a5b3c7dc05e..68d66eb275e 100644 --- a/src/transformers/tools/__init__.py +++ b/src/transformers/tools/__init__.py @@ -24,7 +24,7 @@ _import_structure = { - "agents": ["Agent", "HfAgent", "LocalAgent", "OpenAiAgent"], + "agents": ["Agent", "AzureOpenAiAgent", "HfAgent", "LocalAgent", "OpenAiAgent"], "base": ["PipelineTool", "RemoteTool", "Tool", "launch_gradio_demo", "load_tool"], } @@ -46,7 +46,7 @@ _import_structure["translation"] = ["TranslationTool"] if TYPE_CHECKING: - from .agents import Agent, HfAgent, LocalAgent, OpenAiAgent + from .agents import Agent, AzureOpenAiAgent, HfAgent, LocalAgent, OpenAiAgent from .base import PipelineTool, RemoteTool, Tool, launch_gradio_demo, load_tool try: diff --git a/src/transformers/tools/agents.py b/src/transformers/tools/agents.py index 0ecb600d921..66563c3529f 100644 --- a/src/transformers/tools/agents.py +++ b/src/transformers/tools/agents.py @@ -453,6 +453,132 @@ def _completion_generate(self, prompts, stop): return [answer["text"] for answer in result["choices"]] +class AzureOpenAiAgent(Agent): + """ + Agent that uses Azure OpenAI to generate code. See the [official + documentation](https://learn.microsoft.com/en-us/azure/cognitive-services/openai/) to learn how to deploy an openAI + model on Azure + + + + The openAI models are used in generation mode, so even for the `chat()` API, it's better to use models like + `"text-davinci-003"` over the chat-GPT variant. Proper support for chat-GPT models will come in a next version. + + + + Args: + deployment_id (`str`): + The name of the deployed Azure openAI model to use. + api_key (`str`, *optional*): + The API key to use. If unset, will look for the environment variable `"AZURE_OPENAI_API_KEY"`. + resource_name (`str`, *optional*): + The name of your Azure OpenAI Resource. If unset, will look for the environment variable + `"AZURE_OPENAI_RESOURCE_NAME"`. + api_version (`str`, *optional*, default to `"2022-12-01"`): + The API version to use for this agent. + is_chat_mode (`bool`, *optional*): + Whether you are using a completion model or a chat model (see note above, chat models won't be as + efficient). Will default to `gpt` being in the `deployment_id` or not. + chat_prompt_template (`str`, *optional*): + Pass along your own prompt if you want to override the default template for the `chat` method. Can be the + actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named + `chat_prompt_template.txt` in this repo in this case. + run_prompt_template (`str`, *optional*): + Pass along your own prompt if you want to override the default template for the `run` method. Can be the + actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named + `run_prompt_template.txt` in this repo in this case. + additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*): + Any additional tools to include on top of the default ones. If you pass along a tool with the same name as + one of the default tools, that default tool will be overridden. + + Example: + + ```py + from transformers import AzureOpenAiAgent + + agent = AzureAiAgent(deployment_id="Davinci-003", api_key=xxx, resource_name=yyy) + agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!") + ``` + """ + + def __init__( + self, + deployment_id, + api_key=None, + resource_name=None, + api_version="2022-12-01", + is_chat_model=None, + chat_prompt_template=None, + run_prompt_template=None, + additional_tools=None, + ): + if not is_openai_available(): + raise ImportError("Using `OpenAiAgent` requires `openai`: `pip install openai`.") + + self.deployment_id = deployment_id + openai.api_type = "azure" + if api_key is None: + api_key = os.environ.get("AZURE_OPENAI_API_KEY", None) + if api_key is None: + raise ValueError( + "You need an Azure openAI key to use `AzureOpenAIAgent`. If you have one, set it in your env with " + "`os.environ['AZURE_OPENAI_API_KEY'] = xxx." + ) + else: + openai.api_key = api_key + if resource_name is None: + resource_name = os.environ.get("AZURE_OPENAI_RESOURCE_NAME", None) + if resource_name is None: + raise ValueError( + "You need a resource_name to use `AzureOpenAIAgent`. If you have one, set it in your env with " + "`os.environ['AZURE_OPENAI_RESOURCE_NAME'] = xxx." + ) + else: + openai.api_base = f"https://{resource_name}.openai.azure.com" + openai.api_version = api_version + + if is_chat_model is None: + is_chat_model = "gpt" in deployment_id.lower() + self.is_chat_model = is_chat_model + + super().__init__( + chat_prompt_template=chat_prompt_template, + run_prompt_template=run_prompt_template, + additional_tools=additional_tools, + ) + + def generate_many(self, prompts, stop): + if self.is_chat_model: + return [self._chat_generate(prompt, stop) for prompt in prompts] + else: + return self._completion_generate(prompts, stop) + + def generate_one(self, prompt, stop): + if self.is_chat_model: + return self._chat_generate(prompt, stop) + else: + return self._completion_generate([prompt], stop)[0] + + def _chat_generate(self, prompt, stop): + result = openai.ChatCompletion.create( + engine=self.deployment_id, + messages=[{"role": "user", "content": prompt}], + temperature=0, + stop=stop, + ) + return result["choices"][0]["message"]["content"] + + def _completion_generate(self, prompts, stop): + result = openai.Completion.create( + engine=self.deployment_id, + prompt=prompts, + temperature=0, + stop=stop, + max_tokens=200, + ) + return [answer["text"] for answer in result["choices"]] + + class HfAgent(Agent): """ Agent that uses an inference endpoint to generate code.