Skip to content

Commit

Permalink
Fix llm.generate (#11217)
Browse files Browse the repository at this point in the history
* Fix llm.generate

Signed-off-by: Hemil Desai <[email protected]>

* fix

Signed-off-by: Hemil Desai <[email protected]>

* Apply isort and black reformatting

Signed-off-by: hemildesai <[email protected]>

* Fix

Signed-off-by: Hemil Desai <[email protected]>

---------

Signed-off-by: Hemil Desai <[email protected]>
Signed-off-by: hemildesai <[email protected]>
Co-authored-by: hemildesai <[email protected]>
  • Loading branch information
hemildesai and hemildesai authored Nov 8, 2024
1 parent 922e840 commit 43ba11a
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 5 deletions.
66 changes: 66 additions & 0 deletions nemo/collections/llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,72 @@ def generate(
inference_params: Optional["CommonInferenceParams"] = None,
text_only: bool = False,
) -> list[Union["InferenceRequest", str]]:
"""
Generates text using a NeMo LLM model.
This function takes a checkpoint path and a list of prompts,
and generates text based on the loaded model and parameters.
It returns a list of generated text, either as a string or as an InferenceRequest object.
Python Usage:
```python
strategy = nl.MegatronStrategy(
tensor_model_parallel_size=2,
pipeline_model_parallel_size=1,
context_parallel_size=1,
sequence_parallel=False,
setup_optimizers=False,
store_optimizer_states=False,
)
trainer = nl.Trainer(
accelerator="gpu",
devices=2,
num_nodes=1,
strategy=strategy,
plugins=nl.MegatronMixedPrecision(
precision="bf16-mixed",
params_dtype=torch.bfloat16,
pipeline_dtype=torch.bfloat16,
autocast_enabled=False,
grad_reduce_in_fp32=False,
),
)
prompts = [
"Hello, how are you?",
"How many r's are in the word 'strawberry'?",
"Which number is bigger? 10.119 or 10.19?",
]
if __name__ == "__main__":
results = api.generate(
path=os.path.join(os.environ["NEMO_HOME"], "models", "meta-llama/Meta-Llama-3-8B"),
prompts=prompts,
trainer=trainer,
inference_params=CommonInferenceParams(temperature=0.1, top_k=10, num_tokens_to_generate=512),
text_only=True,
)
```
Args:
path (Union[Path, str]): The path to the model checkpoint.
prompts (list[str]): The list of prompts to generate text for.
trainer (nl.Trainer): The trainer object.
encoder_prompts (Optional[list[str]], optional): The list of encoder prompts. Defaults to None.
params_dtype (torch.dtype, optional): The data type of the model parameters. Defaults to torch.bfloat16.
add_BOS (bool, optional): Whether to add the beginning of sequence token. Defaults to False.
max_batch_size (int, optional): The maximum batch size. Defaults to 4.
random_seed (Optional[int], optional): The random seed. Defaults to None.
inference_batch_times_seqlen_threshold (int, optional): If batch-size times sequence-length is smaller than
this threshold then we will not use pipelining, otherwise we will. Defaults to 1000.
inference_params (Optional["CommonInferenceParams"], optional): The inference parameters defined in
Mcore's CommonInferenceParams. Defaults to None.
text_only (bool, optional): Whether to return only the generated text as a string. Defaults to False.
Returns:
list[Union["InferenceRequest", str]]: A list of generated text,
either as a string or as an InferenceRequest object.
"""
from nemo.collections.llm import inference

inference_wrapped_model, mcore_tokenizer = inference.setup_model_and_tokenizer(
Expand Down
107 changes: 102 additions & 5 deletions nemo/collections/llm/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,46 +36,105 @@
import nemo.lightning as nl
from nemo.collections.llm.peft import LoRA
from nemo.lightning import io
from nemo.lightning.ckpt_utils import ADAPTER_META_FILENAME, ckpt_to_context_subdir, ckpt_to_weights_subdir
from nemo.lightning.ckpt_utils import ADAPTER_META_FILENAME, ckpt_to_context_subdir
from nemo.lightning.io.pl import ckpt_to_weights_subdir
from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy
from nemo.lightning.pytorch.strategies.utils import RestoreConfig


# We need this wrapper since mcore generate uses methods/properties such as tokenizer.detokenize, tokenizer.tokenize, tokenizer.bos, tokenizer.pad, etc. to encode and decode prompts
class MCoreTokenizerWrappper:
"""
We need this wrapper since mcore generate uses methods/properties such as
tokenizer.detokenize, tokenizer.tokenize, tokenizer.bos, tokenizer.pad, etc. to encode and decode prompts
"""

def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.eod = tokenizer.eod
self.vocab_size = tokenizer.vocab_size

def detokenize(self, tokens, remove_special_tokens=False):
"""
Detokenizes a list of tokens into a string.
Args:
tokens (list): The list of tokens to detokenize.
remove_special_tokens (bool, optional): Whether to remove special tokens. Defaults to False.
Returns:
str: The detokenized string.
"""
return self.tokenizer.ids_to_text(tokens, remove_special_tokens)

def tokenize(self, prompt):
"""
Tokenizes a prompt into a list of tokens.
Args:
prompt (str): The prompt to tokenize.
Returns:
list: The list of tokens.
"""
return self.tokenizer.text_to_ids(prompt)

@property
def additional_special_tokens_ids(self):
"""
Gets the IDs of additional special tokens.
Returns:
list: The IDs of additional special tokens.
"""
return self.tokenizer.additional_special_tokens_ids

@property
def bos(self):
"""
Gets the ID of the beginning of sequence token.
Returns:
int: The ID of the beginning of sequence token.
"""
return self.tokenizer.bos_id

@property
def pad(self):
"""
Gets the ID of the padding token.
Returns:
int: The ID of the padding token.
"""
return self.tokenizer.pad_id


# TODO: Move to lightning Fabric API.
def _setup_trainer_and_restore_model(path: Path, trainer: nl.Trainer, model: pl.LightningModule):
"""
Sets up the trainer and restores the model from the given checkpoint path.
It does the following:
- Defines a RestoreConfig to restore only model weights
- Disables setting up optimizers in the Trainer
- Calls strategy.setup_environment(), model.configure_model() and strategy.setup_megatron_parallel(trainer=trainer)
- Finally loads the model weights
Args:
path (Path): The path to the checkpoint file.
trainer (nl.Trainer): The trainer object.
model (pl.LightningModule): The model object.
Returns:
None
"""
assert isinstance(trainer.strategy, MegatronStrategy), "Only MegatronStrategy is supported for trainer.strategy."
assert trainer.strategy.context_parallel_size <= 1, "Context parallelism is not supported for inference."
if (adapter_meta_path := ckpt_to_weights_subdir(path) / ADAPTER_META_FILENAME).exists():
if (adapter_meta_path := ckpt_to_weights_subdir(path, is_saving=False) / ADAPTER_META_FILENAME).exists():
with open(adapter_meta_path, "r") as f:
metadata = json.load(f)
restore_config = RestoreConfig(
path=metadata['model_ckpt_path'],
path=metadata["model_ckpt_path"],
load_model_state=True,
load_optim_state=False,
)
Expand Down Expand Up @@ -107,7 +166,7 @@ def _setup_trainer_and_restore_model(path: Path, trainer: nl.Trainer, model: pl.
model = lora(model)
adapter_sharded_state_dict = {k: v for k, v in model.sharded_state_dict().items() if ".adapter." in k}
adapter_state = trainer.strategy.checkpoint_io.load_checkpoint(
ckpt_to_weights_subdir(path), sharded_state_dict=adapter_sharded_state_dict
ckpt_to_weights_subdir(path, is_saving=False), sharded_state_dict=adapter_sharded_state_dict
)
trainer.strategy.load_model_state_dict(adapter_state, strict=False)

Expand All @@ -118,6 +177,24 @@ def setup_model_and_tokenizer(
params_dtype: torch.dtype = torch.bfloat16,
inference_batch_times_seqlen_threshold: int = 1000,
) -> tuple[MegatronModule, MCoreTokenizerWrappper]:
"""
Sets up the model and tokenizer for inference.
This function loads the model and tokenizer from the given checkpoint path,
sets up the trainer, and returns the Megatron inference-wrapped model and tokenizer.
Args:
path (Path): The path to the checkpoint file.
trainer (nl.Trainer): The trainer object.
params_dtype (torch.dtype, optional): The data type of the model parameters.
Defaults to torch.bfloat16.
inference_batch_times_seqlen_threshold (int, optional): If batch-size times sequence-length is smaller
than this threshold then we will not use pipelining, otherwise we will.
Returns:
tuple[MegatronModule, MCoreTokenizerWrappper]:
A tuple containing the inference-wrapped model and Mcore wrapped tokenizer.
"""
model: io.TrainerContext = io.load_context(path=ckpt_to_context_subdir(path), subpath="model")
_setup_trainer_and_restore_model(path=path, trainer=trainer, model=model)

Expand All @@ -135,6 +212,26 @@ def generate(
random_seed: Optional[int] = None,
inference_params: Optional[CommonInferenceParams] = None,
) -> dict:
"""
Runs generate on the model with the given prompts.
This function uses the loaded model, loaded tokenizer, and prompts to generate text.
It returns a dictionary containing the generated text.
Args:
model (AbstractModelInferenceWrapper): The inference-wrapped model.
tokenizer (MCoreTokenizerWrappper): The tokenizer.
prompts (list[str]): The list of prompts to generate text for.
encoder_prompts (Optional[list[str]], optional): The list of encoder prompts. Defaults to None.
add_BOS (bool, optional): Whether to add the beginning of sequence token. Defaults to False.
max_batch_size (int, optional): The maximum batch size. Defaults to 4.
random_seed (Optional[int], optional): The random seed. Defaults to None.
inference_params (Optional[CommonInferenceParams], optional): The inference parameters defined in
Mcore's CommonInferenceParams. Defaults to None.
Returns:
dict: A dictionary containing the generated results.
"""
if encoder_prompts is not None:
text_generation_controller = EncoderDecoderTextGenerationController(
inference_wrapped_model=model, tokenizer=tokenizer
Expand Down

0 comments on commit 43ba11a

Please sign in to comment.