Skip to content

Commit

Permalink
Move yield of metrics chunk after generation chunk
Browse files Browse the repository at this point in the history
  - when using mistral and streaming is enabled,the final
    chunk includes a stop_reason. There is nothing to say this final
    chunk doesn't also include some generated text. The existing
    implementation would result in that final chunk never getting
    sent back
  - this update moves the yield of the metrics chunk after the
    generation chunk
  - also included a change to include invocation metrics for cohere
    models
  • Loading branch information
ihmaws committed Sep 25, 2024
1 parent 6ad78b7 commit d653f01
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 13 deletions.
25 changes: 12 additions & 13 deletions libs/aws/langchain_aws/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,20 @@ def prepare_output_stream(
if provider == "cohere" and (
chunk_obj["is_finished"] or chunk_obj[output_key] == "<EOS_TOKEN>"
):
yield _get_invocation_metrics_chunk(chunk_obj)
return

elif (
generation_chunk = _stream_response_to_generation_chunk(
chunk_obj,
provider=provider,
output_key=output_key,
messages_api=messages_api,
coerce_content_to_string=coerce_content_to_string,
)
if generation_chunk:
yield generation_chunk

if (
provider == "mistral"
and chunk_obj.get(output_key, [{}])[0].get("stop_reason", "") == "stop"
):
Expand All @@ -384,18 +395,6 @@ def prepare_output_stream(
yield _get_invocation_metrics_chunk(chunk_obj)
return

generation_chunk = _stream_response_to_generation_chunk(
chunk_obj,
provider=provider,
output_key=output_key,
messages_api=messages_api,
coerce_content_to_string=coerce_content_to_string,
)
if generation_chunk:
yield generation_chunk
else:
continue

@classmethod
async def aprepare_output_stream(
cls,
Expand Down
21 changes: 21 additions & 0 deletions libs/aws/tests/unit_tests/llms/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,11 @@ def test__human_assistant_format() -> None:
{"chunk": {"bytes": b'{"text": " you"}'}},
]

MOCK_STREAMING_RESPONSE_MISTRAL = [
{"chunk": {"bytes": b'{"outputs": [{"text": "Thank","stop_reason": null}]}'}},
{"chunk": {"bytes": b'{"outputs": [{"text": "you.","stop_reason": "stop"}]}'}},
]


async def async_gen_mock_streaming_response() -> AsyncGenerator[Dict, None]:
for item in MOCK_STREAMING_RESPONSE:
Expand Down Expand Up @@ -330,6 +335,13 @@ def mistral_response():

return response

@pytest.fixture
def mistral_streaming_response():
response = dict(
body=MOCK_STREAMING_RESPONSE_MISTRAL
)
return response


@pytest.fixture
def cohere_response():
Expand Down Expand Up @@ -411,6 +423,15 @@ def test_prepare_output_for_mistral(mistral_response):
assert result["usage"]["total_tokens"] == 46
assert result["stop_reason"] is None

def test_prepare_output_stream_for_mistral(mistral_streaming_response) -> None:
results = [
chunk.text
for chunk in LLMInputOutputAdapter.prepare_output_stream("mistral", mistral_streaming_response)

Check failure on line 429 in libs/aws/tests/unit_tests/llms/test_bedrock.py

View workflow job for this annotation

GitHub Actions / cd libs/aws / make lint #3.9

Ruff (E501)

tests/unit_tests/llms/test_bedrock.py:429:89: E501 Line too long (103 > 88)

Check failure on line 429 in libs/aws/tests/unit_tests/llms/test_bedrock.py

View workflow job for this annotation

GitHub Actions / cd libs/aws / make lint #3.12

Ruff (E501)

tests/unit_tests/llms/test_bedrock.py:429:89: E501 Line too long (103 > 88)
]

assert results[0] == "Thank"
assert results[1] == "you."


def test_prepare_output_for_cohere(cohere_response):
result = LLMInputOutputAdapter.prepare_output("cohere", cohere_response)
Expand Down

0 comments on commit d653f01

Please sign in to comment.