Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relax verify_first_message_correctness to accept any function call #340

Merged
merged 4 commits into from
Nov 7, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions memgpt/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,25 +466,37 @@ def load_from_json_file_inplace(self, json_file):
state = json.load(file)
self.load_inplace(state)

def verify_first_message_correctness(self, response, require_send_message=True, require_monologue=False):
def verify_first_message_correctness(self, response, require_function_call=True, require_monologue=False):
cpacker marked this conversation as resolved.
Show resolved Hide resolved
"""Can be used to enforce that the first message always uses send_message"""
response_message = response.choices[0].message

# First message should be a call to send_message with a non-empty content
if require_send_message and not response_message.get("function_call"):
if require_function_call and not response_message.get("function_call"):
printd(f"First message didn't include function call: {response_message}")
return False

expected_function_calls = [x["name"] for x in self.functions]
function_name = response_message["function_call"]["name"]
if require_send_message and function_name != "send_message":
printd(f"First message function call wasn't send_message: {response_message}")
if require_function_call and function_name not in expected_function_calls:
printd(f"First message function call wasn't one of {expected_function_calls}: {response_message}")
return False

if require_monologue and (
not response_message.get("content") or response_message["content"] is None or response_message["content"] == ""
):
printd(f"First message missing internal monologue: {response_message}")
return False
if function_name in expected_function_calls:
try:
raw_function_args = response_message["function_call"]["arguments"]
function_args = parse_json(raw_function_args)
except Exception as e:
printd(f"First message missing internal monologue and has badly formed arguments: {response_message}")
return False
if function_args.get("request_heartbeat") != True:
printd(f"First message missing internal monologue and does not chain further functions: {response_message}")
return False
else:
printd(f"First message missing internal monologue: {response_message}")
return False

if response_message.get("content"):
### Extras
Expand Down
Loading