From a1c2747950b72d21d514b5f52ab61e6975cb19aa Mon Sep 17 00:00:00 2001 From: Vivian Fang Date: Mon, 6 Nov 2023 16:38:10 -0800 Subject: [PATCH 1/4] Relax verify_first_message_correctness to accept any function call --- memgpt/agent.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/memgpt/agent.py b/memgpt/agent.py index ecf13e86e1..60afbb6a2e 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -466,18 +466,19 @@ def load_from_json_file_inplace(self, json_file): state = json.load(file) self.load_inplace(state) - def verify_first_message_correctness(self, response, require_send_message=True, require_monologue=False): + def verify_first_message_correctness(self, response, require_function_call=True, require_monologue=False): """Can be used to enforce that the first message always uses send_message""" response_message = response.choices[0].message # First message should be a call to send_message with a non-empty content - if require_send_message and not response_message.get("function_call"): + if require_function_call and not response_message.get("function_call"): printd(f"First message didn't include function call: {response_message}") return False + expected_function_calls = [x["name"] for x in self.functions] function_name = response_message["function_call"]["name"] - if require_send_message and function_name != "send_message": - printd(f"First message function call wasn't send_message: {response_message}") + if require_function_call and function_name not in expected_function_calls: + printd(f"First message function call wasn't one of {expected_function_calls}: {response_message}") return False if require_monologue and ( From a73c2ed74e2c677b50fd63007ad4d90314d1e0b4 Mon Sep 17 00:00:00 2001 From: Vivian Fang Date: Mon, 6 Nov 2023 16:53:15 -0800 Subject: [PATCH 2/4] Also allow missing internal monologue if request_heartbeat --- memgpt/agent.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/memgpt/agent.py b/memgpt/agent.py index 60afbb6a2e..e5de6312f7 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -484,8 +484,20 @@ def verify_first_message_correctness(self, response, require_function_call=True, if require_monologue and ( not response_message.get("content") or response_message["content"] is None or response_message["content"] == "" ): - printd(f"First message missing internal monologue: {response_message}") - return False + if function_name in expected_function_calls: + try: + raw_function_args = response_message["function_call"]["arguments"] + function_args = parse_json(raw_function_args) + except Exception as e: + printd(f"First message missing internal monologue and has badly formed arguments: {response_message}") + return False + if function_args["request_heartbeat"] != True: + printd(f"First message missing internal monologue and does not chain further functions: {response_message}") + return False + print("No internal monologue but that's ok bc request_heartbeats") + else: + printd(f"First message missing internal monologue: {response_message}") + return False if response_message.get("content"): ### Extras From 248961c6280765e67c29cdf999d788f4688afdb7 Mon Sep 17 00:00:00 2001 From: Vivian Fang Date: Mon, 6 Nov 2023 16:57:17 -0800 Subject: [PATCH 3/4] Cleanup --- memgpt/agent.py | 1 - 1 file changed, 1 deletion(-) diff --git a/memgpt/agent.py b/memgpt/agent.py index e5de6312f7..53a316ea65 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -494,7 +494,6 @@ def verify_first_message_correctness(self, response, require_function_call=True, if function_args["request_heartbeat"] != True: printd(f"First message missing internal monologue and does not chain further functions: {response_message}") return False - print("No internal monologue but that's ok bc request_heartbeats") else: printd(f"First message missing internal monologue: {response_message}") return False From cebd10fae116e834b899ce0ca604c7eebf38a749 Mon Sep 17 00:00:00 2001 From: Vivian Fang Date: Mon, 6 Nov 2023 17:26:48 -0800 Subject: [PATCH 4/4] get instead of raw dict access --- memgpt/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/memgpt/agent.py b/memgpt/agent.py index 53a316ea65..ec156e53c3 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -491,7 +491,7 @@ def verify_first_message_correctness(self, response, require_function_call=True, except Exception as e: printd(f"First message missing internal monologue and has badly formed arguments: {response_message}") return False - if function_args["request_heartbeat"] != True: + if function_args.get("request_heartbeat") != True: printd(f"First message missing internal monologue and does not chain further functions: {response_message}") return False else: