Skip to content

Commit

Permalink
Fix lora_name bug
Browse files Browse the repository at this point in the history
  • Loading branch information
jeejeelee committed Oct 22, 2024
1 parent e59c8e4 commit 69c3141
Showing 1 changed file with 67 additions and 46 deletions.
113 changes: 67 additions & 46 deletions tests/lora/test_long_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,15 @@
def _create_lora_request(lora_id, long_context_infos):
context_len = long_context_infos[lora_id]["context_length"]
scaling_factor = context_len_to_scaling_factor[context_len]
return LoRARequest(context_len, lora_id,
long_context_infos[lora_id]["lora"], None,
4096 * scaling_factor)
return LoRARequest(
# There are 2 LoRAs for 16K, we need to add lora_id to indicate
# they are different LoRAs.
context_len + str(lora_id),
lora_id,
long_context_infos[lora_id]["lora"],
None,
4096 * scaling_factor,
)


def evaluate_json_response(model_response, golden_response):
Expand All @@ -56,13 +62,14 @@ def evaluate_json_response(model_response, golden_response):
for person_attribute, person_attribute_value in golden_response.items():
if person_attribute in model_response:
if isinstance(person_attribute_value, dict):
for (sub_attribute,
sub_attribute_value) in person_attribute_value.items():
for (
sub_attribute,
sub_attribute_value,
) in person_attribute_value.items():
total_values += 1
if sub_attribute in model_response[
person_attribute] and model_response[
person_attribute][
sub_attribute] == sub_attribute_value:
if (sub_attribute in model_response[person_attribute]
and model_response[person_attribute][sub_attribute]
== sub_attribute_value):
positive_values += 1
else:
total_values += 1
Expand Down Expand Up @@ -92,10 +99,12 @@ def batched_generate(
for input in inputs:
prompt, sampling_param, lora_req = input
# Add requests to the engine and run the engine
llm._validate_and_add_requests(prompt,
sampling_param,
lora_request=lora_req,
prompt_adapter_request=None)
llm._validate_and_add_requests(
prompt,
sampling_param,
lora_request=lora_req,
prompt_adapter_request=None,
)

outputs = llm._run_engine(use_tqdm=True)
return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))]
Expand All @@ -118,7 +127,8 @@ def lora_llm(long_context_infos):
tensor_parallel_size=4,
# FIXME enable async output processor
disable_async_output_proc=True,
distributed_executor_backend="mp")
distributed_executor_backend="mp",
)
yield llm
del llm

Expand All @@ -127,9 +137,12 @@ def test_rotary_emb_replaced(dist_init):
"""Verify rotary emb in all the layers are replaced"""
from vllm.engine.arg_utils import EngineArgs
from vllm.worker.model_runner import ModelRunner
engine_args = EngineArgs("meta-llama/Llama-2-7b-hf",
long_lora_scaling_factors=(4.0, ),
enable_lora=True)

engine_args = EngineArgs(
"meta-llama/Llama-2-7b-hf",
long_lora_scaling_factors=(4.0, ),
enable_lora=True,
)
engine_config = engine_args.create_engine_config()
model_runner = ModelRunner(
model_config=engine_config.model_config,
Expand Down Expand Up @@ -158,16 +171,18 @@ def test_rotary_emb_replaced(dist_init):
@pytest.mark.skip_global_cleanup
def test_batched_rope_kernel(lora_llm, long_context_infos):
"""We test the batched kernel by comparing the results of batched an
non-batched generation.
non-batched generation.
"""
# Create non batched results first to compare against batched results
non_batched_results: List[str] = []

for lora_id, info in long_context_infos.items():
context_len = info["context_length"]
lora_prompt = (prompts_and_responses[context_len][0]["prompt"],
sampling_params,
_create_lora_request(lora_id, long_context_infos))
lora_prompt = (
prompts_and_responses[context_len][0]["prompt"],
sampling_params,
_create_lora_request(lora_id, long_context_infos),
)
lora_output = generate(lora_llm, lora_prompt)
non_batched_results.append(lora_output)

Expand All @@ -178,10 +193,11 @@ def test_batched_rope_kernel(lora_llm, long_context_infos):
Optional[LoRARequest]]] = []
for lora_id, info in long_context_infos.items():
context_len = info["context_length"]
batched_prompts.extend([
(prompts_and_responses[context_len][0]["prompt"], sampling_params,
_create_lora_request(lora_id, long_context_infos))
])
batched_prompts.extend([(
prompts_and_responses[context_len][0]["prompt"],
sampling_params,
_create_lora_request(lora_id, long_context_infos),
)])
batched_results = batched_generate(lora_llm, batched_prompts)

