diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py index 0b2e1408..e64e9d2d 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py @@ -50,6 +50,7 @@ from langchain_core.runnables import Runnable from langchain_core.tools import BaseTool from langchain_core.utils.function_calling import convert_to_openai_tool +from langchain_core.utils.pydantic import is_basemodel_subclass from langchain_nvidia_ai_endpoints._common import _NVIDIAClient from langchain_nvidia_ai_endpoints._statics import Model @@ -679,24 +680,6 @@ class Choices(enum.Enum): output_parser: BaseOutputParser = JsonOutputParser() nvext_param: Dict[str, Any] = {"guided_json": schema} - elif issubclass(schema, BaseModel): - # PydanticOutputParser does not support streaming. what we do - # instead is ignore all inputs that are incomplete wrt the - # underlying Pydantic schema. if the entire input is invalid, - # we return None. - class ForgivingPydanticOutputParser(PydanticOutputParser): - def parse_result( - self, result: List[Generation], *, partial: bool = False - ) -> Any: - try: - return super().parse_result(result, partial=partial) - except OutputParserException: - pass - return None - - output_parser = ForgivingPydanticOutputParser(pydantic_object=schema) - nvext_param = {"guided_json": schema.schema()} - elif issubclass(schema, enum.Enum): # langchain's EnumOutputParser is not in langchain_core # and doesn't support streaming. this is a simple implementation @@ -724,6 +707,25 @@ def parse(self, response: str) -> Any: ) output_parser = EnumOutputParser(enum=schema) nvext_param = {"guided_choice": choices} + + elif is_basemodel_subclass(schema): + # PydanticOutputParser does not support streaming. what we do + # instead is ignore all inputs that are incomplete wrt the + # underlying Pydantic schema. if the entire input is invalid, + # we return None. + class ForgivingPydanticOutputParser(PydanticOutputParser): + def parse_result( + self, result: List[Generation], *, partial: bool = False + ) -> Any: + try: + return super().parse_result(result, partial=partial) + except OutputParserException: + pass + return None + + output_parser = ForgivingPydanticOutputParser(pydantic_object=schema) + nvext_param = {"guided_json": schema.schema()} + else: raise ValueError( "Schema must be a Pydantic object, a dictionary " diff --git a/libs/ai-endpoints/scripts/check_pydantic.sh b/libs/ai-endpoints/scripts/check_pydantic.sh index 06b5bb81..d0fa31d6 100755 --- a/libs/ai-endpoints/scripts/check_pydantic.sh +++ b/libs/ai-endpoints/scripts/check_pydantic.sh @@ -14,7 +14,7 @@ fi repository_path="$1" # Search for lines matching the pattern within the specified repository -result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic') +result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic' | grep -v "# ignore: check_pydantic") # Check if any matching lines were found if [ -n "$result" ]; then diff --git a/libs/ai-endpoints/tests/unit_tests/test_structured_output.py b/libs/ai-endpoints/tests/unit_tests/test_structured_output.py index 0c8aa626..053b10b3 100644 --- a/libs/ai-endpoints/tests/unit_tests/test_structured_output.py +++ b/libs/ai-endpoints/tests/unit_tests/test_structured_output.py @@ -1,14 +1,18 @@ import enum import warnings -from typing import Callable, List, Optional +from typing import Callable, List, Optional, Type import pytest -from langchain_core.pydantic_v1 import BaseModel, Field +import requests_mock +from langchain_core.pydantic_v1 import BaseModel as lc_pydanticV1BaseModel +from langchain_core.pydantic_v1 import Field +from pydantic import BaseModel as pydanticV2BaseModel # ignore: check_pydantic +from pydantic.v1 import BaseModel as pydanticV1BaseModel # ignore: check_pydantic from langchain_nvidia_ai_endpoints import ChatNVIDIA -class Joke(BaseModel): +class Joke(lc_pydanticV1BaseModel): """Joke to tell user.""" setup: str = Field(description="The setup of the joke") @@ -136,3 +140,52 @@ def test_stream_enum_incomplete( for chunk in structured_llm.stream("This is ignored."): response = chunk assert response is None + + +@pytest.mark.parametrize( + "pydanticBaseModel", + [ + lc_pydanticV1BaseModel, + pydanticV1BaseModel, + pydanticV2BaseModel, + ], + ids=["lc-pydantic-v1", "pydantic-v1", "pydantic-v2"], +) +def test_pydantic_version( + requests_mock: requests_mock.Mocker, + pydanticBaseModel: Type, +) -> None: + requests_mock.post( + "https://integrate.api.nvidia.com/v1/chat/completions", + json={ + "id": "chatcmpl-ID", + "object": "chat.completion", + "created": 1234567890, + "model": "BOGUS", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": '{"name": "Sam Doe"}', + }, + "logprobs": None, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 22, + "completion_tokens": 20, + "total_tokens": 42, + }, + "system_fingerprint": None, + }, + ) + + class Person(pydanticBaseModel): # type: ignore + name: str + + llm = ChatNVIDIA(api_key="BOGUS").with_structured_output(Person) + response = llm.invoke("This is ignored.") + assert isinstance(response, Person) + assert response.name == "Sam Doe"