Skip to content

Commit

Permalink
Replace memgpt run flags error with warning + remove custom embeddi…
Browse files Browse the repository at this point in the history
…ng endpoint option + add agent create time (#364)
  • Loading branch information
sarahwooders authored Nov 9, 2023
1 parent 350e4af commit 7d74aad
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 14 deletions.
13 changes: 10 additions & 3 deletions memgpt/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,18 @@ def run(
# persistence_manager = LocalStateManager(agent_config).load() # TODO: implement load
# TODO: load prior agent state
if persona and persona != agent_config.persona:
raise ValueError(f"Cannot override {agent_config.name} existing persona {agent_config.persona} with {persona}")
typer.secho(f"Warning: Overriding existing persona {agent_config.persona} with {persona}", fg=typer.colors.YELLOW)
agent_config.persona = persona
# raise ValueError(f"Cannot override {agent_config.name} existing persona {agent_config.persona} with {persona}")
if human and human != agent_config.human:
raise ValueError(f"Cannot override {agent_config.name} existing human {agent_config.human} with {human}")
typer.secho(f"Warning: Overriding existing human {agent_config.human} with {human}", fg=typer.colors.YELLOW)
agent_config.human = human
# raise ValueError(f"Cannot override {agent_config.name} existing human {agent_config.human} with {human}")
if model and model != agent_config.model:
raise ValueError(f"Cannot override {agent_config.name} existing model {agent_config.model} with {model}")
typer.secho(f"Warning: Overriding existing model {agent_config.model} with {model}", fg=typer.colors.YELLOW)
agent_config.model = model
# raise ValueError(f"Cannot override {agent_config.name} existing model {agent_config.model} with {model}")
agent_config.save()

# load existing agent
memgpt_agent = AgentAsync.load_agent(memgpt.interface, agent_config)
Expand Down
34 changes: 23 additions & 11 deletions memgpt/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,21 +66,24 @@ def configure():
# TODO: configure local model

# configure provider
use_local = not use_openai and os.getenv("OPENAI_API_BASE")
endpoint_options = []
model_endpoint_options = []
if os.getenv("OPENAI_API_BASE") is not None:
endpoint_options.append(os.getenv("OPENAI_API_BASE"))
model_endpoint_options.append(os.getenv("OPENAI_API_BASE"))
if use_azure:
endpoint_options += ["azure"]
model_endpoint_options += ["azure"]
if use_openai:
endpoint_options += ["openai"]
model_endpoint_options += ["openai"]

assert len(endpoint_options) > 0, "No endpoints found. Please enable OpenAI, Azure, or set OPENAI_API_BASE."
default_endpoint = questionary.select("Select default inference endpoint:", endpoint_options).ask()
assert len(model_endpoint_options) > 0, "No endpoints found. Please enable OpenAI, Azure, or set OPENAI_API_BASE."
default_endpoint = questionary.select("Select default inference endpoint:", model_endpoint_options).ask()

# configure embedding provider
endpoint_options.append("local") # can compute embeddings locally
default_embedding_endpoint = questionary.select("Select default embedding endpoint:", endpoint_options).ask()
embedding_endpoint_options = ["local"] # cannot configure custom endpoint (too confusing)
if use_azure:
model_endpoint_options += ["azure"]
if use_openai:
model_endpoint_options += ["openai"]
default_embedding_endpoint = questionary.select("Select default embedding endpoint:", embedding_endpoint_options).ask()

# configure embedding dimentions
default_embedding_dim = 1536
Expand Down Expand Up @@ -159,11 +162,20 @@ def list(option: str):
if option == "agents":
"""List all agents"""
table = PrettyTable()
table.field_names = ["Name", "Model", "Persona", "Human", "Data Source"]
table.field_names = ["Name", "Model", "Persona", "Human", "Data Source", "Create Time"]
for agent_file in utils.list_agent_config_files():
agent_name = os.path.basename(agent_file).replace(".json", "")
agent_config = AgentConfig.load(agent_name)
table.add_row([agent_name, agent_config.model, agent_config.persona, agent_config.human, ",".join(agent_config.data_sources)])
table.add_row(
[
agent_name,
agent_config.model,
agent_config.persona,
agent_config.human,
",".join(agent_config.data_sources),
agent_config.create_time,
]
)
print(table)
elif option == "humans":
"""List all humans"""
Expand Down

0 comments on commit 7d74aad

Please sign in to comment.