diff --git a/libs/partners/nvidia-trt/langchain_nvidia_trt/llms.py b/libs/partners/nvidia-trt/langchain_nvidia_trt/llms.py index 161da60ba5758..0474e9b035285 100644 --- a/libs/partners/nvidia-trt/langchain_nvidia_trt/llms.py +++ b/libs/partners/nvidia-trt/langchain_nvidia_trt/llms.py @@ -40,6 +40,7 @@ class TritonTensorRTLLM(BaseLLM): length_penalty: (float) The penalty to apply repeated tokens tokens: (int) The maximum number of tokens to generate. client: The client object used to communicate with the inference server + verbose_client: flag to pass to the client on creation Example: .. code-block:: python @@ -73,6 +74,7 @@ class TritonTensorRTLLM(BaseLLM): description="Request the inference server to load the specified model.\ Certain Triton configurations do not allow for this operation.", ) + verbose_client: bool = False def __del__(self): """Ensure the client streaming connection is properly shutdown""" @@ -82,7 +84,9 @@ def __del__(self): def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Validate that python package exists in environment.""" if not values.get("client"): - values["client"] = grpcclient.InferenceServerClient(values["server_url"]) + values["client"] = grpcclient.InferenceServerClient( + values["server_url"], verbose=values.get("verbose_client", False) + ) return values @property diff --git a/libs/partners/nvidia-trt/tests/unit_tests/test_llms.py b/libs/partners/nvidia-trt/tests/unit_tests/test_llms.py index 0113b83507315..8ef80cdea4688 100644 --- a/libs/partners/nvidia-trt/tests/unit_tests/test_llms.py +++ b/libs/partners/nvidia-trt/tests/unit_tests/test_llms.py @@ -1,7 +1,33 @@ """Test TritonTensorRT Chat API wrapper.""" +import sys +from io import StringIO +from unittest.mock import patch + from langchain_nvidia_trt import TritonTensorRTLLM def test_initialization() -> None: """Test integration initialization.""" TritonTensorRTLLM(model_name="ensemble", server_url="http://localhost:8001") + + +@patch("tritonclient.grpc.service_pb2_grpc.GRPCInferenceServiceStub") +def test_default_verbose(ignore) -> None: + llm = TritonTensorRTLLM(server_url="http://localhost:8001", model_name="ensemble") + captured = StringIO() + sys.stdout = captured + llm.client.is_server_live() + sys.stdout = sys.__stdout__ + assert "is_server_live" not in captured.getvalue() + + +@patch("tritonclient.grpc.service_pb2_grpc.GRPCInferenceServiceStub") +def test_verbose(ignore) -> None: + llm = TritonTensorRTLLM( + server_url="http://localhost:8001", model_name="ensemble", verbose_client=True + ) + captured = StringIO() + sys.stdout = captured + llm.client.is_server_live() + sys.stdout = sys.__stdout__ + assert "is_server_live" in captured.getvalue()