Skip to content

Commit

Permalink
Handling tokenizer in PTQ for Nemo 2.0 (#11237)
Browse files Browse the repository at this point in the history
* Handling tokenizer in PTQ for Nemo 2.0

Signed-off-by: Jan Lasek <[email protected]>

* Print log msg and enable overriding

Signed-off-by: Jan Lasek <[email protected]>

* Warning for legacy tokenizer config

Signed-off-by: Jan Lasek <[email protected]>

* Save HF tokenizer to make tokenizer_config.yaml (almost) redundant

Signed-off-by: Jan Lasek <[email protected]>

* Handle tokenizer in a unified way

Signed-off-by: Jan Lasek <[email protected]>

* Move saving context within export

Signed-off-by: Jan Lasek <[email protected]>

* Fix typo in get_tokenzier

Signed-off-by: Jan Lasek <[email protected]>

* Reduce diff

Signed-off-by: Jan Lasek <[email protected]>

* Drop unused import

Signed-off-by: Jan Lasek <[email protected]>

---------

Signed-off-by: Jan Lasek <[email protected]>
  • Loading branch information
janekl authored Nov 12, 2024
1 parent 66766b1 commit d32c664
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 23 deletions.
21 changes: 12 additions & 9 deletions nemo/collections/llm/quantization/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os
import shutil
from dataclasses import dataclass
from typing import Optional, Union

Expand All @@ -22,6 +23,7 @@
from tqdm import tqdm

from nemo.collections import llm
from nemo.lightning.ckpt_utils import CONTEXT_PATH
from nemo.utils import logging

from .utils import get_unwrapped_mcore_model
Expand Down Expand Up @@ -259,7 +261,7 @@ def loop(model):

return loop

def export(self, model: llm.GPTModel) -> None:
def export(self, model: llm.GPTModel, model_dir: str) -> None:
assert self.export_config is not None, "Export config is not set"
# TODO: Add sample generate
# TODO: Support megatron_amp_O2
Expand All @@ -277,15 +279,16 @@ def export(self, model: llm.GPTModel) -> None:
use_nfs_workspace=use_nfs_workspace,
)

dist.barrier() # Wait until all ranks complete export_model_config step
logging.info(f"Export succeeded, model has been exported to {export_dir}. Saving tokenizer if possible...")
# Save the model context in order to restore its tokenizer later. The destination
# path is "nemo_context" as this name is used in nemo.export to setup tokenizer.
shutil.copytree(
os.path.join(model_dir, CONTEXT_PATH),
os.path.join(export_dir, "nemo_context"),
dirs_exist_ok=True,
)
logging.info(f"Model context saved.")

if dist.get_rank() == 0:
try:
tokenizer_dst = os.path.join(export_dir, 'tokenizer')
model.tokenizer.tokenizer.save_pretrained(tokenizer_dst)
except Exception as err:
logging.warning("Could not save the tokenizer: " + str(err))
logging.info(f"Export succeeded, model has been exported to {export_dir}.")


