Skip to content

Commit

Permalink
Add speculative decoding params to lm_bench
Browse files Browse the repository at this point in the history
  • Loading branch information
sbalandi committed Nov 18, 2024
1 parent 3b02a29 commit 7ac52d7
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 6 deletions.
6 changes: 6 additions & 0 deletions tools/llm_bench/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,12 @@ def get_argprser():
parser.add_argument('--lora_alphas', nargs='*', help='Alphas params for LoRA adapters.', required=False, default=[])
parser.add_argument("--use_cb", action="store_true", help="Use Continuous Batching inference mode")
parser.add_argument("--cb_config", required=False, default=None, help="Path to file with Continuous Batching Scheduler settings or dict")
parser.add_argument("--draft_model", required=False, default=None,
help="Path to draft model folder including IR files for Speculative decoding generation.")
parser.add_argument("--draft_device", required=False, default='cpu', help="Inference device for Speculative decoding generation.")
parser.add_argument("--num_assistant_tokens", required=False, default=5, help="Config option num_assistant_tokens for Speculative decoding")
parser.add_argument("--assistant_confidence_threshold", required=False, default=0,
help="Config option assistant_confidence_threshold for Speculative decodin")
parser.add_argument(
'--end_token_stopping',
action='store_true',
Expand Down
8 changes: 7 additions & 1 deletion tools/llm_bench/llm_bench_utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,18 @@ def analyze_args(args):
model_args['model_type'] = get_model_type(model_name, use_case, model_framework)
model_args['model_name'] = model_name

if args.use_cb and not args.genai:
if (args.use_cb or args.draft_model) and not args.genai:
raise RuntimeError("Continuous batching mode supported only via OpenVINO GenAI")
cb_config = None
if args.cb_config:
cb_config = get_config(args.cb_config)
model_args["cb_config"] = cb_config
model_args['draft_model'] = args.draft_model
model_args['draft_device'] = args.draft_device
if (args.num_assistant_tokens > 0 and args.assistant_confidence_threshold > 0):
raise RuntimeError("Parameters `assistant_confidence_threshold` and `num_assistant_tokens` are mutually exclusive")
model_args['num_assistant_tokens'] = args.num_assistant_tokens
model_args['assistant_confidence_threshold'] = args.assistant_confidence_threshold
return model_path, model_framework, model_args, model_name


Expand Down
12 changes: 10 additions & 2 deletions tools/llm_bench/llm_bench_utils/ov_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,9 @@ def create_genai_text_gen_model(model_path, device, ov_config, **kwargs):

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

draft_model_path = kwargs.get("draft_model", '')
cb = kwargs.get("use_cb", False)
if cb:
if cb or draft_model_path:
log.info("Continuous Batching mode activated")
default_cb_config = {"cache_size": 1}
scheduler_config = openvino_genai.SchedulerConfig()
Expand All @@ -227,6 +228,13 @@ def create_genai_text_gen_model(model_path, device, ov_config, **kwargs):
setattr(scheduler_config, param, value)
ov_config["scheduler_config"] = scheduler_config

if draft_model_path:
if not Path(draft_model_path).exists():
raise RuntimeError(f'==Failure ==: draft model by path:{draft_model_path} is not exists')
log.info("Speculative Decoding is activated")

ov_config['draft_model'] = openvino_genai.draft_model(draft_model_path, kwargs['draft_device'].upper())

adapter_config = get_lora_config(kwargs.get("lora", None), kwargs.get("lora_alphas", []))
if adapter_config:
ov_config['adapters'] = adapter_config
Expand Down Expand Up @@ -263,7 +271,7 @@ def get_tokens(self):

def get_time_list(self):
return self.token_generation_time
streamer = TokenStreamer(llm_pipe.get_tokenizer()) if cb else None
streamer = TokenStreamer(llm_pipe.get_tokenizer()) if cb or draft_model_path else None

return llm_pipe, tokenizer, end - start, streamer, True

Expand Down
25 changes: 22 additions & 3 deletions tools/llm_bench/task/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,21 @@ def run_text_generation_genai(input_text, num, model, tokenizer, args, iter_data
if args['infer_count'] is not None:
out_str += 'all max_output_token_size: {} * {}'.format(args['infer_count'], args['batch_size'])
log.info(out_str)
gen_config = model.get_generation_config()
gen_config.max_new_tokens = max_gen_tokens
gen_config.num_beams = args["num_beams"]
gen_config.do_sample = False
if args.get('draft_model', ''):
gen_config.num_assistant_tokens = args['num_assistant_tokens']
gen_config.assistant_confidence_threshold = args['assistant_confidence_threshold']
log.info("Speculative decoding config: ")
log.info(f" num_assistant_tokens {gen_config.num_assistant_tokens}")
log.info(f" assistant_confidence_threshold {gen_config.assistant_confidence_threshold}")
start = time.perf_counter()
generation_result = model.generate(input_text_list, max_new_tokens=max_gen_tokens, num_beams=args["num_beams"], do_sample=False)
generation_result = model.generate(input_text_list, gen_config)
end = time.perf_counter()
generated_text = generation_result.texts
perf_metrics = generation_result.perf_metrics

if (args['mem_consumption'] == 1 and num == 0) or args['mem_consumption'] == 2:
mem_consumption.end_collect_momory_consumption()
max_rss_mem_consumption, max_shared_mem_consumption, max_uss_mem_consumption = mem_consumption.get_max_memory_consumption()
Expand Down Expand Up @@ -314,8 +323,18 @@ def run_text_generation_genai_with_stream(input_text, num, model, tokenizer, arg
mem_consumption.start_collect_memory_consumption()
max_gen_tokens = DEFAULT_OUTPUT_TOKEN_SIZE if args['infer_count'] is None else args['infer_count']
streamer.reset()
gen_config = model.get_generation_config()
gen_config.max_new_tokens = max_gen_tokens
gen_config.num_beams = args["num_beams"]
gen_config.do_sample = False
if args.get('draft_model', ''):
gen_config.num_assistant_tokens = args['num_assistant_tokens']
gen_config.assistant_confidence_threshold = args['assistant_confidence_threshold']
log.info("Speculative decoding config: ")
log.info(f" num_assistant_tokens {gen_config.num_assistant_tokens}")
log.info(f" assistant_confidence_threshold {gen_config.assistant_confidence_threshold}")
start = time.perf_counter()
generated_tokens = model.generate(input_data, max_new_tokens=max_gen_tokens, num_beams=args["num_beams"], streamer=streamer, do_sample=False).tokens
generated_tokens = model.generate(input_data, gen_config, streamer=streamer).tokens
end = time.perf_counter()
if (args['mem_consumption'] == 1 and num == 0) or args['mem_consumption'] == 2:
mem_consumption.end_collect_momory_consumption()
Expand Down

0 comments on commit 7ac52d7

Please sign in to comment.