Skip to content

Commit

Permalink
support trust_remote_code in inference test (#709)
Browse files Browse the repository at this point in the history
* support trust_remote_code

* make trust_remote _code as an argument

---------

Co-authored-by: Ammar Ahmad Awan <[email protected]>
Co-authored-by: Michael Wyatt <[email protected]>
  • Loading branch information
3 people authored Oct 3, 2023
1 parent 0d11c63 commit ca41e8b
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 9 deletions.
3 changes: 2 additions & 1 deletion inference/huggingface/text-generation/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@
parser.add_argument("--test_performance", action='store_true', help="enable latency, bandwidth, and throughout testing")
parser.add_argument("--local_rank", type=int, default=int(os.getenv("LOCAL_RANK", "0")), help="local rank")
parser.add_argument("--world_size", type=int, default=int(os.getenv("WORLD_SIZE", "1")), help="world_size")
parser.add_argument("--test_hybrid_engine", action='store_true', help="enable hybrid engine testing")
parser.add_argument("--test_hybrid_engine", action='store_true', help="enable hybrid engine testing")
parser.add_argument("--trust_remote_code", action='store_true', help="Trust remote code for hugging face models")
6 changes: 3 additions & 3 deletions inference/huggingface/text-generation/inference-test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
dtype=data_type,
is_meta=args.use_meta_tensor,
device=args.local_rank,
checkpoint_path=args.checkpoint_path)
checkpoint_path=args.checkpoint_path,
trust_remote_code=args.trust_remote_code)

if args.local_rank == 0:
print(f"initialization time: {(time.time()-t0) * 1000}ms")
Expand All @@ -51,7 +52,7 @@
save_mp_checkpoint_path=args.save_mp_checkpoint_path,
**ds_kwargs
)

if args.local_rank == 0:
see_memory_usage("after init_inference", True)

Expand Down Expand Up @@ -90,4 +91,3 @@
print(f"\nin={i}\nout={o}\n{'-'*60}")
if args.test_performance:
Performance.print_perf_stats(map(lambda t: t / args.max_new_tokens, times), pipe.model.config, args.dtype, args.batch_size)

11 changes: 6 additions & 5 deletions inference/huggingface/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def __init__(self,
dtype=torch.float16,
is_meta=True,
device=-1,
checkpoint_path=None
checkpoint_path=None,
trust_remote_code=False,
):
self.model_name = model_name
self.dtype = dtype
Expand All @@ -38,18 +39,18 @@ def __init__(self,
# the Deepspeed team made these so it's super fast to load (~1 minute), rather than wait 10-20min loading time.
self.tp_presharded_models = ["microsoft/bloom-deepspeed-inference-int8", "microsoft/bloom-deepspeed-inference-fp16"]

self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", trust_remote_code=trust_remote_code)
self.tokenizer.pad_token = self.tokenizer.eos_token

if (is_meta):
'''When meta tensors enabled, use checkpoints'''
self.config = AutoConfig.from_pretrained(self.model_name)
self.config = AutoConfig.from_pretrained(self.model_name, trust_remote_code=trust_remote_code)
self.repo_root, self.checkpoints_json = self._generate_json(checkpoint_path)

with deepspeed.OnDevice(dtype=torch.float16, device="meta"):
self.model = AutoModelForCausalLM.from_config(self.config)
self.model = AutoModelForCausalLM.from_config(self.config, trust_remote_code=trust_remote_code)
else:
self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
self.model = AutoModelForCausalLM.from_pretrained(self.model_name, trust_remote_code=trust_remote_code)

self.model.eval()

Expand Down

0 comments on commit ca41e8b

Please sign in to comment.