Skip to content

Commit

Permalink
Add speculative decoding params to lm_bench (#1221)
Browse files Browse the repository at this point in the history
Task: [CVS-155520](https://jira.devtools.intel.com/browse/CVS-155520)

---------

Co-authored-by: Ekaterina Aidova <[email protected]>
  • Loading branch information
sbalandi and eaidova authored Nov 20, 2024
1 parent 37d01e8 commit a2e1ae9
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 14 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/llm_bench-python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ jobs:
run: |
wget -O ./ov_models/soulcard.safetensors https://civitai.com/api/download/models/72591
python ./tools/llm_bench/benchmark.py -m ./ov_models/dreamlike-art-dreamlike-anime-1.0/FP16/ -pf ./tools/llm_bench/prompts/stable-diffusion.jsonl -d cpu -n 1 --genai --lora ./ov_models/soulcard.safetensors --lora_alphas 0.7
- name: Test TinyLlama-1.1B-Chat-v1.0 in Speculative Deconding mode on Linux
run: |
optimum-cli export openvino --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 --trust-remote-code --weight-format fp16 ov_models/TinyLlama-1.1B-Chat-v1.0/FP16
optimum-cli export openvino --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 --trust-remote-code --weight-format int8 ov_models/TinyLlama-1.1B-Chat-v1.0/INT8
python ./tools/llm_bench/benchmark.py -m ./ov_models/TinyLlama-1.1B-Chat-v1.0/FP16/ --draft_model ./ov_models/TinyLlama-1.1B-Chat-v1.0/INT8/ -p "Why is the Sun yellow?" -d cpu --draft_device cpu -n 1 --genai --assistant_confidence_threshold 0.4
python ./tools/llm_bench/benchmark.py -m ./ov_models/TinyLlama-1.1B-Chat-v1.0/FP16/ --draft_model ./ov_models/TinyLlama-1.1B-Chat-v1.0/INT8/ -p "Why is the Sun yellow?" -d cpu --draft_device cpu -n 1 --genai --num_assistant_tokens 5
- name: Test whisper-tiny on Linux
run: |
GIT_LFS_SKIP_SMUDGE=1 git clone --depth 1 --branch main --single-branch https://huggingface.co/datasets/facebook/multilingual_librispeech
Expand Down
8 changes: 8 additions & 0 deletions tools/llm_bench/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,14 @@ def get_argprser():
parser.add_argument('--lora_alphas', nargs='*', help='Alphas params for LoRA adapters.', required=False, default=[])
parser.add_argument("--use_cb", action="store_true", help="Use Continuous Batching inference mode")
parser.add_argument("--cb_config", required=False, default=None, help="Path to file with Continuous Batching Scheduler settings or dict")
parser.add_argument("--draft_model", required=False, default=None,
help="Path to draft model folder including IR files for Speculative decoding generation")
parser.add_argument("--draft_device", required=False, default=None, help="Inference device for Speculative decoding of draft model")
parser.add_argument("--draft_cb_config", required=False, default=None,
help="Path to file with Continuous Batching Scheduler settings or dict for Speculative decoding of draft model")
parser.add_argument("--num_assistant_tokens", required=False, default=None, help="Config option num_assistant_tokens for Speculative decoding")
parser.add_argument("--assistant_confidence_threshold", required=False, default=None,
help="Config option assistant_confidence_threshold for Speculative decoding")
parser.add_argument(
'--end_token_stopping',
action='store_true',
Expand Down
10 changes: 9 additions & 1 deletion tools/llm_bench/llm_bench_utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,20 @@ def analyze_args(args):
model_args['model_type'] = get_model_type(model_name, use_case, model_framework)
model_args['model_name'] = model_name

if args.use_cb and not args.genai:
if (args.use_cb or args.draft_model) and not args.genai:
raise RuntimeError("Continuous batching mode supported only via OpenVINO GenAI")
cb_config = None
if args.cb_config:
cb_config = get_config(args.cb_config)
model_args["cb_config"] = cb_config
model_args['draft_model'] = args.draft_model
model_args['draft_device'] = args.draft_device
draft_cb_config = None
if args.draft_cb_config:
draft_cb_config = get_config(args.draft_cb_config)
model_args["draft_cb_config"] = draft_cb_config
model_args['num_assistant_tokens'] = args.num_assistant_tokens
model_args['assistant_confidence_threshold'] = args.assistant_confidence_threshold
return model_path, model_framework, model_args, model_name


Expand Down
37 changes: 27 additions & 10 deletions tools/llm_bench/llm_bench_utils/ov_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,21 @@ def create_text_gen_model(model_path, device, **kwargs):
return ov_model, tokenizer, from_pretrained_time, bench_hook, False


def get_scheduler_config_genai(user_config, config_name="CB config"):
import openvino_genai

default_cb_config = {"cache_size": 1}
scheduler_config = openvino_genai.SchedulerConfig()
scheduler_params = user_config or default_cb_config
if scheduler_params:
log.info(f"Scheduler parameters for {config_name}:\n{scheduler_params}")

for param, value in scheduler_params.items():
setattr(scheduler_config, param, value)

return scheduler_config


def create_genai_text_gen_model(model_path, device, ov_config, **kwargs):
import openvino_tokenizers # noqa: F401
import openvino_genai
Expand All @@ -214,18 +229,20 @@ def create_genai_text_gen_model(model_path, device, ov_config, **kwargs):

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

draft_model_path = kwargs.get("draft_model", '')
cb = kwargs.get("use_cb", False)
if cb:
if cb or draft_model_path:
log.info("Continuous Batching mode activated")
default_cb_config = {"cache_size": 1}
scheduler_config = openvino_genai.SchedulerConfig()
scheduler_params = kwargs.get("cb_config") or default_cb_config
if scheduler_params:
log.info(f"Scheduler parameters:\n{scheduler_params}")
ov_config["scheduler_config"] = get_scheduler_config_genai(kwargs.get("cb_config"))

for param, value in scheduler_params.items():
setattr(scheduler_config, param, value)
ov_config["scheduler_config"] = scheduler_config
if draft_model_path:
if not Path(draft_model_path).exists():
raise RuntimeError(f'==Failure ==: draft model by path:{draft_model_path} is not exists')
log.info("Speculative Decoding is activated")
draft_device = kwargs.get('draft_device', None) or device
draft_model_load_kwargs = {'scheduler_config': get_scheduler_config_genai(kwargs.get("draft_cb_config"), "draft CB config")}\
if kwargs.get("draft_cb_config") is not None else {}
ov_config['draft_model'] = openvino_genai.draft_model(draft_model_path, draft_device.upper(), **draft_model_load_kwargs)

adapter_config = get_lora_config(kwargs.get("lora", None), kwargs.get("lora_alphas", []))
if adapter_config:
Expand Down Expand Up @@ -263,7 +280,7 @@ def get_tokens(self):

def get_time_list(self):
return self.token_generation_time
streamer = TokenStreamer(llm_pipe.get_tokenizer()) if cb else None
streamer = TokenStreamer(llm_pipe.get_tokenizer()) if cb or draft_model_path else None

return llm_pipe, tokenizer, end - start, streamer, True

Expand Down
31 changes: 28 additions & 3 deletions tools/llm_bench/task/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,24 @@ def run_text_generation_genai(input_text, num, model, tokenizer, args, iter_data
if args['infer_count'] is not None:
out_str += 'all max_output_token_size: {} * {}'.format(args['infer_count'], args['batch_size'])
log.info(out_str)
gen_config = model.get_generation_config()
gen_config.max_new_tokens = max_gen_tokens
gen_config.num_beams = args["num_beams"]
gen_config.do_sample = False
if args.get('draft_model', ''):
config_info = "Speculative decoding config: "
if args.get('num_assistant_tokens', None):
gen_config.num_assistant_tokens = args['num_assistant_tokens']
config_info += f" num_assistant_tokens {gen_config.num_assistant_tokens}"
if args.get('assistant_confidence_threshold', None):
gen_config.assistant_confidence_threshold = args['assistant_confidence_threshold']
config_info += f" assistant_confidence_threshold {gen_config.assistant_confidence_threshold}"
log.info(config_info)
start = time.perf_counter()
generation_result = model.generate(input_text_list, max_new_tokens=max_gen_tokens, num_beams=args["num_beams"], do_sample=False)
generation_result = model.generate(input_text_list, gen_config)
end = time.perf_counter()
generated_text = generation_result.texts
perf_metrics = generation_result.perf_metrics

if (args['mem_consumption'] == 1 and num == 0) or args['mem_consumption'] == 2:
mem_consumption.end_collect_momory_consumption()
max_rss_mem_consumption, max_shared_mem_consumption, max_uss_mem_consumption = mem_consumption.get_max_memory_consumption()
Expand Down Expand Up @@ -314,8 +326,21 @@ def run_text_generation_genai_with_stream(input_text, num, model, tokenizer, arg
mem_consumption.start_collect_memory_consumption()
max_gen_tokens = DEFAULT_OUTPUT_TOKEN_SIZE if args['infer_count'] is None else args['infer_count']
streamer.reset()
gen_config = model.get_generation_config()
gen_config.max_new_tokens = max_gen_tokens
gen_config.num_beams = args["num_beams"]
gen_config.do_sample = False
if args.get('draft_model', ''):
config_info = "Speculative decoding config: "
if args.get("num_assistant_tokens", None):
gen_config.num_assistant_tokens = int(args["num_assistant_tokens"])
config_info += f'num_assistant_tokens {args["num_assistant_tokens"]}'
if args.get("assistant_confidence_threshold", None):
gen_config.assistant_confidence_threshold = float(args["assistant_confidence_threshold"])
config_info += f'assistant_confidence_threshold {args["assistant_confidence_threshold"]}'
log.info(config_info)
start = time.perf_counter()
generated_tokens = model.generate(input_data, max_new_tokens=max_gen_tokens, num_beams=args["num_beams"], streamer=streamer, do_sample=False).tokens
generated_tokens = model.generate(input_data, gen_config, streamer=streamer).tokens
end = time.perf_counter()
if (args['mem_consumption'] == 1 and num == 0) or args['mem_consumption'] == 2:
mem_consumption.end_collect_momory_consumption()
Expand Down

0 comments on commit a2e1ae9

Please sign in to comment.