Skip to content

Commit

Permalink
Example for multimodal inference
Browse files Browse the repository at this point in the history
  • Loading branch information
Shanmugam Ramasamy committed Nov 13, 2024
1 parent 7c0a38d commit 07674de
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 30 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
inference:
greedy: True # Whether or not to use sampling ; use greedy decoding otherwise
top_k: 0 # The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p: 0.9 # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature: 1.0 # sampling temperature
add_BOS: True # add the bos token at the begining of the prompt
tokens_to_generate: 30 # The minimum length of the sequence to be generated.
all_probs: False # whether return the log prob for all the tokens in vocab
repetition_penalty: 1.2 # The parameter for repetition penalty. 1.0 means no penalty.
min_tokens_to_generate: 0 # The minimum length of the sequence to be generated.
compute_logprob: False # a flag used to compute logprob of all the input text, a very special case of running inference, default False
end_strings: ["<|extra_204|>"] # generation will stop when one of these tokens is generated

trainer:
devices: 1
num_nodes: 1
accelerator: gpu
logger: False # logger provided by exp_manager
precision: bf16 # 16, 32, or bf16
use_distributed_sampler: False

tensor_model_parallel_size: -1
pipeline_model_parallel_size: -1
pipeline_model_parallel_split_rank: -1 # used for encoder and decoder model (0 for others)
megatron_amp_O2: False # Enable O2-level automatic mixed precision to save memory
image_encoder: Cosmos-Tokenizer-DV8x16x16
gpt_model_file: null # GPT nemo file path
checkpoint_dir: null # checkpoint file dir. This is used to load the PTL checkpoint generated during the GPT training
checkpoint_name: null # PTL checkpoint file name, only used for PTL checkpoint loading
hparams_file: null # model configuration file, only used for PTL checkpoint loading
captions: # prompts for GPT inference
- "a drawing of a green pokemon with red eyes"
- "a red pokemon with green eyes"
- "a cartoon fish with a big smile"
images_output_path: null # Path to the directory to store the output images

Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
inference:
greedy: False # Whether or not to use sampling ; use greedy decoding otherwise
greedy: True # Whether or not to use sampling ; use greedy decoding otherwise
top_k: 0 # The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p: 0.9 # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature: 1.0 # sampling temperature
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import os
import math
import datetime
import torch
import torchvision
from pytorch_lightning.trainer.trainer import Trainer

from nemo.collections.common.video_tokenizers.cosmos_tokenizer import CausalVideoTokenizer
from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam
from nemo.collections.nlp.parts.nlp_overrides import CustomProgressBar, NLPDDPStrategy
from nemo.core.config import hydra_runner
from examples.nlp.language_modeling.megatron_gpt_eval import load_model_from_config,round_to_mult, remove_padded_prompts

"""
This is the script to run multimodal autoregresssive text generation.
Make sure you install tiktoken==0.6.0
Usage:
Assume the model has TP=1, PP=1 in the following use cases.
a. run greedy inference from a nemo file:
python megatron_mm_autoregresssive_eval.py \
gpt_model_file=PATH_TO_MODEL \
inference.greedy=True \
inference.add_BOS=True \
trainer.devices=1 \
trainer.num_nodes=1 \
tensor_model_parallel_size=-1 \
pipeline_model_parallel_size=-1 \
captions=[caption1,caption2]
b. run greedy inference from a PTL checkpoint file:
python megatron_mm_autoregresssive_eval.py \
checkpoint_dir=PATH_TO_CHECKPOINT_FILE \
checkpoint_name=CHECKPOINT_FILE_NAME \
hparams_file=HPARAMS_FILE \
inference.greedy=True \
inference.add_BOS=True \
trainer.devices=1 \
trainer.num_nodes=1 \
tensor_model_parallel_size=-1 \
pipeline_model_parallel_size=-1 \
captions=[caption1,caption2]
c. run top_p inference from a nemo file:
python megatron_mm_autoregresssive_eval.py \
gpt_model_file=PATH_TO_MODEL \
inference.greedy=False \
inference.top_k=0 \
inference.top_p=0.9 \
inference.repetition_penalty=1.2 \
inference.add_BOS=True \
trainer.devices=1 \
trainer.num_nodes=1 \
tensor_model_parallel_size=-1 \
pipeline_model_parallel_size=-1 \
captions=[caption1,caption2]
d. If you don't need to generate tokens and need model to compute logprobs:
python megatron_mm_autoregresssive_eval.py \
gpt_model_file=PATH_TO_MODEL \
inference.compute_logprob=True \
trainer.devices=1 \
trainer.num_nodes=1 \
tensor_model_parallel_size=-1 \
pipeline_model_parallel_size=-1 \
captions=[caption1,caption2]
"""


