Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NVIDIA Triton+TRT-LLM connector needs to handle dynamic model parameters #15045

Closed
1 of 14 tasks
fciannella opened this issue Dec 22, 2023 · 2 comments
Closed
1 of 14 tasks
Labels
🤖:bug Related to a bug, vulnerability, unexpected error with an existing feature Ɑ: models Related to LLMs or chat model modules 🔌: nvidia Related to nvidia integrations

Comments

@fciannella
Copy link

fciannella commented Dec 22, 2023

System Info

LangChain 0.0.348
langchain-nvidia-trt 0.0.1rc0
Python 3.11

Who can help?

@jdye64

Information

  • The official example notebooks/scripts
  • My own modified scripts

Related Components

  • LLMs/Chat Models
  • Embedding Models
  • Prompts / Prompt Templates / Prompt Selectors
  • Output Parsers
  • Document Loaders
  • Vector Stores / Retrievers
  • Memory
  • Agents / Agent Executors
  • Tools / Toolkits
  • Chains
  • Callbacks/Tracing
  • Async

Reproduction

You can set up the triton server as in the quick start guide here using a nemotron model:

https://github.com/fciannella/langchain-fciannella/blob/master/libs/partners/nvidia-trt/README.md

Then you can just try sending a request from the LangChain plugin and it will not work.

Same thing if you follow the more complex setup.

The input and output parameters need to be discovered by the client. One option is to use pytriton as a client, or just pull the parameters from the server using an API call to get the triton configuration?

Expected behavior

The client should provide a list of the mandatory parameters in the error code in case there is a missing parameter in the request.

Copy link

dosubot bot commented Dec 22, 2023

🤖

Hello @fciannella! 👋

I am Dosu, a friendly bot here to lend a hand while you're waiting for the human maintainers to respond. I'm designed to help you troubleshoot bugs, answer queries, and guide you around becoming a contributor. So, feel free to bounce anything off me. Let's get your issue sorted! 🛠️😀

Thank you for your detailed request. I understand that you want the NVIDIA Triton+TRT-LLM connector in LangChain to handle dynamic model parameters, discover input and output parameters from the server, and provide a list of mandatory parameters in the error code if a parameter is missing in the request.

Here's how you can achieve this:

  1. Handle Dynamic Model Parameters: You can modify the _generate_inputs method to accept a dictionary of parameters instead of individual parameters. This way, you can pass any number of parameters dynamically.
def _generate_inputs(
    self,
    prompt: Sequence[Sequence[str]],
    params: Dict[str, Any],
    stream: bool = True,
) -> List[grpcclient.InferRequestedOutput]:
    """Create the input for the triton inference server."""
    query = np.array(prompt).astype(object)
    inputs = []
    for param, value in params.items():
        input_data = np.array([value]).astype(np.float32).reshape((1, -1))
        inputs.append(self._prepare_tensor(param, input_data))
    inputs.append(self._prepare_tensor("text_input", query))
    inputs.append(self._prepare_tensor("stream", np.array([[stream]], dtype=bool)))
    return inputs
  1. Discover Input and Output Parameters: You can use the get_model_metadata method of the InferenceServerClient class to get the metadata of the model which includes the input and output parameters.
def get_model_metadata(self) -> Dict[str, Any]:
    """Get the metadata of the model."""
    return self.client.get_model_metadata(self.model_name, as_json=True)
  1. Provide a List of Mandatory Parameters in the Error Code: You can define a list of mandatory parameters and check if these parameters are present in the request. If a parameter is missing, you can raise an error with the list of missing parameters.
MANDATORY_PARAMS = ["tokens", "top_k", "top_p", "temperature", "repetition_penalty", "length_penalty", "beam_width"]

def validate_params(self, params: Dict[str, Any]) -> None:
    """Validate that all mandatory parameters are present."""
    missing_params = [param for param in self.MANDATORY_PARAMS if param not in params]
    if missing_params:
        raise TritonTensorRTRuntimeError(f"Missing mandatory parameters: {missing_params}")

Remember to call validate_params before calling _generate_inputs in the _generate and _stream methods.

You can find the current structure of the _generate_inputs method in the NVIDIA Triton+TRT-LLM connector in LangChain in the libs/partners/nvidia-trt/langchain_nvidia_trt/llms.py file.

I hope this helps! If you have any further questions, feel free to ask.

Best,
Dosu

Sources


This response is meant to be useful and save you time. It is not meant to be a precise solution, but rather a starting point for your own research.

Help me be more useful! Please leave a 👍 if this is helpful and 👎 if it is irrelevant.

If you want to continue the conversation, start your reply with @dosu-bot.

@dosubot dosubot bot added Ɑ: models Related to LLMs or chat model modules 🤖:bug Related to a bug, vulnerability, unexpected error with an existing feature labels Dec 22, 2023
@hinthornw hinthornw added the 🔌: nvidia Related to nvidia integrations label Dec 22, 2023
@mkhludnev
Copy link
Contributor

Here's an example of passing max_tokens via sample_parameters in one backend triton-inference-server/vllm_backend#34

@dosubot dosubot bot added the stale Issue has not had recent activity or appears to be solved. Stale issues will be automatically closed label Jun 1, 2024
@dosubot dosubot bot closed this as not planned Won't fix, can't repro, duplicate, stale Jun 8, 2024
@dosubot dosubot bot removed the stale Issue has not had recent activity or appears to be solved. Stale issues will be automatically closed label Jun 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🤖:bug Related to a bug, vulnerability, unexpected error with an existing feature Ɑ: models Related to LLMs or chat model modules 🔌: nvidia Related to nvidia integrations
Projects
None yet
Development

No branches or pull requests

3 participants