diff --git a/examples/pytorch/gpt/utils/huggingface_bloom_convert.py b/examples/pytorch/gpt/utils/huggingface_bloom_convert.py index 7c78711f4..60adad2db 100644 --- a/examples/pytorch/gpt/utils/huggingface_bloom_convert.py +++ b/examples/pytorch/gpt/utils/huggingface_bloom_convert.py @@ -20,11 +20,12 @@ import configparser import logging import multiprocessing +import os import re import time from pathlib import Path -from typing import Optional, Union +from typing import Dict, List, Optional, Union import numpy as np import torch @@ -77,6 +78,9 @@ def get_args(): parser.add_argument( '-v', '--verbose', action='store_true', help='Enable verbose logging') + parser.add_argument( + '-s', '--by-shard', action='store_true', + help='Process shard by shard, enable when converting big model like bloom 175B') _args = parser.parse_args() set_logger(_args.verbose) @@ -301,40 +305,176 @@ def save_bloom_config(model_config: BloomConfig, save_dir: PathLike): config.write(f, space_around_delimiters=False) +def load_state_dict(file_path: Path, dtype: torch.dtype) -> Dict[str, torch.Tensor]: + """ Load weights from model file + + `safetensors` or `pytorch binary` is supported + + # Args. + file_path: model file path, ends with .bin or .safetensors. + dtype: torch.dtype, data type. + # Returns. + Dict[str, torch.Tensor] + """ + + state_dict = {} + if file_path.suffix == ".safetensors": + # load from safetensors file + from safetensors import safe_open + with safe_open(file_path, framework="pt", device="cpu") as f: + for k in f.keys(): + state_dict[k] = f.get_tensor(k).type(dtype) + else: + # load from pytorch bin file + state_dict = torch.load(file_path, map_location="cpu") + for k in state_dict: + state_dict[k] = state_dict[k].type(dtype) + return state_dict + + +def get_model_files(model_name: str) -> List[Path]: + """ List all model files that you want to load and convert + + # Args. + model_name: name(like `bigscience/bloom`) or local directory of the model + # Returns. + List[Path] model file paths + """ + + import glob + from huggingface_hub import try_to_load_from_cache + + model_dir = model_name + + # get the local model directory + try: + config_file = "config.json" + # will fall back to HUGGINGFACE_HUB_CACHE + config_path = try_to_load_from_cache( + model_name, config_file, cache_dir=os.getenv("TRANSFORMERS_CACHE") + ) + + if config_path is not None: + # treat the model name as an huggingface model path + model_dir = os.path.dirname(config_path) + except: + # treat the model name as an explicit model path + pass + + model_files = glob.glob(model_dir + "/*.bin") + try: + from safetensors import safe_open as _ + + st_files = glob.glob(model_dir + "/*.safetensors") + if st_files: + model_files = st_files + logger.info("loading from safetensors format") + except ImportError: + logger.info("loading from pytorch bin format") + + if not model_files: + raise FileNotFoundError('model files not found') + + logger.info(f"model file num: {len(model_files)}") + return [Path(i) for i in model_files] + + +def process_by_model_param(model_id: str, dtype: torch.dtype, tp_size: int, save_dir: Path, nproc: int): + """ Process conversion parameter by parameter. + """ + + # init bloom config + model_config = BloomConfig.from_pretrained(model_id) + # list all model files + model_files = get_model_files(model_id) + # save bloom config to output dir + save_bloom_config(model_config, save_dir) + + if nproc > 1: + pool = multiprocessing.Pool(nproc) + star_args = [] + for model_file in model_files: + state_dict = load_state_dict(model_file, dtype) + for name in state_dict: + param = state_dict[name] + # Preprocess + param_name = convert_parameter_name(name) + param = safe_transpose(param) + param = handle_exceptions(model_config, param_name, param) + star_args.append((param_name, param.detach().cpu().numpy(), tp_size, save_dir)) + pool.starmap_async(convert_and_save_parameter, star_args) + pool.close() + pool.join() + else: + for model_file in model_files: + state_dict = load_state_dict(model_file, dtype) + for name in state_dict: + param = state_dict[name] + # Preprocess + param_name = convert_parameter_name(name) + param = safe_transpose(param) + param = handle_exceptions(model_config, param_name, param) + convert_and_save_parameter(param_name, param.detach().cpu().numpy(), tp_size, save_dir) + + +def _process_by_model_shard(model_config, model_file, dtype: torch.dtype, tp_size: int, save_dir: Path): + state_dict = load_state_dict(model_file, dtype) + for name in state_dict: + param = state_dict[name] + # Preprocess + param_name = convert_parameter_name(name) + param = safe_transpose(param) + param = handle_exceptions(model_config, param_name, param) + convert_and_save_parameter(param_name, param.detach().cpu().numpy(), tp_size, save_dir) + + +def process_by_model_shard(model_id: str, dtype: torch.dtype, tp_size: int, save_dir: Path, nproc: int): + """ Process conversion shard by shard. + + Benchmarks @ 64C(Intel Xeon 6326 2.90GH) x 756G: + + | model | format | by-shard | nproc | elapsed(s) | mem | + |------------|------------------|----------|-------|------------|------| + | bloom-175b | safetensors x 72 | NO | 8 | 1516.66 | 350G | + | bloom-175b | safetensors x 72 | YES | 8 | 1165.03 | 50G | + | bloom-175b | safetensors x 72 | YES | 24 | 494.81 | 150G | + + """ + + # init bloom config + model_config = BloomConfig.from_pretrained(model_id) + # list all model files + model_files = get_model_files(model_id) + # save bloom config to output dir + save_bloom_config(model_config, save_dir) + + if nproc > 1: + pool = multiprocessing.Pool(nproc) + star_args = [] + for model_file in model_files: + star_args.append((model_config, model_file, dtype, tp_size, save_dir)) + pool.starmap_async(_process_by_model_shard, star_args) + pool.close() + pool.join() + else: + for model_file in model_files: + _process_by_model_shard(model_config, model_file, dtype, tp_size, save_dir) + + def main(): + start_time = time.time() args = get_args() tp_size = args.tensor_para_size - dtype = DATATYPE_MAP[args.data_type] - model = AutoModel.from_pretrained(args.input_dir).cpu().type(dtype) - assert isinstance(model, torch.nn.Module) save_dir = Path(args.output_dir) / f'{tp_size}-gpu' save_dir.mkdir(exist_ok=True, parents=True) - save_bloom_config(model.config, save_dir) - start_time = time.time() - logger.info(f'Start the checkpoint conversion: ' - f'{len(list(model.parameters()))} params') - if args.processes > 1: - pool = multiprocessing.Pool(args.processes) - star_args = [] - for name, param in model.named_parameters(): - # Preprocess - param_name = convert_parameter_name(name) - param = safe_transpose(param) - param = handle_exceptions(model.config, param_name, param) - star_args.append((param_name, param.detach().cpu().numpy(), tp_size, save_dir)) - pool.starmap_async(convert_and_save_parameter, star_args) - pool.close() - pool.join() + if args.by_shard: + process_by_model_shard(args.input_dir, dtype, tp_size, save_dir, args.processes) else: - for name, param in model.named_parameters(): - # Preprocess - param_name = convert_parameter_name(name) - param = safe_transpose(param) - param = handle_exceptions(model.config, param_name, param) - convert_and_save_parameter(param_name, param.detach().cpu().numpy(), tp_size, save_dir) + process_by_model_param(args.input_dir, dtype, tp_size, save_dir, args.processes) + elapsed_time = time.time() - start_time logger.info(f'Checkpoint conversion (HF >> FT) has done ' f'(elapsed time: {elapsed_time:.2f} sec)')