From 909927bff6fa1cca93b6d7120682094862a69027 Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Tue, 6 Aug 2024 01:34:50 +0800 Subject: [PATCH 1/5] test: add test for chat tool calls --- tests/assets/chat_tool.hprompt | 37 +++++++++++++++++++++++ tests/test_tool_call.py | 54 ++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+) create mode 100644 tests/assets/chat_tool.hprompt create mode 100644 tests/test_tool_call.py diff --git a/tests/assets/chat_tool.hprompt b/tests/assets/chat_tool.hprompt new file mode 100644 index 0000000..372925b --- /dev/null +++ b/tests/assets/chat_tool.hprompt @@ -0,0 +1,37 @@ +--- +meta: + var_map_path: var_map.txt + output_path: tmp-out/%Y-%m-%d/result.%H-%M-%S.hprompt + output_evaled_prompt_path: tmp-evaled/%Y-%m-%d/result.%H-%M-%S.hprompt +model: gpt-4o +temperature: 0.2 +tools: [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["location", "unit"] + } + } + } + ] +--- + +$system$ +You are a helpful assistant. + +$user$ +What's the weather like in San Francisco and New York? diff --git a/tests/test_tool_call.py b/tests/test_tool_call.py new file mode 100644 index 0000000..ec0fa6e --- /dev/null +++ b/tests/test_tool_call.py @@ -0,0 +1,54 @@ +from pathlib import Path +import re +from handyllm import load_from, ChatPrompt +import responses + + +tests_dir = Path(__file__).parent + +mock_fetch_data = { + "id": "chatcmpl-xxxxxxxxxxxxxxxxxxxxxx", + "object": "chat.completion", + "created": 1722818900, + "model": "gpt-4o-2024-05-13", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_xxxxxxxxxxxxxxxxxxxxxxxx", + "type": "function", + "function": { + "name": "get_current_weather", + "arguments": '{"location": "San Francisco, CA", "unit": "celsius"}', + }, + }, + { + "id": "call_yyyyyyyyyyyyyyyyyyyyyyyy", + "type": "function", + "function": { + "name": "get_current_weather", + "arguments": '{"location": "New York, NY", "unit": "celsius"}', + }, + }, + ], + }, + "logprobs": None, + "finish_reason": "tool_calls", + } + ], + "usage": {"prompt_tokens": 89, "completion_tokens": 62, "total_tokens": 151}, + "system_fingerprint": "fp_3cd8b62c3b", +} + + +@responses.activate +def test_tool_call(): + responses.add(responses.POST, url=re.compile(r".*"), json=mock_fetch_data) + prompt_file = tests_dir / "assets" / "chat_tool.hprompt" + prompt = load_from(prompt_file, cls=ChatPrompt) + response = prompt.fetch(api_key="fake-key") + assert "tool_calls" in response.choices[0].message From e5a0bad1224446e405db27e57f62552d661cc891 Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Tue, 6 Aug 2024 01:37:22 +0800 Subject: [PATCH 2/5] test: add more asserts for tool call results --- tests/test_tool_call.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/test_tool_call.py b/tests/test_tool_call.py index ec0fa6e..5001284 100644 --- a/tests/test_tool_call.py +++ b/tests/test_tool_call.py @@ -52,3 +52,19 @@ def test_tool_call(): prompt = load_from(prompt_file, cls=ChatPrompt) response = prompt.fetch(api_key="fake-key") assert "tool_calls" in response.choices[0].message + assert ( + response.choices[0].message["tool_calls"][0]["function"]["name"] + == "get_current_weather" + ) + assert ( + response.choices[0].message["tool_calls"][0]["function"]["arguments"] + == '{"location": "San Francisco, CA", "unit": "celsius"}' + ) + assert ( + response.choices[0].message["tool_calls"][1]["function"]["name"] + == "get_current_weather" + ) + assert ( + response.choices[0].message["tool_calls"][1]["function"]["arguments"] + == '{"location": "New York, NY", "unit": "celsius"}' + ) From 58157023e0f4f0e5fc89ea594a7f78f4733f7ed3 Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Tue, 6 Aug 2024 22:54:08 +0800 Subject: [PATCH 3/5] test: add test for stream tool call --- tests/test_tool_call.py | 499 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 498 insertions(+), 1 deletion(-) diff --git a/tests/test_tool_call.py b/tests/test_tool_call.py index 5001284..bb19c45 100644 --- a/tests/test_tool_call.py +++ b/tests/test_tool_call.py @@ -1,6 +1,8 @@ +import json from pathlib import Path import re -from handyllm import load_from, ChatPrompt +from handyllm import load_from, ChatPrompt, stream_chat_all +from pytest import CaptureFixture import responses @@ -44,6 +46,478 @@ "system_fingerprint": "fp_3cd8b62c3b", } +mock_stream_data = [ + { + "id": "chatcmpl-xxxxxxxxxxxxxxxxxxxxx", + "object": "chat.completion.chunk", + "created": 1722879508, + "model": "gpt-4o-2024-05-13", + "system_fingerprint": "fp_c832e4513b", + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": None}, + "logprobs": None, + "finish_reason": None, + } + ], + }, + { + "id": "chatcmpl-xxxxxxxxxxxxxxxxxxxxx", + "object": "chat.completion.chunk", + "created": 1722879508, + "model": "gpt-4o-2024-05-13", + "system_fingerprint": "fp_c832e4513b", + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [ + { + "index": 0, + "id": "call_xxxxxxxxxxxxxxxxxxxxxxxx", + "type": "function", + "function": { + "name": "get_current_weather", + "arguments": "", + }, + } + ] + }, + "logprobs": None, + "finish_reason": None, + } + ], + }, + { + "id": "chatcmpl-xxxxxxxxxxxxxxxxxxxxx", + "object": "chat.completion.chunk", + "created": 1722879508, + "model": "gpt-4o-2024-05-13", + "system_fingerprint": "fp_c832e4513b", + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [{"index": 0, "function": {"arguments": '{"lo'}}] + }, + "logprobs": None, + "finish_reason": None, + } + ], + }, + { + "id": "chatcmpl-xxxxxxxxxxxxxxxxxxxxx", + "object": "chat.completion.chunk", + "created": 1722879508, + "model": "gpt-4o-2024-05-13", + "system_fingerprint": "fp_c832e4513b", + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [{"index": 0, "function": {"arguments": "catio"}}] + }, + "logprobs": None, + "finish_reason": None, + } + ], + }, + { + "id": "chatcmpl-xxxxxxxxxxxxxxxxxxxxx", + "object": "chat.completion.chunk", + "created": 1722879508, + "model": "gpt-4o-2024-05-13", + "system_fingerprint": "fp_c832e4513b", + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [{"index": 0, "function": {"arguments": 'n": "S'}}] + }, + "logprobs": None, + "finish_reason": None, + } + ], + }, + { + "id": "chatcmpl-xxxxxxxxxxxxxxxxxxxxx", + "object": "chat.completion.chunk", + "created": 1722879508, + "model": "gpt-4o-2024-05-13", + "system_fingerprint": "fp_c832e4513b", + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [{"index": 0, "function": {"arguments": "an F"}}] + }, + "logprobs": None, + "finish_reason": None, + } + ], + }, + { + "id": "chatcmpl-xxxxxxxxxxxxxxxxxxxxx", + "object": "chat.completion.chunk", + "created": 1722879508, + "model": "gpt-4o-2024-05-13", + "system_fingerprint": "fp_c832e4513b", + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [{"index": 0, "function": {"arguments": "ranci"}}] + }, + "logprobs": None, + "finish_reason": None, + } + ], + }, + { + "id": "chatcmpl-xxxxxxxxxxxxxxxxxxxxx", + "object": "chat.completion.chunk", + "created": 1722879508, + "model": "gpt-4o-2024-05-13", + "system_fingerprint": "fp_c832e4513b", + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [{"index": 0, "function": {"arguments": "sco, C"}}] + }, + "logprobs": None, + "finish_reason": None, + } + ], + }, + { + "id": "chatcmpl-xxxxxxxxxxxxxxxxxxxxx", + "object": "chat.completion.chunk", + "created": 1722879508, + "model": "gpt-4o-2024-05-13", + "system_fingerprint": "fp_c832e4513b", + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [{"index": 0, "function": {"arguments": 'A", '}}] + }, + "logprobs": None, + "finish_reason": None, + } + ], + }, + { + "id": "chatcmpl-xxxxxxxxxxxxxxxxxxxxx", + "object": "chat.completion.chunk", + "created": 1722879508, + "model": "gpt-4o-2024-05-13", + "system_fingerprint": "fp_c832e4513b", + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [{"index": 0, "function": {"arguments": '"unit'}}] + }, + "logprobs": None, + "finish_reason": None, + } + ], + }, + { + "id": "chatcmpl-xxxxxxxxxxxxxxxxxxxxx", + "object": "chat.completion.chunk", + "created": 1722879508, + "model": "gpt-4o-2024-05-13", + "system_fingerprint": "fp_c832e4513b", + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [{"index": 0, "function": {"arguments": '": "fa'}}] + }, + "logprobs": None, + "finish_reason": None, + } + ], + }, + { + "id": "chatcmpl-xxxxxxxxxxxxxxxxxxxxx", + "object": "chat.completion.chunk", + "created": 1722879508, + "model": "gpt-4o-2024-05-13", + "system_fingerprint": "fp_c832e4513b", + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [{"index": 0, "function": {"arguments": "hren"}}] + }, + "logprobs": None, + "finish_reason": None, + } + ], + }, + { + "id": "chatcmpl-xxxxxxxxxxxxxxxxxxxxx", + "object": "chat.completion.chunk", + "created": 1722879508, + "model": "gpt-4o-2024-05-13", + "system_fingerprint": "fp_c832e4513b", + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [{"index": 0, "function": {"arguments": 'heit"'}}] + }, + "logprobs": None, + "finish_reason": None, + } + ], + }, + { + "id": "chatcmpl-xxxxxxxxxxxxxxxxxxxxx", + "object": "chat.completion.chunk", + "created": 1722879508, + "model": "gpt-4o-2024-05-13", + "system_fingerprint": "fp_c832e4513b", + "choices": [ + { + "index": 0, + "delta": {"tool_calls": [{"index": 0, "function": {"arguments": "}"}}]}, + "logprobs": None, + "finish_reason": None, + } + ], + }, + { + "id": "chatcmpl-xxxxxxxxxxxxxxxxxxxxx", + "object": "chat.completion.chunk", + "created": 1722879508, + "model": "gpt-4o-2024-05-13", + "system_fingerprint": "fp_c832e4513b", + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [ + { + "index": 1, + "id": "call_yyyyyyyyyyyyyyyyyyyyyyyy", + "type": "function", + "function": { + "name": "get_current_weather", + "arguments": "", + }, + } + ] + }, + "logprobs": None, + "finish_reason": None, + } + ], + }, + { + "id": "chatcmpl-xxxxxxxxxxxxxxxxxxxxx", + "object": "chat.completion.chunk", + "created": 1722879508, + "model": "gpt-4o-2024-05-13", + "system_fingerprint": "fp_c832e4513b", + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [{"index": 1, "function": {"arguments": '{"lo'}}] + }, + "logprobs": None, + "finish_reason": None, + } + ], + }, + { + "id": "chatcmpl-xxxxxxxxxxxxxxxxxxxxx", + "object": "chat.completion.chunk", + "created": 1722879508, + "model": "gpt-4o-2024-05-13", + "system_fingerprint": "fp_c832e4513b", + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [{"index": 1, "function": {"arguments": "catio"}}] + }, + "logprobs": None, + "finish_reason": None, + } + ], + }, + { + "id": "chatcmpl-xxxxxxxxxxxxxxxxxxxxx", + "object": "chat.completion.chunk", + "created": 1722879508, + "model": "gpt-4o-2024-05-13", + "system_fingerprint": "fp_c832e4513b", + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [{"index": 1, "function": {"arguments": 'n": "N'}}] + }, + "logprobs": None, + "finish_reason": None, + } + ], + }, + { + "id": "chatcmpl-xxxxxxxxxxxxxxxxxxxxx", + "object": "chat.completion.chunk", + "created": 1722879508, + "model": "gpt-4o-2024-05-13", + "system_fingerprint": "fp_c832e4513b", + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [{"index": 1, "function": {"arguments": "ew Y"}}] + }, + "logprobs": None, + "finish_reason": None, + } + ], + }, + { + "id": "chatcmpl-xxxxxxxxxxxxxxxxxxxxx", + "object": "chat.completion.chunk", + "created": 1722879508, + "model": "gpt-4o-2024-05-13", + "system_fingerprint": "fp_c832e4513b", + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [{"index": 1, "function": {"arguments": "ork, "}}] + }, + "logprobs": None, + "finish_reason": None, + } + ], + }, + { + "id": "chatcmpl-xxxxxxxxxxxxxxxxxxxxx", + "object": "chat.completion.chunk", + "created": 1722879508, + "model": "gpt-4o-2024-05-13", + "system_fingerprint": "fp_c832e4513b", + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [{"index": 1, "function": {"arguments": 'NY", "'}}] + }, + "logprobs": None, + "finish_reason": None, + } + ], + }, + { + "id": "chatcmpl-xxxxxxxxxxxxxxxxxxxxx", + "object": "chat.completion.chunk", + "created": 1722879508, + "model": "gpt-4o-2024-05-13", + "system_fingerprint": "fp_c832e4513b", + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [{"index": 1, "function": {"arguments": "unit"}}] + }, + "logprobs": None, + "finish_reason": None, + } + ], + }, + { + "id": "chatcmpl-xxxxxxxxxxxxxxxxxxxxx", + "object": "chat.completion.chunk", + "created": 1722879508, + "model": "gpt-4o-2024-05-13", + "system_fingerprint": "fp_c832e4513b", + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [{"index": 1, "function": {"arguments": '": "f'}}] + }, + "logprobs": None, + "finish_reason": None, + } + ], + }, + { + "id": "chatcmpl-xxxxxxxxxxxxxxxxxxxxx", + "object": "chat.completion.chunk", + "created": 1722879508, + "model": "gpt-4o-2024-05-13", + "system_fingerprint": "fp_c832e4513b", + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [{"index": 1, "function": {"arguments": "ahrenh"}}] + }, + "logprobs": None, + "finish_reason": None, + } + ], + }, + { + "id": "chatcmpl-xxxxxxxxxxxxxxxxxxxxx", + "object": "chat.completion.chunk", + "created": 1722879508, + "model": "gpt-4o-2024-05-13", + "system_fingerprint": "fp_c832e4513b", + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [{"index": 1, "function": {"arguments": 'eit"'}}] + }, + "logprobs": None, + "finish_reason": None, + } + ], + }, + { + "id": "chatcmpl-xxxxxxxxxxxxxxxxxxxxx", + "object": "chat.completion.chunk", + "created": 1722879508, + "model": "gpt-4o-2024-05-13", + "system_fingerprint": "fp_c832e4513b", + "choices": [ + { + "index": 0, + "delta": {"tool_calls": [{"index": 1, "function": {"arguments": "}"}}]}, + "logprobs": None, + "finish_reason": None, + } + ], + }, + { + "id": "chatcmpl-xxxxxxxxxxxxxxxxxxxxx", + "object": "chat.completion.chunk", + "created": 1722879508, + "model": "gpt-4o-2024-05-13", + "system_fingerprint": "fp_c832e4513b", + "choices": [ + {"index": 0, "delta": {}, "logprobs": None, "finish_reason": "tool_calls"} + ], + }, +] +tmp = ["data: " + json.dumps(data) for data in mock_stream_data] +tmp.append("data: [DONE]") +stream_body = "\n".join(tmp) + @responses.activate def test_tool_call(): @@ -68,3 +542,26 @@ def test_tool_call(): response.choices[0].message["tool_calls"][1]["function"]["arguments"] == '{"location": "New York, NY", "unit": "celsius"}' ) + + +@responses.activate +def test_tool_call_stream(capsys: CaptureFixture[str]): + responses.add(responses.POST, url=re.compile(r".*"), body=stream_body) + prompt_file = tests_dir / "assets" / "chat_tool.hprompt" + prompt = load_from(prompt_file, cls=ChatPrompt) + response = prompt.stream(api_key="fake-key") + tool_calls = [] + for chunk in stream_chat_all(response): + role, content, tool_call = chunk + tool_calls.append(tool_call) + assert role == "assistant" + assert content is None + assert len(tool_calls) == 2 + assert tool_calls[0]["function"]["name"] == "get_current_weather" + assert tool_calls[0]["function"]["arguments"] == '{"location": "San Francisco, CA", "unit": "fahrenheit"}' + assert tool_calls[1]["function"]["name"] == "get_current_weather" + assert tool_calls[1]["function"]["arguments"] == '{"location": "New York, NY", "unit": "fahrenheit"}' + + # make sure no debug prints + captured = capsys.readouterr() + assert captured.out == "" From a096d88b6846daea514336d270f28ef5eaa55535 Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Tue, 6 Aug 2024 23:24:11 +0800 Subject: [PATCH 4/5] fix(utils.trans_stream_chat): yield last tool call --- src/handyllm/hprompt.py | 11 +++++++++++ src/handyllm/utils.py | 24 +++++++++++++++++++----- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/src/handyllm/hprompt.py b/src/handyllm/hprompt.py index 0729bcd..174d554 100644 --- a/src/handyllm/hprompt.py +++ b/src/handyllm/hprompt.py @@ -731,6 +731,10 @@ def _stream_with_client( run_config.on_chunk = cast(SyncHandlerChat, run_config.on_chunk) run_config.on_chunk(*ret) yield chat_chunk + ret = producer.send(None) # signal the end of the stream + if run_config.on_chunk and ret: + run_config.on_chunk = cast(SyncHandlerChat, run_config.on_chunk) + run_config.on_chunk(*ret) producer.close() @classmethod @@ -756,6 +760,13 @@ async def _astream_with_client( run_config.on_chunk = cast(SyncHandlerChat, run_config.on_chunk) run_config.on_chunk(*ret) yield chat_chunk + ret = producer.send(None) # signal the end of the stream + if run_config.on_chunk and ret: + if inspect.iscoroutinefunction(run_config.on_chunk): + await run_config.on_chunk(*ret) + else: + run_config.on_chunk = cast(SyncHandlerChat, run_config.on_chunk) + run_config.on_chunk(*ret) producer.close() @classmethod diff --git a/src/handyllm/utils.py b/src/handyllm/utils.py index a780a08..c72e940 100644 --- a/src/handyllm/utils.py +++ b/src/handyllm/utils.py @@ -21,6 +21,7 @@ ] import base64 +import copy from pathlib import Path from typing import ( IO, @@ -69,7 +70,7 @@ def download_binary(download_url, file_path=None, dir="."): def trans_stream_chat( consumer: Generator[YieldType, ShortChatChunk, None], -) -> Generator[Optional[YieldType], ChatChunk, None]: +) -> Generator[Optional[YieldType], Optional[ChatChunk], None]: next(consumer) # prime the generator role = "" tool_call = ToolCallDelta() @@ -77,6 +78,8 @@ def trans_stream_chat( try: while True: data = yield ret + if data is None: + break ret = None try: message = data["choices"][0]["delta"] @@ -93,19 +96,24 @@ def trans_stream_chat( "arguments" ] else: - # this is a new tool call, yield the previous one - ret = consumer.send((role, content, tool_call)) + if tool_call: + # this is a new tool call, yield the previous one + ret = consumer.send((role, content, tool_call)) # reset the tool call - tool_call = ToolCallDelta(chunk) + tool_call = copy.deepcopy(chunk) elif content: ret = consumer.send((role, content, tool_call)) except (KeyError, IndexError): pass - except GeneratorExit: if tool_call: # yield the last tool call ret = consumer.send((role, None, tool_call)) + yield ret + else: + yield None consumer.close() + except GeneratorExit: + pass def echo_consumer(): @@ -123,6 +131,9 @@ def stream_chat_all( ret = producer.send(data) if ret is not None: yield ret + ret = producer.send(None) # signal the end of the stream + if ret is not None: + yield ret producer.close() @@ -154,6 +165,9 @@ async def astream_chat_all( ret = producer.send(data) if ret is not None: yield ret + ret = producer.send(None) # signal the end of the stream + if ret is not None: + yield ret producer.close() From 3fe005c2feeeccf26814b312ae17618d21f14c13 Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Tue, 6 Aug 2024 23:25:13 +0800 Subject: [PATCH 5/5] test(test_tool_call): lint format script --- tests/test_tool_call.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/test_tool_call.py b/tests/test_tool_call.py index bb19c45..c7dba93 100644 --- a/tests/test_tool_call.py +++ b/tests/test_tool_call.py @@ -558,9 +558,15 @@ def test_tool_call_stream(capsys: CaptureFixture[str]): assert content is None assert len(tool_calls) == 2 assert tool_calls[0]["function"]["name"] == "get_current_weather" - assert tool_calls[0]["function"]["arguments"] == '{"location": "San Francisco, CA", "unit": "fahrenheit"}' + assert ( + tool_calls[0]["function"]["arguments"] + == '{"location": "San Francisco, CA", "unit": "fahrenheit"}' + ) assert tool_calls[1]["function"]["name"] == "get_current_weather" - assert tool_calls[1]["function"]["arguments"] == '{"location": "New York, NY", "unit": "fahrenheit"}' + assert ( + tool_calls[1]["function"]["arguments"] + == '{"location": "New York, NY", "unit": "fahrenheit"}' + ) # make sure no debug prints captured = capsys.readouterr()