Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate logits computation and gather to model_runner #3233

Merged
merged 22 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ steps:
- label: Samplers Test
command: pytest -v -s samplers

- label: LogitsProcessor Test
command: pytest -v -s test_logits_processor.py

- label: Worker Test
command: pytest -v -s worker

Expand Down
7 changes: 5 additions & 2 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import vllm
from vllm.config import LoRAConfig
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
Expand Down Expand Up @@ -85,7 +86,8 @@ def dummy_model() -> nn.Module:
("outact", nn.Sigmoid()),
# Special handling for lm_head & sampler
("lm_head", ParallelLMHead(512, 10)),
("sampler", Sampler(512))
("logits_processor", LogitsProcessor(512)),
("sampler", Sampler())
]))
model.config = MagicMock()
return model
Expand All @@ -110,7 +112,8 @@ def dummy_model_gate_up() -> nn.Module:
("outact", nn.Sigmoid()),
# Special handling for lm_head & sampler
("lm_head", ParallelLMHead(512, 10)),
("sampler", Sampler(512))
("logits_processor", LogitsProcessor(512)),
("sampler", Sampler())
]))
model.config = MagicMock()
return model
Expand Down
66 changes: 36 additions & 30 deletions tests/lora/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
QKVParallelLinearWithLora,
VocabParallelEmbeddingWithLoRA,
RowParallelLinearWithLoRA,
SamplerWithLoRA,
LogitsProcessorWithLoRA,
LoRAMapping,
BaseLayerWithLoRA,
)
from vllm.lora.models import (LoRALayerWeights, convert_mapping,
PackedLoRALayerWeights)
from vllm.config import LoRAConfig
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear,
Expand Down Expand Up @@ -394,36 +394,37 @@ def create_random_embedding_layer():
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_lm_head_sampler(dist_init, num_loras, device) -> None:
def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:

torch.set_default_device(device)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
lora_dtype=torch.float16)

def create_random_sampler_layer():
def _pretest():
linear = ParallelLMHead(32000 + lora_config.lora_extra_vocab_size,
1024, 32000)
linear.weight.data = torch.rand_like(linear.weight.data)
linear.weight.data[:, 32000:] = 0
sampler = Sampler(32000 + lora_config.lora_extra_vocab_size, 32000)
lora_sampler = SamplerWithLoRA(sampler, 1024, linear.weight.dtype,
linear.weight.device)
lora_sampler.create_lora_weights(max_loras, lora_config)
logits_processor = LogitsProcessor(
32000 + lora_config.lora_extra_vocab_size, 32000)
lora_logits_processor = LogitsProcessorWithLoRA(
logits_processor, 1024, linear.weight.dtype, linear.weight.device)
lora_logits_processor.create_lora_weights(max_loras, lora_config)

return linear, sampler, lora_sampler
return linear, logits_processor, lora_logits_processor

for i in range(10):
set_random_seed(i)

id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, sampler, lora_sampler = create_random_sampler_layer()
linear, logits_processor, lora_logits_processor = _pretest()

# NOTE: all the generated loras share the same embeddings tensor.
lora_dict, _ = populate_loras(
id_to_index,
layer=lora_sampler,
layer=lora_logits_processor,
layer_weights=linear.weight,
generate_embeddings_tensor=1024,
)
Expand All @@ -447,34 +448,37 @@ def create_random_sampler_layer():
32000,
lora_config.lora_extra_vocab_size,
)
lora_sampler.set_mapping(*mapping_info, )
lora_logits_processor.set_mapping(*mapping_info, )

lora_result = lora_sampler._get_logits(hidden_states=torch.cat(inputs),
embedding=linear.weight,
embedding_bias=None)
lora_result = lora_logits_processor._get_logits(
hidden_states=torch.cat(inputs),
embedding=linear.weight,
embedding_bias=None)

original_weight = linear.weight.clone()

linear.weight[sampler.org_vocab_size:sampler.org_vocab_size +
linear.weight[logits_processor.
org_vocab_size:logits_processor.org_vocab_size +
embeddings_tensor_len] = embeddings_tensor

sampler.org_vocab_size = 32000 + lora_config.lora_extra_vocab_size
logits_processor.org_vocab_size = (32000 +
lora_config.lora_extra_vocab_size)
expected_results = []
for input_, lora_id in zip(inputs, prompt_mapping):
lora = lora_dict[lora_id]
result = sampler._get_logits(hidden_states=input_,
embedding=linear.weight,
embedding_bias=None)
result = logits_processor._get_logits(hidden_states=input_,
embedding=linear.weight,
embedding_bias=None)
result[:, 32000 + embeddings_tensor_len:] = float("-inf")
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
expected_results.append(result)
expected_result = torch.cat(expected_results)
sampler.org_vocab_size = 32000
logits_processor.org_vocab_size = 32000

# Check that resetting the lora weights succeeds

for slot_idx in range(max_loras):
lora_sampler.reset_lora(slot_idx)
lora_logits_processor.reset_lora(slot_idx)

inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=[0],
Expand All @@ -488,14 +492,16 @@ def create_random_sampler_layer():
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
32000,
lora_config.lora_extra_vocab_size)
lora_sampler.set_mapping(*mapping_info, )

lora_result = lora_sampler._get_logits(hidden_states=torch.cat(inputs),
embedding=original_weight,
embedding_bias=None)[:, :32000]
expected_result = sampler._get_logits(hidden_states=torch.cat(inputs),
embedding=original_weight,
embedding_bias=None)
lora_logits_processor.set_mapping(*mapping_info, )

