Skip to content

Commit

Permalink
python script.
Browse files Browse the repository at this point in the history
  • Loading branch information
Duyi-Wang committed Nov 22, 2023
1 parent 3ea284e commit f075c27
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions examples/pytorch/prefix_sharing.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ def build_inputs_baichuan(tokenizer, query: str, padding, history: List[Tuple[st

# Master
if model.rank == 0:
print(f"Input prompt length is :{input_ids.shape[-1]}.")
print(f"Shared prefix length is :{args.prefix_len}.")
print(f"[INFO] Input prompt length is :{input_ids.shape[-1]}.")
print(f"[INFO] Shared prefix length is :{args.prefix_len}.")
# Base
start_time = time.perf_counter()
model.config(max_length=input_ids.shape[-1] + args.output_len, num_beams=args.num_beams)
Expand All @@ -118,8 +118,8 @@ def build_inputs_baichuan(tokenizer, query: str, padding, history: List[Tuple[st
print(snt)
print("=" * 20 + "Performance" + "=" * 20)
execution_time = end_time - start_time
print(f"Origin 1st token time:\t{first_token_time:.2f} s")
print(f"Origin execution time:\t{execution_time:.2f} s")
print(f"[INFO] Origin 1st token time:\t{first_token_time:.2f} s")
print(f"[INFO] Origin execution time:\t{execution_time:.2f} s")

# Enable perfix sharing
truncate_tail = input_ids.shape[-1] - args.prefix_len
Expand Down Expand Up @@ -150,4 +150,6 @@ def build_inputs_baichuan(tokenizer, query: str, padding, history: List[Tuple[st
else:
# Slave
model.generate()

model.prefix_sharing()
model.generate()

0 comments on commit f075c27

Please sign in to comment.