Skip to content

Commit

Permalink
[Lora] Support long context lora (vllm-project#4787)
Browse files Browse the repository at this point in the history
Currently we need to call rotary embedding kernel for each LoRA, which makes it hard to serve multiple long context length LoRA. Add batched rotary embedding kernel and pipe it through.

It replaces the rotary embedding layer to the one that is aware of multiple cos-sin-cache per scaling factors.

Follow up of https://github.com/vllm-project/vllm/pull/3095/files
  • Loading branch information
rkooo567 authored May 18, 2024
1 parent c0724fc commit 2e9a222
Show file tree
Hide file tree
Showing 25 changed files with 999 additions and 72 deletions.
16 changes: 15 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,23 @@ steps:

- label: LoRA Test %N
#mirror_hardwares: [amd]
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py
parallelism: 4

- label: LoRA Long Context (Distributed)
#mirror_hardwares: [amd]
num_gpus: 4
# This test runs llama 13B, so it is required to run on 4 GPUs.
commands:
# Temporarily run this way because we cannot clean up GPU mem usage
# for multi GPU tests.
# TODO(sang): Fix it.
- pytest -v -s lora/test_long_context.py::test_rotary_emb_replaced
- pytest -v -s lora/test_long_context.py::test_batched_rope_kernel
- pytest -v -s lora/test_long_context.py::test_self_consistency
- pytest -v -s lora/test_long_context.py::test_quality
- pytest -v -s lora/test_long_context.py::test_max_len

- label: Tensorizer Test
#mirror_hardwares: [amd]
command: apt-get install curl libsodium23 && pytest -v -s tensorizer_loader
Expand Down
5 changes: 2 additions & 3 deletions format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ mypy vllm/model_executor --config-file pyproject.toml


CODESPELL_EXCLUDES=(
'--skip' '*docs/source/_build/**'
'--skip' '*docs/source/_build/**,./tests/lora/data'
)

# check spelling of specified files
Expand All @@ -133,10 +133,9 @@ spell_check_changed() {
# `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that
# exist on both branches.
MERGEBASE="$(git merge-base origin/main HEAD)"

if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \
codespell "${CODESPELL_EXCLUDES[@]}"
codespell "${CODESPELL_EXCLUDES[@]}"
fi
}

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ exclude = [

[tool.codespell]
ignore-words-list = "dout, te, indicies"
skip = "./tests/prompts,./benchmarks/sonnet.txt"
skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data"

[tool.isort]
use_parentheses = true
Expand Down
50 changes: 50 additions & 0 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader import get_model

LONG_LORA_INFOS = [{
"lora_id": 1,
"context_length": "16k",
}, {
"lora_id": 2,
"context_length": "16k",
}, {
"lora_id": 3,
"context_length": "32k",
}]


def cleanup():
destroy_model_parallel()
Expand Down Expand Up @@ -154,6 +165,45 @@ def tinyllama_lora_files():
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")


@pytest.fixture(scope="session")
def long_context_lora_files_16k_1():
return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_1")


@pytest.fixture(scope="session")
def long_context_lora_files_16k_2():
return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_2")


@pytest.fixture(scope="session")
def long_context_lora_files_32k():
return snapshot_download(repo_id="SangBinCho/long_context_32k_testing")


# SANG-TODO Download long lora files.
@pytest.fixture(scope="session")
def long_context_infos(long_context_lora_files_16k_1,
long_context_lora_files_16k_2,
long_context_lora_files_32k):
cleanup()
infos = {}
for lora_checkpoint_info in LONG_LORA_INFOS:
lora_id = lora_checkpoint_info["lora_id"]
if lora_id == 1:
lora = long_context_lora_files_16k_1
elif lora_id == 2:
lora = long_context_lora_files_16k_2
elif lora_id == 3:
lora = long_context_lora_files_32k
else:
raise AssertionError("Unknown lora id")
infos[lora_id] = {
"context_length": lora_checkpoint_info["context_length"],
"lora": lora,
}
return infos


@pytest.fixture
def llama_2_7b_engine_extra_embeddings() -> nn.Module:
cleanup()
Expand Down
Empty file added tests/lora/data/__init__.py
Empty file.
97 changes: 97 additions & 0 deletions tests/lora/data/long_context_test_data.py

Large diffs are not rendered by default.

100 changes: 98 additions & 2 deletions tests/lora/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,22 @@
# yapf conflicts with isort for this block
# yapf: disable
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
LinearScalingRotaryEmbeddingWithLora,
LogitsProcessorWithLoRA, LoRAMapping,
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora,
QKVParallelLinearWithLora,
RowParallelLinearWithLoRA,
VocabParallelEmbeddingWithLoRA)
# yapf: enable
from vllm.lora.models import (LoRALayerWeights, PackedLoRALayerWeights,
convert_mapping)
from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights,
PackedLoRALayerWeights, convert_mapping)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.utils import set_random_seed
Expand Down Expand Up @@ -771,3 +773,97 @@ class FakeConfig:
expected_result,
rtol=rtol,
atol=atol)


