diff --git a/tests/python_tests/test_continuous_batching.py b/tests/python_tests/test_continuous_batching.py index 128d8b7413..14ce9bfed3 100644 --- a/tests/python_tests/test_continuous_batching.py +++ b/tests/python_tests/test_continuous_batching.py @@ -168,9 +168,12 @@ def test_post_oom_health(tmp_path, sampling_config): # Pre-emption # -def get_greedy_seq_len_300() -> GenerationConfig: +def get_parallel_samppling_seq_len_300() -> GenerationConfig: generation_config = GenerationConfig() generation_config.num_return_sequences = 3 + generation_config.do_sample = True + generation_config.top_k = 10 + generation_config.top_p = 0.5 generation_config.max_new_tokens = 300 return generation_config @@ -185,8 +188,8 @@ def get_beam_search_seq_len_300() -> GenerationConfig: scheduler_params_list = [({"num_kv_blocks": 2, "dynamic_split_fuse": True, "max_num_batched_tokens": 256, "max_num_seqs": 256}, get_greedy()), ({"num_kv_blocks": 2, "dynamic_split_fuse": False, "max_num_batched_tokens": 256, "max_num_seqs": 256}, get_greedy()), - ({"num_kv_blocks": 10, "dynamic_split_fuse": True}, get_greedy_seq_len_300()), - ({"num_kv_blocks": 10, "dynamic_split_fuse": False}, get_greedy_seq_len_300()), + ({"num_kv_blocks": 10, "dynamic_split_fuse": True}, get_parallel_samppling_seq_len_300()), + ({"num_kv_blocks": 10, "dynamic_split_fuse": False}, get_parallel_samppling_seq_len_300()), ({"num_kv_blocks": 34, "dynamic_split_fuse": True, "max_num_batched_tokens": 256, "max_num_seqs": 256}, get_beam_search()), ({"num_kv_blocks": 34, "dynamic_split_fuse": False, "max_num_batched_tokens": 256, "max_num_seqs": 256}, get_beam_search()), ({"num_kv_blocks": 100, "dynamic_split_fuse": True}, get_beam_search_seq_len_300()), diff --git a/tests/python_tests/test_llm_pipeline.py b/tests/python_tests/test_llm_pipeline.py index e0def3b433..ae20f90a59 100644 --- a/tests/python_tests/test_llm_pipeline.py +++ b/tests/python_tests/test_llm_pipeline.py @@ -663,7 +663,7 @@ def test_stop_token_ids(): res = ov_pipe.generate( ov.Tensor([(1,)]), max_new_tokens=3, - stop_token_ids={-1, 9935, ov_pipe.get_tokenizer().get_eos_token_id()}, + stop_token_ids={9935, ov_pipe.get_tokenizer().get_eos_token_id()}, include_stop_str_in_output=False ) assert 2 == len(res.tokens[0])