diff --git a/tests/lora/test_long_context.py b/tests/lora/test_long_context.py index 9d51591d705ce..f57136b5ca613 100644 --- a/tests/lora/test_long_context.py +++ b/tests/lora/test_long_context.py @@ -28,9 +28,15 @@ def _create_lora_request(lora_id, long_context_infos): context_len = long_context_infos[lora_id]["context_length"] scaling_factor = context_len_to_scaling_factor[context_len] - return LoRARequest(context_len, lora_id, - long_context_infos[lora_id]["lora"], None, - 4096 * scaling_factor) + return LoRARequest( + # There are 2 LoRAs for 16K, we need to add lora_id to indicate + # they are different LoRAs. + context_len + str(lora_id), + lora_id, + long_context_infos[lora_id]["lora"], + None, + 4096 * scaling_factor, + ) def evaluate_json_response(model_response, golden_response): @@ -56,13 +62,14 @@ def evaluate_json_response(model_response, golden_response): for person_attribute, person_attribute_value in golden_response.items(): if person_attribute in model_response: if isinstance(person_attribute_value, dict): - for (sub_attribute, - sub_attribute_value) in person_attribute_value.items(): + for ( + sub_attribute, + sub_attribute_value, + ) in person_attribute_value.items(): total_values += 1 - if sub_attribute in model_response[ - person_attribute] and model_response[ - person_attribute][ - sub_attribute] == sub_attribute_value: + if (sub_attribute in model_response[person_attribute] + and model_response[person_attribute][sub_attribute] + == sub_attribute_value): positive_values += 1 else: total_values += 1 @@ -92,10 +99,12 @@ def batched_generate( for input in inputs: prompt, sampling_param, lora_req = input # Add requests to the engine and run the engine - llm._validate_and_add_requests(prompt, - sampling_param, - lora_request=lora_req, - prompt_adapter_request=None) + llm._validate_and_add_requests( + prompt, + sampling_param, + lora_request=lora_req, + prompt_adapter_request=None, + ) outputs = llm._run_engine(use_tqdm=True) return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))] @@ -118,7 +127,8 @@ def lora_llm(long_context_infos): tensor_parallel_size=4, # FIXME enable async output processor disable_async_output_proc=True, - distributed_executor_backend="mp") + distributed_executor_backend="mp", + ) yield llm del llm @@ -127,9 +137,12 @@ def test_rotary_emb_replaced(dist_init): """Verify rotary emb in all the layers are replaced""" from vllm.engine.arg_utils import EngineArgs from vllm.worker.model_runner import ModelRunner - engine_args = EngineArgs("meta-llama/Llama-2-7b-hf", - long_lora_scaling_factors=(4.0, ), - enable_lora=True) + + engine_args = EngineArgs( + "meta-llama/Llama-2-7b-hf", + long_lora_scaling_factors=(4.0, ), + enable_lora=True, + ) engine_config = engine_args.create_engine_config() model_runner = ModelRunner( model_config=engine_config.model_config, @@ -158,16 +171,18 @@ def test_rotary_emb_replaced(dist_init): @pytest.mark.skip_global_cleanup def test_batched_rope_kernel(lora_llm, long_context_infos): """We test the batched kernel by comparing the results of batched an - non-batched generation. + non-batched generation. """ # Create non batched results first to compare against batched results non_batched_results: List[str] = [] for lora_id, info in long_context_infos.items(): context_len = info["context_length"] - lora_prompt = (prompts_and_responses[context_len][0]["prompt"], - sampling_params, - _create_lora_request(lora_id, long_context_infos)) + lora_prompt = ( + prompts_and_responses[context_len][0]["prompt"], + sampling_params, + _create_lora_request(lora_id, long_context_infos), + ) lora_output = generate(lora_llm, lora_prompt) non_batched_results.append(lora_output) @@ -178,10 +193,11 @@ def test_batched_rope_kernel(lora_llm, long_context_infos): Optional[LoRARequest]]] = [] for lora_id, info in long_context_infos.items(): context_len = info["context_length"] - batched_prompts.extend([ - (prompts_and_responses[context_len][0]["prompt"], sampling_params, - _create_lora_request(lora_id, long_context_infos)) - ]) + batched_prompts.extend([( + prompts_and_responses[context_len][0]["prompt"], + sampling_params, + _create_lora_request(lora_id, long_context_infos), + )]) batched_results = batched_generate(lora_llm, batched_prompts) # Results should be the same @@ -203,10 +219,11 @@ def test_self_consistency(lora_llm, long_context_infos): Optional[LoRARequest]]] = [] for lora_id, info in long_context_infos.items(): context_len = info["context_length"] - batched_prompts.extend([ - (prompts_and_responses[context_len][0]["prompt"], sampling_params, - _create_lora_request(lora_id, long_context_infos)) - ]) + batched_prompts.extend([( + prompts_and_responses[context_len][0]["prompt"], + sampling_params, + _create_lora_request(lora_id, long_context_infos), + )]) batched_results = batched_generate(lora_llm, batched_prompts) @@ -217,19 +234,20 @@ def test_self_consistency(lora_llm, long_context_infos): for i in permutation: lora_id, info = list(long_context_infos.items())[i] context_len = info["context_length"] - batched_prompts.extend([ - (prompts_and_responses[context_len][0]["prompt"], sampling_params, - _create_lora_request(lora_id, long_context_infos)) - ]) + batched_prompts.extend([( + prompts_and_responses[context_len][0]["prompt"], + sampling_params, + _create_lora_request(lora_id, long_context_infos), + )]) permutated_batched_results = batched_generate(lora_llm, batched_prompts) # Results should be the same for i in range(num_loras): - assert batched_results[i] == permutated_batched_results[ - permutation[i]], ( - f"Results should be the same:\n{batched_results[i]}" - f"\n{permutated_batched_results[permutation[i]]}") + assert ( + batched_results[i] == permutated_batched_results[permutation[i]] + ), (f"Results should be the same:\n{batched_results[i]}" + f"\n{permutated_batched_results[permutation[i]]}") @pytest.mark.skip_global_cleanup @@ -252,8 +270,11 @@ def test_quality(lora_llm, long_context_infos): for lora_id, info in long_context_infos.items(): context_len = info["context_length"] for prompt_and_response in prompts_and_responses[context_len]: - lora_prompt = (prompt_and_response["prompt"], sampling_params, - _create_lora_request(lora_id, long_context_infos)) + lora_prompt = ( + prompt_and_response["prompt"], + sampling_params, + _create_lora_request(lora_id, long_context_infos), + ) response = generate(lora_llm, lora_prompt) golden_answer = prompt_and_response["golden_answer"] score = evaluate_json_response(response, golden_answer) @@ -266,7 +287,7 @@ def test_quality(lora_llm, long_context_infos): @pytest.mark.skip_global_cleanup def test_max_len(lora_llm, long_context_infos): """Test that we raise an ValueError when the input of a given LoRA - model exceeds the maximum length.""" + model exceeds the maximum length.""" # Since each LoRA model has a different maximum length, we need to # test each one separately for lora_id, info in long_context_infos.items(): @@ -286,12 +307,12 @@ def test_max_len(lora_llm, long_context_infos): for lora_id_with_bad_inputs in long_context_infos: for lora_id, info in long_context_infos.items(): context_len = info["context_length"] - batched_prompts.extend([ - (prompts_and_responses[context_len][0]["prompt"] * - (2 if lora_id == lora_id_with_bad_inputs else 1), - sampling_params, - _create_lora_request(lora_id, long_context_infos)) - ]) + batched_prompts.extend([( + prompts_and_responses[context_len][0]["prompt"] * + (2 if lora_id == lora_id_with_bad_inputs else 1), + sampling_params, + _create_lora_request(lora_id, long_context_infos), + )]) # Turn good prompt into bad prompt inside of batched prompts with pytest.raises(ValueError):