diff --git a/mlflow/langchain/api_request_parallel_processor.py b/mlflow/langchain/api_request_parallel_processor.py index a38196040d81f..29c19b92a2802 100644 --- a/mlflow/langchain/api_request_parallel_processor.py +++ b/mlflow/langchain/api_request_parallel_processor.py @@ -27,7 +27,6 @@ import langchain.chains from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import AgentAction import mlflow from mlflow.exceptions import MlflowException @@ -114,28 +113,13 @@ def _prepare_to_serialize(self, response: dict): if "intermediate_steps" in response: steps = response["intermediate_steps"] - if ( - isinstance(steps, tuple) - and len(steps) == 2 - and isinstance(steps[0], AgentAction) - and isinstance(steps[1], str) - ): - response["intermediate_steps"] = [ - { - "tool": agent.tool, - "tool_input": agent.tool_input, - "log": agent.log, - "result": result, - } - for agent, result in response["intermediate_steps"] - ] - else: - try: - # `AgentAction` objects are not yet implemented for serialization in `dumps` - # https://github.com/langchain-ai/langchain/issues/8815#issuecomment-1666763710 - response["intermediate_steps"] = dumps(steps) - except Exception as e: - _logger.warning(f"Failed to serialize intermediate steps: {e!r}") + try: + # `AgentAction` objects are not JSON serializable + # https://github.com/langchain-ai/langchain/issues/8815#issuecomment-1666763710 + response["intermediate_steps"] = dumps(steps) + except Exception as e: + _logger.warning(f"Failed to serialize intermediate steps: {e!r}") + # The `dumps` format for `Document` objects is noisy, so we will still have custom logic if "source_documents" in response: response["source_documents"] = [ @@ -218,7 +202,9 @@ def single_call_api(self, callback_handlers: Optional[List[BaseCallbackHandler]] # to maintain existing code, single output chains will still return # only the result response = response.popitem()[1] - else: + elif not self.stream: + # DO NOT call _prepare_to_serialize for stream output. It will consume the generator + # until the end and the iterator will be empty when the user tries to consume it. self._prepare_to_serialize(response) return response diff --git a/tests/langchain/test_langchain_model_export.py b/tests/langchain/test_langchain_model_export.py index 6335fa7cd7b45..56cec9254cfc1 100644 --- a/tests/langchain/test_langchain_model_export.py +++ b/tests/langchain/test_langchain_model_export.py @@ -1,4 +1,5 @@ import importlib +import inspect import json import os import shutil @@ -339,6 +340,21 @@ def test_langchain_model_predict(): assert result == [TEST_CONTENT] +@pytest.mark.skipif( + Version(langchain.__version__) < Version("0.0.354"), + reason="LLMChain does not support streaming before LangChain 0.0.354", +) +def test_langchain_model_predict_stream(): + with _mock_request(return_value=_mock_chat_completion_response()): + model = create_openai_llmchain() + with mlflow.start_run(): + logged_model = mlflow.langchain.log_model(model, "langchain_model") + loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri) + result = loaded_model.predict_stream([{"product": "MLflow"}]) + assert inspect.isgenerator(result) + assert list(result) == [{"product": "MLflow", "text": "test"}] + + def test_pyfunc_spark_udf_with_langchain_model(spark): model = create_openai_llmchain() with mlflow.start_run(): @@ -473,6 +489,41 @@ def test_langchain_agent_model_predict(return_intermediate_steps): ) +@pytest.mark.skipif( + Version(langchain.__version__) < Version("0.0.354"), + reason="AgentExecutor does not support streaming before LangChain 0.0.354", +) +def test_langchain_agent_model_predict_stream(): + langchain_agent_output = { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "choices": [ + { + "index": 0, + "finish_reason": "stop", + "text": f"Final Answer: {TEST_CONTENT}", + } + ], + "usage": {"prompt_tokens": 9, "completion_tokens": 12, "total_tokens": 21}, + } + model = create_openai_llmagent() + + with mlflow.start_run(): + logged_model = mlflow.langchain.log_model(model, "langchain_model") + loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri) + langchain_input = {"input": "foo"} + with _mock_request(return_value=_MockResponse(200, langchain_agent_output)): + response = loaded_model.predict_stream([langchain_input]) + assert inspect.isgenerator(response) + assert list(response) == [ + { + "output": TEST_CONTENT, + "messages": [AIMessage(content=f"Final Answer: {TEST_CONTENT}")], + } + ] + + def test_langchain_native_log_and_load_qaevalchain(): # QAEvalChain is a subclass of LLMChain model = create_qa_eval_chain() @@ -3356,3 +3407,13 @@ def test_agent_executor_model_with_messages_input(): # we convert pandas dataframe back to records, and a single row will be # wrapped inside a list. assert pyfunc_model.predict(question) == ["Databricks"] + + # Test stream output + response = pyfunc_model.predict_stream(question) + assert inspect.isgenerator(response) + assert list(response) == [ + { + "output": "Databricks", + "messages": [AIMessage(content="Databricks")], + } + ]