From aa201fa4fec1673eceabebb9330327fc72a06f4b Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Wed, 12 Jun 2024 11:44:42 -0400 Subject: [PATCH 01/11] add tests for tool calling --- .../langchain_nvidia_ai_endpoints/_statics.py | 2 + .../tests/integration_tests/conftest.py | 18 + .../integration_tests/test_bind_tools.py | 382 ++++++++++++++++++ 3 files changed, 402 insertions(+) create mode 100644 libs/ai-endpoints/tests/integration_tests/test_bind_tools.py diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_statics.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_statics.py index 1151143e..d5649623 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_statics.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_statics.py @@ -13,6 +13,7 @@ class Model(BaseModel): client: client name, e.g. ChatNVIDIA, NVIDIAEmbeddings, NVIDIARerank endpoint: custom endpoint for the model aliases: list of aliases for the model + supports_tools: whether the model supports tool calling All aliases are deprecated and will trigger a warning when used. """ @@ -25,6 +26,7 @@ class Model(BaseModel): client: Optional[Literal["ChatNVIDIA", "NVIDIAEmbeddings", "NVIDIARerank"]] = None endpoint: Optional[str] = None aliases: Optional[list] = None + supports_tools: Optional[bool] = False base_model: Optional[str] = None def __hash__(self) -> int: diff --git a/libs/ai-endpoints/tests/integration_tests/conftest.py b/libs/ai-endpoints/tests/integration_tests/conftest.py index 1b658767..5240a89c 100644 --- a/libs/ai-endpoints/tests/integration_tests/conftest.py +++ b/libs/ai-endpoints/tests/integration_tests/conftest.py @@ -21,6 +21,12 @@ def pytest_addoption(parser: pytest.Parser) -> None: nargs="+", help="Run tests for a specific chat model or list of models", ) + parser.addoption( + "--tool-model-id", + action="store", + nargs="+", + help="Run tests for a specific chat models that support tool calling", + ) parser.addoption( "--qa-model-id", action="store", @@ -74,6 +80,18 @@ def get_all_known_models() -> List[Model]: ] metafunc.parametrize("chat_model", models, ids=models) + if "tool_model" in metafunc.fixturenames: + models = [] + if model_list := metafunc.config.getoption("tool_model_id"): + models = model_list + if metafunc.config.getoption("all_models"): + models = [ + model.id + for model in ChatNVIDIA(**mode).available_models + if model.model_type == "chat" and model.supports_tools + ] + metafunc.parametrize("tool_model", models, ids=models) + if "rerank_model" in metafunc.fixturenames: models = [NVIDIARerank._default_model_name] if model_list := metafunc.config.getoption("rerank_model_id"): diff --git a/libs/ai-endpoints/tests/integration_tests/test_bind_tools.py b/libs/ai-endpoints/tests/integration_tests/test_bind_tools.py new file mode 100644 index 00000000..83dff6b0 --- /dev/null +++ b/libs/ai-endpoints/tests/integration_tests/test_bind_tools.py @@ -0,0 +1,382 @@ +import json +import warnings +from typing import Any, List, Literal, Optional, Union + +import pytest +from langchain_core.messages import AIMessage, ChatMessage +from langchain_core.pydantic_v1 import Field +from langchain_core.tools import tool + +from langchain_nvidia_ai_endpoints import ChatNVIDIA + +# +# ways to specify tools: +# 0. bind_tools +# ways to specify tool_choice: +# 1. invoke +# 2. bind_tools +# 3. stream +# tool_choice levels: +# 4. "none" +# 5. "auto" (accuracy only) +# 6. None (accuracy only) +# 7. "required" +# 8. {"function": {"name": tool name}} (partial function) +# 9. {"type": "function", "function": {"name": tool name}} +# 10. "any" (bind_tools only) +# 11. tool name (bind_tools only) +# 12. True (bind_tools only) +# 13. False (bind_tools only) +# tools levels: +# 14. no tools +# 15. one tool +# 16. multiple tools (accuracy only) +# test types: +# 17. deterministic (minimial accuracy tests; relies on basic tool calling skills) +# 18. accuracy (proper tool; proper arguments) +# negative tests: +# 19. require unknown named tool (invoke/stream only) +# 20. partial function (invoke/stream only) +# + +# todo: streaming +# todo: test tool with no arguments +# todo: parallel_tool_calls + + +@tool +def xxyyzz( + a: int = Field(..., description="First number"), + b: int = Field(..., description="Second number"), +) -> int: + """xxyyzz two numbers""" + return (a**b) % (b - a) + + +@tool +def zzyyxx( + a: int = Field(..., description="First number"), + b: int = Field(..., description="Second number"), +) -> int: + """zzyyxx two numbers""" + return (b**a) % (a - b) + + +def check_response_structure(response: AIMessage) -> None: + assert not response.content # should be `response.content is None` but + # AIMessage.content: Union[str, List[Union[str, Dict]]] cannot be None. + for tool_call in response.tool_calls: + assert tool_call["id"] is not None + assert response.response_metadata is not None + assert isinstance(response.response_metadata, dict) + assert "finish_reason" in response.response_metadata + assert response.response_metadata["finish_reason"] in [ + "tool_calls", + "stop", + ] # todo: remove "stop" + assert len(response.tool_calls) > 0 + + +# users can also get at the tool calls from the response.additional_kwargs +@pytest.mark.xfail(reason="Accuracy test") +def test_accuracy_default_invoke_additional_kwargs(tool_model: str, mode: dict) -> None: + llm = ChatNVIDIA(temperature=0, model=tool_model, **mode).bind_tools([xxyyzz]) + response = llm.invoke("What is 11 xxyyzz 3?") + assert not response.content # should be `response.content is None` but + # AIMessage.content: Union[str, List[Union[str, Dict]]] cannot be None. + assert response.additional_kwargs is not None + assert "tool_calls" in response.additional_kwargs + assert isinstance(response.additional_kwargs["tool_calls"], list) + assert response.additional_kwargs["tool_calls"] + for tool_call in response.additional_kwargs["tool_calls"]: + assert "id" in tool_call + assert tool_call["id"] is not None + assert "type" in tool_call + assert tool_call["type"] == "function" + assert "function" in tool_call + assert response.response_metadata is not None + assert isinstance(response.response_metadata, dict) + assert "content" in response.response_metadata + assert response.response_metadata["content"] is None + assert "finish_reason" in response.response_metadata + assert response.response_metadata["finish_reason"] in [ + "tool_calls", + "stop", + ] # todo: remove "stop" + assert len(response.additional_kwargs["tool_calls"]) > 0 + tool_call = response.additional_kwargs["tool_calls"][0] + assert tool_call["function"]["name"] == "xxyyzz" + assert json.loads(tool_call["function"]["arguments"]) == {"a": 11, "b": 3} + + +@pytest.mark.parametrize( + "tool_choice", + [ + "none", + "required", + {"function": {"name": "xxyyzz"}}, + {"type": "function", "function": {"name": "xxyyzz"}}, + ], + ids=["none", "required", "partial", "function"], +) +def test_invoke_tool_choice_with_no_tool( + tool_model: str, mode: dict, tool_choice: Any +) -> None: + llm = ChatNVIDIA(model=tool_model, **mode) + with pytest.raises(Exception) as e: + llm.invoke("What is 11 xxyyzz 3?", tool_choice=tool_choice) + assert "400" in str(e.value) or "###" in str( + e.value + ) # todo: stop transforming 400 -> ### + assert ( + "Value error, When using `tool_choice`, `tools` must be set." in str(e.value) + or ( + "Value error, Invalid value for `tool_choice`: `tool_choice` is only " + "allowed when `tools` are specified." + ) + in str(e.value) + or "invalid_request_error" in str(e.value) + ) + + +def test_invoke_tool_choice_none(tool_model: str, mode: dict) -> None: + llm = ChatNVIDIA(model=tool_model, **mode).bind_tools(tools=[xxyyzz]) + response = llm.invoke("What is 11 xxyyzz 3?", tool_choice="none") # type: ignore + assert isinstance(response, ChatMessage) + assert "tool_calls" not in response.additional_kwargs + + +@pytest.mark.parametrize( + "tool_choice", + [ + {"function": {"name": "xxyyzz"}}, + ], + ids=["partial"], +) +def test_invoke_tool_choice_negative( + tool_model: str, + mode: dict, + tool_choice: Optional[ + Union[dict, str, Literal["auto", "none", "any", "required"], bool] + ], +) -> None: + llm = ChatNVIDIA(model=tool_model, **mode).bind_tools([xxyyzz]) + with pytest.raises(Exception) as e: + llm.invoke("What is 11 xxyyzz 3?", tool_choice=tool_choice) # type: ignore + assert "400" in str(e.value) or "###" in str( + e.value + ) # todo: stop transforming 400 -> ### + assert "invalid_request_error" in str(e.value) or "value_error" in str(e.value) + + +@pytest.mark.parametrize( + "tool_choice", + [ + "required", + {"type": "function", "function": {"name": "xxyyzz"}}, + ], + ids=["required", "function"], +) +def test_invoke_tool_choice( + tool_model: str, + mode: dict, + tool_choice: Optional[ + Union[dict, str, Literal["auto", "none", "any", "required"], bool] + ], +) -> None: + llm = ChatNVIDIA(model=tool_model, **mode).bind_tools([xxyyzz]) + response = llm.invoke("What is 11 xxyyzz 3?", tool_choice=tool_choice) # type: ignore + assert isinstance(response, AIMessage) + check_response_structure(response) + + +@pytest.mark.parametrize( + "tool_choice", + [ + "auto", + None, + "required", + {"type": "function", "function": {"name": "xxyyzz"}}, + ], + ids=["auto", "absent", "required", "function"], +) +@pytest.mark.parametrize( + "tools", + [[xxyyzz], [xxyyzz, zzyyxx], [zzyyxx, xxyyzz]], + ids=["xxyyzz", "xxyyzz_and_zzyyxx", "zzyyxx_and_xxyyzz"], +) +@pytest.mark.xfail(reason="Accuracy test") +def test_accuracy_invoke_tool_choice( + tool_model: str, + mode: dict, + tools: List, + tool_choice: Any, +) -> None: + llm = ChatNVIDIA(temperature=0, model=tool_model, **mode).bind_tools(tools) + response = llm.invoke("What is 11 xxyyzz 3?", tool_choice=tool_choice) # type: ignore + assert isinstance(response, AIMessage) + check_response_structure(response) + tool_call = response.tool_calls[0] + assert tool_call["name"] == "xxyyzz" + assert tool_call["args"] == {"b": 3, "a": 11} + + +def test_invoke_tool_choice_with_unknown_tool(tool_model: str, mode: dict) -> None: + llm = ChatNVIDIA(model=tool_model, **mode).bind_tools(tools=[xxyyzz]) + with pytest.raises(Exception) as e: + llm.invoke( + "What is 11 xxyyzz 3?", + tool_choice={"type": "function", "function": {"name": "zzyyxx"}}, + ) # type: ignore + assert ( + "not found in the tools list" in str(e.value) + or "no function named" in str(e.value) + or "does not match any of the specified" in str(e.value) + ) + + +@pytest.mark.parametrize( + "tool_choice", + [ + {"function": {"name": "xxyyzz"}}, + {"type": "function", "function": {"name": "xxyyzz"}}, + "xxyyzz", + ], + ids=["partial", "function", "name"], +) +def test_bind_tool_tool_choice_with_no_tool_client( + tool_model: str, + mode: dict, + tool_choice: Optional[ + Union[dict, str, Literal["auto", "none", "any", "required"], bool] + ], +) -> None: + with pytest.raises(ValueError) as e: + ChatNVIDIA(model=tool_model, **mode).bind_tools( + tools=[], tool_choice=tool_choice + ) + assert "not found in the tools list" in str(e.value) + + +@pytest.mark.parametrize( + "tool_choice", + [ + "none", + "required", + "any", + True, + False, + ], + ids=["none", "required", "any", "True", "False"], +) +def test_bind_tool_tool_choice_with_no_tool_server( + tool_model: str, mode: dict, tool_choice: Any +) -> None: + llm = ChatNVIDIA(model=tool_model, **mode).bind_tools([], tool_choice=tool_choice) + with pytest.raises(Exception) as e: + llm.invoke("What is 11 xxyyzz 3?") + assert "400" in str(e.value) or "###" in str( + e.value + ) # todo: stop transforming 400 -> ### + assert ( + "Value error, When using `tool_choice`, `tools` must be set." in str(e.value) + or ( + "Value error, Invalid value for `tool_choice`: `tool_choice` is only " + "allowed when `tools` are specified." + ) + in str(e.value) + or "Expected an array with minimum length" in str(e.value) + or "should be non-empty" in str(e.value) + ) + + +@pytest.mark.parametrize( + "tool_choice", + ["none", False], +) +def test_bind_tool_tool_choice_none( + tool_model: str, mode: dict, tool_choice: Any +) -> None: + llm = ChatNVIDIA(model=tool_model, **mode).bind_tools( + tools=[xxyyzz], tool_choice=tool_choice + ) + response = llm.invoke("What is 11 xxyyzz 3?") + assert isinstance(response, ChatMessage) + assert "tool_calls" not in response.additional_kwargs + + +@pytest.mark.parametrize( + "tool_choice", + [ + "required", + {"function": {"name": "xxyyzz"}}, + {"type": "function", "function": {"name": "xxyyzz"}}, + "any", + "xxyyzz", + True, + ], + ids=["required", "partial", "function", "any", "name", "True"], +) +def test_bind_tool_tool_choice( + tool_model: str, + mode: dict, + tool_choice: Optional[ + Union[dict, str, Literal["auto", "none", "any", "required"], bool] + ], +) -> None: + llm = ChatNVIDIA(model=tool_model, **mode).bind_tools( + [xxyyzz], tool_choice=tool_choice + ) + response = llm.invoke("What is 11 xxyyzz 3?") + assert isinstance(response, AIMessage) + check_response_structure(response) + + +@pytest.mark.parametrize( + "tool_choice", + [ + "auto", + None, + "required", + {"function": {"name": "xxyyzz"}}, + {"type": "function", "function": {"name": "xxyyzz"}}, + "any", + "xxyyzz", + True, + ], + ids=["auto", "absent", "required", "partial", "function", "any", "name", "True"], +) +@pytest.mark.parametrize( + "tools", + [[xxyyzz], [xxyyzz, zzyyxx], [zzyyxx, xxyyzz]], + ids=["xxyyzz", "xxyyzz_and_zzyyxx", "zzyyxx_and_xxyyzz"], +) +@pytest.mark.xfail(reason="Accuracy test") +def test_accuracy_bind_tool_tool_choice( + tool_model: str, + mode: dict, + tools: List, + tool_choice: Any, +) -> None: + llm = ChatNVIDIA(temperature=0, model=tool_model, **mode).bind_tools( + tools=tools, tool_choice=tool_choice + ) + response = llm.invoke("What is 11 xxyyzz 3?") + assert isinstance(response, AIMessage) + check_response_structure(response) + tool_call = response.tool_calls[0] + assert tool_call["name"] == "xxyyzz" + assert tool_call["args"] == {"b": 3, "a": 11} + + +def test_known_does_not_warn(tool_model: str, mode: dict) -> None: + with warnings.catch_warnings(): + warnings.simplefilter("error") + ChatNVIDIA(model=tool_model, **mode).bind_tools([xxyyzz]) + + +def test_unknown_warns(mode: dict) -> None: + with pytest.warns(UserWarning) as record: + ChatNVIDIA(model="mock-model", **mode).bind_tools([xxyyzz]) + assert len(record) == 1 + assert "not known to support tools" in str(record[0].message) From 54dec90f55d8b41ba8c633488065d54b227c44f9 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Mon, 17 Jun 2024 18:59:52 -0400 Subject: [PATCH 02/11] add langchain-community dep for openai convert_message_to_dict --- libs/ai-endpoints/poetry.lock | 378 +++++++++++++++++- libs/ai-endpoints/pyproject.toml | 1 + .../integration_tests/test_vlm_models.py | 7 + 3 files changed, 375 insertions(+), 11 deletions(-) diff --git a/libs/ai-endpoints/poetry.lock b/libs/ai-endpoints/poetry.lock index a4464a70..77f7b29c 100644 --- a/libs/ai-endpoints/poetry.lock +++ b/libs/ai-endpoints/poetry.lock @@ -292,6 +292,21 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +[[package]] +name = "dataclasses-json" +version = "0.6.7" +description = "Easily serialize dataclasses to and from JSON." +optional = false +python-versions = "<4.0,>=3.7" +files = [ + {file = "dataclasses_json-0.6.7-py3-none-any.whl", hash = "sha256:0dbf33f26c8d5305befd61b39d2b3414e8a407bedc2834dea9b8d642666fb40a"}, + {file = "dataclasses_json-0.6.7.tar.gz", hash = "sha256:b6b3e528266ea45b9535223bc53ca645f5208833c29229e847b3f26a1cc55fc0"}, +] + +[package.dependencies] +marshmallow = ">=3.18.0,<4.0.0" +typing-inspect = ">=0.4.0,<1" + [[package]] name = "exceptiongroup" version = "1.2.1" @@ -421,6 +436,77 @@ files = [ {file = "frozenlist-1.4.1.tar.gz", hash = "sha256:c037a86e8513059a2613aaba4d817bb90b9d9b6b69aace3ce9c877e8c8ed402b"}, ] +[[package]] +name = "greenlet" +version = "3.0.3" +description = "Lightweight in-process concurrent programming" +optional = false +python-versions = ">=3.7" +files = [ + {file = "greenlet-3.0.3-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:9da2bd29ed9e4f15955dd1595ad7bc9320308a3b766ef7f837e23ad4b4aac31a"}, + {file = "greenlet-3.0.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d353cadd6083fdb056bb46ed07e4340b0869c305c8ca54ef9da3421acbdf6881"}, + {file = "greenlet-3.0.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dca1e2f3ca00b84a396bc1bce13dd21f680f035314d2379c4160c98153b2059b"}, + {file = "greenlet-3.0.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3ed7fb269f15dc662787f4119ec300ad0702fa1b19d2135a37c2c4de6fadfd4a"}, + {file = "greenlet-3.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd4f49ae60e10adbc94b45c0b5e6a179acc1736cf7a90160b404076ee283cf83"}, + {file = "greenlet-3.0.3-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:73a411ef564e0e097dbe7e866bb2dda0f027e072b04da387282b02c308807405"}, + {file = "greenlet-3.0.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:7f362975f2d179f9e26928c5b517524e89dd48530a0202570d55ad6ca5d8a56f"}, + {file = "greenlet-3.0.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:649dde7de1a5eceb258f9cb00bdf50e978c9db1b996964cd80703614c86495eb"}, + {file = "greenlet-3.0.3-cp310-cp310-win_amd64.whl", hash = "sha256:68834da854554926fbedd38c76e60c4a2e3198c6fbed520b106a8986445caaf9"}, + {file = "greenlet-3.0.3-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:b1b5667cced97081bf57b8fa1d6bfca67814b0afd38208d52538316e9422fc61"}, + {file = "greenlet-3.0.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:52f59dd9c96ad2fc0d5724107444f76eb20aaccb675bf825df6435acb7703559"}, + {file = "greenlet-3.0.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:afaff6cf5200befd5cec055b07d1c0a5a06c040fe5ad148abcd11ba6ab9b114e"}, + {file = "greenlet-3.0.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fe754d231288e1e64323cfad462fcee8f0288654c10bdf4f603a39ed923bef33"}, + {file = "greenlet-3.0.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2797aa5aedac23af156bbb5a6aa2cd3427ada2972c828244eb7d1b9255846379"}, + {file = "greenlet-3.0.3-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b7f009caad047246ed379e1c4dbcb8b020f0a390667ea74d2387be2998f58a22"}, + {file = "greenlet-3.0.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c5e1536de2aad7bf62e27baf79225d0d64360d4168cf2e6becb91baf1ed074f3"}, + {file = "greenlet-3.0.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:894393ce10ceac937e56ec00bb71c4c2f8209ad516e96033e4b3b1de270e200d"}, + {file = "greenlet-3.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:1ea188d4f49089fc6fb283845ab18a2518d279c7cd9da1065d7a84e991748728"}, + {file = "greenlet-3.0.3-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:70fb482fdf2c707765ab5f0b6655e9cfcf3780d8d87355a063547b41177599be"}, + {file = "greenlet-3.0.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d4d1ac74f5c0c0524e4a24335350edad7e5f03b9532da7ea4d3c54d527784f2e"}, + {file = "greenlet-3.0.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:149e94a2dd82d19838fe4b2259f1b6b9957d5ba1b25640d2380bea9c5df37676"}, + {file = "greenlet-3.0.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:15d79dd26056573940fcb8c7413d84118086f2ec1a8acdfa854631084393efcc"}, + {file = "greenlet-3.0.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:881b7db1ebff4ba09aaaeae6aa491daeb226c8150fc20e836ad00041bcb11230"}, + {file = "greenlet-3.0.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fcd2469d6a2cf298f198f0487e0a5b1a47a42ca0fa4dfd1b6862c999f018ebbf"}, + {file = "greenlet-3.0.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:1f672519db1796ca0d8753f9e78ec02355e862d0998193038c7073045899f305"}, + {file = "greenlet-3.0.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2516a9957eed41dd8f1ec0c604f1cdc86758b587d964668b5b196a9db5bfcde6"}, + {file = "greenlet-3.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:bba5387a6975598857d86de9eac14210a49d554a77eb8261cc68b7d082f78ce2"}, + {file = "greenlet-3.0.3-cp37-cp37m-macosx_11_0_universal2.whl", hash = "sha256:5b51e85cb5ceda94e79d019ed36b35386e8c37d22f07d6a751cb659b180d5274"}, + {file = "greenlet-3.0.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:daf3cb43b7cf2ba96d614252ce1684c1bccee6b2183a01328c98d36fcd7d5cb0"}, + {file = "greenlet-3.0.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:99bf650dc5d69546e076f413a87481ee1d2d09aaaaaca058c9251b6d8c14783f"}, + {file = "greenlet-3.0.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2dd6e660effd852586b6a8478a1d244b8dc90ab5b1321751d2ea15deb49ed414"}, + {file = "greenlet-3.0.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3391d1e16e2a5a1507d83e4a8b100f4ee626e8eca43cf2cadb543de69827c4c"}, + {file = "greenlet-3.0.3-cp37-cp37m-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e1f145462f1fa6e4a4ae3c0f782e580ce44d57c8f2c7aae1b6fa88c0b2efdb41"}, + {file = "greenlet-3.0.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:1a7191e42732df52cb5f39d3527217e7ab73cae2cb3694d241e18f53d84ea9a7"}, + {file = "greenlet-3.0.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:0448abc479fab28b00cb472d278828b3ccca164531daab4e970a0458786055d6"}, + {file = "greenlet-3.0.3-cp37-cp37m-win32.whl", hash = "sha256:b542be2440edc2d48547b5923c408cbe0fc94afb9f18741faa6ae970dbcb9b6d"}, + {file = "greenlet-3.0.3-cp37-cp37m-win_amd64.whl", hash = "sha256:01bc7ea167cf943b4c802068e178bbf70ae2e8c080467070d01bfa02f337ee67"}, + {file = "greenlet-3.0.3-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:1996cb9306c8595335bb157d133daf5cf9f693ef413e7673cb07e3e5871379ca"}, + {file = "greenlet-3.0.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3ddc0f794e6ad661e321caa8d2f0a55ce01213c74722587256fb6566049a8b04"}, + {file = "greenlet-3.0.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c9db1c18f0eaad2f804728c67d6c610778456e3e1cc4ab4bbd5eeb8e6053c6fc"}, + {file = "greenlet-3.0.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7170375bcc99f1a2fbd9c306f5be8764eaf3ac6b5cb968862cad4c7057756506"}, + {file = "greenlet-3.0.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b66c9c1e7ccabad3a7d037b2bcb740122a7b17a53734b7d72a344ce39882a1b"}, + {file = "greenlet-3.0.3-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:098d86f528c855ead3479afe84b49242e174ed262456c342d70fc7f972bc13c4"}, + {file = "greenlet-3.0.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:81bb9c6d52e8321f09c3d165b2a78c680506d9af285bfccbad9fb7ad5a5da3e5"}, + {file = "greenlet-3.0.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fd096eb7ffef17c456cfa587523c5f92321ae02427ff955bebe9e3c63bc9f0da"}, + {file = "greenlet-3.0.3-cp38-cp38-win32.whl", hash = "sha256:d46677c85c5ba00a9cb6f7a00b2bfa6f812192d2c9f7d9c4f6a55b60216712f3"}, + {file = "greenlet-3.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:419b386f84949bf0e7c73e6032e3457b82a787c1ab4a0e43732898a761cc9dbf"}, + {file = "greenlet-3.0.3-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:da70d4d51c8b306bb7a031d5cff6cc25ad253affe89b70352af5f1cb68e74b53"}, + {file = "greenlet-3.0.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:086152f8fbc5955df88382e8a75984e2bb1c892ad2e3c80a2508954e52295257"}, + {file = "greenlet-3.0.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d73a9fe764d77f87f8ec26a0c85144d6a951a6c438dfe50487df5595c6373eac"}, + {file = "greenlet-3.0.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b7dcbe92cc99f08c8dd11f930de4d99ef756c3591a5377d1d9cd7dd5e896da71"}, + {file = "greenlet-3.0.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1551a8195c0d4a68fac7a4325efac0d541b48def35feb49d803674ac32582f61"}, + {file = "greenlet-3.0.3-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:64d7675ad83578e3fc149b617a444fab8efdafc9385471f868eb5ff83e446b8b"}, + {file = "greenlet-3.0.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b37eef18ea55f2ffd8f00ff8fe7c8d3818abd3e25fb73fae2ca3b672e333a7a6"}, + {file = "greenlet-3.0.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:77457465d89b8263bca14759d7c1684df840b6811b2499838cc5b040a8b5b113"}, + {file = "greenlet-3.0.3-cp39-cp39-win32.whl", hash = "sha256:57e8974f23e47dac22b83436bdcf23080ade568ce77df33159e019d161ce1d1e"}, + {file = "greenlet-3.0.3-cp39-cp39-win_amd64.whl", hash = "sha256:c5ee858cfe08f34712f548c3c363e807e7186f03ad7a5039ebadb29e8c6be067"}, + {file = "greenlet-3.0.3.tar.gz", hash = "sha256:43374442353259554ce33599da8b692d5aa96f8976d567d4badf263371fbe491"}, +] + +[package.extras] +docs = ["Sphinx", "furo"] +test = ["objgraph", "psutil"] + [[package]] name = "idna" version = "3.7" @@ -468,9 +554,62 @@ files = [ {file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"}, ] +[[package]] +name = "langchain" +version = "0.2.5" +description = "Building applications with LLMs through composability" +optional = false +python-versions = "<4.0,>=3.8.1" +files = [ + {file = "langchain-0.2.5-py3-none-any.whl", hash = "sha256:9aded9a65348254e1c93dcdaacffe4d1b6a5e7f74ef80c160c88ff78ad299228"}, + {file = "langchain-0.2.5.tar.gz", hash = "sha256:ffdbf4fcea46a10d461bcbda2402220fcfd72a0c70e9f4161ae0510067b9b3bd"}, +] + +[package.dependencies] +aiohttp = ">=3.8.3,<4.0.0" +async-timeout = {version = ">=4.0.0,<5.0.0", markers = "python_version < \"3.11\""} +langchain-core = ">=0.2.7,<0.3.0" +langchain-text-splitters = ">=0.2.0,<0.3.0" +langsmith = ">=0.1.17,<0.2.0" +numpy = [ + {version = ">=1,<2", markers = "python_version < \"3.12\""}, + {version = ">=1.26.0,<2.0.0", markers = "python_version >= \"3.12\""}, +] +pydantic = ">=1,<3" +PyYAML = ">=5.3" +requests = ">=2,<3" +SQLAlchemy = ">=1.4,<3" +tenacity = ">=8.1.0,<9.0.0" + +[[package]] +name = "langchain-community" +version = "0.2.5" +description = "Community contributed LangChain integrations." +optional = false +python-versions = "<4.0,>=3.8.1" +files = [ + {file = "langchain_community-0.2.5-py3-none-any.whl", hash = "sha256:bf37a334952e42c7676d083cf2d2c4cbfbb7de1949c4149fe19913e2b06c485f"}, + {file = "langchain_community-0.2.5.tar.gz", hash = "sha256:476787b8c8c213b67e7b0eceb53346e787f00fbae12d8e680985bd4f93b0bf64"}, +] + +[package.dependencies] +aiohttp = ">=3.8.3,<4.0.0" +dataclasses-json = ">=0.5.7,<0.7" +langchain = ">=0.2.5,<0.3.0" +langchain-core = ">=0.2.7,<0.3.0" +langsmith = ">=0.1.0,<0.2.0" +numpy = [ + {version = ">=1,<2", markers = "python_version < \"3.12\""}, + {version = ">=1.26.0,<2.0.0", markers = "python_version >= \"3.12\""}, +] +PyYAML = ">=5.3" +requests = ">=2,<3" +SQLAlchemy = ">=1.4,<3" +tenacity = ">=8.1.0,<9.0.0" + [[package]] name = "langchain-core" -version = "0.2.0rc1" +version = "0.2.8" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -479,31 +618,45 @@ develop = false [package.dependencies] jsonpatch = "^1.33" -langsmith = "^0.1.0" -packaging = "^23.2" +langsmith = "^0.1.75" +packaging = ">=23.2,<25" pydantic = ">=1,<3" PyYAML = ">=5.3" tenacity = "^8.1.0" -[package.extras] -extended-testing = ["jinja2 (>=3,<4)"] - [package.source] type = "git" url = "https://github.com/langchain-ai/langchain.git" reference = "HEAD" -resolved_reference = "cd1879f5e75fc9e6a8c04ac839909e0d6f2fb541" +resolved_reference = "c2b2e3266ce97ea647d4b86eedadbb7cd77d0381" subdirectory = "libs/core" +[[package]] +name = "langchain-text-splitters" +version = "0.2.1" +description = "LangChain text splitting utilities" +optional = false +python-versions = "<4.0,>=3.8.1" +files = [ + {file = "langchain_text_splitters-0.2.1-py3-none-any.whl", hash = "sha256:c2774a85f17189eaca50339629d2316d13130d4a8d9f1a1a96f3a03670c4a138"}, + {file = "langchain_text_splitters-0.2.1.tar.gz", hash = "sha256:06853d17d7241ecf5c97c7b6ef01f600f9b0fb953dd997838142a527a4f32ea4"}, +] + +[package.dependencies] +langchain-core = ">=0.2.0,<0.3.0" + +[package.extras] +extended-testing = ["beautifulsoup4 (>=4.12.3,<5.0.0)", "lxml (>=4.9.3,<6.0)"] + [[package]] name = "langsmith" -version = "0.1.50" +version = "0.1.78" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langsmith-0.1.50-py3-none-any.whl", hash = "sha256:a81e9809fcaa277bfb314d729e58116554f186d1478fcfdf553b1c2ccce54b85"}, - {file = "langsmith-0.1.50.tar.gz", hash = "sha256:9fd22df8c689c044058536ea5af66f5302067e7551b60d7a335fede8d479572b"}, + {file = "langsmith-0.1.78-py3-none-any.whl", hash = "sha256:87bc5d9072bfcb6392d7552cbcd6089dcc1faed36d688b1587d80bd48a1acba2"}, + {file = "langsmith-0.1.78.tar.gz", hash = "sha256:d9112d2e9298ec6b02d3b1afec6ed557df9db3746c79d34ef3b448fc18e116cd"}, ] [package.dependencies] @@ -511,6 +664,25 @@ orjson = ">=3.9.14,<4.0.0" pydantic = ">=1,<3" requests = ">=2,<3" +[[package]] +name = "marshmallow" +version = "3.21.3" +description = "A lightweight library for converting complex datatypes to and from native Python datatypes." +optional = false +python-versions = ">=3.8" +files = [ + {file = "marshmallow-3.21.3-py3-none-any.whl", hash = "sha256:86ce7fb914aa865001a4b2092c4c2872d13bc347f3d42673272cabfdbad386f1"}, + {file = "marshmallow-3.21.3.tar.gz", hash = "sha256:4f57c5e050a54d66361e826f94fba213eb10b67b2fdb02c3e0343ce207ba1662"}, +] + +[package.dependencies] +packaging = ">=17.0" + +[package.extras] +dev = ["marshmallow[tests]", "pre-commit (>=3.5,<4.0)", "tox"] +docs = ["alabaster (==0.7.16)", "autodocsumm (==0.2.12)", "sphinx (==7.3.7)", "sphinx-issues (==4.1.0)", "sphinx-version-warning (==1.1.2)"] +tests = ["pytest", "pytz", "simplejson"] + [[package]] name = "multidict" version = "6.0.5" @@ -671,6 +843,88 @@ files = [ {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, ] +[[package]] +name = "numpy" +version = "1.24.4" +description = "Fundamental package for array computing in Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "numpy-1.24.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0bfb52d2169d58c1cdb8cc1f16989101639b34c7d3ce60ed70b19c63eba0b64"}, + {file = "numpy-1.24.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ed094d4f0c177b1b8e7aa9cba7d6ceed51c0e569a5318ac0ca9a090680a6a1b1"}, + {file = "numpy-1.24.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79fc682a374c4a8ed08b331bef9c5f582585d1048fa6d80bc6c35bc384eee9b4"}, + {file = "numpy-1.24.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ffe43c74893dbf38c2b0a1f5428760a1a9c98285553c89e12d70a96a7f3a4d6"}, + {file = "numpy-1.24.4-cp310-cp310-win32.whl", hash = "sha256:4c21decb6ea94057331e111a5bed9a79d335658c27ce2adb580fb4d54f2ad9bc"}, + {file = "numpy-1.24.4-cp310-cp310-win_amd64.whl", hash = "sha256:b4bea75e47d9586d31e892a7401f76e909712a0fd510f58f5337bea9572c571e"}, + {file = "numpy-1.24.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f136bab9c2cfd8da131132c2cf6cc27331dd6fae65f95f69dcd4ae3c3639c810"}, + {file = "numpy-1.24.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e2926dac25b313635e4d6cf4dc4e51c8c0ebfed60b801c799ffc4c32bf3d1254"}, + {file = "numpy-1.24.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:222e40d0e2548690405b0b3c7b21d1169117391c2e82c378467ef9ab4c8f0da7"}, + {file = "numpy-1.24.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7215847ce88a85ce39baf9e89070cb860c98fdddacbaa6c0da3ffb31b3350bd5"}, + {file = "numpy-1.24.4-cp311-cp311-win32.whl", hash = "sha256:4979217d7de511a8d57f4b4b5b2b965f707768440c17cb70fbf254c4b225238d"}, + {file = "numpy-1.24.4-cp311-cp311-win_amd64.whl", hash = "sha256:b7b1fc9864d7d39e28f41d089bfd6353cb5f27ecd9905348c24187a768c79694"}, + {file = "numpy-1.24.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1452241c290f3e2a312c137a9999cdbf63f78864d63c79039bda65ee86943f61"}, + {file = "numpy-1.24.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:04640dab83f7c6c85abf9cd729c5b65f1ebd0ccf9de90b270cd61935eef0197f"}, + {file = "numpy-1.24.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5425b114831d1e77e4b5d812b69d11d962e104095a5b9c3b641a218abcc050e"}, + {file = "numpy-1.24.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd80e219fd4c71fc3699fc1dadac5dcf4fd882bfc6f7ec53d30fa197b8ee22dc"}, + {file = "numpy-1.24.4-cp38-cp38-win32.whl", hash = "sha256:4602244f345453db537be5314d3983dbf5834a9701b7723ec28923e2889e0bb2"}, + {file = "numpy-1.24.4-cp38-cp38-win_amd64.whl", hash = "sha256:692f2e0f55794943c5bfff12b3f56f99af76f902fc47487bdfe97856de51a706"}, + {file = "numpy-1.24.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2541312fbf09977f3b3ad449c4e5f4bb55d0dbf79226d7724211acc905049400"}, + {file = "numpy-1.24.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9667575fb6d13c95f1b36aca12c5ee3356bf001b714fc354eb5465ce1609e62f"}, + {file = "numpy-1.24.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3a86ed21e4f87050382c7bc96571755193c4c1392490744ac73d660e8f564a9"}, + {file = "numpy-1.24.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d11efb4dbecbdf22508d55e48d9c8384db795e1b7b51ea735289ff96613ff74d"}, + {file = "numpy-1.24.4-cp39-cp39-win32.whl", hash = "sha256:6620c0acd41dbcb368610bb2f4d83145674040025e5536954782467100aa8835"}, + {file = "numpy-1.24.4-cp39-cp39-win_amd64.whl", hash = "sha256:befe2bf740fd8373cf56149a5c23a0f601e82869598d41f8e188a0e9869926f8"}, + {file = "numpy-1.24.4-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:31f13e25b4e304632a4619d0e0777662c2ffea99fcae2029556b17d8ff958aef"}, + {file = "numpy-1.24.4-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95f7ac6540e95bc440ad77f56e520da5bf877f87dca58bd095288dce8940532a"}, + {file = "numpy-1.24.4-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:e98f220aa76ca2a977fe435f5b04d7b3470c0a2e6312907b37ba6068f26787f2"}, + {file = "numpy-1.24.4.tar.gz", hash = "sha256:80f5e3a4e498641401868df4208b74581206afbee7cf7b8329daae82676d9463"}, +] + +[[package]] +name = "numpy" +version = "1.26.4" +description = "Fundamental package for array computing in Python" +optional = false +python-versions = ">=3.9" +files = [ + {file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"}, + {file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"}, + {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4"}, + {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f"}, + {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a"}, + {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2"}, + {file = "numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07"}, + {file = "numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5"}, + {file = "numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71"}, + {file = "numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef"}, + {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e"}, + {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5"}, + {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a"}, + {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a"}, + {file = "numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20"}, + {file = "numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2"}, + {file = "numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218"}, + {file = "numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b"}, + {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b"}, + {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed"}, + {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a"}, + {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0"}, + {file = "numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110"}, + {file = "numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818"}, + {file = "numpy-1.26.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c"}, + {file = "numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be"}, + {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764"}, + {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3"}, + {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd"}, + {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c"}, + {file = "numpy-1.26.4-cp39-cp39-win32.whl", hash = "sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6"}, + {file = "numpy-1.26.4-cp39-cp39-win_amd64.whl", hash = "sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0"}, + {file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"}, +] + [[package]] name = "orjson" version = "3.10.1" @@ -1174,6 +1428,93 @@ files = [ {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, ] +[[package]] +name = "sqlalchemy" +version = "2.0.30" +description = "Database Abstraction Library" +optional = false +python-versions = ">=3.7" +files = [ + {file = "SQLAlchemy-2.0.30-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:3b48154678e76445c7ded1896715ce05319f74b1e73cf82d4f8b59b46e9c0ddc"}, + {file = "SQLAlchemy-2.0.30-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2753743c2afd061bb95a61a51bbb6a1a11ac1c44292fad898f10c9839a7f75b2"}, + {file = "SQLAlchemy-2.0.30-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a7bfc726d167f425d4c16269a9a10fe8630ff6d14b683d588044dcef2d0f6be7"}, + {file = "SQLAlchemy-2.0.30-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c4f61ada6979223013d9ab83a3ed003ded6959eae37d0d685db2c147e9143797"}, + {file = "SQLAlchemy-2.0.30-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:3a365eda439b7a00732638f11072907c1bc8e351c7665e7e5da91b169af794af"}, + {file = "SQLAlchemy-2.0.30-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bba002a9447b291548e8d66fd8c96a6a7ed4f2def0bb155f4f0a1309fd2735d5"}, + {file = "SQLAlchemy-2.0.30-cp310-cp310-win32.whl", hash = "sha256:0138c5c16be3600923fa2169532205d18891b28afa817cb49b50e08f62198bb8"}, + {file = "SQLAlchemy-2.0.30-cp310-cp310-win_amd64.whl", hash = "sha256:99650e9f4cf3ad0d409fed3eec4f071fadd032e9a5edc7270cd646a26446feeb"}, + {file = "SQLAlchemy-2.0.30-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:955991a09f0992c68a499791a753523f50f71a6885531568404fa0f231832aa0"}, + {file = "SQLAlchemy-2.0.30-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f69e4c756ee2686767eb80f94c0125c8b0a0b87ede03eacc5c8ae3b54b99dc46"}, + {file = "SQLAlchemy-2.0.30-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69c9db1ce00e59e8dd09d7bae852a9add716efdc070a3e2068377e6ff0d6fdaa"}, + {file = "SQLAlchemy-2.0.30-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a1429a4b0f709f19ff3b0cf13675b2b9bfa8a7e79990003207a011c0db880a13"}, + {file = "SQLAlchemy-2.0.30-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:efedba7e13aa9a6c8407c48facfdfa108a5a4128e35f4c68f20c3407e4376aa9"}, + {file = "SQLAlchemy-2.0.30-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:16863e2b132b761891d6c49f0a0f70030e0bcac4fd208117f6b7e053e68668d0"}, + {file = "SQLAlchemy-2.0.30-cp311-cp311-win32.whl", hash = "sha256:2ecabd9ccaa6e914e3dbb2aa46b76dede7eadc8cbf1b8083c94d936bcd5ffb49"}, + {file = "SQLAlchemy-2.0.30-cp311-cp311-win_amd64.whl", hash = "sha256:0b3f4c438e37d22b83e640f825ef0f37b95db9aa2d68203f2c9549375d0b2260"}, + {file = "SQLAlchemy-2.0.30-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5a79d65395ac5e6b0c2890935bad892eabb911c4aa8e8015067ddb37eea3d56c"}, + {file = "SQLAlchemy-2.0.30-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9a5baf9267b752390252889f0c802ea13b52dfee5e369527da229189b8bd592e"}, + {file = "SQLAlchemy-2.0.30-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3cb5a646930c5123f8461f6468901573f334c2c63c795b9af350063a736d0134"}, + {file = "SQLAlchemy-2.0.30-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:296230899df0b77dec4eb799bcea6fbe39a43707ce7bb166519c97b583cfcab3"}, + {file = "SQLAlchemy-2.0.30-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:c62d401223f468eb4da32627bffc0c78ed516b03bb8a34a58be54d618b74d472"}, + {file = "SQLAlchemy-2.0.30-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3b69e934f0f2b677ec111b4d83f92dc1a3210a779f69bf905273192cf4ed433e"}, + {file = "SQLAlchemy-2.0.30-cp312-cp312-win32.whl", hash = "sha256:77d2edb1f54aff37e3318f611637171e8ec71472f1fdc7348b41dcb226f93d90"}, + {file = "SQLAlchemy-2.0.30-cp312-cp312-win_amd64.whl", hash = "sha256:b6c7ec2b1f4969fc19b65b7059ed00497e25f54069407a8701091beb69e591a5"}, + {file = "SQLAlchemy-2.0.30-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:5a8e3b0a7e09e94be7510d1661339d6b52daf202ed2f5b1f9f48ea34ee6f2d57"}, + {file = "SQLAlchemy-2.0.30-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b60203c63e8f984df92035610c5fb76d941254cf5d19751faab7d33b21e5ddc0"}, + {file = "SQLAlchemy-2.0.30-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f1dc3eabd8c0232ee8387fbe03e0a62220a6f089e278b1f0aaf5e2d6210741ad"}, + {file = "SQLAlchemy-2.0.30-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:40ad017c672c00b9b663fcfcd5f0864a0a97828e2ee7ab0c140dc84058d194cf"}, + {file = "SQLAlchemy-2.0.30-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:e42203d8d20dc704604862977b1470a122e4892791fe3ed165f041e4bf447a1b"}, + {file = "SQLAlchemy-2.0.30-cp37-cp37m-win32.whl", hash = "sha256:2a4f4da89c74435f2bc61878cd08f3646b699e7d2eba97144030d1be44e27584"}, + {file = "SQLAlchemy-2.0.30-cp37-cp37m-win_amd64.whl", hash = "sha256:b6bf767d14b77f6a18b6982cbbf29d71bede087edae495d11ab358280f304d8e"}, + {file = "SQLAlchemy-2.0.30-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:bc0c53579650a891f9b83fa3cecd4e00218e071d0ba00c4890f5be0c34887ed3"}, + {file = "SQLAlchemy-2.0.30-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:311710f9a2ee235f1403537b10c7687214bb1f2b9ebb52702c5aa4a77f0b3af7"}, + {file = "SQLAlchemy-2.0.30-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:408f8b0e2c04677e9c93f40eef3ab22f550fecb3011b187f66a096395ff3d9fd"}, + {file = "SQLAlchemy-2.0.30-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37a4b4fb0dd4d2669070fb05b8b8824afd0af57587393015baee1cf9890242d9"}, + {file = "SQLAlchemy-2.0.30-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:a943d297126c9230719c27fcbbeab57ecd5d15b0bd6bfd26e91bfcfe64220621"}, + {file = "SQLAlchemy-2.0.30-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:0a089e218654e740a41388893e090d2e2c22c29028c9d1353feb38638820bbeb"}, + {file = "SQLAlchemy-2.0.30-cp38-cp38-win32.whl", hash = "sha256:fa561138a64f949f3e889eb9ab8c58e1504ab351d6cf55259dc4c248eaa19da6"}, + {file = "SQLAlchemy-2.0.30-cp38-cp38-win_amd64.whl", hash = "sha256:7d74336c65705b986d12a7e337ba27ab2b9d819993851b140efdf029248e818e"}, + {file = "SQLAlchemy-2.0.30-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ae8c62fe2480dd61c532ccafdbce9b29dacc126fe8be0d9a927ca3e699b9491a"}, + {file = "SQLAlchemy-2.0.30-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2383146973a15435e4717f94c7509982770e3e54974c71f76500a0136f22810b"}, + {file = "SQLAlchemy-2.0.30-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8409de825f2c3b62ab15788635ccaec0c881c3f12a8af2b12ae4910a0a9aeef6"}, + {file = "SQLAlchemy-2.0.30-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0094c5dc698a5f78d3d1539853e8ecec02516b62b8223c970c86d44e7a80f6c7"}, + {file = "SQLAlchemy-2.0.30-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:edc16a50f5e1b7a06a2dcc1f2205b0b961074c123ed17ebda726f376a5ab0953"}, + {file = "SQLAlchemy-2.0.30-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:f7703c2010355dd28f53deb644a05fc30f796bd8598b43f0ba678878780b6e4c"}, + {file = "SQLAlchemy-2.0.30-cp39-cp39-win32.whl", hash = "sha256:1f9a727312ff6ad5248a4367358e2cf7e625e98b1028b1d7ab7b806b7d757513"}, + {file = "SQLAlchemy-2.0.30-cp39-cp39-win_amd64.whl", hash = "sha256:a0ef36b28534f2a5771191be6edb44cc2673c7b2edf6deac6562400288664221"}, + {file = "SQLAlchemy-2.0.30-py3-none-any.whl", hash = "sha256:7108d569d3990c71e26a42f60474b4c02c8586c4681af5fd67e51a044fdea86a"}, + {file = "SQLAlchemy-2.0.30.tar.gz", hash = "sha256:2b1708916730f4830bc69d6f49d37f7698b5bd7530aca7f04f785f8849e95255"}, +] + +[package.dependencies] +greenlet = {version = "!=0.4.17", markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\""} +typing-extensions = ">=4.6.0" + +[package.extras] +aiomysql = ["aiomysql (>=0.2.0)", "greenlet (!=0.4.17)"] +aioodbc = ["aioodbc", "greenlet (!=0.4.17)"] +aiosqlite = ["aiosqlite", "greenlet (!=0.4.17)", "typing_extensions (!=3.10.0.1)"] +asyncio = ["greenlet (!=0.4.17)"] +asyncmy = ["asyncmy (>=0.2.3,!=0.2.4,!=0.2.6)", "greenlet (!=0.4.17)"] +mariadb-connector = ["mariadb (>=1.0.1,!=1.1.2,!=1.1.5)"] +mssql = ["pyodbc"] +mssql-pymssql = ["pymssql"] +mssql-pyodbc = ["pyodbc"] +mypy = ["mypy (>=0.910)"] +mysql = ["mysqlclient (>=1.4.0)"] +mysql-connector = ["mysql-connector-python"] +oracle = ["cx_oracle (>=8)"] +oracle-oracledb = ["oracledb (>=1.0.1)"] +postgresql = ["psycopg2 (>=2.7)"] +postgresql-asyncpg = ["asyncpg", "greenlet (!=0.4.17)"] +postgresql-pg8000 = ["pg8000 (>=1.29.1)"] +postgresql-psycopg = ["psycopg (>=3.0.7)"] +postgresql-psycopg2binary = ["psycopg2-binary"] +postgresql-psycopg2cffi = ["psycopg2cffi"] +postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"] +pymysql = ["pymysql"] +sqlcipher = ["sqlcipher3_binary"] + [[package]] name = "syrupy" version = "4.6.1" @@ -1249,6 +1590,21 @@ files = [ {file = "typing_extensions-4.11.0.tar.gz", hash = "sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0"}, ] +[[package]] +name = "typing-inspect" +version = "0.9.0" +description = "Runtime inspection utilities for typing module." +optional = false +python-versions = "*" +files = [ + {file = "typing_inspect-0.9.0-py3-none-any.whl", hash = "sha256:9ee6fc59062311ef8547596ab6b955e1b8aa46242d854bfc78f4f6b0eff35f9f"}, + {file = "typing_inspect-0.9.0.tar.gz", hash = "sha256:b23fc42ff6f6ef6954e4852c1fb512cdd18dbea03134f91f856a95ccc9461f78"}, +] + +[package.dependencies] +mypy-extensions = ">=0.3.0" +typing-extensions = ">=3.7.4" + [[package]] name = "urllib3" version = "2.2.1" @@ -1413,4 +1769,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "540028a924d26a41cf40e1109b2540a642dc668618f338f6bb35d1b40bcaeed4" +content-hash = "e9a538766aca94f7cf2fa991936319857ac32d78e79b0b815f691b81476a27a6" diff --git a/libs/ai-endpoints/pyproject.toml b/libs/ai-endpoints/pyproject.toml index ba34df24..89c36ab1 100644 --- a/libs/ai-endpoints/pyproject.toml +++ b/libs/ai-endpoints/pyproject.toml @@ -15,6 +15,7 @@ python = ">=3.8.1,<4.0" langchain-core = ">=0.1.27,<0.3" aiohttp = "^3.9.1" pillow = ">=10.0.0,<11.0.0" +langchain-community = "^0.2.5" [tool.poetry.group.test] optional = true diff --git a/libs/ai-endpoints/tests/integration_tests/test_vlm_models.py b/libs/ai-endpoints/tests/integration_tests/test_vlm_models.py index c181e175..260073c6 100644 --- a/libs/ai-endpoints/tests/integration_tests/test_vlm_models.py +++ b/libs/ai-endpoints/tests/integration_tests/test_vlm_models.py @@ -19,7 +19,11 @@ # - openai api supports server-side image download, api catalog does not # - ChatNVIDIA does client side download to simulate the same behavior # - ChatNVIDIA will automatically read local files and convert them to base64 +# - openai api uses {"image_url": {"url": "..."}} +# where api catalog uses {"image_url": "..."} # + + @pytest.mark.parametrize( "content", [ @@ -54,3 +58,6 @@ def test_vlm_model( response = chat.invoke([HumanMessage(content=content)]) assert isinstance(response, BaseMessage) assert isinstance(response.content, str) + + for token in chat.stream([HumanMessage(content=content)]): + assert isinstance(token.content, str) From a549a400d286e272b8eab15eed68a53034e3b3d5 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Sat, 6 Jul 2024 20:33:10 -0400 Subject: [PATCH 03/11] add tool calling implementation (invoke; no streaming) --- .../langchain_nvidia_ai_endpoints/_common.py | 8 +- .../chat_models.py | 220 ++++++++++++------ 2 files changed, 162 insertions(+), 66 deletions(-) diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py index dd65ad6f..2a9e673b 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py @@ -368,13 +368,17 @@ def _aggregate_msgs(self, msg_list: Sequence[dict]) -> Tuple[dict, bool]: content_buffer: Dict[str, Any] = dict() content_holder: Dict[Any, Any] = dict() usage_holder: Dict[Any, Any] = dict() #### + finish_reason_holder: Optional[str] = None is_stopped = False for msg in msg_list: usage_holder = msg.get("usage", {}) #### if "choices" in msg: ## Tease out ['choices'][0]...['delta'/'message'] msg = msg.get("choices", [{}])[0] - is_stopped = msg.get("finish_reason", "") == "stop" + # todo: this meeds to be fixed, the fact we only + # use the first choice breaks the interface + finish_reason_holder = msg.get("finish_reason", None) + is_stopped = finish_reason_holder == "stop" msg = msg.get("delta", msg.get("message", msg.get("text", ""))) if not isinstance(msg, dict): msg = {"content": msg} @@ -392,6 +396,8 @@ def _aggregate_msgs(self, msg_list: Sequence[dict]) -> Tuple[dict, bool]: content_holder = {**content_holder, **content_buffer} if usage_holder: content_holder.update(token_usage=usage_holder) #### + if finish_reason_holder: + content_holder.update(finish_reason=finish_reason_holder) return content_holder, is_stopped #################################################################################### diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py index 09400aa3..9c13c924 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py @@ -8,6 +8,7 @@ import os import sys import urllib.parse +import warnings from typing import ( Any, Callable, @@ -15,7 +16,6 @@ Iterator, List, Literal, - Mapping, Optional, Sequence, Type, @@ -23,12 +23,14 @@ ) import requests +from langchain_community.adapters.openai import convert_message_to_dict from langchain_core.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain_core.language_models import BaseChatModel, LanguageModelInput from langchain_core.messages import ( + AIMessage, BaseMessage, ChatMessage, ChatMessageChunk, @@ -41,6 +43,7 @@ from langchain_core.pydantic_v1 import BaseModel, Field, PrivateAttr from langchain_core.runnables import Runnable from langchain_core.tools import BaseTool +from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_nvidia_ai_endpoints._common import _NVIDIAClient from langchain_nvidia_ai_endpoints._statics import Model @@ -116,6 +119,53 @@ def _url_to_b64_string(image_source: str) -> str: raise ValueError(f"Unable to process the provided image source: {e}") +def _nv_vlm_adjust_input(message_dict: Dict[str, Any]) -> Dict[str, Any]: + """ + The NVIDIA VLM API input message.content: + { + "role": "user", + "content": [ + ..., + { + "type": "image_url", + "image_url": "{data}" + }, + ... + ] + } + where OpenAI VLM API input message.content: + { + "role": "user", + "content": [ + ..., + { + "type": "image_url", + "image_url": { + "url": "{url | data}" + } + }, + ... + ] + } + + This function converts the OpenAI VLM API input message to + NVIDIA VLM API input message, in place. + + In the process, it accepts a url or file and converts them to + data urls. + """ + if content := message_dict.get("content"): + if isinstance(content, list): + for part in content: + if isinstance(part, dict) and "image_url" in part: + if ( + isinstance(part["image_url"], dict) + and "url" in part["image_url"] + ): + part["image_url"] = _url_to_b64_string(part["image_url"]["url"]) + return message_dict + + class ChatNVIDIA(BaseChatModel): """NVIDIA chat model. @@ -209,12 +259,22 @@ def _generate( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - inputs = self._custom_preprocess(messages) + inputs = [ + _nv_vlm_adjust_input(message) + for message in [convert_message_to_dict(message) for message in messages] + ] payload = self._get_payload(inputs=inputs, stop=stop, stream=False, **kwargs) response = self._client.client.get_req(payload=payload) responses, _ = self._client.client.postprocess(response) self._set_callback_out(responses, run_manager) - message = ChatMessage(**self._custom_postprocess(responses)) + parsed_response = self._custom_postprocess(responses) + # arguably we should always return an AIMessage, but to maintain + # API compatibility, we only return it for tool_calls. we can + # change this for an API breaking 1.0. + if "tool_calls" in parsed_response["additional_kwargs"]: + message: BaseMessage = AIMessage(**parsed_response) + else: + message = ChatMessage(**parsed_response) generation = ChatGeneration(message=message) return ChatResult(generations=[generation], llm_output=responses) @@ -226,10 +286,14 @@ def _stream( **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: """Allows streaming to model!""" - inputs = self._custom_preprocess(messages) + inputs = [ + _nv_vlm_adjust_input(message) + for message in [convert_message_to_dict(message) for message in messages] + ] payload = self._get_payload(inputs=inputs, stop=stop, stream=True, **kwargs) for response in self._client.client.get_req_stream(payload=payload): self._set_callback_out(response, run_manager) + # todo: AIMessageChunk for tool_calls chunk = ChatGenerationChunk( message=ChatMessageChunk(**self._custom_postprocess(response)) ) @@ -248,54 +312,6 @@ def _set_callback_out( if hasattr(cb, "llm_output"): cb.llm_output = result - def _custom_preprocess( # todo: remove - self, msg_list: Sequence[BaseMessage] - ) -> List[Dict[str, str]]: - def _preprocess_msg(msg: BaseMessage) -> Dict[str, str]: - if isinstance(msg, BaseMessage): - role_convert = {"ai": "assistant", "human": "user"} - if isinstance(msg, ChatMessage): - role = msg.role - else: - role = msg.type - role = role_convert.get(role, role) - content = self._process_content(msg.content) - return {"role": role, "content": content} - raise ValueError(f"Invalid message: {repr(msg)} of type {type(msg)}") - - return [_preprocess_msg(m) for m in msg_list] - - def _process_content(self, content: Union[str, List[Union[dict, str]]]) -> str: - if isinstance(content, str): - return content - string_array: list = [] - - for part in content: - if isinstance(part, str): - string_array.append(part) - elif isinstance(part, Mapping): - # OpenAI Format - if "type" in part: - if part["type"] == "text": - string_array.append(str(part["text"])) - elif part["type"] == "image_url": - img_url = part["image_url"] - if isinstance(img_url, dict): - if "url" not in img_url: - raise ValueError( - f"Unrecognized message image format: {img_url}" - ) - img_url = img_url["url"] - b64_string = _url_to_b64_string(img_url) - string_array.append(f'') - else: - raise ValueError( - f"Unrecognized message part type: {part['type']}" - ) - else: - raise ValueError(f"Unrecognized message part format: {part}") - return "".join(string_array) - def _custom_postprocess(self, msg: dict) -> dict: # todo: remove kw_left = msg.copy() out_dict = { @@ -306,9 +322,8 @@ def _custom_postprocess(self, msg: dict) -> dict: # todo: remove "additional_kwargs": {}, "response_metadata": {}, } - for k in list(kw_left.keys()): - if "tool" in k: - out_dict["additional_kwargs"][k] = kw_left.pop(k) + if tool_calls := kw_left.pop("tool_calls", None): + out_dict["additional_kwargs"]["tool_calls"] = tool_calls out_dict["response_metadata"] = kw_left return out_dict @@ -365,13 +380,92 @@ def bind_tools( self, tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], *, - tool_choice: Optional[Union[dict, str, Literal["auto", "none"], bool]] = None, + tool_choice: Optional[ + Union[dict, str, Literal["auto", "none", "any", "required"], bool] + ] = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: - raise NotImplementedError( - "Not implemented, awaiting server-side function-recieving API" - " Consider following open-source LLM agent spec techniques:" - " https://huggingface.co/blog/open-source-llms-as-agents" + """ + Bind tools to the model. + + Args: + tools (list): A list of tools to bind to the model. + tool_choice (Optional[Union[dict, + str, + Literal["auto", "none", "any", "required"], + bool]]): + Control tool choice. + "any" and "required" - force a tool call. + "auto" - let the model decide. + "none" - force no tool call. + string or dict - force a specific tool call. + bool - if True, force a tool call; if False, force no tool call. + Defaults to passing no value. + **kwargs: Additional keyword arguments. + + see https://python.langchain.com/v0.1/docs/modules/model_io/chat/function_calling/#request-forcing-a-tool-call + """ + # check if the model supports tools, warn if it does not + known_good = False + # todo: we need to store model: Model in this class + # instead of model: str (= Model.id) + # this should be: if not self.model.supports_tools: warnings.warn... + candidates = [ + model for model in self.available_models if model.id == self.model + ] + if not candidates: # user must have specified the model themselves + known_good = False + else: + assert len(candidates) == 1, "Multiple models with the same id" + known_good = candidates[0].supports_tools is True + if not known_good: + warnings.warn( + f"Model '{self.model}' is not known to support tools. " + "Your tool binding may fail at inference time." + ) + + tool_name = None + if isinstance(tool_choice, bool): + tool_choice = "required" if tool_choice else "none" + elif isinstance(tool_choice, str): + # LangChain documents "any" as an option, server API uses "required" + if tool_choice == "any": + tool_choice = "required" + # if a string that's not "auto", "none", or "required" + # then it must be a tool name + if tool_choice not in ["auto", "none", "required"]: + tool_name = tool_choice + tool_choice = { + "type": "function", + "function": {"name": tool_choice}, + } + elif isinstance(tool_choice, dict): + # if a dict, it must be a tool choice dict, e.g. + # {"type": "function", "function": {"name": "my_tool"}} + if "type" not in tool_choice: + tool_choice["type"] = "function" + if "function" not in tool_choice: + raise ValueError("Tool choice dict must have a 'function' key") + if "name" not in tool_choice["function"]: + raise ValueError("Tool choice function dict must have a 'name' key") + tool_name = tool_choice["function"]["name"] + + # check that the specified tool is in the tools list + if tool_name: + if not any( + isinstance(tool, BaseTool) and tool.name == tool_name for tool in tools + ) and not any( + isinstance(tool, dict) and tool.get("name") == tool_name + for tool in tools + ): + raise ValueError( + f"Tool choice '{tool_name}' not found in the tools list" + ) + + return super().bind( + tools=[convert_to_openai_tool(tool) for tool in tools], + tool_choice=tool_choice, + **kwargs, ) def bind_functions( @@ -380,11 +474,7 @@ def bind_functions( function_call: Optional[str] = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: - raise NotImplementedError( - "Not implemented, awaiting server-side function-recieving API" - " Consider following open-source LLM agent spec techniques:" - " https://huggingface.co/blog/open-source-llms-as-agents" - ) + raise NotImplementedError("Not implemented, use `bind_tools` instead.") def with_structured_output( self, From 2b9779e813aa3362e643a4cceea4feebdac095bf Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Wed, 10 Jul 2024 08:38:38 -0400 Subject: [PATCH 04/11] add tests for tool calling (streaming) --- .../integration_tests/test_bind_tools.py | 310 +++++++++++++++--- 1 file changed, 271 insertions(+), 39 deletions(-) diff --git a/libs/ai-endpoints/tests/integration_tests/test_bind_tools.py b/libs/ai-endpoints/tests/integration_tests/test_bind_tools.py index 83dff6b0..15c47832 100644 --- a/libs/ai-endpoints/tests/integration_tests/test_bind_tools.py +++ b/libs/ai-endpoints/tests/integration_tests/test_bind_tools.py @@ -1,9 +1,15 @@ import json import warnings -from typing import Any, List, Literal, Optional, Union +from typing import Any, Callable, List, Literal, Optional, Union import pytest -from langchain_core.messages import AIMessage, ChatMessage +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + BaseMessageChunk, + ChatMessage, +) from langchain_core.pydantic_v1 import Field from langchain_core.tools import tool @@ -34,14 +40,18 @@ # test types: # 17. deterministic (minimial accuracy tests; relies on basic tool calling skills) # 18. accuracy (proper tool; proper arguments) -# negative tests: +# edge/negative tests: # 19. require unknown named tool (invoke/stream only) # 20. partial function (invoke/stream only) +# 21. not enough tokens to generate tool call +# 22. tool with no arguments +# 23. duplicate tool names +# 24. unknown tool (invoke/stream only) # -# todo: streaming -# todo: test tool with no arguments +# todo: async methods # todo: parallel_tool_calls +# todo: too many tools @tool @@ -62,6 +72,31 @@ def zzyyxx( return (b**a) % (a - b) +@tool +def tool_no_args() -> str: + """8-ball""" + return "lookin' good" + + +def eval_stream(llm: ChatNVIDIA, msg: str, tool_choice: Any = None) -> BaseMessageChunk: + if tool_choice: + generator = llm.stream(msg, tool_choice=tool_choice) # type: ignore + else: + generator = llm.stream(msg) + response = next(generator) + for chunk in generator: + assert isinstance(chunk, AIMessageChunk) + response += chunk + return response + + +def eval_invoke(llm: ChatNVIDIA, msg: str, tool_choice: Any = None) -> BaseMessage: + if tool_choice: + return llm.invoke(msg, tool_choice=tool_choice) # type: ignore + else: + return llm.invoke(msg) + + def check_response_structure(response: AIMessage) -> None: assert not response.content # should be `response.content is None` but # AIMessage.content: Union[str, List[Union[str, Dict]]] cannot be None. @@ -77,36 +112,49 @@ def check_response_structure(response: AIMessage) -> None: assert len(response.tool_calls) > 0 -# users can also get at the tool calls from the response.additional_kwargs +@pytest.mark.parametrize( + "func", + [eval_invoke, eval_stream], + ids=["invoke", "stream"], +) @pytest.mark.xfail(reason="Accuracy test") -def test_accuracy_default_invoke_additional_kwargs(tool_model: str, mode: dict) -> None: +def test_accuracy_extra(tool_model: str, mode: dict, func: Callable) -> None: llm = ChatNVIDIA(temperature=0, model=tool_model, **mode).bind_tools([xxyyzz]) - response = llm.invoke("What is 11 xxyyzz 3?") + response = func(llm, "What is 11 xxyyzz 3?") assert not response.content # should be `response.content is None` but # AIMessage.content: Union[str, List[Union[str, Dict]]] cannot be None. assert response.additional_kwargs is not None - assert "tool_calls" in response.additional_kwargs - assert isinstance(response.additional_kwargs["tool_calls"], list) - assert response.additional_kwargs["tool_calls"] - for tool_call in response.additional_kwargs["tool_calls"]: - assert "id" in tool_call - assert tool_call["id"] is not None - assert "type" in tool_call - assert tool_call["type"] == "function" - assert "function" in tool_call + # todo: this is not good, should not care about the param + if func == eval_invoke: + assert isinstance(response, AIMessage) + assert "tool_calls" in response.additional_kwargs + assert isinstance(response.additional_kwargs["tool_calls"], list) + assert response.additional_kwargs["tool_calls"] + assert response.tool_calls + for tool_call in response.additional_kwargs["tool_calls"]: + assert "id" in tool_call + assert tool_call["id"] is not None + assert "type" in tool_call + assert tool_call["type"] == "function" + assert "function" in tool_call + assert len(response.additional_kwargs["tool_calls"]) > 0 + tool_call = response.additional_kwargs["tool_calls"][0] + assert tool_call["function"]["name"] == "xxyyzz" + assert json.loads(tool_call["function"]["arguments"]) == {"a": 11, "b": 3} + else: + assert isinstance(response, AIMessageChunk) + assert response.tool_call_chunks assert response.response_metadata is not None assert isinstance(response.response_metadata, dict) - assert "content" in response.response_metadata - assert response.response_metadata["content"] is None + if "content" in response.response_metadata: + assert response.response_metadata["content"] is None + assert "model_name" in response.response_metadata + assert response.response_metadata["model_name"] == tool_model assert "finish_reason" in response.response_metadata assert response.response_metadata["finish_reason"] in [ "tool_calls", "stop", ] # todo: remove "stop" - assert len(response.additional_kwargs["tool_calls"]) > 0 - tool_call = response.additional_kwargs["tool_calls"][0] - assert tool_call["function"]["name"] == "xxyyzz" - assert json.loads(tool_call["function"]["arguments"]) == {"a": 11, "b": 3} @pytest.mark.parametrize( @@ -119,12 +167,17 @@ def test_accuracy_default_invoke_additional_kwargs(tool_model: str, mode: dict) ], ids=["none", "required", "partial", "function"], ) -def test_invoke_tool_choice_with_no_tool( - tool_model: str, mode: dict, tool_choice: Any +@pytest.mark.parametrize( + "func", + [eval_invoke, eval_stream], + ids=["invoke", "stream"], +) +def test_tool_choice_with_no_tool( + tool_model: str, mode: dict, tool_choice: Any, func: Callable ) -> None: llm = ChatNVIDIA(model=tool_model, **mode) with pytest.raises(Exception) as e: - llm.invoke("What is 11 xxyyzz 3?", tool_choice=tool_choice) + func(llm, "What is 11 xxyyzz 3?", tool_choice=tool_choice) assert "400" in str(e.value) or "###" in str( e.value ) # todo: stop transforming 400 -> ### @@ -139,10 +192,14 @@ def test_invoke_tool_choice_with_no_tool( ) -def test_invoke_tool_choice_none(tool_model: str, mode: dict) -> None: +@pytest.mark.parametrize( + "func", + [eval_invoke, eval_stream], + ids=["invoke", "stream"], +) +def test_tool_choice_none(tool_model: str, mode: dict, func: Callable) -> None: llm = ChatNVIDIA(model=tool_model, **mode).bind_tools(tools=[xxyyzz]) - response = llm.invoke("What is 11 xxyyzz 3?", tool_choice="none") # type: ignore - assert isinstance(response, ChatMessage) + response = func(llm, "What is 11 xxyyzz 3?", tool_choice="none") assert "tool_calls" not in response.additional_kwargs @@ -153,20 +210,173 @@ def test_invoke_tool_choice_none(tool_model: str, mode: dict) -> None: ], ids=["partial"], ) -def test_invoke_tool_choice_negative( +@pytest.mark.parametrize( + "func", + [eval_invoke, eval_stream], + ids=["invoke", "stream"], +) +def test_tool_choice_negative( tool_model: str, mode: dict, tool_choice: Optional[ Union[dict, str, Literal["auto", "none", "any", "required"], bool] ], + func: Callable, ) -> None: llm = ChatNVIDIA(model=tool_model, **mode).bind_tools([xxyyzz]) with pytest.raises(Exception) as e: - llm.invoke("What is 11 xxyyzz 3?", tool_choice=tool_choice) # type: ignore + func(llm, "What is 11 xxyyzz 3?", tool_choice=tool_choice) + assert "400" in str(e.value) or "###" in str( + e.value + ) # todo: stop transforming 400 -> ### + assert ( + "invalid_request_error" in str(e.value) + or "value_error" in str(e.value) + or "Incorrectly formatted `tool_choice`" in str(e.value) + ) + + +@pytest.mark.parametrize( + "func", + [eval_invoke, eval_stream], + ids=["invoke", "stream"], +) +def test_tool_choice_negative_max_tokens_required( + tool_model: str, + mode: dict, + func: Callable, +) -> None: + llm = ChatNVIDIA(max_tokens=5, model=tool_model, **mode).bind_tools([xxyyzz]) + with pytest.raises(Exception) as e: + func(llm, "What is 11 xxyyzz 3?", tool_choice="required") assert "400" in str(e.value) or "###" in str( e.value ) # todo: stop transforming 400 -> ### - assert "invalid_request_error" in str(e.value) or "value_error" in str(e.value) + assert "invalid_request_error" in str(e.value) + assert ( + "Could not finish the message because max_tokens was reached. " + "Please try again with higher max_tokens." + ) in str(e.value) + + +@pytest.mark.parametrize( + "func", + [eval_invoke, eval_stream], + ids=["invoke", "stream"], +) +def test_tool_choice_negative_max_tokens_function( + tool_model: str, + mode: dict, + func: Callable, +) -> None: + llm = ChatNVIDIA(max_tokens=5, model=tool_model, **mode).bind_tools([xxyyzz]) + response = func( + llm, + "What is 11 xxyyzz 3?", + tool_choice={"type": "function", "function": {"name": "xxyyzz"}}, + ) + # todo: this is not good, should not care about the param + if func == eval_invoke: + assert isinstance(response, AIMessage) + assert "tool_calls" in response.additional_kwargs + assert response.invalid_tool_calls + else: + assert isinstance(response, AIMessageChunk) + assert response.tool_call_chunks + assert "finish_reason" in response.response_metadata + assert response.response_metadata["finish_reason"] == "length" + + +@pytest.mark.parametrize( + "tool_choice", + [ + "required", + {"type": "function", "function": {"name": "tool_no_args"}}, + ], + ids=["required", "function"], +) +@pytest.mark.parametrize( + "func", + [eval_invoke, eval_stream], + ids=["invoke", "stream"], +) +def test_tool_choice_negative_no_args( + tool_model: str, + mode: dict, + tool_choice: Optional[ + Union[dict, str, Literal["auto", "none", "any", "required"], bool] + ], + func: Callable, +) -> None: + llm = ChatNVIDIA(model=tool_model, **mode).bind_tools([tool_no_args]) + response = func(llm, "What does the 8-ball say?", tool_choice=tool_choice) + # todo: this is not good, should not care about the param + if func == eval_invoke: + assert isinstance(response, AIMessage) + assert response.tool_calls + else: + assert isinstance(response, AIMessageChunk) + assert response.tool_call_chunks + # assert "tool_calls" in response.additional_kwargs + + +@pytest.mark.parametrize( + "tool_choice", + [ + "required", + {"type": "function", "function": {"name": "tool_no_args"}}, + ], + ids=["required", "function"], +) +@pytest.mark.parametrize( + "func", + [eval_invoke, eval_stream], + ids=["invoke", "stream"], +) +@pytest.mark.xfail(reason="Accuracy test") +def test_accuracy_tool_choice_negative_no_args( + tool_model: str, + mode: dict, + tool_choice: Optional[ + Union[dict, str, Literal["auto", "none", "any", "required"], bool] + ], + func: Callable, +) -> None: + llm = ChatNVIDIA(model=tool_model, **mode).bind_tools([tool_no_args]) + response = func(llm, "What does the 8-ball say?", tool_choice=tool_choice) + assert isinstance(response, AIMessage) + # assert "tool_calls" in response.additional_kwargs + assert response.tool_calls + assert response.tool_calls[0]["name"] == "tool_no_args" + assert response.tool_calls[0]["args"] == {} + + +@pytest.mark.parametrize( + "tool_choice", + [ + "required", + {"type": "function", "function": {"name": "xxyyzz"}}, + ], + ids=["required", "function"], +) +@pytest.mark.parametrize( + "func", + [eval_invoke, eval_stream], + ids=["invoke", "stream"], +) +def test_tool_choice_negative_duplicate_tool( + tool_model: str, + mode: dict, + tool_choice: Optional[ + Union[dict, str, Literal["auto", "none", "any", "required"], bool] + ], + func: Callable, +) -> None: + llm = ChatNVIDIA(model=tool_model, **mode).bind_tools([xxyyzz, xxyyzz]) + response = func(llm, "What is 11 xxyyzz 3?", tool_choice=tool_choice) + assert isinstance(response, AIMessage) + assert response.tool_calls + # assert "tool_calls" in response.additional_kwargs @pytest.mark.parametrize( @@ -177,15 +387,21 @@ def test_invoke_tool_choice_negative( ], ids=["required", "function"], ) -def test_invoke_tool_choice( +@pytest.mark.parametrize( + "func", + [eval_invoke, eval_stream], + ids=["invoke", "stream"], +) +def test_tool_choice( tool_model: str, mode: dict, tool_choice: Optional[ Union[dict, str, Literal["auto", "none", "any", "required"], bool] ], + func: Callable, ) -> None: llm = ChatNVIDIA(model=tool_model, **mode).bind_tools([xxyyzz]) - response = llm.invoke("What is 11 xxyyzz 3?", tool_choice=tool_choice) # type: ignore + response = func(llm, "What is 11 xxyyzz 3?", tool_choice=tool_choice) assert isinstance(response, AIMessage) check_response_structure(response) @@ -205,15 +421,21 @@ def test_invoke_tool_choice( [[xxyyzz], [xxyyzz, zzyyxx], [zzyyxx, xxyyzz]], ids=["xxyyzz", "xxyyzz_and_zzyyxx", "zzyyxx_and_xxyyzz"], ) +@pytest.mark.parametrize( + "func", + [eval_invoke, eval_stream], + ids=["invoke", "stream"], +) @pytest.mark.xfail(reason="Accuracy test") -def test_accuracy_invoke_tool_choice( +def test_accuracy_tool_choice( tool_model: str, mode: dict, tools: List, tool_choice: Any, + func: Callable, ) -> None: llm = ChatNVIDIA(temperature=0, model=tool_model, **mode).bind_tools(tools) - response = llm.invoke("What is 11 xxyyzz 3?", tool_choice=tool_choice) # type: ignore + response = func(llm, "What is 11 xxyyzz 3?", tool_choice=tool_choice) assert isinstance(response, AIMessage) check_response_structure(response) tool_call = response.tool_calls[0] @@ -221,13 +443,23 @@ def test_accuracy_invoke_tool_choice( assert tool_call["args"] == {"b": 3, "a": 11} -def test_invoke_tool_choice_with_unknown_tool(tool_model: str, mode: dict) -> None: +@pytest.mark.parametrize( + "func", + [eval_invoke, eval_stream], + ids=["invoke", "stream"], +) +def test_tool_choice_negative_unknown_tool( + tool_model: str, + mode: dict, + func: Callable, +) -> None: llm = ChatNVIDIA(model=tool_model, **mode).bind_tools(tools=[xxyyzz]) with pytest.raises(Exception) as e: - llm.invoke( + func( + llm, "What is 11 xxyyzz 3?", tool_choice={"type": "function", "function": {"name": "zzyyxx"}}, - ) # type: ignore + ) assert ( "not found in the tools list" in str(e.value) or "no function named" in str(e.value) From d05acdc939871be54e96f95f4d43e944d910de71 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Wed, 10 Jul 2024 08:39:18 -0400 Subject: [PATCH 05/11] add tool calling implementation (streaming) --- .../langchain_nvidia_ai_endpoints/_statics.py | 16 +++++ .../chat_models.py | 61 +++++++++++++++---- 2 files changed, 65 insertions(+), 12 deletions(-) diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_statics.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_statics.py index d5649623..e3deb95b 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_statics.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_statics.py @@ -1,3 +1,4 @@ +import os import warnings from typing import Literal, Optional @@ -425,6 +426,18 @@ def validate_client(cls, client: str, values: dict) -> str: # ), # } + +OPENAI_MODEL_TABLE = { + "gpt-3.5-turbo": Model( + id="gpt-3.5-turbo", + model_type="chat", + client="ChatNVIDIA", + endpoint="https://api.openai.com/v1/chat/completions", + supports_tools=True, + ), +} + + MODEL_TABLE = { **CHAT_MODEL_TABLE, **QA_MODEL_TABLE, @@ -433,6 +446,9 @@ def validate_client(cls, client: str, values: dict) -> str: **RANKING_MODEL_TABLE, } +if "_INCLUDE_OPENAI" in os.environ: + MODEL_TABLE.update(OPENAI_MODEL_TABLE) + def register_model(model: Model) -> None: """ diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py index 9c13c924..af2138f3 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py @@ -31,9 +31,9 @@ from langchain_core.language_models import BaseChatModel, LanguageModelInput from langchain_core.messages import ( AIMessage, + AIMessageChunk, BaseMessage, ChatMessage, - ChatMessageChunk, ) from langchain_core.outputs import ( ChatGeneration, @@ -267,10 +267,10 @@ def _generate( response = self._client.client.get_req(payload=payload) responses, _ = self._client.client.postprocess(response) self._set_callback_out(responses, run_manager) - parsed_response = self._custom_postprocess(responses) - # arguably we should always return an AIMessage, but to maintain - # API compatibility, we only return it for tool_calls. we can - # change this for an API breaking 1.0. + parsed_response = self._custom_postprocess(responses, streaming=False) + # todo: we should always return an AIMessage, but to maintain + # API compatibility, we only return it for tool_calls. we can + # change this for an API breaking 1.0. if "tool_calls" in parsed_response["additional_kwargs"]: message: BaseMessage = AIMessage(**parsed_response) else: @@ -293,10 +293,16 @@ def _stream( payload = self._get_payload(inputs=inputs, stop=stop, stream=True, **kwargs) for response in self._client.client.get_req_stream(payload=payload): self._set_callback_out(response, run_manager) - # todo: AIMessageChunk for tool_calls - chunk = ChatGenerationChunk( - message=ChatMessageChunk(**self._custom_postprocess(response)) - ) + parsed_response = self._custom_postprocess(response, streaming=True) + # todo: we should always return an AIMessage, but to maintain + # API compatibility, we only return it for tool_calls. we can + # change this for an API breaking 1.0. + # if "tool_calls" in parsed_response["additional_kwargs"]: + # message: BaseMessageChunk = AIMessageChunk(**parsed_response) + # else: + # message = ChatMessageChunk(**parsed_response) + message = AIMessageChunk(**parsed_response) + chunk = ChatGenerationChunk(message=message) if run_manager: run_manager.on_llm_new_token(chunk.text, chunk=chunk) yield chunk @@ -312,7 +318,9 @@ def _set_callback_out( if hasattr(cb, "llm_output"): cb.llm_output = result - def _custom_postprocess(self, msg: dict) -> dict: # todo: remove + def _custom_postprocess( + self, msg: dict, streaming: bool = False + ) -> dict: # todo: remove kw_left = msg.copy() out_dict = { "role": kw_left.pop("role", "assistant") or "assistant", @@ -322,9 +330,38 @@ def _custom_postprocess(self, msg: dict) -> dict: # todo: remove "additional_kwargs": {}, "response_metadata": {}, } + # "tool_calls" is set for invoke and stream responses if tool_calls := kw_left.pop("tool_calls", None): - out_dict["additional_kwargs"]["tool_calls"] = tool_calls - out_dict["response_metadata"] = kw_left + assert isinstance( + tool_calls, list + ), "invalid response from server: tool_calls must be a list" + # todo: break this into post-processing for invoke and stream + if not streaming: + out_dict["additional_kwargs"]["tool_calls"] = tool_calls + elif streaming: + out_dict["tool_call_chunks"] = [] + for tool_call in tool_calls: + assert "index" in tool_call, ( + "invalid response from server: " + "tool_call must have an 'index' key" + ) + assert "function" in tool_call, ( + "invalid response from server: " + "tool_call must have a 'function' key" + ) + out_dict["tool_call_chunks"].append( + { + "index": tool_call.get("index", None), + "id": tool_call.get("id", None), + "name": tool_call["function"].get("name", None), + "args": tool_call["function"].get("arguments", None), + } + ) + # we only create the response_metadata from the last message in a stream. + # if we do it for all messages, we'll end up with things like + # "model_name" = "mode-xyz" * # messages. + if "finish_reason" in kw_left: + out_dict["response_metadata"] = kw_left return out_dict ###################################################################################### From 493020fe7cc432f7547af102d664f75b97b45be9 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Mon, 15 Jul 2024 14:29:03 -0400 Subject: [PATCH 06/11] add tests for parallel_tool_calls (invoke, stream) --- .../integration_tests/test_bind_tools.py | 139 +++++++++++++-- .../unit_tests/test_parallel_tool_calls.py | 160 ++++++++++++++++++ 2 files changed, 289 insertions(+), 10 deletions(-) create mode 100644 libs/ai-endpoints/tests/unit_tests/test_parallel_tool_calls.py diff --git a/libs/ai-endpoints/tests/integration_tests/test_bind_tools.py b/libs/ai-endpoints/tests/integration_tests/test_bind_tools.py index 15c47832..1d64e576 100644 --- a/libs/ai-endpoints/tests/integration_tests/test_bind_tools.py +++ b/libs/ai-endpoints/tests/integration_tests/test_bind_tools.py @@ -47,10 +47,13 @@ # 22. tool with no arguments # 23. duplicate tool names # 24. unknown tool (invoke/stream only) -# +# ways to specify parallel_tool_calls: (accuracy only) +# 25. invoke +# 26. stream +# todo: parallel_tool_calls w/ bind_tools +# todo: parallel_tool_calls w/ tool_choice = function # todo: async methods -# todo: parallel_tool_calls # todo: too many tools @@ -78,11 +81,31 @@ def tool_no_args() -> str: return "lookin' good" -def eval_stream(llm: ChatNVIDIA, msg: str, tool_choice: Any = None) -> BaseMessageChunk: +@tool +def get_current_weather( + location: str = Field(..., description="The location to get the weather for"), + scale: Optional[str] = Field( + default="Fahrenheit", + description="The temperature scale (e.g., Celsius or Fahrenheit)", + ), +) -> str: + """Get the current weather for a location""" + return f"The current weather in {location} is sunny." + + +def eval_stream( + llm: ChatNVIDIA, + msg: str, + tool_choice: Any = None, + parallel_tool_calls: bool = False, +) -> BaseMessageChunk: + params = {} if tool_choice: - generator = llm.stream(msg, tool_choice=tool_choice) # type: ignore - else: - generator = llm.stream(msg) + params["tool_choice"] = tool_choice + if parallel_tool_calls: + params["parallel_tool_calls"] = True + + generator = llm.stream(msg, **params) # type: ignore response = next(generator) for chunk in generator: assert isinstance(chunk, AIMessageChunk) @@ -90,11 +113,19 @@ def eval_stream(llm: ChatNVIDIA, msg: str, tool_choice: Any = None) -> BaseMessa return response -def eval_invoke(llm: ChatNVIDIA, msg: str, tool_choice: Any = None) -> BaseMessage: +def eval_invoke( + llm: ChatNVIDIA, + msg: str, + tool_choice: Any = None, + parallel_tool_calls: bool = False, +) -> BaseMessage: + params = {} if tool_choice: - return llm.invoke(msg, tool_choice=tool_choice) # type: ignore - else: - return llm.invoke(msg) + params["tool_choice"] = tool_choice + if parallel_tool_calls: + params["parallel_tool_calls"] = True + + return llm.invoke(msg, **params) # type: ignore def check_response_structure(response: AIMessage) -> None: @@ -612,3 +643,91 @@ def test_unknown_warns(mode: dict) -> None: ChatNVIDIA(model="mock-model", **mode).bind_tools([xxyyzz]) assert len(record) == 1 assert "not known to support tools" in str(record[0].message) + + +@pytest.mark.parametrize( + "tool_choice", + [ + "auto", + None, + "required", + ], + ids=["auto", "absent", "required"], +) +@pytest.mark.parametrize( + "tools", + [[xxyyzz, zzyyxx], [zzyyxx, xxyyzz]], + ids=["xxyyzz_and_zzyyxx", "zzyyxx_and_xxyyzz"], +) +@pytest.mark.parametrize( + "func", + [eval_invoke, eval_stream], + ids=["invoke", "stream"], +) +@pytest.mark.xfail(reason="Accuracy test") +def test_accuracy_parallel_tool_calls_hard( + tool_model: str, + mode: dict, + tools: List, + tool_choice: Any, + func: Callable, +) -> None: + llm = ChatNVIDIA(seed=42, temperature=1, model=tool_model, **mode).bind_tools(tools) + response = func( + llm, + "What is 11 xxyyzz 3 zzyyxx 5?", + tool_choice=tool_choice, + parallel_tool_calls=True, + ) + assert isinstance(response, AIMessage) + check_response_structure(response) + assert len(response.tool_calls) == 2 + valid_tool_names = ["xxyyzz", "zzyyxx"] + tool_call0 = response.tool_calls[0] + assert tool_call0["name"] in valid_tool_names + valid_tool_names.remove(tool_call0["name"]) + tool_call1 = response.tool_calls[1] + assert tool_call1["name"] in valid_tool_names + + +@pytest.mark.parametrize( + "tool_choice", + [ + "auto", + None, + "required", + ], + ids=["auto", "absent", "required"], +) +@pytest.mark.parametrize( + "func", + [eval_invoke, eval_stream], + ids=["invoke", "stream"], +) +@pytest.mark.xfail(reason="Accuracy test") +def test_accuracy_parallel_tool_calls_easy( + tool_model: str, + mode: dict, + tool_choice: Any, + func: Callable, +) -> None: + llm = ChatNVIDIA(seed=42, temperature=1, model=tool_model, **mode).bind_tools( + tools=[get_current_weather], + ) + response = func( + llm, + "What is the weather in Boston, and what is the weather in Dublin?", + tool_choice=tool_choice, + parallel_tool_calls=True, + ) + assert isinstance(response, AIMessage) + check_response_structure(response) + assert len(response.tool_calls) == 2 + valid_args = ["Boston", "Dublin"] + tool_call0 = response.tool_calls[0] + assert tool_call0["name"] == "get_current_weather" + assert tool_call0["args"]["location"] in valid_args + valid_args.remove(tool_call0["args"]["location"]) + tool_call1 = response.tool_calls[1] + assert tool_call1["name"] == "get_current_weather" + assert tool_call1["args"]["location"] in valid_args diff --git a/libs/ai-endpoints/tests/unit_tests/test_parallel_tool_calls.py b/libs/ai-endpoints/tests/unit_tests/test_parallel_tool_calls.py new file mode 100644 index 00000000..7f1abd2d --- /dev/null +++ b/libs/ai-endpoints/tests/unit_tests/test_parallel_tool_calls.py @@ -0,0 +1,160 @@ +import warnings + +import pytest +import requests_mock +from langchain_core.messages import AIMessage + +from langchain_nvidia_ai_endpoints import ChatNVIDIA + + +@pytest.fixture(autouse=True) +def mock_v1_models(requests_mock: requests_mock.Mocker) -> None: + requests_mock.get( + "https://integrate.api.nvidia.com/v1/models", + json={ + "data": [ + { + "id": "magic-model", + "object": "model", + "created": 1234567890, + "owned_by": "OWNER", + "root": "magic-model", + }, + ] + }, + ) + + +def test_invoke_parallel_tool_calls(requests_mock: requests_mock.Mocker) -> None: + requests_mock.post( + "https://integrate.api.nvidia.com/v1/chat/completions", + json={ + "id": "cmpl-100f0463deb8421480ab18ed32cb2581", + "object": "chat.completion", + "created": 1721154188, + "model": "magic-model", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "chatcmpl-tool-7980a682cc24446a8da9148c2c3e37ce", + "type": "function", + "function": { + "name": "xxyyzz", + "arguments": '{"a": 11, "b": 3}', + }, + }, + { + "id": "chatcmpl-tool-299964d0c5fe4fc1b917c8eaabd1cda2", + "type": "function", + "function": { + "name": "zzyyxx", + "arguments": '{"a": 11, "b": 5}', + }, + }, + ], + }, + "logprobs": None, + "finish_reason": "tool_calls", + "stop_reason": None, + } + ], + "usage": { + "prompt_tokens": 194, + "total_tokens": 259, + "completion_tokens": 65, + }, + }, + ) + + warnings.filterwarnings("ignore", r".*Found magic-model in available_models.*") + llm = ChatNVIDIA(model="magic-model") + response = llm.invoke( + "What is 11 xxyyzz 3 zzyyxx 5?", + ) + assert isinstance(response, AIMessage) + assert len(response.tool_calls) == 2 + tool_call0 = response.tool_calls[0] + assert tool_call0["name"] == "xxyyzz" + assert tool_call0["args"] == {"b": 3, "a": 11} + tool_call1 = response.tool_calls[1] + assert tool_call1["name"] == "zzyyxx" + assert tool_call1["args"] == {"b": 5, "a": 11} + + +def test_stream_parallel_tool_calls_A(requests_mock: requests_mock.Mocker) -> None: + response_contents = "\n\n".join( + [ + 'data: {"id":"chatcmpl-ID0","object":"chat.completion.chunk","created":1721155403,"model":"magic-model","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":null},"logprobs":null,"finish_reason":null}]}', # noqa: E501 + 'data: {"id":"chatcmpl-ID0","object":"chat.completion.chunk","created":1721155403,"model":"magic-model","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_ID0","type":"function","function":{"name":"xxyyzz","arguments":""}}]},"logprobs":null,"finish_reason":null}]}', # noqa: E501 + 'data: {"id":"chatcmpl-ID0","object":"chat.completion.chunk","created":1721155403,"model":"magic-model","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\\"a\\""}}]},"logprobs":null,"finish_reason":null}]}', # noqa: E501 + 'data: {"id":"chatcmpl-ID0","object":"chat.completion.chunk","created":1721155403,"model":"magic-model","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":": 11,"}}]},"logprobs":null,"finish_reason":null}]}', # noqa: E501 + 'data: {"id":"chatcmpl-ID0","object":"chat.completion.chunk","created":1721155403,"model":"magic-model","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" \\"b\\": "}}]},"logprobs":null,"finish_reason":null}]}', # noqa: E501 + 'data: {"id":"chatcmpl-ID0","object":"chat.completion.chunk","created":1721155403,"model":"magic-model","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"3}"}}]},"logprobs":null,"finish_reason":null}]}', # noqa: E501 + 'data: {"id":"chatcmpl-ID0","object":"chat.completion.chunk","created":1721155403,"model":"magic-model","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"id":"call_ID1","type":"function","function":{"name":"zzyyxx","arguments":""}}]},"logprobs":null,"finish_reason":null}]}', # noqa: E501 + 'data: {"id":"chatcmpl-ID0","object":"chat.completion.chunk","created":1721155403,"model":"magic-model","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"{\\"a\\""}}]},"logprobs":null,"finish_reason":null}]}', # noqa: E501 + 'data: {"id":"chatcmpl-ID0","object":"chat.completion.chunk","created":1721155403,"model":"magic-model","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":": 5, "}}]},"logprobs":null,"finish_reason":null}]}', # noqa: E501 + 'data: {"id":"chatcmpl-ID0","object":"chat.completion.chunk","created":1721155403,"model":"magic-model","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"\\"b\\": 3"}}]},"logprobs":null,"finish_reason":null}]}', # noqa: E501 + 'data: {"id":"chatcmpl-ID0","object":"chat.completion.chunk","created":1721155403,"model":"magic-model","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"}"}}]},"logprobs":null,"finish_reason":null}]}', # noqa: E501 + 'data: {"id":"chatcmpl-ID0","object":"chat.completion.chunk","created":1721155403,"model":"magic-model","system_fingerprint":null,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}', # noqa: E501 + "data: [DONE]", + ] + ) + + requests_mock.post( + "https://integrate.api.nvidia.com/v1/chat/completions", + text=response_contents, + ) + + warnings.filterwarnings("ignore", r".*Found magic-model in available_models.*") + llm = ChatNVIDIA(model="magic-model") + generator = llm.stream( + "What is 11 xxyyzz 3 zzyyxx 5?", + ) + response = next(generator) + for chunk in generator: + response += chunk + assert isinstance(response, AIMessage) + assert len(response.tool_calls) == 2 + tool_call0 = response.tool_calls[0] + assert tool_call0["name"] == "xxyyzz" + assert tool_call0["args"] == {"b": 3, "a": 11} + tool_call1 = response.tool_calls[1] + assert tool_call1["name"] == "zzyyxx" + assert tool_call1["args"] == {"b": 3, "a": 5} + + +def test_stream_parallel_tool_calls_B(requests_mock: requests_mock.Mocker) -> None: + response_contents = "\n\n".join( + [ + 'data: {"id":"cmpl-call_ID0","object":"chat.completion.chunk","created":1721155320,"model":"magic-model","choices":[{"index":0,"delta":{"role":"assistant","content":null},"logprobs":null,"finish_reason":null}]}', # noqa: E501 + 'data: {"id":"cmpl-call_ID0","object":"chat.completion.chunk","created":1721155320,"model":"magic-model","choices":[{"index":0,"delta":{"role":null,"content":null,"tool_calls":[{"index":0,"id":"chatcmpl-tool-IDA","type":"function","function":{"name":"xxyyzz","arguments":"{\\"a\\": 11, \\"b\\": 3}"}},{"index":1,"id":"chatcmpl-tool-IDB","type":"function","function":{"name":"zzyyxx","arguments":"{\\"a\\": 11, \\"b\\": 5}"}}]},"logprobs":null,"finish_reason":"tool_calls","stop_reason":null}]}', # noqa: E501 + "data: [DONE]", + ] + ) + + requests_mock.post( + "https://integrate.api.nvidia.com/v1/chat/completions", + text=response_contents, + ) + + warnings.filterwarnings("ignore", r".*Found magic-model in available_models.*") + llm = ChatNVIDIA(model="magic-model") + generator = llm.stream( + "What is 11 xxyyzz 3 zzyyxx 5?", + ) + response = next(generator) + for chunk in generator: + response += chunk + assert isinstance(response, AIMessage) + assert len(response.tool_calls) == 2 + tool_call0 = response.tool_calls[0] + assert tool_call0["name"] == "xxyyzz" + assert tool_call0["args"] == {"b": 3, "a": 11} + tool_call1 = response.tool_calls[1] + assert tool_call1["name"] == "zzyyxx" + assert tool_call1["args"] == {"b": 5, "a": 11} From bfbcd47f2a9c9cd5162befe44aaab4c0faefa87a Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Wed, 17 Jul 2024 10:26:54 -0400 Subject: [PATCH 07/11] to support langchain.agents.AgentExecutor, change invoke response to be an AIMessage in all cases (bump version to 0.2.0) applications that were expecting a ChatMessage may break --- .../chat_models.py | 17 +---------------- libs/ai-endpoints/pyproject.toml | 2 +- .../tests/integration_tests/test_bind_tools.py | 3 +-- 3 files changed, 3 insertions(+), 19 deletions(-) diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py index af2138f3..d6f40593 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py @@ -33,7 +33,6 @@ AIMessage, AIMessageChunk, BaseMessage, - ChatMessage, ) from langchain_core.outputs import ( ChatGeneration, @@ -268,14 +267,7 @@ def _generate( responses, _ = self._client.client.postprocess(response) self._set_callback_out(responses, run_manager) parsed_response = self._custom_postprocess(responses, streaming=False) - # todo: we should always return an AIMessage, but to maintain - # API compatibility, we only return it for tool_calls. we can - # change this for an API breaking 1.0. - if "tool_calls" in parsed_response["additional_kwargs"]: - message: BaseMessage = AIMessage(**parsed_response) - else: - message = ChatMessage(**parsed_response) - generation = ChatGeneration(message=message) + generation = ChatGeneration(message=AIMessage(**parsed_response)) return ChatResult(generations=[generation], llm_output=responses) def _stream( @@ -294,13 +286,6 @@ def _stream( for response in self._client.client.get_req_stream(payload=payload): self._set_callback_out(response, run_manager) parsed_response = self._custom_postprocess(response, streaming=True) - # todo: we should always return an AIMessage, but to maintain - # API compatibility, we only return it for tool_calls. we can - # change this for an API breaking 1.0. - # if "tool_calls" in parsed_response["additional_kwargs"]: - # message: BaseMessageChunk = AIMessageChunk(**parsed_response) - # else: - # message = ChatMessageChunk(**parsed_response) message = AIMessageChunk(**parsed_response) chunk = ChatGenerationChunk(message=message) if run_manager: diff --git a/libs/ai-endpoints/pyproject.toml b/libs/ai-endpoints/pyproject.toml index 89c36ab1..5e6f2a29 100644 --- a/libs/ai-endpoints/pyproject.toml +++ b/libs/ai-endpoints/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain-nvidia-ai-endpoints" -version = "0.1.6" +version = "0.2.0" description = "An integration package connecting NVIDIA AI Endpoints and LangChain" authors = [] readme = "README.md" diff --git a/libs/ai-endpoints/tests/integration_tests/test_bind_tools.py b/libs/ai-endpoints/tests/integration_tests/test_bind_tools.py index 1d64e576..f4e9a415 100644 --- a/libs/ai-endpoints/tests/integration_tests/test_bind_tools.py +++ b/libs/ai-endpoints/tests/integration_tests/test_bind_tools.py @@ -8,7 +8,6 @@ AIMessageChunk, BaseMessage, BaseMessageChunk, - ChatMessage, ) from langchain_core.pydantic_v1 import Field from langchain_core.tools import tool @@ -564,7 +563,7 @@ def test_bind_tool_tool_choice_none( tools=[xxyyzz], tool_choice=tool_choice ) response = llm.invoke("What is 11 xxyyzz 3?") - assert isinstance(response, ChatMessage) + assert isinstance(response, AIMessage) assert "tool_calls" not in response.additional_kwargs From 63d48bb1ffb95f8a5bb1cdbe0201e2111421f43a Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Fri, 19 Jul 2024 08:43:03 -0400 Subject: [PATCH 08/11] add compatibility role property to mitigate ChatMessage -> AIMessage change note: this does not work for AIMessageChunk compatibility --- .../chat_models.py | 8 +++++++ .../integration_tests/test_chat_models.py | 21 ++++++++++++++++--- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py index d6f40593..e3661659 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py @@ -267,6 +267,9 @@ def _generate( responses, _ = self._client.client.postprocess(response) self._set_callback_out(responses, run_manager) parsed_response = self._custom_postprocess(responses, streaming=False) + # for pre 0.2 compatibility w/ ChatMessage + # ChatMessage had a role property that was not present in AIMessage + parsed_response.update({"role": "assistant"}) generation = ChatGeneration(message=AIMessage(**parsed_response)) return ChatResult(generations=[generation], llm_output=responses) @@ -286,6 +289,11 @@ def _stream( for response in self._client.client.get_req_stream(payload=payload): self._set_callback_out(response, run_manager) parsed_response = self._custom_postprocess(response, streaming=True) + # for pre 0.2 compatibility w/ ChatMessageChunk + # ChatMessageChunk had a role property that was not + # present in AIMessageChunk + # unfortunately, AIMessageChunk does not have extensible propery + # parsed_response.update({"role": "assistant"}) message = AIMessageChunk(**parsed_response) chunk = ChatGenerationChunk(message=message) if run_manager: diff --git a/libs/ai-endpoints/tests/integration_tests/test_chat_models.py b/libs/ai-endpoints/tests/integration_tests/test_chat_models.py index bb1490de..6ed7c4ae 100644 --- a/libs/ai-endpoints/tests/integration_tests/test_chat_models.py +++ b/libs/ai-endpoints/tests/integration_tests/test_chat_models.py @@ -5,7 +5,12 @@ import pytest from langchain_core.load.dump import dumps from langchain_core.load.load import loads -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage +from langchain_core.messages import ( + AIMessage, + BaseMessage, + HumanMessage, + SystemMessage, +) from langchain_nvidia_ai_endpoints.chat_models import ChatNVIDIA @@ -25,6 +30,10 @@ def test_chat_ai_endpoints(chat_model: str, mode: dict) -> None: response = chat.invoke([message]) assert isinstance(response, BaseMessage) assert isinstance(response.content, str) + # compatibility test for ChatMessage (pre 0.2) + # assert isinstance(response, ChatMessage) + assert hasattr(response, "role") + assert response.role == "assistant" def test_unknown_model() -> None: @@ -145,11 +154,17 @@ def test_ai_endpoints_streaming(chat_model: str, mode: dict) -> None: """Test streaming tokens from ai endpoints.""" llm = ChatNVIDIA(model=chat_model, max_tokens=36, **mode) + generator = llm.stream("I'm Pickle Rick") + response = next(generator) cnt = 0 - for token in llm.stream("I'm Pickle Rick"): - assert isinstance(token.content, str) + for chunk in generator: + assert isinstance(chunk.content, str) + response += chunk cnt += 1 assert cnt > 1 + # compatibility test for ChatMessageChunk (pre 0.2) + # assert hasattr(response, "role") + # assert response.role == "assistant" # does not work, role not passed through async def test_ai_endpoints_astream(chat_model: str, mode: dict) -> None: From d9592df15c9d85e47714ab6ab893e1d01c46f094 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Fri, 19 Jul 2024 10:27:02 -0400 Subject: [PATCH 09/11] add tool calling section to doc notebook --- .../docs/chat/nvidia_ai_endpoints.ipynb | 69 +++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/libs/ai-endpoints/docs/chat/nvidia_ai_endpoints.ipynb b/libs/ai-endpoints/docs/chat/nvidia_ai_endpoints.ipynb index dc2236a9..e1025759 100644 --- a/libs/ai-endpoints/docs/chat/nvidia_ai_endpoints.ipynb +++ b/libs/ai-endpoints/docs/chat/nvidia_ai_endpoints.ipynb @@ -548,6 +548,75 @@ "source": [ "conversation.invoke(\"Tell me about yourself.\")[\"response\"]" ] + }, + { + "cell_type": "markdown", + "id": "f3cbbba0", + "metadata": {}, + "source": [ + "## Tool calling\n", + "\n", + "Starting in v0.2, `ChatNVIDIA` supports [bind_tools](https://api.python.langchain.com/en/latest/language_models/langchain_core.language_models.chat_models.BaseChatModel.html#langchain_core.language_models.chat_models.BaseChatModel.bind_tools).\n", + "\n", + "`ChatNVIDIA` provides integration with the variety of models on [build.nvidia.com](https://build.nvidia.com) as well as local NIMs. Not all these models are trained for tool calling. Be sure to select a model that does have tool calling for your experimention and applications." + ] + }, + { + "cell_type": "markdown", + "id": "6f7b535e", + "metadata": {}, + "source": [ + "You can get a list of models that are known to support tool calling with," + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e36c8911", + "metadata": {}, + "outputs": [], + "source": [ + "tool_models = [model for model in ChatNVIDIA.get_available_models() if model.supports_tools]\n", + "tool_models" + ] + }, + { + "cell_type": "markdown", + "id": "b01d75a7", + "metadata": {}, + "source": [ + "With a tool capable model," + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bd54f174", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_core.pydantic_v1 import Field\n", + "from langchain_core.tools import tool\n", + "\n", + "@tool\n", + "def get_current_weather(\n", + " location: str = Field(..., description=\"The location to get the weather for.\")\n", + "):\n", + " \"\"\"Get the current weather for a location.\"\"\"\n", + " ...\n", + "\n", + "llm = ChatNVIDIA(model=tool_models[0].id).bind_tools(tools=[get_current_weather])\n", + "response = llm.invoke(\"What is the weather in Boston?\")\n", + "response.tool_calls" + ] + }, + { + "cell_type": "markdown", + "id": "e08df68c", + "metadata": {}, + "source": [ + "See [How to use chat models to call tools](https://python.langchain.com/v0.2/docs/how_to/tool_calling/) for additional examples." + ] } ], "metadata": { From f743d781710b4f7cacdf141dfc44a1a3a70f65fc Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Tue, 23 Jul 2024 09:42:34 -0400 Subject: [PATCH 10/11] workaround for missing index field on streamed tool calls (revert when nim bug fixed) --- .../langchain_nvidia_ai_endpoints/chat_models.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py index e3661659..86a6a8ca 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py @@ -334,10 +334,16 @@ def _custom_postprocess( elif streaming: out_dict["tool_call_chunks"] = [] for tool_call in tool_calls: - assert "index" in tool_call, ( - "invalid response from server: " - "tool_call must have an 'index' key" - ) + # todo: the nim api does not return the function index + # for tool calls in stream responses. this is + # an issue that needs to be resolved server-side. + # the only reason we can skip this for now + # is because the nim endpoint returns only full + # tool calls, no deltas. + # assert "index" in tool_call, ( + # "invalid response from server: " + # "tool_call must have an 'index' key" + # ) assert "function" in tool_call, ( "invalid response from server: " "tool_call must have a 'function' key" From a26088ee2c9e43478c1a728f50424a42f2739093 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Tue, 23 Jul 2024 12:15:35 -0400 Subject: [PATCH 11/11] add tool support for meta/llama-3.1-8b-instruct, meta/llama-3.1-70b-instruct & meta/llama-3.1-405b-instruct --- libs/ai-endpoints/langchain_nvidia_ai_endpoints/_statics.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_statics.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_statics.py index cbfcf7f4..7c51466b 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_statics.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_statics.py @@ -289,16 +289,19 @@ def validate_client(cls, client: str, values: dict) -> str: id="meta/llama-3.1-8b-instruct", model_type="chat", client="ChatNVIDIA", + supports_tools=True, ), "meta/llama-3.1-70b-instruct": Model( id="meta/llama-3.1-70b-instruct", model_type="chat", client="ChatNVIDIA", + supports_tools=True, ), "meta/llama-3.1-405b-instruct": Model( id="meta/llama-3.1-405b-instruct", model_type="chat", client="ChatNVIDIA", + supports_tools=True, ), }