diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py index a151855..c877ee8 100644 --- a/src/mistral_common/protocol/instruct/normalize.py +++ b/src/mistral_common/protocol/instruct/normalize.py @@ -117,7 +117,7 @@ def _aggregate_assistant_messages(self, messages: List[UATS]) -> AssistantMessag weight: Optional[float] = None for message in messages: assert isinstance(message, self._assistant_message_class), "Expected assistant message" - if message.tool_calls is not None: + if message.tool_calls: for tool_call in message.tool_calls: normalized_tool_call = self._normalize_tool_call(tool_call) tool_calls.append(normalized_tool_call) diff --git a/tests/test_normalization.py b/tests/test_normalization.py index 9637a18..da11bf8 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -293,6 +293,10 @@ def test_normalize_funcalls(self, normalizer: InstructRequestNormalizer) -> None normalized = normalizer.from_chat_completion_request(request) assert normalized == gt + def test_normalize_empty_array_tool_calls(self, normalizer: InstructRequestNormalizer) -> None: + message = AssistantMessage(role="assistant", content="Hello", tool_calls=[]) + normalized_message = normalizer._aggregate_assistant_messages([message]) + assert normalized_message.content == "Hello" class TestFineTuningNormalizer: @pytest.fixture(autouse=True)