Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Meta inference for different models #223

Merged
merged 6 commits into from
Jan 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 40 additions & 9 deletions inference/huggingface/text-generation/inference-test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,44 @@
import math
import os
import torch
import time
from utils import DSPipeline

from deepspeed.runtime.utils import see_memory_usage

parser = ArgumentParser()

parser.add_argument("--name", required=True, type=str, help="model_name")
parser.add_argument("--checkpoint_path", required=False, default=None, type=str, help="model checkpoint path")
parser.add_argument("--save_mp_checkpoint_path", required=False, default=None, type=str, help="save-path to store the new model checkpoint")
parser.add_argument("--batch_size", default=1, type=int, help="batch size")
parser.add_argument("--dtype", default="float16", type=str, choices=["float32", "float16", "int8"], help="data-type")
parser.add_argument("--ds_inference", default=True, type=bool, help="enable ds-inference")
parser.add_argument("--ds_inference", action='store_true', help="enable ds-inference")
parser.add_argument("--use_kernel", action='store_true', help="enable kernel-injection")
parser.add_argument("--max_tokens", default=1024, type=int, help="maximum tokens used for the text-generation KV-cache")
parser.add_argument("--max_new_tokens", default=50, type=int, help="maximum new tokens to generate")
parser.add_argument("--greedy", default=False, type=bool, help="greedy generation mode")
parser.add_argument("--use_meta_tensor", default=False, type=bool, help="use the meta tensors to initialize model")
parser.add_argument("--greedy", action='store_true', help="greedy generation mode")
parser.add_argument("--use_meta_tensor", action='store_true', help="use the meta tensors to initialize model")
parser.add_argument("--use_cache", default=True, type=bool, help="use cache for generation")
parser.add_argument("--local_rank", type=int, default=0, help="local rank")
args = parser.parse_args()

world_size = int(os.getenv('WORLD_SIZE', '1'))
local_rank = int(os.getenv('LOCAL_RANK', '0'))

data_type = getattr(torch, args.dtype)

if local_rank == 0:
see_memory_usage("before init", True)

t0 = time.time()
pipe = DSPipeline(model_name=args.name,
dtype=data_type,
is_meta=args.use_meta_tensor,
device=args.local_rank)

device=args.local_rank,
checkpoint_path=args.checkpoint_path)
if local_rank == 0:
print(f"initialization time: {(time.time()-t0) * 1000}ms")
see_memory_usage("after init", True)
if args.use_meta_tensor:
ds_kwargs = dict(base_dir=pipe.repo_root, checkpoint=pipe.checkpoints_json)
else:
Expand All @@ -38,10 +51,15 @@
pipe.model = deepspeed.init_inference(pipe.model,
dtype=data_type,
mp_size=world_size,
replace_with_kernel_inject=True,
replace_with_kernel_inject=args.use_kernel,
replace_method='auto',
max_tokens=args.max_tokens,
save_mp_checkpoint_path=args.save_mp_checkpoint_path,
**ds_kwargs
)
if local_rank == 0:
see_memory_usage("after init_inference", True)


input_sentences = [
"DeepSpeed is a machine learning framework",
Expand All @@ -60,10 +78,23 @@

inputs = input_sentences[:args.batch_size]

# warmup
outputs = pipe(inputs,
num_tokens=args.max_new_tokens,
do_sample=(not args.greedy))

for i, o in zip(inputs, outputs):
print(f"\nin={i}\nout={o}\n{'-'*60}")
torch.cuda.synchronize()
start = time.time()

outputs = pipe(inputs,
num_tokens=args.max_new_tokens,
do_sample=(not args.greedy))

torch.cuda.synchronize()
end = time.time()
print(f'generation time is {end-start} sec')

if args.local_rank == 0:
for i, o in zip(inputs, outputs):
print(f"\nin={i}\nout={o}\n{'-'*60}")

26 changes: 18 additions & 8 deletions inference/huggingface/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def __init__(self,
model_name='bigscience/bloom-3b',
dtype=torch.float16,
is_meta=True,
device=-1
device=-1,
checkpoint_path=None
):
self.model_name = model_name
self.dtype = dtype
Expand All @@ -43,7 +44,7 @@ def __init__(self,
if (is_meta):
'''When meta tensors enabled, use checkpoints'''
self.config = AutoConfig.from_pretrained(self.model_name)
self.repo_root, self.checkpoints_json = self._generate_json()
self.repo_root, self.checkpoints_json = self._generate_json(checkpoint_path)

with deepspeed.OnDevice(dtype=torch.float16, device="meta"):
self.model = AutoModelForCausalLM.from_config(self.config)
Expand All @@ -66,18 +67,27 @@ def __call__(self,
return outputs


def _generate_json(self):
repo_root = snapshot_download(self.model_name, allow_patterns=["*"], local_files_only=False, revision=None)

if (self.model_name in self.tp_presharded_models):
def _generate_json(self, checkpoint_path=None):
if checkpoint_path is None:
repo_root = snapshot_download(self.model_name,
allow_patterns=["*"],
cache_dir=os.getenv("TRANSFORMERS_CACHE", None),
ignore_patterns=["*.safetensors"],
local_files_only=False,
revision=None)
else:
repo_root = checkpoint_path
if os.path.exists(os.path.join(repo_root, "ds_inference_config.json")):
checkpoints_json = os.path.join(repo_root, "ds_inference_config.json")
elif (self.model_name in self.tp_presharded_models):
# tp presharded repos come with their own checkpoints config file
checkpoints_json = os.path.join(repo_root, "ds_inference_config.json")
else:
checkpoints_json = "checkpoints.json"

with io.open(checkpoints_json, "w", encoding="utf-8") as f:
file_list = [str(entry) for entry in Path(repo_root).rglob("*.[bp][it][n]") if entry.is_file()]
data = {"type": self.config.model_type, "checkpoints": file_list, "version": 1.0}
file_list = [str(entry).split('/')[-1] for entry in Path(repo_root).rglob("*.[bp][it][n]") if entry.is_file()]
data = {"type": "BLOOM", "checkpoints": file_list, "version": 1.0}
json.dump(data, f)

return repo_root, checkpoints_json
Expand Down