Skip to content

Commit

Permalink
Fix 3
Browse files Browse the repository at this point in the history
  • Loading branch information
ilya-lavrenov committed Dec 27, 2024
1 parent 202b69e commit 0883e3f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
9 changes: 6 additions & 3 deletions tests/python_tests/test_continuous_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,12 @@ def test_post_oom_health(tmp_path, sampling_config):
# Pre-emption
#

def get_greedy_seq_len_300() -> GenerationConfig:
def get_parallel_samppling_seq_len_300() -> GenerationConfig:
generation_config = GenerationConfig()
generation_config.num_return_sequences = 3
generation_config.do_sample = True
generation_config.top_k = 10
generation_config.top_p = 0.5
generation_config.max_new_tokens = 300
return generation_config

Expand All @@ -185,8 +188,8 @@ def get_beam_search_seq_len_300() -> GenerationConfig:

scheduler_params_list = [({"num_kv_blocks": 2, "dynamic_split_fuse": True, "max_num_batched_tokens": 256, "max_num_seqs": 256}, get_greedy()),
({"num_kv_blocks": 2, "dynamic_split_fuse": False, "max_num_batched_tokens": 256, "max_num_seqs": 256}, get_greedy()),
({"num_kv_blocks": 10, "dynamic_split_fuse": True}, get_greedy_seq_len_300()),
({"num_kv_blocks": 10, "dynamic_split_fuse": False}, get_greedy_seq_len_300()),
({"num_kv_blocks": 10, "dynamic_split_fuse": True}, get_parallel_samppling_seq_len_300()),
({"num_kv_blocks": 10, "dynamic_split_fuse": False}, get_parallel_samppling_seq_len_300()),
({"num_kv_blocks": 34, "dynamic_split_fuse": True, "max_num_batched_tokens": 256, "max_num_seqs": 256}, get_beam_search()),
({"num_kv_blocks": 34, "dynamic_split_fuse": False, "max_num_batched_tokens": 256, "max_num_seqs": 256}, get_beam_search()),
({"num_kv_blocks": 100, "dynamic_split_fuse": True}, get_beam_search_seq_len_300()),
Expand Down
2 changes: 1 addition & 1 deletion tests/python_tests/test_llm_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ def test_stop_token_ids():
res = ov_pipe.generate(
ov.Tensor([(1,)]),
max_new_tokens=3,
stop_token_ids={-1, 9935, ov_pipe.get_tokenizer().get_eos_token_id()},
stop_token_ids={9935, ov_pipe.get_tokenizer().get_eos_token_id()},
include_stop_str_in_output=False
)
assert 2 == len(res.tokens[0])
Expand Down

0 comments on commit 0883e3f

Please sign in to comment.