-
Notifications
You must be signed in to change notification settings - Fork 15.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
NVIDIA Triton+TRT-LLM connector needs to handle dynamic model parameters #15045
Comments
🤖 Hello @fciannella! 👋 I am Dosu, a friendly bot here to lend a hand while you're waiting for the human maintainers to respond. I'm designed to help you troubleshoot bugs, answer queries, and guide you around becoming a contributor. So, feel free to bounce anything off me. Let's get your issue sorted! 🛠️😀 Thank you for your detailed request. I understand that you want the NVIDIA Triton+TRT-LLM connector in LangChain to handle dynamic model parameters, discover input and output parameters from the server, and provide a list of mandatory parameters in the error code if a parameter is missing in the request. Here's how you can achieve this:
def _generate_inputs(
self,
prompt: Sequence[Sequence[str]],
params: Dict[str, Any],
stream: bool = True,
) -> List[grpcclient.InferRequestedOutput]:
"""Create the input for the triton inference server."""
query = np.array(prompt).astype(object)
inputs = []
for param, value in params.items():
input_data = np.array([value]).astype(np.float32).reshape((1, -1))
inputs.append(self._prepare_tensor(param, input_data))
inputs.append(self._prepare_tensor("text_input", query))
inputs.append(self._prepare_tensor("stream", np.array([[stream]], dtype=bool)))
return inputs
def get_model_metadata(self) -> Dict[str, Any]:
"""Get the metadata of the model."""
return self.client.get_model_metadata(self.model_name, as_json=True)
MANDATORY_PARAMS = ["tokens", "top_k", "top_p", "temperature", "repetition_penalty", "length_penalty", "beam_width"]
def validate_params(self, params: Dict[str, Any]) -> None:
"""Validate that all mandatory parameters are present."""
missing_params = [param for param in self.MANDATORY_PARAMS if param not in params]
if missing_params:
raise TritonTensorRTRuntimeError(f"Missing mandatory parameters: {missing_params}") Remember to call You can find the current structure of the I hope this helps! If you have any further questions, feel free to ask. Best, SourcesThis response is meant to be useful and save you time. It is not meant to be a precise solution, but rather a starting point for your own research. Help me be more useful! Please leave a 👍 if this is helpful and 👎 if it is irrelevant. If you want to continue the conversation, start your reply with @dosu-bot. |
Here's an example of passing |
System Info
LangChain 0.0.348
langchain-nvidia-trt 0.0.1rc0
Python 3.11
Who can help?
@jdye64
Information
Related Components
Reproduction
You can set up the triton server as in the quick start guide here using a nemotron model:
https://github.com/fciannella/langchain-fciannella/blob/master/libs/partners/nvidia-trt/README.md
Then you can just try sending a request from the LangChain plugin and it will not work.
Same thing if you follow the more complex setup.
The input and output parameters need to be discovered by the client. One option is to use pytriton as a client, or just pull the parameters from the server using an API call to get the triton configuration?
Expected behavior
The client should provide a list of the mandatory parameters in the error code in case there is a missing parameter in the request.
The text was updated successfully, but these errors were encountered: