Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support pydantic v2 and v1 for with_structured_output #84

Merged
merged 2 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 20 additions & 18 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass

from langchain_nvidia_ai_endpoints._common import _NVIDIAClient
from langchain_nvidia_ai_endpoints._statics import Model
Expand Down Expand Up @@ -679,24 +680,6 @@ class Choices(enum.Enum):
output_parser: BaseOutputParser = JsonOutputParser()
nvext_param: Dict[str, Any] = {"guided_json": schema}

elif issubclass(schema, BaseModel):
# PydanticOutputParser does not support streaming. what we do
# instead is ignore all inputs that are incomplete wrt the
# underlying Pydantic schema. if the entire input is invalid,
# we return None.
class ForgivingPydanticOutputParser(PydanticOutputParser):
def parse_result(
self, result: List[Generation], *, partial: bool = False
) -> Any:
try:
return super().parse_result(result, partial=partial)
except OutputParserException:
pass
return None

output_parser = ForgivingPydanticOutputParser(pydantic_object=schema)
nvext_param = {"guided_json": schema.schema()}

elif issubclass(schema, enum.Enum):
# langchain's EnumOutputParser is not in langchain_core
# and doesn't support streaming. this is a simple implementation
Expand Down Expand Up @@ -724,6 +707,25 @@ def parse(self, response: str) -> Any:
)
output_parser = EnumOutputParser(enum=schema)
nvext_param = {"guided_choice": choices}

elif is_basemodel_subclass(schema):
# PydanticOutputParser does not support streaming. what we do
# instead is ignore all inputs that are incomplete wrt the
# underlying Pydantic schema. if the entire input is invalid,
# we return None.
class ForgivingPydanticOutputParser(PydanticOutputParser):
def parse_result(
self, result: List[Generation], *, partial: bool = False
) -> Any:
try:
return super().parse_result(result, partial=partial)
except OutputParserException:
pass
return None

output_parser = ForgivingPydanticOutputParser(pydantic_object=schema)
nvext_param = {"guided_json": schema.schema()}

else:
raise ValueError(
"Schema must be a Pydantic object, a dictionary "
Expand Down
2 changes: 1 addition & 1 deletion libs/ai-endpoints/scripts/check_pydantic.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fi
repository_path="$1"

# Search for lines matching the pattern within the specified repository
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic' | grep -v "# ignore: check_pydantic")

# Check if any matching lines were found
if [ -n "$result" ]; then
Expand Down
59 changes: 56 additions & 3 deletions libs/ai-endpoints/tests/unit_tests/test_structured_output.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import enum
import warnings
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Type

import pytest
from langchain_core.pydantic_v1 import BaseModel, Field
import requests_mock
from langchain_core.pydantic_v1 import BaseModel as lc_pydanticV1BaseModel
from langchain_core.pydantic_v1 import Field
from pydantic import BaseModel as pydanticV2BaseModel # ignore: check_pydantic
from pydantic.v1 import BaseModel as pydanticV1BaseModel # ignore: check_pydantic

from langchain_nvidia_ai_endpoints import ChatNVIDIA


class Joke(BaseModel):
class Joke(lc_pydanticV1BaseModel):
"""Joke to tell user."""

setup: str = Field(description="The setup of the joke")
Expand Down Expand Up @@ -136,3 +140,52 @@ def test_stream_enum_incomplete(
for chunk in structured_llm.stream("This is ignored."):
response = chunk
assert response is None


@pytest.mark.parametrize(
"pydanticBaseModel",
[
lc_pydanticV1BaseModel,
pydanticV1BaseModel,
pydanticV2BaseModel,
],
ids=["lc-pydantic-v1", "pydantic-v1", "pydantic-v2"],
)
def test_pydantic_version(
requests_mock: requests_mock.Mocker,
pydanticBaseModel: Type,
) -> None:
requests_mock.post(
"https://integrate.api.nvidia.com/v1/chat/completions",
json={
"id": "chatcmpl-ID",
"object": "chat.completion",
"created": 1234567890,
"model": "BOGUS",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": '{"name": "Sam Doe"}',
},
"logprobs": None,
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 22,
"completion_tokens": 20,
"total_tokens": 42,
},
"system_fingerprint": None,
},
)

class Person(pydanticBaseModel): # type: ignore
name: str

llm = ChatNVIDIA(api_key="BOGUS").with_structured_output(Person)
response = llm.invoke("This is ignored.")
assert isinstance(response, Person)
assert response.name == "Sam Doe"
Loading