Skip to content

Commit

Permalink
nvidia-trt[patch]: add TritonTensorRTLLM(verbose_client=False) (langc…
Browse files Browse the repository at this point in the history
…hain-ai#16848)

- **Description:** adding verbose flag to TritonTensorRTLLM, 
  - **Issue:** nope,
  - **Dependencies:** not any,
  - **Twitter handle:**
  • Loading branch information
mkhludnev authored and thebhulawat committed Mar 6, 2024
1 parent e7f9fdf commit ab6f0eb
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
6 changes: 5 additions & 1 deletion libs/partners/nvidia-trt/langchain_nvidia_trt/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class TritonTensorRTLLM(BaseLLM):
length_penalty: (float) The penalty to apply repeated tokens
tokens: (int) The maximum number of tokens to generate.
client: The client object used to communicate with the inference server
verbose_client: flag to pass to the client on creation
Example:
.. code-block:: python
Expand Down Expand Up @@ -73,6 +74,7 @@ class TritonTensorRTLLM(BaseLLM):
description="Request the inference server to load the specified model.\
Certain Triton configurations do not allow for this operation.",
)
verbose_client: bool = False

def __del__(self):
"""Ensure the client streaming connection is properly shutdown"""
Expand All @@ -82,7 +84,9 @@ def __del__(self):
def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate that python package exists in environment."""
if not values.get("client"):
values["client"] = grpcclient.InferenceServerClient(values["server_url"])
values["client"] = grpcclient.InferenceServerClient(
values["server_url"], verbose=values.get("verbose_client", False)
)
return values

@property
Expand Down
26 changes: 26 additions & 0 deletions libs/partners/nvidia-trt/tests/unit_tests/test_llms.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,33 @@
"""Test TritonTensorRT Chat API wrapper."""
import sys
from io import StringIO
from unittest.mock import patch

from langchain_nvidia_trt import TritonTensorRTLLM


def test_initialization() -> None:
"""Test integration initialization."""
TritonTensorRTLLM(model_name="ensemble", server_url="http://localhost:8001")


@patch("tritonclient.grpc.service_pb2_grpc.GRPCInferenceServiceStub")
def test_default_verbose(ignore) -> None:
llm = TritonTensorRTLLM(server_url="http://localhost:8001", model_name="ensemble")
captured = StringIO()
sys.stdout = captured
llm.client.is_server_live()
sys.stdout = sys.__stdout__
assert "is_server_live" not in captured.getvalue()


@patch("tritonclient.grpc.service_pb2_grpc.GRPCInferenceServiceStub")
def test_verbose(ignore) -> None:
llm = TritonTensorRTLLM(
server_url="http://localhost:8001", model_name="ensemble", verbose_client=True
)
captured = StringIO()
sys.stdout = captured
llm.client.is_server_live()
sys.stdout = sys.__stdout__
assert "is_server_live" in captured.getvalue()

0 comments on commit ab6f0eb

Please sign in to comment.