diff --git a/memgpt/agent.py b/memgpt/agent.py index ecf13e86e1..ec156e53c3 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -466,25 +466,37 @@ def load_from_json_file_inplace(self, json_file): state = json.load(file) self.load_inplace(state) - def verify_first_message_correctness(self, response, require_send_message=True, require_monologue=False): + def verify_first_message_correctness(self, response, require_function_call=True, require_monologue=False): """Can be used to enforce that the first message always uses send_message""" response_message = response.choices[0].message # First message should be a call to send_message with a non-empty content - if require_send_message and not response_message.get("function_call"): + if require_function_call and not response_message.get("function_call"): printd(f"First message didn't include function call: {response_message}") return False + expected_function_calls = [x["name"] for x in self.functions] function_name = response_message["function_call"]["name"] - if require_send_message and function_name != "send_message": - printd(f"First message function call wasn't send_message: {response_message}") + if require_function_call and function_name not in expected_function_calls: + printd(f"First message function call wasn't one of {expected_function_calls}: {response_message}") return False if require_monologue and ( not response_message.get("content") or response_message["content"] is None or response_message["content"] == "" ): - printd(f"First message missing internal monologue: {response_message}") - return False + if function_name in expected_function_calls: + try: + raw_function_args = response_message["function_call"]["arguments"] + function_args = parse_json(raw_function_args) + except Exception as e: + printd(f"First message missing internal monologue and has badly formed arguments: {response_message}") + return False + if function_args.get("request_heartbeat") != True: + printd(f"First message missing internal monologue and does not chain further functions: {response_message}") + return False + else: + printd(f"First message missing internal monologue: {response_message}") + return False if response_message.get("content"): ### Extras