Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: data parallel inference examples #2805

Merged
merged 1 commit into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docsrc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,13 @@ Tutorials
tutorials/_rendered_examples/dynamo/torch_compile_transformers_example
tutorials/_rendered_examples/dynamo/torch_compile_advanced_usage
tutorials/_rendered_examples/dynamo/torch_compile_stable_diffusion
<<<<<<< HEAD
tutorials/_rendered_examples/dynamo/custom_kernel_plugins
=======
tutorials/_rendered_examples/distributed_inference/data_parallel_gpt2
tutorials/_rendered_examples/distributed_inference/data_parallel_stable_diffusion

>>>>>>> dfbf6ea84 (feat: data parallel inference sample)

Python API Documenation
------------------------
Expand Down
14 changes: 14 additions & 0 deletions examples/distributed_inference/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Torch-TensorRT parallelism for distributed inference

Examples in this folder demonstrates doing distributed inference on multiple devices with Torch-TensorRT backend.

1. Data parallel distributed inference based on [Acclerate](https://huggingface.co/docs/accelerate/usage_guides/distributed_inference)

Using Accelerate users can achieve data parallel distributed inference with Torch-TensorRt backend. In this case, the entire model
will be loaded onto each GPU and different chunks of batch input is processed on each device.

See the examples started with `data_parallel` for more details.

2. Tensor parallel distributed inference

In development.
64 changes: 64 additions & 0 deletions examples/distributed_inference/data_parallel_gpt2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
.. _data_parallel_gpt2:

Torch-TensorRT Distributed Inference
======================================================

This interactive script is intended as a sample of distributed inference using data
parallelism using Accelerate
library with the Torch-TensorRT workflow on GPT2 model.

"""

# %%
# Imports and Model Definition
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

import torch
from accelerate import PartialState
from transformers import AutoTokenizer, GPT2LMHeadModel

import torch_tensorrt

tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Set input prompts for different devices
prompt1 = "GPT2 is a model developed by."
prompt2 = "Llama is a model developed by "

input_id1 = tokenizer(prompt1, return_tensors="pt").input_ids
input_id2 = tokenizer(prompt2, return_tensors="pt").input_ids

distributed_state = PartialState()

# Import GPT2 model and load to distributed devices
model = GPT2LMHeadModel.from_pretrained("gpt2").eval().to(distributed_state.device)


# Instantiate model with Torch-TensorRT backend
model.forward = torch.compile(
model.forward,
backend="torch_tensorrt",
options={
"truncate_long_and_double": True,
"enabled_precisions": {torch.float16},
"debug": True,
},
dynamic=False,
)

# %%
# Inference
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Assume there are 2 processes (2 devices)
with distributed_state.split_between_processes([input_id1, input_id2]) as prompt:
cur_input = torch.clone(prompt[0]).to(distributed_state.device)

gen_tokens = model.generate(
cur_input,
do_sample=True,
temperature=0.9,
max_length=100,
)
gen_text = tokenizer.batch_decode(gen_tokens)[0]
61 changes: 61 additions & 0 deletions examples/distributed_inference/data_parallel_stable_diffusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
.. _data_parallel_stable_diffusion:

Torch-TensorRT Distributed Inference
======================================================

This interactive script is intended as a sample of distributed inference using data
parallelism using Accelerate
library with the Torch-TensorRT workflow on Stable Diffusion model.

"""

# %%
# Imports and Model Definition
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
import torch
from accelerate import PartialState
from diffusers import DiffusionPipeline

import torch_tensorrt

model_id = "CompVis/stable-diffusion-v1-4"

# Instantiate Stable Diffusion Pipeline with FP16 weights
pipe = DiffusionPipeline.from_pretrained(
model_id, revision="fp16", torch_dtype=torch.float16
)

distributed_state = PartialState()
pipe = pipe.to(distributed_state.device)

backend = "torch_tensorrt"

# Optimize the UNet portion with Torch-TensorRT
pipe.unet = torch.compile( # %%
# Inference
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Assume there are 2 processes (2 devices)
pipe.unet,
backend=backend,
options={
"truncate_long_and_double": True,
"precision": torch.float16,
"debug": True,
"use_python_runtime": True,
},
dynamic=False,
)
torch_tensorrt.runtime.set_multi_device_safe_mode(True)


# %%
# Inference
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Assume there are 2 processes (2 devices)
with distributed_state.split_between_processes(["a dog", "a cat"]) as prompt:
print("before \n")
result = pipe(prompt).images[0]
print("after ")
result.save(f"result_{distributed_state.process_index}.png")
3 changes: 3 additions & 0 deletions examples/distributed_inference/requirement.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
accelerate
transformers
diffusers
Loading