Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OAI api refactor #391

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions rdagent/oai/backends/az.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""
TODO:
It is not complete now.

Please refer to rdagent/oai/llm_utils.py:APIBackend for the future design
"""

from azure.identity import DefaultAzureCredential, get_bearer_token_provider
import openai
from pydantic_settings import BaseSettings


class AzureConf(BaseSettings):
"""
TODO: move more settings here
"""
use_azure_token_provider: bool = False
managed_identity_client_id: str | None = None
chat_model: str = "gpt-4-turbo"

chat_azure_api_base: str = ""
chat_azure_api_version: str = ""


class BaseAPI:
"""
TOOD: there may be some more shared methods in the BaseAPI
"""
pass


class AzureAPI(BaseAPI):

def _get_credential(self):
dac_kwargs = {}
if AZURE_CONF.managed_identity_client_id is not None:
dac_kwargs["managed_identity_client_id"] = self.managed_identity_client_id
credential = DefaultAzureCredential(**dac_kwargs)
return credential

def _get_client(self):
kwargs = {}
if AZURE_CONF.use_azure_token_provider:
kwargs["azure_ad_token_provider"]= get_bearer_token_provider(
self._get_credential(),
"https://cognitiveservices.azure.com/.default",
)
return openai.AzureOpenAI(
api_version=AZURE_CONF.chat_azure_api_version,
azure_endpoint=AZURE_CONF.chat_azure_api_base,
**kwargs,
)

# def list_deployments(self):
# client = self._get_client()
# try:
# deployments = client.deployments.list()
# return [deployment for deployment in deployments]
# except Exception as e:
# print(f"An error occurred while listing deployments: {e}")
# return []

AZURE_CONF = AzureConf()


# if __name__ == "__main__":
# api = AzureAPI()
# deployments = api.list_deployments()
# print(deployments)
8 changes: 8 additions & 0 deletions rdagent/oai/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,14 @@ def display_history(self) -> None:


class APIBackend:
"""
This is a unified interface for different backends.

(xiao) thinks integerate all kinds of API in a single class is not a good design.
So we should split them into different classes in `oai/backends/` in the future.
"""
# FIXME: (xiao) We should avoid using self.xxxx.
# Instead, we can use self.cfg directly. If it's difficult to support different backend settings, we can split them into multiple BaseSettings.
def __init__( # noqa: C901, PLR0912, PLR0915
self,
*,
Expand Down
Loading