Skip to content

Commit

Permalink
Fix issue #5383: [Bug]: LLM Cost is added to the metrics twice (#5396)
Browse files Browse the repository at this point in the history
Co-authored-by: Engel Nyst <[email protected]>
  • Loading branch information
openhands-agent and enyst authored Dec 4, 2024
1 parent 9aa89e8 commit 794408c
Showing 1 changed file with 23 additions and 17 deletions.
40 changes: 23 additions & 17 deletions openhands/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,20 @@ def wrapper(*args, **kwargs):
)
resp.choices[0].message = fn_call_response_message

message_back: str = resp['choices'][0]['message']['content'] or ''
tool_calls = resp['choices'][0]['message'].get('tool_calls', [])
if tool_calls:
for tool_call in tool_calls:
fn_name = tool_call.function.name
fn_args = tool_call.function.arguments
message_back += f'\nFunction call: {fn_name}({fn_args})'

# log the LLM response
self.log_response(message_back)

# post-process the response first to calculate cost
cost = self._post_completion(resp)

# log for evals or other scripts that need the raw completion
if self.config.log_completions:
assert self.config.log_completions_folder is not None
Expand All @@ -228,37 +242,27 @@ def wrapper(*args, **kwargs):
f'{self.metrics.model_name.replace("/", "__")}-{time.time()}.json',
)

# set up the dict to be logged
_d = {
'messages': messages,
'response': resp,
'args': args,
'kwargs': {k: v for k, v in kwargs.items() if k != 'messages'},
'timestamp': time.time(),
'cost': self._completion_cost(resp),
'cost': cost,
}

# if non-native function calling, save messages/response separately
if mock_function_calling:
# Overwrite response as non-fncall to be consistent with `messages``
# Overwrite response as non-fncall to be consistent with messages
_d['response'] = non_fncall_response

# Save fncall_messages/response separately
_d['fncall_messages'] = original_fncall_messages
_d['fncall_response'] = resp
with open(log_file, 'w') as f:
f.write(json.dumps(_d))

message_back: str = resp['choices'][0]['message']['content'] or ''
tool_calls = resp['choices'][0]['message'].get('tool_calls', [])
if tool_calls:
for tool_call in tool_calls:
fn_name = tool_call.function.name
fn_args = tool_call.function.arguments
message_back += f'\nFunction call: {fn_name}({fn_args})'

# log the LLM response
self.log_response(message_back)

# post-process the response
self._post_completion(resp)

return resp
except APIError as e:
if 'Attention Required! | Cloudflare' in str(e):
Expand Down Expand Up @@ -414,7 +418,7 @@ def is_function_calling_active(self) -> bool:
)
return model_name_supported

def _post_completion(self, response: ModelResponse) -> None:
def _post_completion(self, response: ModelResponse) -> float:
"""Post-process the completion response.
Logs the cost and usage stats of the completion call.
Expand Down Expand Up @@ -472,6 +476,8 @@ def _post_completion(self, response: ModelResponse) -> None:
if stats:
logger.debug(stats)

return cur_cost

def get_token_count(self, messages) -> int:
"""Get the number of tokens in a list of messages.
Expand Down

0 comments on commit 794408c

Please sign in to comment.