Skip to content

Commit

Permalink
Fix predict_stream for AgentExecutor and other non-Runnable chains (m…
Browse files Browse the repository at this point in the history
…lflow#12518)

Signed-off-by: B-Step62 <[email protected]>
Signed-off-by: Yuki Watanabe <[email protected]>
Co-authored-by: Harutaka Kawamura <[email protected]>
  • Loading branch information
B-Step62 and harupy authored Jul 2, 2024
1 parent 0645dcf commit b96b36b
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 24 deletions.
34 changes: 10 additions & 24 deletions mlflow/langchain/api_request_parallel_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

import langchain.chains
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction

import mlflow
from mlflow.exceptions import MlflowException
Expand Down Expand Up @@ -114,28 +113,13 @@ def _prepare_to_serialize(self, response: dict):

if "intermediate_steps" in response:
steps = response["intermediate_steps"]
if (
isinstance(steps, tuple)
and len(steps) == 2
and isinstance(steps[0], AgentAction)
and isinstance(steps[1], str)
):
response["intermediate_steps"] = [
{
"tool": agent.tool,
"tool_input": agent.tool_input,
"log": agent.log,
"result": result,
}
for agent, result in response["intermediate_steps"]
]
else:
try:
# `AgentAction` objects are not yet implemented for serialization in `dumps`
# https://github.com/langchain-ai/langchain/issues/8815#issuecomment-1666763710
response["intermediate_steps"] = dumps(steps)
except Exception as e:
_logger.warning(f"Failed to serialize intermediate steps: {e!r}")
try:
# `AgentAction` objects are not JSON serializable
# https://github.com/langchain-ai/langchain/issues/8815#issuecomment-1666763710
response["intermediate_steps"] = dumps(steps)
except Exception as e:
_logger.warning(f"Failed to serialize intermediate steps: {e!r}")

# The `dumps` format for `Document` objects is noisy, so we will still have custom logic
if "source_documents" in response:
response["source_documents"] = [
Expand Down Expand Up @@ -218,7 +202,9 @@ def single_call_api(self, callback_handlers: Optional[List[BaseCallbackHandler]]
# to maintain existing code, single output chains will still return
# only the result
response = response.popitem()[1]
else:
elif not self.stream:
# DO NOT call _prepare_to_serialize for stream output. It will consume the generator
# until the end and the iterator will be empty when the user tries to consume it.
self._prepare_to_serialize(response)

return response
Expand Down
61 changes: 61 additions & 0 deletions tests/langchain/test_langchain_model_export.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import importlib
import inspect
import json
import os
import shutil
Expand Down Expand Up @@ -339,6 +340,21 @@ def test_langchain_model_predict():
assert result == [TEST_CONTENT]


@pytest.mark.skipif(
Version(langchain.__version__) < Version("0.0.354"),
reason="LLMChain does not support streaming before LangChain 0.0.354",
)
def test_langchain_model_predict_stream():
with _mock_request(return_value=_mock_chat_completion_response()):
model = create_openai_llmchain()
with mlflow.start_run():
logged_model = mlflow.langchain.log_model(model, "langchain_model")
loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri)
result = loaded_model.predict_stream([{"product": "MLflow"}])
assert inspect.isgenerator(result)
assert list(result) == [{"product": "MLflow", "text": "test"}]


def test_pyfunc_spark_udf_with_langchain_model(spark):
model = create_openai_llmchain()
with mlflow.start_run():
Expand Down Expand Up @@ -473,6 +489,41 @@ def test_langchain_agent_model_predict(return_intermediate_steps):
)


@pytest.mark.skipif(
Version(langchain.__version__) < Version("0.0.354"),
reason="AgentExecutor does not support streaming before LangChain 0.0.354",
)
def test_langchain_agent_model_predict_stream():
langchain_agent_output = {
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"choices": [
{
"index": 0,
"finish_reason": "stop",
"text": f"Final Answer: {TEST_CONTENT}",
}
],
"usage": {"prompt_tokens": 9, "completion_tokens": 12, "total_tokens": 21},
}
model = create_openai_llmagent()

with mlflow.start_run():
logged_model = mlflow.langchain.log_model(model, "langchain_model")
loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri)
langchain_input = {"input": "foo"}
with _mock_request(return_value=_MockResponse(200, langchain_agent_output)):
response = loaded_model.predict_stream([langchain_input])
assert inspect.isgenerator(response)
assert list(response) == [
{
"output": TEST_CONTENT,
"messages": [AIMessage(content=f"Final Answer: {TEST_CONTENT}")],
}
]


def test_langchain_native_log_and_load_qaevalchain():
# QAEvalChain is a subclass of LLMChain
model = create_qa_eval_chain()
Expand Down Expand Up @@ -3356,3 +3407,13 @@ def test_agent_executor_model_with_messages_input():
# we convert pandas dataframe back to records, and a single row will be
# wrapped inside a list.
assert pyfunc_model.predict(question) == ["Databricks"]

# Test stream output
response = pyfunc_model.predict_stream(question)
assert inspect.isgenerator(response)
assert list(response) == [
{
"output": "Databricks",
"messages": [AIMessage(content="Databricks")],
}
]

0 comments on commit b96b36b

Please sign in to comment.