lora_result = lora_logits_processor._get_logits(
hidden_states=torch.cat(inputs),
embedding=original_weight,
embedding_bias=None)[:, :32000]
expected_result = logits_processor._get_logits(
hidden_states=torch.cat(inputs),
embedding=original_weight,
embedding_bias=None)

rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result,
Expand Down
95 changes: 20 additions & 75 deletions tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,12 @@

class MockLogitsSampler(Sampler):

def __init__(self, vocab_size: int, fake_logits: torch.Tensor):
super().__init__(vocab_size=vocab_size)
def __init__(self, fake_logits: torch.Tensor):
super().__init__()
self.fake_logits = fake_logits

def forward(self, *args, **kwargs):
with patch(
"vllm.model_executor.layers.sampler._prune_hidden_states",
lambda x, y: x), patch(
"vllm.model_executor.layers.sampler.Sampler._get_logits",
lambda *args, **kwargs: self.fake_logits):
return super().forward(*args, **kwargs)
return super().forward(*args, **kwargs)


def _prepare_test(
Expand All @@ -36,7 +31,7 @@ def _prepare_test(
fake_logits = torch.full((batch_size, vocab_size),
1e-2,
dtype=input_tensor.dtype)
sampler = MockLogitsSampler(32000, fake_logits)
sampler = MockLogitsSampler(fake_logits)
model_runner = ModelRunner(None, None, None, None, None)
return input_tensor, fake_logits, sampler, model_runner

Expand Down Expand Up @@ -70,9 +65,7 @@ def _do_sample(
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
return sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
Expand All @@ -85,8 +78,8 @@ def test_sampler_all_greedy(seed: int, device: str):
batch_size)

sampling_params = SamplingParams(temperature=0)
sampler_output = _do_sample(batch_size, input_tensor, sampler,
model_runner, sampling_params)
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
sampling_params)
expected = torch.argmax(fake_logits, dim=-1)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
Expand All @@ -111,8 +104,8 @@ def test_sampler_all_random(seed: int, device: str):
temperature=1.0,
n=random.randint(1, 10),
)
sampler_output = _do_sample(batch_size, input_tensor, sampler,
model_runner, sampling_params)
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
sampling_params)

for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
Expand All @@ -127,8 +120,7 @@ def test_sampler_all_random_seed(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)

for i in range(batch_size):
fake_logits[i, i] = 1e2
Expand All @@ -138,8 +130,8 @@ def test_sampler_all_random_seed(seed: int, device: str):
n=random.randint(1, 10),
seed=random.randint(0, 10000),
)
sampler_output = _do_sample(batch_size, input_tensor, sampler,
model_runner, sampling_params)
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
sampling_params)

for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
Expand All @@ -154,18 +146,17 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)

sampling_params = SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
seed=random.randint(0, 10000),
)
first_sampler_output = _do_sample(batch_size, input_tensor, sampler,
first_sampler_output = _do_sample(batch_size, fake_logits, sampler,
model_runner, sampling_params)

second_sampler_output = _do_sample(batch_size, input_tensor, sampler,
second_sampler_output = _do_sample(batch_size, fake_logits, sampler,
model_runner, sampling_params)

assert first_sampler_output == second_sampler_output
Expand All @@ -179,15 +170,14 @@ def test_sampler_all_beam(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)

sampling_params = SamplingParams(
temperature=0,
best_of=2,
use_beam_search=True,
)
_do_sample(batch_size, input_tensor, sampler, model_runner,
sampling_params)
_do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params)
# no assertion here as I am not sure how to determine whether
# the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler
Expand Down Expand Up @@ -246,8 +236,7 @@ def test_sampler_mixed(seed: int, device: str):
def test_sampling(model_runner: ModelRunner):
sampling_metadata = model_runner._prepare_sample(
seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampler_output = sampler(logits=fake_logits,
sampling_metadata=sampling_metadata)

for i, (sequence_output, metadata) in enumerate(
Expand Down Expand Up @@ -294,48 +283,6 @@ def test_sampling(model_runner: ModelRunner):
del model_runner


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_logits_processors(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)

# This sample logits processor gives maximum score to the i-th token,
# where i is the length of the input sequence.
# We therefore expect the output token sequence to be [0, 1, 2, ...]
def pick_ith(token_ids, logits):
logits[len(token_ids)] = torch.finfo(logits.dtype).max
return logits

seq_group_metadata_list = []
prompt_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(temperature=0,
logits_processors=[pick_ith]),
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
for _, sequence_output in enumerate(sampler_output):
for idx, nth_output in enumerate(sequence_output.samples):
assert nth_output.output_token == idx

del model_runner


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_top_k_top_p(seed: int, device: str):
Expand All @@ -352,7 +299,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
size=(batch_size, vocab_size),
device=input_tensor.device,
dtype=input_tensor.dtype)
sampler = MockLogitsSampler(32000, fake_logits)
sampler = MockLogitsSampler(fake_logits)
model_runner = ModelRunner(None, None, None, None, None)

generation_model = GenerationMixin()
Expand Down Expand Up @@ -391,9 +338,7 @@ def mock_sample(probs, *args, **kwargs):
return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs]

with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone())
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
assert torch.allclose(hf_probs, sample_probs, atol=1e-5)
Expand Down
Loading
Loading