diff --git a/examples/pytorch/prefix_sharing.py b/examples/pytorch/prefix_sharing.py index b8c0c1b9..6ed822f5 100644 --- a/examples/pytorch/prefix_sharing.py +++ b/examples/pytorch/prefix_sharing.py @@ -96,8 +96,8 @@ def build_inputs_baichuan(tokenizer, query: str, padding, history: List[Tuple[st # Master if model.rank == 0: - print(f"Input prompt length is :{input_ids.shape[-1]}.") - print(f"Shared prefix length is :{args.prefix_len}.") + print(f"[INFO] Input prompt length is :{input_ids.shape[-1]}.") + print(f"[INFO] Shared prefix length is :{args.prefix_len}.") # Base start_time = time.perf_counter() model.config(max_length=input_ids.shape[-1] + args.output_len, num_beams=args.num_beams) @@ -118,8 +118,8 @@ def build_inputs_baichuan(tokenizer, query: str, padding, history: List[Tuple[st print(snt) print("=" * 20 + "Performance" + "=" * 20) execution_time = end_time - start_time - print(f"Origin 1st token time:\t{first_token_time:.2f} s") - print(f"Origin execution time:\t{execution_time:.2f} s") + print(f"[INFO] Origin 1st token time:\t{first_token_time:.2f} s") + print(f"[INFO] Origin execution time:\t{execution_time:.2f} s") # Enable perfix sharing truncate_tail = input_ids.shape[-1] - args.prefix_len @@ -150,4 +150,6 @@ def build_inputs_baichuan(tokenizer, query: str, padding, history: List[Tuple[st else: # Slave model.generate() + + model.prefix_sharing() model.generate()