Skip to content

Commit

Permalink
Migrate logits computation and gather to model_runner (vllm-proje…
Browse files Browse the repository at this point in the history
  • Loading branch information
esmeetu authored Mar 20, 2024
1 parent 6e435de commit f1c0fc3
Show file tree
Hide file tree
Showing 35 changed files with 577 additions and 306 deletions.
3 changes: 3 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ steps:
- label: Samplers Test
command: pytest -v -s samplers

- label: LogitsProcessor Test
command: pytest -v -s test_logits_processor.py

- label: Worker Test
command: pytest -v -s worker

Expand Down
7 changes: 5 additions & 2 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import vllm
from vllm.config import LoRAConfig
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
Expand Down Expand Up @@ -85,7 +86,8 @@ def dummy_model() -> nn.Module:
("outact", nn.Sigmoid()),
# Special handling for lm_head & sampler
("lm_head", ParallelLMHead(512, 10)),
("sampler", Sampler(512))
("logits_processor", LogitsProcessor(512)),
("sampler", Sampler())
]))
model.config = MagicMock()
return model
Expand All @@ -110,7 +112,8 @@ def dummy_model_gate_up() -> nn.Module:
("outact", nn.Sigmoid()),
# Special handling for lm_head & sampler
("lm_head", ParallelLMHead(512, 10)),
("sampler", Sampler(512))
("logits_processor", LogitsProcessor(512)),
("sampler", Sampler())
]))
model.config = MagicMock()
return model
Expand Down
66 changes: 36 additions & 30 deletions tests/lora/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
QKVParallelLinearWithLora,
VocabParallelEmbeddingWithLoRA,
RowParallelLinearWithLoRA,
SamplerWithLoRA,
LogitsProcessorWithLoRA,
LoRAMapping,
BaseLayerWithLoRA,
)
from vllm.lora.models import (LoRALayerWeights, convert_mapping,
PackedLoRALayerWeights)
from vllm.config import LoRAConfig
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear,
Expand Down Expand Up @@ -394,36 +394,37 @@ def create_random_embedding_layer():
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_lm_head_sampler(dist_init, num_loras, device) -> None:
def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:

torch.set_default_device(device)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
lora_dtype=torch.float16)

def create_random_sampler_layer():
def _pretest():
linear = ParallelLMHead(32000 + lora_config.lora_extra_vocab_size,
1024, 32000)
linear.weight.data = torch.rand_like(linear.weight.data)
linear.weight.data[:, 32000:] = 0
sampler = Sampler(32000 + lora_config.lora_extra_vocab_size, 32000)
lora_sampler = SamplerWithLoRA(sampler, 1024, linear.weight.dtype,
linear.weight.device)
lora_sampler.create_lora_weights(max_loras, lora_config)
logits_processor = LogitsProcessor(
32000 + lora_config.lora_extra_vocab_size, 32000)
lora_logits_processor = LogitsProcessorWithLoRA(
logits_processor, 1024, linear.weight.dtype, linear.weight.device)
lora_logits_processor.create_lora_weights(max_loras, lora_config)

return linear, sampler, lora_sampler
return linear, logits_processor, lora_logits_processor

for i in range(10):
set_random_seed(i)

id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, sampler, lora_sampler = create_random_sampler_layer()
linear, logits_processor, lora_logits_processor = _pretest()

# NOTE: all the generated loras share the same embeddings tensor.
lora_dict, _ = populate_loras(
id_to_index,
layer=lora_sampler,
layer=lora_logits_processor,
layer_weights=linear.weight,
generate_embeddings_tensor=1024,
)
Expand All @@ -447,34 +448,37 @@ def create_random_sampler_layer():
32000,
lora_config.lora_extra_vocab_size,
)
lora_sampler.set_mapping(*mapping_info, )
lora_logits_processor.set_mapping(*mapping_info, )

lora_result = lora_sampler._get_logits(hidden_states=torch.cat(inputs),
embedding=linear.weight,
embedding_bias=None)
lora_result = lora_logits_processor._get_logits(
hidden_states=torch.cat(inputs),
embedding=linear.weight,
embedding_bias=None)

original_weight = linear.weight.clone()

linear.weight[sampler.org_vocab_size:sampler.org_vocab_size +
linear.weight[logits_processor.
org_vocab_size:logits_processor.org_vocab_size +
embeddings_tensor_len] = embeddings_tensor

sampler.org_vocab_size = 32000 + lora_config.lora_extra_vocab_size
logits_processor.org_vocab_size = (32000 +
lora_config.lora_extra_vocab_size)
expected_results = []
for input_, lora_id in zip(inputs, prompt_mapping):
lora = lora_dict[lora_id]
result = sampler._get_logits(hidden_states=input_,
embedding=linear.weight,
embedding_bias=None)
result = logits_processor._get_logits(hidden_states=input_,
embedding=linear.weight,
embedding_bias=None)
result[:, 32000 + embeddings_tensor_len:] = float("-inf")
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
expected_results.append(result)
expected_result = torch.cat(expected_results)
sampler.org_vocab_size = 32000
logits_processor.org_vocab_size = 32000

# Check that resetting the lora weights succeeds

for slot_idx in range(max_loras):
lora_sampler.reset_lora(slot_idx)
lora_logits_processor.reset_lora(slot_idx)

inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=[0],
Expand All @@ -488,14 +492,16 @@ def create_random_sampler_layer():
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
32000,
lora_config.lora_extra_vocab_size)
lora_sampler.set_mapping(*mapping_info, )

