Skip to content

Commit

Permalink
Cli bug fixes (loading human/persona text, azure setup, local setup) (#…
Browse files Browse the repository at this point in the history
…222)

* mark depricated API section

* add readme

* add readme

* add readme

* add readme

* add readme

* add readme

* add readme

* add readme

* add readme

* CLI bug fixes for azure

* check azure before running

* Update README.md

* Update README.md

* bug fix with persona loading

* revert readme

* remove print
  • Loading branch information
sarahwooders authored Oct 31, 2023
1 parent d7e021b commit 89f0a69
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 18 deletions.
13 changes: 11 additions & 2 deletions memgpt/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
from memgpt.constants import MEMGPT_DIR
from memgpt.agent import AgentAsync
from memgpt.embeddings import embedding_model
from memgpt.openai_tools import (
configure_azure_support,
check_azure_embeddings,
)


def run(
Expand Down Expand Up @@ -135,14 +139,19 @@ def run(
agent_config.preset,
agent_config,
agent_config.model,
agent_config.persona,
agent_config.human,
utils.get_persona_text(agent_config.persona),
utils.get_human_text(agent_config.human),
memgpt.interface,
persistence_manager,
)

# start event loop
from memgpt.main import run_agent_loop

# setup azure if using
# TODO: cleanup this code
if config.model_endpoint == "azure":
configure_azure_support()

loop = asyncio.get_event_loop()
loop.run_until_complete(run_agent_loop(memgpt_agent, first, no_verify, config)) # TODO: add back no_verify
10 changes: 5 additions & 5 deletions memgpt/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ def configure():
use_azure_deployment_ids = False
if use_azure:
# search for key in enviornment
azure_key = os.getenv("AZURE_API_KEY")
azure_endpoint = (os.getenv("AZURE_ENDPOINT"),)
azure_version = (os.getenv("AZURE_VERSION"),)
azure_deployment = (os.getenv("AZURE_OPENAI_DEPLOYMENT"),)
azure_key = os.getenv("AZURE_OPENAI_KEY")
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
azure_version = os.getenv("AZURE_OPENAI_VERSION")
azure_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
azure_embedding_deployment = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT")

if all([azure_key, azure_endpoint, azure_version]):
Expand All @@ -66,7 +66,7 @@ def configure():
endpoint_options = []
if os.getenv("OPENAI_API_BASE") is not None:
endpoint_options.append(os.getenv("OPENAI_API_BASE"))
if os.getenv("AZURE_ENDPOINT") is not None:
if use_azure:
endpoint_options += ["azure"]
if use_openai:
endpoint_options += ["openai"]
Expand Down
11 changes: 7 additions & 4 deletions memgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,10 @@ def load(cls) -> "MemGPTConfig":
azure_key = config.get("azure", "key")
azure_endpoint = config.get("azure", "endpoint")
azure_version = config.get("azure", "version")
azure_deployment = config.get("azure", "deployment")
azure_embedding_deployment = config.get("azure", "embedding_deployment")
azure_deployment = config.get("azure", "deployment") if config.has_option("azure", "deployment") else None
azure_embedding_deployment = (
config.get("azure", "embedding_deployment") if config.has_option("azure", "embedding_deployment") else None
)

embedding_model = config.get("embedding", "model")
embedding_dim = config.getint("embedding", "dim")
Expand Down Expand Up @@ -167,8 +169,9 @@ def save(self):
config.set("azure", "key", self.azure_key)
config.set("azure", "endpoint", self.azure_endpoint)
config.set("azure", "version", self.azure_version)
config.set("azure", "deployment", self.azure_deployment)
config.set("azure", "embedding_deployment", self.azure_embedding_deployment)
if self.azure_deployment:
config.set("azure", "deployment", self.azure_deployment)
config.set("azure", "embedding_deployment", self.azure_embedding_deployment)

# embeddings
config.add_section("embedding")
Expand Down
38 changes: 31 additions & 7 deletions memgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from llama_index import set_global_service_context, ServiceContext, VectorStoreIndex, load_index_from_storage, StorageContext
from llama_index.embeddings import OpenAIEmbedding

from memgpt.embeddings import embedding_model


def count_tokens(s: str, model: str = "gpt-4") -> int:
encoding = tiktoken.encoding_for_model(model)
Expand Down Expand Up @@ -398,13 +400,11 @@ def get_index(name, docs):

# read embedding confirguration
# TODO: in the future, make an IngestData class that loads the config once
# config = MemGPTConfig.load()
# chunk_size = config.embedding_chunk_size
# model = config.embedding_model # TODO: actually use this
# dim = config.embedding_dim # TODO: actually use this
# embed_model = OpenAIEmbedding()
# service_context = ServiceContext.from_defaults(embed_model=embed_model, chunk_size=chunk_size)
# set_global_service_context(service_context)
config = MemGPTConfig.load()
embed_model = embedding_model(config)
chunk_size = config.embedding_chunk_size
service_context = ServiceContext.from_defaults(embed_model=embed_model, chunk_size=chunk_size)
set_global_service_context(service_context)

# index documents
index = VectorStoreIndex.from_documents(docs)
Expand Down Expand Up @@ -481,3 +481,27 @@ def list_persona_files():
user_added = os.listdir(user_dir)
user_added = [os.path.join(user_dir, f) for f in user_added]
return memgpt_defaults + user_added


def get_human_text(name: str):
for file_path in list_human_files():
file = os.path.basename(file_path)
if f"{name}.txt" == file or name == file:
return open(file_path, "r").read().strip()
raise ValueError(f"Human {name} not found")


def get_persona_text(name: str):
for file_path in list_persona_files():
file = os.path.basename(file_path)
if f"{name}.txt" == file or name == file:
return open(file_path, "r").read().strip()

raise ValueError(f"Persona {name} not found")


def get_human_text(name: str):
for file_path in list_human_files():
file = os.path.basename(file_path)
if f"{name}.txt" == file or name == file:
return open(file_path, "r").read().strip()

0 comments on commit 89f0a69

Please sign in to comment.