diff --git a/README.md b/README.md index dd6ad66321..8f68e31ede 100644 --- a/README.md +++ b/README.md @@ -132,7 +132,10 @@ If you're using Azure OpenAI, set these variables instead: export AZURE_OPENAI_KEY = ... export AZURE_OPENAI_ENDPOINT = ... export AZURE_OPENAI_VERSION = ... + +# set the below if you are using deployment ids export AZURE_OPENAI_DEPLOYMENT = ... +export AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT = ... # then use the --use_azure_openai flag memgpt --use_azure_openai diff --git a/memgpt/main.py b/memgpt/main.py index e0bc5bc43e..ec2d73ee04 100644 --- a/memgpt/main.py +++ b/memgpt/main.py @@ -28,6 +28,12 @@ from memgpt.config import Config from memgpt.constants import MEMGPT_DIR +from memgpt.openai_tools import ( + configure_azure_support, + check_azure_embeddings, + get_set_azure_env_vars, +) + import asyncio app = typer.Typer() @@ -187,6 +193,18 @@ async def main( if debug: logging.getLogger().setLevel(logging.DEBUG) + # Azure OpenAI support + if use_azure_openai: + configure_azure_support() + check_azure_embeddings() + else: + azure_vars = get_set_azure_env_vars() + if len(azure_vars) > 0: + print( + f"Error: Environment variables {', '.join([x[0] for x in azure_vars])} should not be set if --use_azure_openai is False" + ) + return + if any( ( persona, @@ -285,38 +303,6 @@ async def main( f"⛔️ Warning - you are running MemGPT with {cfg.model}, which is not officially supported (yet). Expect bugs!" ) - # Azure OpenAI support - if use_azure_openai: - azure_openai_key = os.getenv("AZURE_OPENAI_KEY") - azure_openai_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") - azure_openai_version = os.getenv("AZURE_OPENAI_VERSION") - azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") - if None in [ - azure_openai_key, - azure_openai_endpoint, - azure_openai_version, - azure_openai_deployment, - ]: - print( - f"Error: missing Azure OpenAI environment variables. Please see README section on Azure." - ) - return - - import openai - - openai.api_type = "azure" - openai.api_key = azure_openai_key - openai.api_base = azure_openai_endpoint - openai.api_version = azure_openai_version - # deployment gets passed into chatcompletion - else: - azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") - if azure_openai_deployment is not None: - print( - f"Error: AZURE_OPENAI_DEPLOYMENT should not be set if --use_azure_openai is False" - ) - return - if cfg.index: persistence_manager = InMemoryStateManagerWithFaiss( cfg.index, cfg.archival_database diff --git a/memgpt/openai_tools.py b/memgpt/openai_tools.py index 3d63d13464..20ad5f9a46 100644 --- a/memgpt/openai_tools.py +++ b/memgpt/openai_tools.py @@ -116,18 +116,26 @@ async def acompletions_with_backoff(**kwargs): # OpenAI / Azure model else: - azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") - if azure_openai_deployment is not None: - kwargs["deployment_id"] = azure_openai_deployment + if using_azure(): + azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") + if azure_openai_deployment is not None: + kwargs["deployment_id"] = azure_openai_deployment + else: + kwargs["engine"] = MODEL_TO_AZURE_ENGINE[kwargs["model"]] + kwargs.pop("model") return await openai.ChatCompletion.acreate(**kwargs) @aretry_with_exponential_backoff async def acreate_embedding_with_backoff(**kwargs): """Wrapper around Embedding.acreate w/ backoff""" - azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") - if azure_openai_deployment is not None: - kwargs["deployment_id"] = azure_openai_deployment + if using_azure(): + azure_openai_deployment = os.getenv("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT") + if azure_openai_deployment is not None: + kwargs["deployment_id"] = azure_openai_deployment + else: + kwargs["engine"] = kwargs["model"] + kwargs.pop("model") return await openai.Embedding.acreate(**kwargs) @@ -138,3 +146,63 @@ async def async_get_embedding_with_backoff(text, model="text-embedding-ada-002") response = await acreate_embedding_with_backoff(input=[text], model=model) embedding = response["data"][0]["embedding"] return embedding + + +MODEL_TO_AZURE_ENGINE = { + "gpt-4": "gpt-4", + "gpt-4-32k": "gpt-4-32k", + "gpt-3.5": "gpt-35-turbo", + "gpt-3.5-turbo": "gpt-35-turbo", + "gpt-3.5-turbo-16k": "gpt-35-turbo-16k", +} + + +def get_set_azure_env_vars(): + azure_env_variables = [ + ("AZURE_OPENAI_KEY", os.getenv("AZURE_OPENAI_KEY")), + ("AZURE_OPENAI_ENDPOINT", os.getenv("AZURE_OPENAI_ENDPOINT")), + ("AZURE_OPENAI_VERSION", os.getenv("AZURE_OPENAI_VERSION")), + ("AZURE_OPENAI_DEPLOYMENT", os.getenv("AZURE_OPENAI_DEPLOYMENT")), + ( + "AZURE_OPENAI_EMBEDDING_DEPLOYMENT", + os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT"), + ), + ] + return [x for x in azure_env_variables if x[1] is not None] + + +def using_azure(): + return len(get_set_azure_env_vars()) > 0 + + +def configure_azure_support(): + azure_openai_key = os.getenv("AZURE_OPENAI_KEY") + azure_openai_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") + azure_openai_version = os.getenv("AZURE_OPENAI_VERSION") + if None in [ + azure_openai_key, + azure_openai_endpoint, + azure_openai_version, + ]: + print( + f"Error: missing Azure OpenAI environment variables. Please see README section on Azure." + ) + return + + openai.api_type = "azure" + openai.api_key = azure_openai_key + openai.api_base = azure_openai_endpoint + openai.api_version = azure_openai_version + # deployment gets passed into chatcompletion + + +def check_azure_embeddings(): + azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") + azure_openai_embedding_deployment = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT") + if ( + azure_openai_deployment is not None + and azure_openai_embedding_deployment is None + ): + raise ValueError( + f"Error: It looks like you are using Azure deployment ids and computing embeddings, make sure you are setting one for embeddings as well. Please see README section on Azure" + )