@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 8])
@pytest.mark.parametrize("device", ["cuda"])
@pytest.mark.parametrize("scaling_factors", [(1.0, ), (4.0, ), (4.0, 8.0),
(6.0, 1.0)])
@pytest.mark.parametrize("max_position", [11, 4096, 32768])
@pytest.mark.parametrize("is_neox_style", [True, False])
@pytest.mark.parametrize("rotary_dim", [None, 32])
@pytest.mark.parametrize("head_size", [32, 108])
@pytest.mark.parametrize("seq_len", [11, 1024])
def test_rotary_embedding_long_context(dist_init, num_loras, device,
scaling_factors, max_position,
is_neox_style, rotary_dim, head_size,
seq_len) -> None:
dtype = torch.float16
seed = 0
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device)

max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
long_lora_scaling_factors=scaling_factors,
lora_dtype=dtype)

if rotary_dim is None:
rotary_dim = head_size
base = 10000
batch_size = 5 * num_loras
num_heads = 7

# Verify lora is equivalent to linear scaling rotary embedding.
rope = get_rope(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
)
lora_rope = LinearScalingRotaryEmbeddingWithLora(rope)
lora_rope.create_lora_weights(max_loras, lora_config)
linear_rope = get_rope(head_size, rotary_dim, max_position, base,
is_neox_style, {
"type": "linear",
"factor": scaling_factors
})
linear_rope = linear_rope.to(dtype=dtype)
id_to_index = get_random_id_to_index(num_loras, max_loras)
_, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=[0],
num_inputs=batch_size,
input_size=(1, max_position),
input_range=(0, lora_config.lora_extra_vocab_size),
input_type=torch.float16,
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
long_lora_context = LongContextLoRAContext(list(scaling_factors),
rotary_dim)

next_expected_offset = 0
# Make sure the offset is correct.
scaling_factor_to_offset = lora_rope.scaling_factor_to_offset
for scaling_factor, offset in scaling_factor_to_offset.items():
assert offset == next_expected_offset
next_expected_offset += scaling_factor * max_position

for i in range(len(scaling_factors)):
long_lora_context.offsets_by_lora_id[i] = scaling_factor_to_offset.get(
scaling_factors[i], 0)
mapping_info = convert_mapping(
lora_mapping,
id_to_index,
max_loras,
512,
lora_config.lora_extra_vocab_size,
long_lora_context=long_lora_context,
)
lora_rope.set_mapping(*mapping_info)

positions = torch.randint(0, max_position, (batch_size, seq_len))
query = torch.randn(batch_size,
seq_len,
num_heads * head_size,
dtype=dtype)
key = torch.randn_like(query)
ref_q, ref_k = linear_rope(positions, query, key)
actual_q, actual_k = lora_rope(positions, query, key)

torch.allclose(ref_q, actual_q)
torch.allclose(ref_k, actual_k)
Loading

0 comments on commit 2e9a222

Please sign in to comment.