Skip to content

Commit

Permalink
perf(bloom): improve performance of huggingface_bloom_convert.py, dec…
Browse files Browse the repository at this point in the history
…rease the time cost and the mem using (#568)

Co-authored-by: r.yang <[email protected]>
  • Loading branch information
Yangruipis and r.yang authored Apr 24, 2023
1 parent 3460e20 commit 19b2956
Showing 1 changed file with 166 additions and 26 deletions.
192 changes: 166 additions & 26 deletions examples/pytorch/gpt/utils/huggingface_bloom_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
import configparser
import logging
import multiprocessing
import os
import re
import time

from pathlib import Path
from typing import Optional, Union
from typing import Dict, List, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -77,6 +78,9 @@ def get_args():
parser.add_argument(
'-v', '--verbose', action='store_true',
help='Enable verbose logging')
parser.add_argument(
'-s', '--by-shard', action='store_true',
help='Process shard by shard, enable when converting big model like bloom 175B')
_args = parser.parse_args()

set_logger(_args.verbose)
Expand Down Expand Up @@ -301,40 +305,176 @@ def save_bloom_config(model_config: BloomConfig, save_dir: PathLike):
config.write(f, space_around_delimiters=False)


def load_state_dict(file_path: Path, dtype: torch.dtype) -> Dict[str, torch.Tensor]:
""" Load weights from model file
`safetensors` or `pytorch binary` is supported
# Args.
file_path: model file path, ends with .bin or .safetensors.
dtype: torch.dtype, data type.
# Returns.
Dict[str, torch.Tensor]
"""

state_dict = {}
if file_path.suffix == ".safetensors":
# load from safetensors file
from safetensors import safe_open
with safe_open(file_path, framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k).type(dtype)
else:
# load from pytorch bin file
state_dict = torch.load(file_path, map_location="cpu")
for k in state_dict:
state_dict[k] = state_dict[k].type(dtype)
return state_dict


def get_model_files(model_name: str) -> List[Path]:
""" List all model files that you want to load and convert
# Args.
model_name: name(like `bigscience/bloom`) or local directory of the model
# Returns.
List[Path] model file paths
"""

import glob
from huggingface_hub import try_to_load_from_cache

model_dir = model_name

# get the local model directory
try:
config_file = "config.json"
# will fall back to HUGGINGFACE_HUB_CACHE
config_path = try_to_load_from_cache(
model_name, config_file, cache_dir=os.getenv("TRANSFORMERS_CACHE")
)

if config_path is not None:
# treat the model name as an huggingface model path
model_dir = os.path.dirname(config_path)
except:
# treat the model name as an explicit model path
pass

model_files = glob.glob(model_dir + "/*.bin")
try:
from safetensors import safe_open as _

st_files = glob.glob(model_dir + "/*.safetensors")
if st_files:
model_files = st_files
logger.info("loading from safetensors format")
except ImportError:
logger.info("loading from pytorch bin format")

if not model_files:
raise FileNotFoundError('model files not found')

logger.info(f"model file num: {len(model_files)}")
return [Path(i) for i in model_files]


def process_by_model_param(model_id: str, dtype: torch.dtype, tp_size: int, save_dir: Path, nproc: int):
""" Process conversion parameter by parameter.
"""

# init bloom config
model_config = BloomConfig.from_pretrained(model_id)
# list all model files
model_files = get_model_files(model_id)
# save bloom config to output dir
save_bloom_config(model_config, save_dir)

if nproc > 1:
pool = multiprocessing.Pool(nproc)
star_args = []
for model_file in model_files:
state_dict = load_state_dict(model_file, dtype)
for name in state_dict:
param = state_dict[name]
# Preprocess
param_name = convert_parameter_name(name)
param = safe_transpose(param)
param = handle_exceptions(model_config, param_name, param)
star_args.append((param_name, param.detach().cpu().numpy(), tp_size, save_dir))
pool.starmap_async(convert_and_save_parameter, star_args)
pool.close()
pool.join()
else:
for model_file in model_files:
state_dict = load_state_dict(model_file, dtype)
for name in state_dict:
param = state_dict[name]
# Preprocess
param_name = convert_parameter_name(name)
param = safe_transpose(param)
param = handle_exceptions(model_config, param_name, param)
convert_and_save_parameter(param_name, param.detach().cpu().numpy(), tp_size, save_dir)


def _process_by_model_shard(model_config, model_file, dtype: torch.dtype, tp_size: int, save_dir: Path):
state_dict = load_state_dict(model_file, dtype)
for name in state_dict:
param = state_dict[name]
# Preprocess
param_name = convert_parameter_name(name)
param = safe_transpose(param)
param = handle_exceptions(model_config, param_name, param)
convert_and_save_parameter(param_name, param.detach().cpu().numpy(), tp_size, save_dir)


def process_by_model_shard(model_id: str, dtype: torch.dtype, tp_size: int, save_dir: Path, nproc: int):
""" Process conversion shard by shard.
Benchmarks @ 64C(Intel Xeon 6326 2.90GH) x 756G:
| model | format | by-shard | nproc | elapsed(s) | mem |
|------------|------------------|----------|-------|------------|------|
| bloom-175b | safetensors x 72 | NO | 8 | 1516.66 | 350G |
| bloom-175b | safetensors x 72 | YES | 8 | 1165.03 | 50G |
| bloom-175b | safetensors x 72 | YES | 24 | 494.81 | 150G |
"""

# init bloom config
model_config = BloomConfig.from_pretrained(model_id)
# list all model files
model_files = get_model_files(model_id)
# save bloom config to output dir
save_bloom_config(model_config, save_dir)

if nproc > 1:
pool = multiprocessing.Pool(nproc)
star_args = []
for model_file in model_files:
star_args.append((model_config, model_file, dtype, tp_size, save_dir))
pool.starmap_async(_process_by_model_shard, star_args)
pool.close()
pool.join()
else:
for model_file in model_files:
_process_by_model_shard(model_config, model_file, dtype, tp_size, save_dir)


def main():
start_time = time.time()
args = get_args()
tp_size = args.tensor_para_size

dtype = DATATYPE_MAP[args.data_type]
model = AutoModel.from_pretrained(args.input_dir).cpu().type(dtype)
assert isinstance(model, torch.nn.Module)

save_dir = Path(args.output_dir) / f'{tp_size}-gpu'
save_dir.mkdir(exist_ok=True, parents=True)
save_bloom_config(model.config, save_dir)

start_time = time.time()
logger.info(f'Start the checkpoint conversion: '
f'{len(list(model.parameters()))} params')
if args.processes > 1:
pool = multiprocessing.Pool(args.processes)
star_args = []
for name, param in model.named_parameters():
# Preprocess
param_name = convert_parameter_name(name)
param = safe_transpose(param)
param = handle_exceptions(model.config, param_name, param)
star_args.append((param_name, param.detach().cpu().numpy(), tp_size, save_dir))
pool.starmap_async(convert_and_save_parameter, star_args)
pool.close()
pool.join()
if args.by_shard:
process_by_model_shard(args.input_dir, dtype, tp_size, save_dir, args.processes)
else:
for name, param in model.named_parameters():
# Preprocess
param_name = convert_parameter_name(name)
param = safe_transpose(param)
param = handle_exceptions(model.config, param_name, param)
convert_and_save_parameter(param_name, param.detach().cpu().numpy(), tp_size, save_dir)
process_by_model_param(args.input_dir, dtype, tp_size, save_dir, args.processes)

elapsed_time = time.time() - start_time
logger.info(f'Checkpoint conversion (HF >> FT) has done '
f'(elapsed time: {elapsed_time:.2f} sec)')
Expand Down

0 comments on commit 19b2956

Please sign in to comment.