Skip to content

Commit

Permalink
fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan committed Aug 7, 2024
1 parent 7416720 commit 13378c6
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 8 deletions.
8 changes: 5 additions & 3 deletions libs/partners/openai/langchain_openai/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,11 +1140,13 @@ def with_structured_output(
The method for steering model generation, one of:
- "function_calling":
Uses OpenAI's tool-calling (formerly called function calling) API:
https://platform.openai.com/docs/guides/function-calling
Uses OpenAI's tool-calling (formerly called function calling)
API: https://platform.openai.com/docs/guides/function-calling
- "json_schema":
Uses OpenAI's Structured Output API:
https://platform.openai.com/docs/guides/structured-outputs
https://platform.openai.com/docs/guides/structured-outputs.
Supported for "gpt-4o-mini", "gpt-4o-2024-08-06", and later
models.
- "json_mode":
Uses OpenAI's JSON mode. Note that if using JSON mode then you
must include instructions for formatting the output into the
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Test ChatOpenAI chat model."""

import base64
from typing import Any, AsyncIterator, List, Optional, cast
from typing import Any, AsyncIterator, List, Literal, Optional, cast

import httpx
import openai
Expand Down Expand Up @@ -796,13 +796,21 @@ class magic_function(BaseModel):
next(model_with_invalid_tool_schema.stream(query))


def test_structured_output_strict() -> None:
@pytest.mark.parametrize(
("model", "method", "strict"),
[("gpt-4o", "function_calling", True), ("gpt-4o-2024-08-06", "json_schema", None)],
)
def test_structured_output_strict(
model: str,
method: Literal["function_calling", "json_schema"],
strict: Optional[bool],
) -> None:
"""Test to verify structured output with strict=True."""

from pydantic import BaseModel as BaseModelProper
from pydantic import Field as FieldProper

model = ChatOpenAI(model="gpt-4o", temperature=0)
llm = ChatOpenAI(model=model, temperature=0)

class Joke(BaseModelProper):
"""Joke to tell user."""
Expand All @@ -814,15 +822,17 @@ class Joke(BaseModelProper):
# Type ignoring since the interface only officially supports pydantic 1
# or pydantic.v1.BaseModel but not pydantic.BaseModel from pydantic 2.
# We'll need to do a pass updating the type signatures.
chat = model.with_structured_output(Joke, strict=True) # type: ignore[arg-type]
chat = llm.with_structured_output(Joke, method=method, strict=strict)
result = chat.invoke("Tell me a joke about cats.")
assert isinstance(result, Joke)

for chunk in chat.stream("Tell me a joke about cats."):
assert isinstance(chunk, Joke)

# Schema
chat = model.with_structured_output(Joke.model_json_schema(), strict=True)
chat = llm.with_structured_output(
Joke.model_json_schema(), method=method, strict=strict
)
result = chat.invoke("Tell me a joke about cats.")
assert isinstance(result, dict)
assert set(result.keys()) == {"setup", "punchline"}
Expand All @@ -831,3 +841,27 @@ class Joke(BaseModelProper):
assert isinstance(chunk, dict)
assert isinstance(chunk, dict) # for mypy
assert set(chunk.keys()) == {"setup", "punchline"}

# Invalid schema with optional fields:
class InvalidJoke(BaseModelProper):
"""Joke to tell user."""

setup: str = FieldProper(description="question to set up a joke")
# Invalid field, can't have default value.
punchline: str = FieldProper(
default="foo", description="answer to resolve the joke"
)

chat = llm.with_structured_output(InvalidJoke, method=method, strict=strict)
with pytest.raises(openai.BadRequestError):
chat.invoke("Tell me a joke about cats.")
with pytest.raises(openai.BadRequestError):
next(chat.stream("Tell me a joke about cats."))

chat = llm.with_structured_output(
InvalidJoke.model_json_schema(), method=method, strict=strict
)
with pytest.raises(openai.BadRequestError):
chat.invoke("Tell me a joke about cats.")
with pytest.raises(openai.BadRequestError):
next(chat.stream("Tell me a joke about cats."))

0 comments on commit 13378c6

Please sign in to comment.