lora_result = lora_sampler._get_logits(hidden_states=torch.cat(inputs),
embedding=original_weight,
embedding_bias=None)[:, :32000]
expected_result = sampler._get_logits(hidden_states=torch.cat(inputs),
embedding=original_weight,
embedding_bias=None)
lora_logits_processor.set_mapping(*mapping_info, )

lora_result = lora_logits_processor._get_logits(
hidden_states=torch.cat(inputs),
embedding=original_weight,
embedding_bias=None)[:, :32000]
expected_result = logits_processor._get_logits(
hidden_states=torch.cat(inputs),
embedding=original_weight,
embedding_bias=None)

rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result,
Expand Down
95 changes: 20 additions & 75 deletions tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,12 @@

class MockLogitsSampler(Sampler):

def __init__(self, vocab_size: int, fake_logits: torch.Tensor):
super().__init__(vocab_size=vocab_size)
def __init__(self, fake_logits: torch.Tensor):
super().__init__()
self.fake_logits = fake_logits

def forward(self, *args, **kwargs):
with patch(
"vllm.model_executor.layers.sampler._prune_hidden_states",
lambda x, y: x), patch(
"vllm.model_executor.layers.sampler.Sampler._get_logits",
lambda *args, **kwargs: self.fake_logits):
return super().forward(*args, **kwargs)
return super().forward(*args, **kwargs)


def _prepare_test(
Expand All @@ -36,7 +31,7 @@ def _prepare_test(
fake_logits = torch.full((batch_size, vocab_size),
1e-2,
dtype=input_tensor.dtype)
sampler = MockLogitsSampler(32000, fake_logits)
sampler = MockLogitsSampler(fake_logits)
model_runner = ModelRunner(None, None, None, None, None)
return input_tensor, fake_logits, sampler, model_runner

Expand Down Expand Up @@ -70,9 +65,7 @@ def _do_sample(
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
return sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
Expand All @@ -85,8 +78,8 @@ def test_sampler_all_greedy(seed: int, device: str):
batch_size)

sampling_params = SamplingParams(temperature=0)
sampler_output = _do_sample(batch_size, input_tensor, sampler,
model_runner, sampling_params)
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
sampling_params)
expected = torch.argmax(fake_logits, dim=-1)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
Expand All @@ -111,8 +104,8 @@ def test_sampler_all_random(seed: int, device: str):
temperature=1.0,
n=random.randint(1, 10),
)
sampler_output = _do_sample(batch_size, input_tensor, sampler,
model_runner, sampling_params)
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
sampling_params)

for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
Expand All @@ -127,8 +120,7 @@ def test_sampler_all_random_seed(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)

for i in range(batch_size):
fake_logits[i, i] = 1e2
Expand All @@ -138,8 +130,8 @@ def test_sampler_all_random_seed(seed: int, device: str):
n=random.randint(1, 10),
seed=random.randint(0, 10000),
)
sampler_output = _do_sample(batch_size, input_tensor, sampler,
model_runner, sampling_params)
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
sampling_params)

for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
Expand All @@ -154,18 +146,17 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)

sampling_params = SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
seed=random.randint(0, 10000),
)
first_sampler_output = _do_sample(batch_size, input_tensor, sampler,
first_sampler_output = _do_sample(batch_size, fake_logits, sampler,
model_runner, sampling_params)

second_sampler_output = _do_sample(batch_size, input_tensor, sampler,
second_sampler_output = _do_sample(batch_size, fake_logits, sampler,
model_runner, sampling_params)

assert first_sampler_output == second_sampler_output
Expand All @@ -179,15 +170,14 @@ def test_sampler_all_beam(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)

sampling_params = SamplingParams(
temperature=0,
best_of=2,
use_beam_search=True,
)
_do_sample(batch_size, input_tensor, sampler, model_runner,
sampling_params)
_do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params)
# no assertion here as I am not sure how to determine whether
# the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler
Expand Down Expand Up @@ -246,8 +236,7 @@ def test_sampler_mixed(seed: int, device: str):
def test_sampling(model_runner: ModelRunner):
sampling_metadata = model_runner._prepare_sample(
seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampler_output = sampler(logits=fake_logits,
sampling_metadata=sampling_metadata)

for i, (sequence_output, metadata) in enumerate(
Expand Down Expand Up @@ -294,48 +283,6 @@ def test_sampling(model_runner: ModelRunner):
del model_runner


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_logits_processors(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)

# This sample logits processor gives maximum score to the i-th token,
# where i is the length of the input sequence.
# We therefore expect the output token sequence to be [0, 1, 2, ...]
def pick_ith(token_ids, logits):
logits[len(token_ids)] = torch.finfo(logits.dtype).max
return logits

seq_group_metadata_list = []
prompt_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(temperature=0,
logits_processors=[pick_ith]),
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
for _, sequence_output in enumerate(sampler_output):
for idx, nth_output in enumerate(sequence_output.samples):
assert nth_output.output_token == idx

del model_runner


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_top_k_top_p(seed: int, device: str):
Expand All @@ -352,7 +299,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
size=(batch_size, vocab_size),
device=input_tensor.device,
dtype=input_tensor.dtype)
sampler = MockLogitsSampler(32000, fake_logits)
sampler = MockLogitsSampler(fake_logits)
model_runner = ModelRunner(None, None, None, None, None)

generation_model = GenerationMixin()
Expand Down Expand Up @@ -391,9 +338,7 @@ def mock_sample(probs, *args, **kwargs):
return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs]

with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone())
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
assert torch.allclose(hf_probs, sample_probs, atol=1e-5)
Expand Down
Loading

0 comments on commit f1c0fc3

Please sign in to comment.