Skip to content

Commit

Permalink
Fix the decoding logic in test_local_grpc.py (#44)
Browse files Browse the repository at this point in the history
* fix the test_local_grpc script

* lint fix
  • Loading branch information
alfredgui2 authored and tjluyao committed Jul 7, 2024
1 parent cea2a41 commit d6e3758
Showing 1 changed file with 24 additions and 19 deletions.
43 changes: 24 additions & 19 deletions server/examples/test_local_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,8 @@ def make_input(lora_id, lora_or_base, id=0, promptOverride=None):


requests = [
make_input(
"abcdabcd987/gsm8k-llama2-7b-lora-16",
"base",
id=0,
promptOverride="Give me a breif introduction to Byznatine Fault Tolerance and why it is important?",
),
make_input(
"abcdabcd987/gsm8k-llama2-7b-lora-16",
"lora",
id=1,
promptOverride="Which network interface card is more suitable for distributed systems, Meallanox or Broadcom?",
),
make_input("tjluyao/gemma-2b-it-math", "base", id=0),
make_input("tjluyao/gemma-2b-it-math", "base", id=1),
]

# Assemble input batch
Expand All @@ -78,11 +68,26 @@ def make_input(lora_id, lora_or_base, id=0, promptOverride=None):
)
stub.Warmup(wr)
# Prefill
pr = generate_pb2.PrefillRequest(batch=pb_batch_empty)
pr = generate_pb2.PrefillRequest(batch=pb_batch_with_inputs)
resp = stub.Prefill(pr)
gen, cbatch = resp.generations, resp.batch
# Decode
dr = generate_pb2.DecodeRequest(batches=[cbatch])
resp = stub.Decode(dr)
gen, cbatch = resp.generations, resp.batch
print("done")
generations, cbatch = resp.generations, resp.batch
for gen in generations:
print(gen.tokens.texts)

print("finished prefill tokens")

while True:
dr = generate_pb2.DecodeRequest(batches=[cbatch])
resp = stub.Decode(dr)
generations, cbatch = resp.generations, resp.batch
toExit = False
for gen in generations:
if gen.generated_text.text:
print("finished")
res = gen.generated_text.text
toExit = True

if toExit:
break

print(res)

0 comments on commit d6e3758

Please sign in to comment.