Skip to content

Commit

Permalink
Merge pull request #43 from atomiechen/fix_tool_call
Browse files Browse the repository at this point in the history
Fix tool call
  • Loading branch information
atomiechen authored Aug 6, 2024
2 parents cdf8400 + 3fe005c commit c62c088
Show file tree
Hide file tree
Showing 4 changed files with 640 additions and 5 deletions.
11 changes: 11 additions & 0 deletions src/handyllm/hprompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,10 @@ def _stream_with_client(
run_config.on_chunk = cast(SyncHandlerChat, run_config.on_chunk)
run_config.on_chunk(*ret)
yield chat_chunk
ret = producer.send(None) # signal the end of the stream
if run_config.on_chunk and ret:
run_config.on_chunk = cast(SyncHandlerChat, run_config.on_chunk)
run_config.on_chunk(*ret)
producer.close()

@classmethod
Expand All @@ -756,6 +760,13 @@ async def _astream_with_client(
run_config.on_chunk = cast(SyncHandlerChat, run_config.on_chunk)
run_config.on_chunk(*ret)
yield chat_chunk
ret = producer.send(None) # signal the end of the stream
if run_config.on_chunk and ret:
if inspect.iscoroutinefunction(run_config.on_chunk):
await run_config.on_chunk(*ret)
else:
run_config.on_chunk = cast(SyncHandlerChat, run_config.on_chunk)
run_config.on_chunk(*ret)
producer.close()

@classmethod
Expand Down
24 changes: 19 additions & 5 deletions src/handyllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
]

import base64
import copy
from pathlib import Path
from typing import (
IO,
Expand Down Expand Up @@ -69,14 +70,16 @@ def download_binary(download_url, file_path=None, dir="."):

def trans_stream_chat(
consumer: Generator[YieldType, ShortChatChunk, None],
) -> Generator[Optional[YieldType], ChatChunk, None]:
) -> Generator[Optional[YieldType], Optional[ChatChunk], None]:
next(consumer) # prime the generator
role = ""
tool_call = ToolCallDelta()
ret = None
try:
while True:
data = yield ret
if data is None:
break
ret = None
try:
message = data["choices"][0]["delta"]
Expand All @@ -93,19 +96,24 @@ def trans_stream_chat(
"arguments"
]
else:
# this is a new tool call, yield the previous one
ret = consumer.send((role, content, tool_call))
if tool_call:
# this is a new tool call, yield the previous one
ret = consumer.send((role, content, tool_call))
# reset the tool call
tool_call = ToolCallDelta(chunk)
tool_call = copy.deepcopy(chunk)
elif content:
ret = consumer.send((role, content, tool_call))
except (KeyError, IndexError):
pass
except GeneratorExit:
if tool_call:
# yield the last tool call
ret = consumer.send((role, None, tool_call))
yield ret
else:
yield None
consumer.close()
except GeneratorExit:
pass


def echo_consumer():
Expand All @@ -123,6 +131,9 @@ def stream_chat_all(
ret = producer.send(data)
if ret is not None:
yield ret
ret = producer.send(None) # signal the end of the stream
if ret is not None:
yield ret
producer.close()


Expand Down Expand Up @@ -154,6 +165,9 @@ async def astream_chat_all(
ret = producer.send(data)
if ret is not None:
yield ret
ret = producer.send(None) # signal the end of the stream
if ret is not None:
yield ret
producer.close()


Expand Down
37 changes: 37 additions & 0 deletions tests/assets/chat_tool.hprompt
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
---
meta:
var_map_path: var_map.txt
output_path: tmp-out/%Y-%m-%d/result.%H-%M-%S.hprompt
output_evaled_prompt_path: tmp-evaled/%Y-%m-%d/result.%H-%M-%S.hprompt
model: gpt-4o
temperature: 0.2
tools: [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["location", "unit"]
}
}
}
]
---

$system$
You are a helpful assistant.

$user$
What's the weather like in San Francisco and New York?
Loading

0 comments on commit c62c088

Please sign in to comment.