From 94677c74175249502085dcc42428ee89a6891e13 Mon Sep 17 00:00:00 2001 From: foin6 <61218792+foin6@users.noreply.github.com> Date: Fri, 12 Jan 2024 02:22:40 +0800 Subject: [PATCH] Modify codes so that different accelerators can be called according to specific device conditions (#844) * modify inference-test.py to meet with the requirement of using Intel's device * modify ds-hf-compare.py to meet with the requirement of using Intel's device * use deepspeed.accelerator.get_accelerator() to replace the original hard code about cuda to access and enable the accelerators available(not just Nvidia's GPU) in the current device * column 117: self.model.xpu().to(self.device)--->self.model.to(self.device) for generalization. * For upstream, use get_accelerator() to hide backend. Add bf16 dtype for cpu. * Update README.md * Delete redundant comment code * Delete +123 in README title * delete checkpoints.json * modify inference-test.py * modify inference-test.py v2 * modify inference.py v3 * add bfloat16 for cpu * fix an error in setup commands with conda --------- Co-authored-by: Olatunji Ruwase --- inference/huggingface/text-generation/README.md | 2 +- inference/huggingface/text-generation/arguments.py | 2 +- inference/huggingface/text-generation/ds-hf-compare.py | 5 +++-- inference/huggingface/text-generation/inference-test.py | 5 +++-- inference/huggingface/text-generation/utils.py | 5 +++-- 5 files changed, 11 insertions(+), 8 deletions(-) diff --git a/inference/huggingface/text-generation/README.md b/inference/huggingface/text-generation/README.md index 8019aa298..318e37416 100644 --- a/inference/huggingface/text-generation/README.md +++ b/inference/huggingface/text-generation/README.md @@ -20,7 +20,7 @@ If you are using conda, the following works: conda create -c conda-forge -n deepspeed python=3.10 conda activate deepspeed pip install -r requirements.txt -deepspeed --num_gpus 1 inference-test.py --name bigscience/bloom-3b --batch_size 2 +deepspeed --num_gpus 1 inference-test.py --model bigscience/bloom-3b --batch_size 2 # Inference Test diff --git a/inference/huggingface/text-generation/arguments.py b/inference/huggingface/text-generation/arguments.py index b50198ff9..a6dade23f 100644 --- a/inference/huggingface/text-generation/arguments.py +++ b/inference/huggingface/text-generation/arguments.py @@ -7,7 +7,7 @@ parser.add_argument("--checkpoint_path", required=False, default=None, type=str, help="model checkpoint path") parser.add_argument("--save_mp_checkpoint_path", required=False, default=None, type=str, help="save-path to store the new model checkpoint") parser.add_argument("--batch_size", default=1, type=int, help="batch size") -parser.add_argument("--dtype", default="float16", type=str, choices=["float32", "float16", "int8"], help="data-type") +parser.add_argument("--dtype", default="float16", type=str, choices=["float32", "float16", "int8", "bfloat16"], help="data-type") parser.add_argument("--hf_baseline", action='store_true', help="disable DeepSpeed inference") parser.add_argument("--use_kernel", action='store_true', help="enable kernel-injection") parser.add_argument("--max_tokens", default=1024, type=int, help="maximum tokens used for the text-generation KV-cache") diff --git a/inference/huggingface/text-generation/ds-hf-compare.py b/inference/huggingface/text-generation/ds-hf-compare.py index 378a13940..27f307a32 100644 --- a/inference/huggingface/text-generation/ds-hf-compare.py +++ b/inference/huggingface/text-generation/ds-hf-compare.py @@ -3,11 +3,12 @@ from transformers import pipeline from difflib import SequenceMatcher from argparse import ArgumentParser +from deepspeed.accelerator import get_accelerator parser = ArgumentParser() parser.add_argument("--model", required=True, type=str, help="model_name") -parser.add_argument("--dtype", default="float16", type=str, choices=["float32", "float16", "int8"], help="data-type") +parser.add_argument("--dtype", default="float16", type=str, choices=["float32", "float16", "int8", "bfloat16"], help="data-type") parser.add_argument("--num_inputs", default=1, type=int, help="number of test inputs") parser.add_argument("--min_length", default=200, type=int, help="minimum tokens generated") parser.add_argument("--max_length", default=300, type=int, help="maximum tokens generated") @@ -73,7 +74,7 @@ def string_similarity(str1, str2): inputs = test_inputs data_type = getattr(torch, args.dtype) -pipe = pipeline('text-generation', args.model, torch_dtype=data_type, device=0) +pipe = pipeline('text-generation', args.model, torch_dtype=data_type, device=torch.device(get_accelerator().device_name(0))) base_out_list = [] match_count=0 diff --git a/inference/huggingface/text-generation/inference-test.py b/inference/huggingface/text-generation/inference-test.py index 827d8db35..0ba3b20cd 100644 --- a/inference/huggingface/text-generation/inference-test.py +++ b/inference/huggingface/text-generation/inference-test.py @@ -6,6 +6,7 @@ import time from utils import DSPipeline, Performance from deepspeed.runtime.utils import see_memory_usage +from deepspeed.accelerator import get_accelerator from arguments import parser args = parser.parse_args() @@ -76,12 +77,12 @@ iters = 30 if args.test_performance else 2 #warmup times = [] for i in range(iters): - torch.cuda.synchronize() + get_accelerator().synchronize() start = time.time() outputs = pipe(inputs, num_tokens=args.max_new_tokens, do_sample=(not args.greedy)) - torch.cuda.synchronize() + get_accelerator().synchronize() end = time.time() times.append(end - start) print(f"generation time is {times[1]} sec") diff --git a/inference/huggingface/text-generation/utils.py b/inference/huggingface/text-generation/utils.py index 173eac039..bf727fefc 100644 --- a/inference/huggingface/text-generation/utils.py +++ b/inference/huggingface/text-generation/utils.py @@ -10,6 +10,7 @@ import torch from huggingface_hub import snapshot_download from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizerFast +from deepspeed.accelerator import get_accelerator class DSPipeline(): ''' @@ -34,7 +35,7 @@ def __init__(self, elif device < 0: self.device = torch.device("cpu") else: - self.device = torch.device(f"cuda:{device}") + self.device = torch.device(get_accelerator().device_name(device)) # the Deepspeed team made these so it's super fast to load (~1 minute), rather than wait 10-20min loading time. self.tp_presharded_models = ["microsoft/bloom-deepspeed-inference-int8", "microsoft/bloom-deepspeed-inference-fp16"] @@ -110,7 +111,7 @@ def generate_outputs(self, if torch.is_tensor(input_tokens[t]): input_tokens[t] = input_tokens[t].to(self.device) - self.model.cuda().to(self.device) + self.model.to(self.device) if isinstance(self.tokenizer, LlamaTokenizerFast): # NOTE: Check if Llamma can work w/ **input_tokens