From cacc739c5dab7a6a15cde228b122b79da982caea Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Wed, 24 Jul 2024 09:30:18 -0400 Subject: [PATCH] add with_structured_output support for Pydantic models, dicts and Enums (only include_raw=False) --- .../docs/chat/nvidia_ai_endpoints.ipynb | 113 +++++++++ .../langchain_nvidia_ai_endpoints/_statics.py | 6 + .../chat_models.py | 230 +++++++++++++++++- libs/ai-endpoints/pyproject.toml | 2 +- .../tests/integration_tests/conftest.py | 18 ++ .../test_structured_output.py | 186 ++++++++++++++ .../ai-endpoints/tests/unit_tests/conftest.py | 42 ++++ .../unit_tests/test_structured_output.py | 132 ++++++++++ 8 files changed, 716 insertions(+), 13 deletions(-) create mode 100644 libs/ai-endpoints/tests/integration_tests/test_structured_output.py create mode 100644 libs/ai-endpoints/tests/unit_tests/test_structured_output.py diff --git a/libs/ai-endpoints/docs/chat/nvidia_ai_endpoints.ipynb b/libs/ai-endpoints/docs/chat/nvidia_ai_endpoints.ipynb index e1025759..a9a58452 100644 --- a/libs/ai-endpoints/docs/chat/nvidia_ai_endpoints.ipynb +++ b/libs/ai-endpoints/docs/chat/nvidia_ai_endpoints.ipynb @@ -617,6 +617,119 @@ "source": [ "See [How to use chat models to call tools](https://python.langchain.com/v0.2/docs/how_to/tool_calling/) for additional examples." ] + }, + { + "cell_type": "markdown", + "id": "8d249662", + "metadata": {}, + "source": [ + "## Structured output\n", + "\n", + "Starting in v0.2.1, `ChatNVIDIA` supports [with_structured_output](https://api.python.langchain.com/en/latest/language_models/langchain_core.language_models.chat_models.BaseChatModel.html#langchain_core.language_models.chat_models.BaseChatModel.with_structured_output).\n", + "\n", + "`ChatNVIDIA` provides integration with the variety of models on [build.nvidia.com](https://build.nvidia.com) as well as local NIMs. Not all these model endpoints implement the structured output features. Be sure to select a model that does have structured output features for your experimention and applications.\n", + "\n", + "Note: `include_raw` is not supported. You can get raw output from your LLM and use a [PydanticOutputParser](https://python.langchain.com/v0.2/docs/how_to/structured_output/#using-pydanticoutputparser) or [JsonOutputParser](https://python.langchain.com/v0.2/docs/how_to/output_parser_json/#without-pydantic)." + ] + }, + { + "cell_type": "markdown", + "id": "a94e0e69", + "metadata": {}, + "source": [ + "You can get a list of models that are known to support structured output with," + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0515f558", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_nvidia_ai_endpoints import ChatNVIDIA\n", + "structured_models = [model for model in ChatNVIDIA.get_available_models() if model.supports_structured_output]\n", + "structured_models" + ] + }, + { + "cell_type": "markdown", + "id": "21e56187", + "metadata": {}, + "source": [ + "### Pydantic style" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "482c37e8", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_core.pydantic_v1 import BaseModel, Field\n", + "\n", + "class Person(BaseModel):\n", + " first_name: str = Field(..., description=\"The person's first name.\")\n", + " last_name: str = Field(..., description=\"The person's last name.\")\n", + "\n", + "llm = ChatNVIDIA(model=structured_models[0].id).with_structured_output(Person)\n", + "response = llm.invoke(\"Who is Michael Jeffrey Jordon?\")\n", + "response" + ] + }, + { + "cell_type": "markdown", + "id": "a25ce43f", + "metadata": {}, + "source": [ + "### Enum style" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7f802912", + "metadata": {}, + "outputs": [], + "source": [ + "from enum import Enum\n", + "\n", + "class Choices(Enum):\n", + " A = \"A\"\n", + " B = \"B\"\n", + " C = \"C\"\n", + "\n", + "llm = ChatNVIDIA(model=structured_models[2].id).with_structured_output(Choices)\n", + "response = llm.invoke(\"\"\"\n", + " What does 1+1 equal?\n", + " A. -100\n", + " B. 2\n", + " C. doorstop\n", + " \"\"\"\n", + ")\n", + "response" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "02b7ef29", + "metadata": {}, + "outputs": [], + "source": [ + "model = structured_models[3].id\n", + "llm = ChatNVIDIA(model=model).with_structured_output(Choices)\n", + "print(model)\n", + "response = llm.invoke(\"\"\"\n", + " What does 1+1 equal?\n", + " A. -100\n", + " B. 2\n", + " C. doorstop\n", + " \"\"\"\n", + ")\n", + "response" + ] } ], "metadata": { diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_statics.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_statics.py index 7c51466b..82bf5786 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_statics.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_statics.py @@ -15,6 +15,7 @@ class Model(BaseModel): endpoint: custom endpoint for the model aliases: list of aliases for the model supports_tools: whether the model supports tool calling + supports_structured_output: whether the model supports structured output All aliases are deprecated and will trigger a warning when used. """ @@ -28,6 +29,7 @@ class Model(BaseModel): endpoint: Optional[str] = None aliases: Optional[list] = None supports_tools: Optional[bool] = False + supports_structured_output: Optional[bool] = False base_model: Optional[str] = None def __hash__(self) -> int: @@ -284,24 +286,28 @@ def validate_client(cls, client: str, values: dict) -> str: id="nv-mistralai/mistral-nemo-12b-instruct", model_type="chat", client="ChatNVIDIA", + supports_structured_output=True, ), "meta/llama-3.1-8b-instruct": Model( id="meta/llama-3.1-8b-instruct", model_type="chat", client="ChatNVIDIA", supports_tools=True, + supports_structured_output=True, ), "meta/llama-3.1-70b-instruct": Model( id="meta/llama-3.1-70b-instruct", model_type="chat", client="ChatNVIDIA", supports_tools=True, + supports_structured_output=True, ), "meta/llama-3.1-405b-instruct": Model( id="meta/llama-3.1-405b-instruct", model_type="chat", client="ChatNVIDIA", supports_tools=True, + supports_structured_output=True, ), } diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py index b1738536..a0263820 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py @@ -3,6 +3,7 @@ from __future__ import annotations import base64 +import enum import io import logging import os @@ -28,16 +29,23 @@ AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) +from langchain_core.exceptions import OutputParserException from langchain_core.language_models import BaseChatModel, LanguageModelInput from langchain_core.messages import ( AIMessage, AIMessageChunk, BaseMessage, ) +from langchain_core.output_parsers import ( + BaseOutputParser, + JsonOutputParser, + PydanticOutputParser, +) from langchain_core.outputs import ( ChatGeneration, ChatGenerationChunk, ChatResult, + Generation, ) from langchain_core.pydantic_v1 import BaseModel, Field, PrivateAttr from langchain_core.runnables import Runnable @@ -48,8 +56,8 @@ from langchain_nvidia_ai_endpoints._statics import Model _CallbackManager = Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] -_DictOrPydanticClass = Union[Dict[str, Any], Type[BaseModel]] -_DictOrPydantic = Union[Dict, BaseModel] +_DictOrPydanticOrEnumClass = Union[Dict[str, Any], Type[BaseModel], Type[enum.Enum]] +_DictOrPydanticOrEnum = Union[Dict, BaseModel, enum.Enum] try: import PIL.Image @@ -510,16 +518,214 @@ def bind_functions( ) -> Runnable[LanguageModelInput, BaseMessage]: raise NotImplementedError("Not implemented, use `bind_tools` instead.") - def with_structured_output( + # we have an Enum extension to BaseChatModel.with_structured_output and + # as a result need to type ignore for the schema parameter and return type. + def with_structured_output( # type: ignore self, - schema: _DictOrPydanticClass, + schema: _DictOrPydanticOrEnumClass, *, - method: Literal["function_calling", "json_mode"] = "function_calling", - return_type: Literal["parsed", "all"] = "parsed", + include_raw: bool = False, **kwargs: Any, - ) -> Runnable[LanguageModelInput, _DictOrPydantic]: - raise NotImplementedError( - "Not implemented, awaiting server-side function-recieving API" - " Consider following open-source LLM agent spec techniques:" - " https://huggingface.co/blog/open-source-llms-as-agents" - ) + ) -> Runnable[LanguageModelInput, _DictOrPydanticOrEnum]: + """ + Bind a structured output schema to the model. + + The schema can be - + 0. a dictionary representing a JSON schema + 1. a Pydantic object + 2. an Enum + + 0. If a dictionary is provided, the model will return a dictionary. Example: + ``` + json_schema = { + "title": "joke", + "description": "Joke to tell user.", + "type": "object", + "properties": { + "setup": { + "type": "string", + "description": "The setup of the joke", + }, + "punchline": { + "type": "string", + "description": "The punchline to the joke", + }, + }, + "required": ["setup", "punchline"], + } + + structured_llm = llm.with_structured_output(json_schema) + structured_llm.invoke("Tell me a joke about NVIDIA") + # Output: {'setup': 'Why did NVIDIA go broke? The hardware ate all the software.', + # 'punchline': 'It took a big bite out of their main board.'} + ``` + + 1. If a Pydantic schema is provided, the model will return a Pydantic object. + Example: + ``` + from langchain_core.pydantic_v1 import BaseModel, Field + class Joke(BaseModel): + setup: str = Field(description="The setup of the joke") + punchline: str = Field(description="The punchline to the joke") + + structured_llm = llm.with_structured_output(Joke) + structured_llm.invoke("Tell me a joke about NVIDIA") + # Output: Joke(setup='Why did NVIDIA go broke? The hardware ate all the software.', + # punchline='It took a big bite out of their main board.') + ``` + + 2. If an Enum is provided, all values must be strings, and the model will return + an Enum object. Example: + ``` + import enum + class Choices(enum.Enum): + A = "A" + B = "B" + C = "C" + + structured_llm = llm.with_structured_output(Choices) + structured_llm.invoke("What is the first letter in this list? [X, Y, Z, C]") + # Output: + ``` + + Note about streaming: Unlike other streaming responses, the streamed chunks + will be increasingly complete. They will not be deltas. The last chunk will + contain the complete response. + + For instance with a dictionary schema, the chunks will be: + ``` + structured_llm = llm.with_structured_output(json_schema) + for chunk in structured_llm.stream("Tell me a joke about NVIDIA"): + print(chunk) + + # Output: + # {} + # {'setup': ''} + # {'setup': 'Why'} + # {'setup': 'Why did'} + # {'setup': 'Why did N'} + # {'setup': 'Why did NVID'} + # ... + # {'setup': 'Why did NVIDIA go broke? The hardware ate all the software.', 'punchline': 'It took a big bite out of their main board'} + # {'setup': 'Why did NVIDIA go broke? The hardware ate all the software.', 'punchline': 'It took a big bite out of their main board.'} + ``` + + For instnace with a Pydantic schema, the chunks will be: + ``` + structured_llm = llm.with_structured_output(Joke) + for chunk in structured_llm.stream("Tell me a joke about NVIDIA"): + print(chunk) + + # Output: + # setup='Why did NVIDIA go broke? The hardware ate all the software.' punchline='' + # setup='Why did NVIDIA go broke? The hardware ate all the software.' punchline='It' + # setup='Why did NVIDIA go broke? The hardware ate all the software.' punchline='It took' + # ... + # setup='Why did NVIDIA go broke? The hardware ate all the software.' punchline='It took a big bite out of their main board' + # setup='Why did NVIDIA go broke? The hardware ate all the software.' punchline='It took a big bite out of their main board.' + ``` + + For Pydantic schema and Enum, the output will be None if the response is + insufficient to construct the object or otherwise invalid. For instance, + ``` + llm = ChatNVIDIA(max_tokens=1) + structured_llm = llm.with_structured_output(Joke) + print(structured_llm.invoke("Tell me a joke about NVIDIA")) + + # Output: None + ``` + + For more, see https://python.langchain.com/v0.2/docs/how_to/structured_output/ + """ # noqa: E501 + + if "method" in kwargs: + warnings.warn( + "The 'method' parameter is unnecessary and is ignored. " + "The appropriate method will be chosen automatically depending " + "on the type of schema provided." + ) + + if include_raw: + raise NotImplementedError( + "include_raw=True is not implemented, consider " + "https://python.langchain.com/v0.2/docs/how_to/" + "structured_output/#prompting-and-parsing-model" + "-outputs-directly or rely on the structured response " + "being None when the LLM produces an incomplete response." + ) + + # check if the model supports structured output, warn if it does not + known_good = False + # todo: we need to store model: Model in this class + # instead of model: str (= Model.id) + # this should be: if not self.model.supports_tools: warnings.warn... + candidates = [ + model for model in self.available_models if model.id == self.model + ] + if not candidates: # user must have specified the model themselves + known_good = False + else: + assert len(candidates) == 1, "Multiple models with the same id" + known_good = candidates[0].supports_structured_output is True + if not known_good: + warnings.warn( + f"Model '{self.model}' is not known to support structured output. " + "Your output may fail at inference time." + ) + + if isinstance(schema, dict): + output_parser: BaseOutputParser = JsonOutputParser() + nvext_param: Dict[str, Any] = {"guided_json": schema} + + elif issubclass(schema, BaseModel): + # PydanticOutputParser does not support streaming. what we do + # instead is ignore all inputs that are incomplete wrt the + # underlying Pydantic schema. if the entire input is invalid, + # we return None. + class ForgivingPydanticOutputParser(PydanticOutputParser): + def parse_result( + self, result: List[Generation], *, partial: bool = False + ) -> Any: + try: + return super().parse_result(result, partial=partial) + except OutputParserException: + pass + return None + + output_parser = ForgivingPydanticOutputParser(pydantic_object=schema) + nvext_param = {"guided_json": schema.schema()} + + elif issubclass(schema, enum.Enum): + # langchain's EnumOutputParser is not in langchain_core + # and doesn't support streaming. this is a simple implementation + # that supports streaming with our semantics of returning None + # if no complete object can be constructed. + class EnumOutputParser(BaseOutputParser): + enum: Type[enum.Enum] + + def parse(self, response: str) -> Any: + try: + return self.enum(response.strip()) + except ValueError: + pass + return None + + # guided_choice only supports string choices + choices = [choice.value for choice in schema] + if not all(isinstance(choice, str) for choice in choices): + # instead of erroring out we could coerce the enum values to + # strings, but would then need to coerce them back to their + # original type for Enum construction. + raise ValueError( + "Enum schema must only contain string choices. " + "Use StrEnum or ensure all member values are strings." + ) + output_parser = EnumOutputParser(enum=schema) + nvext_param = {"guided_choice": choices} + else: + raise ValueError( + "Schema must be a Pydantic object, a dictionary " + "representing a JSON schema, or an Enum." + ) + + return super().bind(nvext=nvext_param) | output_parser diff --git a/libs/ai-endpoints/pyproject.toml b/libs/ai-endpoints/pyproject.toml index 0632779b..a1b32bca 100644 --- a/libs/ai-endpoints/pyproject.toml +++ b/libs/ai-endpoints/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain-nvidia-ai-endpoints" -version = "0.2.0" +version = "0.2.1" description = "An integration package connecting NVIDIA AI Endpoints and LangChain" authors = [] readme = "README.md" diff --git a/libs/ai-endpoints/tests/integration_tests/conftest.py b/libs/ai-endpoints/tests/integration_tests/conftest.py index 2c45386f..b766610e 100644 --- a/libs/ai-endpoints/tests/integration_tests/conftest.py +++ b/libs/ai-endpoints/tests/integration_tests/conftest.py @@ -27,6 +27,12 @@ def pytest_addoption(parser: pytest.Parser) -> None: nargs="+", help="Run tests for a specific chat models that support tool calling", ) + parser.addoption( + "--structured-model-id", + action="store", + nargs="+", + help="Run tests for a specific models that support structured output", + ) parser.addoption( "--qa-model-id", action="store", @@ -92,6 +98,18 @@ def get_all_known_models() -> List[Model]: ] metafunc.parametrize("tool_model", models, ids=models) + if "structured_model" in metafunc.fixturenames: + models = [] + if model_list := metafunc.config.getoption("structured_model_id"): + models = model_list + if metafunc.config.getoption("all_models"): + models = [ + model.id + for model in ChatNVIDIA(**mode).available_models + if model.supports_structured_output + ] + metafunc.parametrize("structured_model", models, ids=models) + if "rerank_model" in metafunc.fixturenames: models = [NVIDIARerank._default_model_name] if model_list := metafunc.config.getoption("rerank_model_id"): diff --git a/libs/ai-endpoints/tests/integration_tests/test_structured_output.py b/libs/ai-endpoints/tests/integration_tests/test_structured_output.py new file mode 100644 index 00000000..3f4f5aa8 --- /dev/null +++ b/libs/ai-endpoints/tests/integration_tests/test_structured_output.py @@ -0,0 +1,186 @@ +import enum +from typing import Any, Callable, Optional, Union + +import pytest +from langchain_core.messages import HumanMessage +from langchain_core.pydantic_v1 import BaseModel, Field + +from langchain_nvidia_ai_endpoints import ChatNVIDIA + + +def do_invoke(llm: ChatNVIDIA, message: str) -> Any: + return llm.invoke(message) + + +def do_stream(llm: ChatNVIDIA, message: str) -> Any: + # the way streaming works is to progressively grow the response + # so we just return the last chunk. this is different from other + # streaming results, which are *Chunks that can be concatenated. + result = [chunk for chunk in llm.stream(message)] + return result[-1] if result else None + + +@pytest.mark.xfail(reason="Accuracy is not guaranteed") +def test_accuracy(structured_model: str, mode: dict) -> None: + class Person(BaseModel): + name: str = Field(description="The name of the person") + age: Optional[int] = Field(description="The age of the person") + birthdate: Optional[str] = Field(description="The birthdate of the person") + occupation: Optional[str] = Field(description="The occupation of the person") + birthplace: Optional[str] = Field(description="The birthplace of the person") + + messages = [ + HumanMessage( + """ + Jen-Hsun Huang was born in Tainan, Taiwan, on February 17, 1963. His family + moved to Thailand when he was five; when he was nine, he and his brother were + sent to the United States to live with an uncle in Tacoma, Washington. When he + was ten, he lived in the boys' dormitory with his brother at Oneida Baptist + Institute while attending Oneida Elementary school in Oneida, Kentucky—his + uncle had mistaken what was actually a religious reform academy for a + prestigious boarding school. Several years later, their parents also moved to + the United States and settled in Oregon, where Huang graduated from Aloha + High School in Aloha, Oregon. He skipped two years and graduated at sixteen. + While growing up in Oregon in the 1980s, Huang got his first job at a local + Denny's restaurant, where he worked as a busboy and waiter. + Huang received his undergraduate degree in electrical engineering from Oregon + State University in 1984, and his master's degree in electrical engineering + from Stanford University in 1992. + + The current date is July 2034. + """ + ), + HumanMessage("Who is Jensen?"), + ] + + llm = ChatNVIDIA(model=structured_model, **mode) + structured_llm = llm.with_structured_output(Person) + person = structured_llm.invoke(messages) + assert isinstance(person, Person) + assert person.name in ["Jen-Hsun Huang", "Jensen"] + # assert person.age == 71 # this is too hard + assert person.birthdate == "February 17, 1963" + assert person.occupation and ( + "founder" in person.occupation.lower() or "CEO" in person.occupation.upper() + ) + assert person.birthplace == "Tainan, Taiwan" + + +class Joke(BaseModel): + """Joke to tell user.""" + + setup: str = Field(description="The setup of the joke") + punchline: str = Field(description="The punchline to the joke") + rating: Optional[int] = Field(description="How funny the joke is, from 1 to 10") + + +@pytest.mark.parametrize("func", [do_invoke, do_stream], ids=["invoke", "stream"]) +def test_pydantic(structured_model: str, mode: dict, func: Callable) -> None: + llm = ChatNVIDIA(model=structured_model, temperature=0, **mode) + structured_llm = llm.with_structured_output(Joke) + result = func(structured_llm, "Tell me a joke about cats") + assert isinstance(result, Joke) + + +@pytest.mark.parametrize("func", [do_invoke, do_stream], ids=["invoke", "stream"]) +def test_dict(structured_model: str, mode: dict, func: Callable) -> None: + json_schema = { + "title": "joke", + "description": "Joke to tell user.", + "type": "object", + "properties": { + "setup": { + "type": "string", + "description": "The setup of the joke", + }, + "punchline": { + "type": "string", + "description": "The punchline to the joke", + }, + "rating": { + "type": "integer", + "description": "How funny the joke is, from 1 to 10", + }, + }, + "required": ["setup", "punchline"], + } + + llm = ChatNVIDIA(model=structured_model, temperature=0, **mode) + structured_llm = llm.with_structured_output(json_schema) + result = func(structured_llm, "Tell me a joke about cats") + assert isinstance(result, dict) + assert "setup" in result + assert "punchline" in result + + +@pytest.mark.parametrize("func", [do_invoke, do_stream], ids=["invoke", "stream"]) +def test_enum(structured_model: str, mode: dict, func: Callable) -> None: + class Choices(enum.Enum): + A = "A is an option" + B = "B is an option" + C = "C is an option" + + llm = ChatNVIDIA(model=structured_model, temperature=0, **mode) + structured_llm = llm.with_structured_output(Choices) + result = func( + structured_llm, + """ + What does 1+1 equal? + A. -100 + B. 2 + C. doorstop + """, + ) + assert isinstance(result, Choices) + assert result in Choices + + +@pytest.mark.parametrize("func", [do_invoke, do_stream], ids=["invoke", "stream"]) +def test_enum_incomplete(structured_model: str, mode: dict, func: Callable) -> None: + class Choices(enum.Enum): + A = "A is an option you can pick" + B = "B is an option you can pick" + C = "C is an option you can pick" + + llm = ChatNVIDIA(model=structured_model, temperature=0, max_tokens=3, **mode) + structured_llm = llm.with_structured_output(Choices) + result = func( + structured_llm, + """ + What does 1+1 equal? + A. -100 + B. 2 + C. doorstop + """, + ) + assert result is None + + +@pytest.mark.parametrize("func", [do_invoke, do_stream], ids=["invoke", "stream"]) +def test_multiple_schema(structured_model: str, mode: dict, func: Callable) -> None: + class ConversationalResponse(BaseModel): + """Respond in a conversational manner. Be kind and helpful.""" + + response: str = Field( + description="A conversational response to the user's query" + ) + + class Response(BaseModel): + output: Union[Joke, ConversationalResponse] + + llm = ChatNVIDIA(model=structured_model, temperature=0, **mode) + structured_llm = llm.with_structured_output(Response) + response = func(structured_llm, "Tell me a joke about cats") + assert isinstance(response, Response) + assert isinstance(response.output, Joke) or isinstance( + response.output, ConversationalResponse + ) + + +@pytest.mark.parametrize("func", [do_invoke, do_stream], ids=["invoke", "stream"]) +def test_pydantic_incomplete(structured_model: str, mode: dict, func: Callable) -> None: + # 3 tokens is not enough to construct a Joke + llm = ChatNVIDIA(model=structured_model, temperature=0, max_tokens=3, **mode) + structured_llm = llm.with_structured_output(Joke) + result = func(structured_llm, "Tell me a joke about cats") + assert result is None diff --git a/libs/ai-endpoints/tests/unit_tests/conftest.py b/libs/ai-endpoints/tests/unit_tests/conftest.py index 8925a6ad..f0790214 100644 --- a/libs/ai-endpoints/tests/unit_tests/conftest.py +++ b/libs/ai-endpoints/tests/unit_tests/conftest.py @@ -1,3 +1,5 @@ +from typing import Callable, List + import pytest import requests_mock @@ -18,3 +20,43 @@ def public_class(request: pytest.FixtureRequest) -> type: @pytest.fixture def empty_v1_models(requests_mock: requests_mock.Mocker) -> None: requests_mock.get("https://integrate.api.nvidia.com/v1/models", json={"data": []}) + + +@pytest.fixture +def mock_model() -> str: + return "mock-model" + + +@pytest.fixture(autouse=True) +def mock_v1_models(requests_mock: requests_mock.Mocker, mock_model: str) -> None: + requests_mock.get( + "https://integrate.api.nvidia.com/v1/models", + json={ + "data": [ + {"id": mock_model}, + ] + }, + ) + + +@pytest.fixture +def mock_streaming_response( + requests_mock: requests_mock.Mocker, mock_model: str +) -> Callable: + def builder(chunks: List[str]) -> None: + requests_mock.post( + "https://integrate.api.nvidia.com/v1/chat/completions", + text="\n\n".join( + [ + 'data: {"id":"ID0","object":"chat.completion.chunk","created":1234567890,"model":"bogus","choices":[{"index":0,"delta":{"role":"assistant","content":null},"logprobs":null,"finish_reason":null}]}', # noqa: E501 + *[ + f'data: {{"id":"ID0","object":"chat.completion.chunk","created":1234567890,"model":"bogus","choices":[{{"index":0,"delta":{{"role":null,"content":"{content}"}},"logprobs":null,"finish_reason":null}}]}}' # noqa: E501 + for content in chunks + ], + 'data: {"id":"ID0","object":"chat.completion.chunk","created":1234567890,"model":"bogus","choices":[{"index":0,"delta":{"role":null,"content":""},"logprobs":null,"finish_reason":"stop","stop_reason":null}]}', # noqa: E501 + "data: [DONE]", + ] + ), + ) + + return builder diff --git a/libs/ai-endpoints/tests/unit_tests/test_structured_output.py b/libs/ai-endpoints/tests/unit_tests/test_structured_output.py new file mode 100644 index 00000000..39a70089 --- /dev/null +++ b/libs/ai-endpoints/tests/unit_tests/test_structured_output.py @@ -0,0 +1,132 @@ +import enum +import warnings +from typing import Callable, List, Optional + +import pytest +from langchain_core.pydantic_v1 import BaseModel, Field + +from langchain_nvidia_ai_endpoints import ChatNVIDIA + + +class Joke(BaseModel): + """Joke to tell user.""" + + setup: str = Field(description="The setup of the joke") + punchline: str = Field(description="The punchline to the joke") + rating: Optional[int] = Field(description="How funny the joke is, from 1 to 10") + + +def test_method() -> None: + with pytest.warns(UserWarning) as record: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=UserWarning, + message=".*not known to support structured output.*", + ) + ChatNVIDIA().with_structured_output(Joke, method="json_mode") + assert len(record) == 1 + assert "unnecessary" in str(record[0].message) + + +def test_include_raw() -> None: + with pytest.raises(NotImplementedError): + ChatNVIDIA().with_structured_output(Joke, include_raw=True) + + with pytest.raises(NotImplementedError): + ChatNVIDIA().with_structured_output(Joke.schema(), include_raw=True) + + +def test_known_does_not_warn(empty_v1_models: None) -> None: + structured_model = [ + model + for model in ChatNVIDIA.get_available_models() + if model.supports_structured_output + ] + assert structured_model, "No models support structured output" + + with warnings.catch_warnings(): + warnings.simplefilter("error") + ChatNVIDIA(model=structured_model[0].id).with_structured_output(Joke) + + +def test_unknown_warns(empty_v1_models: None) -> None: + unstructured_model = [ + model + for model in ChatNVIDIA.get_available_models() + if not model.supports_structured_output + ] + assert unstructured_model, "All models support structured output" + + with pytest.warns(UserWarning) as record: + ChatNVIDIA(model=unstructured_model[0].id).with_structured_output(Joke) + assert len(record) == 1 + assert "not known to support structured output" in str(record[0].message) + + +def test_enum_negative() -> None: + class Choices(enum.Enum): + A = "A" + B = "2" + C = 3 + + llm = ChatNVIDIA() + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=UserWarning, + message=".*not known to support structured output.*", + ) + with pytest.raises(ValueError) as e: + llm.with_structured_output(Choices) + assert "only contain string choices" in str(e.value) + + +class Choices(enum.Enum): + YES = "Yes it is" + NO = "No it is not" + + +@pytest.mark.parametrize( + "chunks", + [ + ["Y", "es", " it", " is"], + ["N", "o", " it", " is", " not"], + ], + ids=["YES", "NO"], +) +def test_stream_enum( + mock_streaming_response: Callable, + chunks: List[str], +) -> None: + mock_streaming_response(chunks) + + warnings.filterwarnings("ignore", r".*not known to support structured output.*") + structured_llm = ChatNVIDIA().with_structured_output(Choices) + # chunks are progressively more complete, so we only consider the last + for chunk in structured_llm.stream("This is ignored."): + response = chunk + assert isinstance(response, Choices) + assert response in Choices + + +@pytest.mark.parametrize( + "chunks", + [ + ["Y", "es", " it"], + ["N", "o", " it", " is"], + ], + ids=["YES", "NO"], +) +def test_stream_enum_incomplete( + mock_streaming_response: Callable, + chunks: List[str], +) -> None: + mock_streaming_response(chunks) + + warnings.filterwarnings("ignore", r".*not known to support structured output.*") + structured_llm = ChatNVIDIA().with_structured_output(Choices) + # chunks are progressively more complete, so we only consider the last + for chunk in structured_llm.stream("This is ignored."): + response = chunk + assert response is None