Skip to content

Commit

Permalink
Merge pull request #140 from cpacker/azure-patch
Browse files Browse the repository at this point in the history
Patch azure support

Co-Authored-By: rivms <[email protected]>
  • Loading branch information
vivi and rivms authored Oct 26, 2023
2 parents f55858b + 6e24464 commit 571d15b
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 38 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,10 @@ If you're using Azure OpenAI, set these variables instead:
export AZURE_OPENAI_KEY = ...
export AZURE_OPENAI_ENDPOINT = ...
export AZURE_OPENAI_VERSION = ...

# set the below if you are using deployment ids
export AZURE_OPENAI_DEPLOYMENT = ...
export AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT = ...

# then use the --use_azure_openai flag
memgpt --use_azure_openai
Expand Down
50 changes: 18 additions & 32 deletions memgpt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@

from memgpt.config import Config
from memgpt.constants import MEMGPT_DIR
from memgpt.openai_tools import (
configure_azure_support,
check_azure_embeddings,
get_set_azure_env_vars,
)

import asyncio

app = typer.Typer()
Expand Down Expand Up @@ -187,6 +193,18 @@ async def main(
if debug:
logging.getLogger().setLevel(logging.DEBUG)

# Azure OpenAI support
if use_azure_openai:
configure_azure_support()
check_azure_embeddings()
else:
azure_vars = get_set_azure_env_vars()
if len(azure_vars) > 0:
print(
f"Error: Environment variables {', '.join([x[0] for x in azure_vars])} should not be set if --use_azure_openai is False"
)
return

if any(
(
persona,
Expand Down Expand Up @@ -285,38 +303,6 @@ async def main(
f"⛔️ Warning - you are running MemGPT with {cfg.model}, which is not officially supported (yet). Expect bugs!"
)

# Azure OpenAI support
if use_azure_openai:
azure_openai_key = os.getenv("AZURE_OPENAI_KEY")
azure_openai_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
azure_openai_version = os.getenv("AZURE_OPENAI_VERSION")
azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
if None in [
azure_openai_key,
azure_openai_endpoint,
azure_openai_version,
azure_openai_deployment,
]:
print(
f"Error: missing Azure OpenAI environment variables. Please see README section on Azure."
)
return

import openai

openai.api_type = "azure"
openai.api_key = azure_openai_key
openai.api_base = azure_openai_endpoint
openai.api_version = azure_openai_version
# deployment gets passed into chatcompletion
else:
azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
if azure_openai_deployment is not None:
print(
f"Error: AZURE_OPENAI_DEPLOYMENT should not be set if --use_azure_openai is False"
)
return

if cfg.index:
persistence_manager = InMemoryStateManagerWithFaiss(
cfg.index, cfg.archival_database
Expand Down
80 changes: 74 additions & 6 deletions memgpt/openai_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,18 +116,26 @@ async def acompletions_with_backoff(**kwargs):

# OpenAI / Azure model
else:
azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
if azure_openai_deployment is not None:
kwargs["deployment_id"] = azure_openai_deployment
if using_azure():
azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
if azure_openai_deployment is not None:
kwargs["deployment_id"] = azure_openai_deployment
else:
kwargs["engine"] = MODEL_TO_AZURE_ENGINE[kwargs["model"]]
kwargs.pop("model")
return await openai.ChatCompletion.acreate(**kwargs)


@aretry_with_exponential_backoff
async def acreate_embedding_with_backoff(**kwargs):
"""Wrapper around Embedding.acreate w/ backoff"""
azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
if azure_openai_deployment is not None:
kwargs["deployment_id"] = azure_openai_deployment
if using_azure():
azure_openai_deployment = os.getenv("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT")
if azure_openai_deployment is not None:
kwargs["deployment_id"] = azure_openai_deployment
else:
kwargs["engine"] = kwargs["model"]
kwargs.pop("model")
return await openai.Embedding.acreate(**kwargs)


Expand All @@ -138,3 +146,63 @@ async def async_get_embedding_with_backoff(text, model="text-embedding-ada-002")
response = await acreate_embedding_with_backoff(input=[text], model=model)
embedding = response["data"][0]["embedding"]
return embedding


MODEL_TO_AZURE_ENGINE = {
"gpt-4": "gpt-4",
"gpt-4-32k": "gpt-4-32k",
"gpt-3.5": "gpt-35-turbo",
"gpt-3.5-turbo": "gpt-35-turbo",
"gpt-3.5-turbo-16k": "gpt-35-turbo-16k",
}


def get_set_azure_env_vars():
azure_env_variables = [
("AZURE_OPENAI_KEY", os.getenv("AZURE_OPENAI_KEY")),
("AZURE_OPENAI_ENDPOINT", os.getenv("AZURE_OPENAI_ENDPOINT")),
("AZURE_OPENAI_VERSION", os.getenv("AZURE_OPENAI_VERSION")),
("AZURE_OPENAI_DEPLOYMENT", os.getenv("AZURE_OPENAI_DEPLOYMENT")),
(
"AZURE_OPENAI_EMBEDDING_DEPLOYMENT",
os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT"),
),
]
return [x for x in azure_env_variables if x[1] is not None]


def using_azure():
return len(get_set_azure_env_vars()) > 0


def configure_azure_support():
azure_openai_key = os.getenv("AZURE_OPENAI_KEY")
azure_openai_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
azure_openai_version = os.getenv("AZURE_OPENAI_VERSION")
if None in [
azure_openai_key,
azure_openai_endpoint,
azure_openai_version,
]:
print(
f"Error: missing Azure OpenAI environment variables. Please see README section on Azure."
)
return

openai.api_type = "azure"
openai.api_key = azure_openai_key
openai.api_base = azure_openai_endpoint
openai.api_version = azure_openai_version
# deployment gets passed into chatcompletion


def check_azure_embeddings():
azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
azure_openai_embedding_deployment = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT")
if (
azure_openai_deployment is not None
and azure_openai_embedding_deployment is None
):
raise ValueError(
f"Error: It looks like you are using Azure deployment ids and computing embeddings, make sure you are setting one for embeddings as well. Please see README section on Azure"
)

0 comments on commit 571d15b

Please sign in to comment.