diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 901d0b09f3a3d..f9f6841aaf092 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -1140,11 +1140,13 @@ def with_structured_output( The method for steering model generation, one of: - "function_calling": - Uses OpenAI's tool-calling (formerly called function calling) API: - https://platform.openai.com/docs/guides/function-calling + Uses OpenAI's tool-calling (formerly called function calling) + API: https://platform.openai.com/docs/guides/function-calling - "json_schema": Uses OpenAI's Structured Output API: - https://platform.openai.com/docs/guides/structured-outputs + https://platform.openai.com/docs/guides/structured-outputs. + Supported for "gpt-4o-mini", "gpt-4o-2024-08-06", and later + models. - "json_mode": Uses OpenAI's JSON mode. Note that if using JSON mode then you must include instructions for formatting the output into the diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index 1624363ed0e18..3e03755b765bd 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -1,7 +1,7 @@ """Test ChatOpenAI chat model.""" import base64 -from typing import Any, AsyncIterator, List, Optional, cast +from typing import Any, AsyncIterator, List, Literal, Optional, cast import httpx import openai @@ -796,13 +796,21 @@ class magic_function(BaseModel): next(model_with_invalid_tool_schema.stream(query)) -def test_structured_output_strict() -> None: +@pytest.mark.parametrize( + ("model", "method", "strict"), + [("gpt-4o", "function_calling", True), ("gpt-4o-2024-08-06", "json_schema", None)], +) +def test_structured_output_strict( + model: str, + method: Literal["function_calling", "json_schema"], + strict: Optional[bool], +) -> None: """Test to verify structured output with strict=True.""" from pydantic import BaseModel as BaseModelProper from pydantic import Field as FieldProper - model = ChatOpenAI(model="gpt-4o", temperature=0) + llm = ChatOpenAI(model=model, temperature=0) class Joke(BaseModelProper): """Joke to tell user.""" @@ -814,7 +822,7 @@ class Joke(BaseModelProper): # Type ignoring since the interface only officially supports pydantic 1 # or pydantic.v1.BaseModel but not pydantic.BaseModel from pydantic 2. # We'll need to do a pass updating the type signatures. - chat = model.with_structured_output(Joke, strict=True) # type: ignore[arg-type] + chat = llm.with_structured_output(Joke, method=method, strict=strict) result = chat.invoke("Tell me a joke about cats.") assert isinstance(result, Joke) @@ -822,7 +830,9 @@ class Joke(BaseModelProper): assert isinstance(chunk, Joke) # Schema - chat = model.with_structured_output(Joke.model_json_schema(), strict=True) + chat = llm.with_structured_output( + Joke.model_json_schema(), method=method, strict=strict + ) result = chat.invoke("Tell me a joke about cats.") assert isinstance(result, dict) assert set(result.keys()) == {"setup", "punchline"} @@ -831,3 +841,27 @@ class Joke(BaseModelProper): assert isinstance(chunk, dict) assert isinstance(chunk, dict) # for mypy assert set(chunk.keys()) == {"setup", "punchline"} + + # Invalid schema with optional fields: + class InvalidJoke(BaseModelProper): + """Joke to tell user.""" + + setup: str = FieldProper(description="question to set up a joke") + # Invalid field, can't have default value. + punchline: str = FieldProper( + default="foo", description="answer to resolve the joke" + ) + + chat = llm.with_structured_output(InvalidJoke, method=method, strict=strict) + with pytest.raises(openai.BadRequestError): + chat.invoke("Tell me a joke about cats.") + with pytest.raises(openai.BadRequestError): + next(chat.stream("Tell me a joke about cats.")) + + chat = llm.with_structured_output( + InvalidJoke.model_json_schema(), method=method, strict=strict + ) + with pytest.raises(openai.BadRequestError): + chat.invoke("Tell me a joke about cats.") + with pytest.raises(openai.BadRequestError): + next(chat.stream("Tell me a joke about cats."))