def to_img(tokens_string, image_tokenizer):
visual_token_pattern = r"<\|visual token (\d+)\|>"
visual_tokens = [int(match) for match in re.findall(visual_token_pattern, tokens_string)]
# We assume image is square. So if 64 tokensa are present, we reshape it to 8x8 and then pass it to decoder
dim = int(math.sqrt(len(visual_tokens)))
visual_tokens_tensor = torch.tensor(visual_tokens[:dim*dim])
# Decoder accepts input of the following format [bs, channel_dim, h, w]
visual_tokens_tensor_reshaped = visual_tokens_tensor.reshape((dim, dim)).unsqueeze(0).unsqueeze(0)
visual_tokens_final = visual_tokens_tensor_reshaped.to(image_tokenizer._device)
img = image_tokenizer.decode(visual_tokens_final)

# Convert from bf16 to 16 and to format [channel_dim, h, w]
image = torchvision.transforms.functional.to_pil_image(img.float().squeeze())
return image


def load_prompts(cfg):
prompts = []
for caption in cfg.captions:
prompt = f'You are a helpful assistant. Draw a picture for the caption given by the user. USER: {caption}. ASSISTANT: '
prompts.append(prompt)
return prompts

if not torch.cuda.is_available():
raise EnvironmentError("GPU is needed for the inference")

@hydra_runner(config_path="conf", config_name="megatron_mm_ar_inference_image_generation")
def main(cfg) -> None:

callbacks = []
# enable_progress_bar is True by default. If cfg.trainer.enable_progress_bar=False, CustomProgressBar is not appended to callbacks
if 'enable_progress_bar' not in cfg.trainer or cfg.trainer.enable_progress_bar:
callbacks.append(CustomProgressBar())
# trainer required for restoring model parallel models
trainer = Trainer(
strategy=NLPDDPStrategy(timeout=datetime.timedelta(seconds=18000)),
**cfg.trainer,
callbacks=callbacks,
)

image_tokenizer = CausalVideoTokenizer.from_pretrained(
tokenizer_type=cfg.image_encoder,
load_encoder=False,
load_decoder=True,
load_full_model=False
)

model = load_model_from_config(trainer, cfg)
model.freeze()

# Have to turn off activations_checkpoint_method for inference
try:
model.model.language_model.encoder.activations_checkpoint_method = None
except AttributeError:

Check notice

Code scanning / CodeQL

Empty except Note

'except' clause does nothing but pass and there is no explanatory comment.
pass

length_params: LengthParam = {
"max_length": cfg.inference.tokens_to_generate,
"min_length": cfg.inference.min_tokens_to_generate,
}

sampling_params: SamplingParam = {
"use_greedy": cfg.inference.greedy,
"temperature": cfg.inference.temperature,
"top_k": cfg.inference.top_k,
"top_p": cfg.inference.top_p,
"repetition_penalty": cfg.inference.repetition_penalty,
"add_BOS": cfg.inference.add_BOS,
"all_probs": cfg.inference.all_probs,
"compute_logprob": cfg.inference.compute_logprob,
"end_strings": cfg.inference.end_strings,
}

prompts = []
with torch.no_grad():
prompts = load_prompts(cfg)

fp8_enabled = hasattr(model.cfg, "fp8") and (model.cfg.fp8 == True)
if fp8_enabled and len(prompts) > 0:
padded_len = round_to_mult(len(prompts), 8)
nb_paddings = padded_len - len(prompts)
if nb_paddings > 0:
nb_paddings += [''] * nb_paddings

# First method of running text generation, call model.generate method
response = model.generate(inputs=prompts, length_params=length_params, sampling_params=sampling_params)

if fp8_enabled:
response = remove_padded_prompts(response, nb_paddings)

