From 794408cd3103d5854d83ea2c8e11ef0c19b61a0f Mon Sep 17 00:00:00 2001 From: OpenHands Date: Wed, 4 Dec 2024 15:32:08 -0500 Subject: [PATCH] Fix issue #5383: [Bug]: LLM Cost is added to the `metrics` twice (#5396) Co-authored-by: Engel Nyst --- openhands/llm/llm.py | 40 +++++++++++++++++++++++----------------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/openhands/llm/llm.py b/openhands/llm/llm.py index 700c3827fda0..85010b3fec73 100644 --- a/openhands/llm/llm.py +++ b/openhands/llm/llm.py @@ -219,6 +219,20 @@ def wrapper(*args, **kwargs): ) resp.choices[0].message = fn_call_response_message + message_back: str = resp['choices'][0]['message']['content'] or '' + tool_calls = resp['choices'][0]['message'].get('tool_calls', []) + if tool_calls: + for tool_call in tool_calls: + fn_name = tool_call.function.name + fn_args = tool_call.function.arguments + message_back += f'\nFunction call: {fn_name}({fn_args})' + + # log the LLM response + self.log_response(message_back) + + # post-process the response first to calculate cost + cost = self._post_completion(resp) + # log for evals or other scripts that need the raw completion if self.config.log_completions: assert self.config.log_completions_folder is not None @@ -228,37 +242,27 @@ def wrapper(*args, **kwargs): f'{self.metrics.model_name.replace("/", "__")}-{time.time()}.json', ) + # set up the dict to be logged _d = { 'messages': messages, 'response': resp, 'args': args, 'kwargs': {k: v for k, v in kwargs.items() if k != 'messages'}, 'timestamp': time.time(), - 'cost': self._completion_cost(resp), + 'cost': cost, } + + # if non-native function calling, save messages/response separately if mock_function_calling: - # Overwrite response as non-fncall to be consistent with `messages`` + # Overwrite response as non-fncall to be consistent with messages _d['response'] = non_fncall_response + # Save fncall_messages/response separately _d['fncall_messages'] = original_fncall_messages _d['fncall_response'] = resp with open(log_file, 'w') as f: f.write(json.dumps(_d)) - message_back: str = resp['choices'][0]['message']['content'] or '' - tool_calls = resp['choices'][0]['message'].get('tool_calls', []) - if tool_calls: - for tool_call in tool_calls: - fn_name = tool_call.function.name - fn_args = tool_call.function.arguments - message_back += f'\nFunction call: {fn_name}({fn_args})' - - # log the LLM response - self.log_response(message_back) - - # post-process the response - self._post_completion(resp) - return resp except APIError as e: if 'Attention Required! | Cloudflare' in str(e): @@ -414,7 +418,7 @@ def is_function_calling_active(self) -> bool: ) return model_name_supported - def _post_completion(self, response: ModelResponse) -> None: + def _post_completion(self, response: ModelResponse) -> float: """Post-process the completion response. Logs the cost and usage stats of the completion call. @@ -472,6 +476,8 @@ def _post_completion(self, response: ModelResponse) -> None: if stats: logger.debug(stats) + return cur_cost + def get_token_count(self, messages) -> int: """Get the number of tokens in a list of messages.