def get_calib_data_iter(
Expand Down
15 changes: 11 additions & 4 deletions nemo/export/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@
from nemo.export.trt_llm.converter.utils import init_model_parallel_from_nemo
from nemo.export.trt_llm.nemo_ckpt_loader.nemo_file import (
build_tokenizer,
get_tokenzier,
get_tokenizer,
is_nemo_file,
load_nemo_model,
)
from nemo.export.trt_llm.qnemo import qnemo_to_tensorrt_llm
from nemo.export.trt_llm.qnemo.tokenizer_utils import get_nmt_tokenizer
from nemo.export.trt_llm.qnemo.tokenizer_utils import TOKENIZER_CONFIG_FILE, get_nmt_tokenizer
from nemo.export.trt_llm.qnemo.utils import is_qnemo_checkpoint
from nemo.export.trt_llm.tensorrt_llm_build import build_and_save_engine
from nemo.export.trt_llm.tensorrt_llm_run import (
Expand Down Expand Up @@ -294,7 +294,14 @@ def export(
else:
unpack_tarball(nemo_checkpoint_path, tmp_dir.name)
nemo_checkpoint_path = tmp_dir.name
self.tokenizer = get_nmt_tokenizer(nemo_checkpoint_path)

if os.path.exists(os.path.join(nemo_checkpoint_path, TOKENIZER_CONFIG_FILE)):
# Instantiate tokenizer for a legacy "Nemo 1" quantized checkpoint from a tokenizer config.
# Note that using the config is deprecated and it will be removed in future releases.
LOGGER.warning("Detected legacy tokenizer_config.yaml, using it to build tokenizer.")
self.tokenizer = get_nmt_tokenizer(nemo_checkpoint_path)
else:
self.tokenizer = get_tokenizer(nemo_checkpoint_path)

qnemo_to_tensorrt_llm(
nemo_checkpoint_path=nemo_checkpoint_path,
Expand Down Expand Up @@ -1092,7 +1099,7 @@ def _load(self):
if len(folders) > 0:
try:
self._load_config_file()
self.tokenizer = get_tokenzier(Path(os.path.join(self.model_dir)))
self.tokenizer = get_tokenizer(self.model_dir)
self.model = load(
tokenizer=self.tokenizer,
engine_dir=self.model_dir,
Expand Down
9 changes: 5 additions & 4 deletions nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,16 +283,17 @@ def copy_tokenizer_files(config, out_dir):
outfile.write(infile.read())


def get_tokenzier(tokenizer_dir_or_path: Path) -> PreTrainedTokenizer:
"""Loads the tokenizer from the decoded NEMO weights dir."""
def get_tokenizer(tokenizer_dir_or_path: Union[str, Path]) -> PreTrainedTokenizer:
"""Loads the tokenizer from the decoded NeMo weights dir."""
tokenizer_dir_or_path = Path(tokenizer_dir_or_path)
if (tokenizer_dir_or_path / "nemo_context").exists():
from nemo.lightning import io

tokenizer_spec = io.load_context((tokenizer_dir_or_path / "nemo_context"), subpath="model.tokenizer")
return build_tokenizer(tokenizer_spec)
else:
if os.path.isdir(os.path.join(tokenizer_dir_or_path, "huggingface_tokenizer")):
return AutoTokenizer.from_pretrained(os.path.join(tokenizer_dir_or_path, "huggingface_tokenizer"))
if (tokenizer_dir_or_path / "huggingface_tokenizer").is_dir():
return AutoTokenizer.from_pretrained(tokenizer_dir_or_path / "huggingface_tokenizer")

model_path = (
tokenizer_dir_or_path / "tokenizer.model" if tokenizer_dir_or_path.is_dir() else tokenizer_dir_or_path
Expand Down
5 changes: 0 additions & 5 deletions nemo/export/trt_llm/qnemo/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,6 @@
def get_nmt_tokenizer(nemo_checkpoint_path: str):
"""Build tokenizer from Nemo tokenizer config."""

tokenizer_dir = os.path.join(nemo_checkpoint_path, TOKENIZER_DIR)
if os.path.exists(tokenizer_dir):
print(f"Initializing tokenizer from {TOKENIZER_DIR} directory")
return AutoTokenizer.from_pretrained(tokenizer_dir)

print(f"Initializing tokenizer from {TOKENIZER_CONFIG_FILE}")
tokenizer_cfg = OmegaConf.load(os.path.join(nemo_checkpoint_path, TOKENIZER_CONFIG_FILE))

Expand Down
4 changes: 4 additions & 0 deletions nemo/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,10 @@ def save_artifacts(model, output_dir: str, use_abspath: bool = False) -> None:
app_state = AppState()
model_file = app_state.model_restore_path
model_cfg = copy.deepcopy(model.cfg)

if model_cfg.tokenizer.library == "huggingface":
model.tokenizer.save_pretrained(os.path.join(output_dir, "huggingface_tokenizer"))

if not hasattr(model, "artifacts"):
if hasattr(model_cfg, "tokenizer"):
OmegaConf.save(model_cfg.tokenizer, os.path.join(output_dir, "tokenizer_config.yaml"))
Expand Down
2 changes: 1 addition & 1 deletion scripts/llm/ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def main():
quantizer = quantization.Quantizer(quantization_config, export_config)
model = quantization.load_with_modelopt_layer_spec(args.nemo_checkpoint, args.calib_tp, args.calib_pp)
model = quantizer.quantize(model)
quantizer.export(model)
quantizer.export(model, args.nemo_checkpoint)


if __name__ == '__main__':
Expand Down

0 comments on commit d32c664

Please sign in to comment.