output_tokens_strings = response['sentences']
for idx, output_token_string in enumerate(output_tokens_strings):
image = to_img(output_token_string, image_tokenizer)
image.save(os.path.join(cfg.images_output_path, f'{idx}.jpg'))

print(f'Images saved to {cfg.images_output_path}')

if __name__ == '__main__':
main() # noqa pylint: disable=no-value-for-parameter
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def load_prompts(cfg, image_tokenizer, tokenizer):
if not torch.cuda.is_available():
raise EnvironmentError("GPU is needed for the inference")

@hydra_runner(config_path="conf", config_name="megatron_mm_ar_inference")
@hydra_runner(config_path="conf", config_name="megatron_mm_ar_inference_vision_understanding")
def main(cfg) -> None:

callbacks = []
Expand Down
34 changes: 6 additions & 28 deletions nemo/collections/multimodal_autoregressive/data/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,7 @@ This is an example of how to do autoregressive generation for multiple modalitie
### 1. Vision Understanding using EMU3 Tokenizer

#### Download and Extract data
We will be working with coyo dataset which has 700 million images.

First create credentials for rclone . Create this file at `~/.config/rclone/rclone.conf`
```
[pbss-team-vfm-share-ro-s3]
type = s3
env_auth = true
access_key_id = <ACCESS ID>
secret_access_key = <ACCESS KEY>
region = us-east-1
endpoint = https://pdx.s8k.io
```
To download the images
```
rclone copy pbss-team-vfm-share-ro-s3:webdataset_images/webdataset_edify_image_v3/coyo_700m/resolution_lt_720/aspect_ratio_16_9/images images --transfers=16 --multi-thread-streams=16 --checkers=8 -P --stats 5s
```

To download the captions
```
rclone copy pbss-team-vfm-share-ro-s3:webdataset_images/webdataset_edify_image_v3/coyo_700m/resolution_lt_720/aspect_ratio_16_9/captions_ai_v3p1 captions_ai_v3p1 --transfers=16 --multi-thread-streams=16 --checkers=8 -P --stats 5s
```
Download the [COYO700M dataset](https://github.com/kakaobrain/coyo-dataset)

Once downloaded extract the data using tar utilities.

Expand Down Expand Up @@ -70,13 +50,13 @@ Follow usual nemo instructions to train any autoregressive model.
```

#### Inference
To run inference edit the [inference config file](examples/multimodal_autoregressive/conf/megatron_mm_ar_inference.yaml)
To run inference edit the [inference config file](examples/multimodal_autoregressive/conf/megatron_mm_ar_inference_vision_understanding.yaml)
*NOTE* Make sure you have a .nemo file (checkpoint). If you just have a regular megatron checkpoint you have to do a conversion as shown in [this doc](https://docs.nvidia.com/nemo-framework/user-guide/latest/llms/gpt/checkpointconversion.html?highlight=convert)

Run inference as follows

```
torchrun --nproc-per-node 2 examples/multimodal_autoregressive/megatron_mm_autoregressive_eval.py
torchrun --nproc-per-node 2 examples/multimodal_autoregressive/megatron_mm_autoregressive_eval_vision_understanding.py
```


Expand Down Expand Up @@ -116,13 +96,11 @@ Follow usual nemo instructions to train any autoregressive model.
```

#### Inference
To run inference edit the [inference config file](examples/multimodal_autoregressive/conf/megatron_mm_ar_inference.yaml)
To run inference edit the [inference config file](examples/multimodal_autoregressive/conf/megatron_mm_ar_inference_image_generation.yaml)
*NOTE* Make sure you have a .nemo file (checkpoint). If you just have a regular megatron checkpoint you have to do a conversion as shown in [this doc](https://docs.nvidia.com/nemo-framework/user-guide/latest/llms/gpt/checkpointconversion.html?highlight=convert)

Run inference as follows

```
torchrun --nproc-per-node 2 examples/multimodal_autoregressive/megatron_mm_autoregressive_eval.py
```

TODO : Instructions to convert visual tokens to images coming soon.
torchrun --nproc-per-node 2 examples/multimodal_autoregressive/megatron_mm_autoregressive_eval_image_generation.py
```

0 comments on commit 07674de

Please sign in to comment.