From 754c89d8514dde707255bf0e3602dfddff49f268 Mon Sep 17 00:00:00 2001 From: Yeesian Ng Date: Tue, 23 Apr 2024 11:06:25 -0700 Subject: [PATCH] fix: Parse intermediate steps from LangChain into JSON. PiperOrigin-RevId: 627444864 --- .../test_reasoning_engine_templates_langchain.py | 9 ++++++++- .../preview/reasoning_engines/templates/langchain.py | 8 ++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/tests/unit/vertex_langchain/test_reasoning_engine_templates_langchain.py b/tests/unit/vertex_langchain/test_reasoning_engine_templates_langchain.py index f54715a529..21bebc158f 100644 --- a/tests/unit/vertex_langchain/test_reasoning_engine_templates_langchain.py +++ b/tests/unit/vertex_langchain/test_reasoning_engine_templates_langchain.py @@ -29,6 +29,7 @@ from langchain_core import messages from langchain_core import outputs from langchain_core import tools as lc_tools +from langchain.load import dump as langchain_load_dump from langchain.tools.base import StructuredTool @@ -77,6 +78,12 @@ def vertexai_init_mock(): yield vertexai_init_mock +@pytest.fixture +def langchain_dump_mock(): + with mock.patch.object(langchain_load_dump, "dumpd") as langchain_dump_mock: + yield langchain_dump_mock + + @pytest.mark.usefixtures("google_auth_mock") class TestLangchainAgent: def setup_method(self): @@ -114,7 +121,7 @@ def test_set_up(self, vertexai_init_mock): agent.set_up() assert agent._runnable is not None - def test_query(self): + def test_query(self, langchain_dump_mock): agent = reasoning_engines.LangchainAgent(model=_TEST_MODEL) agent._runnable = mock.Mock() mocks = mock.Mock() diff --git a/vertexai/preview/reasoning_engines/templates/langchain.py b/vertexai/preview/reasoning_engines/templates/langchain.py index cd507330db..4eadbe1ae6 100644 --- a/vertexai/preview/reasoning_engines/templates/langchain.py +++ b/vertexai/preview/reasoning_engines/templates/langchain.py @@ -18,6 +18,7 @@ TYPE_CHECKING, Any, Callable, + Dict, List, Mapping, Optional, @@ -418,7 +419,7 @@ def query( input: Union[str, Mapping[str, Any]], config: Optional["RunnableConfig"] = None, **kwargs: Any, - ) -> Mapping[str, Any]: + ) -> Dict[str, Any]: """Queries the Agent with the given input and config. Args: @@ -433,8 +434,11 @@ def query( Returns: The output of querying the Agent with the given input and config. """ + from langchain.load import dump as langchain_load_dump if isinstance(input, str): input = {"input": input} if not self._runnable: self.set_up() - return self._runnable.invoke(input=input, config=config, **kwargs) + return langchain_load_dump.dumpd( + self._runnable.invoke(input=input, config=config, **kwargs) + )