# Results should be the same
Expand All @@ -203,10 +219,11 @@ def test_self_consistency(lora_llm, long_context_infos):
Optional[LoRARequest]]] = []
for lora_id, info in long_context_infos.items():
context_len = info["context_length"]
batched_prompts.extend([
(prompts_and_responses[context_len][0]["prompt"], sampling_params,
_create_lora_request(lora_id, long_context_infos))
])
batched_prompts.extend([(
prompts_and_responses[context_len][0]["prompt"],
sampling_params,
_create_lora_request(lora_id, long_context_infos),
)])

batched_results = batched_generate(lora_llm, batched_prompts)

Expand All @@ -217,19 +234,20 @@ def test_self_consistency(lora_llm, long_context_infos):
for i in permutation:
lora_id, info = list(long_context_infos.items())[i]
context_len = info["context_length"]
batched_prompts.extend([
(prompts_and_responses[context_len][0]["prompt"], sampling_params,
_create_lora_request(lora_id, long_context_infos))
])
batched_prompts.extend([(
prompts_and_responses[context_len][0]["prompt"],
sampling_params,
_create_lora_request(lora_id, long_context_infos),
)])

permutated_batched_results = batched_generate(lora_llm, batched_prompts)

# Results should be the same
for i in range(num_loras):
assert batched_results[i] == permutated_batched_results[
permutation[i]], (
f"Results should be the same:\n{batched_results[i]}"
f"\n{permutated_batched_results[permutation[i]]}")
assert (
batched_results[i] == permutated_batched_results[permutation[i]]
), (f"Results should be the same:\n{batched_results[i]}"
f"\n{permutated_batched_results[permutation[i]]}")


@pytest.mark.skip_global_cleanup
Expand All @@ -252,8 +270,11 @@ def test_quality(lora_llm, long_context_infos):
for lora_id, info in long_context_infos.items():
context_len = info["context_length"]
for prompt_and_response in prompts_and_responses[context_len]:
lora_prompt = (prompt_and_response["prompt"], sampling_params,
_create_lora_request(lora_id, long_context_infos))
lora_prompt = (
prompt_and_response["prompt"],
sampling_params,
_create_lora_request(lora_id, long_context_infos),
)
response = generate(lora_llm, lora_prompt)
golden_answer = prompt_and_response["golden_answer"]
score = evaluate_json_response(response, golden_answer)
Expand All @@ -266,7 +287,7 @@ def test_quality(lora_llm, long_context_infos):
@pytest.mark.skip_global_cleanup
def test_max_len(lora_llm, long_context_infos):
"""Test that we raise an ValueError when the input of a given LoRA
model exceeds the maximum length."""
model exceeds the maximum length."""
# Since each LoRA model has a different maximum length, we need to
# test each one separately
for lora_id, info in long_context_infos.items():
Expand All @@ -286,12 +307,12 @@ def test_max_len(lora_llm, long_context_infos):
for lora_id_with_bad_inputs in long_context_infos:
for lora_id, info in long_context_infos.items():
context_len = info["context_length"]
batched_prompts.extend([
(prompts_and_responses[context_len][0]["prompt"] *
(2 if lora_id == lora_id_with_bad_inputs else 1),
sampling_params,
_create_lora_request(lora_id, long_context_infos))
])
batched_prompts.extend([(
prompts_and_responses[context_len][0]["prompt"] *
(2 if lora_id == lora_id_with_bad_inputs else 1),
sampling_params,
_create_lora_request(lora_id, long_context_infos),
)])
# Turn good prompt into bad prompt inside of batched prompts

with pytest.raises(ValueError):
Expand Down

0 comments on commit 69c3141

Please sign in to comment.